gpt-2 fintune 代码解读
Posted _刘文凯_
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了gpt-2 fintune 代码解读相关的知识,希望对你有一定的参考价值。
在学习使用GPT-2 获得了fintune的代码,阅读之后打上了中文注释,在这里发布出来:
源码:
源码来自github:https://github.com/nshepperd/gpt-2
使用说明:https://medium.com/ai-innovation/beginners-guide-to-retrain-gpt-2-117m-to-generate-custom-text-content-8bb5363d8b7f
源码注释:
encode.py (不是encoder.py, 这两个不一样)
// 在读取代码时model_name 可以改成124M,原来的117M的有问题,根据自己下载的来
#!/usr/bin/env python3
# Usage:
# PYTHONPATH=src ./encode.py <file|directory|glob> /path/to/output.npz
# PYTHONPATH=src ./train --dataset /path/to/output.npz
import argparse
import numpy as np
import encoder
from load_dataset import load_dataset
parser = argparse.ArgumentParser(
description='Pre-encode text files into tokenized training set.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_name', metavar='MODEL', type=str, default='124M', help='Pretrained model name')
parser.add_argument('--models_dir', metavar='PATH', type=str, default='models', help='Path to models directory')
parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate files with <|endoftext|> separator into chunks of this minimum size')
parser.add_argument('--encoding', type=str, default='utf-8', help='Set the encoding for reading and writing files.')
parser.add_argument('in_text', metavar='PATH', type=str, help='Input file, directory, or glob pattern (utf-8 text).')
parser.add_argument('out_npz', metavar='OUT.npz', type=str, help='Output file path')
def main():
args = parser.parse_args()
enc = encoder.get_encoder(args.model_name, models_dir=args.models_dir)
print('Reading files')
chunks = load_dataset(enc, args.in_text, args.combine, encoding=args.encoding) ## 这里调用load_dataset,这个函数内部有个文件判断,识别文件类型
print('Writing', args.out_npz)
np.savez_compressed(args.out_npz, *chunks) ### 仅仅读出来保存下
if __name__ == '__main__':
main()
train.py
// 参数加载 需要注意的是dataset是必须的,其它都是可以省略的,model_name 可以改成124M,原来的117M的有问题
import model, sample, encoder
from load_dataset import load_dataset, Sampler
CHECKPOINT_DIR = 'checkpoint'
SAMPLE_DIR = 'samples'
parser = argparse.ArgumentParser(
description='Fine-tune GPT-2 on your custom dataset.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', metavar='PATH', type=str, required=True, help='Input file, directory, or glob pattern (utf-8 text, or preencoded .npz files).')
parser.add_argument('--model_name', metavar='MODEL', type=str, default='124M', help='Pretrained model name')
parser.add_argument('--models_dir', metavar='PATH', type=str, default='models', help='Path to models directory')
parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate input files with <|endoftext|> separator into chunks of this minimum size')
parser.add_argument('--encoding', type=str, default='utf-8', help='Set the encoding for reading and writing files.')
parser.add_argument('--batch_size', metavar='SIZE', type=int, default=1, help='Batch size')
parser.add_argument('--learning_rate', metavar='LR', type=float, default=0.00002, help='Learning rate for Adam')
parser.add_argument('--accumulate_gradients', metavar='N', type=int, default=1, help='Accumulate gradients across N minibatches.')
parser.add_argument('--memory_saving_gradients', default=False, action='store_true', help='Use gradient checkpointing to reduce vram usage.')
parser.add_argument('--twremat', default=False, action='store_true', help='Use tensor rematerialization (better than memory_saving_gradients and works with tensorflow 2.0).')
parser.add_argument('--twremat_memlimit', type=str, default='12G', help='Memory usage limit/target for twremat. Can be an integer, or an integer suffixed with K/M/G for kilo/mega/giga-bytes.')
parser.add_argument('--only_train_transformer_layers', default=False, action='store_true', help='Restrict training to the transformer blocks.')
parser.add_argument('--optimizer', type=str, default='adam', help='Optimizer. <adam|sgd>.')
parser.add_argument('--noise', type=float, default=0.0, help='Add noise to input training data to regularize against typos.')
parser.add_argument('--top_k', type=int, default=40, help='K for top-k sampling.')
parser.add_argument('--top_p', type=float, default=0.0, help='P for top-p sampling. Overrides top_k if set > 0.')
parser.add_argument('--restore_from', type=str, default='latest', help='Either "latest", "fresh", or a path to a checkpoint file')
parser.add_argument('--run_name', type=str, default='run1', help='Run id. Name of subdirectory in checkpoint/ and samples/')
parser.add_argument('--sample_every', metavar='N', type=int, default=100, help='Generate samples every N steps')
parser.add_argument('--sample_length', metavar='TOKENS', type=int, default=1023, help='Sample this many tokens')
parser.add_argument('--sample_num', metavar='N', type=int, default=1, help='Generate this many samples')
parser.add_argument('--save_every', metavar='N', type=int, default=1000, help='Write a checkpoint every N steps')
parser.add_argument('--val_dataset', metavar='PATH', type=str, default=None, help='Dataset for validation loss, defaults to --dataset.')
parser.add_argument('--val_batch_size', metavar='SIZE', type=int, default=2, help='Batch size for validation.')
parser.add_argument('--val_batch_count', metavar='N', type=int, default=40, help='Number of batches for validation.')
parser.add_argument('--val_every', metavar='STEPS', type=int, default=0, help='Calculate validation loss every STEPS steps.')
两个小函数
def maketree(path): ## 创建文件目录结构
try:
os.makedirs(path)
except:
pass
def randomize(context, hparams, p): ## 随机mask、添加noise
if p > 0:
mask = tf.random.uniform(shape=tf.shape(context)) < p
noise = tf.random.uniform(shape=tf.shape(context), minval=0, maxval=hparams.n_vocab, dtype=tf.int32)
return tf.where(mask, noise, context)
else:
return context
下面是main函数,代码较多,我分几个部分进行展示:
def main():
args = parser.parse_args()
enc = encoder.get_encoder(args.model_name, models_dir=args.models_dir)
hparams = model.default_hparams() ## 读取默认参数
with open(os.path.join('models', args.model_name, 'hparams.json')) as f: ## 预训练中的模型参数
hparams.override_from_dict(json.load(f)) ## 参数重写
if args.sample_length > hparams.n_ctx: ## 这里要求我们设置的一个句子的长度不能大于预训练模型的
raise ValueError(
"Can't get samples longer than window size: %s" % hparams.n_ctx)
//关键代码,这里就是进行继续训练的图的构造了
with tf.Session() as sess:
# Fully static shape required to make memory accounting in
# twremat accurate.
train_context = tf.placeholder(tf.int32, [args.batch_size, 1024]) ## 占位
train_context_in = randomize(train_context, hparams, args.noise) ## 设置为输入
train_output = model.model(hparams=hparams, X=train_context_in) ### 调用gpt-2的model
train_loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=train_context[:, 1:], logits=train_output['logits'][:, :-1])) # 损失函数
if args.val_every > 0: ## 验证数据构建
val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
val_output = model.model(hparams=hparams, X=val_context)
val_loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
val_loss_summary = tf.summary.scalar('val_loss', val_loss)
## 开始创建 sample 的验证模型
sample_context = tf.placeholder(tf.int32, [args.batch_size, None])
tf_sample = sample.sample_sequence(
hparams=hparams,
length=args.sample_length,
context=sample_context,
batch_size=args.batch_size,
temperature=1.0,
top_k=args.top_k,
top_p=args.top_p)
all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] ## 获得所有要更新的参数; tf.trainable_variables () 指的是需要训练的变量
train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars ## 仅仅训练/h里的参数
# 选择梯度下降所用的算法实现
if args.memory_saving_gradients:
if tf.VERSION >= '2':
exit('Memory saving gradients are not supported in tensorflow 2.x')
import memory_saving_gradients
opt_grads = memory_saving_gradients.gradients(train_loss, train_vars) ## 通过train_loss 对train_vars求梯度
elif args.twremat:
import tfremat
opt_grads = tf.gradients(train_loss, train_vars)
(train_loss, opt_grads) = tfremat.tf_remat((train_loss, opt_grads), memlimit=args.twremat_memlimit)
else:
opt_grads = tf.gradients(train_loss, train_vars)
opt_grads = list(zip(opt_grads, train_vars)) ## 148, 148 zip -> 结合起来 求梯度
opt_apply = opt.apply_gradients(opt_grads) ## 进行梯度下降
summary_loss = tf.summary.scalar('loss', train_loss) # 用来显示标量信息
saver = tf.train.Saver( ### 为模型保存做准备
var_list=all_vars,
max_to_keep=5,
keep_checkpoint_every_n_hours=2)
sess.run(tf.global_variables_initializer()) ## 初始化变量
if args.restore_from == 'latest':
ckpt = tf.train.latest_checkpoint(
os.path.join(CHECKPOINT_DIR, args.run_name))
if ckpt is None:
# Get fresh GPT weights if new run.
ckpt = tf.train.latest_checkpoint(
os.path.join('models', args.model_name))
elif args.restore_from == 'fresh':
ckpt = tf.train.latest_checkpoint(
os.path.join('models', args.model_name)) ## 导入预训练模型
else:
ckpt = tf.train.latest_checkpoint(args.restore_from)
print('Loading checkpoint', ckpt)
saver.restore(sess, ckpt) ## 模型恢复 saver.restore(sess,数据路径)
print('Loading dataset...')
chunks = load_dataset(enc, args.dataset, args.combine, encoding=args.encoding) ## 读取数据
data_sampler = Sampler(chunks)
if args.val_every > 0: ## 验证集
if args.val_dataset:
val_chunks = load_dataset(enc, args.val_dataset, args.combine, encoding=args.encoding)
else:
val_chunks = chunks
print('dataset has', data_sampler.total_size, 'tokens')
print('Training...')
if args.val_every > 0:
# Sample from validation set once with fixed seed to make
# it deterministic during training as well as across runs.
val_data_sampler = Sampler(val_chunks, seed=1) ### 构建val数据
val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)]
for _ in range(args.val_batch_count)]
## 计数
counter = 1
counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
if os.path.exists(counter_path):
# Load the step number if we're resuming a run
# Add 1 so we don't immediately try to save again
with open(counter_path, 'r') as fp:
counter = int(fp.read()) + 1 ## 计数器加1
def save(): ## 保存计数和模型
maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) ## 建立目录
print(
'Saving',
os.path.join(CHECKPOINT_DIR, args.run_name,
'model-').format(counter))
saver.save(
sess,
os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
global_step=counter)
with open(counter_path, 'w') as fp:
fp.write(str(counter) + '\\n')
# 生成一个 samples 示例 下层学习器
def generate_samples():
print('Generating samples...')
context_tokens = data_sampler.sample(1)
all_text = []
index = 0
while index < args.sample_num:
out = sess.run(
tf_sample,
feed_dict=sample_context: args.batch_size * [context_tokens])
for i in range(min(args.sample_num - index, args.batch_size)):
text = enc.decode(out[i])
text = '======== SAMPLE ========\\n\\n'.format(
index + 1, text)
all_text.append(text)
index += 1
print(text)
maketree(os.path.join(SAMPLE_DIR, args.run_name))
with open(
os.path.join(SAMPLE_DIR, args.run_name,
'samples-').format(counter), 'w', encoding=args.encoding) as fp:
fp.write('\\n'.join(all_text))
def validation(): ## 输出验证集的结果
print('Calculating validation loss...')
losses = []
for batch in tqdm.tqdm(val_batches):
losses.append(sess.run(val_loss, feed_dict=val_context: batch))
v_val_loss = np.mean(losses)
v_summary = sess.run(val_loss_summary, feed_dict=val_loss: v_val_loss)
summary_log.add_summary(v_summary, counter)
summary_log.flush()
print(
'[counter | time:2.2f] validation loss = loss:2.2f'
.format(
counter=counter,
time=time.time() - start_time,
loss=v_val_loss))
def sample_batch(): # 构建数据集专用
return [data_sampler.sample(1024) for _ in range(args.batch_size)]
//下面开始训练:
avg_loss = (0.0, 0.0)
start_time = time.time()
try:
while True:
if counter % args.save_every == 0: # 符合条件保存模型 默认训练1000次保存一次
save()
if counter % args.sample_every == 0: ## 符合条件输出一次sample
generate_samples()
if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
validation()
(_, v_loss, v_summary) = sess.run(
(opt_apply, train_loss, summaries),
feed_dict=train_context: sample_batch())
summary_log.add_summary(v_summary, counter)
avg_loss = (avg_loss[0] * 0.99 + v_loss,
avg_loss[1] * 0.99 + 1.0) ## 平均损失函数
print( ### 输出各种值
'[counter | time:2.2f] loss=loss:2.2f avg=avg:2.2f'
.format(
counter=counter,
time=time.time() - start_time,
loss=v_loss,
avg=avg_loss[0] / avg_loss[1]))
counter += 1
except KeyboardInterrupt:
print('interrupted')
save()
注意:这个训练会一直训练,只能按下Ctrl+C才会结束
代码版权:
@articleradford2019language,
title=Language Models are Unsupervised Multitask Learners,
author=Radford, Alec and Wu, Jeff and Child, Rewon and Luan, David and Amodei, Dario and Sutskever, Ilya,
year=2019
以上是关于gpt-2 fintune 代码解读的主要内容,如果未能解决你的问题,请参考以下文章
SpringMVC源码解读--HandlerMapping代码解读