论文复现丨基于ModelArts实现Text2SQL

Posted 华为云开发者联盟

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了论文复现丨基于ModelArts实现Text2SQL相关的知识,希望对你有一定的参考价值。

摘要:该论文提出了一种基于预训练 BERT 的新神经网络架构,称为 M-SQL。基于列的值提取分为值提取和值列匹配两个模块。

本文分享自华为云社区《基于ModelArts实现Text2SQL》,作者:HWCloudAI。

M-SQL: Multi-Task Representation Learning for Single-Table Text2sql Generation

虽然之前对 Text2SQL 的研究提供了一些可行的解决方案,但大多数都是基于列表示提取值。如果查询中有多个值,并且这些值属于不同的列,则以前基于列表示的方法无法准确提取值。该论文提出了一种基于预训练 BERT 的新神经网络架构,称为 M-SQL。基于列的值提取分为值提取和值列匹配两个模块。

论文地址:M-SQL: Multi-Task Representation Learning for Single-Table Text2sql Generation | IEEE Journals & Magazine | IEEE Xplore

具体算法介绍:AI Gallery_算法_模型_云市场-华为云

注意事项:

1.本案例使用框架:PyTorch1.4.0
2.本案例使用硬件:GPU: 1*NVIDIA-V100NV32(32GB) | CPU: 8 核 64GB
3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码
4.JupyterLab的详细用法: 请参考《ModelAtrs JupyterLab使用指导》
5.碰到问题的解决办法: 请参考《ModelAtrs JupyterLab常见问题解决办法》

1.下载代码和数据集

运行下面代码,进行数据和代码的下载和解压缩

使用TableQA数据集,数据位于m-sql/TableQA/中

import os
# 数据代码下载
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/m-sql.zip
# 解压缩
os.system('unzip m-sql.zip -d ./')
os.chdir('./m-sql')

2.训练

2.1安装依赖库

!pip install -r pip-requirements.txt

2.2训练所需参数和函数

import os
import argparse
import shutil
import sqlite3
import time
import tqdm
import torch
import random as python_random
from transformers import BertTokenizer, BertModel
import logging
import numpy as np
from model import Loss_sw_se, Seq2SQL_v1
# import moxing as mox
from sql_utils.utils_tableqa import load_tableqa, get_loader, get_fields, get_g, get_g_wvi, get_wemb_bert, \\
    pred_sw_se, convert_pr_wvi_to_string, generate_sql_i, extract_val, normalize_sql, get_acc, get_acc_x, \\
    save_for_evaluation, load_tableqa_data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def construct_hyper_param(parser):
    parser.add_argument("--eval", default='False', type=str)
    parser.add_argument("--no_save", default='False', type=str)
    parser.add_argument("--toy_model", default='False', type=str)
    parser.add_argument("--toy_size", default=16, type=int)
    parser.add_argument('--tepoch', default=1, type=int)
    parser.add_argument('--print_per_step', default=50, type=int)
    parser.add_argument("--bS", default=32, type=int,
 help="Batch size")
    parser.add_argument("--accumulate_gradients", default=1, type=int,
 help="The number of accumulation of backpropagation to effectivly increase the batch size.")
    parser.add_argument('--fine_tune',
                        default='False', type=str,
 help="If present, BERT is trained.")
    parser.add_argument("--data_url", default='./TableQA', type=str,
 help="Saving path of model file, logfile and result file.")
    parser.add_argument("--train_url", default='./data_and_model/', type=str,
 help="Saving path of model file, logfile and result file.")
    parser.add_argument("--vocab_file",
                        default='vocab.txt', type=str,
 help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument("--max_seq_length",
                        default=512, type=int,
 help="The maximum total input sequence length after WordPiece tokenization. Sequences ")
    parser.add_argument("--num_target_layers",
                        default=1, type=int,
 help="The Number of final layers of BERT to be used in downstream task.")
    parser.add_argument('--lr_bert', default=1e-5, type=float, help='BERT model learning rate.')
    parser.add_argument('--seed',
 type=int,
                        default=1,
 help="random seed for initialization")
    parser.add_argument('--do_lower_case', default='False', type=str, help='whether to use lower case.')
    parser.add_argument("--bert_url", default='./pre-trained_weights/chinese_wwm_ext_pytorch/', type=str,
 help="Path or model name of BERT")
    parser.add_argument("--load_weight", default='./trained_model/model/best_model.pth', type=str,
 help="model path to load")
    parser.add_argument('--dr', default=0, type=float, help="Dropout rate.")
    parser.add_argument('--lr', default=1e-3, type=float, help="Learning rate.")
    parser.add_argument('--num_warmup_steps', default=-1, type=int, help="num_warmup_steps")
    parser.add_argument("--split", default='val', type=str, help='prefix of jsonl and db files')
    args, _ = parser.parse_known_args()
    python_random.seed(args.seed)
 np.random.seed(args.seed)
 torch.manual_seed(args.seed)
 if torch.cuda.is_available():
 torch.cuda.manual_seed_all(args.seed)
    args.do_lower_case = args.do_lower_case == 'True'
 args.fine_tune = args.fine_tune == 'True'
    args.no_save = args.no_save == 'True'
 args.eval = args.eval == 'True'
    args.toy_model = args.toy_model == 'True'
 return args
def get_bert(bert_path):
    tokenizer = BertTokenizer.from_pretrained(bert_path)
    model_bert = BertModel.from_pretrained(bert_path)
    bert_config = model_bert.config
    model_bert.to(device)
 return model_bert, tokenizer, bert_config
def update_lr(param_groups, current_step, num_warmup_steps, start_lr):
 if current_step <= num_warmup_steps:
        warmup_frac_done = current_step / num_warmup_steps
        new_lr = start_lr * warmup_frac_done
 for param_group in param_groups:
            param_group['lr'] = new_lr
def get_opt(model, model_bert, fine_tune):
 if fine_tune:
        opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                               lr=args.lr, weight_decay=0)
        opt_bert = torch.optim.Adam(filter(lambda p: p.requires_grad, model_bert.parameters()),
                                    lr=args.lr_bert, weight_decay=0)
 else:
        opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                               lr=args.lr, weight_decay=0)
        opt_bert = None
 return opt, opt_bert
