NLP 实战 | 我发现的飞桨(paddlepaddle)大坑

Posted 幻灰龙

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了NLP 实战 | 我发现的飞桨(paddlepaddle)大坑相关的知识,希望对你有一定的参考价值。


上一篇 我们介绍了数据集和模型的上传/下载管理。解决数据集和模型的管理问题,在我们的新成员加入时就体现了优势,新成员克隆仓库代码、根据文档执行命令下载相关数据集、下载相关模型、启动服务、执行测试,以最快的时间跑通全流程,进而获取新任务,达成 first commit 的目标。本节我们分析一个实战问题诊断的过程。

分离阶段:以交付为目标

我们反复强调,从数据集入手、训练模型、服务化,最终我们是要达成集成或交付的目标。所以事实上核心要解决的有两个方面的事情:

  • 清洗数据输入、选择算法和预训练模型,训练出有足够精度和召回率的模型,这是基础,这个步骤算法工程师们已经需要付出极大的努力。项目开发中要为算法工程师提供一定的时间集中精力解决这个部分。
  • 提供服务,这里面有大量的工程问题,例如前一个步骤不需要考虑代码如何组织的问题,对内存占用、性能也可以暂时不做要求。但是到了提供服务阶段,则必须考虑这些问题。

尽早集成:暴露内存和性能问题

实际上工程师们在第一步已经付出极大的努力,这第二步当然也可以通过学习获得快速成长。本节我们以一个实际的例子说明工程上的模块化和问题诊断的基本手法,我们展示实际的代码和问题只是为了示例。

当我们做了集成,提供了API给上层的时候,事实上是信心不太足的:

  • “该接口会耗时,它是异步的”
  • “那么它的平均耗时多少”
  • “嗯,确实需要测一下”

当我这么回复的时候,实际上说明我们的模块缺乏对响应的基本 profile,这是一个信号。但是为了尽早集成这个目标,是可以先把 API 推出来。上层这边做集成里还有一个压测环节,测试时立刻暴露了问题。

我们的工程师经过努力做了一个打标签的接口,这个接口内部有一组策略,其中有一个环节用到了paddlehub。在批量跑测试中立刻暴露了第1个问题:

  • 在4G内存的机器上接口跑着跑着就会报内存不足崩溃

我们的工程师咨询 paddlepaddle 的技术人员,认为应该跑在8G内存上才可以。我其实是存疑的,但是官方这么说,我们就试一下,在8G内存的服务器上提供服务确实可以把批量测试跑完而不出问题。

接着的一个问题是接口的平均耗时问题。实际上使用预训练模型做微调训练后的模型还是比较大,这种模式我心里对内存占用总有点没底。另外一个没底的是 NLP 预测是否一定会是耗时的?不过在多次集成模型之后,包括第2节我们提到的模型加载单例模式,我们还是有一些经验:

  • 进程内应该用单例模式管理模型的加载,避免同进程内反复加载同一个模型,这对内存占用和耗时都是不必要的浪费。
  • 如果有词向量计算,应该尽可能把能预先计算的做好预计算,提供服务的接口内应该做最小计算。
  • 如果是需要动态加载的,应该最小化每次加载的数据量。解决单次内存占用最小化和加载耗时之间的平衡。

回头来说,批量测试暴露的性能问题:

  • 该接口平均耗时5秒

这个数据并不好,我们立刻调整了优先级,必须解决性能问题。

重构代码:做好模块化

这又回到上一节 里提到的“Hackable Project” 的主题。我们希望问题出现的时候,代码是可以明显看出问题可能在哪。这里我花了一些时间重构已有的代码,原来的代码如下:

