diff --git a/core/model.py b/core/model.py index 265f5311d2a49601fb21addc9031358170a287fd..22e742374d4ac2d4ef079b6cb4157759ef3ffd51 100755 --- a/core/model.py +++ b/core/model.py @@ -177,6 +177,13 @@ class ModelBase(object): opt_name = envs.get_global_env("hyper_parameters.optimizer.class") opt_lr = envs.get_global_env( "hyper_parameters.optimizer.learning_rate") + if not isinstance(opt_lr, (float, Variable)): + try: + opt_lr = float(opt_lr) + except ValueError: + raise ValueError( + "In your config yaml, 'learning_rate': %s must be written as a floating piont number,such as 0.001 or 1e-3" + % opt_lr) opt_strategy = envs.get_global_env( "hyper_parameters.optimizer.strategy") diff --git a/core/trainers/framework/dataset.py b/core/trainers/framework/dataset.py index 239b568be34793c5ddb0830e9cca06951da143f4..3861813cdd7d8d5e0b64e8c568a9c70ede2b9c05 100644 --- a/core/trainers/framework/dataset.py +++ b/core/trainers/framework/dataset.py @@ -143,6 +143,8 @@ class QueueDataset(DatasetBase): if need_split_files: file_list = split_files(file_list, context["fleet"].worker_index(), context["fleet"].worker_num()) + + context["file_list"] = file_list print("File_list: {}".format(file_list)) dataset.set_filelist(file_list) diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index 4375b7267359e50e8cf9d739ba9dc1f58529e36a..0229a3a8c6e245df5d6530c48dcca0d8a0638306 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -18,10 +18,12 @@ import os import time import warnings import numpy as np +import random import logging import paddle.fluid as fluid from paddlerec.core.utils import envs +from paddlerec.core.utils.util import shuffle_files from paddlerec.core.metric import Metric logging.basicConfig( @@ -92,7 +94,6 @@ class RunnerBase(object): reader_name = model_dict["dataset_name"] model_name = model_dict["name"] model_class = context["model"][model_dict["name"]]["model"] - fetch_vars = [] fetch_alias = [] fetch_period = int( @@ -395,7 +396,12 @@ class SingleRunner(RunnerBase): for model_dict in context["phases"]: model_class = context["model"][model_dict["name"]]["model"] metrics = model_class._metrics - + if "shuffle_filelist" in model_dict: + need_shuffle_files = model_dict.get("shuffle_filelist", + None) + filelist = context["file_list"] + context["file_list"] = shuffle_files(need_shuffle_files, + filelist) begin_time = time.time() result = self._run(context, model_dict) end_time = time.time() @@ -439,6 +445,11 @@ class PSRunner(RunnerBase): model_class = context["model"][model_dict["name"]]["model"] metrics = model_class._metrics for epoch in range(epochs): + if "shuffle_filelist" in model_dict: + need_shuffle_files = model_dict.get("shuffle_filelist", None) + filelist = context["file_list"] + context["file_list"] = shuffle_files(need_shuffle_files, + filelist) begin_time = time.time() result = self._run(context, model_dict) end_time = time.time() @@ -484,6 +495,11 @@ class CollectiveRunner(RunnerBase): ".epochs")) model_dict = context["env"]["phase"][0] for epoch in range(epochs): + if "shuffle_filelist" in model_dict: + need_shuffle_files = model_dict.get("shuffle_filelist", None) + filelist = context["file_list"] + context["file_list"] = shuffle_files(need_shuffle_files, + filelist) begin_time = time.time() self._run(context, model_dict) end_time = time.time() @@ -512,6 +528,11 @@ class PslibRunner(RunnerBase): envs.get_global_env("runner." + context["runner_name"] + ".epochs")) for epoch in range(epochs): + if "shuffle_filelist" in model_dict: + need_shuffle_files = model_dict.get("shuffle_filelist", None) + filelist = context["file_list"] + context["file_list"] = shuffle_files(need_shuffle_files, + filelist) begin_time = time.time() self._run(context, model_dict) end_time = time.time() @@ -574,6 +595,12 @@ class SingleInferRunner(RunnerBase): metrics = model_class._infer_results self._load(context, model_dict, self.epoch_model_path_list[index]) + if "shuffle_filelist" in model_dict: + need_shuffle_files = model_dict.get("shuffle_filelist", + None) + filelist = context["file_list"] + context["file_list"] = shuffle_files(need_shuffle_files, + filelist) begin_time = time.time() result = self._run(context, model_dict) end_time = time.time() diff --git a/core/utils/dataloader_instance.py b/core/utils/dataloader_instance.py index 03e6f0a67884917e9af2d02d13eb86576620ceef..0de193d1c2e02f58610dfad2ec1f09989513a4b2 100755 --- a/core/utils/dataloader_instance.py +++ b/core/utils/dataloader_instance.py @@ -59,7 +59,7 @@ def dataloader_by_name(readerclass, if need_split_files: files = split_files(files, context["fleet"].worker_index(), context["fleet"].worker_num()) - + context["file_list"] = files reader = reader_class(yaml_file) reader.init() @@ -121,7 +121,7 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context): if need_split_files: files = split_files(files, context["fleet"].worker_index(), context["fleet"].worker_num()) - + context["file_list"] = files sparse = get_global_env(name + "sparse_slots", "#") if sparse == "": sparse = "#" @@ -191,7 +191,7 @@ def slotdataloader(readerclass, train, yaml_file, context): if need_split_files: files = split_files(files, context["fleet"].worker_index(), context["fleet"].worker_num()) - + context["file_list"] = files sparse = get_global_env("sparse_slots", "#", namespace) if sparse == "": sparse = "#" diff --git a/core/utils/util.py b/core/utils/util.py index f6acfe203612326a77f41326581583278dac4183..09aece5e899c7eab3f71a5ac84d430c54274bf06 100755 --- a/core/utils/util.py +++ b/core/utils/util.py @@ -16,6 +16,8 @@ import datetime import os import sys import time +import warnings +import random import numpy as np from paddle import fluid @@ -223,6 +225,16 @@ def check_filelist(hidden_file_list, data_file_list, train_data_path): return hidden_file_list, data_file_list +def shuffle_files(need_shuffle_files, filelist): + if not isinstance(need_shuffle_files, bool): + raise ValueError( + "In your config yaml, 'shuffle_filelist': %s must be written as a boolean type,such as True or False" + % need_shuffle_files) + elif need_shuffle_files: + random.shuffle(filelist) + return filelist + + class CostPrinter(object): """ For count cost time && print cost log