提交 3c6b733d 编写于 作者: Q Qiao Longfei

remove exe context

上级 9573d610
...@@ -40,7 +40,6 @@ using DDim = framework::DDim; ...@@ -40,7 +40,6 @@ using DDim = framework::DDim;
template <typename T> template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
const framework::ExecutionContext &ctx,
const framework::Scope &scope) { const framework::Scope &scope) {
framework::Scope *local_scope = scope.NewTmpScope(); framework::Scope *local_scope = scope.NewTmpScope();
...@@ -48,8 +47,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, ...@@ -48,8 +47,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
auto &cpu_ctx = *pool.Get(platform::CPUPlace()); auto &cpu_ctx = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
ctx.Attr<int>("trainer_id"));
auto *recv_var = scope.FindVar(rpc_ctx.var_name); auto *recv_var = scope.FindVar(rpc_ctx.var_name);
...@@ -80,12 +78,13 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, ...@@ -80,12 +78,13 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
size_t output_offset = 0; size_t output_offset = 0;
framework::Tensor *recv_tensor = framework::Tensor *recv_tensor =
recv_var->GetMutable<framework::LoDTensor>(); recv_var->GetMutable<framework::LoDTensor>();
auto dev_ctx = paddle::platform::CPUDeviceContext();
for (auto *in : recved_tensors) { for (auto *in : recved_tensors) {
auto in_stride = framework::stride_numel(in->dims()); auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(recv_tensor->dims()); auto out_stride = framework::stride_numel(recv_tensor->dims());
StridedNumelCopyWithAxis<T>( StridedNumelCopyWithAxis<T>(
ctx.device_context(), 0, recv_tensor->data<T>() + output_offset, dev_ctx, 0, recv_tensor->data<T>() + output_offset, out_stride,
out_stride, in->data<T>(), in_stride, in_stride[0]); in->data<T>(), in_stride, in_stride[0]);
output_offset += in_stride[0]; output_offset += in_stride[0];
} }
} }
......
...@@ -26,9 +26,7 @@ namespace distributed { ...@@ -26,9 +26,7 @@ namespace distributed {
template <typename T> template <typename T>
struct ParameterRecv { struct ParameterRecv {
void operator()(const RpcContext &rpc_ctx, void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope);
const framework::ExecutionContext &context,
const framework::Scope &scope);
}; };
}; // namespace distributed }; // namespace distributed
......
...@@ -39,7 +39,6 @@ using DDim = framework::DDim; ...@@ -39,7 +39,6 @@ using DDim = framework::DDim;
template <typename T> template <typename T>
void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
const framework::ExecutionContext &ctx,
const framework::Scope &scope, bool sync) { const framework::Scope &scope, bool sync) {
framework::Scope *local_scope = scope.NewTmpScope(); framework::Scope *local_scope = scope.NewTmpScope();
...@@ -47,8 +46,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -47,8 +46,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto &cpu_ctx = *pool.Get(platform::CPUPlace()); auto &cpu_ctx = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
ctx.Attr<int>("trainer_id"));
auto *send_var = scope.FindVar(rpc_ctx.var_name); auto *send_var = scope.FindVar(rpc_ctx.var_name);
size_t out_num = rpc_ctx.splited_var_names.size(); size_t out_num = rpc_ctx.splited_var_names.size();
...@@ -105,7 +103,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -105,7 +103,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
outs_rows_idx[out_idx].push_back(send_rows[i]); outs_rows_idx[out_idx].push_back(send_rows[i]);
outs_dense_idx[out_idx].push_back(i); outs_dense_idx[out_idx].push_back(i);
} }
auto place = ctx.GetPlace(); auto place = platform::CPUPlace();
for (size_t i = 0; i < outs_rows_idx.size(); ++i) { for (size_t i = 0; i < outs_rows_idx.size(); ++i) {
auto rows_idx = outs_rows_idx[i]; auto rows_idx = outs_rows_idx[i];
...@@ -118,22 +116,25 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -118,22 +116,25 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
for (auto idx : rows_idx) { for (auto idx : rows_idx) {
outs[i]->mutable_rows()->push_back(idx - abs_sections[i]); outs[i]->mutable_rows()->push_back(idx - abs_sections[i]);
} }
auto dst = outs[i]->mutable_value()->mutable_data<T>(ctx.GetPlace()); auto dst = outs[i]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) { for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
memory::Copy( memory::Copy(
platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(), platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(),
src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel); src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel);
} else { } else {
#ifdef PADDLE_WITH_CUDA PADDLE_THROW("do not support GPU now");
/*
#ifdef PADDLE_WITH_CUDA
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
memory::Copy(platform::CUDAPlace(), dst + j * row_numel, memory::Copy(platform::CUDAPlace(), dst + j * row_numel,
platform::CUDAPlace(), platform::CUDAPlace(),
src + outs_dense_idx[i][j] * row_numel, src + outs_dense_idx[i][j] * row_numel,
sizeof(T) * row_numel, stream); sizeof(T) * row_numel, stream);
#else #else
PADDLE_THROW("Paddle is not compiled with GPU"); PADDLE_THROW("Paddle is not compiled with GPU");
#endif #endif
*/
} }
} }
} }
......
...@@ -26,9 +26,8 @@ namespace distributed { ...@@ -26,9 +26,8 @@ namespace distributed {
template <typename T> template <typename T>
struct ParameterSend { struct ParameterSend {
void operator()(const RpcContext &rpc_ctx, void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope,
const framework::ExecutionContext &context, bool sync);
const framework::Scope &scope, bool sync);
}; };
}; // namespace distributed }; // namespace distributed
......
...@@ -62,7 +62,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -62,7 +62,7 @@ class RecvOp : public framework::OperatorBase {
framework::ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr); framework::ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr);
auto recv_functor = distributed::ParameterRecv<float>(); auto recv_functor = distributed::ParameterRecv<float>();
auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}); auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {});
recv_functor(rpc_ctx, exe_ctx, scope); recv_functor(rpc_ctx, scope);
} else { } else {
if (with_barrier) { if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
......
...@@ -56,7 +56,7 @@ class SendOp : public framework::OperatorBase { ...@@ -56,7 +56,7 @@ class SendOp : public framework::OperatorBase {
auto send_functor = distributed::ParameterSend<float>(); auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
height_sections); height_sections);
send_functor(rpc_ctx, exe_ctx, scope, static_cast<bool>(sync_send)); send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
} else { } else {
platform::DeviceContextPool& pool = platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册