未验证 提交 20859c08 编写于 作者: C chengduo 提交者: GitHub

[DyGraph] Make multi-card program faster (#18892)

* update parallel.py
test=develop
上级 24f85431
...@@ -66,27 +66,47 @@ class AssignFunctor { ...@@ -66,27 +66,47 @@ class AssignFunctor {
const platform::DeviceContext &dev_ctx_; const platform::DeviceContext &dev_ctx_;
}; };
class AssignOp : public framework::OperatorBase { class AssignOp : public framework::OperatorWithKernel {
public: public:
AssignOp(const std::string &type, const framework::VariableNameMap &inputs, AssignOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
private: void InferShape(framework::InferShapeContext *ctx) const override {
void RunImpl(const framework::Scope &scope, if (ctx->HasInput("X")) {
const platform::Place &place) const override { auto type = ctx->GetInputsVarType("X")[0];
auto *x = scope.FindVar(Input("X")); if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarType::LOD_TENSOR) {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
class AssignKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *x = ctx.InputVar("X");
if (x == nullptr) { if (x == nullptr) {
return; return;
} }
auto *out = scope.FindVar(Output("Out")); auto *out = ctx.OutputVar("Out");
PADDLE_ENFORCE( PADDLE_ENFORCE(
out != nullptr, out != nullptr,
"The Output(Out) should not be null if the Input(X) is set."); "The Output(Out) should not be null if the Input(X) is set.");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(ctx.GetPlace());
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
} }
...@@ -110,19 +130,6 @@ raise error if the type is not listed above. ...@@ -110,19 +130,6 @@ raise error if the type is not listed above.
} }
}; };
class AssignInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
if (context->HasInput("X")) {
auto type = context->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarType::LOD_TENSOR) {
context->SetOutputDim("Out", context->GetInputDim("X"));
}
}
}
};
class AssignGradMaker : public framework::SingleGradOpDescMaker { class AssignGradMaker : public framework::SingleGradOpDescMaker {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
...@@ -142,4 +149,13 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker { ...@@ -142,4 +149,13 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker, REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker,
ops::AssignInferShape, ops::AssignOpProtoMaker); ops::AssignOpProtoMaker);
REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel);
#endif
...@@ -61,13 +61,13 @@ limitations under the License. */ ...@@ -61,13 +61,13 @@ limitations under the License. */
#ifndef _WIN32 #ifndef _WIN32
#include "paddle/fluid/pybind/nccl_wrapper_py.h" #include "paddle/fluid/pybind/nccl_wrapper_py.h"
#endif #endif
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/protobuf.h"
#include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/pybind.h" // NOLINT
#include "paddle/fluid/pybind/reader_py.h" #include "paddle/fluid/pybind/reader_py.h"
#include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/recordio.h"
#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/fluid/string/to_string.h" #include "paddle/fluid/string/to_string.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#ifndef _WIN32 #ifndef _WIN32
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h" #include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
...@@ -1106,6 +1106,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1106,6 +1106,8 @@ All parameter, weight, gradient are variables in Paddle.
return std::shared_ptr<framework::ir::Pass>(std::move(pass)); return std::shared_ptr<framework::ir::Pass>(std::move(pass));
}); });
m.def("size_of_dtype", framework::SizeOfType);
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass"); py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init()) pass.def(py::init())
.def("has", &ir::Pass::Has) .def("has", &ir::Pass::Has)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
import six import six
import numpy as np import numpy as np
from collections import OrderedDict
from .. import core from .. import core
from . import layers from . import layers
from . import parallel_helper from . import parallel_helper
...@@ -36,7 +36,7 @@ def prepare_context(strategy=None): ...@@ -36,7 +36,7 @@ def prepare_context(strategy=None):
strategy.current_endpoint = Env().current_endpoint strategy.current_endpoint = Env().current_endpoint
if strategy.nranks < 2: if strategy.nranks < 2:
return return
assert framework.in_dygraph_mode() is True,\ assert framework.in_dygraph_mode() is True, \
"dygraph.parallel.prepare_context should be used with dygrahp mode." "dygraph.parallel.prepare_context should be used with dygrahp mode."
place = framework._current_expected_place() place = framework._current_expected_place()
assert place is not None, \ assert place is not None, \
...@@ -168,6 +168,37 @@ class DataParallel(layers.Layer): ...@@ -168,6 +168,37 @@ class DataParallel(layers.Layer):
loss = loss / loss_scale loss = loss / loss_scale
return loss return loss
def _coalesce_tensors(self, var_groups):
from ..layers import nn
coalesced_grads_and_grad_vars = []
for group_id, grad_vars in var_groups.items():
flattened_vars = []
g_var_shapes = []
for g_var in grad_vars:
g_var_shapes.append(g_var.shape)
flattened_vars.append(
nn.reshape(
x=g_var, shape=[np.prod(g_var.shape)], inplace=True))
coalesced_grad = nn.concat(flattened_vars)
coalesced_grads_and_grad_vars.append(
[coalesced_grad, grad_vars, g_var_shapes])
return coalesced_grads_and_grad_vars
def _split_tensors(self, coalesced_grads_and_grad_vars):
from ..layers import nn
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
splited_vars = nn.split(
coalesced_grad, num_or_sections=grad_var_len, dim=0)
reshaped_grad_vars = []
for g_var, g_shape in zip(splited_vars, grad_shapes):
reshaped_grad_vars.append(
nn.reshape(
x=g_var, shape=g_shape, inplace=True))
for origin_g_var, reshaped_g_var in zip(origin_grad_vars,
reshaped_grad_vars):
nn.assign(input=reshaped_g_var, output=origin_g_var)
def apply_collective_grads(self): def apply_collective_grads(self):
""" """
AllReduce the Parameters' gradient. AllReduce the Parameters' gradient.
...@@ -175,6 +206,8 @@ class DataParallel(layers.Layer): ...@@ -175,6 +206,8 @@ class DataParallel(layers.Layer):
if not self._is_data_parallel_mode(): if not self._is_data_parallel_mode():
return return
grad_var_set = set()
grad_vars = []
for param in self._layers.parameters(): for param in self._layers.parameters():
# NOTE(zcd): The grad_ivar maybe no generated. # NOTE(zcd): The grad_ivar maybe no generated.
if param.trainable and param._ivar._grad_ivar(): if param.trainable and param._ivar._grad_ivar():
...@@ -183,7 +216,36 @@ class DataParallel(layers.Layer): ...@@ -183,7 +216,36 @@ class DataParallel(layers.Layer):
name=param._ivar._grad_name(), name=param._ivar._grad_name(),
stop_gradient=True, stop_gradient=True,
ivar=param._ivar._grad_ivar()) ivar=param._ivar._grad_ivar())
collective._allreduce(g_var, g_var, sync_mode=True) grad_vars.append(g_var)
assert g_var not in grad_var_set
grad_var_set.add(g_var)
# FIXME(zcd): the type of the var should be LoDTensor, i.e
# the gradients should be dense, otherwise, the following
# logic should be updated.
# 128 MB as a group
mega_bytes = 128 * 1024 * 1024
group_idx = 0
memory_counter = 0
grad_var_groups = OrderedDict()
dtype = grad_vars[0].dtype
for g_var in grad_vars:
# Note: the dtype of the same group should be the same.
bytes = np.prod(g_var.shape) * core.size_of_dtype(g_var.dtype)
if memory_counter < mega_bytes and dtype == g_var.dtype:
memory_counter += bytes
else:
memory_counter = bytes
group_idx += 1
grad_var_groups.setdefault(group_idx, []).append(g_var)
coalesced_grads_and_vars = self._coalesce_tensors(grad_var_groups)
for coalesced_grad, g_vars, g_shapes in coalesced_grads_and_vars:
collective._allreduce(
coalesced_grad, coalesced_grad, sync_mode=False)
self._split_tensors(coalesced_grads_and_vars)
def _is_data_parallel_mode(self): def _is_data_parallel_mode(self):
return self._strategy.nranks > 1 return self._strategy.nranks > 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册