未验证 提交 8a15a4df 编写于 作者: T Taylor Robie 提交者: GitHub

Keras-ify NCF TPU embedding lookup (#5641)

* Keras-ify TPU embedding lookup

* delint

* pull get_variable() out of keras lambda

* delint

* move get_variable under variable scope
上级 b8318fd3
......@@ -204,24 +204,34 @@ def construct_model(users, items, params):
name="embeddings_mf_user",
shape=[num_users, mf_dim + model_layers[0] // 2],
initializer=tf.glorot_uniform_initializer())
cmb_embedding_item = tf.get_variable(
name="embeddings_mf_item",
shape=[num_items, mf_dim + model_layers[0] // 2],
initializer=tf.glorot_uniform_initializer())
cmb_user_latent = tf.gather(cmb_embedding_user, user_input)
cmb_item_latent = tf.gather(cmb_embedding_item, item_input)
mlp_user_latent = tf.slice(cmb_user_latent, [0, 0],
[batch_size, model_layers[0] // 2])
mlp_item_latent = tf.slice(cmb_item_latent, [0, 0],
[batch_size, model_layers[0] // 2])
mlp_vector = tf.keras.layers.concatenate([mlp_user_latent,
mlp_item_latent])
mf_user_latent = tf.slice(cmb_user_latent, [0, model_layers[0] // 2],
[batch_size, mf_dim])
mf_item_latent = tf.slice(cmb_item_latent, [0, model_layers[0] // 2],
[batch_size, mf_dim])
cmb_user_latent = tf.keras.layers.Lambda(lambda ids: tf.gather(
cmb_embedding_user, ids))(user_input)
cmb_item_latent = tf.keras.layers.Lambda(lambda ids: tf.gather(
cmb_embedding_item, ids))(item_input)
mlp_user_latent = tf.keras.layers.Lambda(
lambda x: tf.slice(x, [0, 0], [batch_size, model_layers[0] // 2])
)(cmb_user_latent)
mlp_item_latent = tf.keras.layers.Lambda(
lambda x: tf.slice(x, [0, 0], [batch_size, model_layers[0] // 2])
)(cmb_item_latent)
mf_user_latent = tf.keras.layers.Lambda(
lambda x: tf.slice(x, [0, model_layers[0] // 2], [batch_size, mf_dim])
)(cmb_user_latent)
mf_item_latent = tf.keras.layers.Lambda(
lambda x: tf.slice(x, [0, model_layers[0] // 2], [batch_size, mf_dim])
)(cmb_item_latent)
else:
# Initializer for embedding layers
embedding_initializer = "glorot_uniform"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册