未验证 提交 c26b0e75 编写于 作者: L liuyuhui 提交者: GitHub

add support for file_list shuffle each epoch and fix float learning rate bug (#197)

* add support for file_list shuffle each epoch, test=develop

* fix float learning rate bug

* optimized code for shuffle files
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 770693ab
...@@ -177,6 +177,13 @@ class ModelBase(object): ...@@ -177,6 +177,13 @@ class ModelBase(object):
opt_name = envs.get_global_env("hyper_parameters.optimizer.class") opt_name = envs.get_global_env("hyper_parameters.optimizer.class")
opt_lr = envs.get_global_env( opt_lr = envs.get_global_env(
"hyper_parameters.optimizer.learning_rate") "hyper_parameters.optimizer.learning_rate")
if not isinstance(opt_lr, (float, Variable)):
try:
opt_lr = float(opt_lr)
except ValueError:
raise ValueError(
"In your config yaml, 'learning_rate': %s must be written as a floating piont number,such as 0.001 or 1e-3"
% opt_lr)
opt_strategy = envs.get_global_env( opt_strategy = envs.get_global_env(
"hyper_parameters.optimizer.strategy") "hyper_parameters.optimizer.strategy")
......
...@@ -143,6 +143,8 @@ class QueueDataset(DatasetBase): ...@@ -143,6 +143,8 @@ class QueueDataset(DatasetBase):
if need_split_files: if need_split_files:
file_list = split_files(file_list, context["fleet"].worker_index(), file_list = split_files(file_list, context["fleet"].worker_index(),
context["fleet"].worker_num()) context["fleet"].worker_num())
context["file_list"] = file_list
print("File_list: {}".format(file_list)) print("File_list: {}".format(file_list))
dataset.set_filelist(file_list) dataset.set_filelist(file_list)
......
...@@ -18,10 +18,12 @@ import os ...@@ -18,10 +18,12 @@ import os
import time import time
import warnings import warnings
import numpy as np import numpy as np
import random
import logging import logging
import paddle.fluid as fluid import paddle.fluid as fluid
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
from paddlerec.core.utils.util import shuffle_files
from paddlerec.core.metric import Metric from paddlerec.core.metric import Metric
logging.basicConfig( logging.basicConfig(
...@@ -92,7 +94,6 @@ class RunnerBase(object): ...@@ -92,7 +94,6 @@ class RunnerBase(object):
reader_name = model_dict["dataset_name"] reader_name = model_dict["dataset_name"]
model_name = model_dict["name"] model_name = model_dict["name"]
model_class = context["model"][model_dict["name"]]["model"] model_class = context["model"][model_dict["name"]]["model"]
fetch_vars = [] fetch_vars = []
fetch_alias = [] fetch_alias = []
fetch_period = int( fetch_period = int(
...@@ -395,7 +396,12 @@ class SingleRunner(RunnerBase): ...@@ -395,7 +396,12 @@ class SingleRunner(RunnerBase):
for model_dict in context["phases"]: for model_dict in context["phases"]:
model_class = context["model"][model_dict["name"]]["model"] model_class = context["model"][model_dict["name"]]["model"]
metrics = model_class._metrics metrics = model_class._metrics
if "shuffle_filelist" in model_dict:
need_shuffle_files = model_dict.get("shuffle_filelist",
None)
filelist = context["file_list"]
context["file_list"] = shuffle_files(need_shuffle_files,
filelist)
begin_time = time.time() begin_time = time.time()
result = self._run(context, model_dict) result = self._run(context, model_dict)
end_time = time.time() end_time = time.time()
...@@ -439,6 +445,11 @@ class PSRunner(RunnerBase): ...@@ -439,6 +445,11 @@ class PSRunner(RunnerBase):
model_class = context["model"][model_dict["name"]]["model"] model_class = context["model"][model_dict["name"]]["model"]
metrics = model_class._metrics metrics = model_class._metrics
for epoch in range(epochs): for epoch in range(epochs):
if "shuffle_filelist" in model_dict:
need_shuffle_files = model_dict.get("shuffle_filelist", None)
filelist = context["file_list"]
context["file_list"] = shuffle_files(need_shuffle_files,
filelist)
begin_time = time.time() begin_time = time.time()
result = self._run(context, model_dict) result = self._run(context, model_dict)
end_time = time.time() end_time = time.time()
...@@ -484,6 +495,11 @@ class CollectiveRunner(RunnerBase): ...@@ -484,6 +495,11 @@ class CollectiveRunner(RunnerBase):
".epochs")) ".epochs"))
model_dict = context["env"]["phase"][0] model_dict = context["env"]["phase"][0]
for epoch in range(epochs): for epoch in range(epochs):
if "shuffle_filelist" in model_dict:
need_shuffle_files = model_dict.get("shuffle_filelist", None)
filelist = context["file_list"]
context["file_list"] = shuffle_files(need_shuffle_files,
filelist)
begin_time = time.time() begin_time = time.time()
self._run(context, model_dict) self._run(context, model_dict)
end_time = time.time() end_time = time.time()
...@@ -512,6 +528,11 @@ class PslibRunner(RunnerBase): ...@@ -512,6 +528,11 @@ class PslibRunner(RunnerBase):
envs.get_global_env("runner." + context["runner_name"] + envs.get_global_env("runner." + context["runner_name"] +
".epochs")) ".epochs"))
for epoch in range(epochs): for epoch in range(epochs):
if "shuffle_filelist" in model_dict:
need_shuffle_files = model_dict.get("shuffle_filelist", None)
filelist = context["file_list"]
context["file_list"] = shuffle_files(need_shuffle_files,
filelist)
begin_time = time.time() begin_time = time.time()
self._run(context, model_dict) self._run(context, model_dict)
end_time = time.time() end_time = time.time()
...@@ -574,6 +595,12 @@ class SingleInferRunner(RunnerBase): ...@@ -574,6 +595,12 @@ class SingleInferRunner(RunnerBase):
metrics = model_class._infer_results metrics = model_class._infer_results
self._load(context, model_dict, self._load(context, model_dict,
self.epoch_model_path_list[index]) self.epoch_model_path_list[index])
if "shuffle_filelist" in model_dict:
need_shuffle_files = model_dict.get("shuffle_filelist",
None)
filelist = context["file_list"]
context["file_list"] = shuffle_files(need_shuffle_files,
filelist)
begin_time = time.time() begin_time = time.time()
result = self._run(context, model_dict) result = self._run(context, model_dict)
end_time = time.time() end_time = time.time()
......
...@@ -59,7 +59,7 @@ def dataloader_by_name(readerclass, ...@@ -59,7 +59,7 @@ def dataloader_by_name(readerclass,
if need_split_files: if need_split_files:
files = split_files(files, context["fleet"].worker_index(), files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num()) context["fleet"].worker_num())
context["file_list"] = files
reader = reader_class(yaml_file) reader = reader_class(yaml_file)
reader.init() reader.init()
...@@ -121,7 +121,7 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context): ...@@ -121,7 +121,7 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
if need_split_files: if need_split_files:
files = split_files(files, context["fleet"].worker_index(), files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num()) context["fleet"].worker_num())
context["file_list"] = files
sparse = get_global_env(name + "sparse_slots", "#") sparse = get_global_env(name + "sparse_slots", "#")
if sparse == "": if sparse == "":
sparse = "#" sparse = "#"
...@@ -191,7 +191,7 @@ def slotdataloader(readerclass, train, yaml_file, context): ...@@ -191,7 +191,7 @@ def slotdataloader(readerclass, train, yaml_file, context):
if need_split_files: if need_split_files:
files = split_files(files, context["fleet"].worker_index(), files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num()) context["fleet"].worker_num())
context["file_list"] = files
sparse = get_global_env("sparse_slots", "#", namespace) sparse = get_global_env("sparse_slots", "#", namespace)
if sparse == "": if sparse == "":
sparse = "#" sparse = "#"
......
...@@ -16,6 +16,8 @@ import datetime ...@@ -16,6 +16,8 @@ import datetime
import os import os
import sys import sys
import time import time
import warnings
import random
import numpy as np import numpy as np
from paddle import fluid from paddle import fluid
...@@ -223,6 +225,16 @@ def check_filelist(hidden_file_list, data_file_list, train_data_path): ...@@ -223,6 +225,16 @@ def check_filelist(hidden_file_list, data_file_list, train_data_path):
return hidden_file_list, data_file_list return hidden_file_list, data_file_list
def shuffle_files(need_shuffle_files, filelist):
if not isinstance(need_shuffle_files, bool):
raise ValueError(
"In your config yaml, 'shuffle_filelist': %s must be written as a boolean type,such as True or False"
% need_shuffle_files)
elif need_shuffle_files:
random.shuffle(filelist)
return filelist
class CostPrinter(object): class CostPrinter(object):
""" """
For count cost time && print cost log For count cost time && print cost log
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册