def get_models(args, logger, bert_model, trained=False, path_model=None, eval=False):
 # some constants
 if not eval:
 logger.info(f"Batch_size = args.bS * args.accumulate_gradients")
 logger.info(f"BERT parameters:")
 logger.info(f"learning rate: args.lr_bert")
 logger.info(f"Fine-tune BERT: args.fine_tune")
 # Get BERT
    model_bert, tokenizer, bert_config = get_bert(bert_model)
    iS = bert_config.hidden_size * args.num_target_layers
    logger.info(bert_config.to_json_string())
 # Get Seq-to-SQL
 if not eval:
 logger.info(f"Seq-to-SQL: the number of final BERT layers to be used: args.num_target_layers")
 logger.info(f"Seq-to-SQL: learning rate = args.lr")
    model = Seq2SQL_v1(iS, args.dr)
    model = model.to(device)
 if trained:
 assert path_model != None
 if torch.cuda.is_available():
            res = torch.load(path_model)
 else:
            res = torch.load(path_model, map_location='cpu')
        model_bert.load_state_dict(res['model_bert'])
        model_bert.to(device)
 model.load_state_dict(res['model'])
        model.to(device)
 return model, model_bert, tokenizer, bert_config
def get_data(path_wikisql, args):
    train_data, train_table, dev_data, dev_table = load_tableqa(path_wikisql, args.toy_model, args.toy_size,
                                                                no_hs_tok=True)
    train_loader, dev_loader = get_loader(train_data, dev_data, args.bS, shuffle_train=True)
 return train_data, train_table, dev_data, dev_table, train_loader, dev_loader
