提交 291e1594 编写于 作者: C chengmo

revert performance code

上级 d3cda7f7
...@@ -97,7 +97,6 @@ message AsyncConfig { ...@@ -97,7 +97,6 @@ message AsyncConfig {
optional int32 thread_pool_size = 6 [ default = 1 ]; optional int32 thread_pool_size = 6 [ default = 1 ];
optional int32 send_wait_times = 7 [ default = 1 ]; optional int32 send_wait_times = 7 [ default = 1 ];
optional bool runtime_split_send_recv = 8 [ default = false ]; optional bool runtime_split_send_recv = 8 [ default = false ];
optional string heter_worker_device = 9 [ default = 'cpu' ];
} }
message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; } message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; }
......
...@@ -112,6 +112,10 @@ void RecvSelectedRows(const CommContext &rpc_ctx, ...@@ -112,6 +112,10 @@ void RecvSelectedRows(const CommContext &rpc_ctx,
template <typename T> template <typename T>
void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) { void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id); distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
...@@ -121,14 +125,10 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) { ...@@ -121,14 +125,10 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
if (rpc_ctx.origin_varnames.size() == 1 && if (rpc_ctx.origin_varnames.size() == 1 &&
rpc_ctx.splited_varnames.size() == 1) { rpc_ctx.splited_varnames.size() == 1) {
auto varname = rpc_ctx.origin_varnames[0]; auto varname = rpc_ctx.origin_varnames[0];
const auto place = VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0];
scope.FindVar(varname)->Get<framework::LoDTensor>().place(); rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], cpu_ctx,
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(place);
VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0] << " in gpu? "
<< platform::is_gpu_place(place);
rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], ctx,
scope, varname, varname)); scope, varname, varname));
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U, rets[i]->Wait(), 0U,
......
...@@ -259,7 +259,7 @@ class DistributedStrategy(object): ...@@ -259,7 +259,7 @@ class DistributedStrategy(object):
def a_sync(self, flag): def a_sync(self, flag):
if isinstance(flag, bool): if isinstance(flag, bool):
self.strategy.a_sync = flag self.strategy.a_sync = flag
self.a_sync_configs = {"k_steps": 0, "worker_device": 'cpu'} self.a_sync_configs = {"k_steps": 0}
else: else:
raise ValueError( raise ValueError(
"The type of `flag` is invalid, expected type is bool, but received %s". "The type of `flag` is invalid, expected type is bool, but received %s".
......
...@@ -31,10 +31,6 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer): ...@@ -31,10 +31,6 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer):
if k_steps < 0: if k_steps < 0:
return False return False
device = self.user_defined_strategy.a_sync_configs["worker_device"]
if device.upper() != 'CPU':
return False
if self.role_maker._is_server(): if self.role_maker._is_server():
return False return False
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
from paddle import fluid from paddle import fluid
from .meta_optimizer_base import MetaOptimizerBase from .meta_optimizer_base import MetaOptimizerBase
from ..base.private_helper_function import wait_server_ready
from paddle.fluid import core from paddle.fluid import core
import subprocess import subprocess
import re import re
...@@ -75,8 +74,6 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -75,8 +74,6 @@ class ParameterServerOptimizer(MetaOptimizerBase):
_startup = worker.delet_extra_optimizes_pass(_startup, _startup = worker.delet_extra_optimizes_pass(_startup,
compiled_config) compiled_config)
compiled_config.set_origin_ps_main_program(_main)
compiled_config.set_origin_ps_startup_program(_startup)
# for heter program # for heter program
if self.role_maker._is_heter_parameter_server_mode: if self.role_maker._is_heter_parameter_server_mode:
from paddle.fluid.incubate.fleet.parameter_server.ir import heter_trainer_pass as heter_worker from paddle.fluid.incubate.fleet.parameter_server.ir import heter_trainer_pass as heter_worker
...@@ -94,16 +91,6 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -94,16 +91,6 @@ class ParameterServerOptimizer(MetaOptimizerBase):
else: else:
_main = worker.append_send_ops_pass(_main, compiled_config) _main = worker.append_send_ops_pass(_main, compiled_config)
_startup = _startup _startup = _startup
compiled_config.set_origin_ps_main_program(_main)
compiled_config.set_origin_ps_startup_program(_startup)
# for trainer wait server ready
wait_server_ready(self.role_maker._get_pserver_endpoints())
# for ps-heter mode, wait heter worker ready
if self.role_maker._is_heter_parameter_server_mode and self.role_maker._is_worker(
):
wait_server_ready(self.role_maker._get_heter_worker_endpoints())
return _main, _startup return _main, _startup
......
...@@ -198,21 +198,16 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -198,21 +198,16 @@ class ParameterServerRuntime(RuntimeBase):
warnings.warn("communicator has been initialized, skip") warnings.warn("communicator has been initialized, skip")
def _get_executor(self): def _get_executor(self):
if self.role_maker._is_heter_worker(): if self.role_maker._is_heter_worker():
heter_worker_device = self.context["valid_strategy"].a_sync_configs[ if self.role_maker._get_heter_worker_device() == "GPU":
"heter_worker_device"].upper()
if heter_worker_device == "GPU":
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0")) gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
executor = Executor(fluid.CUDAPlace(gpu_id)) executor = Executor(fluid.CUDAPlace(gpu_id))
elif heter_worker_device == "XPU": elif self.role_maker._get_heter_worker_device() == "XPU":
xpu_id = int(os.getenv("FLAGS_selected_xpus", "0")) xpu_id = int(os.getenv("FLAGS_selected_xpus", "0"))
executor = Executor(fluid.XPUPlace(xpu_id)) executor = Executor(fluid.XPUPlace(xpu_id))
elif heter_worker_device == "CPU":
fluid.Executor(fluid.CPUPlace())
else: else:
raise ValueError("Heter Worker Not Support Device {}".format( raise ValueError("Not Support Device {}".format(
heter_worker_device)) self.role_maker._get_heter_worker_device()))
else: else:
executor = fluid.Executor(fluid.CPUPlace()) executor = fluid.Executor(fluid.CPUPlace())
return executor return executor
...@@ -463,13 +458,13 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -463,13 +458,13 @@ class ParameterServerRuntime(RuntimeBase):
def _save_distributed_persistables(self, executor, dirname, main_program): def _save_distributed_persistables(self, executor, dirname, main_program):
dense_ctx = self.compiled_strategy.get_communicator_recv_context( dense_ctx = self.compiled_strategy.get_communicator_recv_context(
recv_type=1, use_origin_program=True) recv_type=1)
sparse_ctx = self.compiled_strategy.get_communicator_recv_context( sparse_ctx = self.compiled_strategy.get_communicator_recv_context(
recv_type=2, use_origin_program=True) recv_type=2)
distributed_ctx = self.compiled_strategy.get_communicator_recv_context( distributed_ctx = self.compiled_strategy.get_communicator_recv_context(
recv_type=3, use_origin_program=True) recv_type=3)
recv_dense_varnames = self._save_dense_params(executor, dirname, recv_dense_varnames = self._save_dense_params(executor, dirname,
dense_ctx, main_program) dense_ctx, main_program)
...@@ -521,7 +516,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -521,7 +516,7 @@ class ParameterServerRuntime(RuntimeBase):
) )
if main_program is None: if main_program is None:
main_program = self.compiled_strategy.get_origin_ps_main_program() main_program = fluid.default_main_program()
if isinstance(main_program, CompiledProgram): if isinstance(main_program, CompiledProgram):
raise TypeError( raise TypeError(
......
...@@ -133,8 +133,6 @@ class CompileTimeStrategy(object): ...@@ -133,8 +133,6 @@ class CompileTimeStrategy(object):
self.origin_main_program = main_program self.origin_main_program = main_program
self.origin_startup_program = startup_program self.origin_startup_program = startup_program
self.origin_ps_main_program = main_program
self.origin_ps_startup_program = startup_program
self.strategy = strategy self.strategy = strategy
self.role_maker = role_maker self.role_maker = role_maker
...@@ -155,11 +153,6 @@ class CompileTimeStrategy(object): ...@@ -155,11 +153,6 @@ class CompileTimeStrategy(object):
self._build_var_distributed() self._build_var_distributed()
# for heter-ps save variables
self.origin_merged_variables_pairs = list(self.merged_variables_pairs)
self.origin_merged_dense_pairs = list(self.merged_dense_pairs)
self.origin_merged_sparse_pairs = list(self.merged_sparse_pairs)
def get_distributed_mode(self): def get_distributed_mode(self):
trainer = self.strategy.get_trainer_runtime_config() trainer = self.strategy.get_trainer_runtime_config()
return trainer.mode return trainer.mode
...@@ -221,18 +214,6 @@ class CompileTimeStrategy(object): ...@@ -221,18 +214,6 @@ class CompileTimeStrategy(object):
def get_origin_startup_program(self): def get_origin_startup_program(self):
return self.origin_startup_program return self.origin_startup_program
def set_origin_ps_main_program(self, program):
self.origin_ps_main_program = program
def set_origin_ps_startup_program(self, program):
self.origin_ps_startup_program = program
def get_origin_ps_main_program(self):
return self.origin_ps_main_program
def get_origin_ps_startup_program(self):
return self.origin_ps_startup_program
def get_sparse_varname_on_ps(self, is_distributed, endpoint=None): def get_sparse_varname_on_ps(self, is_distributed, endpoint=None):
if not endpoint: if not endpoint:
endpoint = self.get_ps_endpoint() endpoint = self.get_ps_endpoint()
...@@ -397,9 +378,7 @@ class CompileTimeStrategy(object): ...@@ -397,9 +378,7 @@ class CompileTimeStrategy(object):
send_ctx[name] = ctx send_ctx[name] = ctx
return send_ctx return send_ctx
def get_communicator_recv_context(self, def get_communicator_recv_context(self, recv_type=1):
recv_type=1,
use_origin_program=False):
# recv_type # recv_type
# 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL # 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL
distibuted_varnames = get_sparse_tablenames(self.origin_main_program, distibuted_varnames = get_sparse_tablenames(self.origin_main_program,
...@@ -413,8 +392,7 @@ class CompileTimeStrategy(object): ...@@ -413,8 +392,7 @@ class CompileTimeStrategy(object):
sparse_recv_ctx = {} sparse_recv_ctx = {}
distributed_recv_ctx = {} distributed_recv_ctx = {}
variables_pairs = self.merged_variables_pairs if not use_origin_program else self.origin_merged_variables_pairs for merged in self.merged_variables_pairs:
for merged in variables_pairs:
params = merged[0] params = merged[0]
if params.merged_var.name in sparse_varnames: if params.merged_var.name in sparse_varnames:
continue continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册