热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

仅需24小时,带你基于PaddleRec复现经典CTR预估算法

项目背景偶然看到了【飞桨论文复现挑战赛】,抱着划水提升自己的态度,报名了一个推荐赛道的赛题。因为本身已经参加工作了,实际空闲时间不是太多,只能晚上下班或者周末和各位参赛大佬卷上一卷

a4e9e04a2f7fecf64d0e56f5638ff2e8.png

项目背景

偶然看到了【飞桨论文复现挑战赛】,抱着 划水 提升自己的态度,报名了一个推荐赛道的赛题。因为本身已经参加工作了,实际空闲时间不是太多,只能晚上下班或者周末和各位参赛大佬卷上一卷,划划水~

工欲善其事,必先利其器!在实际推荐算法开发工作中,一般也都有自己的开发项目框架,包含了「数据加载」「特征处理」「模型构建」等模块,可以快速完成一个新算法的开发,类似GitHub上开源的DeepCTR包。因此,首先找了一下飞桨的相关套件,所幸飞桨团队开源了PaddleRec飞桨推荐模型库。

工具有了,下面就是比拼对论文的理解了,所以论文复现赛一定要 熟读论文!熟读论文!熟读论文! 重要的事情说3遍.在此前提下,基于PaddleRec复现开发就十分方便了,甚至都不用题目里提到的24小时。下面我们根据复现挑战赛的93号题目DLRM复现进行介绍,主要包括以下几个部分:

  1. DLRM算法原理

  2. PaddleRec介绍

  3. 如何基于PaddleRec快速复现

  4. 项目总结

  5. 参考资料

DLRM算法原理

1.模型结构

DeepLearningRecommendationModelforPersonalizationandRecommendationSystems,DLRM是FaceBook于2019年提出的CTR预估算法,推荐或广告相关同学可以阅读一下原论文,也是非常经典的一篇。

论文链接https://arxiv.org/pdf/1906.00091v1.pdf,

除了DLRM模型本身的经典结构,FaceBook还对线上推断做了非常多的工程方面的优化,感兴趣的同学可以去找一下相关博客。

f76b3ba942715d54d69093dc373ad3a7.png

推荐rank模型网络结构一般较为简单,如上图DLRM的网络结构看着和DNN就没啥区别,主要由四个基础模块构成,EmbeddingsMatrixFactorizationFactorizationMachineMultilayerPerceptrons

DLRM模型的特征输入,主要包括dense数值型和sparse类别型两种特征。

densefeatures直接连接MLP(上图中的蓝色三角形),sparsefeatures经由Embedding层(上图红色模块)查找得到相应的embedding向量.Interactions层(上图云状模块)进行特征交叉,包括densefeatures和sparsefeatures的交叉以及sparsefeatures内部之间的交叉等,该部分与因子分解机FM有些类似。

DLRM模型中所有的sparsefeautres的embedding向量长度均是相等的,且densefeatures经由MLP也转化成相同的维度。这点是理解该模型代码的关键。

总结一下,DLRM模型的步骤如下:

  1. Densefeatures经过MLP(论文中称为bottom-MLP)处理为同样维度的向量;

  2. Sparsefeatures经由lookup获得统一维度的embedding向量(可选择每一个特征对应的embedding是否经过MLP处理);

  3. Densefeatures&sparsefeatures的向量两两之间进行dotproduct交叉;

  4. 交叉结果再和dense向量concat一起输入到顶层MLP(top-MLP);

  5. 经过sigmoid函数激活得到点击概率。

2.实验部分

不得不说,Facebook大佬发文章就NB,DLRM网络结构简单干净,没有任何调参,简简单单的SGD+lr=0.1就打败了DCN。原文所说,“DLRMvsDCNwithoutextensivetuningandnoregularizationisused.”太强了!

d8b38dd5a44bd9ba6a9d0557a00fb973.png

3.原论文repo

作者原论文开源代码是基于Pytorch实现的,https://github.com/facebookresearch/dlrm,代码逻辑可能有点儿复杂,参考本项目之后再去理解,可能会事半功倍。

4.数据集

原论文采用KaggleCriteo数据集,为常用的CTR预估任务基准数据集。单条样本包括13列densefeatures、26列sparsefeatures及label。

7515ff587aca523dd902673ba85c8ebb.png

本项目采用PaddleRec所提供的Criteo数据集进行复现。

PaddleRec介绍

PaddleRec涵盖了推荐系统的各个阶段,包括内容理解、匹配、召回、排序、多任务、重排序等,但这里我们只关注CTR预估,即排序阶段.该部分在models/rank/路径下,已经实现了deepfmdnnffmfm等经典CTR算法,每类算法包含静态图和动态图两种训练方式。我们一般选择动态图复现,因为和PyTorch及Tensorflow2等语法上更接近,调试也更方便。

