未验证 提交 e6fb6599 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Refactor Model Parallel in eager mode (#41761)

* refactor mp in eager mode

* update

* update

* add uts
上级 ff818c77
......@@ -27,8 +27,10 @@ namespace cub = hipcub;
#include <iterator>
#include <random>
#include "paddle/fluid/operators/class_center_sample_op.h"
#include "paddle/phi/api/include/tensor.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
......@@ -328,19 +330,34 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nranks > 1) {
const auto& comm =
platform::NCCLCommContext::Instance().Get(rid, ctx.GetPlace());
// use global calculate stream
const auto calcu_stream =
static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(ctx.GetPlace()))
->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
num_classes_per_device_ptr, num_classes_per_device_ptr,
num_classes_per_device.numel(),
platform::ToNCCLDataType(
framework::TransToProtoVarType(num_classes_per_device.dtype())),
ncclSum, comm->comm(), calcu_stream));
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(num_classes_per_device);
out_tensor.push_back(num_classes_per_device);
distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
} else {
const auto& comm =
platform::NCCLCommContext::Instance().Get(rid, ctx.GetPlace());
// use global calculate stream
const auto calcu_stream =
static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(ctx.GetPlace()))
->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
num_classes_per_device_ptr, num_classes_per_device_ptr,
num_classes_per_device.numel(),
platform::ToNCCLDataType(
framework::TransToProtoVarType(num_classes_per_device.dtype())),
ncclSum, comm->comm(), calcu_stream));
}
}
#endif
......
......@@ -16,12 +16,14 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/api/include/tensor.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_XPU_BKCL) || \
......@@ -351,6 +353,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
ncclDataType_t dtype =
......@@ -360,7 +363,43 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);
int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*in);
out_tensor.push_back(*out);
distributed::AllreduceOptions opts;
switch (red_type) {
case kRedSum:
opts.reduce_op = distributed::ReduceOp::SUM;
break;
case kRedMax:
opts.reduce_op = distributed::ReduceOp::MAX;
break;
case kRedMin:
opts.reduce_op = distributed::ReduceOp::MIN;
break;
case kRedProd:
opts.reduce_op = distributed::ReduceOp::PRODUCT;
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid reduce type: %d", red_type));
}
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
return;
}
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
gpuStream_t stream = nullptr;
......
......@@ -16,8 +16,10 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_concat_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/api/include/tensor.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
......@@ -55,26 +57,39 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> {
rank, nranks));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks, comm->nranks(),
platform::errors::InvalidArgument("nranks: %s should equal to %s",
nranks, comm->nranks()));
framework::Tensor temp_out;
framework::DDim temp_out_dims = x->dims();
temp_out_dims[0] *= nranks;
temp_out.mutable_data<T>(temp_out_dims, place);
int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
gpuStream_t stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*x);
out_tensor.push_back(temp_out);
auto task = pg->AllGather(in_tensor, out_tensor);
task->Wait();
} else {
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks, comm->nranks(),
platform::errors::InvalidArgument("nranks: %s should equal to %s",
nranks, comm->nranks()));
int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
gpuStream_t stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
}
std::vector<framework::Tensor> inputs;
int axis = x->dims().size() - 1;
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle {
......@@ -73,6 +74,21 @@ template <typename T>
class CSoftmaxWithCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> functor_;
functor_(ctx);
} else {
CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> functor_;
functor_(ctx);
}
}
};
template <typename T>
struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
Tensor* softmax = ctx.Output<Tensor>("Softmax");
......@@ -201,6 +217,129 @@ class CSoftmaxWithCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
}
};
template <typename T>
struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
Tensor* softmax = ctx.Output<Tensor>("Softmax");
Tensor* loss = ctx.Output<Tensor>("Loss");
const int rid = ctx.Attr<int>("ring_id");
const int nranks = ctx.Attr<int>("nranks");
const int rank = ctx.Attr<int>("rank");
const auto& place = ctx.GetPlace();
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto map = distributed::ProcessGroupMapFromGid::getInstance();
distributed::ProcessGroup* pg = map->get(rid);
distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
// allocate memory on device.
softmax->mutable_data<T>(place);
loss->mutable_data<T>(place);
const auto& logits_dims = logits->dims();
const auto& labels_dims = labels->dims();
const int axis = logits_dims.size() - 1;
const int N = phi::funcs::SizeToAxis(axis, logits_dims);
const int D = phi::funcs::SizeFromAxis(axis, logits_dims);
Tensor logits_2d, softmax_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({N, D});
softmax_2d.ShareDataWith(*softmax).Resize({N, D});
loss_2d.ShareDataWith(*loss).Resize({N, 1});
auto eigen_logits = math::EigenMatrix<T>::From(logits_2d);
auto eigen_softmax = math::EigenMatrix<T>::From(softmax_2d);
// step 1, obtain logit_max
Tensor logits_max;
logits_max =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
auto eigen_logits_max = math::EigenMatrix<T>::From(logits_max);
Eigen::DSizes<int, 1> along_axis(1);
eigen_logits_max.device(*dev_ctx.eigen_device()) =
eigen_logits.maximum(along_axis);
std::vector<phi::DenseTensor> in_out;
in_out.push_back(logits_max);
pg->AllReduce(in_out, in_out, opts)->Synchronize();
// step 2, obtain logit - logit_max
Eigen::DSizes<int, 2> batch_by_one(N, 1);
Eigen::DSizes<int, 2> one_by_class(1, D);
eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_logits -
eigen_logits_max.reshape(batch_by_one).broadcast(one_by_class))
.unaryExpr(math::ValueClip<T>());
// step 3, obtain predict target
Tensor predicted_logits;
predicted_logits =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
predicted_logits.mutable_data<T>(place);
auto t = framework::EigenVector<T>::Flatten(predicted_logits);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
const int start_index = rank * D;
const int end_index = start_index + D;
int blocks = NumBlocks(N);
int threads = kNumCUDAThreads;
const auto& label_type = framework::TransToProtoVarType(labels->dtype());
if (label_type == framework::proto::VarType::INT32) {
MaskLabelByIndex<T, int32_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
predicted_logits.data<T>(), softmax_2d.data<T>(),
labels->data<int32_t>(), start_index, end_index, N, D, nranks);
} else if (label_type == framework::proto::VarType::INT64) {
MaskLabelByIndex<T, int64_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
predicted_logits.data<T>(), softmax_2d.data<T>(),
labels->data<int64_t>(), start_index, end_index, N, D, nranks);
}
in_out.clear();
in_out.push_back(predicted_logits);
pg->AllReduce(in_out, in_out, opts)->Synchronize();
// step 4, obtain exp(logit)
eigen_softmax.device(*dev_ctx.eigen_device()) = eigen_softmax.exp();
// step 5, obtain sum_exp_logits
Tensor sum_exp_logits;
sum_exp_logits =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
void* sum_exp_logits_buff = sum_exp_logits.mutable_data<T>(place);
auto eigen_sum_exp_logits = math::EigenMatrix<T>::From(sum_exp_logits);
eigen_sum_exp_logits.device(*dev_ctx.eigen_device()) =
eigen_softmax.sum(along_axis);
in_out.clear();
in_out.push_back(sum_exp_logits);
pg->AllReduce(in_out, in_out, opts)->Synchronize();
auto eigen_loss = math::EigenMatrix<T>::From(loss_2d);
auto eigen_predicted_logits = math::EigenMatrix<T>::From(predicted_logits);
eigen_loss.device(*dev_ctx.eigen_device()) =
(eigen_sum_exp_logits.log().unaryExpr(math::TolerableValue<T>()) -
eigen_predicted_logits)
.unaryExpr(math::TolerableValue<T>());
eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_softmax *
eigen_sum_exp_logits.inverse().broadcast(one_by_class));
}
};
template <typename T>
class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
public:
......
......@@ -18,11 +18,13 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
namespace operators {
......@@ -36,5 +38,15 @@ class CSoftmaxWithCrossEntropyOpCPUKernel : public framework::OpKernel<T> {
}
};
template <typename Context, typename T>
struct CSoftmaxWithCrossEntropyFunctor {
void operator()(const framework::ExecutionContext& ctx);
};
template <typename Context, typename T>
struct CSoftmaxWithCrossEntropyProcessGroupFunctor {
void operator()(const framework::ExecutionContext& ctx);
};
} // namespace operators
} // namespace paddle
......@@ -20,6 +20,7 @@ from ..fluid.framework import Variable
from ..fluid.framework import in_dygraph_mode
from ..fluid.framework import OpProtoHolder
from ..fluid.framework import _non_static_mode
from ..fluid.framework import _in_legacy_dygraph
from ..fluid.framework import convert_np_dtype_to_dtype_
from ..fluid.framework import _varbase_creator
from ..fluid.data_feeder import convert_dtype
......@@ -1132,13 +1133,36 @@ def _mp_allreduce(tensor,
group=None,
use_calc_stream=True,
use_model_parallel=True):
"""[it is same as allreduce above, but it suuports model parallel. And it support inplace startegy]
"""[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
if _non_static_mode():
if in_dygraph_mode():
assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op)
from paddle.autograd import EagerPyLayer
class mp_allreduce_eager(EagerPyLayer):
@staticmethod
def forward(ctx, tensor, use_calc_stream, ring_id,
use_model_parallel):
ctx.ring_id = ring_id
return _C_ops.c_allreduce_sum_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id',
ring_id, "use_model_parallel", use_model_parallel)
@staticmethod
def backward(ctx, dy):
return _C_ops.c_identity(dy, 'use_calc_stream', True, 'ring_id',
ctx.ring_id, 'use_model_parallel',
True)
return mp_allreduce_eager.apply(tensor, use_calc_stream, ring_id,
use_model_parallel)
elif _in_legacy_dygraph():
if op == ReduceOp.SUM:
return _C_ops.c_allreduce_sum_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id,
......
......@@ -378,7 +378,7 @@ def sync_params_buffers(model,
param.name)
# is_distributed param not need to sync when in mp mode
if isinstance(param, ParamBase):
if isinstance(param, (ParamBase, core.eager.Tensor)):
if is_model_parallel and param.is_distributed:
continue
......
......@@ -329,7 +329,9 @@ def concat(input, axis=0, name=None):
axis = axis.item(0)
if not isinstance(input, Variable):
input = [t for t in input if t.shape.count(0) == 0]
return _C_ops.final_state_concat(input, axis)
out = _varbase_creator()
_C_ops.concat(input, out, 'axis', axis)
return out
if _in_legacy_dygraph():
if isinstance(axis, Variable):
......
......@@ -14,16 +14,20 @@
from __future__ import print_function
import os
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
from paddle.fluid.framework import _test_eager_guard
class TestParallelClassCenterSample(TestMultipleGpus):
def test_parallel_class_center_sample(self):
self.run_mnist_2gpu('parallel_class_center_sample.py')
self.run_mnist_2gpu('parallel_class_center_sample.py', eager_mode=False)
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main()
......@@ -100,6 +100,7 @@ def start_local_trainers(cluster,
pod,
training_script,
training_script_args,
eager_mode=True,
log_dir=None):
current_env = copy.copy(os.environ.copy())
#paddle broadcast ncclUniqueId use socket, and
......@@ -119,6 +120,9 @@ def start_local_trainers(cluster,
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
}
if not eager_mode:
proc_env["FLAGS_enable_eager_mode"] = "%d" % 0
current_env.update(proc_env)
print("trainer proc env:{}".format(current_env))
......@@ -145,15 +149,8 @@ def start_local_trainers(cluster,
return procs
def get_dist_port_from_flags():
DIST_UT_PORT = 6175
if os.getenv("PADDLE_DIST_UT_PORT"):
DIST_UT_PORT = int(os.getenv("PADDLE_DIST_UT_PORT"))
return DIST_UT_PORT
class TestMultipleGpus(unittest.TestCase):
def run_mnist_2gpu(self, target_file_name):
def run_mnist_2gpu(self, target_file_name, eager_mode=True):
if not fluid.core.is_compiled_with_cuda(
) or fluid.core.get_cuda_device_count() == 0:
return
......@@ -167,6 +164,7 @@ class TestMultipleGpus(unittest.TestCase):
procs = start_local_trainers(
cluster,
pod,
eager_mode=eager_mode,
training_script=target_file_name,
training_script_args=[])
......@@ -206,9 +204,9 @@ class TestDataParallelGradientCheck(TestMultipleGpus):
class TestDataParallelWithPyLayer(TestMultipleGpus):
def test_parallel_dygraph_dataparallel_with_pylayer(self):
with _test_eager_guard():
self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py')
self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py')
self.run_mnist_2gpu(
'parallel_dygraph_dataparallel_with_pylayer.py', eager_mode=False)
class TestGradientCheckInEagerMode(TestMultipleGpus):
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import os
import unittest
import paddle.fluid as fluid
......@@ -23,7 +24,9 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestModelParallelLayer(TestMultipleGpus):
def test_hybrid_parallel_mp_layer(self):
self.run_mnist_2gpu('hybrid_parallel_mp_layers.py')
self.run_mnist_2gpu('hybrid_parallel_mp_layers.py', eager_mode=False)
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main()
......@@ -14,6 +14,7 @@
from __future__ import print_function
import os
import unittest
import paddle.fluid as fluid
......@@ -22,20 +23,26 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestHybridParallel(TestMultipleGpus):
def test_hybrid_parallel_mp_random(self):
self.run_mnist_2gpu('hybrid_parallel_mp_random.py')
# self.run_mnist_2gpu('hybrid_parallel_mp_random.py')
self.run_mnist_2gpu('hybrid_parallel_mp_random.py', eager_mode=False)
def test_hybrid_parallel_mp_model(self):
self.run_mnist_2gpu('hybrid_parallel_mp_model.py')
self.run_mnist_2gpu('hybrid_parallel_mp_model.py', eager_mode=False)
def test_hybrid_parallel_mp_amp(self):
self.run_mnist_2gpu('hybrid_parallel_mp_amp.py')
self.run_mnist_2gpu('hybrid_parallel_mp_amp.py', eager_mode=False)
def test_hybrid_parallel_mp_fp16(self):
self.run_mnist_2gpu('hybrid_parallel_mp_fp16.py')
self.run_mnist_2gpu('hybrid_parallel_mp_fp16.py', eager_mode=False)
def test_hybrid_parallel_mp_clip_grad(self):
self.run_mnist_2gpu('hybrid_parallel_mp_clip_grad.py')
self.run_mnist_2gpu('hybrid_parallel_mp_clip_grad.py', eager_mode=False)
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册