未验证 提交 e6ec98fe 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move softmax with cross entropy kernel into phi (#40832)

* add cross_entropy_with_softmax phi kernel

* remove softmax_with_cross_entropy kernel

* add softmax_with_cross_entropy grad kernel

* remove original op kernel

* refine cross entropy impl

* fix pointer error

* revert kernel cu change

* fix xpu failed

* fix cinn failed

* fix npu failed

* add forward sig

* add check_nan_inf for pt kernel

* remove repeat cmake item

* fix unittest error
上级 d65a7a46
...@@ -35,7 +35,7 @@ USE_OP_ITSELF(elementwise_add); ...@@ -35,7 +35,7 @@ USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(sigmoid); USE_OP_ITSELF(sigmoid);
USE_OP_ITSELF(tanh); USE_OP_ITSELF(tanh);
USE_OP_ITSELF(elementwise_mul); USE_OP_ITSELF(elementwise_mul);
USE_OP(softmax_with_cross_entropy); USE_OP_ITSELF(softmax_with_cross_entropy);
USE_OP_ITSELF(reduce_mean); USE_OP_ITSELF(reduce_mean);
USE_OP_ITSELF(reduce_sum); USE_OP_ITSELF(reduce_sum);
USE_OP_ITSELF(reduce_sum_grad); USE_OP_ITSELF(reduce_sum_grad);
...@@ -83,6 +83,8 @@ PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT); ...@@ -83,6 +83,8 @@ PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sgd, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sgd, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(slice, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(slice, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cross_entropy_with_softmax, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cross_entropy_with_softmax_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sqrt, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sqrt, GPU, ALL_LAYOUT);
DECLARE_double(eager_delete_tensor_gb); DECLARE_double(eager_delete_tensor_gb);
......
...@@ -87,7 +87,7 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey( ...@@ -87,7 +87,7 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey(
} else if (kernel_type.library_type_ == LibraryType::kKP) { } else if (kernel_type.library_type_ == LibraryType::kKP) {
backend = phi::Backend::KPS; backend = phi::Backend::KPS;
} else { } else {
// do // do nothing
} }
paddle::experimental::DataLayout layout = kernel_type.data_layout_; paddle::experimental::DataLayout layout = kernel_type.data_layout_;
paddle::experimental::DataType dtype = paddle::experimental::DataType dtype =
......
...@@ -484,6 +484,11 @@ static void PreparedOpRunPtImpl( ...@@ -484,6 +484,11 @@ static void PreparedOpRunPtImpl(
pt_kernel(&pt_kernel_context); pt_kernel(&pt_kernel_context);
} }
if (FLAGS_check_nan_inf) {
framework::details::CheckOpHasNanOrInfInDygraph<VarType>(
op.Type(), outs, dev_ctx->GetPlace());
}
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
dev_ctx->Wait(); dev_ctx->Wait();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -89,13 +90,11 @@ struct HardLabelCrossEntropyCPUFunctorImpl { ...@@ -89,13 +90,11 @@ struct HardLabelCrossEntropyCPUFunctorImpl {
const int axis_dim_; const int axis_dim_;
}; };
template <typename T> template <typename DeviceContext, typename T>
class CrossEntropyFunctor<platform::CPUDeviceContext, T> { void CrossEntropyFunctor<DeviceContext, T>::operator()(
public: const DeviceContext& ctx, framework::Tensor* out,
void operator()(const platform::CPUDeviceContext& ctx, framework::Tensor* out, const framework::Tensor* prob, const framework::Tensor* labels,
const framework::Tensor* prob, const bool softLabel, const int ignore_index, const int axis_dim) {
const framework::Tensor* labels, const bool softLabel,
const int ignore_index, const int axis_dim) {
if (softLabel) { if (softLabel) {
const int batch_size = prob->dims()[0]; const int batch_size = prob->dims()[0];
const int num_classes = prob->dims()[1]; const int num_classes = prob->dims()[1];
...@@ -111,16 +110,18 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> { ...@@ -111,16 +110,18 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
.reshape(batch_axis_remain) .reshape(batch_axis_remain)
.sum(Eigen::DSizes<int, 1>(1))); .sum(Eigen::DSizes<int, 1>(1)));
} else { } else {
HardLabelCrossEntropyCPUFunctorImpl<T> functor_impl( HardLabelCrossEntropyCPUFunctorImpl<T> functor_impl(out, prob, labels,
out, prob, labels, ignore_index, axis_dim); ignore_index, axis_dim);
framework::VisitIntDataType( framework::VisitIntDataType(framework::TransToProtoVarType(labels->dtype()),
framework::TransToProtoVarType(labels->dtype()), functor_impl); functor_impl);
}
} }
}; }
template class CrossEntropyFunctor<platform::CPUDeviceContext, float>; template class CrossEntropyFunctor<platform::CPUDeviceContext, float>;
template class CrossEntropyFunctor<platform::CPUDeviceContext, double>; template class CrossEntropyFunctor<platform::CPUDeviceContext, double>;
template class CrossEntropyFunctor<phi::CPUContext, float>;
template class CrossEntropyFunctor<phi::CPUContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -93,13 +94,11 @@ struct HardLabelCrossEntropyCUDAFunctorImpl { ...@@ -93,13 +94,11 @@ struct HardLabelCrossEntropyCUDAFunctorImpl {
gpuStream_t stream_; gpuStream_t stream_;
}; };
template <typename T> template <typename DeviceContext, typename T>
class CrossEntropyFunctor<platform::CUDADeviceContext, T> { void CrossEntropyFunctor<DeviceContext, T>::operator()(
public: const DeviceContext& ctx, framework::Tensor* out,
void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor* prob, const framework::Tensor* labels,
framework::Tensor* out, const framework::Tensor* prob, const bool softLabel, const int ignore_index, const int axis_dim) {
const framework::Tensor* labels, const bool softLabel,
const int ignore_index, const int axis_dim) {
const T* prob_data = prob->data<T>(); const T* prob_data = prob->data<T>();
T* loss_data = out->mutable_data<T>(ctx.GetPlace()); T* loss_data = out->mutable_data<T>(ctx.GetPlace());
...@@ -126,13 +125,17 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> { ...@@ -126,13 +125,17 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
framework::VisitDataType(framework::TransToProtoVarType(labels->dtype()), framework::VisitDataType(framework::TransToProtoVarType(labels->dtype()),
functor); functor);
} }
} }
};
template class CrossEntropyFunctor<platform::CUDADeviceContext, float>; template class CrossEntropyFunctor<platform::CUDADeviceContext, float>;
template class CrossEntropyFunctor<platform::CUDADeviceContext, double>; template class CrossEntropyFunctor<platform::CUDADeviceContext, double>;
template class CrossEntropyFunctor<platform::CUDADeviceContext, template class CrossEntropyFunctor<platform::CUDADeviceContext,
platform::float16>; platform::float16>;
template class CrossEntropyFunctor<phi::GPUContext, float>;
template class CrossEntropyFunctor<phi::GPUContext, double>;
template class CrossEntropyFunctor<phi::GPUContext, platform::float16>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -29,9 +29,9 @@ using DataLayout = platform::DataLayout; ...@@ -29,9 +29,9 @@ using DataLayout = platform::DataLayout;
template <typename T> template <typename T>
using CudnnDataType = platform::CudnnDataType<T>; using CudnnDataType = platform::CudnnDataType<T>;
template <typename T> template <typename T, typename DeviceContext>
void SoftmaxCUDNNFunctor<T>::operator()( void SoftmaxCUDNNFunctor<T, DeviceContext>::operator()(
const platform::CUDADeviceContext& context, const framework::Tensor* X, const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) { framework::Tensor* Y) {
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor xDesc; ScopedTensorDescriptor xDesc;
...@@ -69,9 +69,9 @@ void SoftmaxCUDNNFunctor<T>::operator()( ...@@ -69,9 +69,9 @@ void SoftmaxCUDNNFunctor<T>::operator()(
#endif #endif
} }
template <typename T> template <typename T, typename DeviceContext>
void SoftmaxGradCUDNNFunctor<T>::operator()( void SoftmaxGradCUDNNFunctor<T, DeviceContext>::operator()(
const platform::CUDADeviceContext& context, const framework::Tensor* Y, const DeviceContext& context, const framework::Tensor* Y,
const framework::Tensor* YGrad, framework::Tensor* XGrad) { const framework::Tensor* YGrad, framework::Tensor* XGrad) {
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor yDesc; ScopedTensorDescriptor yDesc;
...@@ -116,19 +116,31 @@ void SoftmaxGradCUDNNFunctor<T>::operator()( ...@@ -116,19 +116,31 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
#endif #endif
} }
template class SoftmaxCUDNNFunctor<float>; template class SoftmaxCUDNNFunctor<float, platform::CUDADeviceContext>;
template class SoftmaxCUDNNFunctor<platform::float16>; template class SoftmaxCUDNNFunctor<platform::float16,
template class SoftmaxGradCUDNNFunctor<float>; platform::CUDADeviceContext>;
template class SoftmaxGradCUDNNFunctor<platform::float16>; template class SoftmaxGradCUDNNFunctor<float, platform::CUDADeviceContext>;
template class SoftmaxGradCUDNNFunctor<platform::float16,
platform::CUDADeviceContext>;
template class SoftmaxCUDNNFunctor<float, phi::GPUContext>;
template class SoftmaxCUDNNFunctor<platform::float16, phi::GPUContext>;
template class SoftmaxGradCUDNNFunctor<float, phi::GPUContext>;
template class SoftmaxGradCUDNNFunctor<platform::float16, phi::GPUContext>;
#if CUDNN_VERSION_MIN(8, 1, 0) #if CUDNN_VERSION_MIN(8, 1, 0)
template class SoftmaxCUDNNFunctor<platform::bfloat16>; template class SoftmaxCUDNNFunctor<platform::bfloat16,
template class SoftmaxGradCUDNNFunctor<platform::bfloat16>; platform::CUDADeviceContext>;
template class SoftmaxGradCUDNNFunctor<platform::bfloat16,
platform::CUDADeviceContext>;
template class SoftmaxCUDNNFunctor<platform::bfloat16, phi::GPUContext>;
template class SoftmaxGradCUDNNFunctor<platform::bfloat16, phi::GPUContext>;
#endif #endif
// MIOPEN do not support double // MIOPEN do not support double
#ifndef PADDLE_WITH_HIP #ifndef PADDLE_WITH_HIP
template class SoftmaxCUDNNFunctor<double>; template class SoftmaxCUDNNFunctor<double, platform::CUDADeviceContext>;
template class SoftmaxGradCUDNNFunctor<double>; template class SoftmaxGradCUDNNFunctor<double, platform::CUDADeviceContext>;
template class SoftmaxCUDNNFunctor<double, phi::GPUContext>;
template class SoftmaxGradCUDNNFunctor<double, phi::GPUContext>;
#endif #endif
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16, template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,
......
...@@ -36,19 +36,18 @@ class SoftmaxGradFunctor { ...@@ -36,19 +36,18 @@ class SoftmaxGradFunctor {
}; };
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T> template <typename T, typename DeviceContext>
class SoftmaxCUDNNFunctor { class SoftmaxCUDNNFunctor {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const DeviceContext& context, const framework::Tensor* X,
const framework::Tensor* X, framework::Tensor* Y); framework::Tensor* Y);
}; };
template <typename T> template <typename T, typename DeviceContext>
class SoftmaxGradCUDNNFunctor { class SoftmaxGradCUDNNFunctor {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const DeviceContext& context, const framework::Tensor* Y,
const framework::Tensor* Y, const framework::Tensor* y_grad, const framework::Tensor* y_grad, framework::Tensor* x_grad);
framework::Tensor* x_grad);
}; };
#endif #endif
......
...@@ -58,7 +58,7 @@ class SequenceSoftmaxCUDNNKernel : public framework::OpKernel<T> { ...@@ -58,7 +58,7 @@ class SequenceSoftmaxCUDNNKernel : public framework::OpKernel<T> {
phi::make_ddim({1UL, end_pos - start_pos}); phi::make_ddim({1UL, end_pos - start_pos});
x_i.Resize(dims_i); x_i.Resize(dims_i);
out_i.Resize(dims_i); out_i.Resize(dims_i);
math::SoftmaxCUDNNFunctor<T>()( math::SoftmaxCUDNNFunctor<T, platform::CUDADeviceContext>()(
ctx.template device_context<platform::CUDADeviceContext>(), &x_i, ctx.template device_context<platform::CUDADeviceContext>(), &x_i,
&out_i); &out_i);
} }
...@@ -93,7 +93,7 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> { ...@@ -93,7 +93,7 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
out_i.Resize(dims_i); out_i.Resize(dims_i);
out_grad_i.Resize(dims_i); out_grad_i.Resize(dims_i);
x_grad_i.Resize(dims_i); x_grad_i.Resize(dims_i);
math::SoftmaxGradCUDNNFunctor<T>()( math::SoftmaxGradCUDNNFunctor<T, platform::CUDADeviceContext>()(
ctx.template device_context<platform::CUDADeviceContext>(), &out_i, ctx.template device_context<platform::CUDADeviceContext>(), &out_i,
&out_grad_i, &x_grad_i); &out_grad_i, &x_grad_i);
} }
......
...@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -335,12 +336,6 @@ REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, ...@@ -335,12 +336,6 @@ REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
REGISTER_OPERATOR(softmax_with_cross_entropy_grad, REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad, ops::SoftmaxWithCrossEntropyOpGrad,
ops::SoftmaxWithCrossEntropyGradInplaceInferer); ops::SoftmaxWithCrossEntropyGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<float>,
ops::SoftmaxWithCrossEntropyKernel<double>);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradKernel<float>,
ops::SoftmaxWithCrossEntropyGradKernel<double>);
REGISTER_OP_VERSION(softmax_with_cross_entropy) REGISTER_OP_VERSION(softmax_with_cross_entropy)
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU) #if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, typename Visitor>
struct SoftmaxWithCrossEntropyFunctor {
public:
SoftmaxWithCrossEntropyFunctor(const framework::ExecutionContext& context,
const framework::Tensor& labels,
const bool soft_label, const Visitor& visitor)
: context_(context),
labels_(labels),
soft_label_(soft_label),
visitor_(visitor) {}
template <typename U>
void apply() const {
visitor_.template Apply<U>(context_, labels_, soft_label_);
}
private:
const framework::ExecutionContext& context_;
const framework::Tensor& labels_;
const bool soft_label_;
const Visitor& visitor_;
};
template <typename T, typename Visitor>
static void RunSoftmaxWithCrossEntropyFunctor(
const framework::ExecutionContext& context, const Visitor& visitor) {
const auto* labels = context.Input<framework::Tensor>("Label");
const bool soft_label = context.Attr<bool>("soft_label");
SoftmaxWithCrossEntropyFunctor<T, Visitor> functor(context, *labels,
soft_label, visitor);
auto dtype = framework::TransToProtoVarType(labels->dtype());
if (soft_label) {
PADDLE_ENFORCE_EQ(
dtype, framework::DataTypeTrait<T>::DataType(),
platform::errors::InvalidArgument("The Input(Label) should be with the "
"same data type as Input(Logits)."));
functor.template apply<T>();
} else {
framework::VisitIntDataType(dtype, functor);
}
}
template <typename T>
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(context.GetPlace()), true,
platform::errors::Unimplemented("This kernel only runs on CPU."));
const bool use_softmax = context.Attr<bool>("use_softmax");
const Tensor* labels = context.Input<Tensor>("Label");
const bool soft_label = context.Attr<bool>("soft_label");
// do not with softmax op, and input is softmax
if (!use_softmax) {
const Tensor* softmax = context.Input<Tensor>("Logits");
Tensor* softmax_out = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss");
const int rank = softmax->dims().size();
const int axis =
phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = softmax->dims()[axis];
PADDLE_ENFORCE_GT(
axis_dim, 0,
platform::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
softmax_out->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
const int n = phi::funcs::SizeToAxis(axis, softmax->dims());
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"SizeToAxis of softmax is %d.",
n));
const int d = phi::funcs::SizeFromAxis(axis, softmax->dims());
Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
softmax_2d.ShareDataWith(*softmax).Resize({n, d});
labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
context.Attr<int>("ignore_index"), axis_dim);
// cause of input is softmax
// copy to output softmax, directly
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), softmax_out);
return;
}
const Tensor* logits = context.Input<Tensor>("Logits");
Tensor* softmax = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss");
const int rank = logits->dims().size();
const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logits->dims()[axis];
PADDLE_ENFORCE_GT(
axis_dim, 0,
platform::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
const int n = phi::funcs::SizeToAxis(axis, logits->dims());
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"SizeToAxis of logits is %d.",
n));
const int d = phi::funcs::SizeFromAxis(axis, logits->dims());
Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({n, d});
softmax_2d.ShareDataWith(*softmax).Resize({n, d});
labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, axis_dim, &logits_2d, &softmax_2d);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
context.Attr<int>("ignore_index"), axis_dim);
}
};
template <typename T>
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
RunSoftmaxWithCrossEntropyFunctor<T>(context, *this);
}
template <typename LabelT>
static void Apply(const framework::ExecutionContext& context,
const framework::Tensor& labels, const bool soft_label) {
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Loss"));
Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits"));
const Tensor* softmax = context.Input<Tensor>("Softmax");
const bool use_softmax = context.Attr<bool>("use_softmax");
if (logit_grad != softmax || !use_softmax) {
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), logit_grad);
}
auto ignore_index = context.Attr<int>("ignore_index");
const int rank = logit_grad->dims().size();
const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logit_grad->dims()[axis];
PADDLE_ENFORCE_GT(
axis_dim, 0,
platform::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
const int n = phi::funcs::SizeToAxis(axis, logit_grad->dims());
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"SizeToAxis of logit_grad is %d.",
n));
const int d = phi::funcs::SizeFromAxis(axis, logit_grad->dims());
Tensor logit_grad_2d, labels_2d, out_grad_2d;
logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n});
out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim});
auto out_grad_mat = framework::EigenMatrix<T>::From(out_grad_2d);
auto logit_grad_mat = framework::EigenMatrix<T>::From(logit_grad_2d);
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
if (!use_softmax) {
// use_softmax step1
if (soft_label) {
auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
(-lbl_mat / logit_grad_mat); // for each sample ,i is sample id
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
logit_grad_mat;
} else {
// use_softmax step2
const auto* label_data = labels.template data<LabelT>();
T* logit_grad_data = logit_grad->template data<T>();
const T* out_grad_data = out_grad->template data<T>();
const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) { // for each sample_1_dim
for (int j = 0; j < remain; j++) { // for each sample_other_dims
int idx = i * remain + j; // this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i
auto lbl = static_cast<int64_t>(label_data[idx]);
if (lbl == ignore_index) {
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
logit_grad_data[i * d + k * remain + j] = 0;
}
} else {
// only for this sample's label_idx, the label is 1, others is 0,
// so, only compute this label_idx's class
logit_grad_data[i * d + lbl * remain + j] =
(-1 / logit_grad_data[i * d + lbl * remain + j]) *
out_grad_data[idx];
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
if (k !=
label_data[idx]) { // label_data[idx]: this sample's label
logit_grad_data[i * d + k * remain + j] = 0;
}
}
}
}
}
}
return;
}
// for use_softmax=False, continue
if (soft_label) {
// when soft_label = True, ignore_index is not supported
auto lbl_mat = framework::EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
(logit_grad_mat - lbl_mat); // for each sample ,i is sample id
// 1) compute dy/dx by p_j - y_j or P-Y, where j is class id,
// P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is
// all class's labels
// 2) compute dy * dy/dx by Chain rule, dy=out_grad_mat[i]
// for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix
// operation
} else {
logit_grad_mat.device(place) =
logit_grad_mat * // element_wise multiply
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
const auto* label_data = labels.template data<LabelT>();
T* logit_grad_data = logit_grad->template data<T>();
const T* out_grad_data = out_grad->template data<T>();
const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) { // for each sample_1_dim
for (int j = 0; j < remain; j++) { // for each sample_other_dims
int idx = i * remain + j; // this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i
auto lbl = static_cast<int64_t>(label_data[idx]);
if (lbl == ignore_index) {
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
logit_grad_data[i * d + k * remain + j] = 0;
}
} else {
// only for this sample's label_idx, the label is 1, others is 0,
// so, only compute this label_idx's class
// for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] *
// remain + j] = [i * d + label_data[idx]]
// let idx_x = i * d + label_data[idx] * remain + j,
// logit_grad_data[idx_x] = logit_grad_data[idx_x] -
// out_grad_data[idx]
// note: logit_grad_mat = logit_grad_mat * out_grad_mat
// so: logit_grad_data[idx_x] = (logit_grad_data[idx_x] - 1) *
// out_grad_data[idx]
// means: dy/dp * dy= ( p - y ) * dy
logit_grad_data[i * d + lbl * remain + j] -= out_grad_data[idx];
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
...@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,20 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,20 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "xpu/refactor/math.h" #include "xpu/refactor/math.h"
#include "xpu/refactor/nn.h" #include "xpu/refactor/nn.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T> template <typename T>
class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> { class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
......
...@@ -33,6 +33,8 @@ Backend TransToPhiBackend(const phi::Place& place) { ...@@ -33,6 +33,8 @@ Backend TransToPhiBackend(const phi::Place& place) {
return Backend::GPU; return Backend::GPU;
} else if (allocation_type == phi::AllocationType::XPU) { } else if (allocation_type == phi::AllocationType::XPU) {
return Backend::XPU; return Backend::XPU;
} else if (allocation_type == phi::AllocationType::NPU) {
return Backend::NPU;
} else if (allocation_type == phi::AllocationType::CUSTOM) { } else if (allocation_type == phi::AllocationType::CUSTOM) {
return static_cast<Backend>( return static_cast<Backend>(
static_cast<size_t>(Backend::NUM_BACKENDS) + static_cast<size_t>(Backend::NUM_BACKENDS) +
......
...@@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel) ...@@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used. # Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies. # These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here. # In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel set(MANUAL_BUILD_KERNELS cross_entropy_kernel adam_kernel adamw_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel
gumbel_softmax_kernel gumbel_softmax_grad_kernel hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
...@@ -35,8 +35,10 @@ set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deforma ...@@ -35,8 +35,10 @@ set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deforma
triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel rnn_kernel rnn_grad_kernel warpctc_kernel warpctc_grad_kernel) triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel rnn_kernel rnn_grad_kernel warpctc_kernel warpctc_grad_kernel)
kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper) kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper)
kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel) kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel)
kernel_library(cross_entropy_kernel DEPS ${COMMON_KERNEL_DEPS} softmax cross_entropy)
kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor) kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor)
kernel_library(deformable_conv_grad_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor) kernel_library(deformable_conv_grad_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor)
kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)
kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function) kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function)
kernel_library(hierarchical_sigmoid_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code) kernel_library(hierarchical_sigmoid_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code)
kernel_library(hierarchical_sigmoid_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code) kernel_library(hierarchical_sigmoid_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code)
...@@ -57,7 +59,6 @@ kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) ...@@ -57,7 +59,6 @@ kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce) kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce)
kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)
kernel_library(rnn_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute) kernel_library(rnn_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
kernel_library(rnn_grad_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute) kernel_library(rnn_grad_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale) kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/cross_entropy_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
// TODO(chenweihang): move dispatch.h into phi/core
#include "paddle/phi/api/ext/dispatch.h"
namespace phi {
template <typename T, typename LabelT>
void CrossEntropyWithSoftmaxGradCPUKernel(const CPUContext& dev_ctx,
const DenseTensor& label,
const DenseTensor& softmax,
const DenseTensor& loss_grad,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* logits_grad) {
const DenseTensor* out_grad = &loss_grad;
DenseTensor* logit_grad = logits_grad;
if (logit_grad != &softmax || !use_softmax) {
phi::Copy(dev_ctx, softmax, dev_ctx.GetPlace(), false, logit_grad);
}
const int rank = logit_grad->dims().size();
const int axis_v = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = logit_grad->dims()[axis_v];
PADDLE_ENFORCE_GT(
axis_dim,
0,
phi::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
const int n = phi::funcs::SizeToAxis(axis_v, logit_grad->dims());
PADDLE_ENFORCE_GT(
n,
0,
phi::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"SizeToAxis of logit_grad is %d.",
n));
const int d = phi::funcs::SizeFromAxis(axis_v, logit_grad->dims());
DenseTensor logit_grad_2d(*logit_grad);
logit_grad_2d.Resize({n, d});
DenseTensor labels_2d(label);
labels_2d.Resize({n, label.numel() / n});
DenseTensor out_grad_2d(*out_grad);
out_grad_2d.Resize({n, d / axis_dim});
auto out_grad_mat = EigenMatrix<T>::From(out_grad_2d);
auto logit_grad_mat = EigenMatrix<T>::From(logit_grad_2d);
auto& place = *dev_ctx.eigen_device();
if (!use_softmax) {
// use_softmax step1
if (soft_label) {
auto lbl_mat = EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
(-lbl_mat / logit_grad_mat); // for each sample ,i is sample id
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
logit_grad_mat;
} else {
// use_softmax step2
const auto* label_data = label.data<LabelT>();
T* logit_grad_data = logit_grad->data<T>();
const T* out_grad_data = out_grad->data<T>();
const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) { // for each sample_1_dim
for (int j = 0; j < remain; j++) { // for each sample_other_dims
int idx = i * remain + j; // this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i
auto lbl = static_cast<int64_t>(label_data[idx]);
if (lbl == ignore_index) {
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
logit_grad_data[i * d + k * remain + j] = 0;
}
} else {
// only for this sample's label_idx, the label is 1, others is 0,
// so, only compute this label_idx's class
logit_grad_data[i * d + lbl * remain + j] =
(-1 / logit_grad_data[i * d + lbl * remain + j]) *
out_grad_data[idx];
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
if (k !=
label_data[idx]) { // label_data[idx]: this sample's label
logit_grad_data[i * d + k * remain + j] = 0;
}
}
}
}
}
}
return;
}
// for use_softmax=False, continue
if (soft_label) {
// when soft_label = True, ignore_index is not supported
auto lbl_mat = EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
(logit_grad_mat - lbl_mat);
// for each sample, i is sample id
// 1) compute dy/dx by p_j - y_j or P-Y, where j is class id,
// P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is
// all class's label
// 2) compute dy * dy/dx by Chain rule, dy=out_grad_mat[i]
// for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix
// operation
} else {
logit_grad_mat.device(place) =
logit_grad_mat * // element_wise multiply
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
const auto* label_data = label.data<LabelT>();
T* logit_grad_data = logit_grad->data<T>();
const T* out_grad_data = out_grad->data<T>();
const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) { // for each sample_1_dim
for (int j = 0; j < remain; j++) { // for each sample_other_dims
int idx = i * remain + j; // this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i
auto lbl = static_cast<int64_t>(label_data[idx]);
if (lbl == ignore_index) {
for (int k = 0; k < axis_dim; ++k) { // for each class id's label
logit_grad_data[i * d + k * remain + j] = 0;
}
} else {
// only for this sample's label_idx, the label is 1, others is 0,
// so, only compute this label_idx's class
// for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] *
// remain + j] = [i * d + label_data[idx]]
// let idx_x = i * d + label_data[idx] * remain + j,
// logit_grad_data[idx_x] = logit_grad_data[idx_x] -
// out_grad_data[idx]
// note: logit_grad_mat = logit_grad_mat * out_grad_mat
// so: logit_grad_data[idx_x] = (logit_grad_data[idx_x] - 1) *
// out_grad_data[idx]
// means: dy/dp * dy= ( p - y ) * dy
logit_grad_data[i * d + lbl * remain + j] -= out_grad_data[idx];
}
}
}
}
}
template <typename T, typename Context>
void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& softmax,
const DenseTensor& loss_grad,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* logits_grad) {
auto dtype = label.dtype();
if (soft_label) {
PADDLE_ENFORCE_EQ(
dtype,
paddle::experimental::CppTypeToDataType<T>::Type(),
phi::errors::InvalidArgument("The Input(Label) should be with the "
"same data type as kernel data type."));
CrossEntropyWithSoftmaxGradCPUKernel<T, T>(dev_ctx,
label,
softmax,
loss_grad,
soft_label,
use_softmax,
numeric_stable_mode,
ignore_index,
axis,
logits_grad);
} else {
PD_DISPATCH_INTEGRAL_TYPES(
dtype, "CrossEntropyWithSoftmaxGradCPUKernel", ([&] {
CrossEntropyWithSoftmaxGradCPUKernel<T, data_t>(dev_ctx,
label,
softmax,
loss_grad,
soft_label,
use_softmax,
numeric_stable_mode,
ignore_index,
axis,
logits_grad);
}));
}
}
} // namespace phi
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
CPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/cross_entropy_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/softmax_kernel.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
namespace phi {
template <typename T>
void CrossEntropy(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
bool soft_label,
int ignore_index,
int axis,
DenseTensor* out) {
const int rank = x.dims().size();
const int axis_v = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = x.dims()[axis_v];
PADDLE_ENFORCE_GT(
axis_dim,
0,
phi::errors::InvalidArgument(
"The axis dimention should be larger than 0, but received "
"axis dimention is %d.",
axis_dim));
dev_ctx.template Alloc<T>(out);
const int n = phi::funcs::SizeToAxis(axis_v, x.dims());
PADDLE_ENFORCE_GT(
n,
0,
phi::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"SizeToAxis of softmax is %d.",
n));
const int d = phi::funcs::SizeFromAxis(axis_v, x.dims());
DenseTensor x_2d(x);
x_2d.Resize({n, d});
DenseTensor label_2d(label);
label_2d.Resize({n, label.numel() / n});
DenseTensor out_2d(*out);
out_2d.Resize({n, d / axis_dim});
paddle::operators::math::CrossEntropyFunctor<CPUContext, T>()(
dev_ctx, &out_2d, &x_2d, &label_2d, soft_label, ignore_index, axis_dim);
}
template <typename T, typename Context>
void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& logits,
const DenseTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* softmax,
DenseTensor* loss) {
// do not with softmax op, and input is softmax
if (!use_softmax) {
CrossEntropy<T>(
dev_ctx, logits, label, soft_label, ignore_index, axis, loss);
// cause of input is softmax, copy to output softmax, directly
phi::Copy<Context>(dev_ctx, logits, dev_ctx.GetPlace(), false, softmax);
return;
}
phi::SoftmaxKernel<T, Context>(dev_ctx, logits, axis, softmax);
CrossEntropy<T>(
dev_ctx, *softmax, label, soft_label, ignore_index, axis, loss);
}
} // namespace phi
PD_REGISTER_KERNEL(cross_entropy_with_softmax,
CPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxKernel,
float,
double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& softmax,
const DenseTensor& loss_grad,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* logits_grad);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
// The deformed product of operator iterative upgrade, there is no strict 2.0
// API corresponding to it! In 2.0 API paddle.nn.functional.cross_entropy,
// use_softmax has become an optional argument, which may be called
// CrossEntropyWithSoftmax more accurately, here we keep this kernel arguments
// same as original OpMaker, and if need a CrossEntropyKernel like
// paddle.nn.functional.cross_entropy, we can reuse this kernel
template <typename T, typename Context>
void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& logits,
const DenseTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* softmax,
DenseTensor* loss);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/cross_entropy_grad_kernel.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
// TODO(chenweihang): move dispatch.h into phi/core
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
namespace phi {
template <typename T>
__global__ void SoftLabelCrossEntropyGradientKernel(T* logit_grad,
const T* loss_grad,
const T* labels,
const int n,
const int d,
const int remain) {
int ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < n * d) {
int idx_n = ids / d;
int idx_remain = ids % remain;
int idx_loss = idx_n * remain + idx_remain;
logit_grad[ids] = loss_grad[idx_loss] * (-labels[ids] / logit_grad[ids]);
}
}
template <typename T, typename LabelT>
__global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad,
const LabelT* labels,
const int n,
const int d,
const int remain,
const int ignore_index) {
CUDA_KERNEL_LOOP(index, n * remain) {
int idx_n = index / remain;
int idx_remain = index % remain;
int tmp = static_cast<int>(labels[index]);
int idx = idx_n * d + tmp * remain + idx_remain;
if (ignore_index != tmp) {
logit_grad[idx] = -static_cast<T>(1.) / logit_grad[idx];
}
}
}
template <typename T, typename LabelT>
__global__ void ScaleCrossEntropyGradient(T* logit_grad,
const T* loss_grad,
const int num,
const int d,
const int remain,
const LabelT* labels,
const int ignore_index) {
CUDA_KERNEL_LOOP(index, num) {
int idx_n = index / d;
int idx_remain = index % remain;
int idx_lbl = idx_n * remain + idx_remain;
int k = (index % d) / remain;
auto lbl = static_cast<int64_t>(labels[idx_lbl]);
if (lbl == ignore_index || lbl != k) {
logit_grad[index] = static_cast<T>(0.);
} else {
logit_grad[index] *= loss_grad[idx_lbl];
}
}
}
template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
const T* loss_grad,
const T* labels,
const int64_t n,
const int64_t d,
const int64_t remain) {
int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < n * d) {
int64_t idx_n = ids / d;
int64_t idx_remain = ids % remain;
int64_t idx_loss = idx_n * remain + idx_remain;
logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
}
}
/*
Wrapper of softmax with cross entropy grad hard label.
*/
template <typename T, typename LabelT>
__global__ void SoftmaxWithCrossEntropyGradHardLabel(T* logits_grad,
const T* loss_grad,
const T* softmax,
const LabelT* labels,
const int64_t n,
const int64_t dim,
const int64_t d,
const int ignore_index) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx_n = idx / (d * dim);
int64_t idx_dim = (idx / d) % dim;
int64_t idx_d = idx % d;
int64_t ids = idx_n * d + idx_d;
if (idx < n * dim * d) {
auto lbl = static_cast<int64_t>(labels[ids]);
if (lbl == ignore_index) {
logits_grad[idx] = static_cast<T>(0.0);
} else if (lbl == idx_dim) {
logits_grad[idx] = (softmax[idx] - static_cast<T>(1.0)) * loss_grad[ids];
} else {
logits_grad[idx] = softmax[idx] * loss_grad[ids];
}
}
}
template <typename T, typename LabelT>
void CrossEntropyWithSoftmaxGradGPUKernel(const GPUContext& dev_ctx,
const DenseTensor& label,
const DenseTensor& softmax,
const DenseTensor& loss_grad,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* logits_grad) {
PADDLE_ENFORCE_EQ(
dev_ctx.GetPlace().GetType(),
phi::AllocationType::GPU,
phi::errors::Unavailable("softmax_with_cross_entropy operator's "
"CUDA kernel only runs on GPU device."));
const T* loss_grad_data = loss_grad.data<T>();
DenseTensor* logit_grad = logits_grad;
T* logit_grad_data = nullptr;
bool copy_flag = (logit_grad != &softmax && (!use_softmax || soft_label));
if (copy_flag) {
phi::Copy(dev_ctx, softmax, dev_ctx.GetPlace(), false, logit_grad);
logit_grad_data = logit_grad->data<T>();
} else {
logit_grad_data = dev_ctx.template Alloc<T>(logit_grad);
}
const int rank = logit_grad->dims().size();
const int axis_v = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = logit_grad->dims()[axis_v];
const int64_t n = phi::funcs::SizeToAxis(axis_v, logit_grad->dims());
const int64_t d = phi::funcs::SizeFromAxis(axis_v, logit_grad->dims());
const int64_t remain = d / axis_dim;
#ifdef __HIPCC__
int block = 256;
#else
int block = 512;
#endif
auto stream = dev_ctx.stream();
// do not with softmax op, and input is softmax
if (!use_softmax) {
if (soft_label) {
int grid = (n * d + block - 1) / block;
const T* label_data = label.data<T>();
SoftLabelCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else {
DenseTensor logits_grad_2d(*logit_grad);
logits_grad_2d.Resize({n, d});
int grid = (n * remain + block - 1) / block;
const auto* label_data = label.data<LabelT>();
HardLabelCrossEntropyGradientKernel<T,
LabelT><<<grid, block, 0, stream>>>(
logit_grad_data, label_data, n, d, remain, ignore_index);
int num = n * d;
grid = (num + block - 1) / block;
ScaleCrossEntropyGradient<T, LabelT><<<grid, block, 0, stream>>>(
logit_grad_data,
loss_grad_data,
num,
d,
remain,
label_data,
ignore_index);
}
return;
}
// with softmax, continue
if (soft_label) {
int64_t grid = (n * d + block - 1) / block;
const T* label_data = label.data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else {
const T* softmax_data = softmax.data<T>();
const auto* label_data = label.data<LabelT>();
int grid = (n * d + block - 1) / block;
SoftmaxWithCrossEntropyGradHardLabel<T><<<grid, block, 0, stream>>>(
logit_grad_data,
loss_grad_data,
softmax_data,
label_data,
n,
d / remain,
remain,
ignore_index);
}
}
template <typename T, typename Context>
void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& softmax,
const DenseTensor& loss_grad,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
DenseTensor* logits_grad) {
auto dtype = label.dtype();
if (soft_label) {
PADDLE_ENFORCE_EQ(
dtype,
paddle::experimental::CppTypeToDataType<T>::Type(),
phi::errors::InvalidArgument("The Input(Label) should be with the "
"same data type as kernel data type."));
CrossEntropyWithSoftmaxGradGPUKernel<T, T>(dev_ctx,
label,
softmax,
loss_grad,
soft_label,
use_softmax,
numeric_stable_mode,
ignore_index,
axis,
logits_grad);
} else {
PD_DISPATCH_INTEGRAL_TYPES(
dtype, "CrossEntropyWithSoftmaxGradGPUKernel", ([&] {
CrossEntropyWithSoftmaxGradGPUKernel<T, data_t>(dev_ctx,
label,
softmax,
loss_grad,
soft_label,
use_softmax,
numeric_stable_mode,
ignore_index,
axis,
logits_grad);
}));
}
}
} // namespace phi
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SoftmaxWithCrossEntropyOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cross_entropy_with_softmax",
{"Logits", "Label"},
{"soft_label",
"use_softmax",
"numeric_stable_mode",
"ignore_index",
"axis"},
{"Softmax", "Loss"});
}
KernelSignature SoftmaxWithCrossEntropyGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cross_entropy_with_softmax_grad",
{"Label", "Softmax", GradVarName("Loss")},
{"soft_label",
"use_softmax",
"numeric_stable_mode",
"ignore_index",
"axis"},
{GradVarName("Logits")});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(softmax_with_cross_entropy,
cross_entropy_with_softmax);
PD_REGISTER_BASE_KERNEL_NAME(softmax_with_cross_entropy_grad,
cross_entropy_with_softmax_grad);
PD_REGISTER_ARG_MAPPING_FN(softmax_with_cross_entropy,
phi::SoftmaxWithCrossEntropyOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(softmax_with_cross_entropy_grad,
phi::SoftmaxWithCrossEntropyGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册