Database兼容 Python2 / Python3 适配编码的文件型数据容器
Posted 糖果天王
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Database兼容 Python2 / Python3 适配编码的文件型数据容器相关的知识,希望对你有一定的参考价值。
0x00 前言
训练内存轻量化
最近又在训练模型(炼丹),以前老抱怨,区区2万 samples 也好意思叫大数据,近期的任务似乎听到了我这个抱怨,纷纷都是什么“1700万个句子”,“4000个文档”的数据,对服务器内存一次次的进行着冲击。
虽说我之前已经写过一个CIR(CorpusIterationReader)类实现的文章用来解决类似问题(哎?我那篇文章哪去了,被吃了么……emmmm,以后再重发一次吧。)但是那个类也只能让 pivot 以 “文件指针+instance指针” 的方式进行顺序存取,不是很好处理 “shuffle后随机存取” 的情况,再者,“每个文件中包括多个samples” 的设计在多进程中容易产生冲突。
经 cyx 学长提醒,可以考虑每个 sample 单独作为一个文件。(我觉得吧,这个也会有个小问题,就是这个文件夹里千万别一不小心按一下 ls -all
不然要等好半天了哈哈哈)于是基于学长 OneFileDB 的设计,重构并实现了一些这种处理方案的工具类及工具函数,便于我基于 PyTorch 的模型得以正常训练。
跨版本编码兼容
实现过程中,由于 python2 的老项目和 python3 的新项目都需要使用,于是编码也是一个大难题,参考 Bert 里的 convert_to_unicode,研究了 ujson
和 json.JSONEncoder
是如何将不同编码处理成 unicode 并存为 json 格式,把相关的实现也放了进去。
0x01 用法介绍
对于一个 json 文件而言,通常是一个 list,里面包含多个dict的形式存储的 samples
对于模型而言,我们需要的是,在 sample 的数量足够多时,还要能够较快地通过下标(或者key)来获取到对应的 sample 喂给模型。
# JSON EXAMPLE
j = ['info': 'sid': 'test1',
'words': ['id': 'w0', 'word': u'电'.encode('utf-8'),
'id': 'w1', 'word': u'话'.encode('utf-8'),
'id': 'w2', 'word': '[unused10]',
'id': 'w3', 'word': '0',
'id': 'w4', 'word': '2',
'id': 'w5', 'word': '1',
'id': 'w6', 'word': '-',
'id': 'w7', 'word': '3',
'id': 'w27', 'word': '0'],
'entities': [], 'relations': [],
'info': 'sid': 'test2',
'words': ['id': 'w0', 'word': u'地'.encode('utf-8'),
'id': 'w1', 'word': u'址'.encode('utf-8')],
'entities': [], 'relations': []]
我们有三种方式来进行存储:
- OneFileDB,即单文件存储,和我们平时直接读一个文件进来没有两样
- FolderDB,文件夹存储,文件夹中的每一个文件是一个 sample
- CFolderDB,加密文件夹存储,是 FolderDB 的继承类,不同点在于 sample 是加密压缩的
# 特别的,我们可以将一个json文件读入为 OneFileDB 后,
# 通过成员函数 `transfer_to_folderdb(path=<folder_path>)` 生成一个 FolderDB
db_of = Database('./test.json') # OneFileDB
db_f1 = db_of.transfer_to_folderdb('./test') # FolderDB
db_f2 = Database('./test') # FolderDB
db_cf = Database('./test.cfolder') # CFolderDB
这几种 DB 的使用,也是通常的写入,下标读取,遍历,获得 samples 长度等。
而对于 Folder 类的 DB 来说,还有额外的 append 函数,方便其增加新的 samples。
db.write(samples=j)
db_f.append(samples=j)
for idx, item in enumerate(db):
print(idx, item)
print(db.__len__())
print(db[1])
0x02 Source Code
Database 主类
# coding: utf-8
# ==========================================================================
# Copyright (C) 2016-2020 All rights reserved.
#
# filename : training_dbs_new.py
# origin : cyx / caoyixuan
# author : chendian / okcd00@qq.com
# date : 2020-07-21
# desc : An alternative to the original database class (multi-json).
# can be called as a dict or a list.
# ==========================================================================
class Database(object):
"""
A unified wrapper for OneFileDB, FolderDB
"""
def __init__(self, path, samples=None, n_samples=None, read_only=True, load_now=False):
if samples is not None:
db = OneFileDB(path, samples, n_samples=n_samples)
else:
mode = self.determine_mode(path)
logging.info('database mode: '.format(mode))
if mode == 'all_samples_one_file':
db = OneFileDB(path, samples=None, n_samples=n_samples,
read_only=read_only, load_now=load_now)
elif mode == 'one_sample_per_file':
db = FolderDB(path, n_samples=n_samples,
read_only=read_only, load_now=load_now)
elif mode == 'cfolder':
db = CFolderDB(path, n_samples=n_samples,
read_only=read_only, load_now=load_now)
else:
raise ValueError("Unknown mode: ".format(mode))
self.db = db
self.sids = db.sids
@staticmethod
def determine_mode(label_path):
if label_path.endswith('.json'):
mode = 'all_samples_one_file'
elif label_path.endswith('.cfolder') or label_path.endswith('.cfolder/'):
mode = 'cfolder'
else: # directory path without postfix
mode = 'one_sample_per_file'
return mode
def write(self, samples):
return self.db.write(samples)
def get_by_sid(self, sid):
return self.db.get_by_sid(sid)
def __getitem__(self, item):
if isinstance(item, slice):
return self.sl(item)
return self.db[item]
def sl(self, key):
start, stop, step = key.indices(len(self))
for i in range(start, stop, step):
yield self.db[i]
def __len__(self):
return self.db.__len__()
def __iter__(self):
return self.db.__iter__()
def next(self):
return self.db.next()
@property
def all_samples(self):
return self.db.all_samples
if __name__ == "__main__":
sd = Database('./test')
DB基类与三种衍生
class TrainDBBase(object):
"""
An immutable dataset once write.
"""
def write(self, samples):
"""save samples"""
raise NotImplementedError()
def get_by_sid(self, sid):
"""get sample by sid"""
raise NotImplementedError()
def __getitem__(self, item):
""" get sample by index in dataset"""
raise NotImplementedError()
def __len__(self):
"""return the number of samples in this dataset"""
raise NotImplementedError()
def __iter__(self):
self.n = 0
return self
def next(self):
if self.n == self.__len__():
raise StopIteration
n = self.n
self.n += 1
return self[n]
def __next__(self):
return self.next()
@property
def all_samples(self):
"""return all samples in this dataset"""
return [self[i] for i in range(len(self))]
class FolderDB(TrainDBBase):
"""
一个sample写到一个文件里,一个DB就是一个文件夹,只能按照文件名进行索引
NEW: 也可以按下标遍历
"""
def __init__(self, folder, n_samples=None, read_only=True, load_now=False):
self.folder = folder
self.compress = False
self.n_samples = n_samples
self.sids = None
if load_now:
self.load_register()
def write(self, samples):
write_one_sample_per_file(samples, self.folder)
def append(self, samples):
append_write_one_sample_per_file(samples, self.folder)
def get_by_sid(self, sid):
file_path = path_join(self.folder, sid)
sample = json.load(open(file_path))
return sample
def __getitem__(self, index):
self.load_register()
sid = self.sids[index]
return self.get_by_sid(sid)
def __len__(self):
self.load_register()
return len(self.sids)
def load_register(self):
if self.sids is not None:
return
sids = load_register(self.folder)
if self.n_samples:
sids = sids[: self.n_samples]
self.sids = sids
assert len(self.sids) == len(set(self.sids)), 'exist duplicated sids'
class CFolderDB(FolderDB):
"""A json-encrypted FolderDB"""
def write(self, samples):
write_one_sample_per_file(samples, self.folder, compress=True)
def get_by_sid(self, sid):
file_path = path_join(self.folder, sid)
sample = json_load(path=file_path, mode='r', decrypt=True)
# sample = json.loads(zlib.decompress(open(file_path, 'rb').read()).decode('utf-8'))
return sample
class OneFileDB(TrainDBBase):
""" Single file as a DB"""
def __init__(self, file_path, samples=None, n_samples=None, read_only=True, load_now=False):
self.file_path = file_path
self.sids = None
self.samples = None
self.compress = False
self.sid_to_sample = None
self.n_samples = n_samples
if samples is not None:
self.set_samples(samples)
else:
if load_now:
self.load()
def write(self, samples):
json_dump(
obj_=samples, path=self.file_path,
mode='w', encrypt=self.compress)
def get_by_sid(self, sid):
self.load()
return self.sid_to_sample[sid]
def load(self):
if self.samples is not None:
return
samples = json_load(
path=self.file_path, mode='r',
decrypt=self.compress)
self.set_samples(samples)
def set_samples(self, samples):
# make a minor database for testing.
if self.n_samples:
samples = samples[: self.n_samples]
self.samples = samples
self.sids = [s['info']['sid'] for s in self.samples]
self.sid_to_sample = s['info']['sid']: s for s in self.samples
def transfer_to_folderdb(self, path):
write_one_sample_per_file(
answers=self.samples,
folder=path,
compress=self.compress)
return Database(path=path)
def __getitem__(self, item):
self.load()
return self.samples[item]
def __len__(self):
self.load()
return len(self.samples)
Magic Tools
这种任务,最麻烦的就是 Python2 和 Python3 之间的兼容性,兼容性最麻烦的又体现在编码上,Python2的
unicode
编码即Python3的str
编码,Python2的str
编码即Python3的bytes
编码,于是
头文件及依赖
from __future__ import unicode_literals
from six import PY2, PY3
import logging
import os
import zlib
import numpy as np
from io import open
JSON_MODULE = None
JSON编码相关
try:
# if you have ujson, it will be faster
# but the calling method is different.
import ujson as json
JSON_MODULE = 'ujson'
except ImportError:
import json
JSON_MODULE = 'json'
class JsonBytesEncoder(json.JSONEncoder):
# json.dumps
def default(self, obj):
# if isinstance(obj, np.ndarray):
# return obj.tolist() # for further support.
if isinstance(obj, bytes):
return convert_to_unicode(obj)
# return str(obj, encoding='utf-8')
return json.JSONEncoder.default(self, obj)
def json_dumps(obj_, encrypt=False):
if JSON_MODULE == 'json':
_json_str = json.dumps(
obj_, cls=JsonBytesEncoder)
elif JSON_MODULE == 'ujson':
if int(json.__version__[0]) < 2:
# standard ujson-1.35 for python2.7
_json_str = json.dumps(obj_)
else: # standard ujson-3.0.0 for python3.6
_json_str = json.dumps(
obj_, reject_bytes=False)
else:
_json_str = json.dumps(obj_)
if encrypt:
return zlib_encrypt(_json_str)
return _json_str
def json_dump(obj_, path=None, mode='w', stream=None, encrypt=False):
# the same as json.dump(zlib_encrypt(obj_), open(path, 'w'))
# use 'w', not 'wb' in python3 for
# TypeError: a bytes-like object is required, not 'str'
if encrypt: # the zlib.compress transfers data into bytes
mode = 'wb'
if stream is not None:
# stream contains path and mode
stream.write(json_dumps(obj_, encrypt))
else:
with open(path, mode) as f:
f.write(json_dumps(obj_, encrypt))
def json_loads(str_, decrypt=False):
if decrypt:
str_ = zlib_decrypt(str_)
# all kinds of json have the same loads()
data = json.loads(str_)
return data
def json_load(path, mode='r', decrypt=False):
# the same as json.load(open(path, mode))
if decrypt: # the zlib.compress transfers data into bytes
mode = 'rb'
with open(path, mode) as f:
obj_ = json_loads(f.read(), decrypt)
return obj_
def zlib_encrypt(data):
# return an encrypted string
if isinstance(data, (list, dict, tuple)):
j_str = json_dumps(data) # data-structure to json-str
else: # to unicode (py2-unicode or py3-str)
j_str = convert_to_unicode(data)
# zlib only allow bytes-like inputs
return zlib.compress(convert_to_bytes(j_str))
def zlib_decrypt(str_):
# return a json_str in unicode
b_str = zlib.decompress(str_)
return convert_to_unicode(b_str)
def path_join(*args):
return ''.join(convert_to_unicode(each) for each in args)
def write_data(stream, text, encoding=ubantu的python2与python3的相关兼容更新问题