From 3c6b733d14c0db61eb70208aa79c3999f29efc1d Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 4 Mar 2019 12:11:21 +0800 Subject: [PATCH] remove exe context --- .../operators/distributed/parameter_recv.cc | 9 +++--- .../operators/distributed/parameter_recv.h | 4 +-- .../operators/distributed/parameter_send.cc | 29 ++++++++++--------- .../operators/distributed/parameter_send.h | 5 ++-- .../operators/distributed_ops/recv_op.cc | 2 +- .../operators/distributed_ops/send_op.cc | 2 +- 6 files changed, 24 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index 00956d8e6d..fecc76955d 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -40,7 +40,6 @@ using DDim = framework::DDim; template void ParameterRecv::operator()(const RpcContext &rpc_ctx, - const framework::ExecutionContext &ctx, const framework::Scope &scope) { framework::Scope *local_scope = scope.NewTmpScope(); @@ -48,8 +47,7 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, auto &cpu_ctx = *pool.Get(platform::CPUPlace()); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance( - ctx.Attr("trainer_id")); + distributed::RPCClient::GetInstance(0); auto *recv_var = scope.FindVar(rpc_ctx.var_name); @@ -80,12 +78,13 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, size_t output_offset = 0; framework::Tensor *recv_tensor = recv_var->GetMutable(); + auto dev_ctx = paddle::platform::CPUDeviceContext(); for (auto *in : recved_tensors) { auto in_stride = framework::stride_numel(in->dims()); auto out_stride = framework::stride_numel(recv_tensor->dims()); StridedNumelCopyWithAxis( - ctx.device_context(), 0, recv_tensor->data() + output_offset, - out_stride, in->data(), in_stride, in_stride[0]); + dev_ctx, 0, recv_tensor->data() + output_offset, out_stride, + in->data(), in_stride, in_stride[0]); output_offset += in_stride[0]; } } diff --git a/paddle/fluid/operators/distributed/parameter_recv.h b/paddle/fluid/operators/distributed/parameter_recv.h index e25594024a..e955fca725 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.h +++ b/paddle/fluid/operators/distributed/parameter_recv.h @@ -26,9 +26,7 @@ namespace distributed { template struct ParameterRecv { - void operator()(const RpcContext &rpc_ctx, - const framework::ExecutionContext &context, - const framework::Scope &scope); + void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope); }; }; // namespace distributed diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index eaa1c3ae8e..3fe3be193a 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -39,7 +39,6 @@ using DDim = framework::DDim; template void ParameterSend::operator()(const RpcContext &rpc_ctx, - const framework::ExecutionContext &ctx, const framework::Scope &scope, bool sync) { framework::Scope *local_scope = scope.NewTmpScope(); @@ -47,8 +46,7 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, auto &cpu_ctx = *pool.Get(platform::CPUPlace()); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance( - ctx.Attr("trainer_id")); + distributed::RPCClient::GetInstance(0); auto *send_var = scope.FindVar(rpc_ctx.var_name); size_t out_num = rpc_ctx.splited_var_names.size(); @@ -105,7 +103,7 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, outs_rows_idx[out_idx].push_back(send_rows[i]); outs_dense_idx[out_idx].push_back(i); } - auto place = ctx.GetPlace(); + auto place = platform::CPUPlace(); for (size_t i = 0; i < outs_rows_idx.size(); ++i) { auto rows_idx = outs_rows_idx[i]; @@ -118,22 +116,25 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, for (auto idx : rows_idx) { outs[i]->mutable_rows()->push_back(idx - abs_sections[i]); } - auto dst = outs[i]->mutable_value()->mutable_data(ctx.GetPlace()); + auto dst = outs[i]->mutable_value()->mutable_data(place); for (size_t j = 0; j < rows_idx.size(); j++) { if (platform::is_cpu_place(place)) { memory::Copy( platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(), src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel); } else { -#ifdef PADDLE_WITH_CUDA - auto stream = ctx.cuda_device_context().stream(); - memory::Copy(platform::CUDAPlace(), dst + j * row_numel, - platform::CUDAPlace(), - src + outs_dense_idx[i][j] * row_numel, - sizeof(T) * row_numel, stream); -#else - PADDLE_THROW("Paddle is not compiled with GPU"); -#endif + PADDLE_THROW("do not support GPU now"); + /* + #ifdef PADDLE_WITH_CUDA + auto stream = ctx.cuda_device_context().stream(); + memory::Copy(platform::CUDAPlace(), dst + j * row_numel, + platform::CUDAPlace(), + src + outs_dense_idx[i][j] * row_numel, + sizeof(T) * row_numel, stream); + #else + PADDLE_THROW("Paddle is not compiled with GPU"); + #endif + */ } } } diff --git a/paddle/fluid/operators/distributed/parameter_send.h b/paddle/fluid/operators/distributed/parameter_send.h index 4500497163..9077f4a4fb 100644 --- a/paddle/fluid/operators/distributed/parameter_send.h +++ b/paddle/fluid/operators/distributed/parameter_send.h @@ -26,9 +26,8 @@ namespace distributed { template struct ParameterSend { - void operator()(const RpcContext &rpc_ctx, - const framework::ExecutionContext &context, - const framework::Scope &scope, bool sync); + void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope, + bool sync); }; }; // namespace distributed diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index a4a5ab89a7..41701d3a3e 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -62,7 +62,7 @@ class RecvOp : public framework::OperatorBase { framework::ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr); auto recv_functor = distributed::ParameterRecv(); auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}); - recv_functor(rpc_ctx, exe_ctx, scope); + recv_functor(rpc_ctx, scope); } else { if (with_barrier) { std::vector rets; diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index 1823d89897..5585ad21ce 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -56,7 +56,7 @@ class SendOp : public framework::OperatorBase { auto send_functor = distributed::ParameterSend(); auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, height_sections); - send_functor(rpc_ctx, exe_ctx, scope, static_cast(sync_send)); + send_functor(rpc_ctx, scope, static_cast(sync_send)); } else { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); -- GitLab