def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer,
          max_seq_length, num_target_layers, accumulate_gradients, print_per_step, logger,
          current_step, st_pos=0, opt_bert=None):
 model.train()
    model_bert.train()
    torch.autograd.set_detect_anomaly(True)
    ave_loss = 0
    cnt = 0
 for iB, t in enumerate(train_loader):
        cnt += len(t)
 if cnt < st_pos:
 continue
 # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, train_table, no_hs_t=True, no_sql_t=True)
 # nlu  : natural language utterance
 # nlu_t: tokenized nlu
 # sql_i: canonical form of SQL query
 # sql_q: full SQL query text. Not used.
 # sql_t: tokenized SQL query
 # tb   : table
 # hs_t : tokenized headers. Not used.
        g_sn, g_sc, g_sa, g_wnop, g_wc, g_wo, g_wv = get_g(sql_i)
        g_wvi, g_tags, g_value_match = get_g_wvi(t, g_wc)
        wemb_cls, wemb_n, wemb_h, l_n, l_hpu, l_hs, \\
        nlu_tt, t_to_tt_idx, tt_to_t_idx \\
 = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        l_n_t = []
 for t in t_to_tt_idx:
            l_n_t.append(len(t))
 # wemb_n: natural language embedding
 # wemb_h: header embedding
 # l_n: token lengths of each question
 # l_hpu: header token lengths
 # l_hs: the number of columns (headers) of the tables.
 # score
        s_sn, s_sc, s_sa, s_wnop, s_wc, \\
        s_wo, s_tags, s_match = model(wemb_cls, wemb_n, l_n_t, wemb_h, l_hpu, l_hs,
                                      t_to_tt_idx=t_to_tt_idx,
                                      g_sn=g_sn, g_sc=g_sc, g_sa=g_sa, g_wo=g_wo,
                                      g_wnop=g_wnop, g_wc=g_wc, g_wvi=g_wvi,
                                      g_tags=g_tags, g_vm=g_value_match)
 # Calculate loss & step
        loss = Loss_sw_se(s_sn, s_sc, s_sa, s_wnop, s_wc, s_wo, s_tags, s_match,
                          g_sn, g_sc, g_sa, g_wnop, g_wc, g_wo, g_tags, g_value_match)
 if iB % accumulate_gradients == 0:
 opt.zero_grad()
 if opt_bert:
                opt_bert.zero_grad()
 loss.backward()
 if accumulate_gradients == 1:
                update_lr(opt.param_groups, current_step, args.num_warmup_steps, args.lr)
 opt.step()
 if opt_bert:
                    update_lr(opt_bert.param_groups, current_step, args.num_warmup_steps, args.lr_bert)
                    opt_bert.step()
                current_step += 1
 elif iB % accumulate_gradients == (accumulate_gradients - 1):
 loss.backward()
            update_lr(opt.param_groups, current_step, args.num_warmup_steps, args.lr)
 opt.step()
 if opt_bert:
                update_lr(opt_bert.param_groups, current_step, args.num_warmup_steps, args.lr_bert)
                opt_bert.step()
            current_step += 1
 else:
 loss.backward()
 # statistics
        ave_loss += loss.item()
 if iB % print_per_step == 0:
            log = f'[Train Batch iB] '
            logs = []
 logs.append(f'average loss: "%.4f" % (ave_loss / cnt,)')
 logger.info(log + ', '.join(logs))
 if iB == 150:
            logger.info('暂停训练,如需完整训练删除这个IF分支即可')
 break
    ave_loss /= cnt
 return ave_loss, current_step
