提交 2f1f76f3 编写于 作者: C chengmo

simple code

上级 a82671c4
...@@ -18,6 +18,7 @@ Training use fluid with one node only. ...@@ -18,6 +18,7 @@ Training use fluid with one node only.
from __future__ import print_function from __future__ import print_function
import os
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
...@@ -39,7 +40,7 @@ class ClusterTrainer(TranspileTrainer): ...@@ -39,7 +40,7 @@ class ClusterTrainer(TranspileTrainer):
else: else:
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init) self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('startup_pass', self.startup)
if envs.get_platform() == "LINUX": if envs.get_platform() == "LINUX":
self.regist_context_processor('train_pass', self.dataset_train) self.regist_context_processor('train_pass', self.dataset_train)
else: else:
...@@ -71,9 +72,11 @@ class ClusterTrainer(TranspileTrainer): ...@@ -71,9 +72,11 @@ class ClusterTrainer(TranspileTrainer):
def init(self, context): def init(self, context):
self.model.train_net() self.model.train_net()
optimizer = self.model.optimizer() optimizer = self.model.optimizer()
optimizer_name = envs.get_global_env("hyper_parameters.optimizer") optimizer_name = envs.get_global_env(
if optimizer_name in ['adam', 'ADAM', 'Adagrad', 'ADAGRAD']: "hyper_parameters.optimizer", None, "train.model")
os.environ["FLAGS_communicator_is_sgd_optimizer"] = 0 if optimizer_name not in ["", "sgd", "SGD", "Sgd"]:
os.environ["FLAGS_communicator_is_sgd_optimizer"] = '0'
strategy = self.build_strategy() strategy = self.build_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(self.model.get_cost_op()) optimizer.minimize(self.model.get_cost_op())
...@@ -89,16 +92,18 @@ class ClusterTrainer(TranspileTrainer): ...@@ -89,16 +92,18 @@ class ClusterTrainer(TranspileTrainer):
if metrics: if metrics:
self.fetch_vars = metrics.values() self.fetch_vars = metrics.values()
self.fetch_alias = metrics.keys() self.fetch_alias = metrics.keys()
context['status'] = 'train_pass' context['status'] = 'startup_pass'
def server(self, context): def server(self, context):
fleet.init_server() fleet.init_server()
fleet.run_server() fleet.run_server()
context['is_exit'] = True context['is_exit'] = True
def dataloader_train(self, context): def startup(self, context):
self._exe.run(fleet.startup_program) self._exe.run(fleet.startup_program)
context['status'] = 'train_pass'
def dataloader_train(self, context):
fleet.init_worker() fleet.init_worker()
reader = self._get_dataloader() reader = self._get_dataloader()
...@@ -144,7 +149,6 @@ class ClusterTrainer(TranspileTrainer): ...@@ -144,7 +149,6 @@ class ClusterTrainer(TranspileTrainer):
context['status'] = 'terminal_pass' context['status'] = 'terminal_pass'
def dataset_train(self, context): def dataset_train(self, context):
self._exe.run(fleet.startup_program)
fleet.init_worker() fleet.init_worker()
dataset = self._get_dataset() dataset = self._get_dataset()
......
...@@ -33,7 +33,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -33,7 +33,7 @@ class SingleTrainer(TranspileTrainer):
def processor_register(self): def processor_register(self):
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init) self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('startup_pass', self.startup)
if envs.get_platform() == "LINUX": if envs.get_platform() == "LINUX":
self.regist_context_processor('train_pass', self.dataset_train) self.regist_context_processor('train_pass', self.dataset_train)
else: else:
...@@ -55,10 +55,14 @@ class SingleTrainer(TranspileTrainer): ...@@ -55,10 +55,14 @@ class SingleTrainer(TranspileTrainer):
if metrics: if metrics:
self.fetch_vars = metrics.values() self.fetch_vars = metrics.values()
self.fetch_alias = metrics.keys() self.fetch_alias = metrics.keys()
context['status'] = 'startup_pass'
def startup(self, context):
self._exe.run(fluid.default_startup_program())
context['status'] = 'train_pass' context['status'] = 'train_pass'
def dataloader_train(self, context): def dataloader_train(self, context):
self._exe.run(fluid.default_startup_program())
self.model.custom_preprocess() self.model.custom_preprocess()
reader = self._get_dataloader() reader = self._get_dataloader()
......
...@@ -18,7 +18,7 @@ Training use fluid with one node only. ...@@ -18,7 +18,7 @@ Training use fluid with one node only.
""" """
from __future__ import print_function from __future__ import print_function
import logging
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
...@@ -26,85 +26,28 @@ from paddle.fluid.incubate.fleet.base.role_maker import PaddleCloudRoleMaker ...@@ -26,85 +26,28 @@ from paddle.fluid.incubate.fleet.base.role_maker import PaddleCloudRoleMaker
from fleetrec.core.utils import envs from fleetrec.core.utils import envs
from fleetrec.core.trainers.transpiler_trainer import TranspileTrainer from fleetrec.core.trainers.transpiler_trainer import TranspileTrainer
special_param = ["TDM_Tree_Travel", "TDM_Tree_Layer", "TDM_Tree_Info"]
class TDMClusterTrainer(TranspileTrainer): logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
def processor_register(self): logger = logging.getLogger("fluid")
role = PaddleCloudRoleMaker() logger.setLevel(logging.INFO)
fleet.init(role) special_param = ["TDM_Tree_Travel", "TDM_Tree_Layer", "TDM_Tree_Info"]
if fleet.is_server():
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('server_pass', self.server)
else:
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor(
'trainer_startup_pass', self.trainer_startup)
if envs.get_platform() == "LINUX":
self.regist_context_processor('train_pass', self.dataset_train)
else:
self.regist_context_processor(
'train_pass', self.dataloader_train)
self.regist_context_processor('terminal_pass', self.terminal)
def build_strategy(self):
mode = envs.get_runtime_environ("train.trainer.strategy")
assert mode in ["async", "geo", "sync", "half_async"]
strategy = None
if mode == "async":
strategy = StrategyFactory.create_async_strategy()
elif mode == "geo":
push_num = envs.get_global_env("train.strategy.mode.push_num", 100)
strategy = StrategyFactory.create_geo_strategy(push_num)
elif mode == "sync":
strategy = StrategyFactory.create_sync_strategy()
elif mode == "half_async":
strategy = StrategyFactory.create_half_async_strategy()
assert strategy is not None
self.strategy = strategy
return strategy
def init(self, context):
self.model.train_net()
optimizer = self.model.optimizer()
optimizer_name = envs.get_global_env("hyper_parameters.optimizer")
if optimizer_name in ['adam', 'ADAM', 'Adagrad', 'ADAGRAD']:
os.environ["FLAGS_communicator_is_sgd_optimizer"] = 0
strategy = self.build_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(self.model.get_cost_op())
if fleet.is_server():
context['status'] = 'server_pass'
else:
self.fetch_vars = []
self.fetch_alias = []
self.fetch_period = self.model.get_fetch_period()
metrics = self.model.get_metrics()
if metrics:
self.fetch_vars = metrics.values()
self.fetch_alias = metrics.keys()
context['status'] = 'trainer_startup_pass'
class TDMClusterTrainer(TranspileTrainer):
def server(self, context): def server(self, context):
namespace = "train.startup" namespace = "train.startup"
model_path = envs.get_global_env( init_model_path = envs.get_global_env(
"cluster.model_path", "", namespace) "cluster.init_model_path", "", namespace)
assert not model_path, "Cluster train must has init_model for TDM" assert init_model_path != "", "Cluster train must has init_model for TDM"
fleet.init_server(model_path) fleet.init_server(init_model_path)
logger.info("TDM: load model from {}".format(init_model_path))
fleet.run_server() fleet.run_server()
context['is_exit'] = True context['is_exit'] = True
def trainer_startup(self, context): def startup(self, context):
self._exe.run(fleet.startup_program)
namespace = "train.startup" namespace = "train.startup"
load_tree = envs.get_global_env( load_tree = envs.get_global_env(
"cluster.load_tree", True, namespace) "cluster.load_tree", True, namespace)
...@@ -119,7 +62,6 @@ class TDMClusterTrainer(TranspileTrainer): ...@@ -119,7 +62,6 @@ class TDMClusterTrainer(TranspileTrainer):
"cluster.save_init_model", False, namespace) "cluster.save_init_model", False, namespace)
init_model_path = envs.get_global_env( init_model_path = envs.get_global_env(
"cluster.init_model_path", "", namespace) "cluster.init_model_path", "", namespace)
self._exe.run(fluid.default_startup_program())
if load_tree: if load_tree:
# 将明文树结构及数据,set到组网中的Variale中 # 将明文树结构及数据,set到组网中的Variale中
...@@ -137,74 +79,47 @@ class TDMClusterTrainer(TranspileTrainer): ...@@ -137,74 +79,47 @@ class TDMClusterTrainer(TranspileTrainer):
context['status'] = 'train_pass' context['status'] = 'train_pass'
def dataloader_train(self, context): def tdm_prepare(self, param_name):
self._exe.run(fleet.startup_program) if param_name == "TDM_Tree_Travel":
travel_array = self.tdm_travel_prepare()
fleet.init_worker() return travel_array
elif param_name == "TDM_Tree_Layer":
reader = self._get_dataloader() layer_array, _ = self.tdm_layer_prepare()
epochs = envs.get_global_env("train.epochs") return layer_array
elif param_name == "TDM_Tree_Info":
program = fluid.compiler.CompiledProgram( info_array = self.tdm_info_prepare()
fleet.main_program).with_data_parallel( return info_array
loss_name=self.model.get_cost_op().name, else:
build_strategy=self.strategy.get_build_strategy(), raise " {} is not a special tdm param name".format(param_name)
exec_strategy=self.strategy.get_execute_strategy())
def tdm_travel_prepare(self):
metrics_varnames = [] """load tdm tree param from npy/list file"""
metrics_format = [] travel_array = np.load(self.tree_travel_path)
logger.info("TDM Tree leaf node nums: {}".format(
metrics_format.append("{}: {{}}".format("epoch")) travel_array.shape[0]))
metrics_format.append("{}: {{}}".format("batch")) return travel_array
for name, var in self.model.get_metrics().items(): def tdm_layer_prepare(self):
metrics_varnames.append(var.name) """load tdm tree param from npy/list file"""
metrics_format.append("{}: {{}}".format(name)) layer_list = []
layer_list_flat = []
metrics_format = ", ".join(metrics_format) with open(self.tree_layer_path, 'r') as fin:
for line in fin.readlines():
for epoch in range(epochs): l = []
reader.start() layer = (line.split('\n'))[0].split(',')
batch_id = 0 for node in layer:
try: if node:
while True: layer_list_flat.append(node)
metrics_rets = self._exe.run( l.append(node)
program=program, layer_list.append(l)
fetch_list=metrics_varnames) layer_array = np.array(layer_list_flat)
layer_array = layer_array.reshape([-1, 1])
metrics = [epoch, batch_id] logger.info("TDM Tree max layer: {}".format(len(layer_list)))
metrics.extend(metrics_rets) logger.info("TDM Tree layer_node_num_list: {}".format(
[len(i) for i in layer_list]))
if batch_id % 10 == 0 and batch_id != 0: return layer_array, layer_list
print(metrics_format.format(*metrics))
batch_id += 1 def tdm_info_prepare(self):
except fluid.core.EOFException: """load tdm tree param from list file"""
reader.reset() info_array = np.load(self.tree_info_path)
return info_array
fleet.stop_worker()
context['status'] = 'terminal_pass'
def dataset_train(self, context):
self._exe.run(fleet.startup_program)
fleet.init_worker()
dataset = self._get_dataset()
epochs = envs.get_global_env("train.epochs")
for i in range(epochs):
self._exe.train_from_dataset(program=fluid.default_main_program(),
dataset=dataset,
fetch_list=self.fetch_vars,
fetch_info=self.fetch_alias,
print_period=self.fetch_period)
self.save(i, "train", is_fleet=True)
fleet.stop_worker()
context['status'] = 'terminal_pass'
def infer(self, context):
context['status'] = 'terminal_pass'
def terminal(self, context):
for model in self.increment_models:
print("epoch :{}, dir: {}".format(model[0], model[1]))
context['is_exit'] = True
...@@ -34,34 +34,6 @@ special_param = ["TDM_Tree_Travel", "TDM_Tree_Layer", ...@@ -34,34 +34,6 @@ special_param = ["TDM_Tree_Travel", "TDM_Tree_Layer",
class TDMSingleTrainer(SingleTrainer): class TDMSingleTrainer(SingleTrainer):
def processor_register(self):
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('startup_pass', self.startup)
if envs.get_platform() == "LINUX":
self.regist_context_processor('train_pass', self.dataset_train)
else:
self.regist_context_processor('train_pass', self.dataloader_train)
self.regist_context_processor('infer_pass', self.infer)
self.regist_context_processor('terminal_pass', self.terminal)
def init(self, context):
self.model.train_net()
optimizer = self.model.optimizer()
optimizer.minimize((self.model.get_cost_op()))
self.fetch_vars = []
self.fetch_alias = []
self.fetch_period = self.model.get_fetch_period()
metrics = self.model.get_metrics()
if metrics:
self.fetch_vars = metrics.values()
self.fetch_alias = metrics.keys()
context['status'] = 'startup_pass'
def startup(self, context): def startup(self, context):
namespace = "train.startup" namespace = "train.startup"
load_persistables = envs.get_global_env( load_persistables = envs.get_global_env(
...@@ -114,67 +86,6 @@ class TDMSingleTrainer(SingleTrainer): ...@@ -114,67 +86,6 @@ class TDMSingleTrainer(SingleTrainer):
context['status'] = 'train_pass' context['status'] = 'train_pass'
def dataloader_train(self, context):
reader = self._get_dataloader()
epochs = envs.get_global_env("train.epochs")
program = fluid.compiler.CompiledProgram(
fluid.default_main_program()).with_data_parallel(
loss_name=self.model.get_cost_op().name)
metrics_varnames = []
metrics_format = []
metrics_format.append("{}: {{}}".format("epoch"))
metrics_format.append("{}: {{}}".format("batch"))
for name, var in self.model.get_metrics().items():
metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format)
for epoch in range(epochs):
reader.start()
batch_id = 0
try:
while True:
metrics_rets = self._exe.run(
program=program,
fetch_list=metrics_varnames)
metrics = [epoch, batch_id]
metrics.extend(metrics_rets)
if batch_id % 10 == 0 and batch_id != 0:
print(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
reader.reset()
context['status'] = 'infer_pass'
def dataset_train(self, context):
dataset = self._get_dataset()
epochs = envs.get_global_env("train.epochs")
for i in range(epochs):
self._exe.train_from_dataset(program=fluid.default_main_program(),
dataset=dataset,
fetch_list=self.fetch_vars,
fetch_info=self.fetch_alias,
print_period=self.fetch_period)
self.save(i, "train", is_fleet=False)
context['status'] = 'infer_pass'
def infer(self, context):
context['status'] = 'terminal_pass'
def terminal(self, context):
for model in self.increment_models:
print("epoch :{}, dir: {}".format(model[0], model[1]))
context['is_exit'] = True
def tdm_prepare(self, param_name): def tdm_prepare(self, param_name):
if param_name == "TDM_Tree_Travel": if param_name == "TDM_Tree_Travel":
travel_array = self.tdm_travel_prepare() travel_array = self.tdm_travel_prepare()
......
...@@ -10,6 +10,7 @@ from fleetrec.core.utils import util ...@@ -10,6 +10,7 @@ from fleetrec.core.utils import util
engines = {} engines = {}
device = ["CPU", "GPU"] device = ["CPU", "GPU"]
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"] clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]
custom_model = ['tdm']
def engine_registry(): def engine_registry():
...@@ -31,9 +32,12 @@ def engine_registry(): ...@@ -31,9 +32,12 @@ def engine_registry():
engines["GPU"] = gpu engines["GPU"] = gpu
def get_engine(engine, device): def get_engine(args):
device = args.device
d_engine = engines[device] d_engine = engines[device]
transpiler = get_transpiler() transpiler = get_transpiler()
engine = get_custom_model_engine(args)
run_engine = d_engine[transpiler].get(engine, None) run_engine = d_engine[transpiler].get(engine, None)
if run_engine is None: if run_engine is None:
...@@ -42,6 +46,16 @@ def get_engine(engine, device): ...@@ -42,6 +46,16 @@ def get_engine(engine, device):
return run_engine return run_engine
def get_custom_model_engine(args):
model = args.model
model_name = model.split('.')[1]
if model_name in custom_model:
engine = "_".join((model_name.upper(), args.engine))
else:
engine = args.engine
return engine
def get_transpiler(): def get_transpiler():
FNULL = open(os.devnull, 'w') FNULL = open(os.devnull, 'w')
cmd = ["python", "-c", cmd = ["python", "-c",
...@@ -81,30 +95,23 @@ def set_runtime_envs(cluster_envs, engine_yaml): ...@@ -81,30 +95,23 @@ def set_runtime_envs(cluster_envs, engine_yaml):
print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value"))) print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value")))
def single_engine(args): def get_trainer_prefix(args):
print("use single engine to run model: {}".format(args.model)) model = args.model
model_name = model.split('.')[1]
single_envs = {} if model_name in custom_model:
single_envs["train.trainer.trainer"] = "SingleTrainer" return model_name.upper()
single_envs["train.trainer.threads"] = "2" return ""
single_envs["train.trainer.engine"] = "single"
single_envs["train.trainer.device"] = args.device
single_envs["train.trainer.platform"] = envs.get_platform()
set_runtime_envs(single_envs, args.model)
trainer = TrainerFactory.create(args.model)
return trainer
def tdm_single_engine(args):
print("use tdm single engine to run model: {}".format(args.model))
def single_engine(args):
trainer = get_trainer_prefix(args) + "SingleTrainer"
single_envs = {} single_envs = {}
single_envs["train.trainer.trainer"] = "TDMSingleTrainer" single_envs["train.trainer.trainer"] = trainer
single_envs["train.trainer.threads"] = "2" single_envs["train.trainer.threads"] = "2"
single_envs["train.trainer.engine"] = "single" single_envs["train.trainer.engine"] = "single"
single_envs["train.trainer.device"] = args.device single_envs["train.trainer.device"] = args.device
single_envs["train.trainer.platform"] = envs.get_platform() single_envs["train.trainer.platform"] = envs.get_platform()
print("use {} engine to run model: {}".format(trainer, args.model))
set_runtime_envs(single_envs, args.model) set_runtime_envs(single_envs, args.model)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
...@@ -112,31 +119,15 @@ def tdm_single_engine(args): ...@@ -112,31 +119,15 @@ def tdm_single_engine(args):
def cluster_engine(args): def cluster_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model)) trainer = get_trainer_prefix(args) + "ClusterTrainer"
cluster_envs = {}
cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
cluster_envs["train.trainer.engine"] = "cluster"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform()
set_runtime_envs(cluster_envs, args.model)
trainer = TrainerFactory.create(args.model)
return trainer
def tdm_cluster_engine(args):
print("launch tdm cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {} cluster_envs = {}
cluster_envs["train.trainer.trainer"] = "TDMClusterTrainer" cluster_envs["train.trainer.trainer"] = trainer
cluster_envs["train.trainer.engine"] = "cluster" cluster_envs["train.trainer.engine"] = "cluster"
cluster_envs["train.trainer.device"] = args.device cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform() cluster_envs["train.trainer.platform"] = envs.get_platform()
print("launch {} engine with cluster to run model: {}".format(trainer, args.model))
set_runtime_envs(cluster_envs, args.model) set_runtime_envs(cluster_envs, args.model)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
...@@ -156,40 +147,15 @@ def cluster_mpi_engine(args): ...@@ -156,40 +147,15 @@ def cluster_mpi_engine(args):
def local_cluster_engine(args): def local_cluster_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model))
from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
cluster_envs = {}
cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
cluster_envs["train.trainer.strategy"] = "async"
cluster_envs["train.trainer.threads"] = "2"
cluster_envs["train.trainer.engine"] = "local_cluster"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform()
cluster_envs["CPU_NUM"] = "2"
set_runtime_envs(cluster_envs, args.model)
launch = LocalClusterEngine(cluster_envs, args.model)
return launch
def tdm_local_cluster_engine(args):
print("launch tdm cluster engine with cluster to run model: {}".format(args.model))
from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
trainer = get_trainer_prefix(args) + "ClusterTrainer"
cluster_envs = {} cluster_envs = {}
cluster_envs["server_num"] = 1 cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1 cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001 cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer.trainer"] = "TDMClusterTrainer" cluster_envs["train.trainer.trainer"] = trainer
cluster_envs["train.trainer.strategy"] = "async" cluster_envs["train.trainer.strategy"] = "async"
cluster_envs["train.trainer.threads"] = "2" cluster_envs["train.trainer.threads"] = "2"
cluster_envs["train.trainer.engine"] = "local_cluster" cluster_envs["train.trainer.engine"] = "local_cluster"
...@@ -198,9 +164,9 @@ def tdm_local_cluster_engine(args): ...@@ -198,9 +164,9 @@ def tdm_local_cluster_engine(args):
cluster_envs["train.trainer.platform"] = envs.get_platform() cluster_envs["train.trainer.platform"] = envs.get_platform()
cluster_envs["CPU_NUM"] = "2" cluster_envs["CPU_NUM"] = "2"
print("launch {} engine with cluster to run model: {}".format(trainer, args.model))
set_runtime_envs(cluster_envs, args.model) set_runtime_envs(cluster_envs, args.model)
launch = LocalClusterEngine(cluster_envs, args.model) launch = LocalClusterEngine(cluster_envs, args.model)
return launch return launch
...@@ -258,7 +224,7 @@ if __name__ == "__main__": ...@@ -258,7 +224,7 @@ if __name__ == "__main__":
args.model = get_abs_model(args.model) args.model = get_abs_model(args.model)
engine_registry() engine_registry()
which_engine = get_engine(args.engine, args.device) which_engine = get_engine(args)
engine = which_engine(args) engine = which_engine(args)
engine.run() engine.run()
...@@ -46,22 +46,22 @@ train: ...@@ -46,22 +46,22 @@ train:
startup: startup:
single: tree:
# 建议tree只load一次,保存为paddle tensor,之后从paddle模型热启 # 单机训练建议tree只load一次,保存为paddle tensor,之后从paddle模型热启
load_persistables: False # 分布式训练trainer需要独立load
persistables_model_path: ""
load_tree: True load_tree: True
tree_layer_path: "{workspace}/tree/layer_list.txt" tree_layer_path: "{workspace}/tree/layer_list.txt"
tree_travel_path: "{workspace}/tree/travel_list.npy" tree_travel_path: "{workspace}/tree/travel_list.npy"
tree_info_path: "{workspace}/tree/tree_info.npy" tree_info_path: "{workspace}/tree/tree_info.npy"
tree_emb_path: "{workspace}/tree/tree_emb.npy" tree_emb_path: "{workspace}/tree/tree_emb.npy"
single:
load_persistables: False
persistables_model_path: ""
save_init_model: True save_init_model: True
init_model_path: "" init_model_path: "{workspace}/init_model"
cluster: cluster:
load_persistables: True init_model_path: "{workspace}/init_model"
persistables_model_path: ""
save: save:
increment: increment:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册