提交 00b2de4b 编写于 作者: F frankwhzhang

fix listwise

上级 018a2916
......@@ -12,44 +12,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
evaluate:
reader:
batch_size: 1
class: "{workspace}/random_infer_reader.py"
test_data_path: "{workspace}/data/train"
train:
trainer:
# for cluster training
strategy: "async"
epochs: 3
workspace: "paddlerec.models.rerank.listwise"
device: cpu
workspace: "paddlerec.models.rerank.listwise"
reader:
batch_size: 2
class: "{workspace}/random_reader.py"
train_data_path: "{workspace}/data/train"
dataset_class: "DataLoader"
dataset:
- name: dataset_train
type: DataLoader
data_path: "{workspace}/data/train"
data_converter: "{workspace}/random_reader.py"
- name: dataset_infer
type: DataLoader
data_path: "{workspace}/data/test"
data_converter: "{workspace}/random_reader.py"
model:
models: "{workspace}/model.py"
hyper_parameters:
hyper_parameters:
hidden_size: 128
user_vocab: 200
item_vocab: 1000
item_len: 5
embed_size: 16
batch_size: 1
optimizer:
class: sgd
learning_rate: 0.01
optimizer: sgd
strategy: async
#use infer_runner mode and modify 'phase' below if infer
mode: train_runner
#mode: infer_runner
runner:
- name: train_runner
class: single_train
device: cpu
epochs: 3
save_checkpoint_interval: 2
save_inference_interval: 4
save_checkpoint_path: "increment"
save_inference_path: "inference"
- name: infer_runner
class: single_infer
init_model_path: "increment/0"
device: cpu
epochs: 3
save:
increment:
dirname: "increment"
epoch_interval: 2
save_last: True
inference:
dirname: "inference"
epoch_interval: 4
save_last: True
phase:
- name: train
model: "{workspace}/model.py"
dataset_name: dataset_train
thread_num: 1
#- name: infer
# model: "{workspace}/model.py"
# dataset_name: dataset_infer
# thread_num: 1
......@@ -25,18 +25,13 @@ class Model(ModelBase):
ModelBase.__init__(self, config)
def _init_hyper_parameters(self):
self.item_len = envs.get_global_env("hyper_parameters.self.item_len",
None, self._namespace)
self.hidden_size = envs.get_global_env("hyper_parameters.hidden_size",
None, self._namespace)
self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab",
None, self._namespace)
self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab",
None, self._namespace)
self.embed_size = envs.get_global_env("hyper_parameters.embed_size",
None, self._namespace)
def input_data(self, is_infer=False):
self.item_len = envs.get_global_env("hyper_parameters.self.item_len")
self.hidden_size = envs.get_global_env("hyper_parameters.hidden_size")
self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab")
self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab")
self.embed_size = envs.get_global_env("hyper_parameters.embed_size")
def input_data(self, is_infer=False, **kwargs):
user_slot_names = fluid.data(
name='user_slot_names',
shape=[None, 1],
......
......@@ -23,14 +23,10 @@ from collections import defaultdict
class TrainReader(Reader):
def init(self):
self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab",
None, "train.model")
self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab",
None, "train.model")
self.item_len = envs.get_global_env("hyper_parameters.item_len", None,
"train.model")
self.batch_size = envs.get_global_env("batch_size", None,
"train.reader")
self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab")
self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab")
self.item_len = envs.get_global_env("hyper_parameters.item_len")
self.batch_size = envs.get_global_env("hyper_parameters.batch_size")
def reader_creator(self):
def reader():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册