热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

TFRS之上下文特征

在推荐模型中,有很多因素会影响id之外的特性是否有用:上下文的重要性:如果用户的首选项跨上下文和时间相对稳定,那么上下文特性可能不会提供

在推荐模型中,有很多因素会影响id之外的特性是否有用:


  • 上下文的重要性:如果用户的首选项跨上下文和时间相对稳定,那么上下文特性可能不会提供太多好处。然而,如果用户偏好是高度上下文相关的,那么添加上下文将显著改善模型。例如,在决定是推荐一个短片还是一部电影时,一周的哪一天可能是一个重要的特征:用户在一周中可能只有时间看短内容,但可以在周末放松并享受一部完整长度的电影。类似地,查询时间戳可能在建模流行动态中扮演重要角色:一部电影可能在周围非常流行
  • 数据稀疏:如果数据稀疏,使用非id特征可能是关键。由于对给定用户或项目的观察很少,模型可能难以估计每个用户或每个项目的良好表示。为了建立一个准确的模型,其他的特征,如项目类别、描述和图像,必须被用来帮助模型泛化训练数据以外的数据。这在冷启动的情况下尤其重要,因为在冷启动的情况下,一些项目或用户的数据相对较少。

import os
import tempfileimport numpy as np
import tensorflow as tf
import tensorflow_datasets as tfdsimport tensorflow_recommenders as tfrs

特征数据

ratings = tfds.load("movielens/100k-ratings", split="train")
movies = tfds.load("movielens/100k-movies", split="train")ratings = ratings.map(lambda x: {"movie_title": x["movie_title"],"user_id": x["user_id"],"timestamp": x["timestamp"],
})
movies = movies.map(lambda x: x["movie_title"])

特征词汇表

timestamps = np.concatenate(list(ratings.map(lambda x: x["timestamp"]).batch(100)))max_timestamp = timestamps.max()
min_timestamp = timestamps.min()timestamp_buckets = np.linspace(min_timestamp, max_timestamp, num=1000,
)unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))
unique_user_ids = np.unique(np.concatenate(list(ratings.batch(1_000).map(lambda x: x["user_id"]))))

定义模型

user model
代码中改动:增加是否使用时间戳这个特征的选择。

class UserModel(tf.keras.Model):def __init__(self, use_timestamps):super().__init__()self._use_timestamps = use_timestampsself.user_embedding = tf.keras.Sequential([tf.keras.layers.StringLookup(vocabulary=unique_user_ids, mask_token=None),tf.keras.layers.Embedding(len(unique_user_ids) + 1, 32),])if use_timestamps:self.timestamp_embedding = tf.keras.Sequential([tf.keras.layers.Discretization(timestamp_buckets.tolist()),tf.keras.layers.Embedding(len(timestamp_buckets) + 1, 32),])self.normalized_timestamp = tf.keras.layers.Normalization(axis=None)self.normalized_timestamp.adapt(timestamps)def call(self, inputs):if not self._use_timestamps:return self.user_embedding(inputs["user_id"])return tf.concat([self.user_embedding(inputs["user_id"]),self.timestamp_embedding(inputs["timestamp"]),tf.reshape(self.normalized_timestamp(inputs["timestamp"]), (-1, 1)),], axis=1)

movie model

class MovieModel(tf.keras.Model):def __init__(self):super().__init__()max_tokens = 10_000self.title_embedding = tf.keras.Sequential([tf.keras.layers.StringLookup(vocabulary=unique_movie_titles, mask_token=None),tf.keras.layers.Embedding(len(unique_movie_titles) + 1, 32)])self.title_vectorizer = tf.keras.layers.TextVectorization(max_tokens=max_tokens)self.title_text_embedding = tf.keras.Sequential([self.title_vectorizer,tf.keras.layers.Embedding(max_tokens, 32, mask_zero=True),tf.keras.layers.GlobalAveragePooling1D(),])self.title_vectorizer.adapt(movies)def call(self, titles):return tf.concat([self.title_embedding(titles),self.title_text_embedding(titles),], axis=1)

组合模型

class MovielensModel(tfrs.models.Model):def __init__(self, use_timestamps):super().__init__()self.query_model = tf.keras.Sequential([UserModel(use_timestamps),tf.keras.layers.Dense(32)])self.candidate_model = tf.keras.Sequential([MovieModel(),tf.keras.layers.Dense(32)])self.task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(candidates=movies.batch(128).map(self.candidate_model),),)def compute_loss(self, features, training=False):# We only pass the user id and timestamp features into the query model. This# is to ensure that the training inputs would have the same keys as the# query inputs. Otherwise the discrepancy in input structure would cause an# error when loading the query model after saving it.query_embeddings = self.query_model({"user_id": features["user_id"],"timestamp": features["timestamp"],})movie_embeddings = self.candidate_model(features["movie_title"])return self.task(query_embeddings, movie_embeddings)