def test(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length,
         num_target_layers, print_per_step, logger, path_db, st_pos=0):
 model.eval()
    model_bert.eval()
    cnt = 0
    cnt_sn = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wnop = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_lx = 0
    cnt_x = 0
    db_conn = sqlite3.connect(path_db)
    cursor = db_conn.cursor()
    results = []
 for iB, t in enumerate(data_loader):
        cnt += len(t)
 if cnt < st_pos:
 continue
 # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)
        wemb_cls, wemb_n, wemb_h, l_n, l_hpu, l_hs, \\
        nlu_tt, t_to_tt_idx, tt_to_t_idx \\
 = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        l_n_t = []
 for t in t_to_tt_idx:
            l_n_t.append(len(t))
 # score
        s_sn, s_sc, s_sa, s_wnop, s_wc, \\
        s_wo, s_tags, s_match = model(wemb_cls, wemb_n, l_n_t, wemb_h, l_hpu, l_hs, t_to_tt_idx)
 # prediction
        pr_sn, pr_sc, pr_sa, pr_wn, pr_conn_op, \\
        pr_wc, pr_wo, pr_tags, pr_wvi = pred_sw_se(s_sn, s_sc, s_sa, s_wnop, s_wc, s_wo, s_tags, s_match, l_n_t)
        pr_wv_str = convert_pr_wvi_to_string(pr_wvi, nlu_t)
        pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_conn_op, pr_wc, pr_wo, pr_wv_str, nlu)
        value_indexes, value_nums = extract_val(pr_tags, l_n_t)
 # Saving for the official evaluation later.
 for b, pr_sql_i1 in enumerate(pr_sql_i):
            normalize_sql(pr_sql_i1, tb[b])
            results1 = 
            results1["sql"] = pr_sql_i1
            results1["gold_sql"] = sql_i[b]
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results1['value_indexes'] = value_indexes[b]
            results1['value_nums'] = value_nums[b]
            results1['pr_wc'] = pr_wc[b]
            sn, sc, sa, co, wn, wc, wo, wv, cond, sql = \\
                get_acc(sql_i[b], pr_sql_i1, pr_wc[b], pr_wo[b], tb[b], normalized=True)
            cnt_sn += sn
            cnt_sc += sc
            cnt_sa += sa
            cnt_wnop += (co and wn)
            cnt_wc += wc
            cnt_wo += wo
            cnt_wv += wv
            cnt_lx += sql
            results1['correct'] = sql
            execution, res = get_acc_x(sql_i[b], pr_sql_i1, tb[b], cursor)
            cnt_x += execution
            results1['ex_correct'] = execution
            results1['result'] = res
 results.append(results1)
 # print acc
        cnts = [cnt_sn, cnt_sc, cnt_sa, cnt_wnop, cnt_wc,
                cnt_wo, cnt_wv, cnt_lx, cnt_x, (cnt_lx + cnt_x) / 2]
        cnt_desc = [
 's-num', 's-col', 's-col-agg', 'w-num-op', 'w-col',
 'w-col-op', 'w-col-value', 'acc_lx', 'acc_x', 'acc_mx'
 ]
 if iB % print_per_step == 0:
            log = f'[Test Batch iB] '
            logs = []
 for k, metric in enumerate(cnts):
 logs.append(cnt_desc[k] + ': ' + '%.4f' % (metric / cnt,))
 logger.info(log + ', '.join(logs))
 if iB == 150:
            logger.info('暂停训练,如需完整训练删除这个IF分支即可')
 break
    acc_sn = cnt_sn / cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wnop = cnt_wnop / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wv = cnt_wv / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt
    acc_mx = (acc_lx + acc_x) / 2
    acc = [acc_sn, acc_sc, acc_sa, acc_wnop, acc_wc,
           acc_wo, acc_wv, acc_lx, acc_x, acc_mx]
 return acc, results, acc_lx
def print_result(epoch, acc, dname, logger=None):
 if logger:
 logger.info(f'------------ dname results ------------')
 if dname == 'dev':
            acc_sn, acc_sc, acc_sa, acc_wnop, acc_wc, \\
            acc_wo, acc_wv, acc_lx, acc_x, acc_mx = acc
 logger.info(
 f" Epoch: epoch, s-num: acc_sn:.4f, s-col: acc_sc:.4f,"
 f" s-col-agg: acc_sa:.4f, w-num-op: acc_wnop:.4f,"
 f" w-col: acc_wc:.4f, w-col-op: acc_wo:.4f, w-col-value: acc_wv:.4f,"
 f" acc_lx: acc_lx:.4f, acc_x: acc_x:.4f, acc_mx: acc_mx:.4f"
 )
 else:
 logger.info(f" Epoch: epoch, average loss: acc")
def get_logger(log_fp=None):
 logging.basicConfig(level=logging.INFO,
 format='[%(asctime)s] %(message)s')
    logger = logging.getLogger(__name__)
 if log_fp:
        handler = logging.FileHandler(log_fp)
 handler.setLevel(logging.INFO)
        formatter = logging.Formatter('[%(asctime)s] %(message)s')
 handler.setFormatter(formatter)
 logger.addHandler(handler)
 return logger
