未验证 提交 a4f85074 编写于 作者: 1 123malin 提交者: GitHub

【paddle.fleet】bug fix for parameter_recv (#27838)

* test=develop, bug fix for parameter_recv
* test=develop, for unittest, test_fleet_rolemaker_new
上级 b19b01af
...@@ -106,9 +106,8 @@ class RequestSend final : public RequestBase { ...@@ -106,9 +106,8 @@ class RequestSend final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse( request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->scope(), request_handler->dev_ctx(), request_handler->dev_ctx(), true));
request_handler->distributed_mode()));
int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable); int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
...@@ -420,9 +419,8 @@ class RequestNotify final : public RequestBase { ...@@ -420,9 +419,8 @@ class RequestNotify final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse( request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->scope(), request_handler->dev_ctx(), request_handler->dev_ctx(), true));
request_handler->distributed_mode()));
int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify); int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
...@@ -455,9 +453,8 @@ class RequestSendAndRecv final : public RequestBase { ...@@ -455,9 +453,8 @@ class RequestSendAndRecv final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse( request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->scope(), request_handler->dev_ctx(), request_handler->dev_ctx(), true));
request_handler->distributed_mode()));
int method_id = int method_id =
static_cast<int>(distributed::GrpcMethod::kRequestSendAndRecv); static_cast<int>(distributed::GrpcMethod::kRequestSendAndRecv);
......
...@@ -52,22 +52,25 @@ void RecvSparseLodTensor(const CommContext &rpc_ctx, ...@@ -52,22 +52,25 @@ void RecvSparseLodTensor(const CommContext &rpc_ctx,
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope(); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
std::vector<const float *> tensors; std::vector<const float *> tensors;
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
std::vector<std::string> recv_varnames;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i]; auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *local_var = local_scope->Var(recv_var_name);
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
local_scope->Var(recv_var_name);
// sparse param in recv_scope is LoDTensor // sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVarNoBarrier( rets.push_back(rpc_client->AsyncGetVarNoBarrier(
rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name, rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name,
recv_var_name)); recv_var_name));
recv_varnames.push_back(recv_var_name);
const auto *value = local_var->Get<framework::LoDTensor>().data<float>();
tensors.push_back(value);
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout( PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient")); "internal error in RPCClient"));
auto &recv_var_name = recv_varnames[i];
auto *local_var = local_scope->FindVar(recv_var_name);
const auto *value = local_var->Get<framework::LoDTensor>().data<float>();
tensors.push_back(value);
} }
auto *merged_var = scope.FindVar(rpc_ctx.var_name); auto *merged_var = scope.FindVar(rpc_ctx.var_name);
......
...@@ -825,7 +825,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -825,7 +825,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
if self._is_first_worker(): if self._is_first_worker():
start_http_server = True start_http_server = True
else: else:
ep_rank_0 = self._server_endpoints[0] ep_rank_0 = os.getenv("PADDLE_GLOO_HTTP_ENDPOINT", "")
if self._server_index() == 0: if self._server_index() == 0:
start_http_server = True start_http_server = True
ip, port = ep_rank_0.split(':') ip, port = ep_rank_0.split(':')
......
...@@ -141,6 +141,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -141,6 +141,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
ps_group.add_argument("--server_num", type=int, help="number of servers") ps_group.add_argument("--server_num", type=int, help="number of servers")
ps_group.add_argument( ps_group.add_argument(
"--heter_worker_num", type=int, help="number of heter_workers") "--heter_worker_num", type=int, help="number of heter_workers")
ps_group.add_argument("--http_port", type=int, help="Gloo http Port")
return parser.parse_args() return parser.parse_args()
...@@ -249,12 +250,8 @@ def launch_ps(args, distribute_mode): ...@@ -249,12 +250,8 @@ def launch_ps(args, distribute_mode):
def which_distributed_mode(args): def which_distributed_mode(args):
ps_args = [ ps_args = [
'--worker_num', '--worker_num', '--server_num', '--heter_worker_num', '--servers',
'--server_num', '--workers', '--heter_workers', '--http_port'
'--heter_worker_num',
'--servers',
'--workers',
'--heter_workers',
] ]
collective_args = ['--ips'] collective_args = ['--ips']
...@@ -292,9 +289,16 @@ def which_distributed_mode(args): ...@@ -292,9 +289,16 @@ def which_distributed_mode(args):
format(has_collective_args, cuda_device_num)) format(has_collective_args, cuda_device_num))
return DistributeMode.COLLECTIVE return DistributeMode.COLLECTIVE
else: else:
logger.warning( if not fluid.core.is_compiled_with_cuda():
"Not found distinct arguments. Default use gpu collective mode") logger.warning(
return DistributeMode.COLLECTIVE "Not found distinct arguments and not compiled with cuda. Default use ps mode"
)
return DistributeMode.PS
else:
logger.warning(
"Not found distinct arguments and compiled with cuda. Default use collective mode"
)
return DistributeMode.COLLECTIVE
def launch(): def launch():
......
...@@ -713,6 +713,14 @@ class ParameterServerLauncher(object): ...@@ -713,6 +713,14 @@ class ParameterServerLauncher(object):
else: else:
self.worker_endpoints = args.workers self.worker_endpoints = args.workers
# get http_port
if args.http_port:
self.http_port = args.http_port
else:
http_port = get_ports(1, self.server_num + self.worker_num)
http_ip = self.server_endpoints.split(",")[0].split(":")[0]
self.http_port = http_ip + ":" + str(http_port[0])
# get heter worker envs # get heter worker envs
if self.distribute_mode == DistributeMode.PS_HETER: if self.distribute_mode == DistributeMode.PS_HETER:
if args.heter_worker_num: if args.heter_worker_num:
...@@ -827,7 +835,8 @@ class ParameterServerLauncher(object): ...@@ -827,7 +835,8 @@ class ParameterServerLauncher(object):
self.start_pod_server(self.args, pod) self.start_pod_server(self.args, pod)
self.start_pod_worker(self.args, pod) self.start_pod_worker(self.args, pod)
self.start_pod_heter_worker(self.args, pod) if self.distribute_mode == DistributeMode.PS_HETER:
self.start_pod_heter_worker(self.args, pod)
logger.info( logger.info(
"Please check servers, workers and heter_worker logs in {}/workerlog.*, {}/serverlog.* and {}/heterlog.*". "Please check servers, workers and heter_worker logs in {}/workerlog.*, {}/serverlog.* and {}/heterlog.*".
...@@ -887,7 +896,8 @@ class ParameterServerLauncher(object): ...@@ -887,7 +896,8 @@ class ParameterServerLauncher(object):
"POD_IP": cur_server.endpoint.split(":")[0], "POD_IP": cur_server.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": "1", "PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2", "PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
} }
current_env.update(proc_env) current_env.update(proc_env)
...@@ -938,7 +948,8 @@ class ParameterServerLauncher(object): ...@@ -938,7 +948,8 @@ class ParameterServerLauncher(object):
device_list = [str(x) for x in range(0, heter_device_num)] device_list = [str(x) for x in range(0, heter_device_num)]
for idx, cur_worker in enumerate(pod.workers): for idx, cur_worker in enumerate(pod.workers):
device_id = str(device_list[idx % heter_device_num]) device_id = "0" if heter_device_num == 0 else str(device_list[
idx % heter_device_num])
proc_env = { proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
...@@ -954,6 +965,7 @@ class ParameterServerLauncher(object): ...@@ -954,6 +965,7 @@ class ParameterServerLauncher(object):
"FLAGS_selected_xpus": "0", "FLAGS_selected_xpus": "0",
"CUDA_VISIBLE_DEVICES": device_id, "CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id, "XPU_VISIBLE_DEVICES": device_id,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
} }
current_env.update(proc_env) current_env.update(proc_env)
...@@ -1022,6 +1034,7 @@ class ParameterServerLauncher(object): ...@@ -1022,6 +1034,7 @@ class ParameterServerLauncher(object):
"FLAGS_selected_xpus": "0", "FLAGS_selected_xpus": "0",
"CUDA_VISIBLE_DEVICES": device_id, "CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id, "XPU_VISIBLE_DEVICES": device_id,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
} }
current_env.update(proc_env) current_env.update(proc_env)
......
...@@ -282,8 +282,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase): ...@@ -282,8 +282,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
os.environ["SYS_JOB_ID"] = "gloo_for_cluster" os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "1" os.environ["PADDLE_WITH_GLOO"] = "1"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3" os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3"
os.environ["PADDLE_GLOO_HTTP_HOST"] = "127.0.0.1" os.environ["PADDLE_GLOO_HTTP_ENDPOINT"] = "127.0.0.1:30019"
os.environ["PADDLE_GLOO_HTTP_PORT"] = "30019"
role = role_maker.PaddleCloudRoleMaker() role = role_maker.PaddleCloudRoleMaker()
role._generate_role() role._generate_role()
...@@ -541,8 +540,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase): ...@@ -541,8 +540,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
os.environ["SYS_JOB_ID"] = "gloo_for_cluster" os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "1" os.environ["PADDLE_WITH_GLOO"] = "1"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3" os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3"
os.environ["PADDLE_GLOO_HTTP_HOST"] = "127.0.0.1" os.environ["PADDLE_GLOO_HTTP_ENDPOINT"] = "127.0.0.1:30019"
os.environ["PADDLE_GLOO_HTTP_PORT"] = "30019"
role = role_maker.PaddleCloudRoleMaker() role = role_maker.PaddleCloudRoleMaker()
role._generate_role() role._generate_role()
...@@ -673,8 +671,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase): ...@@ -673,8 +671,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
os.environ["SYS_JOB_ID"] = "gloo_for_cluster" os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "1" os.environ["PADDLE_WITH_GLOO"] = "1"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3" os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3"
os.environ["PADDLE_GLOO_HTTP_HOST"] = "" os.environ["PADDLE_GLOO_HTTP_ENDPOINT"] = ""
os.environ["PADDLE_GLOO_HTTP_PORT"] = ""
role = role_maker.PaddleCloudRoleMaker() role = role_maker.PaddleCloudRoleMaker()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册