Database兼容 Python2 / Python3 适配编码的文件型数据容器

Posted 糖果天王

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Database兼容 Python2 / Python3 适配编码的文件型数据容器相关的知识,希望对你有一定的参考价值。

0x00 前言

训练内存轻量化
最近又在训练模型(炼丹),以前老抱怨,区区2万 samples 也好意思叫大数据,近期的任务似乎听到了我这个抱怨,纷纷都是什么“1700万个句子”,“4000个文档”的数据,对服务器内存一次次的进行着冲击。
虽说我之前已经写过一个CIR(CorpusIterationReader)类实现的文章用来解决类似问题(哎?我那篇文章哪去了,被吃了么……emmmm,以后再重发一次吧。)但是那个类也只能让 pivot 以 “文件指针+instance指针” 的方式进行顺序存取,不是很好处理 “shuffle后随机存取” 的情况,再者,“每个文件中包括多个samples” 的设计在多进程中容易产生冲突。
经 cyx 学长提醒,可以考虑每个 sample 单独作为一个文件。(我觉得吧,这个也会有个小问题,就是这个文件夹里千万别一不小心按一下 ls -all 不然要等好半天了哈哈哈)于是基于学长 OneFileDB 的设计,重构并实现了一些这种处理方案的工具类及工具函数,便于我基于 PyTorch 的模型得以正常训练。

跨版本编码兼容
实现过程中,由于 python2 的老项目和 python3 的新项目都需要使用,于是编码也是一个大难题,参考 Bert 里的 convert_to_unicode,研究了 ujsonjson.JSONEncoder 是如何将不同编码处理成 unicode 并存为 json 格式,把相关的实现也放了进去。

0x01 用法介绍

对于一个 json 文件而言,通常是一个 list,里面包含多个dict的形式存储的 samples
对于模型而言,我们需要的是,在 sample 的数量足够多时,还要能够较快地通过下标(或者key)来获取到对应的 sample 喂给模型。

# JSON EXAMPLE
j = ['info': 'sid': 'test1',
      'words': ['id': 'w0', 'word': u'电'.encode('utf-8'),
                'id': 'w1', 'word': u'话'.encode('utf-8'),
                'id': 'w2', 'word': '[unused10]',
                'id': 'w3', 'word': '0',
                'id': 'w4', 'word': '2',
                'id': 'w5', 'word': '1',
                'id': 'w6', 'word': '-',
                'id': 'w7', 'word': '3',
                'id': 'w27', 'word': '0'],
      'entities': [], 'relations': [],
     'info': 'sid': 'test2',
      'words': ['id': 'w0', 'word': u'地'.encode('utf-8'),
                'id': 'w1', 'word': u'址'.encode('utf-8')],
      'entities': [], 'relations': []]

我们有三种方式来进行存储:

  • OneFileDB,即单文件存储,和我们平时直接读一个文件进来没有两样
  • FolderDB,文件夹存储,文件夹中的每一个文件是一个 sample
  • CFolderDB,加密文件夹存储,是 FolderDB 的继承类,不同点在于 sample 是加密压缩的
# 特别的,我们可以将一个json文件读入为 OneFileDB 后,
# 通过成员函数 `transfer_to_folderdb(path=<folder_path>)` 生成一个 FolderDB
db_of = Database('./test.json')  # OneFileDB
db_f1 = db_of.transfer_to_folderdb('./test')  # FolderDB
db_f2 = Database('./test')  # FolderDB
db_cf = Database('./test.cfolder')  # CFolderDB

这几种 DB 的使用,也是通常的写入,下标读取,遍历,获得 samples 长度等。
而对于 Folder 类的 DB 来说,还有额外的 append 函数,方便其增加新的 samples。

db.write(samples=j)
db_f.append(samples=j)
for idx, item in enumerate(db):
    print(idx, item)
print(db.__len__())
print(db[1])

0x02 Source Code

Database 主类

# coding: utf-8
# ==========================================================================
#   Copyright (C) 2016-2020 All rights reserved.
#
#   filename : training_dbs_new.py
#   origin   : cyx / caoyixuan
#   author   : chendian / okcd00@qq.com
#   date     : 2020-07-21
#   desc     : An alternative to the original database class (multi-json).
#              can be called as a dict or a list.
# ==========================================================================

