From 8228c894ca534cba6ca5c273b4a51931dfa674ec Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Mon, 1 Jun 2020 16:37:48 +0800 Subject: [PATCH] update din yaml --- models/rank/din/config.yaml | 5 ++++- models/rank/din/data/config.txt | 3 --- models/rank/din/model.py | 28 ++++++++++------------------ 3 files changed, 14 insertions(+), 22 deletions(-) delete mode 100644 models/rank/din/data/config.txt diff --git a/models/rank/din/config.yaml b/models/rank/din/config.yaml index e61e4636..2885ba7a 100755 --- a/models/rank/din/config.yaml +++ b/models/rank/din/config.yaml @@ -36,7 +36,9 @@ hyper_parameters: item_emb_size: 64 cat_emb_size: 64 is_sparse: False - config_path: "{workspace}/data/config.txt" + item_count: 63001 + cat_count: 801 + act: "sigmoid" @@ -52,6 +54,7 @@ runner: save_inference_interval: 1 save_checkpoint_path: "increment" save_inference_path: "inference" + print_interval: 1 - name: infer_runner trainer_class: single_infer epochs: 1 diff --git a/models/rank/din/data/config.txt b/models/rank/din/data/config.txt deleted file mode 100644 index 8552fb4c..00000000 --- a/models/rank/din/data/config.txt +++ /dev/null @@ -1,3 +0,0 @@ -192403 -63001 -801 diff --git a/models/rank/din/model.py b/models/rank/din/model.py index 7f3d4801..4f609911 100755 --- a/models/rank/din/model.py +++ b/models/rank/din/model.py @@ -31,10 +31,11 @@ class Model(ModelBase): self.is_sparse = envs.get_global_env("hyper_parameters.is_sparse", False) #significant for speeding up the training process - self.config_path = envs.get_global_env("hyper_parameters.config_path", - "data/config.txt") self.use_DataLoader = envs.get_global_env( "hyper_parameters.use_DataLoader", False) + self.item_count = envs.get_global_env("hyper_parameters.item_count", + 63001) + self.cat_count = envs.get_global_env("hyper_parameters.cat_count", 801) def input_data(self, is_infer=False, **kwargs): seq_len = -1 @@ -74,13 +75,6 @@ class Model(ModelBase): ] + [label] + [mask] + [target_item_seq] + [target_cat_seq] return train_inputs - def config_read(self, config_path): - with open(config_path, "r") as fin: - user_count = int(fin.readline().strip()) - item_count = int(fin.readline().strip()) - cat_count = int(fin.readline().strip()) - return user_count, item_count, cat_count - def din_attention(self, hist, target_expand, mask): """activation weight""" @@ -121,50 +115,48 @@ class Model(ModelBase): target_item_seq = inputs[6] target_cat_seq = inputs[7] - user_count, item_count, cat_count = self.config_read(self.config_path) - item_emb_attr = fluid.ParamAttr(name="item_emb") cat_emb_attr = fluid.ParamAttr(name="cat_emb") hist_item_emb = fluid.embedding( input=hist_item_seq, - size=[item_count, self.item_emb_size], + size=[self.item_count, self.item_emb_size], param_attr=item_emb_attr, is_sparse=self.is_sparse) hist_cat_emb = fluid.embedding( input=hist_cat_seq, - size=[cat_count, self.cat_emb_size], + size=[self.cat_count, self.cat_emb_size], param_attr=cat_emb_attr, is_sparse=self.is_sparse) target_item_emb = fluid.embedding( input=target_item, - size=[item_count, self.item_emb_size], + size=[self.item_count, self.item_emb_size], param_attr=item_emb_attr, is_sparse=self.is_sparse) target_cat_emb = fluid.embedding( input=target_cat, - size=[cat_count, self.cat_emb_size], + size=[self.cat_count, self.cat_emb_size], param_attr=cat_emb_attr, is_sparse=self.is_sparse) target_item_seq_emb = fluid.embedding( input=target_item_seq, - size=[item_count, self.item_emb_size], + size=[self.item_count, self.item_emb_size], param_attr=item_emb_attr, is_sparse=self.is_sparse) target_cat_seq_emb = fluid.embedding( input=target_cat_seq, - size=[cat_count, self.cat_emb_size], + size=[self.cat_count, self.cat_emb_size], param_attr=cat_emb_attr, is_sparse=self.is_sparse) item_b = fluid.embedding( input=target_item, - size=[item_count, 1], + size=[self.item_count, 1], param_attr=fluid.initializer.Constant(value=0.0)) hist_seq_concat = fluid.layers.concat( -- GitLab