未验证 提交 e6ec98fe 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move softmax with cross entropy kernel into phi (#40832)

* add cross_entropy_with_softmax phi kernel

* remove softmax_with_cross_entropy kernel

* add softmax_with_cross_entropy grad kernel

* remove original op kernel

* refine cross entropy impl

* fix pointer error

* revert kernel cu change

* fix xpu failed

* fix cinn failed

* fix npu failed

* add forward sig

* add check_nan_inf for pt kernel

* remove repeat cmake item

* fix unittest error
上级 d65a7a46
......@@ -35,7 +35,7 @@ USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(sigmoid);
USE_OP_ITSELF(tanh);
USE_OP_ITSELF(elementwise_mul);
USE_OP(softmax_with_cross_entropy);
USE_OP_ITSELF(softmax_with_cross_entropy);
USE_OP_ITSELF(reduce_mean);
USE_OP_ITSELF(reduce_sum);
USE_OP_ITSELF(reduce_sum_grad);
......@@ -83,6 +83,8 @@ PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sgd, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(slice, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cross_entropy_with_softmax, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cross_entropy_with_softmax_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sqrt, GPU, ALL_LAYOUT);
DECLARE_double(eager_delete_tensor_gb);
......
......@@ -87,7 +87,7 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey(
} else if (kernel_type.library_type_ == LibraryType::kKP) {
backend = phi::Backend::KPS;
} else {
// do
// do nothing
}
paddle::experimental::DataLayout layout = kernel_type.data_layout_;
paddle::experimental::DataType dtype =
......
......@@ -484,6 +484,11 @@ static void PreparedOpRunPtImpl(
pt_kernel(&pt_kernel_context);
}
if (FLAGS_check_nan_inf) {
framework::details::CheckOpHasNanOrInfInDygraph<VarType>(
op.Type(), outs, dev_ctx->GetPlace());
}
if (FLAGS_benchmark) {
dev_ctx->Wait();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
namespace paddle {
namespace platform {
......@@ -89,13 +90,11 @@ struct HardLabelCrossEntropyCPUFunctorImpl {
const int axis_dim_;
};
template <typename T>
class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& ctx, framework::Tensor* out,
const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel,
const int ignore_index, const int axis_dim) {
template <typename DeviceContext, typename T>
void CrossEntropyFunctor<DeviceContext, T>::operator()(
const DeviceContext& ctx, framework::Tensor* out,
const framework::Tensor* prob, const framework::Tensor* labels,
const bool softLabel, const int ignore_index, const int axis_dim) {
if (softLabel) {
const int batch_size = prob->dims()[0];
const int num_classes = prob->dims()[1];
......@@ -111,16 +110,18 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
.reshape(batch_axis_remain)
.sum(Eigen::DSizes<int, 1>(1)));
} else {
HardLabelCrossEntropyCPUFunctorImpl<T> functor_impl(
out, prob, labels, ignore_index, axis_dim);
framework::VisitIntDataType(
framework::TransToProtoVarType(labels->dtype()), functor_impl);
}
HardLabelCrossEntropyCPUFunctorImpl<T> functor_impl(out, prob, labels,
ignore_index, axis_dim);
framework::VisitIntDataType(framework::TransToProtoVarType(labels->dtype()),
functor_impl);
}
};
}
template class CrossEntropyFunctor<platform::CPUDeviceContext, float>;
template class CrossEntropyFunctor<platform::CPUDeviceContext, double>;
template class CrossEntropyFunctor<phi::CPUContext, float>;
template class CrossEntropyFunctor<phi::CPUContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace paddle {
namespace operators {
......@@ -93,13 +94,11 @@ struct HardLabelCrossEntropyCUDAFunctorImpl {
gpuStream_t stream_;
};
template <typename T>
class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& ctx,
framework::Tensor* out, const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel,
const int ignore_index, const int axis_dim) {
template <typename DeviceContext, typename T>
void CrossEntropyFunctor<DeviceContext, T>::operator()(
const DeviceContext& ctx, framework::Tensor* out,
const framework::Tensor* prob, const framework::Tensor* labels,
const bool softLabel, const int ignore_index, const int axis_dim) {
const T* prob_data = prob->data<T>();
T* loss_data = out->mutable_data<T>(ctx.GetPlace());
......@@ -126,13 +125,17 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
framework::VisitDataType(framework::TransToProtoVarType(labels->dtype()),
functor);
}
}
};
}
template class CrossEntropyFunctor<platform::CUDADeviceContext, float>;
template class CrossEntropyFunctor<platform::CUDADeviceContext, double>;
template class CrossEntropyFunctor<platform::CUDADeviceContext,
platform::float16>;
template class CrossEntropyFunctor<phi::GPUContext, float>;
template class CrossEntropyFunctor<phi::GPUContext, double>;
template class CrossEntropyFunctor<phi::GPUContext, platform::float16>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -29,9 +29,9 @@ using DataLayout = platform::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
void SoftmaxCUDNNFunctor<T>::operator()(
const platform::CUDADeviceContext& context, const framework::Tensor* X,
template <typename T, typename DeviceContext>
void SoftmaxCUDNNFunctor<T, DeviceContext>::operator()(
const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor xDesc;
......@@ -69,9 +69,9 @@ void SoftmaxCUDNNFunctor<T>::operator()(
#endif
}
template <typename T>
void SoftmaxGradCUDNNFunctor<T>::operator()(
const platform::CUDADeviceContext& context, const framework::Tensor* Y,
template <typename T, typename DeviceContext>
void SoftmaxGradCUDNNFunctor<T, DeviceContext>::operator()(
const DeviceContext& context, const framework::Tensor* Y,
const framework::Tensor* YGrad, framework::Tensor* XGrad) {
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor yDesc;
......@@ -116,19 +116,31 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
#endif
}
template class SoftmaxCUDNNFunctor<float>;
template class SoftmaxCUDNNFunctor<platform::float16>;
template class SoftmaxGradCUDNNFunctor<float>;
template class SoftmaxGradCUDNNFunctor<platform::float16>;
template class SoftmaxCUDNNFunctor<float, platform::CUDADeviceContext>;
template class SoftmaxCUDNNFunctor<platform::float16,
platform::CUDADeviceContext>;
template class SoftmaxGradCUDNNFunctor<float, platform::CUDADeviceContext>;
template class SoftmaxGradCUDNNFunctor<platform::float16,
platform::CUDADeviceContext>;
template class SoftmaxCUDNNFunctor<float, phi::GPUContext>;
template class SoftmaxCUDNNFunctor<platform::float16, phi::GPUContext>;
template class SoftmaxGradCUDNNFunctor<float, phi::GPUContext>;
template class SoftmaxGradCUDNNFunctor<platform::float16, phi::GPUContext>;
#if CUDNN_VERSION_MIN(8, 1, 0)
template class SoftmaxCUDNNFunctor<platform::bfloat16>;
template class SoftmaxGradCUDNNFunctor<platform::bfloat16>;
template class SoftmaxCUDNNFunctor<platform::bfloat16,
platform::CUDADeviceContext>;
template class SoftmaxGradCUDNNFunctor<platform::bfloat16,
platform::CUDADeviceContext>;
template class SoftmaxCUDNNFunctor<platform::bfloat16, phi::GPUContext>;
template class SoftmaxGradCUDNNFunctor<platform::bfloat16, phi::GPUContext>;
#endif
// MIOPEN do not support double
#ifndef PADDLE_WITH_HIP
template class SoftmaxCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<double>;
template class SoftmaxCUDNNFunctor<double, platform::CUDADeviceContext>;
template class SoftmaxGradCUDNNFunctor<double, platform::CUDADeviceContext>;
template class SoftmaxCUDNNFunctor<double, phi::GPUContext>;
template class SoftmaxGradCUDNNFunctor<double, phi::GPUContext>;
#endif
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,
......
......@@ -36,19 +36,18 @@ class SoftmaxGradFunctor {
};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T>
template <typename T, typename DeviceContext>
class SoftmaxCUDNNFunctor {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor* X, framework::Tensor* Y);
void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y);
};
template <typename T>
template <typename T, typename DeviceContext>
class SoftmaxGradCUDNNFunctor {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor* Y, const framework::Tensor* y_grad,
framework::Tensor* x_grad);
void operator()(const DeviceContext& context, const framework::Tensor* Y,
const framework::Tensor* y_grad, framework::Tensor* x_grad);
};
#endif
......
......@@ -58,7 +58,7 @@ class SequenceSoftmaxCUDNNKernel : public framework::OpKernel<T> {
phi::make_ddim({1UL, end_pos - start_pos});
x_i.Resize(dims_i);
out_i.Resize(dims_i);
math::SoftmaxCUDNNFunctor<T>()(
math::SoftmaxCUDNNFunctor<T, platform::CUDADeviceContext>()(
ctx.template device_context<platform::CUDADeviceContext>(), &x_i,
&out_i);
}
......@@ -93,7 +93,7 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
out_i.Resize(dims_i);
out_grad_i.Resize(dims_i);
x_grad_i.Resize(dims_i);
math::SoftmaxGradCUDNNFunctor<T>()(
math::SoftmaxGradCUDNNFunctor<T, platform::CUDADeviceContext>()(
ctx.template device_context<platform::CUDADeviceContext>(), &out_i,
&out_grad_i, &x_grad_i);
}
......
......@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle {
namespace operators {
......@@ -335,12 +336,6 @@ REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad,
ops::SoftmaxWithCrossEntropyGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<float>,
ops::SoftmaxWithCrossEntropyKernel<double>);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradKernel<float>,
ops::SoftmaxWithCrossEntropyGradKernel<double>);
REGISTER_OP_VERSION(softmax_with_cross_entropy)
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, typename Visitor>
struct SoftmaxWithCrossEntropyFunctor {
public:
SoftmaxWithCrossEntropyFunctor(const framework::ExecutionContext& context,
const framework::Tensor& labels,
const bool soft_label, const Visitor& visitor)
: context_(context),
labels_(labels),
soft_label_(soft_label),
visitor_(visitor) {}
template <typename U>
void apply() const {
visitor_.template Apply<U>(context_, labels_, soft_label_);
}
private:
const framework::ExecutionContext& context_;
const framework::Tensor& labels_;
const bool soft_label_;
const Visitor& visitor_;
};
template <typename T, typename Visitor>
static void RunSoftmaxWithCrossEntropyFunctor(
const framework::ExecutionContext& context, const Visitor& visitor) {
const auto* labels = context.Input<framework::Tensor>("Label");
const bool soft_label = context.Attr<bool>("soft_label");
SoftmaxWithCrossEntropyFunctor<T, Visitor> functor(context, *labels,
soft_label, visitor);
auto dtype = framework::TransToProtoVarType(labels->dtype());
if (soft_label) {
PADDLE_ENFORCE_EQ(
dtype, framework::DataTypeTrait<T>::DataType(),
platform::errors::InvalidArgument("The Input(Label) should be with the "
"same data type as Input(Logits)."));
functor.template apply<T>();
} else {
framework::VisitIntDataType(dtype, functor);
}
}
template <typename T>
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(context.GetPlace()), true,
platform::errors::Unimplemented("This kernel only runs on CPU."));
const bool use_softmax = context.Attr<bool>("use_softmax");
const Tensor* labels = context.Input<Tensor>("Label");
const bool soft_label = context.Attr<bool>("soft_label");
// do not with softmax op, and input is softmax
if (!use_softmax) {
const Tensor* softmax = context.Input<Tensor>("Logits");
Tensor* softmax_out = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss");
const int rank = softmax->dims().size();
const int axis =
phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = softmax->dims()[axis];
PADDLE_ENFORCE_GT(
axis_dim, 0,
platform::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
softmax_out->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
const int n = phi::funcs::SizeToAxis(axis, softmax->dims());
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"SizeToAxis of softmax is %d.",
n));
const int d = phi::funcs::SizeFromAxis(axis, softmax->dims());
Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
softmax_2d.ShareDataWith(*softmax).Resize({n, d});
labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
context.Attr<int>("ignore_index"), axis_dim);
// cause of input is softmax
// copy to output softmax, directly
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), softmax_out);
return;
}
const Tensor* logits = context.Input<Tensor>("Logits");
Tensor* softmax = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss");
const int rank = logits->dims().size();
const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logits->dims()[axis];
PADDLE_ENFORCE_GT(
axis_dim, 0,
platform::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
const int n = phi::funcs::SizeToAxis(axis, logits->dims());
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"SizeToAxis of logits is %d.",
n));
const int d = phi::funcs::SizeFromAxis(axis, logits->dims());
Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({n, d});
softmax_2d.ShareDataWith(*softmax).Resize({n, d});
labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, axis_dim, &logits_2d, &softmax_2d);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
context.Attr<int>("ignore_index"), axis_dim);
}
};
template <typename T>
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
RunSoftmaxWithCrossEntropyFunctor<T>(context, *this);
}
template <typename LabelT>
static void Apply(const framework::ExecutionContext& context,
const framework::Tensor& labels, const bool soft_label) {
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Loss"));
Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits"));
const Tensor* softmax = context.Input<Tensor>("Softmax");
const bool use_softmax = context.Attr<bool>("use_softmax");
if (logit_grad != softmax || !use_softmax) {
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), logit_grad);
}
auto ignore_index = context.Attr<int>("ignore_index");
const int rank = logit_grad->dims().size();
const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logit_grad->dims()[axis];
PADDLE_ENFORCE_GT(
axis_dim, 0,
platform::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
const int n = phi::funcs::SizeToAxis(axis, logit_grad->dims());
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"SizeToAxis of logit_grad is %d.",
n));
const int d = phi::funcs::SizeFromAxis(axis, logit_grad->dims());
Tensor logit_grad_2d, labels_2d, out_grad_2d;
logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n});
out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim});
auto out_grad_mat = framework::EigenMatrix<T>::From(out_grad_2d);
auto logit_grad_mat = framework::EigenMatrix<T>::From(logit_grad_2d);
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
if (!use_softmax) {
// use_softmax step1
if (soft_label) {
auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
(-lbl_mat / logit_grad_mat); // for each sample ,i is sample id
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
logit_grad_mat;
} else {
// use_softmax step2
const auto* label_data = labels.template data<LabelT>();
T* logit_grad_data = logit_grad->template data<T>();
const T* out_grad_data = out_grad->template data<T>();
const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) { // for each sample_1_dim
for (int j = 0; j < remain; j++) { // for each sample_other_dims
int idx = i * remain + j; // this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i
auto lbl = static_cast<int64_t>(label_data[idx]);
if (lbl == ignore_index) {
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
logit_grad_data[i * d + k * remain + j] = 0;
}
} else {
// only for this sample's label_idx, the label is 1, others is 0,
// so, only compute this label_idx's class
logit_grad_data[i * d + lbl * remain + j] =
(-1 / logit_grad_data[i * d + lbl * remain + j]) *
out_grad_data[idx];
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
if (k !=
label_data[idx]) { // label_data[idx]: this sample's label
logit_grad_data[i * d + k * remain + j] = 0;
}
}
}
}
}
}
return;
}
// for use_softmax=False, continue
if (soft_label) {
// when soft_label = True, ignore_index is not supported
auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
(logit_grad_mat - lbl_mat); // for each sample ,i is sample id
// 1) compute dy/dx by p_j - y_j or P-Y, where j is class id,
// P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is
// all class's labels
// 2) compute dy * dy/dx by Chain rule, dy=out_grad_mat[i]
// for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix
// operation
} else {
logit_grad_mat.device(place) =
logit_grad_mat * // element_wise multiply
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
const auto* label_data = labels.template data<LabelT>();
T* logit_grad_data = logit_grad->template data<T>();
const T* out_grad_data = out_grad->template data<T>();
const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) { // for each sample_1_dim
for (int j = 0; j < remain; j++) { // for each sample_other_dims
int idx = i * remain + j; // this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i
auto lbl = static_cast<int64_t>(label_data[idx]);
if (lbl == ignore_index) {
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
logit_grad_data[i * d + k * remain + j] = 0;
}
} else {
// only for this sample's label_idx, the label is 1, others is 0,
// so, only compute this label_idx's class
// for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] *
// remain + j] = [i * d + label_data[idx]]
// let idx_x = i * d + label_data[idx] * remain + j,
// logit_grad_data[idx_x] = logit_grad_data[idx_x] -
// out_grad_data[idx]
// note: logit_grad_mat = logit_grad_mat * out_grad_mat
// so: logit_grad_data[idx_x] = (logit_grad_data[idx_x] - 1) *
// out_grad_data[idx]
// means: dy/dp * dy= ( p - y ) * dy
logit_grad_data[i * d + lbl * remain + j] -= out_grad_data[idx];
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle {
namespace operators {
......
......@@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle {
namespace operators {
......
......@@ -12,20 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#ifdef PADDLE_WITH_XPU
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "xpu/refactor/math.h"
#include "xpu/refactor/nn.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
......
......@@ -33,6 +33,8 @@ Backend TransToPhiBackend(const phi::Place& place) {
return Backend::GPU;
} else if (allocation_type == phi::AllocationType::XPU) {
return Backend::XPU;
} else if (allocation_type == phi::AllocationType::NPU) {
return Backend::NPU;
} else if (allocation_type == phi::AllocationType::CUSTOM) {
return static_cast<Backend>(
static_cast<size_t>(Backend::NUM_BACKENDS) +
......
......@@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel
set(MANUAL_BUILD_KERNELS cross_entropy_kernel adam_kernel adamw_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel
gumbel_softmax_kernel gumbel_softmax_grad_kernel hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
......@@ -35,8 +35,10 @@ set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deforma
triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel rnn_kernel rnn_grad_kernel warpctc_kernel warpctc_grad_kernel)
kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper)
kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel)
kernel_library(cross_entropy_kernel DEPS ${COMMON_KERNEL_DEPS} softmax cross_entropy)
kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor)
kernel_library(deformable_conv_grad_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor)
kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)
kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function)
kernel_library(hierarchical_sigmoid_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code)
kernel_library(hierarchical_sigmoid_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code)
......@@ -57,7 +59,6 @@ kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce)
kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)
kernel_library(rnn_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
kernel_library(rnn_grad_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/cross_entropy_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
// TODO(chenweihang): move dispatch.h into phi/core
#include "paddle/phi/api/ext/dispatch.h"
namespace phi {
template <typename T, typename LabelT>
void CrossEntropyWithSoftmaxGradCPUKernel(const CPUContext& dev_ctx,
const DenseTensor& label,
const DenseTensor& softmax,
const DenseTensor& loss_grad,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* logits_grad) {
const DenseTensor* out_grad = &loss_grad;
DenseTensor* logit_grad = logits_grad;
if (logit_grad != &softmax || !use_softmax) {
phi::Copy(dev_ctx, softmax, dev_ctx.GetPlace(), false, logit_grad);
}
const int rank = logit_grad->dims().size();
const int axis_v = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = logit_grad->dims()[axis_v];
PADDLE_ENFORCE_GT(
axis_dim,
0,
phi::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
const int n = phi::funcs::SizeToAxis(axis_v, logit_grad->dims());
PADDLE_ENFORCE_GT(
n,
0,
phi::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"SizeToAxis of logit_grad is %d.",
n));
const int d = phi::funcs::SizeFromAxis(axis_v, logit_grad->dims());
DenseTensor logit_grad_2d(*logit_grad);
logit_grad_2d.Resize({n, d});
DenseTensor labels_2d(label);
labels_2d.Resize({n, label.numel() / n});
DenseTensor out_grad_2d(*out_grad);
out_grad_2d.Resize({n, d / axis_dim});
auto out_grad_mat = EigenMatrix<T>::From(out_grad_2d);
auto logit_grad_mat = EigenMatrix<T>::From(logit_grad_2d);
auto& place = *dev_ctx.eigen_device();
if (!use_softmax) {
// use_softmax step1
if (soft_label) {
auto lbl_mat = EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
(-lbl_mat / logit_grad_mat); // for each sample ,i is sample id
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
logit_grad_mat;
} else {
// use_softmax step2
const auto* label_data = label.data<LabelT>();
T* logit_grad_data = logit_grad->data<T>();
const T* out_grad_data = out_grad->data<T>();
const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) { // for each sample_1_dim
for (int j = 0; j < remain; j++) { // for each sample_other_dims
int idx = i * remain + j; // this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i
auto lbl = static_cast<int64_t>(label_data[idx]);
if (lbl == ignore_index) {
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
logit_grad_data[i * d + k * remain + j] = 0;
}
} else {
// only for this sample's label_idx, the label is 1, others is 0,
// so, only compute this label_idx's class
logit_grad_data[i * d + lbl * remain + j] =
(-1 / logit_grad_data[i * d + lbl * remain + j]) *
out_grad_data[idx];
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
if (k !=
label_data[idx]) { // label_data[idx]: this sample's label
logit_grad_data[i * d + k * remain + j] = 0;
}
}
}
}
}
}
return;
}
// for use_softmax=False, continue
if (soft_label) {
// when soft_label = True, ignore_index is not supported
auto lbl_mat = EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
(logit_grad_mat - lbl_mat);
// for each sample, i is sample id
// 1) compute dy/dx by p_j - y_j or P-Y, where j is class id,
// P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is
// all class's label
// 2) compute dy * dy/dx by Chain rule, dy=out_grad_mat[i]
// for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix
// operation
} else {
logit_grad_mat.device(place) =
logit_grad_mat * // element_wise multiply
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
const auto* label_data = label.data<LabelT>();
T* logit_grad_data = logit_grad->data<T>();
const T* out_grad_data = out_grad->data<T>();
const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) { // for each sample_1_dim
for (int j = 0; j < remain; j++) { // for each sample_other_dims
int idx = i * remain + j; // this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i
auto lbl = static_cast<int64_t>(label_data[idx]);
if (lbl == ignore_index) {
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
logit_grad_data[i * d + k * remain + j] = 0;
}
} else {
// only for this sample's label_idx, the label is 1, others is 0,
// so, only compute this label_idx's class
// for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] *
// remain + j] = [i * d + label_data[idx]]
// let idx_x = i * d + label_data[idx] * remain + j,
// logit_grad_data[idx_x] = logit_grad_data[idx_x] -
// out_grad_data[idx]
// note: logit_grad_mat = logit_grad_mat * out_grad_mat
// so: logit_grad_data[idx_x] = (logit_grad_data[idx_x] - 1) *
// out_grad_data[idx]
// means: dy/dp * dy= ( p - y ) * dy
logit_grad_data[i * d + lbl * remain + j] -= out_grad_data[idx];
}
}
}
}
}
template <typename T, typename Context>
void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& softmax,
const DenseTensor& loss_grad,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* logits_grad) {
auto dtype = label.dtype();
if (soft_label) {
PADDLE_ENFORCE_EQ(
dtype,
paddle::experimental::CppTypeToDataType<T>::Type(),
phi::errors::InvalidArgument("The Input(Label) should be with the "
"same data type as kernel data type."));
CrossEntropyWithSoftmaxGradCPUKernel<T, T>(dev_ctx,
label,
softmax,
loss_grad,
soft_label,
use_softmax,
numeric_stable_mode,
ignore_index,
axis,
logits_grad);
} else {
PD_DISPATCH_INTEGRAL_TYPES(
dtype, "CrossEntropyWithSoftmaxGradCPUKernel", ([&] {
CrossEntropyWithSoftmaxGradCPUKernel<T, data_t>(dev_ctx,
label,
softmax,
loss_grad,
soft_label,
use_softmax,
numeric_stable_mode,
ignore_index,
axis,
logits_grad);
}));
}
}
} // namespace phi
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
CPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/cross_entropy_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/softmax_kernel.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
namespace phi {
template <typename T>
void CrossEntropy(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
bool soft_label,
int ignore_index,
int axis,
DenseTensor* out) {
const int rank = x.dims().size();
const int axis_v = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = x.dims()[axis_v];
PADDLE_ENFORCE_GT(
axis_dim,
0,
phi::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
dev_ctx.template Alloc<T>(out);
const int n = phi::funcs::SizeToAxis(axis_v, x.dims());
PADDLE_ENFORCE_GT(
n,
0,
phi::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"SizeToAxis of softmax is %d.",
n));
const int d = phi::funcs::SizeFromAxis(axis_v, x.dims());
DenseTensor x_2d(x);
x_2d.Resize({n, d});
DenseTensor label_2d(label);
label_2d.Resize({n, label.numel() / n});
DenseTensor out_2d(*out);
out_2d.Resize({n, d / axis_dim});
paddle::operators::math::CrossEntropyFunctor<CPUContext, T>()(
dev_ctx, &out_2d, &x_2d, &label_2d, soft_label, ignore_index, axis_dim);
}
template <typename T, typename Context>
void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& logits,
const DenseTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* softmax,
DenseTensor* loss) {
// do not with softmax op, and input is softmax
if (!use_softmax) {
CrossEntropy<T>(
dev_ctx, logits, label, soft_label, ignore_index, axis, loss);
// cause of input is softmax, copy to output softmax, directly
phi::Copy<Context>(dev_ctx, logits, dev_ctx.GetPlace(), false, softmax);
return;
}
phi::SoftmaxKernel<T, Context>(dev_ctx, logits, axis, softmax);
CrossEntropy<T>(
dev_ctx, *softmax, label, soft_label, ignore_index, axis, loss);
}
} // namespace phi
PD_REGISTER_KERNEL(cross_entropy_with_softmax,
CPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxKernel,
float,
double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& softmax,
const DenseTensor& loss_grad,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* logits_grad);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
// The deformed product of operator iterative upgrade, there is no strict 2.0
// API corresponding to it! In 2.0 API paddle.nn.functional.cross_entropy,
// use_softmax has become an optional argument, which may be called
// CrossEntropyWithSoftmax more accurately, here we keep this kernel arguments
// same as original OpMaker, and if need a CrossEntropyKernel like
// paddle.nn.functional.cross_entropy, we can reuse this kernel
template <typename T, typename Context>
void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& logits,
const DenseTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* softmax,
DenseTensor* loss);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/cross_entropy_grad_kernel.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
// TODO(chenweihang): move dispatch.h into phi/core
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
namespace phi {
template <typename T>
__global__ void SoftLabelCrossEntropyGradientKernel(T* logit_grad,
const T* loss_grad,
const T* labels,
const int n,
const int d,
const int remain) {
int ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < n * d) {
int idx_n = ids / d;
int idx_remain = ids % remain;
int idx_loss = idx_n * remain + idx_remain;
logit_grad[ids] = loss_grad[idx_loss] * (-labels[ids] / logit_grad[ids]);
}
}
template <typename T, typename LabelT>
__global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad,
const LabelT* labels,
const int n,
const int d,
const int remain,
const int ignore_index) {
CUDA_KERNEL_LOOP(index, n * remain) {
int idx_n = index / remain;
int idx_remain = index % remain;
int tmp = static_cast<int>(labels[index]);
int idx = idx_n * d + tmp * remain + idx_remain;
if (ignore_index != tmp) {
logit_grad[idx] = -static_cast<T>(1.) / logit_grad[idx];
}
}
}
template <typename T, typename LabelT>
__global__ void ScaleCrossEntropyGradient(T* logit_grad,
const T* loss_grad,
const int num,
const int d,
const int remain,
const LabelT* labels,
const int ignore_index) {
CUDA_KERNEL_LOOP(index, num) {
int idx_n = index / d;
int idx_remain = index % remain;
int idx_lbl = idx_n * remain + idx_remain;
int k = (index % d) / remain;
auto lbl = static_cast<int64_t>(labels[idx_lbl]);
if (lbl == ignore_index || lbl != k) {
logit_grad[index] = static_cast<T>(0.);
} else {
logit_grad[index] *= loss_grad[idx_lbl];
}
}
}
template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
const T* loss_grad,
const T* labels,
const int64_t n,
const int64_t d,
const int64_t remain) {
int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < n * d) {
int64_t idx_n = ids / d;
int64_t idx_remain = ids % remain;
int64_t idx_loss = idx_n * remain + idx_remain;
logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
}
}
/*
Wrapper of softmax with cross entropy grad hard label.
*/
template <typename T, typename LabelT>
__global__ void SoftmaxWithCrossEntropyGradHardLabel(T* logits_grad,
const T* loss_grad,
const T* softmax,
const LabelT* labels,
const int64_t n,
const int64_t dim,
const int64_t d,
const int ignore_index) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx_n = idx / (d * dim);
int64_t idx_dim = (idx / d) % dim;
int64_t idx_d = idx % d;
int64_t ids = idx_n * d + idx_d;
if (idx < n * dim * d) {
auto lbl = static_cast<int64_t>(labels[ids]);
if (lbl == ignore_index) {
logits_grad[idx] = static_cast<T>(0.0);
} else if (lbl == idx_dim) {
logits_grad[idx] = (softmax[idx] - static_cast<T>(1.0)) * loss_grad[ids];
} else {
logits_grad[idx] = softmax[idx] * loss_grad[ids];
}
}
}
template <typename T, typename LabelT>
void CrossEntropyWithSoftmaxGradGPUKernel(const GPUContext& dev_ctx,
const DenseTensor& label,
const DenseTensor& softmax,
const DenseTensor& loss_grad,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* logits_grad) {
PADDLE_ENFORCE_EQ(
dev_ctx.GetPlace().GetType(),
phi::AllocationType::GPU,
phi::errors::Unavailable("softmax_with_cross_entropy operator's "
"CUDA kernel only runs on GPU device."));
const T* loss_grad_data = loss_grad.data<T>();
DenseTensor* logit_grad = logits_grad;
T* logit_grad_data = nullptr;
bool copy_flag = (logit_grad != &softmax && (!use_softmax || soft_label));
if (copy_flag) {
phi::Copy(dev_ctx, softmax, dev_ctx.GetPlace(), false, logit_grad);
logit_grad_data = logit_grad->data<T>();
} else {
logit_grad_data = dev_ctx.template Alloc<T>(logit_grad);
}
const int rank = logit_grad->dims().size();
const int axis_v = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = logit_grad->dims()[axis_v];
const int64_t n = phi::funcs::SizeToAxis(axis_v, logit_grad->dims());
const int64_t d = phi::funcs::SizeFromAxis(axis_v, logit_grad->dims());
const int64_t remain = d / axis_dim;
#ifdef __HIPCC__
int block = 256;
#else
int block = 512;
#endif
auto stream = dev_ctx.stream();
// do not with softmax op, and input is softmax
if (!use_softmax) {
if (soft_label) {
int grid = (n * d + block - 1) / block;
const T* label_data = label.data<T>();
SoftLabelCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else {
DenseTensor logits_grad_2d(*logit_grad);
logits_grad_2d.Resize({n, d});
int grid = (n * remain + block - 1) / block;
const auto* label_data = label.data<LabelT>();
HardLabelCrossEntropyGradientKernel<T,
LabelT><<<grid, block, 0, stream>>>(
logit_grad_data, label_data, n, d, remain, ignore_index);
int num = n * d;
grid = (num + block - 1) / block;
ScaleCrossEntropyGradient<T, LabelT><<<grid, block, 0, stream>>>(
logit_grad_data,
loss_grad_data,
num,
d,
remain,
label_data,
ignore_index);
}
return;
}
// with softmax, continue
if (soft_label) {
int64_t grid = (n * d + block - 1) / block;
const T* label_data = label.data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else {
const T* softmax_data = softmax.data<T>();
const auto* label_data = label.data<LabelT>();
int grid = (n * d + block - 1) / block;
SoftmaxWithCrossEntropyGradHardLabel<T><<<grid, block, 0, stream>>>(
logit_grad_data,
loss_grad_data,
softmax_data,
label_data,
n,
d / remain,
remain,
ignore_index);
}
}
template <typename T, typename Context>
void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& softmax,
const DenseTensor& loss_grad,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* logits_grad) {
auto dtype = label.dtype();
if (soft_label) {
PADDLE_ENFORCE_EQ(
dtype,
paddle::experimental::CppTypeToDataType<T>::Type(),
phi::errors::InvalidArgument("The Input(Label) should be with the "
"same data type as kernel data type."));
CrossEntropyWithSoftmaxGradGPUKernel<T, T>(dev_ctx,
label,
softmax,
loss_grad,
soft_label,
use_softmax,
numeric_stable_mode,
ignore_index,
axis,
logits_grad);
} else {
PD_DISPATCH_INTEGRAL_TYPES(
dtype, "CrossEntropyWithSoftmaxGradGPUKernel", ([&] {
CrossEntropyWithSoftmaxGradGPUKernel<T, data_t>(dev_ctx,
label,
softmax,
loss_grad,
soft_label,
use_softmax,
numeric_stable_mode,
ignore_index,
axis,
logits_grad);
}));
}
}
} // namespace phi
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double,
phi::dtype::float16) {}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/cross_entropy_kernel.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
......@@ -15,39 +21,43 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
// TODO(chenweihang): move dispatch.h into phi/core
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle {
namespace operators {
namespace phi {
#define ALIGN_BYTES 16
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
using Tensor = framework::Tensor;
namespace kps = phi::kps;
enum class SoftmaxMode { kSoftmax, kLogSoftmax, kCrossEntropy };
// Wrapper of log function. Use log(float32) for float16
template <typename T>
static __device__ __forceinline__ T Log(T x) {
using AccT = typename details::MPTypeTrait<T>::Type;
using AccT = typename dtype::MPTypeTrait<T>::Type;
AccT logx = std::log(static_cast<AccT>(x));
return math::TolerableValue<T>()(static_cast<T>(logx));
return paddle::operators::math::TolerableValue<T>()(static_cast<T>(logx));
}
// Wrapper of exp function. Use exp(float32) for float16
template <typename T>
static __device__ __forceinline__ T Exp(T x) {
using AccT = typename details::MPTypeTrait<T>::Type;
using AccT = typename dtype::MPTypeTrait<T>::Type;
AccT expx = std::exp(static_cast<AccT>(x));
return math::TolerableValue<T>()(static_cast<T>(expx));
return paddle::operators::math::TolerableValue<T>()(static_cast<T>(expx));
}
template <typename Tx, typename Ty = Tx>
......@@ -62,22 +72,114 @@ struct ExpAddFunctor {
Tx max;
};
// log2(value)
static inline int Log2Ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
/*
Cross entropy soft label with dynamic size on axis (log2_elements is
varibale).
- if the input is softmax,compute loss with softmax
- if the input is log_softmax, compute loss with log_softmax and update
softmax
*/
template <typename T, typename VecT, bool InLogMode = false>
__global__ void CrossEntropySoftLabel(T* loss,
T* softmaxwrt,
const T* softmax,
const T* labels,
const int n,
const int dim,
const int d,
int log2_elements) {
const int kDimCeil = 1 << log2_elements;
const int kVSize = sizeof(VecT) / sizeof(T);
enum class SoftmaxMode { kSoftmax, kLogSoftmax, kCrossEntropy };
#ifdef __HIPCC__
const int kThreadPerBlock = 256;
#else
const int kThreadPerBlock = 512;
#endif
const int kBatchPerBlock = 1;
const int kWarpSize = 32; // (dim < 32) ? dim : 32;
const int kBatchSize = 1;
const int kThreadPerBatch = kThreadPerBlock / kBatchPerBlock;
const int kWarpPerBatch = kThreadPerBatch / kWarpSize;
const int kIterations = (dim + kThreadPerBatch - 1) / kThreadPerBatch;
const int kIterationsV = (kIterations >= kVSize) ? (kIterations / kVSize) : 1;
const int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
T sum[kBatchSize]{static_cast<T>(0.0)};
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
int ids = first_batch + i;
if (ids >= n * d) break;
int idx_n = ids / d;
int idx_d = ids % d;
#pragma unroll
for (int it = 0; it < kIterations; ++it) {
int idx_dim = it * kThreadPerBatch + threadIdx.x;
int idx = idx_n * dim * d + idx_dim * d + idx_d;
if (idx_n < n && idx_dim < dim) {
VecT softmaxdata;
if (InLogMode) {
softmaxdata = reinterpret_cast<VecT*>(&softmaxwrt[idx])[0];
} else {
softmaxdata = reinterpret_cast<const VecT*>(&softmax[idx])[0];
}
VecT labelsdata = reinterpret_cast<const VecT*>(&labels[idx])[0];
T* softmaxptr = reinterpret_cast<T*>(&softmaxdata);
T* labelsptr = reinterpret_cast<T*>(&labelsdata);
#pragma unroll
for (int s = 0; s < kVSize; s++) {
if (InLogMode) {
sum[i] -= softmaxptr[s] * labelsptr[s];
softmaxptr[s] = Exp(softmaxptr[s]);
} else {
sum[i] -= Log(softmaxptr[s]) * labelsptr[s];
}
}
if (InLogMode) {
reinterpret_cast<VecT*>(&softmaxwrt[idx])[0] = softmaxdata;
}
}
}
}
phi::WarpReduceSum<T, kBatchSize, kWarpSize>(sum);
__syncthreads();
__shared__ T sumshare[kWarpPerBatch][kBatchPerBlock][kBatchSize];
if (threadIdx.x % kWarpSize == 0) {
#pragma unroll
for (int i = 0; i < kBatchSize; i++) {
sumshare[threadIdx.x / kWarpSize][threadIdx.y][i] = sum[i];
}
}
__syncthreads();
// write
if (threadIdx.x == 0) {
for (int i = 0; i < kBatchSize; i++) {
int ids = first_batch + i;
if (ids < n * d) {
loss[ids] = sumshare[0][threadIdx.y][i];
for (int s = 1; s < kWarpPerBatch; s++) {
loss[ids] += sumshare[s][threadIdx.y][i];
}
}
}
}
}
/*
Hard label cross entropy.
*/
template <typename T, typename LabelT, bool IgnoreIndex>
__global__ void CrossEntropyHardLabel(T* loss, const T* softmax,
const LabelT* labels, const int n,
const int dim, const int d,
__global__ void CrossEntropyHardLabel(T* loss,
const T* softmax,
const LabelT* labels,
const int n,
const int dim,
const int d,
const int ignore_idx) {
int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx_n = ids / d;
......@@ -111,9 +213,12 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax,
Output: loss and exp(input)
*/
template <typename T, typename LabelT, bool IgnoreIndex>
__global__ void CrossEntropyExpHardLabel(T* loss, T* softmax,
const LabelT* labels, const int n,
const int dim, const int d,
__global__ void CrossEntropyExpHardLabel(T* loss,
T* softmax,
const LabelT* labels,
const int n,
const int dim,
const int d,
const int ignore_idx) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx_n = idx / (d * dim);
......@@ -146,402 +251,131 @@ __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax,
}
}
/*
Core function of softmax with cross entropy forward
- softmax, SoftmaxMode=kSoftmax
- log softmax, SoftmaxMode=kLogSoftmax
- softmax with cross entropy hard label, SoftmaxMode=kCrossEntropy
The computation includes
- Compute max value: maxvalue_{i} = max_j src_{i,j}
- Compute sum of exp: s_{i} = sum_{j}{e^{src_{i,j} - maxvalue_{i}}}
- Compute: softmax_{i,j} = e^{src_{i,j} - maxvalue_{i}} / s_{i}
- Compute: logsoftmax_{i,j} = src_{i,j} - maxvalue_{i} - log(s_{i})
- Compute: loss_{i} = -logsoftmax[i,label[i]] (Hard label)
This computation results from following formula:
softmax_{i,j} = e^{src_{i,j}} / sum_{j}{e^{src_{i,j}}}
= e^{src_{i,j} - maxvalue_{i}}
/ sum_{j}{e^{src_{i,j} - maxvalue_{i}}}
= e^{src_{i,j} - maxvalue_{i}} / s_{i}
logsoftmax_{i,j} = log(softmax_{i,j})
= src_{i,j} - maxvalue_{i} - log(s_{i})
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use
shuffle api to compute max (sum) in one warp.
*/
template <typename T, typename LabelT, typename VecT, typename AccT,
int Log2Elements, SoftmaxMode mode, bool IgnoreIndex>
__global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src,
const LabelT* label, const int batch_size,
const int stride, const int element_count,
const int ignore_index) {
constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kVSize = sizeof(VecT) / sizeof(T);
constexpr int kIterations = kDimCeil / kWarpSize;
constexpr int kIterationsV =
(kIterations >= kVSize) ? (kIterations / kVSize) : 1;
constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
template <typename T, typename AccT, int VecSize, class ReduceFunctor>
__device__ __forceinline__ AccT ThreadReduce(const T* input,
int size,
const int offset,
AccT init,
ReduceFunctor reducer) {
using VecT = kps::details::VectorType<T, VecSize>;
int tid = threadIdx.x;
AccT val = init;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
if (offset > 0) {
input -= offset;
size += offset;
if (tid >= offset) {
val = reducer(val, input[tid]);
}
size -= blockDim.x;
input += blockDim.x;
}
int remain = size % (VecSize * blockDim.x);
T ins[VecSize];
VecT* ins_vec = reinterpret_cast<VecT*>(&ins);
// vector part
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
*ins_vec = reinterpret_cast<const VecT*>(input)[tid];
// max index to read
int idx_max_v[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; i++) {
int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
idx_max_v[i] = idx_max / kVSize;
for (int i = 0; i < VecSize; ++i) {
val = reducer(val, ins[i]);
}
}
// read data from global memory
AccT srcdata[kBatchSize][kIterationsV][kVSize];
// scalar part
tid = size - remain + threadIdx.x;
for (; tid < size; tid += blockDim.x) {
val = reducer(val, input[tid]);
}
return val;
}
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
// read data to srcdata: - KVSize==1, - KVSize>1
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
int src_idx = threadIdx.x + it * kWarpSize;
if (kVSize == 1) {
if (src_idx < idx_max_v[i]) {
srcdata[i][it][0] =
static_cast<AccT>(src[(first_batch + i) * stride + src_idx]);
template <typename T, bool IgnoreIndex>
__device__ __forceinline__ void ComputeLoss(T* loss,
const T loss_value,
const int label_id,
const int64_t label_value,
const int tid,
const int vec_size,
const int offset,
const int ignore_index) {
int loss_id = vec_size * tid + offset;
if (IgnoreIndex) {
if (label_value == loss_id) {
if (label_value == ignore_index) {
loss[label_id] = static_cast<T>(0.0f);
} else {
srcdata[i][it][0] = -std::numeric_limits<AccT>::infinity();
loss[label_id] = loss_value;
}
} else {
const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
if (src_idx < idx_max_v[i]) {
VecT srctmp = src_v[src_idx];
const T* srcinptr = reinterpret_cast<const T*>(&srctmp);
#pragma unroll
for (int s = 0; s < kVSize; s++) {
srcdata[i][it][s] = static_cast<AccT>(srcinptr[s]);
}
} else {
#pragma unroll
for (int s = 0; s < kVSize; s++) {
srcdata[i][it][s] = -std::numeric_limits<AccT>::infinity();
if (label_value == loss_id) {
loss[label_id] = loss_value;
}
}
}
template <typename T,
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
__device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
T* loss,
T* softmax,
const T* logits,
const LabelT* label,
int size,
const int offset,
const phi::LogSoftmaxForwardFunctor<AccT>& func,
const int ignore_index) {
using VecT = kps::details::VectorType<T, VecSize>;
int tid = threadIdx.x;
int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]);
const bool label_valid = label_value >= 0 && label_value < size;
int loss_id_offset = 0;
if (offset > 0) {
logits -= offset;
softmax -= offset;
size += offset;
loss_id_offset -= offset;
if (tid >= offset) {
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
1,
loss_id_offset,
ignore_index);
}
}
size -= blockDim.x;
logits += blockDim.x;
softmax += blockDim.x;
loss_id_offset += blockDim.x;
}
int remain = size % (VecSize * blockDim.x);
T ins[VecSize];
T outs[VecSize];
VecT* ins_vec = reinterpret_cast<VecT*>(&ins);
VecT* outs_vec = reinterpret_cast<VecT*>(&outs);
// vector part
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
// read
*ins_vec = reinterpret_cast<const VecT*>(logits)[tid];
// compute max value: maxvalue_{i} = max_j src_{i,j}
AccT max_value[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
// it = 0
AccT valmax = srcdata[i][0][0];
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s];
}
max_value[i] = valmax;
// it = 1, 2, ...
#pragma unroll
for (int it = 1; it < kIterationsV; ++it) {
AccT valmax = srcdata[i][it][0];
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s];
}
max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax;
}
}
phi::WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
// compute sum: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
AccT sum[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
// it = 0
if (mode == SoftmaxMode::kLogSoftmax ||
mode == SoftmaxMode::kCrossEntropy) {
sum[i] = std::exp(srcdata[i][0][0] - max_value[i]);
} else {
srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]);
sum[i] = srcdata[i][0][0];
}
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
if (mode == SoftmaxMode::kLogSoftmax ||
mode == SoftmaxMode::kCrossEntropy) {
sum[i] += std::exp(srcdata[i][0][s] - max_value[i]);
} else {
srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]);
sum[i] += srcdata[i][0][s];
}
}
// it = 1, 2, ...
#pragma unroll
for (int it = 1; it < kIterationsV; ++it) {
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (mode == SoftmaxMode::kLogSoftmax ||
mode == SoftmaxMode::kCrossEntropy) {
sum[i] += std::exp(srcdata[i][it][s] - max_value[i]);
} else {
srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]);
sum[i] += srcdata[i][it][s];
}
}
}
}
phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write data
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
if (mode == SoftmaxMode::kLogSoftmax ||
mode == SoftmaxMode::kCrossEntropy) {
sum[i] = std::log(sum[i]);
}
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
int idx = threadIdx.x + it * kWarpSize;
if (kVSize == 1) { // kVSize==1
if (idx < idx_max_v[i]) {
if (mode == SoftmaxMode::kLogSoftmax) { // log softmax
softmax[(first_batch + i) * stride + idx] =
srcdata[i][it][0] - max_value[i] - sum[i];
// softmax with cross entropy hard label
} else if (mode == SoftmaxMode::kCrossEntropy) {
AccT logsoftmax = srcdata[i][it][0] - max_value[i] - sum[i];
// softmax
softmax[(first_batch + i) * stride + idx] = std::exp(logsoftmax);
// label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize;
auto lbl = static_cast<int64_t>(label[first_batch + i]);
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (lbl == loss_idx) {
if (lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax;
} else {
loss[first_batch + i] = static_cast<T>(0.0);
}
}
} else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < element_count) {
if (lbl == loss_idx) {
loss[first_batch + i] = -logsoftmax;
}
} else {
loss[first_batch + i] = static_cast<T>(0.0);
}
}
} else { // softmax
softmax[(first_batch + i) * stride + idx] =
srcdata[i][it][0] / sum[i];
}
} else {
break;
}
} else { // KVSize>1
VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
VecT tmpdata;
T* tmpptr = reinterpret_cast<T*>(&tmpdata);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (mode == SoftmaxMode::kLogSoftmax) { // log softmax
tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i];
// softmax with cross entropy hard label
} else if (mode == SoftmaxMode::kCrossEntropy) {
AccT logsoftmax = srcdata[i][it][s] - max_value[i] - sum[i];
// softmax
tmpptr[s] = std::exp(logsoftmax);
// label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s;
auto lbl = static_cast<int64_t>(label[first_batch + i]);
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (lbl == loss_idx && lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax;
}
} else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < element_count) {
if (lbl == loss_idx) {
loss[first_batch + i] = -logsoftmax;
}
} else {
loss[first_batch + i] = static_cast<T>(0.0);
}
}
} else { // softmax
tmpptr[s] = srcdata[i][it][s] / sum[i];
}
}
if (idx < idx_max_v[i]) {
softmax_v[idx] = tmpdata;
} else {
break;
}
}
}
}
}
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \
case Log2Elements: \
WarpSoftmaxForward<T, LabelT, VecT, AccT, Log2Elements, mode, \
IgnoreIndex><<<blocks, threads, 0, stream>>>( \
loss, softmax, src, label, batch_size, stride, element_count, \
ignore_index); \
break;
/*
Wrapper of softmax with cross entropy forward hard label.
*/
template <typename T, typename LabelT, SoftmaxMode mode, bool IgnoreIndex>
void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src,
const LabelT* label, const int batch_size,
const int stride, const int element_count,
const int ignore_index, gpuStream_t stream) {
using AccT = typename details::MPTypeTrait<T>::Type;
// use 128 threads per block to maximimize gpu utilization
const int log2_elements = static_cast<int>(Log2Ceil(element_count));
const int kDimCeil = 1 << log2_elements;
int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
int batches_per_warp = (kDimCeil <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / kWarpSize);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_size + batches_per_block - 1) / batches_per_block;
dim3 threads(kWarpSize, warps_per_block, 1);
switch (log2_elements) {
SOFTMAX_WARP_FORWARD_CASE(0, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(1, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(2, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(3, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(4, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(5, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(6, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(7, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(8, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(9, LabelT, T, AccT);
default:
break;
}
}
template <typename T, bool IgnoreIndex>
__device__ __forceinline__ void ComputeLoss(T* loss, const T loss_value,
const int label_id,
const int64_t label_value,
const int tid, const int vec_size,
const int offset,
const int ignore_index) {
int loss_id = vec_size * tid + offset;
if (IgnoreIndex) {
if (label_value == loss_id) {
if (label_value == ignore_index) {
loss[label_id] = static_cast<T>(0.0f);
} else {
loss[label_id] = loss_value;
}
}
} else {
if (label_value == loss_id) {
loss[label_id] = loss_value;
}
}
}
template <typename T, typename AccT, int VecSize, class ReduceFunctor>
__device__ __forceinline__ AccT ThreadReduce(const T* input, int size,
const int offset, AccT init,
ReduceFunctor reducer) {
using VecT = kps::details::VectorType<T, VecSize>;
int tid = threadIdx.x;
AccT val = init;
if (offset > 0) {
input -= offset;
size += offset;
if (tid >= offset) {
val = reducer(val, input[tid]);
}
size -= blockDim.x;
input += blockDim.x;
}
int remain = size % (VecSize * blockDim.x);
T ins[VecSize];
VecT* ins_vec = reinterpret_cast<VecT*>(&ins);
// vector part
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
*ins_vec = reinterpret_cast<const VecT*>(input)[tid];
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
val = reducer(val, ins[i]);
}
}
// scalar part
tid = size - remain + threadIdx.x;
for (; tid < size; tid += blockDim.x) {
val = reducer(val, input[tid]);
}
return val;
}
template <typename T, typename AccT, typename LabelT, int VecSize,
bool IgnoreIndex>
__device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
T* loss, T* softmax, const T* logits, const LabelT* label, int size,
const int offset, const phi::LogSoftmaxForwardFunctor<AccT>& func,
const int ignore_index) {
using VecT = kps::details::VectorType<T, VecSize>;
int tid = threadIdx.x;
int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]);
const bool label_valid = label_value >= 0 && label_value < size;
int loss_id_offset = 0;
if (offset > 0) {
logits -= offset;
softmax -= offset;
size += offset;
loss_id_offset -= offset;
if (tid >= offset) {
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
label_id, label_value, tid, 1,
loss_id_offset, ignore_index);
}
}
size -= blockDim.x;
logits += blockDim.x;
softmax += blockDim.x;
loss_id_offset += blockDim.x;
}
int remain = size % (VecSize * blockDim.x);
T ins[VecSize];
T outs[VecSize];
VecT* ins_vec = reinterpret_cast<VecT*>(&ins);
VecT* outs_vec = reinterpret_cast<VecT*>(&outs);
// vector part
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
// read
*ins_vec = reinterpret_cast<const VecT*>(logits)[tid];
#pragma unroll
// compute
for (int i = 0; i < VecSize; ++i) {
......@@ -550,9 +384,14 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
label_id, label_value, tid, VecSize,
loss_id_offset + i, ignore_index);
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
VecSize,
loss_id_offset + i,
ignore_index);
}
}
......@@ -568,8 +407,13 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax), label_id,
label_value, tid, 1, loss_id_offset,
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
1,
loss_id_offset,
ignore_index);
}
}
......@@ -580,11 +424,19 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
}
}
template <typename T, typename AccT, typename LabelT, int VecSize,
template <typename T,
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
__device__ __forceinline__ void ScalarSoftmaxForwardImpl(
T* loss, T* softmax, const T* logits, const LabelT* label, const int size,
const phi::LogSoftmaxForwardFunctor<AccT>& func, const int ignore_index) {
T* loss,
T* softmax,
const T* logits,
const LabelT* label,
const int size,
const phi::LogSoftmaxForwardFunctor<AccT>& func,
const int ignore_index) {
int tid = threadIdx.x;
int remain = size % (VecSize * blockDim.x);
int label_id = blockIdx.x;
......@@ -605,295 +457,440 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
softmax[tid + i * blockDim.x] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
label_id, label_value, tid, VecSize, i,
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
VecSize,
i,
ignore_index);
}
}
}
// tail part
for (; tid < size; tid += blockDim.x) {
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
1,
0,
ignore_index);
}
}
// invalid label, write once
if (!label_valid && threadIdx.x == 0) {
loss[label_id] = static_cast<T>(0.0f);
}
}
template <typename T,
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
__global__ void VectorizedSoftmaxForward(T* loss,
T* softmax,
const T* logits,
const LabelT* label,
const int high_dim,
const int mid_dim,
const int ignore_index) {
using VecT = kps::details::VectorType<T, VecSize>;
// each block deal with one batch
logits += blockIdx.x * mid_dim;
softmax += blockIdx.x * mid_dim;
const int input_offset = ((uint64_t)logits) % ALIGN_BYTES / sizeof(T);
const int output_offset = ((uint64_t)softmax) % ALIGN_BYTES / sizeof(T);
// 1. reduce max
AccT max = ThreadReduce<T, AccT, VecSize, kps::MaxFunctor<AccT>>(
logits,
mid_dim,
input_offset,
-std::numeric_limits<AccT>::infinity(),
kps::MaxFunctor<AccT>());
max = kps::details::BlockXReduce<AccT, kps::MaxFunctor<AccT>>(
max, kps::MaxFunctor<AccT>());
// 2. reduce sum
AccT sum = ThreadReduce<T, AccT, VecSize, ExpAddFunctor<AccT>>(
logits,
mid_dim,
input_offset,
static_cast<AccT>(0),
ExpAddFunctor<AccT>(max));
sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
sum, kps::AddFunctor<AccT>());
// 3. softmax
phi::LogSoftmaxForwardFunctor<AccT> func(max, sum);
if (input_offset == output_offset) {
VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
loss,
softmax,
logits,
label,
mid_dim,
input_offset,
func,
ignore_index);
} else {
ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
loss, softmax, logits, label, mid_dim, func, ignore_index);
}
}
/*
Core function of softmax with cross entropy forward soft label.
The computation includes
- Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
- Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
- Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} -
log(sum[i]))}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements>
__global__ void WarpSoftmaxForwardSoftLabel(T* loss,
T* softmax,
const T* src,
const T* label,
const int batch_size,
const int stride,
const int element_count) {
const bool LogMode = true;
constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kVSize = sizeof(VecT) / sizeof(T);
constexpr int kIterations = kDimCeil / kWarpSize;
constexpr int kIterationsV =
(kIterations >= kVSize) ? (kIterations / kVSize) : 1;
constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
int local_batches = batch_size - first_batch;
if (local_batches > kBatchSize) {
local_batches = kBatchSize;
}
// read data from global memory
VecT srcdata[kBatchSize][kIterationsV];
VecT labeldata[kBatchSize][kIterationsV];
for (int i = 0; i < kBatchSize; ++i) {
const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
const VecT* label_v =
reinterpret_cast<const VecT*>(&label[(first_batch + i) * stride]);
// max index to read
int idx_max = (i < local_batches) ? element_count : 0;
int idx_max_v = idx_max / kVSize;
// read data
for (int it = 0; it < kIterationsV; ++it) {
int src_idx = threadIdx.x + it * kWarpSize;
if (src_idx < idx_max_v) {
srcdata[i][it] = src_v[src_idx];
labeldata[i][it] = label_v[src_idx];
} else {
#pragma unroll
for (int s = 0; s < kVSize; s++) {
reinterpret_cast<T*>(&srcdata[i][it])[s] =
-std::numeric_limits<AccT>::max();
reinterpret_cast<T*>(&labeldata[i][it])[s] = 0.0;
}
}
}
}
// compute max value
AccT max_value[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
max_value[i] = -std::numeric_limits<AccT>::infinity();
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
T* srcptr_v = reinterpret_cast<T*>(&srcdata[i][it]);
T valmax = srcptr_v[0];
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcptr_v[s]) ? valmax : srcptr_v[s];
}
max_value[i] = (max_value[i] > static_cast<AccT>(valmax))
? max_value[i]
: static_cast<AccT>(valmax);
}
}
phi::WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
// tail part
for (; tid < size; tid += blockDim.x) {
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax), label_id,
label_value, tid, 1, 0, ignore_index);
// compute sum
AccT sum[kBatchSize]{0.0};
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
T* srcptr_v = reinterpret_cast<T*>(&srcdata[i][it]);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
sum[i] += std::exp(static_cast<AccT>(srcptr_v[s]) - max_value[i]);
} else {
srcptr_v[s] = std::exp(static_cast<AccT>(srcptr_v[s]) - max_value[i]);
sum[i] += static_cast<AccT>(srcptr_v[s]);
}
}
// invalid label, write once
if (!label_valid && threadIdx.x == 0) {
loss[label_id] = static_cast<T>(0.0f);
}
}
template <typename T, typename AccT, typename LabelT, int VecSize,
bool IgnoreIndex>
__global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits,
const LabelT* label,
const int high_dim, const int mid_dim,
const int ignore_index) {
using VecT = kps::details::VectorType<T, VecSize>;
// each block deal with one batch
logits += blockIdx.x * mid_dim;
softmax += blockIdx.x * mid_dim;
}
phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
const int input_offset = ((uint64_t)logits) % ALIGN_BYTES / sizeof(T);
const int output_offset = ((uint64_t)softmax) % ALIGN_BYTES / sizeof(T);
// log_softmax and loss
AccT sumloss[kBatchSize]{0.0};
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
if (i >= local_batches) break;
// 1. reduce max
AccT max = ThreadReduce<T, AccT, VecSize, kps::MaxFunctor<AccT>>(
logits, mid_dim, input_offset, -std::numeric_limits<AccT>::infinity(),
kps::MaxFunctor<AccT>());
max = kps::details::BlockXReduce<AccT, kps::MaxFunctor<AccT>>(
max, kps::MaxFunctor<AccT>());
VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
// 2. reduce sum
AccT sum = ThreadReduce<T, AccT, VecSize, ExpAddFunctor<AccT>>(
logits, mid_dim, input_offset, static_cast<AccT>(0),
ExpAddFunctor<AccT>(max));
sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
sum, kps::AddFunctor<AccT>());
// max index to write
int idx_max = (i < local_batches) ? element_count : 0;
int idx_max_v = idx_max / kVSize;
// 3. softmax
phi::LogSoftmaxForwardFunctor<AccT> func(max, sum);
if (input_offset == output_offset) {
VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
loss, softmax, logits, label, mid_dim, input_offset, func,
ignore_index);
if (LogMode) {
sum[i] = std::log(sum[i]);
}
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
T* srcvp = reinterpret_cast<T*>(&srcdata[i][it]);
T* labelvp = reinterpret_cast<T*>(&labeldata[i][it]);
VecT tmpv;
T* tmpvp = reinterpret_cast<T*>(&tmpv);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
AccT logsoftmax = static_cast<AccT>(srcvp[s]) - max_value[i] - sum[i];
sumloss[i] -= logsoftmax * static_cast<AccT>(labelvp[s]);
tmpvp[s] = std::exp(logsoftmax);
} else {
ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
loss, softmax, logits, label, mid_dim, func, ignore_index);
tmpvp[s] = static_cast<AccT>(srcvp[s]) / sum[i];
}
}
template <typename T, typename LabelT, bool IgnoreIndex>
void LaunchVectorizedSoftmaxForward(T* loss, T* softmax, const T* logits,
const LabelT* label, const int high_dim,
const int mid_dim, const int ignore_index,
gpuStream_t stream) {
using AccT = typename details::MPTypeTrait<T>::Type;
constexpr int vec_size = sizeof(float4) / sizeof(T);
const int max_num_threads = 1024;
int max_block_size = std::min(mid_dim / vec_size, max_num_threads);
if (vec_size > 1) {
max_block_size /= 2;
}
int block_size = 1;
while (block_size < max_block_size) {
block_size *= 2;
int idx = threadIdx.x + it * kWarpSize;
if (idx < idx_max_v) {
softmax_v[idx] = tmpv;
}
}
block_size = std::max(block_size, kps::details::kWarpSize);
dim3 grids(high_dim);
dim3 blocks(block_size);
VectorizedSoftmaxForward<T, AccT, LabelT, vec_size,
IgnoreIndex><<<grids, blocks, 0, stream>>>(
loss, softmax, logits, label, high_dim, mid_dim, ignore_index);
}
/*
Wrapper of softmax with cross entropy hard label.
- SwitchWarpSoftmaxForward for small size when axis == -1
- LaunchVectorizedSoftmaxForward for large size when axis == -1
- cudnn function for axis != -1
*/
template <typename T, typename LabelT, bool IgnoreIndex>
static void SoftmaxWithCrossEntropyHardLabel(
const platform::CUDADeviceContext& ctx, int rank, int axis,
const T* logits_data, const LabelT* labels_data, T* loss_data,
T* softmax_data, int N, int dim, int D, const int ignore_index) {
auto stream = ctx.stream();
constexpr int max_dim = 320;
if (D == 1) {
if (dim <= max_dim) { // small size
const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(
loss_data, softmax_data, logits_data, labels_data, N, dim, dim,
ignore_index, stream);
} else { // large size
LaunchVectorizedSoftmaxForward<T, LabelT, IgnoreIndex>(
loss_data, softmax_data, logits_data, labels_data, N, dim,
ignore_index, stream);
}
} else {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#else
cudnnTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#endif
auto handle = ctx.cudnn_handle();
// loss
phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sumloss);
#ifdef PADDLE_WITH_HIP
auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
: MIOPEN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
handle, platform::CudnnDataType<T>::kOne(), descp, logits_data,
platform::CudnnDataType<T>::kZero(), descp, softmax_data,
MIOPEN_SOFTMAX_LOG, mode));
#else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward(
handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(),
descp, logits_data, platform::CudnnDataType<T>::kZero(), descp,
softmax_data));
#endif
int threads = 128;
int blocks = (N * dim * D + threads - 1) / threads;
// compute cross entropy, input is log softmax
CrossEntropyExpHardLabel<T, LabelT,
IgnoreIndex><<<blocks, threads, 0, stream>>>(
loss_data, softmax_data, labels_data, N, dim, D, ignore_index);
for (int i = 0; i < kBatchSize; i++) {
if (i >= local_batches) break;
loss[first_batch + i] = sumloss[i];
}
}
#define SOFTMAX_WARP_FORWARD_SOFT_CASE(Log2Elements, VecT, AccT) \
case Log2Elements: \
WarpSoftmaxForwardSoftLabel<T, \
VecT, \
AccT, \
Log2Elements><<<blocks, threads, 0, stream>>>( \
loss, softmax, src, label, batch_size, stride, element_count); \
break;
/*
Wrapper of softmax with cross entropy grad hard label.
Wrapper of softmax with cross entropy forward soft label.
*/
template <typename T, typename LabelT>
__global__ void SoftmaxWithCrossEntropyGradHardLabel(
T* logits_grad, const T* loss_grad, const T* softmax, const LabelT* labels,
const int64_t n, const int64_t dim, const int64_t d,
const int ignore_index) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx_n = idx / (d * dim);
int64_t idx_dim = (idx / d) % dim;
int64_t idx_d = idx % d;
int64_t ids = idx_n * d + idx_d;
if (idx < n * dim * d) {
auto lbl = static_cast<int64_t>(labels[ids]);
if (lbl == ignore_index) {
logits_grad[idx] = static_cast<T>(0.0);
} else if (lbl == idx_dim) {
logits_grad[idx] = (softmax[idx] - static_cast<T>(1.0)) * loss_grad[ids];
} else {
logits_grad[idx] = softmax[idx] * loss_grad[ids];
}
template <typename T>
void SwitchWarpSoftmaxForwardSoftLabel(const int blocks,
const dim3 threads,
gpuStream_t stream,
T* loss,
T* softmax,
const T* src,
const T* label,
const int batch_size,
const int stride,
const int element_count,
const int log2_elements) {
using AccT = typename dtype::MPTypeTrait<T>::Type;
switch (log2_elements) {
SOFTMAX_WARP_FORWARD_SOFT_CASE(0, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(1, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(2, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(3, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(4, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(5, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(6, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(7, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(8, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(9, T, AccT);
default:
break;
}
}
/*
Cross entropy soft label with dynamic size on axis (log2_elements is
varibale).
- if the input is softmax,compute loss with softmax
- if the input is log_softmax, compute loss with log_softmax and update
softmax
*/
template <typename T, typename VecT, bool InLogMode = false>
__global__ void CrossEntropySoftLabel(T* loss, T* softmaxwrt, const T* softmax,
const T* labels, const int n,
const int dim, const int d,
int log2_elements) {
const int kDimCeil = 1 << log2_elements;
const int kVSize = sizeof(VecT) / sizeof(T);
template <typename T>
static void SoftmaxWithCrossEntropySoftLabel(const GPUContext& dev_ctx,
const int rank,
const int axis,
const T* logits_data,
const T* labels_data,
T* softmax_data,
T* loss_data,
int N,
int dim,
int D) {
#ifdef __HIPCC__
const int kThreadPerBlock = 256;
constexpr int kMaxBlockDim = 256;
#else
const int kThreadPerBlock = 512;
constexpr int kMaxBlockDim = 512;
#endif
const int kBatchPerBlock = 1;
const int kWarpSize = 32; // (dim < 32) ? dim : 32;
const int kBatchSize = 1;
const int kThreadPerBatch = kThreadPerBlock / kBatchPerBlock;
const int kWarpPerBatch = kThreadPerBatch / kWarpSize;
int64_t block_dim = dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(dim)));
const int kIterations = (dim + kThreadPerBatch - 1) / kThreadPerBatch;
const int kIterationsV = (kIterations >= kVSize) ? (kIterations / kVSize) : 1;
int64_t grid_dim = N * D;
constexpr int max_dim = 320;
const int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
const int kDimLog2 = static_cast<int>(Log2Ceil(dim));
const int kDimCeil = 1 << kDimLog2;
auto stream = dev_ctx.stream();
T sum[kBatchSize]{static_cast<T>(0.0)};
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
int ids = first_batch + i;
if (ids >= n * d) break;
int idx_n = ids / d;
int idx_d = ids % d;
#pragma unroll
for (int it = 0; it < kIterations; ++it) {
int idx_dim = it * kThreadPerBatch + threadIdx.x;
int idx = idx_n * dim * d + idx_dim * d + idx_d;
if (D == 1 && dim <= max_dim) {
int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
int batches_per_warp = (kDimCeil <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / kWarpSize);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(kWarpSize, warps_per_block, 1);
SwitchWarpSoftmaxForwardSoftLabel<T>(blocks,
threads,
stream,
loss_data,
softmax_data,
logits_data,
labels_data,
N,
dim,
dim,
kDimLog2);
if (idx_n < n && idx_dim < dim) {
VecT softmaxdata;
if (InLogMode) {
softmaxdata = reinterpret_cast<VecT*>(&softmaxwrt[idx])[0];
} else {
softmaxdata = reinterpret_cast<const VecT*>(&softmax[idx])[0];
}
VecT labelsdata = reinterpret_cast<const VecT*>(&labels[idx])[0];
T* softmaxptr = reinterpret_cast<T*>(&softmaxdata);
T* labelsptr = reinterpret_cast<T*>(&labelsdata);
#pragma unroll
for (int s = 0; s < kVSize; s++) {
if (InLogMode) {
sum[i] -= softmaxptr[s] * labelsptr[s];
softmaxptr[s] = Exp(softmaxptr[s]);
} else {
sum[i] -= Log(softmaxptr[s]) * labelsptr[s];
}
}
if (InLogMode) {
reinterpret_cast<VecT*>(&softmaxwrt[idx])[0] = softmaxdata;
}
}
}
}
phi::WarpReduceSum<T, kBatchSize, kWarpSize>(sum);
__syncthreads();
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#else
cudnnTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#endif
__shared__ T sumshare[kWarpPerBatch][kBatchPerBlock][kBatchSize];
if (threadIdx.x % kWarpSize == 0) {
#pragma unroll
for (int i = 0; i < kBatchSize; i++) {
sumshare[threadIdx.x / kWarpSize][threadIdx.y][i] = sum[i];
}
}
__syncthreads();
auto handle = dev_ctx.cudnn_handle();
// write
if (threadIdx.x == 0) {
for (int i = 0; i < kBatchSize; i++) {
int ids = first_batch + i;
if (ids < n * d) {
loss[ids] = sumshare[0][threadIdx.y][i];
for (int s = 1; s < kWarpPerBatch; s++) {
loss[ids] += sumshare[s][threadIdx.y][i];
}
}
}
#ifdef PADDLE_WITH_HIP
auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
: MIOPEN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenSoftmaxForward_V2(
handle,
paddle::platform::CudnnDataType<T>::kOne(),
descp,
logits_data,
paddle::platform::CudnnDataType<T>::kZero(),
descp,
softmax_data,
MIOPEN_SOFTMAX_LOG,
mode));
#else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSoftmaxForward(
handle,
CUDNN_SOFTMAX_LOG,
mode,
paddle::platform::CudnnDataType<T>::kOne(),
descp,
logits_data,
paddle::platform::CudnnDataType<T>::kZero(),
descp,
softmax_data));
#endif
const int kDimLog2 = static_cast<int>(Log2Ceil(dim));
const int kDimCeil = 1 << kDimLog2;
#ifdef __HIPCC__
int kThreadPerBlock = 256;
#else
int kThreadPerBlock = 512;
#endif
int kBatchPerBlock = 1;
int blocks = (N * D + kBatchPerBlock - 1) / kBatchPerBlock;
dim3 threads(kThreadPerBlock / kBatchPerBlock, kBatchPerBlock, 1);
CrossEntropySoftLabel<T, T, true><<<blocks, threads, 0, stream>>>(
loss_data, softmax_data, NULL, labels_data, N, dim, D, kDimLog2);
}
}
/*
Core function of softmax with cross entropy forward soft label.
The computation includes
- Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
- Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
- Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} -
log(sum[i]))}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
Core function of softmax with cross entropy forward
- softmax, SoftmaxMode=kSoftmax
- log softmax, SoftmaxMode=kLogSoftmax
- softmax with cross entropy hard label, SoftmaxMode=kCrossEntropy
The computation includes
- Compute max value: maxvalue_{i} = max_j src_{i,j}
- Compute sum of exp: s_{i} = sum_{j}{e^{src_{i,j} - maxvalue_{i}}}
- Compute: softmax_{i,j} = e^{src_{i,j} - maxvalue_{i}} / s_{i}
- Compute: logsoftmax_{i,j} = src_{i,j} - maxvalue_{i} - log(s_{i})
- Compute: loss_{i} = -logsoftmax[i,label[i]] (Hard label)
This computation results from following formula:
softmax_{i,j} = e^{src_{i,j}} / sum_{j}{e^{src_{i,j}}}
= e^{src_{i,j} - maxvalue_{i}}
/ sum_{j}{e^{src_{i,j} - maxvalue_{i}}}
= e^{src_{i,j} - maxvalue_{i}} / s_{i}
logsoftmax_{i,j} = log(softmax_{i,j})
= src_{i,j} - maxvalue_{i} - log(s_{i})
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use
shuffle api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements>
__global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src,
const T* label,
template <typename T,
typename LabelT,
typename VecT,
typename AccT,
int Log2Elements,
SoftmaxMode mode,
bool IgnoreIndex>
__global__ void WarpSoftmaxForward(T* loss,
T* softmax,
const T* src,
const LabelT* label,
const int batch_size,
const int stride,
const int element_count) {
const bool LogMode = true;
const int element_count,
const int ignore_index) {
constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kVSize = sizeof(VecT) / sizeof(T);
......@@ -903,375 +900,462 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src,
constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
int local_batches = batch_size - first_batch;
if (local_batches > kBatchSize) {
local_batches = kBatchSize;
// max index to read
int idx_max_v[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; i++) {
int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
idx_max_v[i] = idx_max / kVSize;
}
// read data from global memory
VecT srcdata[kBatchSize][kIterationsV];
VecT labeldata[kBatchSize][kIterationsV];
AccT srcdata[kBatchSize][kIterationsV][kVSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
const VecT* label_v =
reinterpret_cast<const VecT*>(&label[(first_batch + i) * stride]);
// max index to read
int idx_max = (i < local_batches) ? element_count : 0;
int idx_max_v = idx_max / kVSize;
// read data
// read data to srcdata: - KVSize==1, - KVSize>1
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
int src_idx = threadIdx.x + it * kWarpSize;
if (src_idx < idx_max_v) {
srcdata[i][it] = src_v[src_idx];
labeldata[i][it] = label_v[src_idx];
if (kVSize == 1) {
if (src_idx < idx_max_v[i]) {
srcdata[i][it][0] =
static_cast<AccT>(src[(first_batch + i) * stride + src_idx]);
} else {
srcdata[i][it][0] = -std::numeric_limits<AccT>::infinity();
}
} else {
const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
if (src_idx < idx_max_v[i]) {
VecT srctmp = src_v[src_idx];
const T* srcinptr = reinterpret_cast<const T*>(&srctmp);
#pragma unroll
for (int s = 0; s < kVSize; s++) {
reinterpret_cast<T*>(&srcdata[i][it])[s] =
-std::numeric_limits<AccT>::max();
reinterpret_cast<T*>(&labeldata[i][it])[s] = 0.0;
srcdata[i][it][s] = static_cast<AccT>(srcinptr[s]);
}
} else {
#pragma unroll
for (int s = 0; s < kVSize; s++) {
srcdata[i][it][s] = -std::numeric_limits<AccT>::infinity();
}
}
}
}
}
// compute max value
// compute max value: maxvalue_{i} = max_j src_{i,j}
AccT max_value[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
max_value[i] = -std::numeric_limits<AccT>::infinity();
// it = 0
AccT valmax = srcdata[i][0][0];
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
T* srcptr_v = reinterpret_cast<T*>(&srcdata[i][it]);
T valmax = srcptr_v[0];
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s];
}
max_value[i] = valmax;
// it = 1, 2, ...
#pragma unroll
for (int it = 1; it < kIterationsV; ++it) {
AccT valmax = srcdata[i][it][0];
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcptr_v[s]) ? valmax : srcptr_v[s];
valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s];
}
max_value[i] = (max_value[i] > static_cast<AccT>(valmax))
? max_value[i]
: static_cast<AccT>(valmax);
max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax;
}
}
phi::WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
// compute sum
AccT sum[kBatchSize]{0.0};
// compute sum: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
AccT sum[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
// it = 0
if (mode == SoftmaxMode::kLogSoftmax ||
mode == SoftmaxMode::kCrossEntropy) {
sum[i] = std::exp(srcdata[i][0][0] - max_value[i]);
} else {
srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]);
sum[i] = srcdata[i][0][0];
}
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
T* srcptr_v = reinterpret_cast<T*>(&srcdata[i][it]);
for (int s = 1; s < kVSize; ++s) {
if (mode == SoftmaxMode::kLogSoftmax ||
mode == SoftmaxMode::kCrossEntropy) {
sum[i] += std::exp(srcdata[i][0][s] - max_value[i]);
} else {
srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]);
sum[i] += srcdata[i][0][s];
}
}
// it = 1, 2, ...
#pragma unroll
for (int it = 1; it < kIterationsV; ++it) {
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
sum[i] += std::exp(static_cast<AccT>(srcptr_v[s]) - max_value[i]);
if (mode == SoftmaxMode::kLogSoftmax ||
mode == SoftmaxMode::kCrossEntropy) {
sum[i] += std::exp(srcdata[i][it][s] - max_value[i]);
} else {
srcptr_v[s] = std::exp(static_cast<AccT>(srcptr_v[s]) - max_value[i]);
sum[i] += static_cast<AccT>(srcptr_v[s]);
srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]);
sum[i] += srcdata[i][it][s];
}
}
}
}
phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// log_softmax and loss
AccT sumloss[kBatchSize]{0.0};
// write data
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
if (i >= local_batches) break;
if (mode == SoftmaxMode::kLogSoftmax ||
mode == SoftmaxMode::kCrossEntropy) {
sum[i] = std::log(sum[i]);
}
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
int idx = threadIdx.x + it * kWarpSize;
if (kVSize == 1) { // kVSize==1
if (idx < idx_max_v[i]) {
if (mode == SoftmaxMode::kLogSoftmax) { // log softmax
softmax[(first_batch + i) * stride + idx] =
srcdata[i][it][0] - max_value[i] - sum[i];
// softmax with cross entropy hard label
} else if (mode == SoftmaxMode::kCrossEntropy) {
AccT logsoftmax = srcdata[i][it][0] - max_value[i] - sum[i];
// softmax
softmax[(first_batch + i) * stride + idx] = std::exp(logsoftmax);
// label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize;
auto lbl = static_cast<int64_t>(label[first_batch + i]);
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (lbl == loss_idx) {
if (lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax;
} else {
loss[first_batch + i] = static_cast<T>(0.0);
}
}
} else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < element_count) {
if (lbl == loss_idx) {
loss[first_batch + i] = -logsoftmax;
}
} else {
loss[first_batch + i] = static_cast<T>(0.0);
}
}
} else { // softmax
softmax[(first_batch + i) * stride + idx] =
srcdata[i][it][0] / sum[i];
}
} else {
break;
}
} else { // KVSize>1
VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
// max index to write
int idx_max = (i < local_batches) ? element_count : 0;
int idx_max_v = idx_max / kVSize;
if (LogMode) {
sum[i] = std::log(sum[i]);
}
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
T* srcvp = reinterpret_cast<T*>(&srcdata[i][it]);
T* labelvp = reinterpret_cast<T*>(&labeldata[i][it]);
VecT tmpv;
T* tmpvp = reinterpret_cast<T*>(&tmpv);
VecT tmpdata;
T* tmpptr = reinterpret_cast<T*>(&tmpdata);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
AccT logsoftmax = static_cast<AccT>(srcvp[s]) - max_value[i] - sum[i];
sumloss[i] -= logsoftmax * static_cast<AccT>(labelvp[s]);
tmpvp[s] = std::exp(logsoftmax);
if (mode == SoftmaxMode::kLogSoftmax) { // log softmax
tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i];
// softmax with cross entropy hard label
} else if (mode == SoftmaxMode::kCrossEntropy) {
AccT logsoftmax = srcdata[i][it][s] - max_value[i] - sum[i];
// softmax
tmpptr[s] = std::exp(logsoftmax);
// label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s;
auto lbl = static_cast<int64_t>(label[first_batch + i]);
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (lbl == loss_idx && lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax;
}
} else {
tmpvp[s] = static_cast<AccT>(srcvp[s]) / sum[i];
// IgnoreIndex is false
if (lbl >= 0 && lbl < element_count) {
if (lbl == loss_idx) {
loss[first_batch + i] = -logsoftmax;
}
} else {
loss[first_batch + i] = static_cast<T>(0.0);
}
int idx = threadIdx.x + it * kWarpSize;
if (idx < idx_max_v) {
softmax_v[idx] = tmpv;
}
} else { // softmax
tmpptr[s] = srcdata[i][it][s] / sum[i];
}
}
if (idx < idx_max_v[i]) {
softmax_v[idx] = tmpdata;
} else {
break;
}
}
}
// loss
phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sumloss);
for (int i = 0; i < kBatchSize; i++) {
if (i >= local_batches) break;
loss[first_batch + i] = sumloss[i];
}
}
#define SOFTMAX_WARP_FORWARD_SOFT_CASE(Log2Elements, VecT, AccT) \
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \
case Log2Elements: \
WarpSoftmaxForwardSoftLabel<T, VecT, AccT, \
Log2Elements><<<blocks, threads, 0, stream>>>( \
loss, softmax, src, label, batch_size, stride, element_count); \
WarpSoftmaxForward<T, \
LabelT, \
VecT, \
AccT, \
Log2Elements, \
mode, \
IgnoreIndex><<<blocks, threads, 0, stream>>>( \
loss, \
softmax, \
src, \
label, \
batch_size, \
stride, \
element_count, \
ignore_index); \
break;
/*
Wrapper of softmax with cross entropy forward soft label.
Wrapper of softmax with cross entropy forward hard label.
*/
template <typename T>
void SwitchWarpSoftmaxForwardSoftLabel(const int blocks, const dim3 threads,
gpuStream_t stream, T* loss, T* softmax,
const T* src, const T* label,
const int batch_size, const int stride,
template <typename T, typename LabelT, SoftmaxMode mode, bool IgnoreIndex>
void SwitchWarpSoftmaxForward(T* loss,
T* softmax,
const T* src,
const LabelT* label,
const int batch_size,
const int stride,
const int element_count,
const int log2_elements) {
using AccT = typename details::MPTypeTrait<T>::Type;
switch (log2_elements) {
SOFTMAX_WARP_FORWARD_SOFT_CASE(0, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(1, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(2, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(3, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(4, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(5, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(6, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(7, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(8, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(9, T, AccT);
default:
break;
}
}
template <typename T>
static void SoftmaxWithCrossEntropySoftLabel(
const platform::CUDADeviceContext& ctx, const int rank, const int axis,
const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
int N, int dim, int D) {
#ifdef __HIPCC__
constexpr int kMaxBlockDim = 256;
#else
constexpr int kMaxBlockDim = 512;
#endif
int64_t block_dim = dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(dim)));
int64_t grid_dim = N * D;
constexpr int max_dim = 320;
const int kDimLog2 = static_cast<int>(Log2Ceil(dim));
const int kDimCeil = 1 << kDimLog2;
auto stream = ctx.stream();
const int ignore_index,
gpuStream_t stream) {
using AccT = typename dtype::MPTypeTrait<T>::Type;
if (D == 1 && dim <= max_dim) {
// use 128 threads per block to maximimize gpu utilization
const int log2_elements = static_cast<int>(Log2Ceil(element_count));
const int kDimCeil = 1 << log2_elements;
int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
int batches_per_warp = (kDimCeil <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / kWarpSize);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
int blocks = (batch_size + batches_per_block - 1) / batches_per_block;
dim3 threads(kWarpSize, warps_per_block, 1);
SwitchWarpSoftmaxForwardSoftLabel<T>(blocks, threads, stream, loss_data,
softmax_data, logits_data, labels_data,
N, dim, dim, kDimLog2);
switch (log2_elements) {
SOFTMAX_WARP_FORWARD_CASE(0, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(1, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(2, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(3, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(4, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(5, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(6, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(7, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(8, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(9, LabelT, T, AccT);
default:
break;
}
}
template <typename T, typename LabelT, bool IgnoreIndex>
void LaunchVectorizedSoftmaxForward(T* loss,
T* softmax,
const T* logits,
const LabelT* label,
const int high_dim,
const int mid_dim,
const int ignore_index,
gpuStream_t stream) {
using AccT = typename dtype::MPTypeTrait<T>::Type;
constexpr int vec_size = sizeof(float4) / sizeof(T);
const int max_num_threads = 1024;
int max_block_size = std::min(mid_dim / vec_size, max_num_threads);
if (vec_size > 1) {
max_block_size /= 2;
}
int block_size = 1;
while (block_size < max_block_size) {
block_size *= 2;
}
block_size = std::max(block_size, kps::details::kWarpSize);
dim3 grids(high_dim);
dim3 blocks(block_size);
VectorizedSoftmaxForward<T,
AccT,
LabelT,
vec_size,
IgnoreIndex><<<grids, blocks, 0, stream>>>(
loss, softmax, logits, label, high_dim, mid_dim, ignore_index);
}
/*
Wrapper of softmax with cross entropy hard label.
- SwitchWarpSoftmaxForward for small size when axis == -1
- LaunchVectorizedSoftmaxForward for large size when axis == -1
- cudnn function for axis != -1
*/
template <typename T, typename LabelT, bool IgnoreIndex>
static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
int rank,
int axis,
const T* logits_data,
const LabelT* labels_data,
T* loss_data,
T* softmax_data,
int N,
int dim,
int D,
const int ignore_index) {
auto stream = dev_ctx.stream();
constexpr int max_dim = 320;
if (D == 1) {
if (dim <= max_dim) { // small size
const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(loss_data,
softmax_data,
logits_data,
labels_data,
N,
dim,
dim,
ignore_index,
stream);
} else { // large size
LaunchVectorizedSoftmaxForward<T, LabelT, IgnoreIndex>(loss_data,
softmax_data,
logits_data,
labels_data,
N,
dim,
ignore_index,
stream);
}
} else {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#else
cudnnTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#endif
auto handle = ctx.cudnn_handle();
auto handle = dev_ctx.cudnn_handle();
#ifdef PADDLE_WITH_HIP
auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
: MIOPEN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
handle, platform::CudnnDataType<T>::kOne(), descp, logits_data,
platform::CudnnDataType<T>::kZero(), descp, softmax_data,
MIOPEN_SOFTMAX_LOG, mode));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenSoftmaxForward_V2(
handle,
paddle::platform::CudnnDataType<T>::kOne(),
descp,
logits_data,
paddle::platform::CudnnDataType<T>::kZero(),
descp,
softmax_data,
MIOPEN_SOFTMAX_LOG,
mode));
#else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward(
handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(),
descp, logits_data, platform::CudnnDataType<T>::kZero(), descp,
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSoftmaxForward(
handle,
CUDNN_SOFTMAX_LOG,
mode,
paddle::platform::CudnnDataType<T>::kOne(),
descp,
logits_data,
paddle::platform::CudnnDataType<T>::kZero(),
descp,
softmax_data));
#endif
const int kDimLog2 = static_cast<int>(Log2Ceil(dim));
const int kDimCeil = 1 << kDimLog2;
#ifdef __HIPCC__
int kThreadPerBlock = 256;
#else
int kThreadPerBlock = 512;
#endif
int kBatchPerBlock = 1;
int blocks = (N * D + kBatchPerBlock - 1) / kBatchPerBlock;
dim3 threads(kThreadPerBlock / kBatchPerBlock, kBatchPerBlock, 1);
CrossEntropySoftLabel<T, T, true><<<blocks, threads, 0, stream>>>(
loss_data, softmax_data, NULL, labels_data, N, dim, D, kDimLog2);
}
}
template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
const T* loss_grad,
const T* labels, const int64_t n,
const int64_t d,
const int64_t remain) {
int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < n * d) {
int64_t idx_n = ids / d;
int64_t idx_remain = ids % remain;
int64_t idx_loss = idx_n * remain + idx_remain;
logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
}
}
template <typename T>
__global__ void SoftLabelCrossEntropyGradientKernel(T* logit_grad,
const T* loss_grad,
const T* labels,
const int n, const int d,
const int remain) {
int ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < n * d) {
int idx_n = ids / d;
int idx_remain = ids % remain;
int idx_loss = idx_n * remain + idx_remain;
logit_grad[ids] = loss_grad[idx_loss] * (-labels[ids] / logit_grad[ids]);
}
}
template <typename T, typename LabelT>
__global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad,
const LabelT* labels,
const int n, const int d,
const int remain,
const int ignore_index) {
CUDA_KERNEL_LOOP(index, n * remain) {
int idx_n = index / remain;
int idx_remain = index % remain;
int tmp = static_cast<int>(labels[index]);
int idx = idx_n * d + tmp * remain + idx_remain;
if (ignore_index != tmp) {
logit_grad[idx] = -static_cast<T>(1.) / logit_grad[idx];
}
int threads = 128;
int blocks = (N * dim * D + threads - 1) / threads;
// compute cross entropy, input is log softmax
CrossEntropyExpHardLabel<T,
LabelT,
IgnoreIndex><<<blocks, threads, 0, stream>>>(
loss_data, softmax_data, labels_data, N, dim, D, ignore_index);
}
}
template <typename T, typename LabelT>
__global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad,
const int num, const int d,
const int remain,
const LabelT* labels,
const int ignore_index) {
CUDA_KERNEL_LOOP(index, num) {
int idx_n = index / d;
int idx_remain = index % remain;
int idx_lbl = idx_n * remain + idx_remain;
int k = (index % d) / remain;
auto lbl = static_cast<int64_t>(labels[idx_lbl]);
if (lbl == ignore_index || lbl != k) {
logit_grad[index] = static_cast<T>(0.);
} else {
logit_grad[index] *= loss_grad[idx_lbl];
}
}
}
template <typename T>
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
RunSoftmaxWithCrossEntropyFunctor<T>(context, *this);
}
template <typename LabelT>
static void Apply(const framework::ExecutionContext& context,
const framework::Tensor& labels, const bool soft_label) {
void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
const DenseTensor& logits,
const DenseTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* softmax,
DenseTensor* loss) {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::Unavailable("softmax_with_cross_entropy operator's "
dev_ctx.GetPlace().GetType(),
AllocationType::GPU,
phi::errors::Unavailable("softmax_with_cross_entropy operator's "
"CUDA kernel only runs on GPU device."));
const bool use_softmax = context.Attr<bool>("use_softmax");
// do not with softmax op, and input is softmax
if (!use_softmax) {
const Tensor* softmax = context.Input<Tensor>("Logits");
Tensor* softmax_out = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss");
DenseTensor* softmax_out = softmax;
const DenseTensor* softmax = &logits;
const DenseTensor& labels = label;
const int rank = softmax->dims().size();
const int axis =
phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
const int axis_dim = softmax->dims()[axis];
const int axis_v = phi::funcs::CanonicalAxis(axis, rank);
const int axis_dim = softmax->dims()[axis_v];
const int n = phi::funcs::SizeToAxis(axis, softmax->dims());
const int d = phi::funcs::SizeFromAxis(axis, softmax->dims());
const int n = phi::funcs::SizeToAxis(axis_v, softmax->dims());
const int d = phi::funcs::SizeFromAxis(axis_v, softmax->dims());
auto* softmax_out_data =
softmax_out->template mutable_data<T>(context.GetPlace());
auto* loss_data = loss->template mutable_data<T>(context.GetPlace());
auto* softmax_out_data = dev_ctx.template Alloc<T>(softmax_out);
auto* loss_data = dev_ctx.template Alloc<T>(loss);
phi::funcs::SetConstant<platform::CUDADeviceContext, T> set_constant;
set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
phi::funcs::SetConstant<GPUContext, T> set_constant;
set_constant(dev_ctx, loss, static_cast<T>(0));
if (axis_dim == 1) {
set_constant(context.cuda_device_context(), softmax_out,
static_cast<T>(1));
set_constant(dev_ctx, softmax_out, static_cast<T>(1));
return;
}
auto ignore_index = context.Attr<int>("ignore_index");
Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
softmax_2d.ShareDataWith(*softmax).Resize({n, d});
labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n});
loss_2d.ShareDataWith(*loss).Resize({n, 1});
softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});
DenseTensor softmax_2d(*softmax);
softmax_2d.Resize({n, d});
DenseTensor labels_2d(labels);
labels_2d.Resize({n, labels.numel() / n});
DenseTensor loss_2d(*loss);
loss_2d.Resize({n, 1});
DenseTensor softmax_out_2d(*softmax_out);
softmax_out_2d.Resize({n, d});
// math::CrossEntropyFunctor support axis is the last
if (axis == -1) {
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
soft_label, ignore_index, axis_dim);
if (axis_v == -1) {
paddle::operators::math::CrossEntropyFunctor<GPUContext, T>()(
dev_ctx,
&loss_2d,
&softmax_2d,
&labels_2d,
soft_label,
ignore_index,
axis_dim);
return;
}
// if axis is not the last, we need a new impliment
if (soft_label) {
auto* logits_data = softmax->template data<T>();
auto* labels_data = labels.template data<T>();
auto* logits_data = softmax->data<T>();
auto* labels_data = labels.data<T>();
const int kDimLog2 = static_cast<int>(Log2Ceil(axis_dim));
const int kDimCeil = 1 << kDimLog2;
......@@ -1284,210 +1368,198 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
int blocks = (n * d + kBatchPerBlock - 1) / kBatchPerBlock;
dim3 threads(kThreadPerBlock / kBatchPerBlock, kBatchPerBlock, 1);
CrossEntropySoftLabel<T, T, false><<<
blocks, threads, 0, context.cuda_device_context().stream()>>>(
loss_data, NULL, logits_data, labels_data, n, axis_dim,
d / axis_dim, kDimLog2);
CrossEntropySoftLabel<T,
T,
false><<<blocks, threads, 0, dev_ctx.stream()>>>(
loss_data,
NULL,
logits_data,
labels_data,
n,
axis_dim,
d / axis_dim,
kDimLog2);
} else { // HardLabel
auto* logits_data = softmax->template data<T>();
auto* labels_data = labels.template data<LabelT>();
auto* logits_data = softmax->data<T>();
auto* labels_data = labels.data<LabelT>();
int threads = 128;
int blocks = (n * d / axis_dim + threads - 1) / threads;
if (ignore_index >= 0 && ignore_index < axis_dim) {
CrossEntropyHardLabel<T, LabelT, true><<<
blocks, threads, 0, context.cuda_device_context().stream()>>>(
loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim,
CrossEntropyHardLabel<T,
LabelT,
true><<<blocks, threads, 0, dev_ctx.stream()>>>(
loss_data,
logits_data,
labels_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
} else {
CrossEntropyHardLabel<T, LabelT, false><<<
blocks, threads, 0, context.cuda_device_context().stream()>>>(
loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim,
CrossEntropyHardLabel<T,
LabelT,
false><<<blocks, threads, 0, dev_ctx.stream()>>>(
loss_data,
logits_data,
labels_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
}
}
// cause of input is softmax
// copy to output softmax, directly
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), softmax_out);
phi::Copy<GPUContext>(
dev_ctx, *softmax, dev_ctx.GetPlace(), false, softmax_out);
return;
}
const Tensor* logits = context.Input<Tensor>("Logits");
Tensor* softmax = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss");
const int rank = logits.dims().size();
const int axis_v = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = logits.dims()[axis_v];
const int rank = logits->dims().size();
const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logits->dims()[axis];
const int64_t n = phi::funcs::SizeToAxis(axis_v, logits.dims());
const int64_t d = phi::funcs::SizeFromAxis(axis_v, logits.dims());
const int64_t n = phi::funcs::SizeToAxis(axis, logits->dims());
const int64_t d = phi::funcs::SizeFromAxis(axis, logits->dims());
auto* softmax_data = softmax->template mutable_data<T>(context.GetPlace());
auto* loss_data = loss->template mutable_data<T>(context.GetPlace());
auto* softmax_data = dev_ctx.template Alloc<T>(softmax);
auto* loss_data = dev_ctx.template Alloc<T>(loss);
if (axis_dim == 1) {
phi::funcs::SetConstant<platform::CUDADeviceContext, T> set_constant;
set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
phi::funcs::SetConstant<GPUContext, T> set_constant;
set_constant(dev_ctx, softmax, static_cast<T>(1));
set_constant(dev_ctx, loss, static_cast<T>(0));
return;
}
auto ignore_index = context.Attr<int>("ignore_index");
if (soft_label) {
auto* logits_data = logits->template data<T>();
auto* labels_data = labels.template data<T>();
SoftmaxWithCrossEntropySoftLabel<T>(
context.cuda_device_context(), rank, axis, logits_data, labels_data,
softmax_data, loss_data, n, axis_dim, d / axis_dim);
auto* logits_data = logits.data<T>();
auto* labels_data = label.data<T>();
SoftmaxWithCrossEntropySoftLabel<T>(dev_ctx,
rank,
axis_v,
logits_data,
labels_data,
softmax_data,
loss_data,
n,
axis_dim,
d / axis_dim);
} else {
if (!context.Attr<bool>("numeric_stable_mode")) {
if (!numeric_stable_mode) {
// CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim
Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({n, d});
softmax_2d.ShareDataWith(*softmax).Resize({n, d});
labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n});
loss_2d.ShareDataWith(*loss).Resize({n, 1});
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
&logits_2d, &softmax_2d);
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
false, ignore_index, axis_dim);
DenseTensor logits_2d(logits);
logits_2d.Resize({n, d});
DenseTensor softmax_2d(*softmax);
softmax_2d.Resize({n, d});
DenseTensor labels_2d(label);
labels_2d.Resize({n, label.numel() / n});
DenseTensor loss_2d(*loss);
loss_2d.Resize({n, 1});
paddle::operators::math::SoftmaxCUDNNFunctor<T, GPUContext>()(
dev_ctx, &logits_2d, &softmax_2d);
paddle::operators::math::CrossEntropyFunctor<GPUContext, T>()(
dev_ctx,
&loss_2d,
&softmax_2d,
&labels_2d,
false,
ignore_index,
axis_dim);
} else {
auto* logits_data = logits->template data<T>();
auto* labels_data = labels.template data<LabelT>();
auto* logits_data = logits.data<T>();
auto* labels_data = label.data<LabelT>();
if (ignore_index >= 0 && ignore_index < axis_dim) {
SoftmaxWithCrossEntropyHardLabel<T, LabelT, true>(
context.cuda_device_context(), rank, axis, logits_data,
labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim,
SoftmaxWithCrossEntropyHardLabel<T, LabelT, true>(dev_ctx,
rank,
axis_v,
logits_data,
labels_data,
loss_data,
softmax_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
} else {
SoftmaxWithCrossEntropyHardLabel<T, LabelT, false>(
context.cuda_device_context(), rank, axis, logits_data,
labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim,
SoftmaxWithCrossEntropyHardLabel<T, LabelT, false>(dev_ctx,
rank,
axis_v,
logits_data,
labels_data,
loss_data,
softmax_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
}
}
}
}
};
template <typename T>
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
RunSoftmaxWithCrossEntropyFunctor<T>(context, *this);
}
template <typename LabelT>
static void Apply(const framework::ExecutionContext& context,
const framework::Tensor& labels, const bool soft_label) {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::Unavailable("softmax_with_cross_entropy operator's "
"CUDA kernel only runs on GPU device."));
const T* loss_grad_data =
context.Input<Tensor>(framework::GradVarName("Loss"))
->template data<T>();
Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits"));
const Tensor* softmax = context.Input<Tensor>("Softmax");
auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index");
auto use_softmax = context.Attr<bool>("use_softmax");
T* logit_grad_data = nullptr;
bool copy_flag = (logit_grad != softmax && (!use_softmax || soft_label));
if (copy_flag) {
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), logit_grad);
logit_grad_data = logit_grad->template data<T>();
} else {
logit_grad_data =
logit_grad->template mutable_data<T>(context.GetPlace());
}
const int rank = logit_grad->dims().size();
const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logit_grad->dims()[axis];
const int64_t n = phi::funcs::SizeToAxis(axis, logit_grad->dims());
const int64_t d = phi::funcs::SizeFromAxis(axis, logit_grad->dims());
const int64_t remain = d / axis_dim;
#ifdef __HIPCC__
int block = 256;
#else
int block = 512;
#endif
// do not with softmax op, and input is softmax
if (!use_softmax) {
if (soft_label) {
int grid = (n * d + block - 1) / block;
const T* label_data = labels.template data<T>();
SoftLabelCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else {
Tensor logits_grad_2d;
logits_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
int grid = (n * remain + block - 1) / block;
const auto* label_data = labels.template data<LabelT>();
HardLabelCrossEntropyGradientKernel<T,
LabelT><<<grid, block, 0, stream>>>(
logit_grad_data, label_data, n, d, remain, ignore_index);
int num = n * d;
grid = (num + block - 1) / block;
ScaleCrossEntropyGradient<T, LabelT><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, num, d, remain, label_data,
ignore_index);
}
return;
}
// with softmax, continue
}
template <typename T, typename Context>
void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& logits,
const DenseTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* softmax,
DenseTensor* loss) {
auto dtype = label.dtype();
if (soft_label) {
int64_t grid = (n * d + block - 1) / block;
const T* label_data = labels.template data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain);
PADDLE_ENFORCE_EQ(
dtype,
paddle::experimental::CppTypeToDataType<T>::Type(),
phi::errors::InvalidArgument("The Input(Label) should be with the "
"same data type as Input(Logits)."));
CrossEntropyWithSoftmaxCUDAKernel<T, T>(dev_ctx,
logits,
label,
soft_label,
use_softmax,
numeric_stable_mode,
ignore_index,
axis,
softmax,
loss);
} else {
const T* softmax_data = softmax->template data<T>();
const auto* label_data = labels.template data<LabelT>();
int grid = (n * d + block - 1) / block;
SoftmaxWithCrossEntropyGradHardLabel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, softmax_data, label_data, n,
d / remain, remain, ignore_index);
}
PD_DISPATCH_INTEGRAL_TYPES(
dtype, "CrossEntropyWithSoftmaxCUDAKernel", ([&] {
CrossEntropyWithSoftmaxCUDAKernel<T, data_t>(dev_ctx,
logits,
label,
soft_label,
use_softmax,
numeric_stable_mode,
ignore_index,
axis,
softmax,
loss);
}));
}
};
}
} // namespace operators
} // namespace paddle
} // namespace phi
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>);
PD_REGISTER_KERNEL(cross_entropy_with_softmax,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxKernel,
float,
phi::dtype::float16) {}
#else
REGISTER_OP_CUDA_KERNEL(
softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel<float>,
ops::SoftmaxWithCrossEntropyCUDAKernel<paddle::platform::float16>,
ops::SoftmaxWithCrossEntropyCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>,
ops::SoftmaxWithCrossEntropyGradCUDAKernel<paddle::platform::float16>,
ops::SoftmaxWithCrossEntropyGradCUDAKernel<double>);
PD_REGISTER_KERNEL(cross_entropy_with_softmax,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxKernel,
float,
double,
phi::dtype::float16) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SoftmaxWithCrossEntropyOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cross_entropy_with_softmax",
{"Logits", "Label"},
{"soft_label",
"use_softmax",
"numeric_stable_mode",
"ignore_index",
"axis"},
{"Softmax", "Loss"});
}
KernelSignature SoftmaxWithCrossEntropyGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cross_entropy_with_softmax_grad",
{"Label", "Softmax", GradVarName("Loss")},
{"soft_label",
"use_softmax",
"numeric_stable_mode",
"ignore_index",
"axis"},
{GradVarName("Logits")});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(softmax_with_cross_entropy,
cross_entropy_with_softmax);
PD_REGISTER_BASE_KERNEL_NAME(softmax_with_cross_entropy_grad,
cross_entropy_with_softmax_grad);
PD_REGISTER_ARG_MAPPING_FN(softmax_with_cross_entropy,
phi::SoftmaxWithCrossEntropyOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(softmax_with_cross_entropy_grad,
phi::SoftmaxWithCrossEntropyGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册