Seq2SQL :使用强化学习通过自然语言生成SQL
Posted 梳下鱼
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Seq2SQL :使用强化学习通过自然语言生成SQL相关的知识,希望对你有一定的参考价值。
Seq2SQL属于natural language interface (NLI)的领域,方便普通用户接入并查询数据库中的内容,即用户不需要了解SQL语句,只需要通过自然语言,就可查询所需内容。
Seq2SQL借鉴的是Seq2Seq的思想,与Seq2Seq应用于机器翻译与Chatbot类似,Seq2SQL将输入的语句encode后再decode成结构化的SQL语言输出,强化学习是在Seq2SQL中的最后一个模块中应用。同时,这篇论文还推出一个数据集WikiSQL,数据集内有人工标注好的问句及其对应SQL语句。
试验结果显示,Seq2SQL的准确率也不是特别的高,只有60.3%
Seq2SQL结构:
Seq2SQL由三部分组成:
![](https://image.cha138.com/20210514/dd48fa7895d04526900f4f190f36b3e9.jpg)
第一部分: Aggregation classifier 这一部分其实是一个分类器,将用户输入的语句分类成是select count/max/min 等统计相关的约束条件
在此处采用的Augmented Pointer Network,Augmented Pointer Network总体而言也是ecoder-to-decoder的结构,
encoder采用的是两层的bi-LSTM, decoder 采用的是两层的unidirectional LSTM,
encoder输出h,ht对应的是第t个词的输出状态
decoder的每一步是,输入y s-1,输出状态gs,接着,decoder为每个位置t生成一个attention的score
![](https://image.cha138.com/20210514/cff63f11e4184d3982834dd7d9603ef3.jpg)
![](https://image.cha138.com/20210514/15048a96ac0e45e8b8cdce18030c40fb.jpg)
在Seq2SQL中,首先为input生成一个表征向量
(agg:aggregation clasifier, inp:input,enc:encoder) 首先为Augmented Pointer Network类似,计算出一个attention的分数,
,![](https://image.cha138.com/20210514/1a42255a939c4888b28a658556a9136f.jpg)
![](https://image.cha138.com/20210514/694460f30a8746aba98a51673a171a10.jpg)
![](https://image.cha138.com/20210514/06b867803ccf472991b2867cb0e5c9da.jpg)
![](https://image.cha138.com/20210514/1a42255a939c4888b28a658556a9136f.jpg)
量化后,通过softmax函数 ![](https://image.cha138.com/20210514/24907866214f42408adce8d760c85e32.jpg)
![](https://image.cha138.com/20210514/24907866214f42408adce8d760c85e32.jpg)
input的表征向量 ![](https://image.cha138.com/20210514/c8e05d6071f44b31bc1ccb427e095597.jpg)
![](https://image.cha138.com/20210514/c8e05d6071f44b31bc1ccb427e095597.jpg)
通过一个多层的网络和softmax完成分类任务
![](https://image.cha138.com/20210514/b0bd737e49ad40ada2994774732dc0ca.jpg)
![](https://image.cha138.com/20210514/2414005b3a864665922900bc21a1242e.jpg)
第二部分: select column 这一部分是看用户输入的问句命中了哪个column
首先将每个column name 通过LSTM encode
![](https://image.cha138.com/20210514/77478c269d984bc7b265215503085785.jpg)
WikiSQL: ![](https://image.cha138.com/20210514/7484d6684a234495a421096dbd4041f3.jpg)
![](https://image.cha138.com/20210514/8b7eab132c634026bc0a110049d28568.jpg)
将用户输入encode成与第一部分
类似的![](https://image.cha138.com/20210514/fac89131f8784572904fc5bbce8ac82a.jpg)
![](https://image.cha138.com/20210514/ac30cddb55c94e8cb7498f53adb4ee82.jpg)
![](https://image.cha138.com/20210514/fac89131f8784572904fc5bbce8ac82a.jpg)
最终通过一个多层的神经元和softmax确定是命中哪一行
![](https://image.cha138.com/20210514/ac2ec49c019d4f67b8d8b16895a2f563.jpg)
![](https://image.cha138.com/20210514/77478c269d984bc7b265215503085785.jpg)
第三部分:where clause 确定约束条件,因为最终生成的SQL可能与标注中的不太一样,但是依旧有一样的结果,所以不能像前两部分一样使用交叉熵作为loss训练,因此使用强化训练中reward函数 (g: ground-truth), loss使用梯度
![](https://image.cha138.com/20210514/fbbcb4cd31cc42a394b55c7dcbede140.jpg)
WikiSQL包含一系列与SQL相关的问题集以及SQL table
![](https://image.cha138.com/20210514/7484d6684a234495a421096dbd4041f3.jpg)
![](https://image.cha138.com/20210514/fe2e037e49e34a77b9e0fe2ff2795915.jpg)
以上是关于Seq2SQL :使用强化学习通过自然语言生成SQL的主要内容,如果未能解决你的问题,请参考以下文章
论文泛读175迈向通过基于文本的自然语言进行交流的协作强化学习代理
论文泛读175迈向通过基于文本的自然语言进行交流的协作强化学习代理