From 654ba717db45b7c8327fad188844b440444cc742 Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Thu, 28 May 2020 10:22:41 +0800 Subject: [PATCH] fix --- core/trainers/single_trainer.py | 6 ++--- models/rank/dnn/config.yaml | 39 ++++++++++++++++++++------------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index 9b58ca45..3ddc48e2 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -132,7 +132,7 @@ class SingleTrainer(TranspileTrainer): padding = 0 if type_name == "DataLoader": - return None#self._get_dataloader(dataset_name) + return None else: return self._get_dataset(dataset_name) @@ -243,7 +243,6 @@ class SingleTrainer(TranspileTrainer): metrics_varnames = [] metrics_format = [] fetch_period = 20 - #metrics_format.append("{}: {{}}".format("epoch")) metrics_format.append("{}: {{}}".format("batch")) for name, var in model_class.get_metrics().items(): metrics_varnames.append(var.name) @@ -259,7 +258,7 @@ class SingleTrainer(TranspileTrainer): while True: metrics_rets = self._exe.run(program=program, fetch_list=metrics_varnames) - metrics = [batch_id]#[epoch, batch_id] + metrics = [batch_id] metrics.extend(metrics_rets) if batch_id % fetch_period == 0 and batch_id != 0: @@ -275,7 +274,6 @@ class SingleTrainer(TranspileTrainer): dirname = envs.get_global_env("epoch.init_model_path", None) if dirname is None: return - dirname = os.path.join(dirname, str(epoch_id)) if is_fleet: fleet.load_persistables(self._exe, dirname) else: diff --git a/models/rank/dnn/config.yaml b/models/rank/dnn/config.yaml index a51647d0..9db1f9bb 100755 --- a/models/rank/dnn/config.yaml +++ b/models/rank/dnn/config.yaml @@ -12,43 +12,52 @@ # See the License for the specific language governing permissions and # limitations under the License. -debug: false -cold_start: true +# 轮数 epochs: 10 +# 设备 device: cpu +# 工作目录 workspace: "paddlerec.models.rank.dnn" +# dataset列表 dataset: -- name: dataset_2 +- name: dataset_2 # 名字,用来区分不同的dataset batch_size: 2 - #type: QueueDataset - type: DataLoader - data_path: "{workspace}/data/sample_data/train" + type: DataLoader # 或者QueueDataset + data_path: "{workspace}/data/sample_data/train" # 数据路径 sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26" dense_slots: "dense_var:13" +# 超参数 hyper_parameters: + #优化器 optimizer: class: Adam learning_rate: 0.001 strategy: async + # 用户自定义 sparse_inputs_slots: 27 sparse_feature_number: 1000001 sparse_feature_dim: 9 dense_input_dim: 13 fc_sizes: [512, 256, 128, 32] +# executor配置 epoch: - name: + name: trainer_class: single - save_checkpoint_interval: 2 - save_inference_interval: 4 - save_checkpoint_path: "increment" - save_inference_path: "inference" + save_checkpoint_interval: 2 # 保存模型 + save_inference_interval: 4 # 保存预测模型 + save_checkpoint_path: "increment" # 保存模型路径 + save_inference_path: "inference" # 保存预测模型路径 + #save_inference_feed_varnames: [] # 预测模型feed vars + #save_inference_fetch_varnames: [] # 预测模型 fetch vars + #init_model_path: "xxxx" # 加载模型 +# 执行器,每轮要跑的所有模型 executor: - name: train - model: "{workspace}/model.py" - dataset_name: dataset_2 - thread_num: 1 - is_infer: False + model: "{workspace}/model.py" # 模型路径 + dataset_name: dataset_2 # 名字,用来区分不同的阶段 + thread_num: 1 # 线程数 + is_infer: False # 是否是infer -- GitLab