class TagService:
    def __init__(self, config, options):
        self.config = config
        self.options = options

    def load(self):
        self.inner_classifier = SGDText2PL()
        self.tag_label = TagLabel()
        self.tag_label.load_label()
        self.inner_classifier.load()
        self.ocr_client = OCRClient(self.config, self.options)
        self.code_extract = CodeExtractService(self.config, self.options)
        self.tag_score = TagScoreService(
            self.tag_label.kg_list, self.tag_label.catalog_xy_list, self.tag_label.kg_position_list)
        self.hub = PaddleHubPL()
        self.hub.load()
        self.ocr_client.load()

    def catalog_predict(self, title, content):
        # import paddlehub as hub
        # sentence = []
        # sentence.append(title + content)
        # sentences = []
        # catalog_id = ''
        # sentences.append(sentence)
        # # model = hub.Module(
        # #     name='ernie_tiny',
        # #     version='2.0.1',
        # #     task='seq-cls',
        # #     load_checkpoint= get_tag_model_path()+'model.pdparams',
        # #     label_map=LABEL_MAP)
        # results = self.hub.model.predict(sentences, max_seq_len=128, batch_size=1, use_gpu=False)
        # for idx, text in enumerate(sentences):
        #     catalog_id = results[idx]
        catalog_id = self.hub.predict(title, content)

        return catalog_id

    def predict(self, title, content):

        # print(question)
        # title = question.get('title')  # 获取标题
        # content = question.get('body')  # 获取内容

        code_title = get_en_character(title)  # 找标题代码
        code_content = self.code_extract.extract_code_for_title(content)[
            'code']  # 找内容代码
        checked = 0
        pre_result = []
        catalog_id = ''
        code_id = ''
        status = 0
        img_list = []
        ocr_text = ''

        try:
            img_list = get_img_url2(content)

            if len(img_list) > 0:

                # print(img_list[0])
                checked = 4
                ocr = self.ocr_client.extract(img_list[0])
                ocr_text = '\\n'.join(ocr['code_text'])
                ocr_code_content = get_code_character(ocr_text)
                pre_result = self.inner_classifier.classify(ocr_code_content)
                code_id = pre_result.get('language')
                if code_id == 'text':
                    code_id = ''
            else:
                if code_title != '' and code_content != '':  # 当标题和内容都有代码的情况
                    checked = 0
                if code_title == '' and code_content != '':  # 当标题没有代码,但内容有代码的情况
                    checked = 1
                if code_title != '' and code_content == '':  # 当标题和代码都没有代码的情况
                    checked = 2
                if code_title == '' and code_content == '':  # 当标题和代码都没有代码的情况
                    checked = 3

            # if checked == 4:
            #     ocr = self.ocr_client.extract(img_list[0])
            #     ocr_text = '\\n'.join(ocr['code_text'])
            #     ocr_code_content = get_code_character(ocr_text)
            #     pre_result = self.inner_classifier.classify(ocr_code_content)
            #     code_id = pre_result.get('language')
            #     print('#$$$$$$#'+code_id)

            if checked == 0 and code_id == '':
                pre_result = self.inner_classifier.classify(code_content)
                code_id = pre_result.get('language')

            if checked == 1 and code_id == '':
                pre_result = self.inner_classifier.classify(code_content)
                code_id = pre_result.get('language')

            if checked == 2 and code_id == '':
                datatitle = []
                datatitle.append(code_title)
                pre_result = self.inner_classifier.classify(datatitle)
                code_id = pre_result.get('language')

            if checked == 3 and code_id == '':
                cn_title = get_cn_character(title)
                cn_content = get_cn_character(content)
                code_id = self.catalog_predict(cn_title, cn_content)

                # print('**************'+code_id)

        except Exception as e:
            temp = str(e)
            code_id = '其他'
            status = -1

        if code_id == 'jar':
            catalog_id == 'jar'

        if checked != 3 and status != -1:
            cn_title = get_cn_character(title)
            cn_content = get_cn_character(content)
            character = {}
            if len(cn_title.strip()) == 0 and len(cn_content.strip()) == 0:
                character = {}
                status = 1
            else:
                catalog_id = self.catalog_predict(cn_title, cn_content)
                # print(catalog_id + '#################')
                status = 2
        if status == 0:
            catalog_id = code_id

        if catalog_id == '':
            catalog_id = '其他'

        if catalog_id == 'text':
            catalog_id = '其他'
        if code_id == 'text':
            catalog_id = '其他'

        return {
            'title': title,
            'content': content,
            'code_id': code_id,
            'catalog_id': catalog_id,
            'status': status
        }

这段代码存在两个典型的问题:

  • 方法 catalog_predict 存在大段注释代码。不要的代码不应该提交到 git,不要用注释的方式保留大段的“备用”代码,应该毫不留情地删除它,如果想看该文件历史上的代码片段,直接看 git 的历史即可。很多工程师不能理解这点,实际上有了 git ,你可以查看该代码文件历史上的任何提交过的代码,没有必要用注释的方式保留“备用代码”,例如上面这个代码片段,就是从 git 的历史 commit 里拷贝的。
  • 方法 predict 里存在一种典型的用flag变量,做代码分支逻辑判断依据的实现方式,而且存在两个交叉的 flag 变量:checkedstatus

第一个问题好解决,删除代码提交即可。第二个问题则让代码不好诊断问题。例如:

  • 几个连续的 if code_title != '' and code_content != '' 需要很费劲才能知道checked flag 的含义,以及它确实在每种情况下只会出现一个唯一的值,这样的代码一不小心就会挂。
  • 几个连续的 if checked == 0 and code_id == '' 需要很费劲才能知道这个分支的含义,以及它确实和其他 if 分支只会被执行一次。至少应该用if elif elif也比全部及格不做闭环的if好理解。
  • 需要很费劲才能理解 if checked == 3 and code_id == ''catalog_idcode_id 不存在时用来赋予了 code_id 的值

