提交 2ebef2b7 编写于 作者: T tangwei

add local cluster trainer

上级 5234589f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
train:
threads: 12
epochs: 10
trainer: "LocalClusterTraining"
pserver_num: 2
trainer_num: 2
start_port: 36001
log_dirname: "logs"
strategy:
mode: "async"
reader:
mode: "dataset"
batch_size: 2
pipe_command: "python /paddle/eleps/models/ctr_dnn/dataset.py"
train_data_path: "/paddle/eleps/models/ctr_dnn/data/train"
model:
models: "eleps.models.ctr_dnn.model"
hyper_parameters:
sparse_inputs_slots: 27
sparse_feature_number: 1000001
sparse_feature_dim: 8
dense_input_dim: 13
fc_sizes: [512, 256, 128, 32]
learning_rate: 0.001
save:
increment:
dirname: "models_for_increment"
epoch_interval: 2
save_last: True
inference:
dirname: "models_for_inference"
epoch_interval: 4
feed_varnames: ["C1", "C2", "C3"]
fetch_varnames: "predict"
save_last: True
evaluate:
batch_size: 32
train_thread_num: 12
reader: "reader.py"
...@@ -28,7 +28,7 @@ train: ...@@ -28,7 +28,7 @@ train:
threads: 12 threads: 12
epochs: 10 epochs: 10
trainer: "SingleTraining" trainer: "SingleTraining"
role_maler: "PaddleCloudRoleMaker"
strategy: strategy:
mode: "async" mode: "async"
......
...@@ -33,7 +33,7 @@ if __name__ == "__main__": ...@@ -33,7 +33,7 @@ if __name__ == "__main__":
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(abs_dir, 'ctr-dnn_train.yaml'), 'r') as rb: with open(os.path.join(abs_dir, 'ctr-dnn_train_single.yaml'), 'r') as rb:
global_config = yaml.load(rb.read(), Loader=yaml.FullLoader) global_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
trainer = TrainerFactory.create(global_config) trainer = TrainerFactory.create(global_config)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -18,16 +18,12 @@ Training use fluid with one node only. ...@@ -18,16 +18,12 @@ Training use fluid with one node only.
from __future__ import print_function from __future__ import print_function
import os import os
import time
import numpy as np
import logging import logging
import paddle.fluid as fluid import paddle.fluid as fluid
from .trainer import Trainer from .trainer import Trainer
from ..utils import envs from ..utils import envs
from ..reader import dataset
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
from paddle.fluid.incubate.fleet.base.role_maker import PaddleCloudRoleMaker from paddle.fluid.incubate.fleet.base.role_maker import PaddleCloudRoleMaker
...@@ -36,36 +32,29 @@ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") ...@@ -36,36 +32,29 @@ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid") logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
from .transpiler_trainer import TranspileTrainer
def need_save(epoch_id, epoch_interval, is_last=False):
if is_last:
return True
return epoch_id % epoch_interval == 0
class ClusterTrainerWithDataloader(TranspileTrainer):
pass
class ClusterTrainer(Trainer):
def __init__(self, config=None, yaml_file=None): class ClusterTrainerWithDataset(TranspileTrainer):
Trainer.__init__(self, config, yaml_file) def processor_register(self):
role = PaddleCloudRoleMaker()
self.exe = fluid.Executor(fluid.CPUPlace()) fleet.init(role)
if role.is_server():
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init) self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('server_pass', self.server) self.regist_context_processor('server_pass', self.server)
else:
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('train_pass', self.train) self.regist_context_processor('train_pass', self.train)
self.regist_context_processor('terminal_pass', self.terminal) self.regist_context_processor('terminal_pass', self.terminal)
def build_role_maker(self):
role_maker = envs.get_global_env("train.role_maker")
if role_maker == "PaddleCloudRoleMaker":
role = PaddleCloudRoleMaker()
return role
else:
raise ValueError("only support PaddleCloudRoleMaker now")
def build_strategy(self): def build_strategy(self):
mode = envs.get_global_env("train.strategy.mode") mode = envs.get_global_env("train.strategy.mode")
strategy = None strategy = None
...@@ -80,29 +69,22 @@ class ClusterTrainer(Trainer): ...@@ -80,29 +69,22 @@ class ClusterTrainer(Trainer):
elif mode == "half_async": elif mode == "half_async":
strategy = StrategyFactory.create_half_async_strategy() strategy = StrategyFactory.create_half_async_strategy()
return strategy assert strategy is not None
def instance(self, context):
model_package = __import__(envs.get_global_env("train.model.models"))
train_model = getattr(model_package, 'Train')
self.model = train_model()
context['status'] = 'init_pass' return strategy
def init(self, context): def init(self, context):
fleet.init(self.build_role_maker())
self.model.input() self.model.input()
self.model.net() self.model.net()
self.model.loss()
self.metrics = self.model.metrics() self.metrics = self.model.metrics()
self.loss = self.model.avg_loss() self.metric_extras = self.model.metric_extras()
loss = self.model.avg_loss()
optimizer = self.model.optimizer()
optimizer = self.model.get_optimizer()
strategy = self.build_strategy() strategy = self.build_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(self.loss) optimizer.minimize(loss)
if fleet.is_server(): if fleet.is_server():
context['status'] = 'server_pass' context['status'] = 'server_pass'
...@@ -112,102 +94,33 @@ class ClusterTrainer(Trainer): ...@@ -112,102 +94,33 @@ class ClusterTrainer(Trainer):
def server(self, context): def server(self, context):
fleet.init_server() fleet.init_server()
fleet.run_server() fleet.run_server()
context['is_exit'] = True
context['status'] = 'wait'
def terminal(self, context): def terminal(self, context):
fleet.stop_worker() fleet.stop_worker()
context['is_exit'] = True context['is_exit'] = True
def train(self, context): def train(self, context):
print("Need to be implement")
context['is_exit'] = True
class ClusterTrainerWithDataloader(ClusterTrainer):
pass
class ClusterTrainerWithDataset(ClusterTrainer):
def _get_dataset(self, inputs, threads, batch_size, pipe_command, train_files_path):
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var(inputs)
dataset.set_pipe_command(pipe_command)
dataset.set_batch_size(batch_size)
dataset.set_thread(threads)
file_list = [
os.path.join(train_files_path, x)
for x in os.listdir(train_files_path)
]
dataset.set_filelist(file_list)
return dataset
def save(self, epoch_id):
def save_inference_model():
is_save_inference = envs.get_global_env("save.inference", False)
if not is_save_inference:
return
save_interval = envs.get_global_env("save.inference.epoch_interval", 1)
if not need_save(epoch_id, save_interval, False):
return
feed_varnames = envs.get_global_env("save.inference.feed_varnames", None)
fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None)
fetch_vars = [fluid.global_scope().vars[varname] for varname in fetch_varnames]
dirname = envs.get_global_env("save.inference.dirname", None)
assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id))
fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe)
def save_persistables():
is_save_increment = envs.get_global_env("save.increment", False)
if not is_save_increment:
return
save_interval = envs.get_global_env("save.increment.epoch_interval", 1)
if not need_save(epoch_id, save_interval, False):
return
dirname = envs.get_global_env("save.inference.dirname", None)
assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id))
fluid.io.save_persistables(self.exe, dirname)
is_save = envs.get_global_env("save", False)
if not is_save:
return
save_persistables()
save_inference_model()
def train(self, context):
inputs = self.model.input_vars()
threads = envs.get_global_env("threads")
batch_size = envs.get_global_env("batch_size")
pipe_command = envs.get_global_env("pipe_command")
train_data_path = envs.get_global_env("train_data_path")
dataset = self._get_dataset(inputs, threads, batch_size, pipe_command, train_data_path)
fleet.init_worker()
self.exe.run(fleet.startup_program) self.exe.run(fleet.startup_program)
fleet.init_worker()
epochs = envs.get_global_env("epochs") dataset = self._get_dataset()
epochs = envs.get_global_env("train.epochs")
for i in range(epochs): for i in range(epochs):
self.exe.train_from_dataset(program=fluid.default_main_program(), self.exe.train_from_dataset(program=fluid.default_main_program(),
dataset=dataset, dataset=dataset,
fetch_list=[self.metrics], fetch_list=self.metric_extras[0],
fetch_info=["epoch {} auc ".format(i)], fetch_info=self.metric_extras[1],
print_period=100) print_period=self.metric_extras[2])
self.save(i) self.save(i, "train", is_fleet=True)
context['status'] = 'infer_pass' context['status'] = 'infer_pass'
fleet.stop_worker()
def infer(self, context): def infer(self, context):
context['status'] = 'terminal_pass' context['status'] = 'terminal_pass'
def terminal(self, context):
for model in self.increment_models:
print("epoch :{}, dir: {}".format(model[0], model[1]))
context['is_exit'] = True
...@@ -25,25 +25,41 @@ ...@@ -25,25 +25,41 @@
# limitations under the License. # limitations under the License.
import os import os
import sys
import yaml import yaml
from eleps.trainer.single_train import SingleTrainerWithDataloader from eleps.trainer.single_trainer import SingleTrainerWithDataloader
from eleps.trainer.single_train import SingleTrainerWithDataset from eleps.trainer.single_trainer import SingleTrainerWithDataset
from eleps.trainer.cluster_train import ClusterTrainerWithDataloader from eleps.trainer.cluster_trainer import ClusterTrainerWithDataloader
from eleps.trainer.cluster_train import ClusterTrainerWithDataset from eleps.trainer.cluster_trainer import ClusterTrainerWithDataset
from eleps.trainer.local_engine import local_launch
from eleps.trainer.ctr_trainer import CtrPaddleTrainer from eleps.trainer.ctr_trainer import CtrPaddleTrainer
from eleps.utils import envs from eleps.utils import envs
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise ValueError('Boolean value expected.')
class TrainerFactory(object): class TrainerFactory(object):
def __init__(self): def __init__(self):
pass pass
@staticmethod @staticmethod
def _build_trainer(config): def _build_trainer(config):
print(envs.pretty_print_envs(envs.get_global_envs()))
train_mode = envs.get_global_env("train.trainer") train_mode = envs.get_global_env("train.trainer")
reader_mode = envs.get_global_env("train.reader.mode") reader_mode = envs.get_global_env("train.reader.mode")
if train_mode == "SingleTraining": if train_mode == "SingleTraining":
...@@ -67,23 +83,41 @@ class TrainerFactory(object): ...@@ -67,23 +83,41 @@ class TrainerFactory(object):
return trainer return trainer
@staticmethod
def _build_engine(yaml_config):
cluster_envs = {}
cluster_envs["server_num"] = envs.get_global_env("train.pserver_num")
cluster_envs["worker_num"] = envs.get_global_env("train.pserver_num")
cluster_envs["start_port"] = envs.get_global_env("train.start_port")
cluster_envs["log_dir"] = envs.get_global_env("train.log_dirname")
envs.pretty_print_envs(cluster_envs, ("Cluster Global Envs", "Value"))
local_launch(cluster_envs, yaml_config)
@staticmethod @staticmethod
def create(config): def create(config):
_config = None _config = None
if isinstance(config, dict):
_config = config
elif isinstance(config, str):
if os.path.exists(config) and os.path.isfile(config): if os.path.exists(config) and os.path.isfile(config):
with open(config, 'r') as rb: with open(config, 'r') as rb:
_config = yaml.load(rb.read()) _config = yaml.load(rb.read())
else: else:
raise ValueError("unknown config about eleps") raise ValueError("eleps's config only support yaml")
envs.set_global_envs(_config) envs.set_global_envs(_config)
train_mode = envs.get_global_env("train.trainer")
instance = str2bool(os.getenv("CLUSTER_INSTANCE", "0"))
print(envs.pretty_print_envs()) if train_mode == "LocalClusterTraining" and not instance:
trainer = TrainerFactory._build_engine(config)
else:
trainer = TrainerFactory._build_trainer(_config) trainer = TrainerFactory._build_trainer(_config)
return trainer return trainer
# server num, worker num
if __name__ == "__main__":
if len(sys.argv) != 2:
raise ValueError("need a yaml file path argv")
TrainerFactory.create(sys.argv[1])
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import unicode_literals
import subprocess
import sys
import os
import copy
def start_procs(args, yaml):
worker_num = args["worker_num"]
server_num = args["server_num"]
start_port = args["start_port"]
logs_dir = args["log_dir"]
default_env = os.environ.copy()
current_env = copy.copy(default_env)
current_env["CLUSTER_INSTANCE"] = "1"
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
procs = []
log_fns = []
ports = range(start_port, start_port + server_num, 1)
user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]
factory = os.path.join(os.path.abspath(os.path.dirname(__file__)), "factory.py")
for i in range(server_num):
current_env.update({
"PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints,
"PADDLE_PORT": user_endpoints_port[i],
"TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(worker_num),
"POD_IP": user_endpoints_ips[i]
})
cmd = [sys.executable, "-u", factory, yaml]
if args.log_dir is not None:
os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/server.%d" % (logs_dir, i), "w")
log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
else:
proc = subprocess.Popen(cmd, env=current_env)
procs.append(proc)
for i in range(worker_num):
current_env.update({
"PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints,
"PADDLE_TRAINERS_NUM": str(worker_num),
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(i)
})
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
if args.log_dir is not None:
os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/worker.%d" % (logs_dir, i), "w")
log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
else:
proc = subprocess.Popen(cmd, env=current_env)
procs.append(proc)
# only wait worker to finish here
for i, proc in enumerate(procs):
if i < server_num:
continue
procs[i].wait()
if len(log_fns) > 0:
log_fns[i].close()
print("all workers exit, going to finish parameter server", file=sys.stderr)
for i in range(server_num):
if len(log_fns) > 0:
log_fns[i].close()
procs[i].terminate()
print("all parameter server are killed", file=sys.stderr)
def local_launch(envs, trainer):
start_procs(envs, trainer)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training use fluid with one node only.
"""
from __future__ import print_function
import os
import time
import numpy as np
import logging
import paddle.fluid as fluid
from .transpiler_trainer import TranspileTrainer
from ..utils import envs
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
class SingleTrainerWithDataloader(TranspileTrainer):
pass
class SingleTrainerWithDataset(TranspileTrainer):
def processor_register(self):
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('train_pass', self.train)
self.regist_context_processor('infer_pass', self.infer)
self.regist_context_processor('terminal_pass', self.terminal)
def init(self, context):
self.model.input()
self.model.net()
self.metrics = self.model.metrics()
self.metric_extras = self.model.metric_extras()
loss = self.model.avg_loss()
optimizer = self.model.optimizer()
optimizer.minimize(loss)
context['status'] = 'train_pass'
def train(self, context):
# run startup program at once
self.exe.run(fluid.default_startup_program())
dataset = self._get_dataset()
epochs = envs.get_global_env("train.epochs")
for i in range(epochs):
self.exe.train_from_dataset(program=fluid.default_main_program(),
dataset=dataset,
fetch_list=self.metric_extras[0],
fetch_info=self.metric_extras[1],
print_period=self.metric_extras[2])
self.save(i, "train", is_fleet=False)
context['status'] = 'infer_pass'
def infer(self, context):
context['status'] = 'terminal_pass'
def terminal(self, context):
for model in self.increment_models:
print("epoch :{}, dir: {}".format(model[0], model[1]))
context['is_exit'] = True
...@@ -13,82 +13,30 @@ ...@@ -13,82 +13,30 @@
# limitations under the License. # limitations under the License.
""" """
Training use fluid with one node only. Training use fluid with DistributeTranspiler
""" """
from __future__ import print_function
import os import os
import time
import numpy as np
import logging
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from .trainer import Trainer from .trainer import Trainer
from ..utils import envs from ..utils import envs
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
class SingleTrainer(Trainer): class TranspileTrainer(Trainer):
def __init__(self, config=None): def __init__(self, config=None):
Trainer.__init__(self, config) Trainer.__init__(self, config)
self.exe = fluid.Executor(fluid.CPUPlace())
self.processor_register()
self.inference_models = [] self.inference_models = []
self.increment_models = [] self.increment_models = []
self.exe = fluid.Executor(fluid.CPUPlace()) def processor_register(self):
print("Need implement by trainer, `self.regist_context_processor('uninit', self.instance)` must be the first")
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('train_pass', self.train)
self.regist_context_processor('infer_pass', self.infer)
self.regist_context_processor('terminal_pass', self.terminal)
def instance(self, context):
models = envs.get_global_env("train.model.models")
model_package = __import__(models, globals(), locals(), models.split("."))
train_model = getattr(model_package, 'Train')
self.model = train_model()
context['status'] = 'init_pass'
def init(self, context):
self.model.input()
self.model.net()
self.metrics = self.model.metrics()
self.metric_extras = self.model.metric_extras()
loss = self.model.avg_loss()
optimizer = self.model.optimizer()
optimizer.minimize(loss)
# run startup program at once
self.exe.run(fluid.default_startup_program())
context['status'] = 'train_pass'
def train(self, context):
print("Need to be implement")
context['is_exit'] = True
def infer(self, context):
context['is_exit'] = True
def terminal(self, context):
print("clean up and exit")
context['is_exit'] = True
class SingleTrainerWithDataloader(SingleTrainer):
pass
class SingleTrainerWithDataset(SingleTrainer):
def _get_dataset(self): def _get_dataset(self):
namespace = "train.reader" namespace = "train.reader"
...@@ -98,7 +46,6 @@ class SingleTrainerWithDataset(SingleTrainer): ...@@ -98,7 +46,6 @@ class SingleTrainerWithDataset(SingleTrainer):
pipe_command = envs.get_global_env("pipe_command", None, namespace) pipe_command = envs.get_global_env("pipe_command", None, namespace)
train_data_path = envs.get_global_env("train_data_path", None, namespace) train_data_path = envs.get_global_env("train_data_path", None, namespace)
dataset = fluid.DatasetFactory().create_dataset() dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var(inputs) dataset.set_use_var(inputs)
dataset.set_pipe_command(pipe_command) dataset.set_pipe_command(pipe_command)
...@@ -112,7 +59,7 @@ class SingleTrainerWithDataset(SingleTrainer): ...@@ -112,7 +59,7 @@ class SingleTrainerWithDataset(SingleTrainer):
dataset.set_filelist(file_list) dataset.set_filelist(file_list)
return dataset return dataset
def save(self, epoch_id, namespace): def save(self, epoch_id, namespace, is_fleet=False):
def need_save(epoch_id, epoch_interval, is_last=False): def need_save(epoch_id, epoch_interval, is_last=False):
if is_last: if is_last:
return True return True
...@@ -138,10 +85,13 @@ class SingleTrainerWithDataset(SingleTrainer): ...@@ -138,10 +85,13 @@ class SingleTrainerWithDataset(SingleTrainer):
assert dirname is not None assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
fleet.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe)
else:
fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe) fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe)
self.inference_models.append((epoch_id, dirname)) self.inference_models.append((epoch_id, dirname))
def save_persistables(): def save_persistables():
save_interval = envs.get_global_env("save.increment.epoch_interval", -1, namespace) save_interval = envs.get_global_env("save.increment.epoch_interval", -1, namespace)
...@@ -152,31 +102,34 @@ class SingleTrainerWithDataset(SingleTrainer): ...@@ -152,31 +102,34 @@ class SingleTrainerWithDataset(SingleTrainer):
assert dirname is not None assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
fleet.save_persistables(self.exe, dirname)
else:
fluid.io.save_persistables(self.exe, dirname) fluid.io.save_persistables(self.exe, dirname)
self.increment_models.append((epoch_id, dirname)) self.increment_models.append((epoch_id, dirname))
save_persistables() save_persistables()
save_inference_model() save_inference_model()
def train(self, context): def instance(self, context):
dataset = self._get_dataset() models = envs.get_global_env("train.model.models")
model_package = __import__(models, globals(), locals(), models.split("."))
epochs = envs.get_global_env("train.epochs") train_model = getattr(model_package, 'Train')
self.model = train_model()
context['status'] = 'init_pass'
for i in range(epochs): def init(self, context):
self.exe.train_from_dataset(program=fluid.default_main_program(), print("Need to be implement")
dataset=dataset, context['is_exit'] = True
fetch_list=self.metric_extras[0],
fetch_info=self.metric_extras[1],
print_period=self.metric_extras[2])
self.save(i, "train")
context['status'] = 'infer_pass'
def train(self, context):
print("Need to be implement")
context['is_exit'] = True
def infer(self, context): def infer(self, context):
context['status'] = 'terminal_pass' context['is_exit'] = True
def terminal(self, context): def terminal(self, context):
for model in self.increment_models: print("clean up and exit")
print("epoch :{}, dir: {}".format(model[0], model[1]))
context['is_exit'] = True context['is_exit'] = True
...@@ -44,12 +44,16 @@ def get_global_env(env_name, default_value=None, namespace=None): ...@@ -44,12 +44,16 @@ def get_global_env(env_name, default_value=None, namespace=None):
return global_envs.get(_env_name, default_value) return global_envs.get(_env_name, default_value)
def pretty_print_envs(): def get_global_envs():
return global_envs
def pretty_print_envs(envs, header):
spacing = 5 spacing = 5
max_k = 45 max_k = 45
max_v = 20 max_v = 20
for k, v in global_envs.items(): for k, v in envs.items():
max_k = max(max_k, len(k)) max_k = max(max_k, len(k))
max_v = max(max_v, len(str(v))) max_v = max(max_v, len(str(v)))
...@@ -62,14 +66,18 @@ def pretty_print_envs(): ...@@ -62,14 +66,18 @@ def pretty_print_envs():
draws = "" draws = ""
draws += border + "\n" draws += border + "\n"
if header:
draws += h_format.format(header[0], header[1])
else:
draws += h_format.format("Eleps Global Envs", "Value") draws += h_format.format("Eleps Global Envs", "Value")
draws += line + "\n" draws += line + "\n"
for k, v in global_envs.items(): for k, v in envs.items():
draws += l_format.format(k, " " * spacing, str(v)) draws += l_format.format(k, " " * spacing, str(v))
draws += border draws += border
_str = "\n{}\n".format(draws) _str = "\n{}\n".format(draws)
return _str return _str
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册