未验证 提交 c26b0e75 编写于 作者: L liuyuhui 提交者: GitHub

add support for file_list shuffle each epoch and fix float learning rate bug (#197)

* add support for file_list shuffle each epoch, test=develop

* fix float learning rate bug

* optimized code for shuffle files
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 770693ab
......@@ -177,6 +177,13 @@ class ModelBase(object):
opt_name = envs.get_global_env("hyper_parameters.optimizer.class")
opt_lr = envs.get_global_env(
"hyper_parameters.optimizer.learning_rate")
if not isinstance(opt_lr, (float, Variable)):
try:
opt_lr = float(opt_lr)
except ValueError:
raise ValueError(
"In your config yaml, 'learning_rate': %s must be written as a floating piont number,such as 0.001 or 1e-3"
% opt_lr)
opt_strategy = envs.get_global_env(
"hyper_parameters.optimizer.strategy")
......
......@@ -143,6 +143,8 @@ class QueueDataset(DatasetBase):
if need_split_files:
file_list = split_files(file_list, context["fleet"].worker_index(),
context["fleet"].worker_num())
context["file_list"] = file_list
print("File_list: {}".format(file_list))
dataset.set_filelist(file_list)
......
......@@ -18,10 +18,12 @@ import os
import time
import warnings
import numpy as np
import random
import logging
import paddle.fluid as fluid
from paddlerec.core.utils import envs
from paddlerec.core.utils.util import shuffle_files
from paddlerec.core.metric import Metric
logging.basicConfig(
......@@ -92,7 +94,6 @@ class RunnerBase(object):
reader_name = model_dict["dataset_name"]
model_name = model_dict["name"]
model_class = context["model"][model_dict["name"]]["model"]
fetch_vars = []
fetch_alias = []
fetch_period = int(
......@@ -395,7 +396,12 @@ class SingleRunner(RunnerBase):
for model_dict in context["phases"]:
model_class = context["model"][model_dict["name"]]["model"]
metrics = model_class._metrics
if "shuffle_filelist" in model_dict:
need_shuffle_files = model_dict.get("shuffle_filelist",
None)
filelist = context["file_list"]
context["file_list"] = shuffle_files(need_shuffle_files,
filelist)
begin_time = time.time()
result = self._run(context, model_dict)
end_time = time.time()
......@@ -439,6 +445,11 @@ class PSRunner(RunnerBase):
model_class = context["model"][model_dict["name"]]["model"]
metrics = model_class._metrics
for epoch in range(epochs):
if "shuffle_filelist" in model_dict:
need_shuffle_files = model_dict.get("shuffle_filelist", None)
filelist = context["file_list"]
context["file_list"] = shuffle_files(need_shuffle_files,
filelist)
begin_time = time.time()
result = self._run(context, model_dict)
end_time = time.time()
......@@ -484,6 +495,11 @@ class CollectiveRunner(RunnerBase):
".epochs"))
model_dict = context["env"]["phase"][0]
for epoch in range(epochs):
if "shuffle_filelist" in model_dict:
need_shuffle_files = model_dict.get("shuffle_filelist", None)
filelist = context["file_list"]
context["file_list"] = shuffle_files(need_shuffle_files,
filelist)
begin_time = time.time()
self._run(context, model_dict)
end_time = time.time()
......@@ -512,6 +528,11 @@ class PslibRunner(RunnerBase):
envs.get_global_env("runner." + context["runner_name"] +
".epochs"))
for epoch in range(epochs):
if "shuffle_filelist" in model_dict:
need_shuffle_files = model_dict.get("shuffle_filelist", None)
filelist = context["file_list"]
context["file_list"] = shuffle_files(need_shuffle_files,
filelist)
begin_time = time.time()
self._run(context, model_dict)
end_time = time.time()
......@@ -574,6 +595,12 @@ class SingleInferRunner(RunnerBase):
metrics = model_class._infer_results
self._load(context, model_dict,
self.epoch_model_path_list[index])
if "shuffle_filelist" in model_dict:
need_shuffle_files = model_dict.get("shuffle_filelist",
None)
filelist = context["file_list"]
context["file_list"] = shuffle_files(need_shuffle_files,
filelist)
begin_time = time.time()
result = self._run(context, model_dict)
end_time = time.time()
......
......@@ -59,7 +59,7 @@ def dataloader_by_name(readerclass,
if need_split_files:
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
context["file_list"] = files
reader = reader_class(yaml_file)
reader.init()
......@@ -121,7 +121,7 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
if need_split_files:
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
context["file_list"] = files
sparse = get_global_env(name + "sparse_slots", "#")
if sparse == "":
sparse = "#"
......@@ -191,7 +191,7 @@ def slotdataloader(readerclass, train, yaml_file, context):
if need_split_files:
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
context["file_list"] = files
sparse = get_global_env("sparse_slots", "#", namespace)
if sparse == "":
sparse = "#"
......
......@@ -16,6 +16,8 @@ import datetime
import os
import sys
import time
import warnings
import random
import numpy as np
from paddle import fluid
......@@ -223,6 +225,16 @@ def check_filelist(hidden_file_list, data_file_list, train_data_path):
return hidden_file_list, data_file_list
def shuffle_files(need_shuffle_files, filelist):
if not isinstance(need_shuffle_files, bool):
raise ValueError(
"In your config yaml, 'shuffle_filelist': %s must be written as a boolean type,such as True or False"
% need_shuffle_files)
elif need_shuffle_files:
random.shuffle(filelist)
return filelist
class CostPrinter(object):
"""
For count cost time && print cost log
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册