总之,需要很费劲才能分析这段代码的分支处理逻辑,以及多个处理情况之间是否有交叉,谁的优先级更高。

经过协调,我决定自己上手改这段代码。我觉的只在类内部用多个函数也能写好,不过我决定拆分下,让每个小类只做一件事。

首先分析 predict 要解决的问题,核心思路应该是:

  • 识别代码标签:
    • 如果内容里有图片,走OCR识别代码类型
    • 否则,如果内容里有代码,识别内容里的代码类型
    • 否则,如果标题里有代码,识别标题里的代码类型
  • 识别大类标签:
    • 提取代码和内容里的中文,用 paddlepaddle 模型来对标签分类

因此,建立一个子文件夹,把上面四个叶子结点的识别分别独立一个类,每个类只做一件事:

ocr_predict.py

class OCRPredict:
    def __init__(self, config, options, code_classifier):
        self.config = config
        self.options = options
        self.code_classifier = code_classifier
        self.ocr_client = None

    def load(self):
        if self.ocr_client is not None:
            return
        self.ocr_client = OCRClient(self.config, self.options)
        self.ocr_client.load()

    def predict(self, content):
        # 查找并处理图片,TODO:查找图片遍历内容和代码提取遍历重复了!
        img_list = get_img_url2(content)
        if len(img_list) == 0:
            return {
                'err': ErrorCode.NOT_FOUND
            }

        # OCR 识别
        ocr = self.ocr_client.extract(img_list[0])
        ocr_text = '\\n'.join(ocr['code_text'])

        # TODO: get_code_character 这个步骤未必要,直接丢给 code_classifier 也是可以的
        ocr_code_content = get_code_character(ocr_text)
        pre_result = self.code_classifier.classify(ocr_code_content)
        code_name = pre_result.get('language')
        if code_name == 'text':
            return {
                'err': ErrorCode.NOT_FOUND
            }
        else:
            return {
                'err': ErrorCode.SUCCESS,
                'code_name': code_name,
            }

code_predict.py

class CodePredict:
    def __init__(self, config, options, code_classifier):
        self.config = config
        self.options = options
        self.code_classifier = code_classifier

    def load(self):
        pass

    def predict(self, code_content):
        pre_result = self.code_classifier.classify(code_content)
        code_name = pre_result.get('language')
        return {
            'err': ErrorCode.SUCCESS,
            'code_name': code_name
        }

category_predict.py

class CategoryPredict:
    def __init__(self, config, options):
        self.config = config
        self.options = options
        self.hub = None

    def load(self):
        if self.hub is not None:
            return
        self.hub = PaddleHubPL()
        self.hub.load()

    def predict(self, title, content):
        category_name = self.hub.predict(title, content)
        return {
            'err': ErrorCode.SUCCESS,
            'category_name': category_name,
        }

这里,OCRPredictCodePredict 都使用依赖注入的方式让外层传入code_classifier。每个类做的事情很简单:load and predict

有了上述三个叶子结点,我们提供一个 策略类 来组织管道的复合逻辑:

class ComposePredict:
    def __init__(self, config, options):
        self.config = config
        self.options = options

        self.code_classifier = None
        self.code_predict = None
        self.ocr_predict = None
        self.category_predict = None
        self.has_load = False

    def load(self):
        if self.has_load:
            return
        self.code_classifier = SGDText2PL()
        self.code_classifier.load()

        self.ocr_predict = OCRPredict(
            self.config, self.options, self.code_classifier)
        self.ocr_predict.load()

        self.code_predict = CodePredict(
            self.config, self.options, self.code_classifier)
        self.code_predict.load()

        self.category_predict = CategoryPredict(self.config, self.options)
        self.category_predict.load()
        self.has_load = True

    def predict(self, title, content, code_title, code_content, cn_title, cn_content):
        # 识别 code_name
        code_name = None
        code_ret = self.predict_code_name(
            title, content, code_title, code_content)

        if code_ret['err'] == ErrorCode.SUCCESS:
            code_name = code_ret['code_name']
            if code_name == 'text' or code_name == 'scheme' or code_name == '':
                code_name = '其他'
            if code_name == 'c':
                code_name = 'c语言'
            if code_name == 'go':
                code_name = 'golang'

        # 识别 category_name
        category_name = None
        category_ret = self.category_predict.predict(cn_title, cn_content)

        if category_ret['err'] != ErrorCode.SUCCESS:
            return category_ret
        else:
            category_name = category_ret['category_name']
            if category_name == 'text' or category_name == '':
                category_name = '其他'

        return {
            'err': ErrorCode.SUCCESS,
            'code_name': code_name,  # 可空
            'category_name': category_name,
        }

    def predict_code_name(self, title, content, code_title, code_content):
        # 内容有代码,尝试识别内容里的代码(内容比标题优先级高)
        if code_content != '':
            ret = self.code_predict.predict(code_content)
            if ret['err'] == ErrorCode.SUCCESS:
                return ret

        # 标题有代码,尝试识别标题里的代码
        if code_title != '':
            ret = self.code_predict.predict([code_title])
            if ret['err'] == ErrorCode.SUCCESS:
                return ret

        # 标题和内容都没有代码,尝试识别图片里的代码(成本最高,放在最后)
        ret = self.ocr_predict.predict(content)
        if ret['err'] == ErrorCode.SUCCESS:
            return ret

        # 识别失败
        return {
            'err': ErrorCode.NOT_FOUND
        }

