提交 9438cb36 编写于 作者: T tangwei12

bug fix

上级 77a2da6d
...@@ -23,6 +23,7 @@ import paddle.fluid as fluid ...@@ -23,6 +23,7 @@ import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
from paddle.fluid.incubate.fleet.base.role_maker import PaddleCloudRoleMaker from paddle.fluid.incubate.fleet.base.role_maker import PaddleCloudRoleMaker
from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker
from fleetrec.core.utils import envs from fleetrec.core.utils import envs
from fleetrec.core.trainers.transpiler_trainer import TranspileTrainer from fleetrec.core.trainers.transpiler_trainer import TranspileTrainer
...@@ -30,7 +31,8 @@ from fleetrec.core.trainers.transpiler_trainer import TranspileTrainer ...@@ -30,7 +31,8 @@ from fleetrec.core.trainers.transpiler_trainer import TranspileTrainer
class ClusterTrainer(TranspileTrainer): class ClusterTrainer(TranspileTrainer):
def processor_register(self): def processor_register(self):
role = PaddleCloudRoleMaker() #role = PaddleCloudRoleMaker()
role = MPISymetricRoleMaker()
fleet.init(role) fleet.init(role)
if fleet.is_server(): if fleet.is_server():
......
...@@ -72,7 +72,8 @@ class TranspileTrainer(Trainer): ...@@ -72,7 +72,8 @@ class TranspileTrainer(Trainer):
train_data_path = envs.get_global_env( train_data_path = envs.get_global_env(
"test_data_path", None, namespace) "test_data_path", None, namespace)
threads = int(envs.get_runtime_environ("train.trainer.threads")) #threads = int(envs.get_runtime_environ("train.trainer.threads"))
threads = 2
batch_size = envs.get_global_env("batch_size", None, namespace) batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace) reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
......
...@@ -110,7 +110,6 @@ def single_engine(args): ...@@ -110,7 +110,6 @@ def single_engine(args):
def cluster_engine(args): def cluster_engine(args):
from fleetrec.core.engine.cluster.cluster import ClusterEngine
def update_workspace(cluster_envs): def update_workspace(cluster_envs):
workspace = cluster_envs.get("engine_workspace", None) workspace = cluster_envs.get("engine_workspace", None)
...@@ -131,6 +130,7 @@ def cluster_engine(args): ...@@ -131,6 +130,7 @@ def cluster_engine(args):
cluster_envs[name] = value cluster_envs[name] = value
def master(): def master():
from fleetrec.core.engine.cluster.cluster import ClusterEngine
with open(args.backend, 'r') as rb: with open(args.backend, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader) _envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
...@@ -155,10 +155,10 @@ def cluster_engine(args): ...@@ -155,10 +155,10 @@ def cluster_engine(args):
print("launch {} engine with cluster to with model: {}".format(trainer, args.model)) print("launch {} engine with cluster to with model: {}".format(trainer, args.model))
set_runtime_envs(cluster_envs, args.model) set_runtime_envs(cluster_envs, args.model)
launch = LocalClusterEngine(cluster_envs, args.model) trainer = TrainerFactory.create(args.model)
return launch return trainer
if args.role == "worker": if args.role == "WORKER":
return worker() return worker()
else: else:
return master() return master()
......
...@@ -29,8 +29,8 @@ function package() { ...@@ -29,8 +29,8 @@ function package() {
cp ${engine_submit_qconf} ${temp} cp ${engine_submit_qconf} ${temp}
echo "copy job.sh from " ${engine_worker} " to " ${temp} echo "copy job.sh from " ${engine_worker} " to " ${temp}
mkdir -p ${temp}/package/python mkdir -p ${temp}/package
cp -r ${engine_package_python}/* ${temp}/package/python/ cp -r ${engine_package_python} ${temp}/package/
echo "copy python from " ${engine_package_python} " to " ${temp} echo "copy python from " ${engine_package_python} " to " ${temp}
mkdir ${temp}/package/whl mkdir ${temp}/package/whl
......
...@@ -16,10 +16,10 @@ declare g_run_stage="" ...@@ -16,10 +16,10 @@ declare g_run_stage=""
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# const define # # const define #
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
declare -r FLAGS_communicator_thread_pool_size=5 export FLAGS_communicator_thread_pool_size=5
declare -r FLAGS_communicator_send_queue_size=18 export FLAGS_communicator_send_queue_size=18
declare -r FLAGS_communicator_thread_pool_size=20 export FLAGS_communicator_thread_pool_size=20
declare -r FLAGS_communicator_max_merge_var_num=18 export FLAGS_communicator_max_merge_var_num=18
################################################################################ ################################################################################
#----------------------------------------------------------------------------------------------------------------- #-----------------------------------------------------------------------------------------------------------------
...@@ -44,9 +44,20 @@ function env_prepare() { ...@@ -44,9 +44,20 @@ function env_prepare() {
WORKDIR=$(pwd) WORKDIR=$(pwd)
mpirun -npernode 1 mv package/* ./ mpirun -npernode 1 mv package/* ./
echo "current:"$WORKDIR echo "current:"$WORKDIR
export LIBRARY_PATH=$WORKDIR/python/lib:$LIBRARY_PATH
mpirun -npernode 1 python/bin/python -m pip install whl/fleet_rec-0.0.2-py2-none-any.whl --index-url=http://pip.baidu.com/pypi/simple --trusted-host pip.baidu.com >/dev/null mpirun -npernode 1 tar -zxvf python.tar.gz > /dev/null
export PYTHONPATH=$WORKDIR/python/
export PYTHONROOT=$WORKDIR/python/
export LIBRARY_PATH=$PYTHONPATH/lib:$LIBRARY_PATH
export LD_LIBRARY_PATH=$PYTHONPATH/lib:$LD_LIBRARY_PATH
export PATH=$PYTHONPATH/bin:$PATH
export LIBRARY_PATH=$PYTHONROOT/lib:$LIBRARY_PATH
python -c "print('heheda')"
mpirun -npernode 1 python/bin/python -m pip uninstall -y fleet-rec
mpirun -npernode 1 python/bin/python -m pip install whl/fleet_rec-0.0.2-py2-none-any.whl --index-url=http://pip.baidu.com/pypi/simple --trusted-host pip.baidu.com
check_error check_error
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册