我们在models/rank/路径下定义dataset加载和模型组网方式之后,便可以通过PaddleRec下tools类进行模型的训练及预测。一个简单的DNN算法训练和推断就是下面简单的两行命令:

# Step 1, 训练模型  
python -u tools/trainer.py -m models/rank/dnn/config.yaml

# Step 2, 预测推断  
python -u tools/infer.py -m models/rank/dnn/config.yaml

以上trainer.py和infer.py都是PaddleRec预先实现的训练类和预测类,我们不需要关心细节,只需关注数据加载及模型组网等就行,通过上述的配置文件config.yaml去调用我们实现的数据读取类和模型。

|--models|--rank|--dlrm                   # 本项目核心代码|--data                 # 采样小数据集|--config.yaml          # 采样小数据集模型配置|--config_bigdata.yaml  # Kaggle Criteo 全量数据集模型配置|--criteo_reader.py     # dataset加载类            |--dygraph_model.py     # PaddleRec 动态图模型训练类|--net.py               # dlrm 核心算法代码,包括 dlrm 组网等
|--tools                      # PaddleRec 工具类

总结一下,基于PaddleRecCTR模型快速复现只需要我们在models/rank/路径下,新建自己的模型文件夹,比如我这里的dlrm/.其中,最重要的三个是:

-config.yaml数据、特征、模型等配置

-xxxx_reader.py数据集加载方式

-net.py模型组网

因为DLRM复现要求的是Criteo数据集,甚至这个reader都不用自己去写,PaddleRec帮你做好了。更多关于PaddleRec的介绍,可以参考这里https://github.com/PaddlePaddle/PaddleRec

如何基于PaddleRec

快速复现

上文提到,基于PaddleRec快速复现的关键是net.py模型组网。这里介绍一下net.py代码:

下面实现MLP层,可以看到和PyTorch、Tensorflow2的语法非常接近,几乎可以无缝切换到PaddlePaddle。

官网API文档中有一张映射表,可以参考:PyTorch2PaddlePaddlehttps://www.paddlepaddle.org.cn/documentation/docs/zh/guides/08_api_mapping/pytorch_api_mapping_cn.html

class MLPLayer(nn.Layer):def __init__(self, input_shape, units_list=None, l2=0.01, last_action=None, **kwargs):super(MLPLayer, self).__init__(**kwargs)if units_list is None:units_list = [128, 128, 64]units_list = [input_shape] + units_listself.units_list = units_listself.l2 = l2self.mlp = []self.last_action = last_action
# 堆叠多层 dense 层for i, unit in enumerate(units_list[:-1]):if i != len(units_list) - 1:dense = paddle.nn.Linear(in_features=unit,out_features=units_list[i + 1],weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Normal(std=1.0 / math.sqrt(unit))))self.mlp.append(dense)
# ReLU激活函数relu = paddle.nn.ReLU()self.mlp.append(relu)# BatchNorm加速训练norm = paddle.nn.BatchNorm1D(units_list[i + 1])self.mlp.append(norm)else:dense = paddle.nn.Linear(in_features=unit,out_features=units_list[i + 1],weight_attr=paddle.nn.initializer.Normal(std=1.0 / math.sqrt(unit)))self.mlp.append(dense)if last_action is not None:relu = paddle.nn.ReLU()self.mlp.append(relu)def forward(self, inputs):outputs = inputsfor n_layer in self.mlp:outputs = n_layer(outputs)return outputs

下面是DLRM模型的核心组网,代码中有注释,结合第二部分算法原理很容易理解。

__init__初始化函数中,定义bottom-MLP模块处理数值型特征,定义Embedding层完成稀疏特征到Embedding向量的映射.定义top-MLP模块处理交叉特征的进一步泛化,得到CTR预测值.

forward中,对输入的densefeatures和sparsefeatures进行处理,分别得到的embedding向量拼接在一起.经过vector-wise特征交叉后,输入top-MLP得到预测值.

