gpt-2 fintune 代码解读

Posted _刘文凯_

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了gpt-2 fintune 代码解读相关的知识,希望对你有一定的参考价值。

在学习使用GPT-2 获得了fintune的代码,阅读之后打上了中文注释,在这里发布出来:

源码:

源码来自github:https://github.com/nshepperd/gpt-2
使用说明:https://medium.com/ai-innovation/beginners-guide-to-retrain-gpt-2-117m-to-generate-custom-text-content-8bb5363d8b7f

源码注释:

encode.py (不是encoder.py, 这两个不一样)
// 在读取代码时model_name 可以改成124M,原来的117M的有问题,根据自己下载的来

#!/usr/bin/env python3
# Usage:
#  PYTHONPATH=src ./encode.py <file|directory|glob> /path/to/output.npz
#  PYTHONPATH=src ./train --dataset /path/to/output.npz

import argparse
import numpy as np

import encoder
from load_dataset import load_dataset

parser = argparse.ArgumentParser(
    description='Pre-encode text files into tokenized training set.',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_name', metavar='MODEL', type=str, default='124M', help='Pretrained model name')
parser.add_argument('--models_dir', metavar='PATH', type=str, default='models', help='Path to models directory')
parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate files with <|endoftext|> separator into chunks of this minimum size')
parser.add_argument('--encoding', type=str, default='utf-8', help='Set the encoding for reading and writing files.')
parser.add_argument('in_text', metavar='PATH', type=str, help='Input file, directory, or glob pattern (utf-8 text).')
parser.add_argument('out_npz', metavar='OUT.npz', type=str, help='Output file path')

def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name, models_dir=args.models_dir)
    print('Reading files')
    chunks = load_dataset(enc, args.in_text, args.combine, encoding=args.encoding)   ## 这里调用load_dataset,这个函数内部有个文件判断,识别文件类型
    print('Writing', args.out_npz)
    np.savez_compressed(args.out_npz, *chunks)   ### 仅仅读出来保存下


if __name__ == '__main__':
    main()


train.py

// 参数加载 需要注意的是dataset是必须的,其它都是可以省略的,model_name 可以改成124M,原来的117M的有问题

import model, sample, encoder
from load_dataset import load_dataset, Sampler

CHECKPOINT_DIR = 'checkpoint'
SAMPLE_DIR = 'samples'


parser = argparse.ArgumentParser(
    description='Fine-tune GPT-2 on your custom dataset.',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--dataset', metavar='PATH', type=str, required=True, help='Input file, directory, or glob pattern (utf-8 text, or preencoded .npz files).')
parser.add_argument('--model_name', metavar='MODEL', type=str, default='124M', help='Pretrained model name')
parser.add_argument('--models_dir', metavar='PATH', type=str, default='models', help='Path to models directory')
parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate input files with <|endoftext|> separator into chunks of this minimum size')
parser.add_argument('--encoding', type=str, default='utf-8', help='Set the encoding for reading and writing files.')

parser.add_argument('--batch_size', metavar='SIZE', type=int, default=1, help='Batch size')
parser.add_argument('--learning_rate', metavar='LR', type=float, default=0.00002, help='Learning rate for Adam')
parser.add_argument('--accumulate_gradients', metavar='N', type=int, default=1, help='Accumulate gradients across N minibatches.')
parser.add_argument('--memory_saving_gradients', default=False, action='store_true', help='Use gradient checkpointing to reduce vram usage.')
parser.add_argument('--twremat', default=False, action='store_true', help='Use tensor rematerialization (better than memory_saving_gradients and works with tensorflow 2.0).')
parser.add_argument('--twremat_memlimit', type=str, default='12G', help='Memory usage limit/target for twremat. Can be an integer, or an integer suffixed with K/M/G for kilo/mega/giga-bytes.')
parser.add_argument('--only_train_transformer_layers', default=False, action='store_true', help='Restrict training to the transformer blocks.')
parser.add_argument('--optimizer', type=str, default='adam', help='Optimizer. <adam|sgd>.')
parser.add_argument('--noise', type=float, default=0.0, help='Add noise to input training data to regularize against typos.')

parser.add_argument('--top_k', type=int, default=40, help='K for top-k sampling.')
parser.add_argument('--top_p', type=float, default=0.0, help='P for top-p sampling. Overrides top_k if set > 0.')

parser.add_argument('--restore_from', type=str, default='latest', help='Either "latest", "fresh", or a path to a checkpoint file')
parser.add_argument('--run_name', type=str, default='run1', help='Run id. Name of subdirectory in checkpoint/ and samples/')
parser.add_argument('--sample_every', metavar='N', type=int, default=100, help='Generate samples every N steps')
parser.add_argument('--sample_length', metavar='TOKENS', type=int, default=1023, help='Sample this many tokens')
parser.add_argument('--sample_num', metavar='N', type=int, default=1, help='Generate this many samples')
parser.add_argument('--save_every', metavar='N', type=int, default=1000, help='Write a checkpoint every N steps')

parser.add_argument('--val_dataset', metavar='PATH', type=str, default=None, help='Dataset for validation loss, defaults to --dataset.')
parser.add_argument('--val_batch_size', metavar='SIZE', type=int, default=2, help='Batch size for validation.')
parser.add_argument('--val_batch_count', metavar='N', type=int, default=40, help='Number of batches for validation.')
parser.add_argument('--val_every', metavar='STEPS', type=int, default=0, help='Calculate validation loss every STEPS steps.')

两个小函数

def maketree(path):  ## 创建文件目录结构
    try:
        os.makedirs(path)
    except:
        pass


def randomize(context, hparams, p):  ## 随机mask、添加noise
    if p > 0:
        mask = tf.random.uniform(shape=tf.shape(context)) < p
        noise = tf.random.uniform(shape=tf.shape(context), minval=0, maxval=hparams.n_vocab, dtype=tf.int32)
        return tf.where(mask, noise, context)
    else:
        return context

下面是main函数,代码较多,我分几个部分进行展示:

def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name, models_dir=args.models_dir)
    hparams = model.default_hparams()  ## 读取默认参数
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:  ## 预训练中的模型参数
        hparams.override_from_dict(json.load(f))  ## 参数重写

    if args.sample_length > hparams.n_ctx:  ## 这里要求我们设置的一个句子的长度不能大于预训练模型的
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)


