diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index d9294048a9e89662958fd5c6af4fcbe5da3814c2..871dfe67343dbe58296095e1df8668e0ebd11d45 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -66,27 +66,47 @@ class AssignFunctor { const platform::DeviceContext &dev_ctx_; }; -class AssignOp : public framework::OperatorBase { +class AssignOp : public framework::OperatorWithKernel { public: AssignOp(const std::string &type, const framework::VariableNameMap &inputs, const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} + : OperatorWithKernel(type, inputs, outputs, attrs) {} - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto *x = scope.FindVar(Input("X")); + void InferShape(framework::InferShapeContext *ctx) const override { + if (ctx->HasInput("X")) { + auto type = ctx->GetInputsVarType("X")[0]; + if (type == framework::proto::VarType::SELECTED_ROWS || + type == framework::proto::VarType::LOD_TENSOR) { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + if (type == framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("X", /*->*/ "Out"); + } + } + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class AssignKernel { + public: + void operator()(const framework::ExecutionContext &ctx) const { + auto *x = ctx.InputVar("X"); if (x == nullptr) { return; } - auto *out = scope.FindVar(Output("Out")); + auto *out = ctx.OutputVar("Out"); PADDLE_ENFORCE( out != nullptr, "The Output(Out) should not be null if the Input(X) is set."); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); + auto &dev_ctx = *pool.Get(ctx.GetPlace()); framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); } @@ -110,19 +130,6 @@ raise error if the type is not listed above. } }; -class AssignInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *context) const override { - if (context->HasInput("X")) { - auto type = context->GetInputsVarType("X")[0]; - if (type == framework::proto::VarType::SELECTED_ROWS || - type == framework::proto::VarType::LOD_TENSOR) { - context->SetOutputDim("Out", context->GetInputDim("X")); - } - } - } -}; - class AssignGradMaker : public framework::SingleGradOpDescMaker { public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; @@ -142,4 +149,13 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker, - ops::AssignInferShape, ops::AssignOpProtoMaker); + ops::AssignOpProtoMaker); +REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, + ops::AssignKernel, int, ops::AssignKernel, + int64_t, ops::AssignKernel); + +#ifdef PADDLE_WITH_CUDA +REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, + ops::AssignKernel, int, ops::AssignKernel, + int64_t, ops::AssignKernel); +#endif diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index fb0882e3533252460e3dd2546e9af8d50f053db6..7e37b3c68da855b8358445e504e855e76c3364be 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -61,13 +61,13 @@ limitations under the License. */ #ifndef _WIN32 #include "paddle/fluid/pybind/nccl_wrapper_py.h" #endif +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/reader_py.h" #include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/string/to_string.h" - #ifdef PADDLE_WITH_CUDA #ifndef _WIN32 #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" @@ -1106,6 +1106,8 @@ All parameter, weight, gradient are variables in Paddle. return std::shared_ptr(std::move(pass)); }); + m.def("size_of_dtype", framework::SizeOfType); + py::class_> pass(m, "Pass"); pass.def(py::init()) .def("has", &ir::Pass::Has) diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 37716cea14c016fca055790d4bbe65c37f058839..c17cfc73de7b5767f842701aba62cf9b29ecd156 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -14,7 +14,7 @@ import os import six import numpy as np - +from collections import OrderedDict from .. import core from . import layers from . import parallel_helper @@ -36,7 +36,7 @@ def prepare_context(strategy=None): strategy.current_endpoint = Env().current_endpoint if strategy.nranks < 2: return - assert framework.in_dygraph_mode() is True,\ + assert framework.in_dygraph_mode() is True, \ "dygraph.parallel.prepare_context should be used with dygrahp mode." place = framework._current_expected_place() assert place is not None, \ @@ -168,6 +168,37 @@ class DataParallel(layers.Layer): loss = loss / loss_scale return loss + def _coalesce_tensors(self, var_groups): + from ..layers import nn + coalesced_grads_and_grad_vars = [] + for group_id, grad_vars in var_groups.items(): + flattened_vars = [] + g_var_shapes = [] + for g_var in grad_vars: + g_var_shapes.append(g_var.shape) + flattened_vars.append( + nn.reshape( + x=g_var, shape=[np.prod(g_var.shape)], inplace=True)) + coalesced_grad = nn.concat(flattened_vars) + coalesced_grads_and_grad_vars.append( + [coalesced_grad, grad_vars, g_var_shapes]) + return coalesced_grads_and_grad_vars + + def _split_tensors(self, coalesced_grads_and_grad_vars): + from ..layers import nn + for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars: + grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes] + splited_vars = nn.split( + coalesced_grad, num_or_sections=grad_var_len, dim=0) + reshaped_grad_vars = [] + for g_var, g_shape in zip(splited_vars, grad_shapes): + reshaped_grad_vars.append( + nn.reshape( + x=g_var, shape=g_shape, inplace=True)) + for origin_g_var, reshaped_g_var in zip(origin_grad_vars, + reshaped_grad_vars): + nn.assign(input=reshaped_g_var, output=origin_g_var) + def apply_collective_grads(self): """ AllReduce the Parameters' gradient. @@ -175,6 +206,8 @@ class DataParallel(layers.Layer): if not self._is_data_parallel_mode(): return + grad_var_set = set() + grad_vars = [] for param in self._layers.parameters(): # NOTE(zcd): The grad_ivar maybe no generated. if param.trainable and param._ivar._grad_ivar(): @@ -183,7 +216,36 @@ class DataParallel(layers.Layer): name=param._ivar._grad_name(), stop_gradient=True, ivar=param._ivar._grad_ivar()) - collective._allreduce(g_var, g_var, sync_mode=True) + grad_vars.append(g_var) + assert g_var not in grad_var_set + grad_var_set.add(g_var) + + # FIXME(zcd): the type of the var should be LoDTensor, i.e + # the gradients should be dense, otherwise, the following + # logic should be updated. + # 128 MB as a group + mega_bytes = 128 * 1024 * 1024 + group_idx = 0 + memory_counter = 0 + grad_var_groups = OrderedDict() + dtype = grad_vars[0].dtype + for g_var in grad_vars: + # Note: the dtype of the same group should be the same. + bytes = np.prod(g_var.shape) * core.size_of_dtype(g_var.dtype) + if memory_counter < mega_bytes and dtype == g_var.dtype: + memory_counter += bytes + else: + memory_counter = bytes + group_idx += 1 + grad_var_groups.setdefault(group_idx, []).append(g_var) + + coalesced_grads_and_vars = self._coalesce_tensors(grad_var_groups) + + for coalesced_grad, g_vars, g_shapes in coalesced_grads_and_vars: + collective._allreduce( + coalesced_grad, coalesced_grad, sync_mode=False) + + self._split_tensors(coalesced_grads_and_vars) def _is_data_parallel_mode(self): return self._strategy.nranks > 1