可以看到,这个类 聚合 了前面的三个功能简单的类, ComposePredict 的使用方式同样是 load and predict。但是我们重点看下区别:

  • predict_code_name 里面使用 快速短路 的方式,从上往下组织管道处理:
    • 如果内容有代码,尝试识别内容里的代码(内容比标题优先级高),成功就直接返回
    • 如果标题有代码,尝试识别标题里的代码,成功就直接返回
    • 如果标题和内容都没有代码,尝试识别图片里的代码(成本最高,放在最后),成功就直接返回
    • 否则,返回失败

当你有一个 管道处理 流程时,用这种方式可以良好的组织管道过程和优先级编排,代码也不会很乱。事实上它是经典设计模式 职责链 模式。不过我日常并不记得它的名字叫什么,如果一个代码组织适合这样写,我们就这样写了。这里给它们起名字只是我在写博客的时候方便说明才起的而已。

再上面的一层 predict 内部,则是拆解了原始代码的意图:

  • 无论怎样 category_name 都是要识别的
  • code_name 可能不存在
  • 当然,这里顺手的变动是,code_namecategory_name 比原来的 code_idcategory_id 更符合含义,它们是名字,不是id。

好了,到了这里,核心的代码重构就完成了,其他还有一些细节的地方只是同理。

耗时分析:找到性能瓶颈

我们的目标是诊断性能瓶颈,最原始的方法就是对代码的每个环节做耗时统计,看哪部分耗时最多。先提供两个AK-47小函数:

def time_start(name):
    '''开始计时,返回计时器上下文'''
    return {
        'name': name,
        'start': round(time.time() * 1000)
    }

def time_end(ctx):
    '''结束计时,返回耗时统计'''
    end = round(time.time() * 1000)
    ctx['end'] = end
    ctx['elapse_mill_secs'] = end - ctx['start']
    ctx['elapse_secs'] = ctx['elapse_mill_secs']/1000
    print("{}耗时:{}毫秒".format(ctx['name'], ctx['elapse_mill_secs']))
    return ctx

于是,我们只需在代码的不同环节加上耗时统计:

timer = time_start()
...
time_end(timer)

通过这种方式,我们很快找到最耗时的地方是 CategoryPredict 类的 predict 方法。而这个类的实现其实是委托给 PaddleHubPL 类,我们看下这个类:

class PaddleHubPL:
    def __init__(self) -> None:
        # 使用 g_model_manager 做单例
        self.model_key = 'paddlehub_tag_svm'
        g_model_manager.register(self.model_key, lambda: PaddleHubPLImpl())
        self.model = None

    def load(self):
        try:
            self.model = g_model_manager.load(self.model_key)
            return {
                'err': ErrorCode.SUCCESS
            }

        except Exception as e:
            logger.error('load SGDText2PL model failed:', str(e))
            logger.error(traceback.format_exc())
            return {
                'err': ErrorCode.NOT_FOUND
            }

    def predict(self, title, content):
        ret = self.model.predict(title, content)
        return ret

由于 CategoryPredict 内部没有别的逻辑,它可以直接被 PaddleHubPL 替代,不过这个我们可以先不管。 PaddleHubPL 内部使用 g_model_manager 来单例化 PaddleHubPLImpl那既然已经单例化了,至少 PaddleHubPLImplload 应该最多只会

以上是关于NLP 实战 | 我发现的飞桨(paddlepaddle)大坑的主要内容,如果未能解决你的问题,请参考以下文章

320万开发者在用的飞桨,全新发布推理部署导航图:打通AI应用最后一公里

百所高校,万人参与,飞桨校园AI Day收官!

越学越有趣:『手把手带你学NLP』系列项目07 ——机器翻译的那些事儿

深度学习核心技术精讲100篇(八十一)-NLP预训练模型ERNIE实战应用案例

深度学习核心技术精讲100篇(八十一)-NLP预训练模型ERNIE实战应用案例

想聊天?自己搭建个聊天机器人吧!