diff --git a/fleetrec/core/trainer.py b/fleetrec/core/trainer.py index b3a6cc2ee101ca8b38c84ff6c707ccbfc7120d10..06b5800ee49fdf51be51d9ca1e9e304c46d174cd 100755 --- a/fleetrec/core/trainer.py +++ b/fleetrec/core/trainer.py @@ -89,7 +89,7 @@ def user_define_engine(engine_yaml): _config = yaml.load(rb.read(), Loader=yaml.FullLoader) assert _config is not None - envs.set_runtime_envions(_config) + envs.set_runtime_environs(_config) train_location = envs.get_global_env("engine.file") train_dirname = os.path.dirname(train_location) diff --git a/fleetrec/core/utils/envs.py b/fleetrec/core/utils/envs.py index e4a04440f403fa9c7a9357fd7424bdb6a058e9f1..a4bc20f2e367e59071981525c3ccf81b0d4070be 100644 --- a/fleetrec/core/utils/envs.py +++ b/fleetrec/core/utils/envs.py @@ -18,13 +18,14 @@ import copy global_envs = {} -def set_runtime_envions(envs): +def flatten_environs(envs): + flatten_dict = {} assert isinstance(envs, dict) def fatten_env_namespace(namespace_nests, local_envs): if not isinstance(local_envs, dict): global_k = ".".join(namespace_nests) - os.environ[global_k] = str(local_envs) + flatten_dict[global_k] = str(local_envs) else: for k, v in local_envs.items(): if isinstance(v, dict): @@ -33,18 +34,24 @@ def set_runtime_envions(envs): fatten_env_namespace(nests, v) else: global_k = ".".join(namespace_nests + [k]) - os.environ[global_k] = str(v) + flatten_dict[global_k] = str(v) for k, v in envs.items(): fatten_env_namespace([k], v) + return flatten_dict -def get_runtime_envion(key): + +def set_runtime_environs(environs): + for k, v in environs.items(): + os.environ[k] = v + +def get_runtime_environ(key): return os.getenv(key, None) def get_trainer(): - train_mode = get_runtime_envion("trainer.trainer") + train_mode = get_runtime_environ("trainer.trainer") return train_mode diff --git a/fleetrec/examples/ctr-dnn_train.yaml b/fleetrec/examples/ctr-dnn_train.yaml index 3e91d52a62e9ed46a2b199ffebd6cfd03448d3aa..35cd1aeb4e8e0a63016316b53a6adb378172194d 100644 --- a/fleetrec/examples/ctr-dnn_train.yaml +++ b/fleetrec/examples/ctr-dnn_train.yaml @@ -13,6 +13,17 @@ # limitations under the License. train: + trainer: + trainer: "/root/FleetRec/fleetrec/examples/user_define_trainer.py" + threads: 4 + # for cluster training + strategy: "async" + communicator: + send_queue_size: 4 + min_send_grad_num_before_recv: 4 + thread_pool_size: 5 + max_merge_var_num: 4 + epochs: 10 reader: diff --git a/fleetrec/examples/runtime.yaml b/fleetrec/examples/runtime.yaml deleted file mode 100644 index 70fa27c3f2cea6db19d10ffdddc2edfd7fe72983..0000000000000000000000000000000000000000 --- a/fleetrec/examples/runtime.yaml +++ /dev/null @@ -1,13 +0,0 @@ -trainer: - trainer: "/root/FleetRec/fleetrec/examples/user_define_trainer.py" - threads: 4 - - # for cluster training - strategy: "async" - communicator: - send_queue_size: 4 - min_send_grad_num_before_recv: 4 - thread_pool_size: 5 - max_merge_var_num: 4 - - diff --git a/fleetrec/run.py b/fleetrec/run.py index e0afe5a1b041bc05fad0143bfaec6299fc853713..9fbb1bbaaa3cf6846b084a4f1f4ef71d63a361fe 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -13,16 +13,23 @@ clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"] def set_runtime_envs(cluster_envs, engine_yaml): - if engine_yaml is not None: + def get_engine_extras(): with open(engine_yaml, 'r') as rb: _envs = yaml.load(rb.read(), Loader=yaml.FullLoader) - else: - _envs = {} + + flattens = envs.flatten_environs(_envs) + + engine_extras = {} + for k, v in flattens.items(): + if k.startswith("train.trainer."): + engine_extras[k] = v + return engine_extras if cluster_envs is None: cluster_envs = {} - envs.set_runtime_envions(cluster_envs) - envs.set_runtime_envions(_envs) + + envs.set_runtime_environs(cluster_envs) + envs.set_runtime_environs(get_engine_extras()) need_print = {} for k, v in os.environ.items(): @@ -62,9 +69,8 @@ def cluster_engine(args): cluster_envs = {} cluster_envs["trainer.trainer"] = "ClusterTrainer" cluster_envs["trainer.engine"] = "cluster" - set_runtime_envs(cluster_envs, args.engine_extras) + set_runtime_envs(cluster_envs, args.model) - envs.set_runtime_envions(cluster_envs) trainer = TrainerFactory.create(args.model) return trainer @@ -146,11 +152,6 @@ if __name__ == "__main__": if args.engine.upper() not in clusters: raise ValueError("argument engine: {} error, must in {}".format(args.engine, clusters)) - if args.engine_extras is not None: - if not os.path.exists(args.engine_extras) or not os.path.isfile(args.engine_extras): - raise ValueError( - "argument engine_extras: {} error, must specify an existed YAML file".format(args.engine_extras)) - which_engine = get_engine(args.engine) engine = which_engine(args) engine.run()