From 52007915ebb6fdd15553ef924deb08ac2dffb6a6 Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Wed, 9 Jun 2021 13:26:03 +0800 Subject: [PATCH] [HybridParallel] Add ParallelCrossEntropy for TensorParallel (#33401) * add parallel_cross_entropy * add grad for crossentropy * fix cross entropy --- .../operators/collective/c_embedding_op.cc | 13 +- .../c_softmax_with_cross_entropy_op.cc | 194 +++++++++++++ .../c_softmax_with_cross_entropy_op.cu | 262 ++++++++++++++++++ .../c_softmax_with_cross_entropy_op.h | 41 +++ python/paddle/distributed/collective.py | 29 ++ .../fleet/meta_parallel/__init__.py | 1 + .../meta_parallel/parallel_layers/__init__.py | 1 + .../parallel_layers/mp_layers.py | 17 ++ .../unittests/hybrid_parallel_mp_layers.py | 57 ++++ 9 files changed, 608 insertions(+), 7 deletions(-) create mode 100644 paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc create mode 100644 paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu create mode 100644 paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h diff --git a/paddle/fluid/operators/collective/c_embedding_op.cc b/paddle/fluid/operators/collective/c_embedding_op.cc index 094ef9c8d4e..3055e2ceb23 100644 --- a/paddle/fluid/operators/collective/c_embedding_op.cc +++ b/paddle/fluid/operators/collective/c_embedding_op.cc @@ -31,13 +31,12 @@ class CEmbeddingOp : public framework::OperatorWithKernel { int ids_rank = ids_dims.size(); VLOG(5) << "ids rank is " << ids_rank << std::endl; - PADDLE_ENFORCE_EQ( - table_dims.size(), 2, - platform::errors::InvalidArgument( - "ShapeError: The dimensions of the 'c_embedding' must be 2. " - "But received c_embedding's dimensions = %d, " - "c_embedding's shape = [%s].", - table_dims.size(), table_dims)); + PADDLE_ENFORCE_EQ(table_dims.size(), 2, + platform::errors::InvalidArgument( + "The dimensions of the 'c_embedding' must be 2. " + "But received c_embedding's dimensions = %d, " + "c_embedding's shape = [%s].", + table_dims.size(), table_dims)); auto output_dims = framework::vectorize(ids_dims); output_dims.push_back(table_dims[1]); diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc new file mode 100644 index 00000000000..f75e1b3c7ae --- /dev/null +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc @@ -0,0 +1,194 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h" + +namespace paddle { +namespace operators { + +class CSoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Logits"), "Input", "Logits", + "CSoftmaxWithCrossEntropyOp"); + OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", + "CSoftmaxWithCrossEntropyOp"); + + OP_INOUT_CHECK(ctx->HasOutput("Softmax"), "Output", "Softmax", + "CSoftmaxWithCrossEntropyOp"); + OP_INOUT_CHECK(ctx->HasOutput("Loss"), "Output", "Loss", + "CSoftmaxWithCrossEntropyOp"); + + auto logits_dims = ctx->GetInputDim("Logits"); + auto labels_dims = ctx->GetInputDim("Label"); + + auto logits_rank = logits_dims.size(); + auto axis = logits_rank - 1; + for (int i = 0; i < logits_rank; i++) { + if (i != axis) { + if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) { + PADDLE_ENFORCE_EQ(logits_dims[i], labels_dims[i], + platform::errors::InvalidArgument( + "Input(Logits) and Input(Label) should in " + "same shape in dimensions except axis.")); + } + } + } + + PADDLE_ENFORCE_EQ( + labels_dims[logits_rank - 1], 1UL, + platform::errors::InvalidArgument( + "the last dimension of Input(Label) should be 1." + "But received: the last dimension of Input(Label) is [%d]," + "the last dimension is [%d]", + labels_dims[logits_rank - 1], logits_rank - 1)); + + ctx->SetOutputDim("Softmax", logits_dims); + + logits_dims[axis] = 1; + ctx->SetOutputDim("Loss", logits_dims); + + ctx->ShareLoD("Logits", /*->*/ "Softmax"); + ctx->ShareLoD("Logits", /*->*/ "Loss"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), + ctx.device_context()); + } +}; + +class CSoftmaxWithCrossEntropyOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("Logits", + "(Tensor, default: Tensor), The input tensor of unscaled " + "log probabilities, whose dimension :attr:`axis` should be scaled " + "by softmax."); + AddInput( + "Label", + "(Tensor) The input tensor of groud truth label. If :attr:`soft_label` " + "is set to false, Label is a Tensor in same shape with " + "Input(Logits) except the shape in dimension :attr:`axis` as 1. If " + "soft_label is set to true, Label is a Tensor in same " + "shape with Input(Logits)."); + AddOutput( + "Softmax", + "(Tensor, default: Tensor), A tensor in same shape with " + "Input(Logits). " + "The outputs value of softmax activation by given the input batch, " + "which will be used in backward calculation."); + AddOutput("Loss", + "(Tensor, default: Tensor), A tensor in same shape with " + "Input(Logits) " + "except the shape in dimension :attr:`axis` as 1. The cross " + "entropy loss."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr("rank", + "(int default 0) rank id for CSoftmaxWithCrossEntropy.") + .SetDefault(0); + AddAttr("nranks", + "(int default 1) nranks id for CSoftmaxWithCrossEntropy.") + .SetDefault(0); + AddComment(R"DOC( +CSoftmaxWithCrossEntropy Operator + +)DOC"); + } +}; + +class CSoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Loss")), true, + platform::errors::InvalidArgument( + "Input(Loss@Grad) should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"), true, + platform::errors::InvalidArgument( + "Input(Softmax) should be not null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Label"), true, + platform::errors::InvalidArgument("Input(Label) should be not null.")); + + PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Logits")), true, + platform::errors::InvalidArgument( + "Output(Logits@Grad) should be not null.")); + + ctx->SetOutputDim(framework::GradVarName("Logits"), + ctx->GetInputDim("Softmax")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.device_context()); + } +}; + +template +class CSoftmaxWithCrossEntropyOpGradMaker + : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("c_softmax_with_cross_entropy_grad"); + + op->SetInput("Softmax", this->Output("Softmax")); + op->SetInput("Label", this->Input("Label")); + op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss")); + op->SetAttrMap(this->Attrs()); + op->SetOutput(framework::GradVarName("Logits"), this->InputGrad("Logits")); + } +}; + +DECLARE_INPLACE_OP_INFERER(CSoftmaxWithCrossEntropyInplaceInferer, + {"Logits", "Softmax"}); + +DECLARE_INPLACE_OP_INFERER(CSoftmaxWithCrossEntropyGradInplaceInferer, + {"Softmax", framework::GradVarName("Logits")}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR( + c_softmax_with_cross_entropy, ops::CSoftmaxWithCrossEntropyOp, + ops::CSoftmaxWithCrossEntropyOpMaker, + ops::CSoftmaxWithCrossEntropyOpGradMaker, + ops::CSoftmaxWithCrossEntropyOpGradMaker, + ops::CSoftmaxWithCrossEntropyInplaceInferer); + +REGISTER_OPERATOR(c_softmax_with_cross_entropy_grad, + ops::CSoftmaxWithCrossEntropyOpGrad, + ops::CSoftmaxWithCrossEntropyGradInplaceInferer); + +REGISTER_OP_CPU_KERNEL(c_softmax_with_cross_entropy, + ops::CSoftmaxWithCrossEntropyOpCPUKernel, + ops::CSoftmaxWithCrossEntropyOpCPUKernel, + ops::CSoftmaxWithCrossEntropyOpCPUKernel); diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu new file mode 100644 index 00000000000..77db86e7111 --- /dev/null +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu @@ -0,0 +1,262 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h" +#include "paddle/fluid/operators/math/cross_entropy.h" +#include "paddle/fluid/operators/math/softmax_impl.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void MaskLabelByIndex(T* predicted_logits, const T* logit, + const IndexT* label, const int start_index, + const int end_index, const int64_t N, + const int64_t D, const int nranks) { + CUDA_KERNEL_LOOP(i, N) { + auto real_label = label[i]; + PADDLE_ENFORCE((real_label < D * nranks) && (real_label >= 0), + "The index is out of bounds, " + "please check whether the value of label and " + "input meet the class number. It should " + "be less than [%d], but received [%d]", + D * nranks, real_label); + + if (real_label >= start_index && real_label < end_index) { + predicted_logits[i] = logit[i * D + real_label - start_index]; + } + } +} + +template +__global__ void MaskLabelByIndexGrad(T* logits_grad, const T* loss_grad, + const IndexT* labels, + const int start_index, const int end_index, + const int64_t N, const int64_t D) { + CUDA_KERNEL_LOOP(i, N * D) { + auto row = i / D; + auto col = i % D; + if ((col + start_index) == labels[row]) { + logits_grad[i] = (logits_grad[i] - static_cast(1.0)) * loss_grad[row]; + } else { + logits_grad[i] *= loss_grad[row]; + } + } +} + +template +class CSoftmaxWithCrossEntropyOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* logits = ctx.Input("Logits"); + const Tensor* labels = ctx.Input("Label"); + Tensor* softmax = ctx.Output("Softmax"); + Tensor* loss = ctx.Output("Loss"); + + const int rid = ctx.Attr("ring_id"); + const int nranks = ctx.Attr("nranks"); + const int rank = ctx.Attr("rank"); + + const auto& place = ctx.GetPlace(); + const auto& comm = platform::NCCLCommContext::Instance().Get(rid, place); + auto& dev_ctx = ctx.template device_context(); + + // use global calculate stream + const auto stream = static_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + + // allocate memory on device. + softmax->mutable_data(place); + loss->mutable_data(place); + + const auto& logits_dims = logits->dims(); + const auto& labels_dims = labels->dims(); + + const int axis = logits_dims.size() - 1; + const int N = SizeToAxis(axis, logits_dims); + const int D = SizeFromAxis(axis, logits_dims); + + Tensor logits_2d, softmax_2d, loss_2d; + logits_2d.ShareDataWith(*logits).Resize({N, D}); + softmax_2d.ShareDataWith(*softmax).Resize({N, D}); + loss_2d.ShareDataWith(*loss).Resize({N, 1}); + + auto eigen_logits = math::EigenMatrix::From(logits_2d); + auto eigen_softmax = math::EigenMatrix::From(softmax_2d); + + // step 1, obtain logit_max + Tensor logits_max; + logits_max = + ctx.AllocateTmpTensor({N, 1}, dev_ctx); + void* logits_max_buff = logits_max.mutable_data(place); + + auto eigen_logits_max = math::EigenMatrix::From(logits_max); + Eigen::DSizes along_axis(1); + eigen_logits_max.device(*dev_ctx.eigen_device()) = + eigen_logits.maximum(along_axis); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( + logits_max_buff, logits_max_buff, logits_max.numel(), + platform::ToNCCLDataType(logits_max.type()), ncclMax, comm->comm(), + stream)); + + // step 2, obtain logit - logit_max + Eigen::DSizes batch_by_one(N, 1); + Eigen::DSizes one_by_class(1, D); + + eigen_softmax.device(*dev_ctx.eigen_device()) = + (eigen_logits - + eigen_logits_max.reshape(batch_by_one).broadcast(one_by_class)) + .unaryExpr(math::ValueClip()); + + // step 3, obtain predict target + Tensor predicted_logits; + predicted_logits = + ctx.AllocateTmpTensor({N, 1}, dev_ctx); + predicted_logits.mutable_data(place); + + auto t = framework::EigenVector::Flatten(predicted_logits); + t.device(*dev_ctx.eigen_device()) = t.constant(static_cast(0)); + + const int start_index = rank * D; + const int end_index = start_index + D; + + int blocks = NumBlocks(N); + int threads = kNumCUDAThreads; + const auto& label_type = labels->type(); + + if (label_type == framework::proto::VarType::INT32) { + MaskLabelByIndex<<>>( + predicted_logits.data(), softmax_2d.data(), + labels->data(), start_index, end_index, N, D, nranks); + } else if (label_type == framework::proto::VarType::INT64) { + MaskLabelByIndex<<>>( + predicted_logits.data(), softmax_2d.data(), + labels->data(), start_index, end_index, N, D, nranks); + } + + void* predict_logits_buff = predicted_logits.mutable_data(place); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( + predict_logits_buff, predict_logits_buff, predicted_logits.numel(), + platform::ToNCCLDataType(predicted_logits.type()), ncclSum, + comm->comm(), stream)); + + // step 4, obtain exp(logit) + eigen_softmax.device(*dev_ctx.eigen_device()) = eigen_softmax.exp(); + + // step 5, obtain sum_exp_logits + Tensor sum_exp_logits; + sum_exp_logits = + ctx.AllocateTmpTensor({N, 1}, dev_ctx); + void* sum_exp_logits_buff = sum_exp_logits.mutable_data(place); + + auto eigen_sum_exp_logits = math::EigenMatrix::From(sum_exp_logits); + eigen_sum_exp_logits.device(*dev_ctx.eigen_device()) = + eigen_softmax.sum(along_axis); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( + sum_exp_logits_buff, sum_exp_logits_buff, sum_exp_logits.numel(), + platform::ToNCCLDataType(sum_exp_logits.type()), ncclSum, comm->comm(), + stream)); + + auto eigen_loss = math::EigenMatrix::From(loss_2d); + auto eigen_predicted_logits = math::EigenMatrix::From(predicted_logits); + + eigen_loss.device(*dev_ctx.eigen_device()) = + (eigen_sum_exp_logits.log().unaryExpr(math::TolerableValue()) - + eigen_predicted_logits) + .unaryExpr(math::TolerableValue()); + + eigen_softmax.device(*dev_ctx.eigen_device()) = + (eigen_softmax * + eigen_sum_exp_logits.inverse().broadcast(one_by_class)); + } +}; + +template +class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* labels = context.Input("Label"); + const Tensor* loss_grad = + context.Input(framework::GradVarName("Loss")); + Tensor* logit_grad = + context.Output(framework::GradVarName("Logits")); + const Tensor* softmax = context.Input("Softmax"); + const int rank = context.Attr("rank"); + auto& dev_ctx = + context.template device_context(); + + if (logit_grad != softmax) { + framework::TensorCopy(*softmax, context.GetPlace(), + context.device_context(), logit_grad); + } + const auto sofrmax_dims = softmax->dims(); + const int axis = sofrmax_dims.size() - 1; + const int N = SizeToAxis(axis, sofrmax_dims); + const int D = SizeFromAxis(axis, sofrmax_dims); + + Tensor logit_grad_2d; + logit_grad_2d.ShareDataWith(*logit_grad).Resize({N, D}); + + int blocks = NumBlocks(N * D); + int threads = kNumCUDAThreads; + const auto& label_type = labels->type(); + const int start_index = rank * D; + const int end_index = start_index + D; + + if (label_type == framework::proto::VarType::INT32) { + MaskLabelByIndexGrad<<>>( + logit_grad_2d.data(), loss_grad->data(), + labels->data(), start_index, end_index, N, D); + } else if (label_type == framework::proto::VarType::INT64) { + MaskLabelByIndexGrad<<>>( + logit_grad_2d.data(), loss_grad->data(), + labels->data(), start_index, end_index, N, D); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + c_softmax_with_cross_entropy, + ops::CSoftmaxWithCrossEntropyOpCUDAKernel, + ops::CSoftmaxWithCrossEntropyOpCUDAKernel, + ops::CSoftmaxWithCrossEntropyOpCUDAKernel); + +REGISTER_OP_CUDA_KERNEL( + c_softmax_with_cross_entropy_grad, + ops::CSoftmaxWithCrossEntropyGradCUDAKernel, + ops::CSoftmaxWithCrossEntropyGradCUDAKernel, + ops::CSoftmaxWithCrossEntropyGradCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h new file mode 100644 index 00000000000..c7cfd41fa25 --- /dev/null +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/cross_entropy.h" +#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/operators/softmax_op.h" + +namespace paddle { +namespace operators { + +template +class CSoftmaxWithCrossEntropyOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support c_embedding for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index e3b8d783b2e..f10b0736ef9 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -954,6 +954,35 @@ class _Linear(layers.Layer): self.weight.shape[0], self.weight.shape[1], self._dtype, name_str) +def _c_softmax_with_cross_entropy(logits, + label, + group=None, + return_softmax=False): + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + global_rank = _get_global_env().rank + rank = global_rank if group is None else group.get_group_rank(global_rank) + nranks = _get_global_env().world_size if group is None else group.nranks + + input_dims = len(list(logits.shape)) + label_dims = len(list(label.shape)) + if input_dims - 1 != label_dims and input_dims != label_dims: + raise ValueError( + 'Expected nput_dims - 1 = label_dims or input_dims == label_dims\ + (got nput_dims{}, label_dims{})'.format(input_dims, label_dims)) + if input_dims - 1 == label_dims: + label = paddle.unsqueeze(label, axis=-1) + + if in_dygraph_mode(): + softmax, loss = core.ops.c_softmax_with_cross_entropy( + logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks) + if not return_softmax: + return loss + else: + return loss, softmax + + def _linear(x, weight, bias=None, name=None): """ Fuction Linear diff --git a/python/paddle/distributed/fleet/meta_parallel/__init__.py b/python/paddle/distributed/fleet/meta_parallel/__init__.py index 894771a3d50..0750c2c250e 100644 --- a/python/paddle/distributed/fleet/meta_parallel/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/__init__.py @@ -15,6 +15,7 @@ from .parallel_layers import VocabParallelEmbedding # noqa: F401 from .parallel_layers import ColumnParallelLinear # noqa: F401 from .parallel_layers import RowParallelLinear # noqa: F401 +from .parallel_layers import ParallelCrossEntropy # noqa: F401 from .parallel_layers import LayerDesc # noqa: F401 from .parallel_layers import PipelineLayer # noqa: F401 from .parallel_layers import RNGStatesTracker # noqa: F401 diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/__init__.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/__init__.py index 6a33611403a..72da962b891 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/__init__.py @@ -15,6 +15,7 @@ from .mp_layers import VocabParallelEmbedding # noqa: F401 from .mp_layers import ColumnParallelLinear # noqa: F401 from .mp_layers import RowParallelLinear # noqa: F401 +from .mp_layers import ParallelCrossEntropy # noqa: F401 from .pp_layers import LayerDesc # noqa: F401 from .pp_layers import PipelineLayer # noqa: F401 from .random import RNGStatesTracker # noqa: F401 diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py index 91f9868f96e..f091c890f68 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py @@ -18,6 +18,7 @@ from .random import get_rng_state_tracker from paddle.nn import functional as F from paddle import framework from ...base import topology as tp +from paddle.autograd import PyLayer __all__ = [] @@ -243,3 +244,19 @@ class RowParallelLinear(Layer): output = output_ + self.bias if self.bias is not None else output_ return output + + +class ParallelCrossEntropy(Layer): + def __init__(self, name=None): + super(ParallelCrossEntropy, self).__init__() + self.name = name + self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( + ) + self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( + ) + self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() + + def forward(self, input, label): + loss = paddle.distributed.collective._c_softmax_with_cross_entropy( + input, label, group=self.model_parallel_group) + return loss diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py index e69cf7d267b..23dae317386 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py @@ -269,6 +269,63 @@ class TestDistTraning(unittest.TestCase): np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy()) + def test_parallel_cross_entropy(self): + batch_size = 2 + seq_length = 1 + class_size_per_card = 2 + vocab_size = class_size_per_card * self.model_parallel_size + seed = 1025 + + set_random_seed(seed) + rank_id = dist.get_rank() + + # model_a + model_a = fleet.meta_parallel.ParallelCrossEntropy() + + model_b = paddle.nn.CrossEntropyLoss(reduction="none") + + paddle.seed(rank_id * 10) + random.seed(seed) + np.random.seed(seed) + + for _ in range(5): + np_label = np.random.randint(0, vocab_size, + (batch_size, seq_length)) + label = paddle.to_tensor(np_label, dtype="int64") + + data = paddle.randn( + shape=[batch_size, seq_length, class_size_per_card], + dtype='float32') + data.stop_gradient = False + + check_group = dist.new_group(list(range(self.model_parallel_size))) + integral_data = [] + partial_data = data.clone().detach() + paddle.distributed.all_gather( + integral_data, partial_data, group=check_group) + integral_data = paddle.concat(integral_data, axis=-1) + integral_data = integral_data.detach().clone() + integral_data.stop_gradient = False + + loss_a = model_a(data, label).sum() / batch_size + loss_b = model_b(integral_data, label).sum() / batch_size + print("loss_a: ", loss_a.numpy(), "loss_b: ", loss_b.numpy()) + + np.testing.assert_allclose( + loss_a.numpy(), loss_b.numpy(), rtol=1e-6) + + loss_a.backward() + loss_b.backward() + + integral_grad = [] + partial_grad = data.grad.clone().detach() + paddle.distributed.all_gather( + integral_grad, partial_grad, group=check_group) + integral_grad = paddle.concat(integral_grad, axis=-1) + + np.testing.assert_allclose( + integral_data.grad.numpy(), integral_grad.numpy(), rtol=1e-6) + if __name__ == '__main__': unittest.main() -- GitLab