提交 2cba27a6 编写于 作者: T tangwei

fix import

上级 2e91f58f
......@@ -28,10 +28,7 @@ class TrainerFactory(object):
def _build_trainer(yaml_path):
print(envs.pretty_print_envs(envs.get_global_envs()))
train_mode = envs.get_global_env("train.trainer")
if train_mode is None:
train_mode = envs.get_runtime_envion("train.trainer")
train_mode = envs.get_training_mode()
if train_mode == "SingleTraining":
from fleetrec.core.trainers.single_trainer import SingleTrainer
......
......@@ -12,10 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import json
import datetime
import numpy as np
import paddle.fluid as fluid
......
......@@ -29,6 +29,14 @@ def get_runtime_envion(key):
return os.getenv(key, None)
def get_training_mode():
train_mode = get_global_env("train.trainer")
if train_mode is None:
train_mode = get_runtime_envion("train.trainer")
return train_mode
def set_global_envs(envs):
assert isinstance(envs, dict)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
examples for user
"""
\ No newline at end of file
trainer: "MPIClusterTraining"
pserver_num: 2
trainer_num: 2
start_port: 36001
log_dirname: "logs"
strategy:
mode: "async"
......@@ -27,7 +27,7 @@ train:
hyper_parameters:
sparse_inputs_slots: 27
sparse_feature_number: 1000001
sparse_feature_dim: 8
sparse_feature_dim: 9
dense_input_dim: 13
fc_sizes: [512, 256, 128, 32]
learning_rate: 0.001
......
......@@ -60,13 +60,18 @@ class Model(ModelBase):
self._data_var.append(self.label_input)
def net(self):
def embedding_layer(input):
train_mode = envs.get_training_mode()
is_distributed = True if train_mode == "CtrTraining" else False
sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self.namespace)
sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self.namespace)
sparse_feature_dim = 9 if train_mode == "CtrTraining" else sparse_feature_dim
def embedding_layer(input):
emb = fluid.layers.embedding(
input=input,
is_sparse=True,
is_distributed=is_distributed,
size=[sparse_feature_number, sparse_feature_dim],
param_attr=fluid.ParamAttr(
name="SparseFeatFactors",
......
......@@ -103,6 +103,15 @@ if __name__ == "__main__":
local_mpi_engine(cluster_envs, args.model)
elif args.engine.upper() == "CLUSTER":
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
if version.is_transpiler():
print("use ClusterTraining to run model: {}".format(args.model))
cluster_envs = {"train.trainer": "ClusterTraining"}
envs.set_runtime_envions(cluster_envs)
else:
cluster_envs = {"train.trainer": "CtrTraining"}
envs.set_runtime_envions(cluster_envs)
run(args.model)
elif args.engine.upper() == "USER_DEFINE":
engine_file = args.engine_extras
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册