实验

准备数据集

tf.random.set_seed(42)
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)cached_train = train.shuffle(100_000).batch(2048)
cached_test = test.batch(4096).cache()

Baseline :没有时间戳特性

model = MovielensModel(use_timestamps=False)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))model.fit(cached_train, epochs=3)train_accuracy = model.evaluate(cached_train, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]
test_accuracy = model.evaluate(cached_test, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]print(f"Top-100 accuracy (train): {train_accuracy:.2f}.")
print(f"Top-100 accuracy (test): {test_accuracy:.2f}.")

利用时间特征捕捉时间动态

model = MovielensModel(use_timestamps=True)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))model.fit(cached_train, epochs=3)train_accuracy = model.evaluate(cached_train, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]
test_accuracy = model.evaluate(cached_test, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]print(f"Top-100 accuracy (train): {train_accuracy:.2f}.")
print(f"Top-100 accuracy (test): {test_accuracy:.2f}.")

推荐阅读
  • java解析json转Map前段时间在做json报文处理的时候,写了一个针对不同格式json转map的处理工具方法,总结记录如下:1、单节点单层级、单节点多层级json转mapim ... [详细]
  • 普通树(每个节点可以有任意数量的子节点)级序遍历 ... [详细]
  • 机器学习算法:SVM(支持向量机)
    SVM算法(SupportVectorMachine,支持向量机)的核心思想有2点:1、如果数据线性可分,那么基于最大间隔的方式来确定超平面,以确保全局最优, ... [详细]
  • 本文节选自《NLTK基础教程——用NLTK和Python库构建机器学习应用》一书的第1章第1.2节,作者Nitin Hardeniya。本文将带领读者快速了解Python的基础知识,为后续的机器学习应用打下坚实的基础。 ... [详细]
  • 本文介绍了如何利用ObjectMapper实现JSON与JavaBean之间的高效转换。ObjectMapper是Jackson库的核心组件,能够便捷地将Java对象序列化为JSON格式,并支持从JSON、XML以及文件等多种数据源反序列化为Java对象。此外,还探讨了在实际应用中如何优化转换性能,以提升系统整体效率。 ... [详细]
  • 在尝试将 mysqldump 文件加载到新的 MySQL 服务器时,遇到因使用保留关键字 'table' 导致的语法错误。 ... [详细]
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • 利用python爬取豆瓣电影Top250的相关信息,包括电影详情链接,图片链接,影片中文名,影片外国名,评分,评价数,概况,导演,主演,年份,地区,类别这12项内容,然后将爬取的信息写入Exce ... [详细]
  • com.sun.javadoc.PackageDoc.exceptions()方法的使用及代码示例 ... [详细]
  • Ihavetwomethodsofgeneratingmdistinctrandomnumbersintherange[0..n-1]我有两种方法在范围[0.n-1]中生 ... [详细]
  • 本文介绍如何使用 Python 的 DOM 和 SAX 方法解析 XML 文件,并通过示例展示了如何动态创建数据库表和处理大量数据的实时插入。 ... [详细]
  • 本文详细介绍了如何使用Python中的smtplib库来发送带有附件的邮件,并提供了完整的代码示例。作者:多测师_王sir,时间:2020年5月20日 17:24,微信:15367499889,公司:上海多测师信息有限公司。 ... [详细]
  • 解决问题:1、批量读取点云las数据2、点云数据读与写出3、csf滤波分类参考:https:github.comsuyunzzzCSF论文题目ÿ ... [详细]
  • 本项目通过Python编程实现了一个简单的汇率转换器v1.02。主要内容包括:1. Python的基本语法元素:(1)缩进:用于表示代码的层次结构,是Python中定义程序框架的唯一方式;(2)注释:提供开发者说明信息,不参与实际运行,通常每个代码块添加一个注释;(3)常量和变量:用于存储和操作数据,是程序执行过程中的重要组成部分。此外,项目还涉及了函数定义、用户输入处理和异常捕获等高级特性,以确保程序的健壮性和易用性。 ... [详细]
  • 【问题】在Android开发中,当为EditText添加TextWatcher并实现onTextChanged方法时,会遇到一个问题:即使只对EditText进行一次修改(例如使用删除键删除一个字符),该方法也会被频繁触发。这不仅影响性能,还可能导致逻辑错误。本文将探讨这一问题的原因,并提供有效的解决方案,包括使用Handler或计时器来限制方法的调用频率,以及通过自定义TextWatcher来优化事件处理,从而提高应用的稳定性和用户体验。 ... [详细]
author-avatar
爱你不愿放cwy
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有