class Database(object):
    """
    A unified wrapper for OneFileDB, FolderDB
    """

    def __init__(self, path, samples=None, n_samples=None, read_only=True, load_now=False):
        if samples is not None:
            db = OneFileDB(path, samples, n_samples=n_samples)
        else:
            mode = self.determine_mode(path)
            logging.info('database mode: '.format(mode))
            if mode == 'all_samples_one_file':
                db = OneFileDB(path, samples=None, n_samples=n_samples,
                               read_only=read_only, load_now=load_now)
            elif mode == 'one_sample_per_file':
                db = FolderDB(path, n_samples=n_samples,
                              read_only=read_only, load_now=load_now)
            elif mode == 'cfolder':
                db = CFolderDB(path, n_samples=n_samples,
                               read_only=read_only, load_now=load_now)
            else:
                raise ValueError("Unknown mode: ".format(mode))

        self.db = db
        self.sids = db.sids

    @staticmethod
    def determine_mode(label_path):
        if label_path.endswith('.json'):
            mode = 'all_samples_one_file'
        elif label_path.endswith('.cfolder') or label_path.endswith('.cfolder/'):
            mode = 'cfolder'
        else:  # directory path without postfix
            mode = 'one_sample_per_file'
        return mode

    def write(self, samples):
        return self.db.write(samples)

    def get_by_sid(self, sid):
        return self.db.get_by_sid(sid)

    def __getitem__(self, item):
        if isinstance(item, slice):
            return self.sl(item)
        return self.db[item]

    def sl(self, key):
        start, stop, step = key.indices(len(self))
        for i in range(start, stop, step):
            yield self.db[i]

    def __len__(self):
        return self.db.__len__()

    def __iter__(self):
        return self.db.__iter__()

    def next(self):
        return self.db.next()

    @property
    def all_samples(self):
        return self.db.all_samples


if __name__ == "__main__":
    sd = Database('./test')

DB基类与三种衍生

class TrainDBBase(object):
    """
    An immutable dataset once write.
    """

    def write(self, samples):
        """save samples"""
        raise NotImplementedError()

    def get_by_sid(self, sid):
        """get sample by sid"""
        raise NotImplementedError()

    def __getitem__(self, item):
        """ get sample by index in dataset"""
        raise NotImplementedError()

    def __len__(self):
        """return the number of samples in this dataset"""
        raise NotImplementedError()

    def __iter__(self):
        self.n = 0
        return self

    def next(self):
        if self.n == self.__len__():
            raise StopIteration
        n = self.n
        self.n += 1
        return self[n]

    def __next__(self):
        return self.next()

    @property
    def all_samples(self):
        """return all samples in this dataset"""
        return [self[i] for i in range(len(self))]


class FolderDB(TrainDBBase):
    """
    一个sample写到一个文件里,一个DB就是一个文件夹,只能按照文件名进行索引
    NEW: 也可以按下标遍历
    """

    def __init__(self, folder, n_samples=None, read_only=True, load_now=False):
        self.folder = folder
        self.compress = False
        self.n_samples = n_samples
        self.sids = None
        if load_now:
            self.load_register()

    def write(self, samples):
        write_one_sample_per_file(samples, self.folder)

    def append(self, samples):
        append_write_one_sample_per_file(samples, self.folder)

    def get_by_sid(self, sid):
        file_path = path_join(self.folder, sid)
        sample = json.load(open(file_path))
        return sample

    def __getitem__(self, index):
        self.load_register()
        sid = self.sids[index]
        return self.get_by_sid(sid)

    def __len__(self):
        self.load_register()
        return len(self.sids)

    def load_register(self):
        if self.sids is not None:
            return
        sids = load_register(self.folder)
        if self.n_samples:
            sids = sids[: self.n_samples]
        self.sids = sids
        assert len(self.sids) == len(set(self.sids)), 'exist duplicated sids'


class CFolderDB(FolderDB):
    """A json-encrypted FolderDB"""
    def write(self, samples):
        write_one_sample_per_file(samples, self.folder, compress=True)

    def get_by_sid(self, sid):
        file_path = path_join(self.folder, sid)
        sample = json_load(path=file_path, mode='r', decrypt=True)
        # sample = json.loads(zlib.decompress(open(file_path, 'rb').read()).decode('utf-8'))
        return sample


