提交 00b2de4b 编写于 作者: F frankwhzhang

fix listwise

上级 018a2916
...@@ -12,44 +12,56 @@ ...@@ -12,44 +12,56 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
evaluate:
reader:
batch_size: 1
class: "{workspace}/random_infer_reader.py"
test_data_path: "{workspace}/data/train"
train: workspace: "paddlerec.models.rerank.listwise"
trainer:
# for cluster training
strategy: "async"
epochs: 3 dataset:
workspace: "paddlerec.models.rerank.listwise" - name: dataset_train
device: cpu type: DataLoader
data_path: "{workspace}/data/train"
data_converter: "{workspace}/random_reader.py"
- name: dataset_infer
type: DataLoader
data_path: "{workspace}/data/test"
data_converter: "{workspace}/random_reader.py"
reader: hyper_parameters:
batch_size: 2 hidden_size: 128
class: "{workspace}/random_reader.py" user_vocab: 200
train_data_path: "{workspace}/data/train" item_vocab: 1000
dataset_class: "DataLoader" item_len: 5
embed_size: 16
batch_size: 1
optimizer:
class: sgd
learning_rate: 0.01
strategy: async
model: #use infer_runner mode and modify 'phase' below if infer
models: "{workspace}/model.py" mode: train_runner
hyper_parameters: #mode: infer_runner
hidden_size: 128
user_vocab: 200 runner:
item_vocab: 1000 - name: train_runner
item_len: 5 class: single_train
embed_size: 16 device: cpu
learning_rate: 0.01 epochs: 3
optimizer: sgd save_checkpoint_interval: 2
save_inference_interval: 4
save_checkpoint_path: "increment"
save_inference_path: "inference"
- name: infer_runner
class: single_infer
init_model_path: "increment/0"
device: cpu
epochs: 3
save: phase:
increment: - name: train
dirname: "increment" model: "{workspace}/model.py"
epoch_interval: 2 dataset_name: dataset_train
save_last: True thread_num: 1
inference: #- name: infer
dirname: "inference" # model: "{workspace}/model.py"
epoch_interval: 4 # dataset_name: dataset_infer
save_last: True # thread_num: 1
...@@ -25,18 +25,13 @@ class Model(ModelBase): ...@@ -25,18 +25,13 @@ class Model(ModelBase):
ModelBase.__init__(self, config) ModelBase.__init__(self, config)
def _init_hyper_parameters(self): def _init_hyper_parameters(self):
self.item_len = envs.get_global_env("hyper_parameters.self.item_len", self.item_len = envs.get_global_env("hyper_parameters.self.item_len")
None, self._namespace) self.hidden_size = envs.get_global_env("hyper_parameters.hidden_size")
self.hidden_size = envs.get_global_env("hyper_parameters.hidden_size", self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab")
None, self._namespace) self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab")
self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab", self.embed_size = envs.get_global_env("hyper_parameters.embed_size")
None, self._namespace)
self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab", def input_data(self, is_infer=False, **kwargs):
None, self._namespace)
self.embed_size = envs.get_global_env("hyper_parameters.embed_size",
None, self._namespace)
def input_data(self, is_infer=False):
user_slot_names = fluid.data( user_slot_names = fluid.data(
name='user_slot_names', name='user_slot_names',
shape=[None, 1], shape=[None, 1],
......
...@@ -23,14 +23,10 @@ from collections import defaultdict ...@@ -23,14 +23,10 @@ from collections import defaultdict
class TrainReader(Reader): class TrainReader(Reader):
def init(self): def init(self):
self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab", self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab")
None, "train.model") self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab")
self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab", self.item_len = envs.get_global_env("hyper_parameters.item_len")
None, "train.model") self.batch_size = envs.get_global_env("hyper_parameters.batch_size")
self.item_len = envs.get_global_env("hyper_parameters.item_len", None,
"train.model")
self.batch_size = envs.get_global_env("batch_size", None,
"train.reader")
def reader_creator(self): def reader_creator(self):
def reader(): def reader():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册