class DLRMLayer(nn.Layer):def __init__(self,dense_feature_dim,bot_layer_sizes,sparse_feature_number,sparse_feature_dim,top_layer_sizes,num_field,sync_mode=None):super(DLRMLayer, self).__init__()self.dense_feature_dim = dense_feature_dimself.bot_layer_sizes = bot_layer_sizesself.sparse_feature_number = sparse_feature_numberself.sparse_feature_dim = sparse_feature_dimself.top_layer_sizes = top_layer_sizesself.num_field = num_field# 定义 DLRM 模型的 Bot-MLP 层self.bot_mlp = MLPLayer(input_shape=dense_feature_dim,units_list=bot_layer_sizes,last_action="relu")# 定义 DLRM 模型的 Top-MLP 层self.top_mlp = MLPLayer(input_shape=int(num_field * (num_field + 1) / 2) + sparse_feature_dim,units_list=top_layer_sizes)# 定义 DLRM 模型的 Embedding 层self.embedding = paddle.nn.Embedding(num_embeddings=self.sparse_feature_number,embedding_dim=self.sparse_feature_dim,sparse=True,weight_attr=paddle.ParamAttr(name="SparseFeatFactors",initializer=paddle.nn.initializer.Uniform()))def forward(self, sparse_inputs, dense_inputs):# (batch_size, sparse_feature_dim)x = self.bot_mlp(dense_inputs)# interact dense and sparse featurebatch_size, d = x.shapesparse_embs = []for s_input in sparse_inputs:emb = self.embedding(s_input)emb = paddle.reshape(emb, shape=[-1, self.sparse_feature_dim])sparse_embs.append(emb)# 拼接数值型特征和 Embedding 特征T = paddle.reshape(paddle.concat(x=sparse_embs + [x], axis=1), (batch_size, -1, d))# 进行 vector-wise 特征交叉Z = paddle.bmm(T, paddle.transpose(T, perm=[0, 2, 1]))Zflat = paddle.triu(Z, 1) + paddle.tril(paddle.ones_like(Z) * MIN_FLOAT, 0)Zflat = paddle.reshape(paddle.masked_select(Zflat,paddle.greater_than(Zflat, paddle.ones_like(Zflat) * MIN_FLOAT)),(batch_size, -1))R = paddle.concat([x] + [Zflat], axis=1)# 交叉特征输入 Top-MLP 进行 CTR 预测y = self.top_mlp(R)return y

本项目DLRM代码已经提交PR,合入到PaddleRec套件中,可以从GitHub上clone代码.源码在PaddleRec/models/rank/dlrm路径中,参考readme.md运行代码。也可以在AIStudio的NoteBook上clone代码,直接上手跑跑看,步骤如下:

-Step1,gitclonecode

-Step2,downloaddata

-Step3,trainmodel&infer

################# Step 1, git clone code ################
# 当前处于 /home/aistudio 目录, 代码存放在 /home/work/rank/DLRM-Paddle 中import os
if not os.path.isdir('work/rank/DLRM-Paddle'):if not os.path.isdir('work/rank'):!mkdir work/rank# 国内访问或 git clone 较慢, 利用 hub.fastgit.org 加速!cd work/rank && git clone https://hub.fastgit.org/Andy1314Chen/DLRM-Paddle.git

################# Step 2, download data ################
# 当前处于 /home/aistudio 目录,数据存放在 /home/data/criteo 中import os
os.makedirs('data/criteo', exist_ok=True)# Download  data
if not os.path.exists('data/criteo/slot_test_data_full.tar.gz') or not os.path.exists('data/criteo/slot_train_data_full.tar.gz'):!cd data/criteo && wget https://paddlerec.bj.bcebos.com/datasets/criteo/slot_test_data_full.tar.gz!cd data/criteo && tar xzvf slot_test_data_full.tar.gz!cd data/criteo && wget https://paddlerec.bj.bcebos.com/datasets/criteo/slot_train_data_full.tar.gz!cd data/criteo && tar xzvf slot_train_data_full.tar.gz

################## Step 3, train model ##################
# 启动训练脚本 (需注意当前是否是 GPU 环境, 非 GPU 环境请修改 config_bigdata.yaml 配置中 use_gpu 为 False)
!cd work/rank/DLRM-Paddle && sh run.sh config_bigdata

项目总结

1.基于PaddleRec可以快速进行推荐算法的复现,让你更加专注模型的细节,提升复现效率。

2.PaddleRec提供了通用的训练/推理逻辑,如需增加一些特殊功能,例如,如何提高数据加载速度?如何在训练过程中设置easy_stopping?等。可以直接修改tools/trainer.py和tools/infer.py。

3.有了PaddleRec,论文复现更加强调熟读论文、读懂论文,知道创新点在哪里?核心参数是什么?

参考资料

1.DeepLearningRecommendationModelforPersonalizationandRecommendationSystems,https://arxiv.org/pdf/1906.00091v1.pdf

2.Facebook开源代码

https://github.com/facebookresearch/dlrm

3.PaddleRec

https://github.com/PaddlePaddle/PaddleRec

4.飞桨论文复现打卡营https://aistudio.baidu.com/aistudio/education/group/info/24681

b3610f7614dfb819df8ea0b0dad8fc40.png

关注公众号,获取更多技术内容~


