未验证 提交 a4f85074 编写于 作者: 1 123malin 提交者: GitHub

【paddle.fleet】bug fix for parameter_recv (#27838)

* test=develop, bug fix for parameter_recv
* test=develop, for unittest, test_fleet_rolemaker_new
上级 b19b01af
......@@ -106,9 +106,8 @@ class RequestSend final : public RequestBase {
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(
request_handler->scope(), request_handler->dev_ctx(),
request_handler->distributed_mode()));
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true));
int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
......@@ -420,9 +419,8 @@ class RequestNotify final : public RequestBase {
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(
request_handler->scope(), request_handler->dev_ctx(),
request_handler->distributed_mode()));
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true));
int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
......@@ -455,9 +453,8 @@ class RequestSendAndRecv final : public RequestBase {
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(
request_handler->scope(), request_handler->dev_ctx(),
request_handler->distributed_mode()));
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true));
int method_id =
static_cast<int>(distributed::GrpcMethod::kRequestSendAndRecv);
......
......@@ -52,22 +52,25 @@ void RecvSparseLodTensor(const CommContext &rpc_ctx,
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
std::vector<const float *> tensors;
std::vector<distributed::VarHandlePtr> rets;
std::vector<std::string> recv_varnames;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *local_var = local_scope->Var(recv_var_name);
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
local_scope->Var(recv_var_name);
// sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVarNoBarrier(
rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name,
recv_var_name));
const auto *value = local_var->Get<framework::LoDTensor>().data<float>();
tensors.push_back(value);
recv_varnames.push_back(recv_var_name);
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
auto &recv_var_name = recv_varnames[i];
auto *local_var = local_scope->FindVar(recv_var_name);
const auto *value = local_var->Get<framework::LoDTensor>().data<float>();
tensors.push_back(value);
}
auto *merged_var = scope.FindVar(rpc_ctx.var_name);
......
......@@ -825,7 +825,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
if self._is_first_worker():
start_http_server = True
else:
ep_rank_0 = self._server_endpoints[0]
ep_rank_0 = os.getenv("PADDLE_GLOO_HTTP_ENDPOINT", "")
if self._server_index() == 0:
start_http_server = True
ip, port = ep_rank_0.split(':')
......
......@@ -141,6 +141,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
ps_group.add_argument("--server_num", type=int, help="number of servers")
ps_group.add_argument(
"--heter_worker_num", type=int, help="number of heter_workers")
ps_group.add_argument("--http_port", type=int, help="Gloo http Port")
return parser.parse_args()
......@@ -249,12 +250,8 @@ def launch_ps(args, distribute_mode):
def which_distributed_mode(args):
ps_args = [
'--worker_num',
'--server_num',
'--heter_worker_num',
'--servers',
'--workers',
'--heter_workers',
'--worker_num', '--server_num', '--heter_worker_num', '--servers',
'--workers', '--heter_workers', '--http_port'
]
collective_args = ['--ips']
......@@ -292,9 +289,16 @@ def which_distributed_mode(args):
format(has_collective_args, cuda_device_num))
return DistributeMode.COLLECTIVE
else:
logger.warning(
"Not found distinct arguments. Default use gpu collective mode")
return DistributeMode.COLLECTIVE
if not fluid.core.is_compiled_with_cuda():
logger.warning(
"Not found distinct arguments and not compiled with cuda. Default use ps mode"
)
return DistributeMode.PS
else:
logger.warning(
"Not found distinct arguments and compiled with cuda. Default use collective mode"
)
return DistributeMode.COLLECTIVE
def launch():
......
......@@ -713,6 +713,14 @@ class ParameterServerLauncher(object):
else:
self.worker_endpoints = args.workers
# get http_port
if args.http_port:
self.http_port = args.http_port
else:
http_port = get_ports(1, self.server_num + self.worker_num)
http_ip = self.server_endpoints.split(",")[0].split(":")[0]
self.http_port = http_ip + ":" + str(http_port[0])
# get heter worker envs
if self.distribute_mode == DistributeMode.PS_HETER:
if args.heter_worker_num:
......@@ -827,7 +835,8 @@ class ParameterServerLauncher(object):
self.start_pod_server(self.args, pod)
self.start_pod_worker(self.args, pod)
self.start_pod_heter_worker(self.args, pod)
if self.distribute_mode == DistributeMode.PS_HETER:
self.start_pod_heter_worker(self.args, pod)
logger.info(
"Please check servers, workers and heter_worker logs in {}/workerlog.*, {}/serverlog.* and {}/heterlog.*".
......@@ -887,7 +896,8 @@ class ParameterServerLauncher(object):
"POD_IP": cur_server.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
}
current_env.update(proc_env)
......@@ -938,7 +948,8 @@ class ParameterServerLauncher(object):
device_list = [str(x) for x in range(0, heter_device_num)]
for idx, cur_worker in enumerate(pod.workers):
device_id = str(device_list[idx % heter_device_num])
device_id = "0" if heter_device_num == 0 else str(device_list[
idx % heter_device_num])
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
......@@ -954,6 +965,7 @@ class ParameterServerLauncher(object):
"FLAGS_selected_xpus": "0",
"CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
}
current_env.update(proc_env)
......@@ -1022,6 +1034,7 @@ class ParameterServerLauncher(object):
"FLAGS_selected_xpus": "0",
"CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
}
current_env.update(proc_env)
......
......@@ -282,8 +282,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "1"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3"
os.environ["PADDLE_GLOO_HTTP_HOST"] = "127.0.0.1"
os.environ["PADDLE_GLOO_HTTP_PORT"] = "30019"
os.environ["PADDLE_GLOO_HTTP_ENDPOINT"] = "127.0.0.1:30019"
role = role_maker.PaddleCloudRoleMaker()
role._generate_role()
......@@ -541,8 +540,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "1"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3"
os.environ["PADDLE_GLOO_HTTP_HOST"] = "127.0.0.1"
os.environ["PADDLE_GLOO_HTTP_PORT"] = "30019"
os.environ["PADDLE_GLOO_HTTP_ENDPOINT"] = "127.0.0.1:30019"
role = role_maker.PaddleCloudRoleMaker()
role._generate_role()
......@@ -673,8 +671,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "1"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3"
os.environ["PADDLE_GLOO_HTTP_HOST"] = ""
os.environ["PADDLE_GLOO_HTTP_PORT"] = ""
os.environ["PADDLE_GLOO_HTTP_ENDPOINT"] = ""
role = role_maker.PaddleCloudRoleMaker()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册