diff --git a/models/rerank/listwise/config.yaml b/models/rerank/listwise/config.yaml index 18b018026634e461257d167fa543f2d81a25436c..2ddfa32fe08aa8bece00727aefc46bb893b4d090 100644 --- a/models/rerank/listwise/config.yaml +++ b/models/rerank/listwise/config.yaml @@ -12,44 +12,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -evaluate: - reader: - batch_size: 1 - class: "{workspace}/random_infer_reader.py" - test_data_path: "{workspace}/data/train" -train: - trainer: - # for cluster training - strategy: "async" +workspace: "paddlerec.models.rerank.listwise" - epochs: 3 - workspace: "paddlerec.models.rerank.listwise" - device: cpu +dataset: +- name: dataset_train + type: DataLoader + data_path: "{workspace}/data/train" + data_converter: "{workspace}/random_reader.py" +- name: dataset_infer + type: DataLoader + data_path: "{workspace}/data/test" + data_converter: "{workspace}/random_reader.py" - reader: - batch_size: 2 - class: "{workspace}/random_reader.py" - train_data_path: "{workspace}/data/train" - dataset_class: "DataLoader" +hyper_parameters: + hidden_size: 128 + user_vocab: 200 + item_vocab: 1000 + item_len: 5 + embed_size: 16 + batch_size: 1 + optimizer: + class: sgd + learning_rate: 0.01 + strategy: async - model: - models: "{workspace}/model.py" - hyper_parameters: - hidden_size: 128 - user_vocab: 200 - item_vocab: 1000 - item_len: 5 - embed_size: 16 - learning_rate: 0.01 - optimizer: sgd +#use infer_runner mode and modify 'phase' below if infer +mode: train_runner +#mode: infer_runner + +runner: +- name: train_runner + class: single_train + device: cpu + epochs: 3 + save_checkpoint_interval: 2 + save_inference_interval: 4 + save_checkpoint_path: "increment" + save_inference_path: "inference" +- name: infer_runner + class: single_infer + init_model_path: "increment/0" + device: cpu + epochs: 3 - save: - increment: - dirname: "increment" - epoch_interval: 2 - save_last: True - inference: - dirname: "inference" - epoch_interval: 4 - save_last: True +phase: +- name: train + model: "{workspace}/model.py" + dataset_name: dataset_train + thread_num: 1 + #- name: infer + # model: "{workspace}/model.py" + # dataset_name: dataset_infer + # thread_num: 1 diff --git a/models/rerank/listwise/model.py b/models/rerank/listwise/model.py index d4cf9d8ed1a669d6d1ff3339008605f1aa26f4cd..d588db0629439eec9396ec9b1f81f1988e99d51e 100644 --- a/models/rerank/listwise/model.py +++ b/models/rerank/listwise/model.py @@ -25,18 +25,13 @@ class Model(ModelBase): ModelBase.__init__(self, config) def _init_hyper_parameters(self): - self.item_len = envs.get_global_env("hyper_parameters.self.item_len", - None, self._namespace) - self.hidden_size = envs.get_global_env("hyper_parameters.hidden_size", - None, self._namespace) - self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab", - None, self._namespace) - self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab", - None, self._namespace) - self.embed_size = envs.get_global_env("hyper_parameters.embed_size", - None, self._namespace) - - def input_data(self, is_infer=False): + self.item_len = envs.get_global_env("hyper_parameters.self.item_len") + self.hidden_size = envs.get_global_env("hyper_parameters.hidden_size") + self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab") + self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab") + self.embed_size = envs.get_global_env("hyper_parameters.embed_size") + + def input_data(self, is_infer=False, **kwargs): user_slot_names = fluid.data( name='user_slot_names', shape=[None, 1], diff --git a/models/rerank/listwise/random_reader.py b/models/rerank/listwise/random_reader.py index 41cf14b79285efe8f2d80e01bba74da3501cc504..aa7af3f083c720d35e9f11f5f5ec1bddd107cabc 100644 --- a/models/rerank/listwise/random_reader.py +++ b/models/rerank/listwise/random_reader.py @@ -23,14 +23,10 @@ from collections import defaultdict class TrainReader(Reader): def init(self): - self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab", - None, "train.model") - self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab", - None, "train.model") - self.item_len = envs.get_global_env("hyper_parameters.item_len", None, - "train.model") - self.batch_size = envs.get_global_env("batch_size", None, - "train.reader") + self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab") + self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab") + self.item_len = envs.get_global_env("hyper_parameters.item_len") + self.batch_size = envs.get_global_env("hyper_parameters.batch_size") def reader_creator(self): def reader():