presto sql输入表输入字段limitjoin操作解析
Posted scx_white
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了presto sql输入表输入字段limitjoin操作解析相关的知识,希望对你有一定的参考价值。
前言
一段时间没有写文章了,写下最近做的事情。目前我们这边有一个
metabase
查询平台供运营、分析师、产品等人员使用,我们的查询都是使用presto
引擎。并且我们的大数据组件都使用的是emr
组件,并且涉及到中国、美西、美东、印度、欧洲、西欧等多个区域,表的权限管理就特别困难。所以就需要一个统一的权限管理来维护某些人拥有那些表的权限,避免隐私的数据泄漏。于是我们就需要一款sql解析工具来解析presto sql
的输入表。另外还有一点,由于使用的人较多,资源较少,为了避免长查询,我们还会对含有 join 操作查询、 select * 的查询直接拒绝
。
sql 解析
第一种方法
presto
本身也是用的 antlr
进行 sql
语法的编辑,如果你clone了presto的源码,会在 presto-parse
模块中发现 presto/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4
文件,也就是说我们可以通过直接使用该文件生成解析的配置文件1,然后进行 sql
解析
,但是这种方法太过复杂,我尝试了下放弃了,因为从语法树中获取某些值时比较混乱,容错较小,还需要再遍历其儿子、兄弟节点,并且通过节点的 getText
方法获得节点值。
第二种方法
我们肯定很容易的就想到,presto
源码肯定也对 sql
进行了解析,何不直接使用 presto
的解析类呢?
功夫不负有心人,我在源码中发现了 SqlParser
这个类,该类在 presto-parser
模块中,通过调用 createStatement(String sql)
方法会返回一个Statement
2,后面我们只需要对 Statement
进行遍历即可
去掉注释
在 sql执行之前,我们需要进行一些预操作,比如去掉注释,分号分割多行代码
/**
* 替换sql注释
*
* @param sqlText sql
* @return 替换后的sl
*/
protected String replaceNotes(String sqlText)
StringBuilder newSql = new StringBuilder();
String lineBreak = "\\n";
String empty = "";
String trimLine;
for (String line : sqlText.split(lineBreak))
trimLine = line.trim();
if (!trimLine.startsWith("--") && !trimLine.startsWith("download"))
//过滤掉行内注释
line = line.replaceAll("/\\\\*.*\\\\*/", empty);
if (org.apache.commons.lang3.StringUtils.isNotBlank(line))
newSql.append(line).append(lineBreak);
return newSql.toString();
分号分割多段 sql
/**
* ;分割多段sql
*
* @param sqlText sql
* @return
*/
protected ArrayList<String> splitSql(String sqlText)
String[] sqlArray = sqlText.split(Constants.SEMICOLON);
ArrayList<String> newSqlArray = new ArrayList<>(sqlArray.length);
String command = "";
int arrayLen = sqlArray.length;
String oneCmd;
for (int i = 0; i < arrayLen; i++)
oneCmd = sqlArray[i];
boolean keepSemicolon = (oneCmd.endsWith("'") && i + 1 < arrayLen && sqlArray[i + 1].startsWith("'"))
|| (oneCmd.endsWith("\\"") && i + 1 < arrayLen && sqlArray[i + 1].startsWith("\\""));
if (oneCmd.endsWith("\\\\"))
command += org.apache.commons.lang.StringUtils.chop(oneCmd) + Constants.SEMICOLON;
continue;
else if (keepSemicolon)
command += oneCmd + Constants.SEMICOLON;
continue;
else
command += oneCmd;
if (org.apache.commons.lang3.StringUtils.isBlank(command))
continue;
newSqlArray.add(command);
command = "";
return newSqlArray;
sql解析
经过预处理之后,就需要对 sql
进行解析。inputTables、outputTables、tempTables
分别表示输入表、输出表、临时表
@Override
protected Tuple3<HashSet<TableInfo>, HashSet<TableInfo>, HashSet<TableInfo>> parseInternal(String sqlText) throws SqlParseException
this.inputTables = new HashSet<>();
this.outputTables = new HashSet<>();
this.tempTables = new HashSet<>();
try
//ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL 表示数字以DECIMAL类型解析
check(new SqlParser().createStatement(sqlText, new ParsingOptions(ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL)));
catch (ParsingException e)
throw new SqlParseException("parse sql exception:" + e.getMessage(), e);
return new Tuple3<>(inputTables, outputTables, tempTables);
根节点识别
进入 check
方法进行 Statement
的遍历
/**
* statement 过滤 只识别select 语句
*
* @param statement
* @throws SqlParseException
*/
private void check(Statement statement) throws SqlParseException
//如果根节点是查询节点 获取所有的孩子节点,深度优先搜索遍历
if (statement instanceof Query)
Query query = (Query) statement;
List<Node> children = query.getChildren();
for (Node child : children)
checkNode(child);
else if (statement instanceof Use)
Use use = (Use) statement;
this.currentDb = use.getSchema().getValue();
else if (statement instanceof ShowColumns)
ShowColumns show = (ShowColumns) statement;
String allName = show.getTable().toString().replace("hive.", "");
inputTables.add(buildTableInfo(allName, OperatorType.READ));
else if (statement instanceof ShowTables)
ShowTables show = (ShowTables) statement;
QualifiedName qualifiedName = show.getSchema().orElseThrow(() -> new SqlParseException("unkonw table name or db name" + statement.toString()));
String allName = qualifiedName.toString().replace("hive.", "");
if (allName.contains(Constants.POINT))
allName += Constants.POINT + "*";
inputTables.add(buildTableInfo(allName, OperatorType.READ));
else
throw new SqlParseException("sorry,only support read statement,unSupport statement:" + statement.getClass().getName());
- 如果根节点是
Query
查询节点 获取所有的孩子节点,深度优先搜索遍历 - 如果根节点是
Use
切换数据库的节点,修改当前的数据库名称 - 如果根节点是
ShowColumns
查看表字段的节点,将该表加入输入表 - 如果根节点是
ShowTables
查看表结构的节点,将该表加入输入表 - 否则抛出无法解析的异常
子节点遍历
主要进入 checkNode
方法,进行查询语句所有孩子节点的遍历
/**
* node 节点的遍历
*
* @param node
*/
private void checkNode(Node node) throws SqlParseException
//查询子句
if (node instanceof QuerySpecification)
QuerySpecification query = (QuerySpecification) node;
//如果查询包含limit语句 直接将limit入栈
query.getLimit().ifPresent(limit -> limitStack.push(limit));
//遍历子节点
loopNode(query.getChildren());
else if (node instanceof TableSubquery)
loopNode(node.getChildren());
else if (node instanceof AliasedRelation)
// 表的别名 需要放到tableAliaMap供别别名的字段解析使用
AliasedRelation alias = (AliasedRelation) node;
String value = alias.getAlias().getValue();
if (alias.getChildren().size() == 1 && alias.getChildren().get(0) instanceof Table)
Table table = (Table) alias.getChildren().get(0);
tableAliaMap.put(value, table.getName().toString());
else
tempTables.add(buildTableInfo(value, OperatorType.READ));
loopNode(node.getChildren());
else if (node instanceof Query || node instanceof SubqueryExpression
|| node instanceof Union || node instanceof With
|| node instanceof LogicalBinaryExpression || node instanceof InPredicate)
loopNode(node.getChildren());
else if (node instanceof Join)
//发现join操作 设置hasJoin 为true
hasJoin = true;
loopNode(node.getChildren());
//基本都是where条件,过滤掉,如果需要,可以调用getColumn解析字段
else if (node instanceof LikePredicate || node instanceof NotExpression
|| node instanceof IfExpression
|| node instanceof ComparisonExpression || node instanceof GroupBy
|| node instanceof OrderBy || node instanceof Identifier
|| node instanceof InListExpression || node instanceof DereferenceExpression
|| node instanceof IsNotNullPredicate || node instanceof IsNullPredicate
|| node instanceof FunctionCall)
print(node.getClass().getName());
else if (node instanceof WithQuery)
//with 子句的临时表
WithQuery withQuery = (WithQuery) node;
tempTables.add(buildTableInfo(withQuery.getName().getValue(), OperatorType.READ));
loopNode(withQuery.getChildren());
else if (node instanceof Table)
//发现table节点 放入输入表
Table table = (Table) node;
inputTables.add(buildTableInfo(table.getName().toString(), OperatorType.READ));
loopNode(table.getChildren());
else if (node instanceof Select)
//发现select 子句,需要调用getColumn方法从selectItems中获取select的字段
Select select = (Select) node;
List<SelectItem> selectItems = select.getSelectItems();
HashSet<String> columns = new HashSet<>();
for (SelectItem item : selectItems)
if (item instanceof SingleColumn)
columns.add(getColumn(((SingleColumn) item).getExpression()));
else if (item instanceof AllColumns)
columns.add(item.toString());
else
throw new SqlParseException("unknow column type:" + item.getClass().getName());
//将字段入栈
columnsStack.push(columns);
else
throw new SqlParseException("unknow node type:" + node.getClass().getName());
上面需要注意的是,每次想输入表、临时表中添加表时都对应一个 column
的集合从 columnsStack
出栈。
后面看从 selectItems
中获取字段的方法 getColumn
.
/**
* select 字段表达式中获取字段
*
* @param expression
* @return
*/
private String getColumn(Expression expression) throws SqlParseException
if (expression instanceof IfExpression)
IfExpression ifExpression = (IfExpression) expression;
List<Expression> list = new ArrayList<>();
list.add(ifExpression.getCondition());
list.add(ifExpression.getTrueValue());
ifExpression.getFalseValue().ifPresent(list::add);
return getString(list);
else if (expression instanceof Identifier)
Identifier identifier = (Identifier) expression;
return identifier.getValue();
else if (expression instanceof FunctionCall)
FunctionCall call = (FunctionCall) expression;
StringBuilder columns = new StringBuilder();
List<Expression> arguments = call.getArguments();
int size = arguments.size();
for (int i = 0; i < size; i++)
Expression exp = arguments.get(i);
if (i == 0)
columns.append(getColumn(exp));
else
columns.append(getColumn(exp)).append(columnSplit);
return columns.toString();
else if (expression instanceof ComparisonExpression)
ComparisonExpression compare = (ComparisonExpression) expression;
return getString(compare.getLeft(), compare.getRight());
else if (expression instanceof Literal || expression instanceof ArithmeticUnaryExpression)
return "";
else if (expression instanceof Cast)
Cast cast = (Cast) expression;
return getColumn(cast.getExpression());
else if (expression instanceof DereferenceExpression)
DereferenceExpression reference = (DereferenceExpression) expression;
return reference.toString();
else if (expression instanceof ArithmeticBinaryExpression)
ArithmeticBinaryExpression binaryExpression = (ArithmeticBinaryExpression) expression;
return getString(binaryExpression.getLeft(), binaryExpression.getRight());
else if (expression instanceof SearchedCaseExpression)
SearchedCaseExpression caseExpression = (SearchedCaseExpression如何从 Django 页面的文本字段中获取输入,然后使用响应更新 SQL 表?