def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer,
            max_seq_length, num_target_layers, path_db):
 model.eval()
    model_bert.eval()
    results = []
    cnt = 0
    cnt_sn = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wnop = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_lx = 0
    cnt_x = 0
    db_conn = sqlite3.connect(path_db)
    cursor = db_conn.cursor()
 for iB, t in tqdm.tqdm(enumerate(data_loader)):
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)
        wemb_cls, wemb_n, wemb_h, l_n, l_hpu, l_hs, \\
        nlu_tt, t_to_tt_idx, tt_to_t_idx \\
 = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        l_n_t = []
 for t in t_to_tt_idx:
            l_n_t.append(len(t))
        s_sn, s_sc, s_sa, s_wnop, s_wc, \\
        s_wo, s_tags, s_match = model(wemb_cls, wemb_n, l_n_t, wemb_h, l_hpu, l_hs, t_to_tt_idx)
 # prediction
        pr_sn, pr_sc, pr_sa, pr_wn, pr_conn_op, \\
        pr_wc, pr_wo, pr_tags, pr_wvi = pred_sw_se(s_sn, s_sc, s_sa, s_wnop, s_wc, s_wo, s_tags, s_match, l_n_t)
        pr_wv_str = convert_pr_wvi_to_string(pr_wvi, nlu_t)
        pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_conn_op, pr_wc, pr_wo, pr_wv_str, nlu)
        value_indexes, value_nums = extract_val(pr_tags, l_n_t)
 for b, pr_sql_i1 in enumerate(pr_sql_i):
            cnt += 1
            results1 = 
            normalize_sql(pr_sql_i1, tb[b])
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results1["sql"] = pr_sql_i1
 if sql_i[b]:
                results1["gold_sql"] = sql_i[b]
            results1['value_indexes'] = value_indexes[b]
            results1['value_nums'] = value_nums[b]
            results1['pr_wc'] = pr_wc[b]
 if sql_i[b]:
                sn, sc, sa, co, wn, wc, wo, wv, cond, sql =\\
                    get_acc(sql_i[b], pr_sql_i1, pr_wc[b], pr_wo[b], tb[b], normalized=True)
                cnt_sn += sn
                cnt_sc += sc
                cnt_sa += sa
                cnt_wnop += (wn and co)
                cnt_wc += wc
                cnt_wo += wo
                cnt_wv += wv
                cnt_lx += sql
                results1['correct'] = sql
                execution, res = get_acc_x(sql_i[b], pr_sql_i1, tb[b], cursor)
                cnt_x += execution
                results1['ex_correct'] = execution
                results1['result'] = res
 results.append(results1)
    cnts = [cnt_sn, cnt_sc, cnt_sa, cnt_wnop, cnt_wc,
            cnt_wo, cnt_wv, cnt_lx, cnt_x, (cnt_x + cnt_lx) / 2]
 if sum(cnts) > 0:
        cnt_desc = [
 's-num', 's-col', 's-col-agg', 'w-num-op', 'w-col',
 'w-col-op', 'w-col-value', 'acc_lx', 'acc_x', 'acc_mx'
 ]
 logger.info('--------- eval result ---------')
 for k, metric in enumerate(cnts):
            logger.info(cnt_desc[k] + ': ' + '%.4f' % (metric / cnt,))
 else:
        cnts = None
        cnt_desc = None
 return results, cnt, cnts, cnt_desc

2.3开始训练