推荐阅读
  • 本文提供了PyTorch框架中常用的预训练模型的下载链接及详细使用指南,涵盖ResNet、Inception、DenseNet、AlexNet、VGGNet等六大分类模型。每种模型的预训练参数均经过精心调优,适用于多种计算机视觉任务。文章不仅介绍了模型的下载方式,还详细说明了如何在实际项目中高效地加载和使用这些模型,为开发者提供全面的技术支持。 ... [详细]
  • 表面缺陷检测数据集综述及GitHub开源项目推荐
    本文综述了表面缺陷检测领域的数据集,并推荐了多个GitHub上的开源项目。通过对现有文献和数据集的系统整理,为研究人员提供了全面的资源参考,有助于推动该领域的发展和技术进步。 ... [详细]
  • 谷歌工程师:TensorFlow已重获新生;网友:我还是用PyTorch
    乾明发自凹非寺量子位报道|公众号QbitAI道友留步!TensorFlow已重获新生。在“PyTorch真香”的潮流中,有人站出来为TensorFlow说话了。这次来自谷歌的工程师 ... [详细]
  • [TensorFlow系列3]:初学者是选择Tensorflow2.x还是1.x? 2.x与1.x的主要区别?
    作者主页(文火冰糖的硅基工坊):https:blog.csdn.netHiWangWenBing本文网址:https:blog.csdn.netHiW ... [详细]
  • 通过整合JavaFX与Swing,我们成功地将现有的Swing应用程序组件进行了现代化改造。此次升级不仅提升了用户界面的美观性和交互性,还确保了与原有Swing应用程序的无缝集成,为开发高质量的Java桌面应用提供了坚实的基础。 ... [详细]
  • 使用 Python 中的 Matplotlib Axes 获取标签方法详解 ... [详细]
  • Android目录遍历工具 | AppCrawler自动化测试进阶(第二部分):个性化配置详解
    终于迎来了“足不出户也能为社会贡献力量”的时刻,但有追求的测试工程师绝不会让自己的生活变得乏味。与其在家消磨时光,不如利用这段时间深入研究和提升自己的技术能力,特别是对AppCrawler自动化测试工具的个性化配置进行详细探索。这不仅能够提高测试效率,还能为项目带来更多的价值。 ... [详细]
  • 本文深入探讨了 C# 中 `SqlCommand` 和 `SqlDataAdapter` 的核心差异及其应用场景。`SqlCommand` 主要用于执行单一的 SQL 命令,并通过 `DataReader` 获取结果,具有较高的执行效率,但灵活性较低。相比之下,`SqlDataAdapter` 则适用于复杂的数据操作,通过 `DataSet` 提供了更多的数据处理功能,如数据填充、更新和批量操作,更适合需要频繁数据交互的场景。 ... [详细]
  • 使用PyQt5与OpenCV实现电脑摄像头的图像捕捉功能
    本文介绍了如何使用Python中的PyQt5和OpenCV库来实现电脑摄像头的图像捕捉功能。通过结合这两个强大的工具,用户可以轻松地打开摄像头并进行实时图像采集和处理。代码示例展示了如何初始化摄像头、捕获图像并将其显示在PyQt5的图形界面中。此外,还提供了详细的步骤说明和代码注释,帮助开发者快速上手并实现相关功能。 ... [详细]
  • 如何在datetimebox中进行赋值与取值操作
    在 datetimebox 中进行赋值和取值操作时,可以通过以下方法实现:使用 `$('#j_dateStart').datebox('setValue', '指定日期')` 进行赋值,而通过 `$('#j_dateStart').datebox('getValue')` 获取当前选中的日期值。若需要清空日期值,可以使用 `$('#j_dateStart').datebox('clear')` 方法。这些操作能够确保日期控件的准确性和灵活性,适用于各种前端应用场景。 ... [详细]
  • REST API 时代落幕,GraphQL 持续引领未来
    尽管REST API已广泛使用多年,但在深入了解GraphQL及其解决的核心问题后,我深感其将引领未来的API设计趋势。GraphQL不仅提高了数据查询的效率,还增强了灵活性和性能,有望成为API开发的新标准。 ... [详细]
  • HTML5 Web存储技术是许多开发者青睐本地应用程序的重要原因之一,因为它能够实现在客户端本地存储数据。HTML5通过引入Web Storage API,使得Web应用程序能够在浏览器中高效地存储数据,从而提升了应用的性能和用户体验。相较于传统的Cookie机制,Web Storage不仅提供了更大的存储容量,还简化了数据管理和访问的方式。本文将从基础概念、关键技术到实际应用,全面解析HTML5 Web存储技术,帮助读者深入了解其工作原理和应用场景。 ... [详细]
  • 基于TensorFlow的鸢尾花数据集神经网络模型深度解析
    基于TensorFlow的鸢尾花数据集神经网络模型深度解析 ... [详细]
  • 深度学习分位数回归实现区间预测
    深度学习分位数回归实现区间预测 ... [详细]
  • 算法和数据结构是计算机科学中最基础和最重要的两个主题,在软件开发中无处不在。我坚信,对这两个主题的充分了解对于成为一名更好的程序员也很关键, ... [详细]
author-avatar
机敏的柑桔hs5
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有