Entire Space Multi-Task Model(ESMM)阅读
Posted 一杯敬朝阳一杯敬月光
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Entire Space Multi-Task Model(ESMM)阅读相关的知识,希望对你有一定的参考价值。
发于2018年,论文地址:https://arxiv.org/pdf/1804.07931.pdf
论文还公开了数据集:数据集-阿里云天池
阿里开源的git地址:x-deeplearning/xdl-algorithm-solution/ESMM/script at master · alibaba/x-deeplearning · GitHub
ABSTRACT
介绍了传统的CVR模型的两个缺点:
- 样本选择偏差,训练集的样本是产生点击的样本,但是我们线上用的却是在整个空间的推断,这会影响模型的泛化能力。trained with samples of clicked impressions while utilized to make inference on the entire space with samples of all impressions.It is only part of the inference space which is composed of all impressions.This causes a sample selection bias problem。SSB problem will hurt the generalization performance of trained models。
- 数据稀疏,CVR模型训练的样本空间是CTR样本空间的一部分,现实中会比训练CTR任务的样本少得多,而且若只在这个空间构造相关特征,由于数据的稀疏可能也会带来特征在统计意义上的不置信。data sparsity problem, data gathered for training CVR model is generally much less than CTR task.
模型充分利用了用户行为的顺序,即展示→ 点击→ 转化,模型也能解决上述传统CVR模型的缺点:
- 在全样本空间直接建模, modeling CVR directly over the entire space
- 采用特征表示迁移学习策略,CTR网络和CVR网络共享特征表示。 employing a feature representation transfer learning strategy.
INTRODUCTION
在实验中用户行为的顺序遵循 展示→ 点击→ 转化,CVR建模单点击转化率,CVR modeling refers to the task of estimating the post-click conversion rate,
ESMM引入了两个辅助任务,预测单展示点击率任务post-view click-through rate (CTR)、预测单展示点击且产生转化的任务post-view clickthrough&conversion rate (CTCVR)。 ESMM将pCVR视为中间变量,即。pCTCVR和pCTR都是在整个展示空间上的估算,则导出的pCVR也适用于整个空间,这就消除了样本选择偏差的问题。ESMM treats pCVR as an intermediate variable which multiplied by pCTR equals to pCTCVR. Both pCTCVR and pCTR are estimated over the entire space with samples of all impressions, thus the derived pCVR is also applicable over the entire space. CVR网络和CTR网络共享表示特征的参数。CTR网络可以用更丰富的样本进行训练。这种参数转移学习有助于显著缓解DS问题。 parameters of feature representation of CVR network is shared with CTR network. The latter one is trained with much richer samples. This kind of parameter transfer learning helps to alleviate the DS trouble remarkablely.整个数据集由89亿个带有点击和转换的顺序标签的样本组成。 The full dataset consists of 8.9 billions samples with sequential labels of click and conversion.
THE PROPOSED APPROACH
样本采自有曝光的样本空间,表示为 ,N是曝光的样本量,是特征空间(代表曝光样本的特征向量,这些向量通常是高维稀疏的,由多个特征域组成,例如user field、item field), 和 是二值标签空间(表示发生点击,表示发生转化【购买】),表示点击和转化的依赖关系,通常在转化前会有点击发生。注:公开的数据集中,也说明
是非法状态。
Post-click CVR modeling is to estimate the probability of pCVR。会用到两个辅助概率:post-view click-through rate (CTR) with pCTR and post-view click&conversion rate (CTCVR) with pCTCVR。公式如下:
从上式看,和是建立在整个曝光的样本空间的,即pCTCVR 和pCTR 是建立在整个样本空间的,通过这种方式pCVR也可以从整个曝光样本空间导出。那为什么不直接用两个模型分别预测pCTCVR 和pCTR,然后再通过除法公式得到pCVR呢?因为pCTR的预测值通常比较小,除法会引起数值不稳定【另外:单独训练的模型也无法确保pCTR一定大于pCTCVR,这样除法得到的结果不是一个合理的概率】。ESMM采用乘法公式,将pCVR作为一个中间变量。ESMM avoids this with the multiplication form. In ESMM, pCVR is just an intermediate variable which is constrained by the equation of Eq.(1). pCTR and pCTCVR are the main factors ESMM actually estimated over entire space. The multiplication form enables the three associated and co-trained estimators to exploit the sequential patten of data and communicate information with each other during training. Besides, it ensures the value of estimated pCVR to be in range of [0,1], which in DIVISION method might exceed 1.
损失函数是由CTR 和 CTCVR 两部分组成的, 不包含CVR任务。The loss function of ESMM is defined as Eq.(3). It consists of two loss terms from CTR and CTCVR tasks which are calculated over samples of all impressions, without using the loss of CVR task.
,表示交叉熵损失。In ESMM, embedding dictionary of CVR network is shared with that of CTR network. It follows a feature representation transfer learning paradigm. Training samples with all impressions for CTR task is relatively much richer than CVR task. This parameter sharing mechanism enables CVR network in ESMM to learn from un-clicked impressions and provides great help for alleviating the data sparsity trouble.
阿里开源的部分代码:
ctr网络和cvr网络共享输入
indicator = mx.sym.BlockGrad(indicator)
din_ad = mx.sym.concat(*ad_embs)
din_user = mx.sym.concat(*user_embs)
din_user = mx.sym.take(din_user, indicator)
din = mx.sym.concat(din_user, din_ad)
############## ctr
act = 'prelu'
ctr_fc1 = fc('ctr_fc1', din, feature_size*embed_size, 200, act)
ctr_fc2 = fc('ctr_fc2', ctr_fc1, 200, 80, act)
ctr_out = fc('ctr_out', ctr_fc2, 80, 2, '')
############## cvr
cvr_fc1 = fc('cvr_fc1', din, feature_size*embed_size, 200, act)
cvr_fc2 = fc('cvr_fc2', cvr_fc1, 200, 80, act)
cvr_out = fc('cvr_out', cvr_fc2, 80, 2, '')
有标签的:ctr、ctcvr
模型ctcvr的输出 = ctr 的 输出 * cvr的输出
# ctr 和 cvr 直接拿的网络对应的输出
ctr_clk = mx.symbol.slice_axis(data=label, axis=1, begin=0, end=1)
ctr_label = mx.symbol.concat(*[1 - ctr_clk, ctr_clk], dim=1)
ctcvr_buy = mx.symbol.slice_axis(data=label, axis=1, begin=1, end=2)
ctcvr_label = mx.symbol.concat(*[1 - ctcvr_buy, ctcvr_buy], dim=1)
ctr_prop = mx.symbol.softmax(data=ctr_out, axis=1)
cvr_prop = mx.symbol.softmax(data=cvr_out, axis=1)
ctr_prop_one =mx.symbol.slice_axis(data=ctr_prop, axis=1,
begin=1, end=2)
cvr_prop_one =mx.symbol.slice_axis(data=cvr_prop, axis=1,
begin=1, end=2)
############## new version 乘法公式
ctcvr_prop_one = ctr_prop_one * cvr_prop_one
ctcvr_prop = mx.symbol.concat(*[1 - ctcvr_prop_one, \\
ctcvr_prop_one], dim=1)
损失函数:ctr损失 + ctcvr损失
loss_r、ctr_loss、ctcvr_loss 都是 除以 batch_size,是因为ctr、ctcvr、都是建立在整个样本空间的,所以他们的损失以及损失之和也是基于整个样本空间的,就除以样本量。
loss_r = mx.symbol.MakeLoss(-mx.symbol.sum_axis( \\
mx.symbol.log(ctr_prop) * ctr_label + \\
mx.symbol.log(ctcvr_prop) * ctcvr_label, \\
axis=1, keepdims=True), \\
normalization="null", \\
grad_scale = 1.0 / batch_size, \\
name="esmm_loss")
ctr_loss = - mx.symbol.sum(mx.symbol.log(ctr_prop) * ctr_label ) \\
/ batch_size
ctcvr_loss = - mx.symbol.sum(mx.symbol.log(ctcvr_prop) * ctcvr_label) \\
/ batch_size
cvr损失不计入模型的损失函数。cvr的定义,是针对有点击的那部分样本,所以在计算loss的时候 * ctr_clk,这样只有有点击的那部分样本有损失,没有点击的样本不考虑损失,除以的是cnt_cvr_sample ,因为cnt_cvr_sample 代表有点击的样本数。计算损失的时候 标签 用的ctcvr_label,是因为样本集中只有ctr=1,cvr=1,则ctcvr=1,其他ctcvr=0,但是cvr=1则ctr必定等于1,否则是非法数据,所以cvr等于1等价于ctcvr=1
cnt_cvr_sample = mx.symbol.sum_axis(ctr_clk)
cnt_ctcvr_sample = mx.symbol.sum_axis(ctcvr_buy)
cvr_loss = - mx.symbol.sum(mx.symbol.sum_axis( \\
mx.symbol.log(cvr_prop) * ctcvr_label, \\
axis=1, keepdims=True) * ctr_clk) / cnt_cvr_sample
以上是关于Entire Space Multi-Task Model(ESMM)阅读的主要内容,如果未能解决你的问题,请参考以下文章
文献阅读:Entire Space Multi-Task Model:An Effective Approach for Estimating Post-Click Conversion Rate
Distillation based Multi-task Learning: A Candidate GenerationModel for Improving Reading Duration
Distillation based Multi-task Learning: A Candidate GenerationModel for Improving Reading Duration
e611. Setting Focus Traversal Keys for the Entire Application