if __name__ == '__main__':
 # Hyper parameters
    parser = argparse.ArgumentParser()
    args = construct_hyper_param(parser)
    save_path = args.train_url
 if not os.path.exists(save_path):
 os.makedirs(save_path)
 if not args.eval:
        _model_path = './trained_model/model/'
 shutil.copytree(_model_path, os.path.join(save_path, 'model'))
    t = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    log_fp = os.path.join(save_path, f't.log')
    logger = get_logger(log_fp)
 logger.info(f"BERT-Model: args.bert_url")
    trained = args.load_weight is not None and args.load_weight != 'None'
    load_path = None
 if trained:
        load_path = '/home/work/modelarts/inputs/best_model.pt'
 if args.load_weight and args.load_weight.startswith('obs://'):
 if not os.path.exists(load_path):
 mox.file.copy_parallel(args.load_weight, load_path)
 print('copy %s to %s' % (args.load_weight, load_path))
 else:
 print(load_path, 'already exists')
 else:
            load_path = args.load_weight
    train_input_dir = args.data_url
    bert_model = args.bert_url
 # Paths
    path_wikisql = train_input_dir
    path_val_db = os.path.join(train_input_dir, 'val.db')
    path_save_for_evaluation = save_path
 # Build & Load models
 if args.eval and not trained:
 print('in eval mode, "--load_weight" must be provided!')
 exit(-1)
 if not trained:
        model, model_bert, tokenizer, bert_config = get_models(args, logger, bert_model, eval=args.eval)
 else:
        path_model = load_path
        model, model_bert, tokenizer, bert_config = get_models(args, logger, bert_model,
                                                               trained=True, path_model=path_model,
 eval=args.eval)
 if not args.eval:
        train_data, train_table, dev_data, dev_table, train_loader, dev_loader = get_data(path_wikisql, args)
        opt, opt_bert = get_opt(model, model_bert, args.fine_tune)
        acc_lx_t_best = -1
        epoch_best = -1
        current_step = 1
 for epoch in range(args.tepoch):
 # train
 logger.info(f'Training Epoch epoch')
            ave_loss_train, current_step = train(train_loader,
                                                 train_table,
                                                 model,
                                                 model_bert,
                                                 opt,
                                                 bert_config,
                                                 tokenizer,
                                                 args.max_seq_length,
                                                 args.num_target_layers,
 args.accumulate_gradients,
                                                 args.print_per_step,
                                                 logger=logger,
                                                 current_step=current_step,
                                                 opt_bert=opt_bert,
                                                 st_pos=0)
 # check DEV
 with torch.no_grad():
 logger.info(f'Testing on dev Epoch epoch:')
                acc_dev, results_dev, \\
                    dev_acc_lx = test(dev_loader,
                                      dev_table,
                                      model,
                                      model_bert,
                                      bert_config,
                                      tokenizer,
                                      args.max_seq_length,
                                      args.num_target_layers,
                                      args.print_per_step,
                                      logger=logger,
                                      path_db=path_val_db,
                                      st_pos=0)
            print_result(epoch, ave_loss_train, 'train', logger=logger)
            print_result(epoch, acc_dev, 'dev', logger=logger)
 # save results for the official evaluation
            path_save_file = save_for_evaluation(path_save_for_evaluation,
                                                 results_dev, 'dev', epoch=epoch)
 # mox.file.copy_parallel(path_save_file,
 #                        args.train_url + f'results_dev-epoch.jsonl')
 # save best model
 # Based on Dev Set logical accuracy lx
 if dev_acc_lx > acc_lx_t_best:
                acc_lx_t_best = dev_acc_lx
                epoch_best = epoch
 # save model
 if not args.no_save:
                    state = 'model': model.state_dict(),
 'model_bert': model_bert.state_dict()
 torch.save(state, os.path.join(save_path, 'model', f'best_model.pth'))
 logger.info(f" Best Dev lx acc: acc_lx_t_best at epoch: epoch_best")
 else:
 try:
            dev_data, dev_table = load_tableqa_data(path_wikisql, mode=args.split, no_hs_tok=True)
 except Exception:
            logger.error('未找到输入文件!')
 exit(-1)
        dev_loader = torch.utils.data.DataLoader(
            batch_size=args.bS,
            dataset=dev_data,
            shuffle=False,
            num_workers=1,
            collate_fn=lambda x: x
 )
 with torch.no_grad():
            results, cnt, cnts, cnt_desc \\
 = predict(dev_loader,
                              dev_table,
                              model,
                              model_bert,
                              bert_config,
                              tokenizer,
                              args.max_seq_length,
                              args.num_target_layers,
 os.path.join(train_input_dir, args.split + '.db'))
        save_for_evaluation(os.path.join(save_path, 'pred_results.jsonl'),
                            results, args.split, 'pred', use_filename=True)
 if cnts:
 with open(os.path.join(save_path, 'eval_result.txt'), 'w') as f_eval:
                f_eval.write('--------- eval result ---------\\n')
 for k, metric in enumerate(cnts):
                    f_eval.write(cnt_desc[k] + ': ' + '%.4f' % (metric / cnt,) + '\\n')

3.模型测试

from trained_model.model.customize_service import *
if __name__ == '__main__':
    model_path = r'./outputs/model/best_model.pth'
    my_model = ModelClass('', model_path)
    data = 
 "question": "近四周成交量小于3574套并且环比低于69.7%的城市有几个",
 "table_id": "252c7b6b302e11e995ee542696d6e445"
 
    data = my_model._preprocess(data)
    result = my_model._inference(data)
 print(json.dumps(dict(result), ensure_ascii=False, indent=2))

点击关注,第一时间了解华为云新鲜技术~

以上是关于论文复现丨基于ModelArts实现Text2SQL的主要内容,如果未能解决你的问题,请参考以下文章

FCOS论文复现:通用物体检测算法

跟我学ModelArts丨探索ModelArts平台个性化联邦学习API

论文《基于FPGA 的CFAR 设计与实现》复现

论文《基于FPGA 的CFAR 设计与实现》复现

华为云举办AI经典论文复现活动,打造领先AI开发者学习社区

论文《基于FPGA 的CFAR 设计与实现》复现