//关键代码,这里就是进行继续训练的图的构造了

    with tf.Session() as sess:
        # Fully static shape required to make memory accounting in
        # twremat accurate.
        train_context = tf.placeholder(tf.int32, [args.batch_size, 1024])  ## 占位
        train_context_in = randomize(train_context, hparams, args.noise)  ## 设置为输入
        train_output = model.model(hparams=hparams, X=train_context_in)   ### 调用gpt-2的model
        train_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=train_context[:, 1:], logits=train_output['logits'][:, :-1]))  # 损失函数
        if args.val_every > 0:  ## 验证数据构建
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)
        ## 开始创建 sample 的验证模型
        sample_context = tf.placeholder(tf.int32, [args.batch_size, None])
        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=sample_context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p)
        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]  ## 获得所有要更新的参数;  tf.trainable_variables () 指的是需要训练的变量
        train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars  ## 仅仅训练/h里的参数
 # 选择梯度下降所用的算法实现
        if args.memory_saving_gradients:
            if tf.VERSION >= '2':
                exit('Memory saving gradients are not supported in tensorflow 2.x')
            import memory_saving_gradients
            opt_grads = memory_saving_gradients.gradients(train_loss, train_vars)  ## 通过train_loss 对train_vars求梯度
        elif args.twremat:
            import tfremat
            opt_grads = tf.gradients(train_loss, train_vars)
            (train_loss, opt_grads) = tfremat.tf_remat((train_loss, opt_grads), memlimit=args.twremat_memlimit)
        else:
            opt_grads = tf.gradients(train_loss, train_vars)

        opt_grads = list(zip(opt_grads, train_vars))  ## 148, 148  zip -> 结合起来  求梯度
        opt_apply = opt.apply_gradients(opt_grads)  ## 进行梯度下降
        summary_loss = tf.summary.scalar('loss', train_loss) # 用来显示标量信息
        saver = tf.train.Saver(  ### 为模型保存做准备
            var_list=all_vars,
            max_to_keep=5,
            keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())  ## 初始化变量

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))  ## 导入预训练模型
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)

        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)  ## 模型恢复 saver.restore(sess,数据路径)

        print('Loading dataset...')
        chunks = load_dataset(enc, args.dataset, args.combine, encoding=args.encoding)  ## 读取数据
        data_sampler = Sampler(chunks)
        if args.val_every > 0: ## 验证集
            if args.val_dataset:
                val_chunks = load_dataset(enc, args.val_dataset, args.combine, encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')
        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)  ### 构建val数据
            val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

## 计数
        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1  ## 计数器加1
        def save():  ## 保存计数和模型
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))  ## 建立目录
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\\n')

        # 生成一个 samples 示例 下层学习器
        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict=sample_context: args.batch_size * [context_tokens])
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE  ========\\n\\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-').format(counter), 'w', encoding=args.encoding) as fp:
                fp.write('\\n'.join(all_text))

        def validation():  ## 输出验证集的结果
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss, feed_dict=val_context: batch))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary, feed_dict=val_loss: v_val_loss)
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print(
                '[counter | time:2.2f] validation loss = loss:2.2f'
                .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))
        def sample_batch(): # 构建数据集专用
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

//下面开始训练:

        avg_loss = (0.0, 0.0)
        start_time = time.time()
        
        try:
            while True:
                if counter % args.save_every == 0: # 符合条件保存模型 默认训练1000次保存一次
                    save()
                if counter % args.sample_every == 0:  ## 符合条件输出一次sample
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    validation()

                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, train_loss, summaries),
                    feed_dict=train_context: sample_batch())

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)  ## 平均损失函数

                print(  ### 输出各种值
                    '[counter | time:2.2f] loss=loss:2.2f avg=avg:2.2f'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()

注意:这个训练会一直训练,只能按下Ctrl+C才会结束



代码版权:
@articleradford2019language,
  title=Language Models are Unsupervised Multitask Learners,
  author=Radford, Alec and Wu, Jeff and Child, Rewon and Luan, David and Amodei, Dario and Sutskever, Ilya,
  year=2019

以上是关于gpt-2 fintune 代码解读的主要内容,如果未能解决你的问题,请参考以下文章

SpringMVC源码解读--HandlerMapping代码解读

SpringMVC源码解读--HandlerMapping代码解读

DDIM代码详细解读:核心采样代码超分辨率重建

Transformer解析与tensorflow代码解读

AlexNet论文解读与代码实现

DCGAN 代码简单解读