Java程序员学深度学习 DJL上手8 使用风格迁移学习
Posted 编程圈子
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Java程序员学深度学习 DJL上手8 使用风格迁移学习相关的知识,希望对你有一定的参考价值。
一、风格迁移学习简介
1. 风格迁移学习
风格迁移,英文名称:Transfer learning,是机器学习的一种,它是在有一定的额外数据和存在一个已有模型的前提下,来生成目标数据,典型应用如生成新的画作,2015年由Gatys等人发表了文章《A Neural Algorithm of Artistic Style》,首次使用深度学习进行艺术画风格学习。
2. BERT
BERT的全称为Bidirectional Encoder Representation from Transformers,是一个预训练的语言表征模型,主要用来作自然语言的分词。
3. DistilBERT
BERT的参数量巨大,运行中需要巨大的空间、消耗大量资源,而DistilBERT 则对Bert进行瘦身。
二、实现过程
1. 说明
这里使用亚马逊的评论数据集,商品品类是数码软件,包含10.2万条有效的评论。选择的预训练模型DistilBERT是一个轻量级的BERT模型 , 已经使用维基百科的超过一百分的文本语料库进行了训练。DistilBERT作为基本层添加到了分类模型用来输出结评论结果星级,星级范围是1-5.
评论数据将作为数据传入,而评分则作为标签。
亚马逊的评论示例:
2. 准备数据集
首先是要准备数据集,原始数据是TSV格式,这里使用CSVDataset来作为数据容器,使用Featurizer接口来对原始数据的行/列进行预处理,以实现特征提取。
final class BertFeaturizer implements CsvDataset.Featurizer
private final BertFullTokenizer tokenizer;
private final int maxLength; // the cut-off length
public BertFeaturizer(BertFullTokenizer tokenizer, int maxLength)
this.tokenizer = tokenizer;
this.maxLength = maxLength;
/** @inheritDoc */
@Override
public void featurize(DynamicBuffer buf, String input)
SimpleVocabulary vocab = tokenizer.getVocabulary();
// convert sentence to tokens (toLowerCase for uncased model)
List<String> tokens = tokenizer.tokenize(input.toLowerCase());
// 超出maxLength的进行截取
tokens = tokens.size() > maxLength ? tokens.subList(0, maxLength) : tokens;
// BERT embedding convention "[CLS] Your Sentence [SEP]"
buf.put(vocab.getIndex("[CLS]"));
tokens.forEach(token -> buf.put(vocab.getIndex(token)));
buf.put(vocab.getIndex("[SEP]"));
对于BERT模型,我们构造一个BertFeaturizer 对象,实现 CsvDataset.Featurizer 方法来进行特征提取。本示例里对数据进行简单的清理。
3. 把 BertFeaturizer 应用在数据集上
CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit)
String amazonReview =
"https://s3.amazonaws.com/amazon-reviews-pds/tsv/amazon_reviews_us_Digital_Software_v1_00.tsv.gz";
float paddingToken = tokenizer.getVocabulary().getIndex("[PAD]");
return CsvDataset.builder()
.optCsvUrl(amazonReview) // load from Url
.setCsvFormat(CSVFormat.TDF.withQuote(null).withHeader()) // Setting TSV loading format
.setSampling(batchSize, true) // make sample size and random access
.optLimit(limit)
.addFeature(new CsvDataset.Feature("review_body", new BertFeaturizer(tokenizer, maxLength)))
.addLabel(new CsvDataset.Feature("star_rating", (buf, data) -> buf.put(Float.parseFloat(data) - 1.0f)))
.optDataBatchifier(
PaddingStackBatchifier.builder()
.optIncludeValidLengths(false)
.addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))
.build()) // define how to pad dataset to a fix length
.build();
在列上应用上面定义的 BertFeaturizer,评分作为标签集。另外一句提取的词比我们的定义长度小的时候,还定义了数据填充方法。
4. 构造模型
先下载DistiledBERT模型,再下载预训练的权重。下载的模型没有包含分类层,我们还需要在构造模型的最后加上分类层然后再训练。对块完成修改后,使用.criteria loadModel setBlock
把模型。
2.4.1 加载模型
// MXNet base model
String modelUrls = "https://resources.djl.ai/test-models/distilbert.zip";
if ("PyTorch".equals(Engine.getInstance().getEngineName()))
modelUrls = "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip";
Criteria<NDList, NDList> criteria = Criteria.builder()
.optApplication(Application.NLP.WORD_EMBEDDING)
.setTypes(NDList.class, NDList.class)
.optModelUrls(modelUrls)
.optProgress(new ProgressBar())
.build();
ZooModel<NDList, NDList> embedding = criteria.loadModel();
2.4.2 创建分类层
这里创建一个简单的MLP层用来对评论级别分类,最后一个全连接层输出5个数值,用来对应评价的5个级别。
层的最前面还会对内嵌文本进行处理。
之后把块加载到模型里。
Predictor<NDList, NDList> embedder = embedding.newPredictor();
Block classifier = new SequentialBlock()
// text embedding layer
.add(
ndList ->
NDArray data = ndList.singletonOrThrow();
NDList inputs = new NDList();
long batchSize = data.getShape().get(0);
float maxLength = data.getShape().get(1);
if ("PyTorch".equals(Engine.getInstance().getEngineName()))
inputs.add(data.toType(DataType.INT64, false));
inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
inputs.add(data.getManager().arange(maxLength)
.toType(DataType.INT64, false)
.broadcast(data.getShape()));
else
inputs.add(data);
inputs.add(data.getManager().full(new Shape(batchSize), maxLength));
// run embedding
try
return embedder.predict(inputs);
catch (TranslateException e)
throw new IllegalArgumentException("embedding error", e);
)
// classification layer
.add(Linear.builder().setUnits(768).build()) // pre classifier
.add(Activation::relu) // 激活函数
.add(Dropout.builder().optRate(0.2f).build())
.add(Linear.builder().setUnits(5).build()) // 5 star rating
.addSingleton(nd -> nd.get(":,0")); // Take [CLS] as the head
Model model = Model.newInstance("AmazonReviewRatingClassification");
model.setBlock(classifier);
5. 开始训练
2.5.1 创建训练集和测试集
首先建立一个单词表,把单词转到数字。然后把字母表喂给tokenizer特征提取器。
最后,要把数据集按比例进行拆分成训练集和测试集。
tokens长度最大设置为64,这意味着评论里只有64个特征分词会被用到。
// Prepare the vocabulary
SimpleVocabulary vocabulary = SimpleVocabulary.builder()
.optMinFrequency(1)
.addFromTextFile(embedding.getArtifact("vocab.txt"))
.optUnknownToken("[UNK]")
.build();
// Prepare dataset
int maxTokenLength = 64; // cutoff tokens length
int batchSize = 8;
int limit = Integer.MAX_VALUE;
// int limit = 512; // uncomment for quick testing
BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);
// split data with 7:3 train:valid ratio
RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);
RandomAccessDataset trainingSet = datasets[0];
RandomAccessDataset validationSet = datasets[1];
2.5.2 创建训练监听器跟踪训练过程
这里要注意设置的精确度、损失函数。训练日志会保存到 build/model1里。
SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
listener.setSaveModelCallback(
trainer ->
TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
// track for accuracy and loss
float accuracy = result.getValidateEvaluation("Accuracy");
model.setProperty("Accuracy", String.format("%.5f", accuracy));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
);
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type
.addEvaluator(new Accuracy())
.optDevices(Device.getDevices(1)) // train using single GPU
.addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
.addTrainingListeners(listener);
2.5.3 训练
int epoch = 2;
Trainer trainer = model.newTrainer(config);
trainer.setMetrics(new Metrics());
Shape encoderInputShape = new Shape(batchSize, maxTokenLength);
// initialize trainer with proper input shape
trainer.initialize(encoderInputShape);
EasyTrain.fit(trainer, epoch, trainingSet, validationSet);
System.out.println(trainer.getTrainingResult());
2.5.4 保存模型
model.save(Paths.get("build/model"), "amazon-review.param");
2.5.5 验证模型
从模型创建一个预测器,然后使用自己的数据进行训练,来验证模型效果。
class MyTranslator implements Translator<String, Classifications>
private BertFullTokenizer tokenizer;
private SimpleVocabulary vocab;
private List<String> ranks;
public MyTranslator(BertFullTokenizer tokenizer)
this.tokenizer = tokenizer;
vocab = tokenizer.getVocabulary();
ranks = Arrays.asList("1", "2", "3", "4", "5");
@Override
public Batchifier getBatchifier() return new StackBatchifier();
@Override
public NDList processInput(TranslatorContext ctx, String input)
List<String> tokens = tokenizer.tokenize(input);
float[] indices = new float[tokens.size() + 2];
indices[0] = vocab.getIndex("[CLS]");
for (int i = 0; i < tokens.size(); i++)
indices[i+1] = vocab.getIndex(tokens.get(i));
indices[indices.length - 1] = vocab.getIndex("[SEP]");
return new NDList(ctx.getNDManager().create(indices));
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list)
return new Classifications(ranks, list.singletonOrThrow().softmax(0));
创建一个预测器:
String review = "It works great, but it takes too long to update itself and slows the system";
Predictor<String, Classifications> predictor = model.newPredictor(new MyTranslator(tokenizer));
System.out.println(predictor.predict(review));
三、源程序
PyTorchLearn
package com.xundh;
import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习
Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习