diff --git a/models/multitask/esmm/config.yaml b/models/multitask/esmm/config.yaml index 255037523350aeeb66adc91a9f04ea7935f5dc4d..c6e79a52a6053c4441f253aca447c55307afa470 100644 --- a/models/multitask/esmm/config.yaml +++ b/models/multitask/esmm/config.yaml @@ -24,7 +24,7 @@ dataset: - name: dataset_infer batch_size: 1 type: QueueDataset - data_path: "{workspace}/data/test" + data_path: "{workspace}/data/train" data_converter: "{workspace}/esmm_reader.py" hyper_parameters: @@ -36,12 +36,12 @@ hyper_parameters: strategy: async #use infer_runner mode and modify 'phase' below if infer -mode: [train_runner] +mode: [train_runner, infer_runner] #mode: infer_runner runner: - name: train_runner - class: single_train + class: train device: cpu epochs: 3 save_checkpoint_interval: 2 @@ -51,7 +51,7 @@ runner: print_interval: 10 phases: [train] - name: infer_runner - class: single_infer + class: infer init_model_path: "increment/0" device: cpu epochs: 1