diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index 8a3b40bbd76efde76dd8594afd77d4d02db59154..b5670565e2a64b55983e4c85faa85fff9d7f48f0 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -35,7 +35,7 @@ USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(sigmoid); USE_OP_ITSELF(tanh); USE_OP_ITSELF(elementwise_mul); -USE_OP(softmax_with_cross_entropy); +USE_OP_ITSELF(softmax_with_cross_entropy); USE_OP_ITSELF(reduce_mean); USE_OP_ITSELF(reduce_sum); USE_OP_ITSELF(reduce_sum_grad); @@ -83,6 +83,8 @@ PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sgd, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(slice, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(cross_entropy_with_softmax, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(cross_entropy_with_softmax_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sqrt, GPU, ALL_LAYOUT); DECLARE_double(eager_delete_tensor_gb); diff --git a/paddle/fluid/framework/phi_utils.cc b/paddle/fluid/framework/phi_utils.cc index 82c2c339311e6f59c7db9223607653638b298b89..8e6f082da10267268855c5fa467f10e49a6692de 100644 --- a/paddle/fluid/framework/phi_utils.cc +++ b/paddle/fluid/framework/phi_utils.cc @@ -87,7 +87,7 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey( } else if (kernel_type.library_type_ == LibraryType::kKP) { backend = phi::Backend::KPS; } else { - // do + // do nothing } paddle::experimental::DataLayout layout = kernel_type.data_layout_; paddle::experimental::DataType dtype = diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index d248715f00c2ba7dddb24a79450f76cd45cfbf5f..077dd54bc9fa59b5763faf16bcd7d2bb6efce833 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -484,6 +484,11 @@ static void PreparedOpRunPtImpl( pt_kernel(&pt_kernel_context); } + if (FLAGS_check_nan_inf) { + framework::details::CheckOpHasNanOrInfInDygraph( + op.Type(), outs, dev_ctx->GetPlace()); + } + if (FLAGS_benchmark) { dev_ctx->Wait(); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) diff --git a/paddle/fluid/operators/math/cross_entropy.cc b/paddle/fluid/operators/math/cross_entropy.cc index 0b0584608a3006819f1e7b1475ab2fb0e66f5d13..cb2f59182c11125d5c30d589ccf88a288d1b53eb 100644 --- a/paddle/fluid/operators/math/cross_entropy.cc +++ b/paddle/fluid/operators/math/cross_entropy.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/framework/convert_utils.h" +#include "paddle/phi/backends/cpu/cpu_context.h" namespace paddle { namespace platform { @@ -89,38 +90,38 @@ struct HardLabelCrossEntropyCPUFunctorImpl { const int axis_dim_; }; -template -class CrossEntropyFunctor { - public: - void operator()(const platform::CPUDeviceContext& ctx, framework::Tensor* out, - const framework::Tensor* prob, - const framework::Tensor* labels, const bool softLabel, - const int ignore_index, const int axis_dim) { - if (softLabel) { - const int batch_size = prob->dims()[0]; - const int num_classes = prob->dims()[1]; - const int num_remain = num_classes / axis_dim; - - Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); - auto in = EigenMatrix::From(*prob); - auto lbl = EigenMatrix::From(*labels); - auto loss = EigenMatrix::From(*out); - - loss.device(*ctx.eigen_device()) = - -((lbl * in.log().unaryExpr(math::TolerableValue())) - .reshape(batch_axis_remain) - .sum(Eigen::DSizes(1))); - } else { - HardLabelCrossEntropyCPUFunctorImpl functor_impl( - out, prob, labels, ignore_index, axis_dim); - framework::VisitIntDataType( - framework::TransToProtoVarType(labels->dtype()), functor_impl); - } +template +void CrossEntropyFunctor::operator()( + const DeviceContext& ctx, framework::Tensor* out, + const framework::Tensor* prob, const framework::Tensor* labels, + const bool softLabel, const int ignore_index, const int axis_dim) { + if (softLabel) { + const int batch_size = prob->dims()[0]; + const int num_classes = prob->dims()[1]; + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + auto in = EigenMatrix::From(*prob); + auto lbl = EigenMatrix::From(*labels); + auto loss = EigenMatrix::From(*out); + + loss.device(*ctx.eigen_device()) = + -((lbl * in.log().unaryExpr(math::TolerableValue())) + .reshape(batch_axis_remain) + .sum(Eigen::DSizes(1))); + } else { + HardLabelCrossEntropyCPUFunctorImpl functor_impl(out, prob, labels, + ignore_index, axis_dim); + framework::VisitIntDataType(framework::TransToProtoVarType(labels->dtype()), + functor_impl); } -}; +} template class CrossEntropyFunctor; template class CrossEntropyFunctor; + +template class CrossEntropyFunctor; +template class CrossEntropyFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index 829ac9fb559645e36a1ef76bed435710aa5279cf..80e06d4b7f688a4b888cb7c266efaf9df0950802 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" namespace paddle { namespace operators { @@ -93,46 +94,48 @@ struct HardLabelCrossEntropyCUDAFunctorImpl { gpuStream_t stream_; }; -template -class CrossEntropyFunctor { - public: - void operator()(const platform::CUDADeviceContext& ctx, - framework::Tensor* out, const framework::Tensor* prob, - const framework::Tensor* labels, const bool softLabel, - const int ignore_index, const int axis_dim) { - const T* prob_data = prob->data(); - T* loss_data = out->mutable_data(ctx.GetPlace()); - - int batch_size = prob->dims()[0]; - int class_num = prob->dims()[1]; +template +void CrossEntropyFunctor::operator()( + const DeviceContext& ctx, framework::Tensor* out, + const framework::Tensor* prob, const framework::Tensor* labels, + const bool softLabel, const int ignore_index, const int axis_dim) { + const T* prob_data = prob->data(); + T* loss_data = out->mutable_data(ctx.GetPlace()); + + int batch_size = prob->dims()[0]; + int class_num = prob->dims()[1]; #ifdef __HIPCC__ - constexpr int kMaxBlockDim = 256; + constexpr int kMaxBlockDim = 256; #else - constexpr int kMaxBlockDim = 512; + constexpr int kMaxBlockDim = 512; #endif - if (softLabel) { - const T* label_data = labels->data(); - int block = class_num > kMaxBlockDim - ? kMaxBlockDim - : pow(2, static_cast(std::log2(class_num))); - - SoftCrossEntropyKernel<<>>( - loss_data, prob_data, label_data, class_num); - } else { - HardLabelCrossEntropyCUDAFunctorImpl functor( - loss_data, prob_data, labels->data(), batch_size, class_num, - ignore_index, kMaxBlockDim, ctx.stream()); - framework::VisitDataType(framework::TransToProtoVarType(labels->dtype()), - functor); - } + if (softLabel) { + const T* label_data = labels->data(); + int block = class_num > kMaxBlockDim + ? kMaxBlockDim + : pow(2, static_cast(std::log2(class_num))); + + SoftCrossEntropyKernel<<>>( + loss_data, prob_data, label_data, class_num); + } else { + HardLabelCrossEntropyCUDAFunctorImpl functor( + loss_data, prob_data, labels->data(), batch_size, class_num, + ignore_index, kMaxBlockDim, ctx.stream()); + framework::VisitDataType(framework::TransToProtoVarType(labels->dtype()), + functor); } -}; +} template class CrossEntropyFunctor; template class CrossEntropyFunctor; template class CrossEntropyFunctor; + +template class CrossEntropyFunctor; +template class CrossEntropyFunctor; +template class CrossEntropyFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 83b124902ebb74e65af0a25e432ff6b488e5cee1..e960dc8a60832fb847a71e412e097feb1098e020 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -29,9 +29,9 @@ using DataLayout = platform::DataLayout; template using CudnnDataType = platform::CudnnDataType; -template -void SoftmaxCUDNNFunctor::operator()( - const platform::CUDADeviceContext& context, const framework::Tensor* X, +template +void SoftmaxCUDNNFunctor::operator()( + const DeviceContext& context, const framework::Tensor* X, framework::Tensor* Y) { // ------------------- cudnn descriptors --------------------- ScopedTensorDescriptor xDesc; @@ -69,9 +69,9 @@ void SoftmaxCUDNNFunctor::operator()( #endif } -template -void SoftmaxGradCUDNNFunctor::operator()( - const platform::CUDADeviceContext& context, const framework::Tensor* Y, +template +void SoftmaxGradCUDNNFunctor::operator()( + const DeviceContext& context, const framework::Tensor* Y, const framework::Tensor* YGrad, framework::Tensor* XGrad) { // ------------------- cudnn descriptors --------------------- ScopedTensorDescriptor yDesc; @@ -116,19 +116,31 @@ void SoftmaxGradCUDNNFunctor::operator()( #endif } -template class SoftmaxCUDNNFunctor; -template class SoftmaxCUDNNFunctor; -template class SoftmaxGradCUDNNFunctor; -template class SoftmaxGradCUDNNFunctor; +template class SoftmaxCUDNNFunctor; +template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; +template class SoftmaxCUDNNFunctor; +template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; #if CUDNN_VERSION_MIN(8, 1, 0) -template class SoftmaxCUDNNFunctor; -template class SoftmaxGradCUDNNFunctor; +template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; +template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; #endif // MIOPEN do not support double #ifndef PADDLE_WITH_HIP -template class SoftmaxCUDNNFunctor; -template class SoftmaxGradCUDNNFunctor; +template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; +template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; #endif template class SoftmaxFunctor +template class SoftmaxCUDNNFunctor { public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor* X, framework::Tensor* Y); + void operator()(const DeviceContext& context, const framework::Tensor* X, + framework::Tensor* Y); }; -template +template class SoftmaxGradCUDNNFunctor { public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor* Y, const framework::Tensor* y_grad, - framework::Tensor* x_grad); + void operator()(const DeviceContext& context, const framework::Tensor* Y, + const framework::Tensor* y_grad, framework::Tensor* x_grad); }; #endif diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc index 57064301d7afb6ad1403a17f7a0c3a25ded1ca07..976c10d0f433f41fa670330763f954b5dc48410b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc @@ -58,7 +58,7 @@ class SequenceSoftmaxCUDNNKernel : public framework::OpKernel { phi::make_ddim({1UL, end_pos - start_pos}); x_i.Resize(dims_i); out_i.Resize(dims_i); - math::SoftmaxCUDNNFunctor()( + math::SoftmaxCUDNNFunctor()( ctx.template device_context(), &x_i, &out_i); } @@ -93,7 +93,7 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel { out_i.Resize(dims_i); out_grad_i.Resize(dims_i); x_grad_i.Resize(dims_i); - math::SoftmaxGradCUDNNFunctor()( + math::SoftmaxGradCUDNNFunctor()( ctx.template device_context(), &out_i, &out_grad_i, &x_grad_i); } diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index 6f0881e9fc98f6c1ce6c7535c9c68a2fe64e2241..22b592c1eb62aad19dc9ea0e1e71b6ca70c941b2 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" namespace paddle { namespace operators { @@ -335,12 +336,6 @@ REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, REGISTER_OPERATOR(softmax_with_cross_entropy_grad, ops::SoftmaxWithCrossEntropyOpGrad, ops::SoftmaxWithCrossEntropyGradInplaceInferer); -REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, - ops::SoftmaxWithCrossEntropyKernel, - ops::SoftmaxWithCrossEntropyKernel); -REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad, - ops::SoftmaxWithCrossEntropyGradKernel, - ops::SoftmaxWithCrossEntropyGradKernel); REGISTER_OP_VERSION(softmax_with_cross_entropy) #if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h deleted file mode 100644 index 4b875cbf5841f661b55e668808051c8928b45cdd..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ /dev/null @@ -1,318 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/cross_entropy.h" -#include "paddle/fluid/operators/math/softmax.h" -#include "paddle/phi/kernels/funcs/axis_utils.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -struct SoftmaxWithCrossEntropyFunctor { - public: - SoftmaxWithCrossEntropyFunctor(const framework::ExecutionContext& context, - const framework::Tensor& labels, - const bool soft_label, const Visitor& visitor) - : context_(context), - labels_(labels), - soft_label_(soft_label), - visitor_(visitor) {} - - template - void apply() const { - visitor_.template Apply(context_, labels_, soft_label_); - } - - private: - const framework::ExecutionContext& context_; - const framework::Tensor& labels_; - const bool soft_label_; - const Visitor& visitor_; -}; - -template -static void RunSoftmaxWithCrossEntropyFunctor( - const framework::ExecutionContext& context, const Visitor& visitor) { - const auto* labels = context.Input("Label"); - const bool soft_label = context.Attr("soft_label"); - SoftmaxWithCrossEntropyFunctor functor(context, *labels, - soft_label, visitor); - auto dtype = framework::TransToProtoVarType(labels->dtype()); - if (soft_label) { - PADDLE_ENFORCE_EQ( - dtype, framework::DataTypeTrait::DataType(), - platform::errors::InvalidArgument("The Input(Label) should be with the " - "same data type as Input(Logits).")); - functor.template apply(); - } else { - framework::VisitIntDataType(dtype, functor); - } -} - -template -class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(context.GetPlace()), true, - platform::errors::Unimplemented("This kernel only runs on CPU.")); - const bool use_softmax = context.Attr("use_softmax"); - const Tensor* labels = context.Input("Label"); - const bool soft_label = context.Attr("soft_label"); - - // do not with softmax op, and input is softmax - if (!use_softmax) { - const Tensor* softmax = context.Input("Logits"); - Tensor* softmax_out = context.Output("Softmax"); - Tensor* loss = context.Output("Loss"); - const int rank = softmax->dims().size(); - const int axis = - phi::funcs::CanonicalAxis(context.Attr("axis"), rank); - int axis_dim = softmax->dims()[axis]; - - PADDLE_ENFORCE_GT( - axis_dim, 0, - platform::errors::InvalidArgument( - "The axis dimention should be larger than 0, but received " - "axis dimention is %d.", - axis_dim)); - - softmax_out->mutable_data(context.GetPlace()); - loss->mutable_data(context.GetPlace()); - - const int n = phi::funcs::SizeToAxis(axis, softmax->dims()); - - PADDLE_ENFORCE_GT( - n, 0, platform::errors::InvalidArgument( - "The size of axis should be larger than 0, but received " - "SizeToAxis of softmax is %d.", - n)); - - const int d = phi::funcs::SizeFromAxis(axis, softmax->dims()); - - Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d; - softmax_2d.ShareDataWith(*softmax).Resize({n, d}); - labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); - loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim}); - softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d}); - - auto& dev_ctx = - context.template device_context(); - - math::CrossEntropyFunctor()( - dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label, - context.Attr("ignore_index"), axis_dim); - - // cause of input is softmax - // copy to output softmax, directly - framework::TensorCopy(*softmax, context.GetPlace(), - context.device_context(), softmax_out); - - return; - } - - const Tensor* logits = context.Input("Logits"); - Tensor* softmax = context.Output("Softmax"); - Tensor* loss = context.Output("Loss"); - - const int rank = logits->dims().size(); - const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); - int axis_dim = logits->dims()[axis]; - PADDLE_ENFORCE_GT( - axis_dim, 0, - platform::errors::InvalidArgument( - "The axis dimention should be larger than 0, but received " - "axis dimention is %d.", - axis_dim)); - - softmax->mutable_data(context.GetPlace()); - loss->mutable_data(context.GetPlace()); - - const int n = phi::funcs::SizeToAxis(axis, logits->dims()); - PADDLE_ENFORCE_GT( - n, 0, platform::errors::InvalidArgument( - "The size of axis should be larger than 0, but received " - "SizeToAxis of logits is %d.", - n)); - - const int d = phi::funcs::SizeFromAxis(axis, logits->dims()); - Tensor logits_2d, softmax_2d, labels_2d, loss_2d; - logits_2d.ShareDataWith(*logits).Resize({n, d}); - softmax_2d.ShareDataWith(*softmax).Resize({n, d}); - labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); - loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim}); - - auto& dev_ctx = - context.template device_context(); - math::SoftmaxFunctor()( - dev_ctx, axis_dim, &logits_2d, &softmax_2d); - math::CrossEntropyFunctor()( - dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label, - context.Attr("ignore_index"), axis_dim); - } -}; - -template -class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - RunSoftmaxWithCrossEntropyFunctor(context, *this); - } - - template - static void Apply(const framework::ExecutionContext& context, - const framework::Tensor& labels, const bool soft_label) { - const Tensor* out_grad = - context.Input(framework::GradVarName("Loss")); - Tensor* logit_grad = - context.Output(framework::GradVarName("Logits")); - const Tensor* softmax = context.Input("Softmax"); - const bool use_softmax = context.Attr("use_softmax"); - if (logit_grad != softmax || !use_softmax) { - framework::TensorCopy(*softmax, context.GetPlace(), - context.device_context(), logit_grad); - } - auto ignore_index = context.Attr("ignore_index"); - - const int rank = logit_grad->dims().size(); - const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); - int axis_dim = logit_grad->dims()[axis]; - PADDLE_ENFORCE_GT( - axis_dim, 0, - platform::errors::InvalidArgument( - "The axis dimention should be larger than 0, but received " - "axis dimention is %d.", - axis_dim)); - - const int n = phi::funcs::SizeToAxis(axis, logit_grad->dims()); - PADDLE_ENFORCE_GT( - n, 0, platform::errors::InvalidArgument( - "The size of axis should be larger than 0, but received " - "SizeToAxis of logit_grad is %d.", - n)); - - const int d = phi::funcs::SizeFromAxis(axis, logit_grad->dims()); - Tensor logit_grad_2d, labels_2d, out_grad_2d; - logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d}); - labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n}); - out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim}); - auto out_grad_mat = framework::EigenMatrix::From(out_grad_2d); - auto logit_grad_mat = framework::EigenMatrix::From(logit_grad_2d); - auto& place = *context.template device_context() - .eigen_device(); - if (!use_softmax) { - // use_softmax step1 - if (soft_label) { - auto lbl_mat = framework::EigenMatrix::From(labels_2d); - logit_grad_mat.device(place) = - (-lbl_mat / logit_grad_mat); // for each sample ,i is sample id - logit_grad_mat.device(place) = - out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)) * - logit_grad_mat; - } else { - // use_softmax step2 - const auto* label_data = labels.template data(); - T* logit_grad_data = logit_grad->template data(); - const T* out_grad_data = out_grad->template data(); - const int remain = d / axis_dim; - for (int i = 0; i < n; ++i) { // for each sample_1_dim - for (int j = 0; j < remain; j++) { // for each sample_other_dims - int idx = i * remain + j; // this sample's label_idx. for 1d case, - // remain=1 and j=0, so, idx = i - auto lbl = static_cast(label_data[idx]); - if (lbl == ignore_index) { - for (int k = 0; k < axis_dim; ++k) { // for each class id's label - logit_grad_data[i * d + k * remain + j] = 0; - } - } else { - // only for this sample's label_idx, the label is 1, others is 0, - // so, only compute this label_idx's class - logit_grad_data[i * d + lbl * remain + j] = - (-1 / logit_grad_data[i * d + lbl * remain + j]) * - out_grad_data[idx]; - for (int k = 0; k < axis_dim; ++k) { // for each class id's label - if (k != - label_data[idx]) { // label_data[idx]: this sample's label - logit_grad_data[i * d + k * remain + j] = 0; - } - } - } - } - } - } - return; - } - // for use_softmax=False, continue - - if (soft_label) { - // when soft_label = True, ignore_index is not supported - auto lbl_mat = framework::EigenMatrix::From(labels_2d); - logit_grad_mat.device(place) = - out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)) * - (logit_grad_mat - lbl_mat); // for each sample ,i is sample id - // 1) compute dy/dx by p_j - y_j or P-Y, where j is class id, - // P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is - // all class's labels - // 2) compute dy * dy/dx by Chain rule, dy=out_grad_mat[i] - // for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix - // operation - - } else { - logit_grad_mat.device(place) = - logit_grad_mat * // element_wise multiply - out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)); - - const auto* label_data = labels.template data(); - T* logit_grad_data = logit_grad->template data(); - const T* out_grad_data = out_grad->template data(); - const int remain = d / axis_dim; - for (int i = 0; i < n; ++i) { // for each sample_1_dim - for (int j = 0; j < remain; j++) { // for each sample_other_dims - int idx = i * remain + j; // this sample's label_idx. for 1d case, - // remain=1 and j=0, so, idx = i - auto lbl = static_cast(label_data[idx]); - if (lbl == ignore_index) { - for (int k = 0; k < axis_dim; ++k) { // for each class id's label - logit_grad_data[i * d + k * remain + j] = 0; - } - } else { - // only for this sample's label_idx, the label is 1, others is 0, - // so, only compute this label_idx's class - // for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] * - // remain + j] = [i * d + label_data[idx]] - // let idx_x = i * d + label_data[idx] * remain + j, - // logit_grad_data[idx_x] = logit_grad_data[idx_x] - - // out_grad_data[idx] - // note: logit_grad_mat = logit_grad_mat * out_grad_mat - // so: logit_grad_data[idx_x] = (logit_grad_data[idx_x] - 1) * - // out_grad_data[idx] - // means: dy/dp * dy= ( p - y ) * dy - - logit_grad_data[i * d + lbl * remain + j] -= out_grad_data[idx]; - } - } - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_mlu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_mlu.cc index 34650c2e06245532eda5ebcf9e8d8454ee93237b..7056bcd4f76bc6a1c80d0b2aaba527ebf752c5b5 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_mlu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_mlu.cc @@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc index 1f1fbea090c13f2eff7e389c9b7c4774ccbb7700..f64d9e022a1adbd24d33bfaef43956aeb08ab9f6 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc @@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" - #include #include + +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc index b5514525f5981d5184c24067143cac667abaf1ce..c07467a9b0ba33ce3bc0d9d140d72ffa4ed7108c 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc @@ -12,20 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #ifdef PADDLE_WITH_XPU #include #include #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/device_wrapper.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" #include "xpu/refactor/math.h" #include "xpu/refactor/nn.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { using XPUType = typename XPUTypeTrait::Type; diff --git a/paddle/phi/core/compat/convert_utils.cc b/paddle/phi/core/compat/convert_utils.cc index 667cee10675d8a6b756a6af701b65c63b3b359de..cc9c2caa889917691624958c1a0386ac538ebca2 100644 --- a/paddle/phi/core/compat/convert_utils.cc +++ b/paddle/phi/core/compat/convert_utils.cc @@ -33,6 +33,8 @@ Backend TransToPhiBackend(const phi::Place& place) { return Backend::GPU; } else if (allocation_type == phi::AllocationType::XPU) { return Backend::XPU; + } else if (allocation_type == phi::AllocationType::NPU) { + return Backend::NPU; } else if (allocation_type == phi::AllocationType::CUSTOM) { return static_cast( static_cast(Backend::NUM_BACKENDS) + diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index b0d762d00ecf9575eeb2e85109818aa0a0108e10..d4b832cef0bd253fa90c7f445667d94d886aca19 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel) # Some kernels depend on some targets that are not commonly used. # These targets are not suitable for common dependencies. # In this case, you need to manually generate them here. -set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel +set(MANUAL_BUILD_KERNELS cross_entropy_kernel adam_kernel adamw_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel @@ -35,8 +35,10 @@ set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deforma triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel rnn_kernel rnn_grad_kernel warpctc_kernel warpctc_grad_kernel) kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper) kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel) +kernel_library(cross_entropy_kernel DEPS ${COMMON_KERNEL_DEPS} softmax cross_entropy) kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor) kernel_library(deformable_conv_grad_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor) +kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse) kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function) kernel_library(hierarchical_sigmoid_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code) kernel_library(hierarchical_sigmoid_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code) @@ -57,7 +59,6 @@ kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce) -kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse) kernel_library(rnn_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute) kernel_library(rnn_grad_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute) kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale) diff --git a/paddle/phi/kernels/cpu/cross_entropy_grad_kernel.cc b/paddle/phi/kernels/cpu/cross_entropy_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..d4a632b5e6ece09030a7071ee7919cb43a2015df --- /dev/null +++ b/paddle/phi/kernels/cpu/cross_entropy_grad_kernel.cc @@ -0,0 +1,226 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/cross_entropy_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +// TODO(chenweihang): move dispatch.h into phi/core +#include "paddle/phi/api/ext/dispatch.h" + +namespace phi { + +template +void CrossEntropyWithSoftmaxGradCPUKernel(const CPUContext& dev_ctx, + const DenseTensor& label, + const DenseTensor& softmax, + const DenseTensor& loss_grad, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + DenseTensor* logits_grad) { + const DenseTensor* out_grad = &loss_grad; + DenseTensor* logit_grad = logits_grad; + + if (logit_grad != &softmax || !use_softmax) { + phi::Copy(dev_ctx, softmax, dev_ctx.GetPlace(), false, logit_grad); + } + + const int rank = logit_grad->dims().size(); + const int axis_v = phi::funcs::CanonicalAxis(axis, rank); + int axis_dim = logit_grad->dims()[axis_v]; + PADDLE_ENFORCE_GT( + axis_dim, + 0, + phi::errors::InvalidArgument( + "The axis dimention should be larger than 0, but received " + "axis dimention is %d.", + axis_dim)); + + const int n = phi::funcs::SizeToAxis(axis_v, logit_grad->dims()); + PADDLE_ENFORCE_GT( + n, + 0, + phi::errors::InvalidArgument( + "The size of axis should be larger than 0, but received " + "SizeToAxis of logit_grad is %d.", + n)); + + const int d = phi::funcs::SizeFromAxis(axis_v, logit_grad->dims()); + DenseTensor logit_grad_2d(*logit_grad); + logit_grad_2d.Resize({n, d}); + DenseTensor labels_2d(label); + labels_2d.Resize({n, label.numel() / n}); + DenseTensor out_grad_2d(*out_grad); + out_grad_2d.Resize({n, d / axis_dim}); + + auto out_grad_mat = EigenMatrix::From(out_grad_2d); + auto logit_grad_mat = EigenMatrix::From(logit_grad_2d); + auto& place = *dev_ctx.eigen_device(); + + if (!use_softmax) { + // use_softmax step1 + if (soft_label) { + auto lbl_mat = EigenMatrix::From(labels_2d); + logit_grad_mat.device(place) = + (-lbl_mat / logit_grad_mat); // for each sample ,i is sample id + logit_grad_mat.device(place) = + out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)) * + logit_grad_mat; + } else { + // use_softmax step2 + const auto* label_data = label.data(); + T* logit_grad_data = logit_grad->data(); + const T* out_grad_data = out_grad->data(); + const int remain = d / axis_dim; + for (int i = 0; i < n; ++i) { // for each sample_1_dim + for (int j = 0; j < remain; j++) { // for each sample_other_dims + int idx = i * remain + j; // this sample's label_idx. for 1d case, + // remain=1 and j=0, so, idx = i + auto lbl = static_cast(label_data[idx]); + if (lbl == ignore_index) { + for (int k = 0; k < axis_dim; ++k) { // for each class id's label + logit_grad_data[i * d + k * remain + j] = 0; + } + } else { + // only for this sample's label_idx, the label is 1, others is 0, + // so, only compute this label_idx's class + logit_grad_data[i * d + lbl * remain + j] = + (-1 / logit_grad_data[i * d + lbl * remain + j]) * + out_grad_data[idx]; + for (int k = 0; k < axis_dim; ++k) { // for each class id's label + if (k != + label_data[idx]) { // label_data[idx]: this sample's label + logit_grad_data[i * d + k * remain + j] = 0; + } + } + } + } + } + } + return; + } + // for use_softmax=False, continue + + if (soft_label) { + // when soft_label = True, ignore_index is not supported + auto lbl_mat = EigenMatrix::From(labels_2d); + logit_grad_mat.device(place) = + out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)) * + (logit_grad_mat - lbl_mat); + // for each sample, i is sample id + // 1) compute dy/dx by p_j - y_j or P-Y, where j is class id, + // P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is + // all class's label + // 2) compute dy * dy/dx by Chain rule, dy=out_grad_mat[i] + // for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix + // operation + + } else { + logit_grad_mat.device(place) = + logit_grad_mat * // element_wise multiply + out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)); + + const auto* label_data = label.data(); + T* logit_grad_data = logit_grad->data(); + const T* out_grad_data = out_grad->data(); + const int remain = d / axis_dim; + for (int i = 0; i < n; ++i) { // for each sample_1_dim + for (int j = 0; j < remain; j++) { // for each sample_other_dims + int idx = i * remain + j; // this sample's label_idx. for 1d case, + // remain=1 and j=0, so, idx = i + auto lbl = static_cast(label_data[idx]); + if (lbl == ignore_index) { + for (int k = 0; k < axis_dim; ++k) { // for each class id's label + logit_grad_data[i * d + k * remain + j] = 0; + } + } else { + // only for this sample's label_idx, the label is 1, others is 0, + // so, only compute this label_idx's class + // for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] * + // remain + j] = [i * d + label_data[idx]] + // let idx_x = i * d + label_data[idx] * remain + j, + // logit_grad_data[idx_x] = logit_grad_data[idx_x] - + // out_grad_data[idx] + // note: logit_grad_mat = logit_grad_mat * out_grad_mat + // so: logit_grad_data[idx_x] = (logit_grad_data[idx_x] - 1) * + // out_grad_data[idx] + // means: dy/dp * dy= ( p - y ) * dy + + logit_grad_data[i * d + lbl * remain + j] -= out_grad_data[idx]; + } + } + } + } +} + +template +void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx, + const DenseTensor& label, + const DenseTensor& softmax, + const DenseTensor& loss_grad, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + DenseTensor* logits_grad) { + auto dtype = label.dtype(); + if (soft_label) { + PADDLE_ENFORCE_EQ( + dtype, + paddle::experimental::CppTypeToDataType::Type(), + phi::errors::InvalidArgument("The Input(Label) should be with the " + "same data type as kernel data type.")); + CrossEntropyWithSoftmaxGradCPUKernel(dev_ctx, + label, + softmax, + loss_grad, + soft_label, + use_softmax, + numeric_stable_mode, + ignore_index, + axis, + logits_grad); + } else { + PD_DISPATCH_INTEGRAL_TYPES( + dtype, "CrossEntropyWithSoftmaxGradCPUKernel", ([&] { + CrossEntropyWithSoftmaxGradCPUKernel(dev_ctx, + label, + softmax, + loss_grad, + soft_label, + use_softmax, + numeric_stable_mode, + ignore_index, + axis, + logits_grad); + })); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad, + CPU, + ALL_LAYOUT, + phi::CrossEntropyWithSoftmaxGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/cross_entropy_kernel.cc b/paddle/phi/kernels/cpu/cross_entropy_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c684fb416eaab38461d490dec940998ad705b6f6 --- /dev/null +++ b/paddle/phi/kernels/cpu/cross_entropy_kernel.cc @@ -0,0 +1,104 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/cross_entropy_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/softmax_kernel.h" + +#include "paddle/fluid/operators/math/cross_entropy.h" + +namespace phi { + +template +void CrossEntropy(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& label, + bool soft_label, + int ignore_index, + int axis, + DenseTensor* out) { + const int rank = x.dims().size(); + const int axis_v = phi::funcs::CanonicalAxis(axis, rank); + int axis_dim = x.dims()[axis_v]; + + PADDLE_ENFORCE_GT( + axis_dim, + 0, + phi::errors::InvalidArgument( + "The axis dimention should be larger than 0, but received " + "axis dimention is %d.", + axis_dim)); + + dev_ctx.template Alloc(out); + + const int n = phi::funcs::SizeToAxis(axis_v, x.dims()); + PADDLE_ENFORCE_GT( + n, + 0, + phi::errors::InvalidArgument( + "The size of axis should be larger than 0, but received " + "SizeToAxis of softmax is %d.", + n)); + + const int d = phi::funcs::SizeFromAxis(axis_v, x.dims()); + + DenseTensor x_2d(x); + x_2d.Resize({n, d}); + DenseTensor label_2d(label); + label_2d.Resize({n, label.numel() / n}); + DenseTensor out_2d(*out); + out_2d.Resize({n, d / axis_dim}); + + paddle::operators::math::CrossEntropyFunctor()( + dev_ctx, &out_2d, &x_2d, &label_2d, soft_label, ignore_index, axis_dim); +} + +template +void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx, + const DenseTensor& logits, + const DenseTensor& label, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + DenseTensor* softmax, + DenseTensor* loss) { + // do not with softmax op, and input is softmax + if (!use_softmax) { + CrossEntropy( + dev_ctx, logits, label, soft_label, ignore_index, axis, loss); + // cause of input is softmax, copy to output softmax, directly + phi::Copy(dev_ctx, logits, dev_ctx.GetPlace(), false, softmax); + return; + } + + phi::SoftmaxKernel(dev_ctx, logits, axis, softmax); + CrossEntropy( + dev_ctx, *softmax, label, soft_label, ignore_index, axis, loss); +} + +} // namespace phi + +PD_REGISTER_KERNEL(cross_entropy_with_softmax, + CPU, + ALL_LAYOUT, + phi::CrossEntropyWithSoftmaxKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cross_entropy_grad_kernel.h b/paddle/phi/kernels/cross_entropy_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ae4b0436c93ca09707407a00f120ef78b3a2ca0a --- /dev/null +++ b/paddle/phi/kernels/cross_entropy_grad_kernel.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx, + const DenseTensor& label, + const DenseTensor& softmax, + const DenseTensor& loss_grad, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + DenseTensor* logits_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/cross_entropy_kernel.h b/paddle/phi/kernels/cross_entropy_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..621c5f366621351137fc2000e618149375ab842b --- /dev/null +++ b/paddle/phi/kernels/cross_entropy_kernel.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +// The deformed product of operator iterative upgrade, there is no strict 2.0 +// API corresponding to it! In 2.0 API paddle.nn.functional.cross_entropy, +// use_softmax has become an optional argument, which may be called +// CrossEntropyWithSoftmax more accurately, here we keep this kernel arguments +// same as original OpMaker, and if need a CrossEntropyKernel like +// paddle.nn.functional.cross_entropy, we can reuse this kernel +template +void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx, + const DenseTensor& logits, + const DenseTensor& label, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + DenseTensor* softmax, + DenseTensor* loss); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..215b94c52b3950f68fcc084ad1942a612e79352b --- /dev/null +++ b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu @@ -0,0 +1,294 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/cross_entropy_grad_kernel.h" + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" + +// TODO(chenweihang): move dispatch.h into phi/core +#include "paddle/phi/api/ext/dispatch.h" + +#include "paddle/fluid/operators/math/cross_entropy.h" +#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" + +namespace phi { + +template +__global__ void SoftLabelCrossEntropyGradientKernel(T* logit_grad, + const T* loss_grad, + const T* labels, + const int n, + const int d, + const int remain) { + int ids = blockIdx.x * blockDim.x + threadIdx.x; + if (ids < n * d) { + int idx_n = ids / d; + int idx_remain = ids % remain; + int idx_loss = idx_n * remain + idx_remain; + logit_grad[ids] = loss_grad[idx_loss] * (-labels[ids] / logit_grad[ids]); + } +} + +template +__global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad, + const LabelT* labels, + const int n, + const int d, + const int remain, + const int ignore_index) { + CUDA_KERNEL_LOOP(index, n * remain) { + int idx_n = index / remain; + int idx_remain = index % remain; + int tmp = static_cast(labels[index]); + int idx = idx_n * d + tmp * remain + idx_remain; + if (ignore_index != tmp) { + logit_grad[idx] = -static_cast(1.) / logit_grad[idx]; + } + } +} + +template +__global__ void ScaleCrossEntropyGradient(T* logit_grad, + const T* loss_grad, + const int num, + const int d, + const int remain, + const LabelT* labels, + const int ignore_index) { + CUDA_KERNEL_LOOP(index, num) { + int idx_n = index / d; + int idx_remain = index % remain; + int idx_lbl = idx_n * remain + idx_remain; + int k = (index % d) / remain; + auto lbl = static_cast(labels[idx_lbl]); + if (lbl == ignore_index || lbl != k) { + logit_grad[index] = static_cast(0.); + } else { + logit_grad[index] *= loss_grad[idx_lbl]; + } + } +} + +template +__global__ void SoftCrossEntropyGradientKernel(T* logit_grad, + const T* loss_grad, + const T* labels, + const int64_t n, + const int64_t d, + const int64_t remain) { + int64_t ids = blockIdx.x * blockDim.x + threadIdx.x; + if (ids < n * d) { + int64_t idx_n = ids / d; + int64_t idx_remain = ids % remain; + int64_t idx_loss = idx_n * remain + idx_remain; + logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]); + } +} + +/* + Wrapper of softmax with cross entropy grad hard label. +*/ +template +__global__ void SoftmaxWithCrossEntropyGradHardLabel(T* logits_grad, + const T* loss_grad, + const T* softmax, + const LabelT* labels, + const int64_t n, + const int64_t dim, + const int64_t d, + const int ignore_index) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t idx_n = idx / (d * dim); + int64_t idx_dim = (idx / d) % dim; + int64_t idx_d = idx % d; + int64_t ids = idx_n * d + idx_d; + + if (idx < n * dim * d) { + auto lbl = static_cast(labels[ids]); + if (lbl == ignore_index) { + logits_grad[idx] = static_cast(0.0); + } else if (lbl == idx_dim) { + logits_grad[idx] = (softmax[idx] - static_cast(1.0)) * loss_grad[ids]; + } else { + logits_grad[idx] = softmax[idx] * loss_grad[ids]; + } + } +} + +template +void CrossEntropyWithSoftmaxGradGPUKernel(const GPUContext& dev_ctx, + const DenseTensor& label, + const DenseTensor& softmax, + const DenseTensor& loss_grad, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + DenseTensor* logits_grad) { + PADDLE_ENFORCE_EQ( + dev_ctx.GetPlace().GetType(), + phi::AllocationType::GPU, + phi::errors::Unavailable("softmax_with_cross_entropy operator's " + "CUDA kernel only runs on GPU device.")); + const T* loss_grad_data = loss_grad.data(); + DenseTensor* logit_grad = logits_grad; + + T* logit_grad_data = nullptr; + bool copy_flag = (logit_grad != &softmax && (!use_softmax || soft_label)); + if (copy_flag) { + phi::Copy(dev_ctx, softmax, dev_ctx.GetPlace(), false, logit_grad); + logit_grad_data = logit_grad->data(); + } else { + logit_grad_data = dev_ctx.template Alloc(logit_grad); + } + + const int rank = logit_grad->dims().size(); + const int axis_v = phi::funcs::CanonicalAxis(axis, rank); + int axis_dim = logit_grad->dims()[axis_v]; + + const int64_t n = phi::funcs::SizeToAxis(axis_v, logit_grad->dims()); + const int64_t d = phi::funcs::SizeFromAxis(axis_v, logit_grad->dims()); + const int64_t remain = d / axis_dim; + +#ifdef __HIPCC__ + int block = 256; +#else + int block = 512; +#endif + auto stream = dev_ctx.stream(); + + // do not with softmax op, and input is softmax + if (!use_softmax) { + if (soft_label) { + int grid = (n * d + block - 1) / block; + const T* label_data = label.data(); + SoftLabelCrossEntropyGradientKernel<<>>( + logit_grad_data, loss_grad_data, label_data, n, d, remain); + } else { + DenseTensor logits_grad_2d(*logit_grad); + logits_grad_2d.Resize({n, d}); + int grid = (n * remain + block - 1) / block; + const auto* label_data = label.data(); + HardLabelCrossEntropyGradientKernel<<>>( + logit_grad_data, label_data, n, d, remain, ignore_index); + int num = n * d; + grid = (num + block - 1) / block; + ScaleCrossEntropyGradient<<>>( + logit_grad_data, + loss_grad_data, + num, + d, + remain, + label_data, + ignore_index); + } + + return; + } + + // with softmax, continue + + if (soft_label) { + int64_t grid = (n * d + block - 1) / block; + const T* label_data = label.data(); + SoftCrossEntropyGradientKernel<<>>( + logit_grad_data, loss_grad_data, label_data, n, d, remain); + } else { + const T* softmax_data = softmax.data(); + const auto* label_data = label.data(); + int grid = (n * d + block - 1) / block; + SoftmaxWithCrossEntropyGradHardLabel<<>>( + logit_grad_data, + loss_grad_data, + softmax_data, + label_data, + n, + d / remain, + remain, + ignore_index); + } +} + +template +void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx, + const DenseTensor& label, + const DenseTensor& softmax, + const DenseTensor& loss_grad, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + DenseTensor* logits_grad) { + auto dtype = label.dtype(); + if (soft_label) { + PADDLE_ENFORCE_EQ( + dtype, + paddle::experimental::CppTypeToDataType::Type(), + phi::errors::InvalidArgument("The Input(Label) should be with the " + "same data type as kernel data type.")); + CrossEntropyWithSoftmaxGradGPUKernel(dev_ctx, + label, + softmax, + loss_grad, + soft_label, + use_softmax, + numeric_stable_mode, + ignore_index, + axis, + logits_grad); + } else { + PD_DISPATCH_INTEGRAL_TYPES( + dtype, "CrossEntropyWithSoftmaxGradGPUKernel", ([&] { + CrossEntropyWithSoftmaxGradGPUKernel(dev_ctx, + label, + softmax, + loss_grad, + soft_label, + use_softmax, + numeric_stable_mode, + ignore_index, + axis, + logits_grad); + })); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad, + GPU, + ALL_LAYOUT, + phi::CrossEntropyWithSoftmaxGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu similarity index 58% rename from paddle/fluid/operators/softmax_with_cross_entropy_op.cu rename to paddle/phi/kernels/gpu/cross_entropy_kernel.cu index 41545a1ca20b267e79f43c2af4c58ea64dd479b2..055706cffd41e50693cc20682f75a46b7f439d04 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu @@ -1,13 +1,19 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + +#include "paddle/phi/kernels/cross_entropy_kernel.h" + #ifdef __NVCC__ #include "cub/cub.cuh" #endif @@ -15,39 +21,43 @@ limitations under the License. */ #include namespace cub = hipcub; #endif -#include "paddle/fluid/operators/amp/fp16_type_traits.h" + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" + +// TODO(chenweihang): move dispatch.h into phi/core +#include "paddle/phi/api/ext/dispatch.h" + #include "paddle/fluid/operators/math/cross_entropy.h" -#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" +#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" -namespace paddle { -namespace operators { +namespace phi { #define ALIGN_BYTES 16 -using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; -using DataLayout = platform::DataLayout; -using Tensor = framework::Tensor; -namespace kps = phi::kps; +enum class SoftmaxMode { kSoftmax, kLogSoftmax, kCrossEntropy }; // Wrapper of log function. Use log(float32) for float16 template static __device__ __forceinline__ T Log(T x) { - using AccT = typename details::MPTypeTrait::Type; + using AccT = typename dtype::MPTypeTrait::Type; AccT logx = std::log(static_cast(x)); - return math::TolerableValue()(static_cast(logx)); + return paddle::operators::math::TolerableValue()(static_cast(logx)); } // Wrapper of exp function. Use exp(float32) for float16 template static __device__ __forceinline__ T Exp(T x) { - using AccT = typename details::MPTypeTrait::Type; + using AccT = typename dtype::MPTypeTrait::Type; AccT expx = std::exp(static_cast(x)); - return math::TolerableValue()(static_cast(expx)); + return paddle::operators::math::TolerableValue()(static_cast(expx)); } template @@ -62,22 +72,114 @@ struct ExpAddFunctor { Tx max; }; -// log2(value) -static inline int Log2Ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} +/* + Cross entropy soft label with dynamic size on axis (log2_elements is + varibale). + - if the input is softmax,compute loss with softmax + - if the input is log_softmax, compute loss with log_softmax and update + softmax +*/ +template +__global__ void CrossEntropySoftLabel(T* loss, + T* softmaxwrt, + const T* softmax, + const T* labels, + const int n, + const int dim, + const int d, + int log2_elements) { + const int kDimCeil = 1 << log2_elements; + const int kVSize = sizeof(VecT) / sizeof(T); -enum class SoftmaxMode { kSoftmax, kLogSoftmax, kCrossEntropy }; +#ifdef __HIPCC__ + const int kThreadPerBlock = 256; +#else + const int kThreadPerBlock = 512; +#endif + const int kBatchPerBlock = 1; + const int kWarpSize = 32; // (dim < 32) ? dim : 32; + const int kBatchSize = 1; + const int kThreadPerBatch = kThreadPerBlock / kBatchPerBlock; + const int kWarpPerBatch = kThreadPerBatch / kWarpSize; + + const int kIterations = (dim + kThreadPerBatch - 1) / kThreadPerBatch; + const int kIterationsV = (kIterations >= kVSize) ? (kIterations / kVSize) : 1; + + const int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; + + T sum[kBatchSize]{static_cast(0.0)}; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + int ids = first_batch + i; + if (ids >= n * d) break; + int idx_n = ids / d; + int idx_d = ids % d; +#pragma unroll + for (int it = 0; it < kIterations; ++it) { + int idx_dim = it * kThreadPerBatch + threadIdx.x; + int idx = idx_n * dim * d + idx_dim * d + idx_d; + + if (idx_n < n && idx_dim < dim) { + VecT softmaxdata; + if (InLogMode) { + softmaxdata = reinterpret_cast(&softmaxwrt[idx])[0]; + } else { + softmaxdata = reinterpret_cast(&softmax[idx])[0]; + } + VecT labelsdata = reinterpret_cast(&labels[idx])[0]; + T* softmaxptr = reinterpret_cast(&softmaxdata); + T* labelsptr = reinterpret_cast(&labelsdata); +#pragma unroll + for (int s = 0; s < kVSize; s++) { + if (InLogMode) { + sum[i] -= softmaxptr[s] * labelsptr[s]; + softmaxptr[s] = Exp(softmaxptr[s]); + } else { + sum[i] -= Log(softmaxptr[s]) * labelsptr[s]; + } + } + if (InLogMode) { + reinterpret_cast(&softmaxwrt[idx])[0] = softmaxdata; + } + } + } + } + phi::WarpReduceSum(sum); + __syncthreads(); + + __shared__ T sumshare[kWarpPerBatch][kBatchPerBlock][kBatchSize]; + if (threadIdx.x % kWarpSize == 0) { +#pragma unroll + for (int i = 0; i < kBatchSize; i++) { + sumshare[threadIdx.x / kWarpSize][threadIdx.y][i] = sum[i]; + } + } + __syncthreads(); + + // write + if (threadIdx.x == 0) { + for (int i = 0; i < kBatchSize; i++) { + int ids = first_batch + i; + if (ids < n * d) { + loss[ids] = sumshare[0][threadIdx.y][i]; + for (int s = 1; s < kWarpPerBatch; s++) { + loss[ids] += sumshare[s][threadIdx.y][i]; + } + } + } + } +} /* Hard label cross entropy. */ template -__global__ void CrossEntropyHardLabel(T* loss, const T* softmax, - const LabelT* labels, const int n, - const int dim, const int d, +__global__ void CrossEntropyHardLabel(T* loss, + const T* softmax, + const LabelT* labels, + const int n, + const int dim, + const int d, const int ignore_idx) { int64_t ids = blockIdx.x * blockDim.x + threadIdx.x; int64_t idx_n = ids / d; @@ -111,9 +213,12 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax, Output: loss and exp(input) */ template -__global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, - const LabelT* labels, const int n, - const int dim, const int d, +__global__ void CrossEntropyExpHardLabel(T* loss, + T* softmax, + const LabelT* labels, + const int n, + const int dim, + const int d, const int ignore_idx) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; int64_t idx_n = idx / (d * dim); @@ -146,308 +251,64 @@ __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, } } -/* - Core function of softmax with cross entropy forward - - softmax, SoftmaxMode=kSoftmax - - log softmax, SoftmaxMode=kLogSoftmax - - softmax with cross entropy hard label, SoftmaxMode=kCrossEntropy - The computation includes - - Compute max value: maxvalue_{i} = max_j src_{i,j} - - Compute sum of exp: s_{i} = sum_{j}{e^{src_{i,j} - maxvalue_{i}}} - - Compute: softmax_{i,j} = e^{src_{i,j} - maxvalue_{i}} / s_{i} - - Compute: logsoftmax_{i,j} = src_{i,j} - maxvalue_{i} - log(s_{i}) - - Compute: loss_{i} = -logsoftmax[i,label[i]] (Hard label) - This computation results from following formula: - softmax_{i,j} = e^{src_{i,j}} / sum_{j}{e^{src_{i,j}}} - = e^{src_{i,j} - maxvalue_{i}} - / sum_{j}{e^{src_{i,j} - maxvalue_{i}}} - = e^{src_{i,j} - maxvalue_{i}} / s_{i} - logsoftmax_{i,j} = log(softmax_{i,j}) - = src_{i,j} - maxvalue_{i} - log(s_{i}) - One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). - For reduction max (sum), firstly compute max (sum) to one warp, then use - shuffle api to compute max (sum) in one warp. -*/ -template -__global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, - const LabelT* label, const int batch_size, - const int stride, const int element_count, - const int ignore_index) { - constexpr int kDimCeil = 1 << Log2Elements; - constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; - constexpr int kVSize = sizeof(VecT) / sizeof(T); - constexpr int kIterations = kDimCeil / kWarpSize; - constexpr int kIterationsV = - (kIterations >= kVSize) ? (kIterations / kVSize) : 1; - constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1; - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; +template +__device__ __forceinline__ AccT ThreadReduce(const T* input, + int size, + const int offset, + AccT init, + ReduceFunctor reducer) { + using VecT = kps::details::VectorType; + int tid = threadIdx.x; + AccT val = init; - // max index to read - int idx_max_v[kBatchSize]; -#pragma unroll - for (int i = 0; i < kBatchSize; i++) { - int idx_max = ((i + first_batch) < batch_size) ? element_count : 0; - idx_max_v[i] = idx_max / kVSize; + if (offset > 0) { + input -= offset; + size += offset; + if (tid >= offset) { + val = reducer(val, input[tid]); + } + size -= blockDim.x; + input += blockDim.x; } + int remain = size % (VecSize * blockDim.x); - // read data from global memory - AccT srcdata[kBatchSize][kIterationsV][kVSize]; + T ins[VecSize]; + VecT* ins_vec = reinterpret_cast(&ins); -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { -// read data to srcdata: - KVSize==1, - KVSize>1 -#pragma unroll - for (int it = 0; it < kIterationsV; ++it) { - int src_idx = threadIdx.x + it * kWarpSize; - if (kVSize == 1) { - if (src_idx < idx_max_v[i]) { - srcdata[i][it][0] = - static_cast(src[(first_batch + i) * stride + src_idx]); - } else { - srcdata[i][it][0] = -std::numeric_limits::infinity(); - } - } else { - const VecT* src_v = - reinterpret_cast(&src[(first_batch + i) * stride]); - if (src_idx < idx_max_v[i]) { - VecT srctmp = src_v[src_idx]; - const T* srcinptr = reinterpret_cast(&srctmp); -#pragma unroll - for (int s = 0; s < kVSize; s++) { - srcdata[i][it][s] = static_cast(srcinptr[s]); - } - } else { -#pragma unroll - for (int s = 0; s < kVSize; s++) { - srcdata[i][it][s] = -std::numeric_limits::infinity(); - } - } - } - } - } + // vector part + for (; VecSize * tid < (size - remain); tid += blockDim.x) { + *ins_vec = reinterpret_cast(input)[tid]; - // compute max value: maxvalue_{i} = max_j src_{i,j} - AccT max_value[kBatchSize]; -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { - // it = 0 - AccT valmax = srcdata[i][0][0]; #pragma unroll - for (int s = 1; s < kVSize; ++s) { - valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s]; + for (int i = 0; i < VecSize; ++i) { + val = reducer(val, ins[i]); } - max_value[i] = valmax; + } -// it = 1, 2, ... -#pragma unroll - for (int it = 1; it < kIterationsV; ++it) { - AccT valmax = srcdata[i][it][0]; -#pragma unroll - for (int s = 1; s < kVSize; ++s) { - valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s]; - } - max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax; - } + // scalar part + tid = size - remain + threadIdx.x; + for (; tid < size; tid += blockDim.x) { + val = reducer(val, input[tid]); } - phi::WarpReduceMax(max_value); + return val; +} - // compute sum: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} } - AccT sum[kBatchSize]; -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { - // it = 0 - if (mode == SoftmaxMode::kLogSoftmax || - mode == SoftmaxMode::kCrossEntropy) { - sum[i] = std::exp(srcdata[i][0][0] - max_value[i]); - } else { - srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]); - sum[i] = srcdata[i][0][0]; - } -#pragma unroll - for (int s = 1; s < kVSize; ++s) { - if (mode == SoftmaxMode::kLogSoftmax || - mode == SoftmaxMode::kCrossEntropy) { - sum[i] += std::exp(srcdata[i][0][s] - max_value[i]); +template +__device__ __forceinline__ void ComputeLoss(T* loss, + const T loss_value, + const int label_id, + const int64_t label_value, + const int tid, + const int vec_size, + const int offset, + const int ignore_index) { + int loss_id = vec_size * tid + offset; + if (IgnoreIndex) { + if (label_value == loss_id) { + if (label_value == ignore_index) { + loss[label_id] = static_cast(0.0f); } else { - srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]); - sum[i] += srcdata[i][0][s]; - } - } - -// it = 1, 2, ... -#pragma unroll - for (int it = 1; it < kIterationsV; ++it) { -#pragma unroll - for (int s = 0; s < kVSize; ++s) { - if (mode == SoftmaxMode::kLogSoftmax || - mode == SoftmaxMode::kCrossEntropy) { - sum[i] += std::exp(srcdata[i][it][s] - max_value[i]); - } else { - srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]); - sum[i] += srcdata[i][it][s]; - } - } - } - } - phi::WarpReduceSum(sum); - -// write data -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { - if (mode == SoftmaxMode::kLogSoftmax || - mode == SoftmaxMode::kCrossEntropy) { - sum[i] = std::log(sum[i]); - } - -#pragma unroll - for (int it = 0; it < kIterationsV; ++it) { - int idx = threadIdx.x + it * kWarpSize; - if (kVSize == 1) { // kVSize==1 - if (idx < idx_max_v[i]) { - if (mode == SoftmaxMode::kLogSoftmax) { // log softmax - softmax[(first_batch + i) * stride + idx] = - srcdata[i][it][0] - max_value[i] - sum[i]; - // softmax with cross entropy hard label - } else if (mode == SoftmaxMode::kCrossEntropy) { - AccT logsoftmax = srcdata[i][it][0] - max_value[i] - sum[i]; - // softmax - softmax[(first_batch + i) * stride + idx] = std::exp(logsoftmax); - // label - int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize; - auto lbl = static_cast(label[first_batch + i]); - if (IgnoreIndex == true) { - // IgnoreIndex is true - if (lbl == loss_idx) { - if (lbl != ignore_index) { - loss[first_batch + i] = -logsoftmax; - } else { - loss[first_batch + i] = static_cast(0.0); - } - } - } else { - // IgnoreIndex is false - if (lbl >= 0 && lbl < element_count) { - if (lbl == loss_idx) { - loss[first_batch + i] = -logsoftmax; - } - } else { - loss[first_batch + i] = static_cast(0.0); - } - } - } else { // softmax - softmax[(first_batch + i) * stride + idx] = - srcdata[i][it][0] / sum[i]; - } - } else { - break; - } - } else { // KVSize>1 - VecT* softmax_v = - reinterpret_cast(&softmax[(first_batch + i) * stride]); - VecT tmpdata; - T* tmpptr = reinterpret_cast(&tmpdata); -#pragma unroll - for (int s = 0; s < kVSize; ++s) { - if (mode == SoftmaxMode::kLogSoftmax) { // log softmax - tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i]; - // softmax with cross entropy hard label - } else if (mode == SoftmaxMode::kCrossEntropy) { - AccT logsoftmax = srcdata[i][it][s] - max_value[i] - sum[i]; - // softmax - tmpptr[s] = std::exp(logsoftmax); - // label - int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s; - auto lbl = static_cast(label[first_batch + i]); - if (IgnoreIndex == true) { - // IgnoreIndex is true - if (lbl == loss_idx && lbl != ignore_index) { - loss[first_batch + i] = -logsoftmax; - } - } else { - // IgnoreIndex is false - if (lbl >= 0 && lbl < element_count) { - if (lbl == loss_idx) { - loss[first_batch + i] = -logsoftmax; - } - } else { - loss[first_batch + i] = static_cast(0.0); - } - } - } else { // softmax - tmpptr[s] = srcdata[i][it][s] / sum[i]; - } - } - if (idx < idx_max_v[i]) { - softmax_v[idx] = tmpdata; - } else { - break; - } - } - } - } -} - -#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \ - case Log2Elements: \ - WarpSoftmaxForward<<>>( \ - loss, softmax, src, label, batch_size, stride, element_count, \ - ignore_index); \ - break; - -/* - Wrapper of softmax with cross entropy forward hard label. -*/ -template -void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, - const LabelT* label, const int batch_size, - const int stride, const int element_count, - const int ignore_index, gpuStream_t stream) { - using AccT = typename details::MPTypeTrait::Type; - - // use 128 threads per block to maximimize gpu utilization - const int log2_elements = static_cast(Log2Ceil(element_count)); - const int kDimCeil = 1 << log2_elements; - int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; - int batches_per_warp = (kDimCeil <= 128) ? 2 : 1; - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / kWarpSize); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_size + batches_per_block - 1) / batches_per_block; - dim3 threads(kWarpSize, warps_per_block, 1); - - switch (log2_elements) { - SOFTMAX_WARP_FORWARD_CASE(0, LabelT, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(1, LabelT, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(2, LabelT, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(3, LabelT, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(4, LabelT, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(5, LabelT, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(6, LabelT, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(7, LabelT, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(8, LabelT, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(9, LabelT, T, AccT); - default: - break; - } -} - -template -__device__ __forceinline__ void ComputeLoss(T* loss, const T loss_value, - const int label_id, - const int64_t label_value, - const int tid, const int vec_size, - const int offset, - const int ignore_index) { - int loss_id = vec_size * tid + offset; - if (IgnoreIndex) { - if (label_value == loss_id) { - if (label_value == ignore_index) { - loss[label_id] = static_cast(0.0f); - } else { - loss[label_id] = loss_value; + loss[label_id] = loss_value; } } } else { @@ -457,51 +318,19 @@ __device__ __forceinline__ void ComputeLoss(T* loss, const T loss_value, } } -template -__device__ __forceinline__ AccT ThreadReduce(const T* input, int size, - const int offset, AccT init, - ReduceFunctor reducer) { - using VecT = kps::details::VectorType; - int tid = threadIdx.x; - AccT val = init; - - if (offset > 0) { - input -= offset; - size += offset; - if (tid >= offset) { - val = reducer(val, input[tid]); - } - size -= blockDim.x; - input += blockDim.x; - } - int remain = size % (VecSize * blockDim.x); - - T ins[VecSize]; - VecT* ins_vec = reinterpret_cast(&ins); - - // vector part - for (; VecSize * tid < (size - remain); tid += blockDim.x) { - *ins_vec = reinterpret_cast(input)[tid]; - -#pragma unroll - for (int i = 0; i < VecSize; ++i) { - val = reducer(val, ins[i]); - } - } - - // scalar part - tid = size - remain + threadIdx.x; - for (; tid < size; tid += blockDim.x) { - val = reducer(val, input[tid]); - } - return val; -} - -template __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( - T* loss, T* softmax, const T* logits, const LabelT* label, int size, - const int offset, const phi::LogSoftmaxForwardFunctor& func, + T* loss, + T* softmax, + const T* logits, + const LabelT* label, + int size, + const int offset, + const phi::LogSoftmaxForwardFunctor& func, const int ignore_index) { using VecT = kps::details::VectorType; int tid = threadIdx.x; @@ -520,9 +349,14 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( softmax[tid] = static_cast(std::exp(log_softmax)); // loss if (label_valid) { - ComputeLoss(loss, static_cast(-log_softmax), - label_id, label_value, tid, 1, - loss_id_offset, ignore_index); + ComputeLoss(loss, + static_cast(-log_softmax), + label_id, + label_value, + tid, + 1, + loss_id_offset, + ignore_index); } } size -= blockDim.x; @@ -550,9 +384,14 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( // loss if (label_valid) { - ComputeLoss(loss, static_cast(-log_softmax), - label_id, label_value, tid, VecSize, - loss_id_offset + i, ignore_index); + ComputeLoss(loss, + static_cast(-log_softmax), + label_id, + label_value, + tid, + VecSize, + loss_id_offset + i, + ignore_index); } } @@ -568,8 +407,13 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( // loss if (label_valid) { - ComputeLoss(loss, static_cast(-log_softmax), label_id, - label_value, tid, 1, loss_id_offset, + ComputeLoss(loss, + static_cast(-log_softmax), + label_id, + label_value, + tid, + 1, + loss_id_offset, ignore_index); } } @@ -580,11 +424,19 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( } } -template __device__ __forceinline__ void ScalarSoftmaxForwardImpl( - T* loss, T* softmax, const T* logits, const LabelT* label, const int size, - const phi::LogSoftmaxForwardFunctor& func, const int ignore_index) { + T* loss, + T* softmax, + const T* logits, + const LabelT* label, + const int size, + const phi::LogSoftmaxForwardFunctor& func, + const int ignore_index) { int tid = threadIdx.x; int remain = size % (VecSize * blockDim.x); int label_id = blockIdx.x; @@ -605,8 +457,13 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( softmax[tid + i * blockDim.x] = static_cast(std::exp(log_softmax)); // loss if (label_valid) { - ComputeLoss(loss, static_cast(-log_softmax), - label_id, label_value, tid, VecSize, i, + ComputeLoss(loss, + static_cast(-log_softmax), + label_id, + label_value, + tid, + VecSize, + i, ignore_index); } } @@ -618,8 +475,14 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( softmax[tid] = static_cast(std::exp(log_softmax)); // loss if (label_valid) { - ComputeLoss(loss, static_cast(-log_softmax), label_id, - label_value, tid, 1, 0, ignore_index); + ComputeLoss(loss, + static_cast(-log_softmax), + label_id, + label_value, + tid, + 1, + 0, + ignore_index); } } @@ -629,11 +492,17 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( } } -template -__global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits, +__global__ void VectorizedSoftmaxForward(T* loss, + T* softmax, + const T* logits, const LabelT* label, - const int high_dim, const int mid_dim, + const int high_dim, + const int mid_dim, const int ignore_index) { using VecT = kps::details::VectorType; @@ -646,14 +515,20 @@ __global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits, // 1. reduce max AccT max = ThreadReduce>( - logits, mid_dim, input_offset, -std::numeric_limits::infinity(), + logits, + mid_dim, + input_offset, + -std::numeric_limits::infinity(), kps::MaxFunctor()); max = kps::details::BlockXReduce>( max, kps::MaxFunctor()); // 2. reduce sum AccT sum = ThreadReduce>( - logits, mid_dim, input_offset, static_cast(0), + logits, + mid_dim, + input_offset, + static_cast(0), ExpAddFunctor(max)); sum = kps::details::BlockXReduce>( sum, kps::AddFunctor()); @@ -662,7 +537,13 @@ __global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits, phi::LogSoftmaxForwardFunctor func(max, sum); if (input_offset == output_offset) { VectorizedSoftmaxForwardImpl( - loss, softmax, logits, label, mid_dim, input_offset, func, + loss, + softmax, + logits, + label, + mid_dim, + input_offset, + func, ignore_index); } else { ScalarSoftmaxForwardImpl( @@ -670,229 +551,26 @@ __global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits, } } -template -void LaunchVectorizedSoftmaxForward(T* loss, T* softmax, const T* logits, - const LabelT* label, const int high_dim, - const int mid_dim, const int ignore_index, - gpuStream_t stream) { - using AccT = typename details::MPTypeTrait::Type; - constexpr int vec_size = sizeof(float4) / sizeof(T); - const int max_num_threads = 1024; - int max_block_size = std::min(mid_dim / vec_size, max_num_threads); - if (vec_size > 1) { - max_block_size /= 2; - } - - int block_size = 1; - while (block_size < max_block_size) { - block_size *= 2; - } - block_size = std::max(block_size, kps::details::kWarpSize); - dim3 grids(high_dim); - dim3 blocks(block_size); - VectorizedSoftmaxForward<<>>( - loss, softmax, logits, label, high_dim, mid_dim, ignore_index); -} - /* - Wrapper of softmax with cross entropy hard label. - - SwitchWarpSoftmaxForward for small size when axis == -1 - - LaunchVectorizedSoftmaxForward for large size when axis == -1 - - cudnn function for axis != -1 +Core function of softmax with cross entropy forward soft label. +The computation includes + - Compute maximum of batch: maxvalue_{i} = max_j src_{i,j} + - Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} } + - Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} - +log(sum[i]))} +One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). +For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle +api to compute max (sum) in one warp. */ -template -static void SoftmaxWithCrossEntropyHardLabel( - const platform::CUDADeviceContext& ctx, int rank, int axis, - const T* logits_data, const LabelT* labels_data, T* loss_data, - T* softmax_data, int N, int dim, int D, const int ignore_index) { - auto stream = ctx.stream(); - constexpr int max_dim = 320; - if (D == 1) { - if (dim <= max_dim) { // small size - const SoftmaxMode mode = SoftmaxMode::kCrossEntropy; - SwitchWarpSoftmaxForward( - loss_data, softmax_data, logits_data, labels_data, N, dim, dim, - ignore_index, stream); - } else { // large size - LaunchVectorizedSoftmaxForward( - loss_data, softmax_data, logits_data, labels_data, N, dim, - ignore_index, stream); - } - } else { - ScopedTensorDescriptor desc; - std::vector tensor_dims = {N, dim, D, 1}; - DataLayout layout = DataLayout::kNCHW; -#ifdef PADDLE_WITH_HIP - miopenTensorDescriptor_t descp = desc.descriptor(layout, tensor_dims); -#else - cudnnTensorDescriptor_t descp = desc.descriptor(layout, tensor_dims); -#endif - - auto handle = ctx.cudnn_handle(); - -#ifdef PADDLE_WITH_HIP - auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE - : MIOPEN_SOFTMAX_MODE_CHANNEL; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( - handle, platform::CudnnDataType::kOne(), descp, logits_data, - platform::CudnnDataType::kZero(), descp, softmax_data, - MIOPEN_SOFTMAX_LOG, mode)); -#else - auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE - : CUDNN_SOFTMAX_MODE_CHANNEL; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward( - handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), - descp, logits_data, platform::CudnnDataType::kZero(), descp, - softmax_data)); -#endif - int threads = 128; - int blocks = (N * dim * D + threads - 1) / threads; - // compute cross entropy, input is log softmax - CrossEntropyExpHardLabel<<>>( - loss_data, softmax_data, labels_data, N, dim, D, ignore_index); - } -} - -/* - Wrapper of softmax with cross entropy grad hard label. -*/ -template -__global__ void SoftmaxWithCrossEntropyGradHardLabel( - T* logits_grad, const T* loss_grad, const T* softmax, const LabelT* labels, - const int64_t n, const int64_t dim, const int64_t d, - const int ignore_index) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - int64_t idx_n = idx / (d * dim); - int64_t idx_dim = (idx / d) % dim; - int64_t idx_d = idx % d; - int64_t ids = idx_n * d + idx_d; - - if (idx < n * dim * d) { - auto lbl = static_cast(labels[ids]); - if (lbl == ignore_index) { - logits_grad[idx] = static_cast(0.0); - } else if (lbl == idx_dim) { - logits_grad[idx] = (softmax[idx] - static_cast(1.0)) * loss_grad[ids]; - } else { - logits_grad[idx] = softmax[idx] * loss_grad[ids]; - } - } -} - -/* - Cross entropy soft label with dynamic size on axis (log2_elements is - varibale). - - if the input is softmax,compute loss with softmax - - if the input is log_softmax, compute loss with log_softmax and update - softmax -*/ -template -__global__ void CrossEntropySoftLabel(T* loss, T* softmaxwrt, const T* softmax, - const T* labels, const int n, - const int dim, const int d, - int log2_elements) { - const int kDimCeil = 1 << log2_elements; - const int kVSize = sizeof(VecT) / sizeof(T); - -#ifdef __HIPCC__ - const int kThreadPerBlock = 256; -#else - const int kThreadPerBlock = 512; -#endif - const int kBatchPerBlock = 1; - const int kWarpSize = 32; // (dim < 32) ? dim : 32; - const int kBatchSize = 1; - const int kThreadPerBatch = kThreadPerBlock / kBatchPerBlock; - const int kWarpPerBatch = kThreadPerBatch / kWarpSize; - - const int kIterations = (dim + kThreadPerBatch - 1) / kThreadPerBatch; - const int kIterationsV = (kIterations >= kVSize) ? (kIterations / kVSize) : 1; - - const int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; - - T sum[kBatchSize]{static_cast(0.0)}; -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { - int ids = first_batch + i; - if (ids >= n * d) break; - int idx_n = ids / d; - int idx_d = ids % d; -#pragma unroll - for (int it = 0; it < kIterations; ++it) { - int idx_dim = it * kThreadPerBatch + threadIdx.x; - int idx = idx_n * dim * d + idx_dim * d + idx_d; - - if (idx_n < n && idx_dim < dim) { - VecT softmaxdata; - if (InLogMode) { - softmaxdata = reinterpret_cast(&softmaxwrt[idx])[0]; - } else { - softmaxdata = reinterpret_cast(&softmax[idx])[0]; - } - VecT labelsdata = reinterpret_cast(&labels[idx])[0]; - T* softmaxptr = reinterpret_cast(&softmaxdata); - T* labelsptr = reinterpret_cast(&labelsdata); -#pragma unroll - for (int s = 0; s < kVSize; s++) { - if (InLogMode) { - sum[i] -= softmaxptr[s] * labelsptr[s]; - softmaxptr[s] = Exp(softmaxptr[s]); - } else { - sum[i] -= Log(softmaxptr[s]) * labelsptr[s]; - } - } - if (InLogMode) { - reinterpret_cast(&softmaxwrt[idx])[0] = softmaxdata; - } - } - } - } - phi::WarpReduceSum(sum); - __syncthreads(); - - __shared__ T sumshare[kWarpPerBatch][kBatchPerBlock][kBatchSize]; - if (threadIdx.x % kWarpSize == 0) { -#pragma unroll - for (int i = 0; i < kBatchSize; i++) { - sumshare[threadIdx.x / kWarpSize][threadIdx.y][i] = sum[i]; - } - } - __syncthreads(); - - // write - if (threadIdx.x == 0) { - for (int i = 0; i < kBatchSize; i++) { - int ids = first_batch + i; - if (ids < n * d) { - loss[ids] = sumshare[0][threadIdx.y][i]; - for (int s = 1; s < kWarpPerBatch; s++) { - loss[ids] += sumshare[s][threadIdx.y][i]; - } - } - } - } -} - -/* -Core function of softmax with cross entropy forward soft label. -The computation includes - - Compute maximum of batch: maxvalue_{i} = max_j src_{i,j} - - Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} } - - Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} - -log(sum[i]))} -One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). -For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle -api to compute max (sum) in one warp. -*/ -template -__global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src, - const T* label, - const int batch_size, - const int stride, - const int element_count) { - const bool LogMode = true; +template +__global__ void WarpSoftmaxForwardSoftLabel(T* loss, + T* softmax, + const T* src, + const T* label, + const int batch_size, + const int stride, + const int element_count) { + const bool LogMode = true; constexpr int kDimCeil = 1 << Log2Elements; constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; @@ -1030,7 +708,9 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src, #define SOFTMAX_WARP_FORWARD_SOFT_CASE(Log2Elements, VecT, AccT) \ case Log2Elements: \ - WarpSoftmaxForwardSoftLabel<<>>( \ loss, softmax, src, label, batch_size, stride, element_count); \ break; @@ -1039,13 +719,18 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src, Wrapper of softmax with cross entropy forward soft label. */ template -void SwitchWarpSoftmaxForwardSoftLabel(const int blocks, const dim3 threads, - gpuStream_t stream, T* loss, T* softmax, - const T* src, const T* label, - const int batch_size, const int stride, +void SwitchWarpSoftmaxForwardSoftLabel(const int blocks, + const dim3 threads, + gpuStream_t stream, + T* loss, + T* softmax, + const T* src, + const T* label, + const int batch_size, + const int stride, const int element_count, const int log2_elements) { - using AccT = typename details::MPTypeTrait::Type; + using AccT = typename dtype::MPTypeTrait::Type; switch (log2_elements) { SOFTMAX_WARP_FORWARD_SOFT_CASE(0, T, AccT); SOFTMAX_WARP_FORWARD_SOFT_CASE(1, T, AccT); @@ -1063,10 +748,16 @@ void SwitchWarpSoftmaxForwardSoftLabel(const int blocks, const dim3 threads, } template -static void SoftmaxWithCrossEntropySoftLabel( - const platform::CUDADeviceContext& ctx, const int rank, const int axis, - const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data, - int N, int dim, int D) { +static void SoftmaxWithCrossEntropySoftLabel(const GPUContext& dev_ctx, + const int rank, + const int axis, + const T* logits_data, + const T* labels_data, + T* softmax_data, + T* loss_data, + int N, + int dim, + int D) { #ifdef __HIPCC__ constexpr int kMaxBlockDim = 256; #else @@ -1081,7 +772,7 @@ static void SoftmaxWithCrossEntropySoftLabel( const int kDimLog2 = static_cast(Log2Ceil(dim)); const int kDimCeil = 1 << kDimLog2; - auto stream = ctx.stream(); + auto stream = dev_ctx.stream(); if (D == 1 && dim <= max_dim) { int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; @@ -1094,35 +785,55 @@ static void SoftmaxWithCrossEntropySoftLabel( int blocks = (N + batches_per_block - 1) / batches_per_block; dim3 threads(kWarpSize, warps_per_block, 1); - SwitchWarpSoftmaxForwardSoftLabel(blocks, threads, stream, loss_data, - softmax_data, logits_data, labels_data, - N, dim, dim, kDimLog2); + SwitchWarpSoftmaxForwardSoftLabel(blocks, + threads, + stream, + loss_data, + softmax_data, + logits_data, + labels_data, + N, + dim, + dim, + kDimLog2); } else { ScopedTensorDescriptor desc; std::vector tensor_dims = {N, dim, D, 1}; - DataLayout layout = DataLayout::kNCHW; + GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW; #ifdef PADDLE_WITH_HIP miopenTensorDescriptor_t descp = desc.descriptor(layout, tensor_dims); #else cudnnTensorDescriptor_t descp = desc.descriptor(layout, tensor_dims); #endif - auto handle = ctx.cudnn_handle(); + auto handle = dev_ctx.cudnn_handle(); #ifdef PADDLE_WITH_HIP auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE : MIOPEN_SOFTMAX_MODE_CHANNEL; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( - handle, platform::CudnnDataType::kOne(), descp, logits_data, - platform::CudnnDataType::kZero(), descp, softmax_data, - MIOPEN_SOFTMAX_LOG, mode)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenSoftmaxForward_V2( + handle, + paddle::platform::CudnnDataType::kOne(), + descp, + logits_data, + paddle::platform::CudnnDataType::kZero(), + descp, + softmax_data, + MIOPEN_SOFTMAX_LOG, + mode)); #else auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward( - handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), - descp, logits_data, platform::CudnnDataType::kZero(), descp, + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSoftmaxForward( + handle, + CUDNN_SOFTMAX_LOG, + mode, + paddle::platform::CudnnDataType::kOne(), + descp, + logits_data, + paddle::platform::CudnnDataType::kZero(), + descp, softmax_data)); #endif @@ -1143,351 +854,712 @@ static void SoftmaxWithCrossEntropySoftLabel( } } -template -__global__ void SoftCrossEntropyGradientKernel(T* logit_grad, - const T* loss_grad, - const T* labels, const int64_t n, - const int64_t d, - const int64_t remain) { - int64_t ids = blockIdx.x * blockDim.x + threadIdx.x; - if (ids < n * d) { - int64_t idx_n = ids / d; - int64_t idx_remain = ids % remain; - int64_t idx_loss = idx_n * remain + idx_remain; - logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]); - } -} +/* + Core function of softmax with cross entropy forward + - softmax, SoftmaxMode=kSoftmax + - log softmax, SoftmaxMode=kLogSoftmax + - softmax with cross entropy hard label, SoftmaxMode=kCrossEntropy + The computation includes + - Compute max value: maxvalue_{i} = max_j src_{i,j} + - Compute sum of exp: s_{i} = sum_{j}{e^{src_{i,j} - maxvalue_{i}}} + - Compute: softmax_{i,j} = e^{src_{i,j} - maxvalue_{i}} / s_{i} + - Compute: logsoftmax_{i,j} = src_{i,j} - maxvalue_{i} - log(s_{i}) + - Compute: loss_{i} = -logsoftmax[i,label[i]] (Hard label) + This computation results from following formula: + softmax_{i,j} = e^{src_{i,j}} / sum_{j}{e^{src_{i,j}}} + = e^{src_{i,j} - maxvalue_{i}} + / sum_{j}{e^{src_{i,j} - maxvalue_{i}}} + = e^{src_{i,j} - maxvalue_{i}} / s_{i} + logsoftmax_{i,j} = log(softmax_{i,j}) + = src_{i,j} - maxvalue_{i} - log(s_{i}) + One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). + For reduction max (sum), firstly compute max (sum) to one warp, then use + shuffle api to compute max (sum) in one warp. +*/ +template +__global__ void WarpSoftmaxForward(T* loss, + T* softmax, + const T* src, + const LabelT* label, + const int batch_size, + const int stride, + const int element_count, + const int ignore_index) { + constexpr int kDimCeil = 1 << Log2Elements; + constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + constexpr int kVSize = sizeof(VecT) / sizeof(T); + constexpr int kIterations = kDimCeil / kWarpSize; + constexpr int kIterationsV = + (kIterations >= kVSize) ? (kIterations / kVSize) : 1; + constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1; -template -__global__ void SoftLabelCrossEntropyGradientKernel(T* logit_grad, - const T* loss_grad, - const T* labels, - const int n, const int d, - const int remain) { - int ids = blockIdx.x * blockDim.x + threadIdx.x; - if (ids < n * d) { - int idx_n = ids / d; - int idx_remain = ids % remain; - int idx_loss = idx_n * remain + idx_remain; - logit_grad[ids] = loss_grad[idx_loss] * (-labels[ids] / logit_grad[ids]); + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; + + // max index to read + int idx_max_v[kBatchSize]; +#pragma unroll + for (int i = 0; i < kBatchSize; i++) { + int idx_max = ((i + first_batch) < batch_size) ? element_count : 0; + idx_max_v[i] = idx_max / kVSize; } -} -template -__global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad, - const LabelT* labels, - const int n, const int d, - const int remain, - const int ignore_index) { - CUDA_KERNEL_LOOP(index, n * remain) { - int idx_n = index / remain; - int idx_remain = index % remain; - int tmp = static_cast(labels[index]); - int idx = idx_n * d + tmp * remain + idx_remain; - if (ignore_index != tmp) { - logit_grad[idx] = -static_cast(1.) / logit_grad[idx]; + // read data from global memory + AccT srcdata[kBatchSize][kIterationsV][kVSize]; + +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { +// read data to srcdata: - KVSize==1, - KVSize>1 +#pragma unroll + for (int it = 0; it < kIterationsV; ++it) { + int src_idx = threadIdx.x + it * kWarpSize; + if (kVSize == 1) { + if (src_idx < idx_max_v[i]) { + srcdata[i][it][0] = + static_cast(src[(first_batch + i) * stride + src_idx]); + } else { + srcdata[i][it][0] = -std::numeric_limits::infinity(); + } + } else { + const VecT* src_v = + reinterpret_cast(&src[(first_batch + i) * stride]); + if (src_idx < idx_max_v[i]) { + VecT srctmp = src_v[src_idx]; + const T* srcinptr = reinterpret_cast(&srctmp); +#pragma unroll + for (int s = 0; s < kVSize; s++) { + srcdata[i][it][s] = static_cast(srcinptr[s]); + } + } else { +#pragma unroll + for (int s = 0; s < kVSize; s++) { + srcdata[i][it][s] = -std::numeric_limits::infinity(); + } + } + } } } -} -template -__global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad, - const int num, const int d, - const int remain, - const LabelT* labels, - const int ignore_index) { - CUDA_KERNEL_LOOP(index, num) { - int idx_n = index / d; - int idx_remain = index % remain; - int idx_lbl = idx_n * remain + idx_remain; - int k = (index % d) / remain; - auto lbl = static_cast(labels[idx_lbl]); - if (lbl == ignore_index || lbl != k) { - logit_grad[index] = static_cast(0.); - } else { - logit_grad[index] *= loss_grad[idx_lbl]; + // compute max value: maxvalue_{i} = max_j src_{i,j} + AccT max_value[kBatchSize]; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + // it = 0 + AccT valmax = srcdata[i][0][0]; +#pragma unroll + for (int s = 1; s < kVSize; ++s) { + valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s]; } - } -} + max_value[i] = valmax; -template -class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - RunSoftmaxWithCrossEntropyFunctor(context, *this); +// it = 1, 2, ... +#pragma unroll + for (int it = 1; it < kIterationsV; ++it) { + AccT valmax = srcdata[i][it][0]; +#pragma unroll + for (int s = 1; s < kVSize; ++s) { + valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s]; + } + max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax; + } } + phi::WarpReduceMax(max_value); - template - static void Apply(const framework::ExecutionContext& context, - const framework::Tensor& labels, const bool soft_label) { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(context.GetPlace()), true, - platform::errors::Unavailable("softmax_with_cross_entropy operator's " - "CUDA kernel only runs on GPU device.")); - const bool use_softmax = context.Attr("use_softmax"); - - // do not with softmax op, and input is softmax - if (!use_softmax) { - const Tensor* softmax = context.Input("Logits"); - Tensor* softmax_out = context.Output("Softmax"); - Tensor* loss = context.Output("Loss"); - - const int rank = softmax->dims().size(); - const int axis = - phi::funcs::CanonicalAxis(context.Attr("axis"), rank); - const int axis_dim = softmax->dims()[axis]; - - const int n = phi::funcs::SizeToAxis(axis, softmax->dims()); - const int d = phi::funcs::SizeFromAxis(axis, softmax->dims()); - - auto* softmax_out_data = - softmax_out->template mutable_data(context.GetPlace()); - auto* loss_data = loss->template mutable_data(context.GetPlace()); - - phi::funcs::SetConstant set_constant; - set_constant(context.cuda_device_context(), loss, static_cast(0)); - if (axis_dim == 1) { - set_constant(context.cuda_device_context(), softmax_out, - static_cast(1)); - return; + // compute sum: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} } + AccT sum[kBatchSize]; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + // it = 0 + if (mode == SoftmaxMode::kLogSoftmax || + mode == SoftmaxMode::kCrossEntropy) { + sum[i] = std::exp(srcdata[i][0][0] - max_value[i]); + } else { + srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]); + sum[i] = srcdata[i][0][0]; + } +#pragma unroll + for (int s = 1; s < kVSize; ++s) { + if (mode == SoftmaxMode::kLogSoftmax || + mode == SoftmaxMode::kCrossEntropy) { + sum[i] += std::exp(srcdata[i][0][s] - max_value[i]); + } else { + srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]); + sum[i] += srcdata[i][0][s]; } + } - auto ignore_index = context.Attr("ignore_index"); - - Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d; - softmax_2d.ShareDataWith(*softmax).Resize({n, d}); - labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n}); - loss_2d.ShareDataWith(*loss).Resize({n, 1}); - softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d}); - - // math::CrossEntropyFunctor support axis is the last - if (axis == -1) { - math::CrossEntropyFunctor()( - context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d, - soft_label, ignore_index, axis_dim); - return; +// it = 1, 2, ... +#pragma unroll + for (int it = 1; it < kIterationsV; ++it) { +#pragma unroll + for (int s = 0; s < kVSize; ++s) { + if (mode == SoftmaxMode::kLogSoftmax || + mode == SoftmaxMode::kCrossEntropy) { + sum[i] += std::exp(srcdata[i][it][s] - max_value[i]); + } else { + srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]); + sum[i] += srcdata[i][it][s]; + } } + } + } + phi::WarpReduceSum(sum); - // if axis is not the last, we need a new impliment - if (soft_label) { - auto* logits_data = softmax->template data(); - auto* labels_data = labels.template data(); +// write data +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + if (mode == SoftmaxMode::kLogSoftmax || + mode == SoftmaxMode::kCrossEntropy) { + sum[i] = std::log(sum[i]); + } - const int kDimLog2 = static_cast(Log2Ceil(axis_dim)); - const int kDimCeil = 1 << kDimLog2; -#ifdef __HIPCC__ - int kThreadPerBlock = 256; -#else - int kThreadPerBlock = 512; -#endif - int kBatchPerBlock = 1; - int blocks = (n * d + kBatchPerBlock - 1) / kBatchPerBlock; - dim3 threads(kThreadPerBlock / kBatchPerBlock, kBatchPerBlock, 1); - - CrossEntropySoftLabel<<< - blocks, threads, 0, context.cuda_device_context().stream()>>>( - loss_data, NULL, logits_data, labels_data, n, axis_dim, - d / axis_dim, kDimLog2); - } else { // HardLabel - auto* logits_data = softmax->template data(); - auto* labels_data = labels.template data(); - int threads = 128; - int blocks = (n * d / axis_dim + threads - 1) / threads; - if (ignore_index >= 0 && ignore_index < axis_dim) { - CrossEntropyHardLabel<<< - blocks, threads, 0, context.cuda_device_context().stream()>>>( - loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim, - ignore_index); +#pragma unroll + for (int it = 0; it < kIterationsV; ++it) { + int idx = threadIdx.x + it * kWarpSize; + if (kVSize == 1) { // kVSize==1 + if (idx < idx_max_v[i]) { + if (mode == SoftmaxMode::kLogSoftmax) { // log softmax + softmax[(first_batch + i) * stride + idx] = + srcdata[i][it][0] - max_value[i] - sum[i]; + // softmax with cross entropy hard label + } else if (mode == SoftmaxMode::kCrossEntropy) { + AccT logsoftmax = srcdata[i][it][0] - max_value[i] - sum[i]; + // softmax + softmax[(first_batch + i) * stride + idx] = std::exp(logsoftmax); + // label + int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize; + auto lbl = static_cast(label[first_batch + i]); + if (IgnoreIndex == true) { + // IgnoreIndex is true + if (lbl == loss_idx) { + if (lbl != ignore_index) { + loss[first_batch + i] = -logsoftmax; + } else { + loss[first_batch + i] = static_cast(0.0); + } + } + } else { + // IgnoreIndex is false + if (lbl >= 0 && lbl < element_count) { + if (lbl == loss_idx) { + loss[first_batch + i] = -logsoftmax; + } + } else { + loss[first_batch + i] = static_cast(0.0); + } + } + } else { // softmax + softmax[(first_batch + i) * stride + idx] = + srcdata[i][it][0] / sum[i]; + } + } else { + break; + } + } else { // KVSize>1 + VecT* softmax_v = + reinterpret_cast(&softmax[(first_batch + i) * stride]); + VecT tmpdata; + T* tmpptr = reinterpret_cast(&tmpdata); +#pragma unroll + for (int s = 0; s < kVSize; ++s) { + if (mode == SoftmaxMode::kLogSoftmax) { // log softmax + tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i]; + // softmax with cross entropy hard label + } else if (mode == SoftmaxMode::kCrossEntropy) { + AccT logsoftmax = srcdata[i][it][s] - max_value[i] - sum[i]; + // softmax + tmpptr[s] = std::exp(logsoftmax); + // label + int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s; + auto lbl = static_cast(label[first_batch + i]); + if (IgnoreIndex == true) { + // IgnoreIndex is true + if (lbl == loss_idx && lbl != ignore_index) { + loss[first_batch + i] = -logsoftmax; + } + } else { + // IgnoreIndex is false + if (lbl >= 0 && lbl < element_count) { + if (lbl == loss_idx) { + loss[first_batch + i] = -logsoftmax; + } + } else { + loss[first_batch + i] = static_cast(0.0); + } + } + } else { // softmax + tmpptr[s] = srcdata[i][it][s] / sum[i]; + } + } + if (idx < idx_max_v[i]) { + softmax_v[idx] = tmpdata; } else { - CrossEntropyHardLabel<<< - blocks, threads, 0, context.cuda_device_context().stream()>>>( - loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim, - ignore_index); + break; } } - - // cause of input is softmax - // copy to output softmax, directly - framework::TensorCopy(*softmax, context.GetPlace(), - context.device_context(), softmax_out); - - return; } + } +} - const Tensor* logits = context.Input("Logits"); - Tensor* softmax = context.Output("Softmax"); - Tensor* loss = context.Output("Loss"); +#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \ + case Log2Elements: \ + WarpSoftmaxForward<<>>( \ + loss, \ + softmax, \ + src, \ + label, \ + batch_size, \ + stride, \ + element_count, \ + ignore_index); \ + break; - const int rank = logits->dims().size(); - const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); - int axis_dim = logits->dims()[axis]; +/* + Wrapper of softmax with cross entropy forward hard label. +*/ +template +void SwitchWarpSoftmaxForward(T* loss, + T* softmax, + const T* src, + const LabelT* label, + const int batch_size, + const int stride, + const int element_count, + const int ignore_index, + gpuStream_t stream) { + using AccT = typename dtype::MPTypeTrait::Type; - const int64_t n = phi::funcs::SizeToAxis(axis, logits->dims()); - const int64_t d = phi::funcs::SizeFromAxis(axis, logits->dims()); + // use 128 threads per block to maximimize gpu utilization + const int log2_elements = static_cast(Log2Ceil(element_count)); + const int kDimCeil = 1 << log2_elements; + int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + int batches_per_warp = (kDimCeil <= 128) ? 2 : 1; + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / kWarpSize); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (batch_size + batches_per_block - 1) / batches_per_block; + dim3 threads(kWarpSize, warps_per_block, 1); - auto* softmax_data = softmax->template mutable_data(context.GetPlace()); - auto* loss_data = loss->template mutable_data(context.GetPlace()); + switch (log2_elements) { + SOFTMAX_WARP_FORWARD_CASE(0, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(1, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(2, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(3, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(4, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(5, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(6, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(7, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(8, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(9, LabelT, T, AccT); + default: + break; + } +} - if (axis_dim == 1) { - phi::funcs::SetConstant set_constant; - set_constant(context.cuda_device_context(), softmax, static_cast(1)); - set_constant(context.cuda_device_context(), loss, static_cast(0)); - return; - } +template +void LaunchVectorizedSoftmaxForward(T* loss, + T* softmax, + const T* logits, + const LabelT* label, + const int high_dim, + const int mid_dim, + const int ignore_index, + gpuStream_t stream) { + using AccT = typename dtype::MPTypeTrait::Type; + constexpr int vec_size = sizeof(float4) / sizeof(T); + const int max_num_threads = 1024; + int max_block_size = std::min(mid_dim / vec_size, max_num_threads); + if (vec_size > 1) { + max_block_size /= 2; + } - auto ignore_index = context.Attr("ignore_index"); + int block_size = 1; + while (block_size < max_block_size) { + block_size *= 2; + } + block_size = std::max(block_size, kps::details::kWarpSize); + dim3 grids(high_dim); + dim3 blocks(block_size); + VectorizedSoftmaxForward<<>>( + loss, softmax, logits, label, high_dim, mid_dim, ignore_index); +} - if (soft_label) { - auto* logits_data = logits->template data(); - auto* labels_data = labels.template data(); - SoftmaxWithCrossEntropySoftLabel( - context.cuda_device_context(), rank, axis, logits_data, labels_data, - softmax_data, loss_data, n, axis_dim, d / axis_dim); - } else { - if (!context.Attr("numeric_stable_mode")) { - // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim - Tensor logits_2d, softmax_2d, labels_2d, loss_2d; - logits_2d.ShareDataWith(*logits).Resize({n, d}); - softmax_2d.ShareDataWith(*softmax).Resize({n, d}); - labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n}); - loss_2d.ShareDataWith(*loss).Resize({n, 1}); - math::SoftmaxCUDNNFunctor()(context.cuda_device_context(), - &logits_2d, &softmax_2d); - math::CrossEntropyFunctor()( - context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d, - false, ignore_index, axis_dim); - } else { - auto* logits_data = logits->template data(); - auto* labels_data = labels.template data(); - if (ignore_index >= 0 && ignore_index < axis_dim) { - SoftmaxWithCrossEntropyHardLabel( - context.cuda_device_context(), rank, axis, logits_data, - labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim, - ignore_index); - } else { - SoftmaxWithCrossEntropyHardLabel( - context.cuda_device_context(), rank, axis, logits_data, - labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim, - ignore_index); - } - } +/* + Wrapper of softmax with cross entropy hard label. + - SwitchWarpSoftmaxForward for small size when axis == -1 + - LaunchVectorizedSoftmaxForward for large size when axis == -1 + - cudnn function for axis != -1 +*/ +template +static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx, + int rank, + int axis, + const T* logits_data, + const LabelT* labels_data, + T* loss_data, + T* softmax_data, + int N, + int dim, + int D, + const int ignore_index) { + auto stream = dev_ctx.stream(); + constexpr int max_dim = 320; + if (D == 1) { + if (dim <= max_dim) { // small size + const SoftmaxMode mode = SoftmaxMode::kCrossEntropy; + SwitchWarpSoftmaxForward(loss_data, + softmax_data, + logits_data, + labels_data, + N, + dim, + dim, + ignore_index, + stream); + } else { // large size + LaunchVectorizedSoftmaxForward(loss_data, + softmax_data, + logits_data, + labels_data, + N, + dim, + ignore_index, + stream); } - } -}; + } else { + ScopedTensorDescriptor desc; + std::vector tensor_dims = {N, dim, D, 1}; + GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW; +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t descp = desc.descriptor(layout, tensor_dims); +#else + cudnnTensorDescriptor_t descp = desc.descriptor(layout, tensor_dims); +#endif -template -class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - RunSoftmaxWithCrossEntropyFunctor(context, *this); + auto handle = dev_ctx.cudnn_handle(); + +#ifdef PADDLE_WITH_HIP + auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE + : MIOPEN_SOFTMAX_MODE_CHANNEL; + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenSoftmaxForward_V2( + handle, + paddle::platform::CudnnDataType::kOne(), + descp, + logits_data, + paddle::platform::CudnnDataType::kZero(), + descp, + softmax_data, + MIOPEN_SOFTMAX_LOG, + mode)); +#else + auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE + : CUDNN_SOFTMAX_MODE_CHANNEL; + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSoftmaxForward( + handle, + CUDNN_SOFTMAX_LOG, + mode, + paddle::platform::CudnnDataType::kOne(), + descp, + logits_data, + paddle::platform::CudnnDataType::kZero(), + descp, + softmax_data)); +#endif + int threads = 128; + int blocks = (N * dim * D + threads - 1) / threads; + // compute cross entropy, input is log softmax + CrossEntropyExpHardLabel<<>>( + loss_data, softmax_data, labels_data, N, dim, D, ignore_index); } +} - template - static void Apply(const framework::ExecutionContext& context, - const framework::Tensor& labels, const bool soft_label) { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(context.GetPlace()), true, - platform::errors::Unavailable("softmax_with_cross_entropy operator's " - "CUDA kernel only runs on GPU device.")); - const T* loss_grad_data = - context.Input(framework::GradVarName("Loss")) - ->template data(); - Tensor* logit_grad = - context.Output(framework::GradVarName("Logits")); - const Tensor* softmax = context.Input("Softmax"); - auto stream = context.cuda_device_context().stream(); - auto ignore_index = context.Attr("ignore_index"); - auto use_softmax = context.Attr("use_softmax"); - - T* logit_grad_data = nullptr; - bool copy_flag = (logit_grad != softmax && (!use_softmax || soft_label)); - if (copy_flag) { - framework::TensorCopy(*softmax, context.GetPlace(), - context.device_context(), logit_grad); - logit_grad_data = logit_grad->template data(); - } else { - logit_grad_data = - logit_grad->template mutable_data(context.GetPlace()); +template +void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, + const DenseTensor& logits, + const DenseTensor& label, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + DenseTensor* softmax, + DenseTensor* loss) { + PADDLE_ENFORCE_EQ( + dev_ctx.GetPlace().GetType(), + AllocationType::GPU, + phi::errors::Unavailable("softmax_with_cross_entropy operator's " + "CUDA kernel only runs on GPU device.")); + + // do not with softmax op, and input is softmax + if (!use_softmax) { + DenseTensor* softmax_out = softmax; + const DenseTensor* softmax = &logits; + const DenseTensor& labels = label; + + const int rank = softmax->dims().size(); + const int axis_v = phi::funcs::CanonicalAxis(axis, rank); + const int axis_dim = softmax->dims()[axis_v]; + + const int n = phi::funcs::SizeToAxis(axis_v, softmax->dims()); + const int d = phi::funcs::SizeFromAxis(axis_v, softmax->dims()); + + auto* softmax_out_data = dev_ctx.template Alloc(softmax_out); + auto* loss_data = dev_ctx.template Alloc(loss); + + phi::funcs::SetConstant set_constant; + set_constant(dev_ctx, loss, static_cast(0)); + if (axis_dim == 1) { + set_constant(dev_ctx, softmax_out, static_cast(1)); + return; } - const int rank = logit_grad->dims().size(); - const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); - int axis_dim = logit_grad->dims()[axis]; + DenseTensor softmax_2d(*softmax); + softmax_2d.Resize({n, d}); + DenseTensor labels_2d(labels); + labels_2d.Resize({n, labels.numel() / n}); + DenseTensor loss_2d(*loss); + loss_2d.Resize({n, 1}); + DenseTensor softmax_out_2d(*softmax_out); + softmax_out_2d.Resize({n, d}); + + // math::CrossEntropyFunctor support axis is the last + if (axis_v == -1) { + paddle::operators::math::CrossEntropyFunctor()( + dev_ctx, + &loss_2d, + &softmax_2d, + &labels_2d, + soft_label, + ignore_index, + axis_dim); + return; + } - const int64_t n = phi::funcs::SizeToAxis(axis, logit_grad->dims()); - const int64_t d = phi::funcs::SizeFromAxis(axis, logit_grad->dims()); - const int64_t remain = d / axis_dim; + // if axis is not the last, we need a new impliment + if (soft_label) { + auto* logits_data = softmax->data(); + auto* labels_data = labels.data(); + const int kDimLog2 = static_cast(Log2Ceil(axis_dim)); + const int kDimCeil = 1 << kDimLog2; #ifdef __HIPCC__ - int block = 256; + int kThreadPerBlock = 256; #else - int block = 512; + int kThreadPerBlock = 512; #endif - - // do not with softmax op, and input is softmax - if (!use_softmax) { - if (soft_label) { - int grid = (n * d + block - 1) / block; - const T* label_data = labels.template data(); - SoftLabelCrossEntropyGradientKernel<<>>( - logit_grad_data, loss_grad_data, label_data, n, d, remain); + int kBatchPerBlock = 1; + int blocks = (n * d + kBatchPerBlock - 1) / kBatchPerBlock; + dim3 threads(kThreadPerBlock / kBatchPerBlock, kBatchPerBlock, 1); + + CrossEntropySoftLabel<<>>( + loss_data, + NULL, + logits_data, + labels_data, + n, + axis_dim, + d / axis_dim, + kDimLog2); + } else { // HardLabel + auto* logits_data = softmax->data(); + auto* labels_data = labels.data(); + int threads = 128; + int blocks = (n * d / axis_dim + threads - 1) / threads; + if (ignore_index >= 0 && ignore_index < axis_dim) { + CrossEntropyHardLabel<<>>( + loss_data, + logits_data, + labels_data, + n, + axis_dim, + d / axis_dim, + ignore_index); } else { - Tensor logits_grad_2d; - logits_grad_2d.ShareDataWith(*logit_grad).Resize({n, d}); - int grid = (n * remain + block - 1) / block; - const auto* label_data = labels.template data(); - HardLabelCrossEntropyGradientKernel<<>>( - logit_grad_data, label_data, n, d, remain, ignore_index); - int num = n * d; - grid = (num + block - 1) / block; - ScaleCrossEntropyGradient<<>>( - logit_grad_data, loss_grad_data, num, d, remain, label_data, + CrossEntropyHardLabel<<>>( + loss_data, + logits_data, + labels_data, + n, + axis_dim, + d / axis_dim, ignore_index); } - - return; } - // with softmax, continue + // cause of input is softmax + // copy to output softmax, directly + phi::Copy( + dev_ctx, *softmax, dev_ctx.GetPlace(), false, softmax_out); - if (soft_label) { - int64_t grid = (n * d + block - 1) / block; - const T* label_data = labels.template data(); - SoftCrossEntropyGradientKernel<<>>( - logit_grad_data, loss_grad_data, label_data, n, d, remain); + return; + } + + const int rank = logits.dims().size(); + const int axis_v = phi::funcs::CanonicalAxis(axis, rank); + int axis_dim = logits.dims()[axis_v]; + + const int64_t n = phi::funcs::SizeToAxis(axis_v, logits.dims()); + const int64_t d = phi::funcs::SizeFromAxis(axis_v, logits.dims()); + + auto* softmax_data = dev_ctx.template Alloc(softmax); + auto* loss_data = dev_ctx.template Alloc(loss); + + if (axis_dim == 1) { + phi::funcs::SetConstant set_constant; + set_constant(dev_ctx, softmax, static_cast(1)); + set_constant(dev_ctx, loss, static_cast(0)); + return; + } + + if (soft_label) { + auto* logits_data = logits.data(); + auto* labels_data = label.data(); + SoftmaxWithCrossEntropySoftLabel(dev_ctx, + rank, + axis_v, + logits_data, + labels_data, + softmax_data, + loss_data, + n, + axis_dim, + d / axis_dim); + } else { + if (!numeric_stable_mode) { + // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim + DenseTensor logits_2d(logits); + logits_2d.Resize({n, d}); + DenseTensor softmax_2d(*softmax); + softmax_2d.Resize({n, d}); + DenseTensor labels_2d(label); + labels_2d.Resize({n, label.numel() / n}); + DenseTensor loss_2d(*loss); + loss_2d.Resize({n, 1}); + paddle::operators::math::SoftmaxCUDNNFunctor()( + dev_ctx, &logits_2d, &softmax_2d); + paddle::operators::math::CrossEntropyFunctor()( + dev_ctx, + &loss_2d, + &softmax_2d, + &labels_2d, + false, + ignore_index, + axis_dim); } else { - const T* softmax_data = softmax->template data(); - const auto* label_data = labels.template data(); - int grid = (n * d + block - 1) / block; - SoftmaxWithCrossEntropyGradHardLabel<<>>( - logit_grad_data, loss_grad_data, softmax_data, label_data, n, - d / remain, remain, ignore_index); + auto* logits_data = logits.data(); + auto* labels_data = label.data(); + if (ignore_index >= 0 && ignore_index < axis_dim) { + SoftmaxWithCrossEntropyHardLabel(dev_ctx, + rank, + axis_v, + logits_data, + labels_data, + loss_data, + softmax_data, + n, + axis_dim, + d / axis_dim, + ignore_index); + } else { + SoftmaxWithCrossEntropyHardLabel(dev_ctx, + rank, + axis_v, + logits_data, + labels_data, + loss_data, + softmax_data, + n, + axis_dim, + d / axis_dim, + ignore_index); + } } } -}; +} + +template +void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx, + const DenseTensor& logits, + const DenseTensor& label, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + DenseTensor* softmax, + DenseTensor* loss) { + auto dtype = label.dtype(); + if (soft_label) { + PADDLE_ENFORCE_EQ( + dtype, + paddle::experimental::CppTypeToDataType::Type(), + phi::errors::InvalidArgument("The Input(Label) should be with the " + "same data type as Input(Logits).")); + CrossEntropyWithSoftmaxCUDAKernel(dev_ctx, + logits, + label, + soft_label, + use_softmax, + numeric_stable_mode, + ignore_index, + axis, + softmax, + loss); + } else { + PD_DISPATCH_INTEGRAL_TYPES( + dtype, "CrossEntropyWithSoftmaxCUDAKernel", ([&] { + CrossEntropyWithSoftmaxCUDAKernel(dev_ctx, + logits, + label, + soft_label, + use_softmax, + numeric_stable_mode, + ignore_index, + axis, + softmax, + loss); + })); + } +} -} // namespace operators -} // namespace paddle +} // namespace phi -namespace ops = paddle::operators; #ifdef PADDLE_WITH_HIP -// MIOPEN do not support double -REGISTER_OP_CUDA_KERNEL( - softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel, - ops::SoftmaxWithCrossEntropyCUDAKernel); -REGISTER_OP_CUDA_KERNEL( - softmax_with_cross_entropy_grad, - ops::SoftmaxWithCrossEntropyGradCUDAKernel, - ops::SoftmaxWithCrossEntropyGradCUDAKernel); +PD_REGISTER_KERNEL(cross_entropy_with_softmax, + GPU, + ALL_LAYOUT, + phi::CrossEntropyWithSoftmaxKernel, + float, + phi::dtype::float16) {} #else -REGISTER_OP_CUDA_KERNEL( - softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel, - ops::SoftmaxWithCrossEntropyCUDAKernel, - ops::SoftmaxWithCrossEntropyCUDAKernel); -REGISTER_OP_CUDA_KERNEL( - softmax_with_cross_entropy_grad, - ops::SoftmaxWithCrossEntropyGradCUDAKernel, - ops::SoftmaxWithCrossEntropyGradCUDAKernel, - ops::SoftmaxWithCrossEntropyGradCUDAKernel); +PD_REGISTER_KERNEL(cross_entropy_with_softmax, + GPU, + ALL_LAYOUT, + phi::CrossEntropyWithSoftmaxKernel, + float, + double, + phi::dtype::float16) {} #endif diff --git a/paddle/phi/ops/compat/softmax_with_cross_entropy_sig.cc b/paddle/phi/ops/compat/softmax_with_cross_entropy_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..9cfc5ded90a49a1572a136f4de609b8ff4b742af --- /dev/null +++ b/paddle/phi/ops/compat/softmax_with_cross_entropy_sig.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SoftmaxWithCrossEntropyOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("cross_entropy_with_softmax", + {"Logits", "Label"}, + {"soft_label", + "use_softmax", + "numeric_stable_mode", + "ignore_index", + "axis"}, + {"Softmax", "Loss"}); +} + +KernelSignature SoftmaxWithCrossEntropyGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("cross_entropy_with_softmax_grad", + {"Label", "Softmax", GradVarName("Loss")}, + {"soft_label", + "use_softmax", + "numeric_stable_mode", + "ignore_index", + "axis"}, + {GradVarName("Logits")}); +} + +} // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(softmax_with_cross_entropy, + cross_entropy_with_softmax); +PD_REGISTER_BASE_KERNEL_NAME(softmax_with_cross_entropy_grad, + cross_entropy_with_softmax_grad); + +PD_REGISTER_ARG_MAPPING_FN(softmax_with_cross_entropy, + phi::SoftmaxWithCrossEntropyOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(softmax_with_cross_entropy_grad, + phi::SoftmaxWithCrossEntropyGradOpArgumentMapping);