class OneFileDB(TrainDBBase):
    """ Single file as a DB"""
    def __init__(self, file_path, samples=None, n_samples=None, read_only=True, load_now=False):
        self.file_path = file_path
        self.sids = None
        self.samples = None
        self.compress = False
        self.sid_to_sample = None
        self.n_samples = n_samples
        if samples is not None:
            self.set_samples(samples)
        else:
            if load_now:
                self.load()

    def write(self, samples):
        json_dump(
            obj_=samples, path=self.file_path,
            mode='w', encrypt=self.compress)

    def get_by_sid(self, sid):
        self.load()
        return self.sid_to_sample[sid]

    def load(self):
        if self.samples is not None:
            return
        samples = json_load(
            path=self.file_path, mode='r',
            decrypt=self.compress)
        self.set_samples(samples)

    def set_samples(self, samples):
        # make a minor database for testing.
        if self.n_samples:
            samples = samples[: self.n_samples]
        self.samples = samples
        self.sids = [s['info']['sid'] for s in self.samples]
        self.sid_to_sample = s['info']['sid']: s for s in self.samples

    def transfer_to_folderdb(self, path):
        write_one_sample_per_file(
            answers=self.samples,
            folder=path,
            compress=self.compress)
        return Database(path=path)

    def __getitem__(self, item):
        self.load()
        return self.samples[item]

    def __len__(self):
        self.load()
        return len(self.samples)

Magic Tools

这种任务,最麻烦的就是 Python2 和 Python3 之间的兼容性,兼容性最麻烦的又体现在编码上,Python2的 unicode 编码即Python3的 str 编码,Python2的 str 编码即Python3的 bytes 编码,于是

头文件及依赖

from __future__ import unicode_literals
from six import PY2, PY3
import logging
import os
import zlib
import numpy as np
from io import open
JSON_MODULE = None

JSON编码相关

try:
    # if you have ujson, it will be faster
    # but the calling method is different.
    import ujson as json
    JSON_MODULE = 'ujson'
except ImportError:
    import json
    JSON_MODULE = 'json'


    class JsonBytesEncoder(json.JSONEncoder):
        # json.dumps
        def default(self, obj):
            # if isinstance(obj, np.ndarray):
            #     return obj.tolist()  # for further support.
            if isinstance(obj, bytes):
                return convert_to_unicode(obj)
                # return str(obj, encoding='utf-8')
            return json.JSONEncoder.default(self, obj)


def json_dumps(obj_, encrypt=False):
    if JSON_MODULE == 'json':
        _json_str = json.dumps(
            obj_, cls=JsonBytesEncoder)
    elif JSON_MODULE == 'ujson':
        if int(json.__version__[0]) < 2:
            # standard ujson-1.35 for python2.7
            _json_str = json.dumps(obj_)
        else:  # standard ujson-3.0.0 for python3.6
            _json_str = json.dumps(
                obj_, reject_bytes=False)
    else:
        _json_str = json.dumps(obj_)
    if encrypt:
        return zlib_encrypt(_json_str)
    return _json_str


def json_dump(obj_, path=None, mode='w', stream=None, encrypt=False):
    # the same as json.dump(zlib_encrypt(obj_), open(path, 'w'))
    # use 'w', not 'wb' in python3 for
    # TypeError: a bytes-like object is required, not 'str'
    if encrypt:  # the zlib.compress transfers data into bytes
        mode = 'wb'
    if stream is not None:
        # stream contains path and mode
        stream.write(json_dumps(obj_, encrypt))
    else:
        with open(path, mode) as f:
            f.write(json_dumps(obj_, encrypt))


def json_loads(str_, decrypt=False):
    if decrypt:
        str_ = zlib_decrypt(str_)
    # all kinds of json have the same loads()
    data = json.loads(str_)
    return data


def json_load(path, mode='r', decrypt=False):
    # the same as json.load(open(path, mode))
    if decrypt:  # the zlib.compress transfers data into bytes
        mode = 'rb'
    with open(path, mode) as f:
        obj_ = json_loads(f.read(), decrypt)
    return obj_


def zlib_encrypt(data):
    # return an encrypted string
    if isinstance(data, (list, dict, tuple)):
        j_str = json_dumps(data)  # data-structure to json-str
    else:  # to unicode (py2-unicode or py3-str)
        j_str = convert_to_unicode(data)
    # zlib only allow bytes-like inputs
    return zlib.compress(convert_to_bytes(j_str))


def zlib_decrypt(str_):
    # return a json_str in unicode
    b_str = zlib.decompress(str_)
    return convert_to_unicode(b_str)


def path_join(*args):
    return ''.join(convert_to_unicode(each) for each in args)


def write_data(stream, text, encoding=ubantu的python2与python3的相关兼容更新问题

使用Conda [Anaconda]批量更新包裹

python2.9与python2.11兼容不

python2.6升级2.7

升级python2.7

Python2与Python3兼容