未验证 提交 52007915 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel] Add ParallelCrossEntropy for TensorParallel (#33401)

* add parallel_cross_entropy

* add grad for crossentropy

* fix cross entropy
上级 cdd6437a
...@@ -31,10 +31,9 @@ class CEmbeddingOp : public framework::OperatorWithKernel { ...@@ -31,10 +31,9 @@ class CEmbeddingOp : public framework::OperatorWithKernel {
int ids_rank = ids_dims.size(); int ids_rank = ids_dims.size();
VLOG(5) << "ids rank is " << ids_rank << std::endl; VLOG(5) << "ids rank is " << ids_rank << std::endl;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(table_dims.size(), 2,
table_dims.size(), 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"ShapeError: The dimensions of the 'c_embedding' must be 2. " "The dimensions of the 'c_embedding' must be 2. "
"But received c_embedding's dimensions = %d, " "But received c_embedding's dimensions = %d, "
"c_embedding's shape = [%s].", "c_embedding's shape = [%s].",
table_dims.size(), table_dims)); table_dims.size(), table_dims));
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h"
namespace paddle {
namespace operators {
class CSoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Logits"), "Input", "Logits",
"CSoftmaxWithCrossEntropyOp");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label",
"CSoftmaxWithCrossEntropyOp");
OP_INOUT_CHECK(ctx->HasOutput("Softmax"), "Output", "Softmax",
"CSoftmaxWithCrossEntropyOp");
OP_INOUT_CHECK(ctx->HasOutput("Loss"), "Output", "Loss",
"CSoftmaxWithCrossEntropyOp");
auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label");
auto logits_rank = logits_dims.size();
auto axis = logits_rank - 1;
for (int i = 0; i < logits_rank; i++) {
if (i != axis) {
if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
PADDLE_ENFORCE_EQ(logits_dims[i], labels_dims[i],
platform::errors::InvalidArgument(
"Input(Logits) and Input(Label) should in "
"same shape in dimensions except axis."));
}
}
}
PADDLE_ENFORCE_EQ(
labels_dims[logits_rank - 1], 1UL,
platform::errors::InvalidArgument(
"the last dimension of Input(Label) should be 1."
"But received: the last dimension of Input(Label) is [%d],"
"the last dimension is [%d]",
labels_dims[logits_rank - 1], logits_rank - 1));
ctx->SetOutputDim("Softmax", logits_dims);
logits_dims[axis] = 1;
ctx->SetOutputDim("Loss", logits_dims);
ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx->ShareLoD("Logits", /*->*/ "Loss");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
ctx.device_context());
}
};
class CSoftmaxWithCrossEntropyOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("Logits",
"(Tensor, default: Tensor<float>), The input tensor of unscaled "
"log probabilities, whose dimension :attr:`axis` should be scaled "
"by softmax.");
AddInput(
"Label",
"(Tensor) The input tensor of groud truth label. If :attr:`soft_label` "
"is set to false, Label is a Tensor<int64> in same shape with "
"Input(Logits) except the shape in dimension :attr:`axis` as 1. If "
"soft_label is set to true, Label is a Tensor<float/double> in same "
"shape with Input(Logits).");
AddOutput(
"Softmax",
"(Tensor, default: Tensor<float>), A tensor in same shape with "
"Input(Logits). "
"The outputs value of softmax activation by given the input batch, "
"which will be used in backward calculation.");
AddOutput("Loss",
"(Tensor, default: Tensor<float>), A tensor in same shape with "
"Input(Logits) "
"except the shape in dimension :attr:`axis` as 1. The cross "
"entropy loss.");
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<int>("rank",
"(int default 0) rank id for CSoftmaxWithCrossEntropy.")
.SetDefault(0);
AddAttr<int>("nranks",
"(int default 1) nranks id for CSoftmaxWithCrossEntropy.")
.SetDefault(0);
AddComment(R"DOC(
CSoftmaxWithCrossEntropy Operator
)DOC");
}
};
class CSoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Loss")), true,
platform::errors::InvalidArgument(
"Input(Loss@Grad) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"), true,
platform::errors::InvalidArgument(
"Input(Softmax) should be not null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Label"), true,
platform::errors::InvalidArgument("Input(Label) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Logits")), true,
platform::errors::InvalidArgument(
"Output(Logits@Grad) should be not null."));
ctx->SetOutputDim(framework::GradVarName("Logits"),
ctx->GetInputDim("Softmax"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Loss")),
ctx.device_context());
}
};
template <typename T>
class CSoftmaxWithCrossEntropyOpGradMaker
: public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("c_softmax_with_cross_entropy_grad");
op->SetInput("Softmax", this->Output("Softmax"));
op->SetInput("Label", this->Input("Label"));
op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("Logits"), this->InputGrad("Logits"));
}
};
DECLARE_INPLACE_OP_INFERER(CSoftmaxWithCrossEntropyInplaceInferer,
{"Logits", "Softmax"});
DECLARE_INPLACE_OP_INFERER(CSoftmaxWithCrossEntropyGradInplaceInferer,
{"Softmax", framework::GradVarName("Logits")});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
c_softmax_with_cross_entropy, ops::CSoftmaxWithCrossEntropyOp,
ops::CSoftmaxWithCrossEntropyOpMaker,
ops::CSoftmaxWithCrossEntropyOpGradMaker<paddle::framework::OpDesc>,
ops::CSoftmaxWithCrossEntropyOpGradMaker<paddle::imperative::OpBase>,
ops::CSoftmaxWithCrossEntropyInplaceInferer);
REGISTER_OPERATOR(c_softmax_with_cross_entropy_grad,
ops::CSoftmaxWithCrossEntropyOpGrad,
ops::CSoftmaxWithCrossEntropyGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(c_softmax_with_cross_entropy,
ops::CSoftmaxWithCrossEntropyOpCPUKernel<float>,
ops::CSoftmaxWithCrossEntropyOpCPUKernel<double>,
ops::CSoftmaxWithCrossEntropyOpCPUKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T, typename IndexT>
__global__ void MaskLabelByIndex(T* predicted_logits, const T* logit,
const IndexT* label, const int start_index,
const int end_index, const int64_t N,
const int64_t D, const int nranks) {
CUDA_KERNEL_LOOP(i, N) {
auto real_label = label[i];
PADDLE_ENFORCE((real_label < D * nranks) && (real_label >= 0),
"The index is out of bounds, "
"please check whether the value of label and "
"input meet the class number. It should "
"be less than [%d], but received [%d]",
D * nranks, real_label);
if (real_label >= start_index && real_label < end_index) {
predicted_logits[i] = logit[i * D + real_label - start_index];
}
}
}
template <typename T, typename IndexT>
__global__ void MaskLabelByIndexGrad(T* logits_grad, const T* loss_grad,
const IndexT* labels,
const int start_index, const int end_index,
const int64_t N, const int64_t D) {
CUDA_KERNEL_LOOP(i, N * D) {
auto row = i / D;
auto col = i % D;
if ((col + start_index) == labels[row]) {
logits_grad[i] = (logits_grad[i] - static_cast<T>(1.0)) * loss_grad[row];
} else {
logits_grad[i] *= loss_grad[row];
}
}
}
template <typename T>
class CSoftmaxWithCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
Tensor* softmax = ctx.Output<Tensor>("Softmax");
Tensor* loss = ctx.Output<Tensor>("Loss");
const int rid = ctx.Attr<int>("ring_id");
const int nranks = ctx.Attr<int>("nranks");
const int rank = ctx.Attr<int>("rank");
const auto& place = ctx.GetPlace();
const auto& comm = platform::NCCLCommContext::Instance().Get(rid, place);
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
// use global calculate stream
const auto stream = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
// allocate memory on device.
softmax->mutable_data<T>(place);
loss->mutable_data<T>(place);
const auto& logits_dims = logits->dims();
const auto& labels_dims = labels->dims();
const int axis = logits_dims.size() - 1;
const int N = SizeToAxis(axis, logits_dims);
const int D = SizeFromAxis(axis, logits_dims);
Tensor logits_2d, softmax_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({N, D});
softmax_2d.ShareDataWith(*softmax).Resize({N, D});
loss_2d.ShareDataWith(*loss).Resize({N, 1});
auto eigen_logits = math::EigenMatrix<T>::From(logits_2d);
auto eigen_softmax = math::EigenMatrix<T>::From(softmax_2d);
// step 1, obtain logit_max
Tensor logits_max;
logits_max =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
void* logits_max_buff = logits_max.mutable_data<T>(place);
auto eigen_logits_max = math::EigenMatrix<T>::From(logits_max);
Eigen::DSizes<int, 1> along_axis(1);
eigen_logits_max.device(*dev_ctx.eigen_device()) =
eigen_logits.maximum(along_axis);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
logits_max_buff, logits_max_buff, logits_max.numel(),
platform::ToNCCLDataType(logits_max.type()), ncclMax, comm->comm(),
stream));
// step 2, obtain logit - logit_max
Eigen::DSizes<int, 2> batch_by_one(N, 1);
Eigen::DSizes<int, 2> one_by_class(1, D);
eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_logits -
eigen_logits_max.reshape(batch_by_one).broadcast(one_by_class))
.unaryExpr(math::ValueClip<T>());
// step 3, obtain predict target
Tensor predicted_logits;
predicted_logits =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
predicted_logits.mutable_data<T>(place);
auto t = framework::EigenVector<T>::Flatten(predicted_logits);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
const int start_index = rank * D;
const int end_index = start_index + D;
int blocks = NumBlocks(N);
int threads = kNumCUDAThreads;
const auto& label_type = labels->type();
if (label_type == framework::proto::VarType::INT32) {
MaskLabelByIndex<T, int32_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
predicted_logits.data<T>(), softmax_2d.data<T>(),
labels->data<int32_t>(), start_index, end_index, N, D, nranks);
} else if (label_type == framework::proto::VarType::INT64) {
MaskLabelByIndex<T, int64_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
predicted_logits.data<T>(), softmax_2d.data<T>(),
labels->data<int64_t>(), start_index, end_index, N, D, nranks);
}
void* predict_logits_buff = predicted_logits.mutable_data<T>(place);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
predict_logits_buff, predict_logits_buff, predicted_logits.numel(),
platform::ToNCCLDataType(predicted_logits.type()), ncclSum,
comm->comm(), stream));
// step 4, obtain exp(logit)
eigen_softmax.device(*dev_ctx.eigen_device()) = eigen_softmax.exp();
// step 5, obtain sum_exp_logits
Tensor sum_exp_logits;
sum_exp_logits =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
void* sum_exp_logits_buff = sum_exp_logits.mutable_data<T>(place);
auto eigen_sum_exp_logits = math::EigenMatrix<T>::From(sum_exp_logits);
eigen_sum_exp_logits.device(*dev_ctx.eigen_device()) =
eigen_softmax.sum(along_axis);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
sum_exp_logits_buff, sum_exp_logits_buff, sum_exp_logits.numel(),
platform::ToNCCLDataType(sum_exp_logits.type()), ncclSum, comm->comm(),
stream));
auto eigen_loss = math::EigenMatrix<T>::From(loss_2d);
auto eigen_predicted_logits = math::EigenMatrix<T>::From(predicted_logits);
eigen_loss.device(*dev_ctx.eigen_device()) =
(eigen_sum_exp_logits.log().unaryExpr(math::TolerableValue<T>()) -
eigen_predicted_logits)
.unaryExpr(math::TolerableValue<T>());
eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_softmax *
eigen_sum_exp_logits.inverse().broadcast(one_by_class));
}
};
template <typename T>
class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* labels = context.Input<Tensor>("Label");
const Tensor* loss_grad =
context.Input<Tensor>(framework::GradVarName("Loss"));
Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits"));
const Tensor* softmax = context.Input<Tensor>("Softmax");
const int rank = context.Attr<int>("rank");
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
if (logit_grad != softmax) {
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), logit_grad);
}
const auto sofrmax_dims = softmax->dims();
const int axis = sofrmax_dims.size() - 1;
const int N = SizeToAxis(axis, sofrmax_dims);
const int D = SizeFromAxis(axis, sofrmax_dims);
Tensor logit_grad_2d;
logit_grad_2d.ShareDataWith(*logit_grad).Resize({N, D});
int blocks = NumBlocks(N * D);
int threads = kNumCUDAThreads;
const auto& label_type = labels->type();
const int start_index = rank * D;
const int end_index = start_index + D;
if (label_type == framework::proto::VarType::INT32) {
MaskLabelByIndexGrad<T,
int32_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
logit_grad_2d.data<T>(), loss_grad->data<T>(),
labels->data<int32_t>(), start_index, end_index, N, D);
} else if (label_type == framework::proto::VarType::INT64) {
MaskLabelByIndexGrad<T,
int64_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
logit_grad_2d.data<T>(), loss_grad->data<T>(),
labels->data<int64_t>(), start_index, end_index, N, D);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_softmax_with_cross_entropy,
ops::CSoftmaxWithCrossEntropyOpCUDAKernel<float>,
ops::CSoftmaxWithCrossEntropyOpCUDAKernel<double>,
ops::CSoftmaxWithCrossEntropyOpCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(
c_softmax_with_cross_entropy_grad,
ops::CSoftmaxWithCrossEntropyGradCUDAKernel<float>,
ops::CSoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
ops::CSoftmaxWithCrossEntropyGradCUDAKernel<double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/softmax_op.h"
namespace paddle {
namespace operators {
template <typename T>
class CSoftmaxWithCrossEntropyOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support c_embedding for cpu kernel now."));
}
};
} // namespace operators
} // namespace paddle
...@@ -954,6 +954,35 @@ class _Linear(layers.Layer): ...@@ -954,6 +954,35 @@ class _Linear(layers.Layer):
self.weight.shape[0], self.weight.shape[1], self._dtype, name_str) self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)
def _c_softmax_with_cross_entropy(logits,
label,
group=None,
return_softmax=False):
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
global_rank = _get_global_env().rank
rank = global_rank if group is None else group.get_group_rank(global_rank)
nranks = _get_global_env().world_size if group is None else group.nranks
input_dims = len(list(logits.shape))
label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims:
raise ValueError(
'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
(got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=-1)
if in_dygraph_mode():
softmax, loss = core.ops.c_softmax_with_cross_entropy(
logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks)
if not return_softmax:
return loss
else:
return loss, softmax
def _linear(x, weight, bias=None, name=None): def _linear(x, weight, bias=None, name=None):
""" """
Fuction Linear Fuction Linear
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from .parallel_layers import VocabParallelEmbedding # noqa: F401 from .parallel_layers import VocabParallelEmbedding # noqa: F401
from .parallel_layers import ColumnParallelLinear # noqa: F401 from .parallel_layers import ColumnParallelLinear # noqa: F401
from .parallel_layers import RowParallelLinear # noqa: F401 from .parallel_layers import RowParallelLinear # noqa: F401
from .parallel_layers import ParallelCrossEntropy # noqa: F401
from .parallel_layers import LayerDesc # noqa: F401 from .parallel_layers import LayerDesc # noqa: F401
from .parallel_layers import PipelineLayer # noqa: F401 from .parallel_layers import PipelineLayer # noqa: F401
from .parallel_layers import RNGStatesTracker # noqa: F401 from .parallel_layers import RNGStatesTracker # noqa: F401
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from .mp_layers import VocabParallelEmbedding # noqa: F401 from .mp_layers import VocabParallelEmbedding # noqa: F401
from .mp_layers import ColumnParallelLinear # noqa: F401 from .mp_layers import ColumnParallelLinear # noqa: F401
from .mp_layers import RowParallelLinear # noqa: F401 from .mp_layers import RowParallelLinear # noqa: F401
from .mp_layers import ParallelCrossEntropy # noqa: F401
from .pp_layers import LayerDesc # noqa: F401 from .pp_layers import LayerDesc # noqa: F401
from .pp_layers import PipelineLayer # noqa: F401 from .pp_layers import PipelineLayer # noqa: F401
from .random import RNGStatesTracker # noqa: F401 from .random import RNGStatesTracker # noqa: F401
......
...@@ -18,6 +18,7 @@ from .random import get_rng_state_tracker ...@@ -18,6 +18,7 @@ from .random import get_rng_state_tracker
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle import framework from paddle import framework
from ...base import topology as tp from ...base import topology as tp
from paddle.autograd import PyLayer
__all__ = [] __all__ = []
...@@ -243,3 +244,19 @@ class RowParallelLinear(Layer): ...@@ -243,3 +244,19 @@ class RowParallelLinear(Layer):
output = output_ + self.bias if self.bias is not None else output_ output = output_ + self.bias if self.bias is not None else output_
return output return output
class ParallelCrossEntropy(Layer):
def __init__(self, name=None):
super(ParallelCrossEntropy, self).__init__()
self.name = name
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
)
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
)
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
def forward(self, input, label):
loss = paddle.distributed.collective._c_softmax_with_cross_entropy(
input, label, group=self.model_parallel_group)
return loss
...@@ -269,6 +269,63 @@ class TestDistTraning(unittest.TestCase): ...@@ -269,6 +269,63 @@ class TestDistTraning(unittest.TestCase):
np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy()) np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy())
def test_parallel_cross_entropy(self):
batch_size = 2
seq_length = 1
class_size_per_card = 2
vocab_size = class_size_per_card * self.model_parallel_size
seed = 1025
set_random_seed(seed)
rank_id = dist.get_rank()
# model_a
model_a = fleet.meta_parallel.ParallelCrossEntropy()
model_b = paddle.nn.CrossEntropyLoss(reduction="none")
paddle.seed(rank_id * 10)
random.seed(seed)
np.random.seed(seed)
for _ in range(5):
np_label = np.random.randint(0, vocab_size,
(batch_size, seq_length))
label = paddle.to_tensor(np_label, dtype="int64")
data = paddle.randn(
shape=[batch_size, seq_length, class_size_per_card],
dtype='float32')
data.stop_gradient = False
check_group = dist.new_group(list(range(self.model_parallel_size)))
integral_data = []
partial_data = data.clone().detach()
paddle.distributed.all_gather(
integral_data, partial_data, group=check_group)
integral_data = paddle.concat(integral_data, axis=-1)
integral_data = integral_data.detach().clone()
integral_data.stop_gradient = False
loss_a = model_a(data, label).sum() / batch_size
loss_b = model_b(integral_data, label).sum() / batch_size
print("loss_a: ", loss_a.numpy(), "loss_b: ", loss_b.numpy())
np.testing.assert_allclose(
loss_a.numpy(), loss_b.numpy(), rtol=1e-6)
loss_a.backward()
loss_b.backward()
integral_grad = []
partial_grad = data.grad.clone().detach()
paddle.distributed.all_gather(
integral_grad, partial_grad, group=check_group)
integral_grad = paddle.concat(integral_grad, axis=-1)
np.testing.assert_allclose(
integral_data.grad.numpy(), integral_grad.numpy(), rtol=1e-6)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册