diff --git a/paddle/fluid/operators/log_softmax_op.cu b/paddle/fluid/operators/log_softmax_op.cu index 034e67568b34cebdfeddb884345b21cd99afb34f..8770abdac838f63b0c9f3a95b1ac0283a80ecbf2 100644 --- a/paddle/fluid/operators/log_softmax_op.cu +++ b/paddle/fluid/operators/log_softmax_op.cu @@ -13,9 +13,9 @@ // limitations under the License. #include -#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/log_softmax_op.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/functors.h" @@ -311,7 +311,7 @@ void LaunchLogSoftmaxForwardCUDAKernelNotLastAxis(T *output_data, template class LogSoftmaxKernel : public framework::OpKernel { - using MPDType = typename details::MPTypeTrait::Type; + using MPDType = typename phi::dtype::MPTypeTrait::Type; public: void Compute(const framework::ExecutionContext &context) const override { @@ -433,7 +433,7 @@ void LaunchSoftmaxBackwardForLastAxis(T *grad_input, const T *grad_output, template class LogSoftmaxGradKernel : public framework::OpKernel { - using MPDType = typename details::MPTypeTrait::Type; + using MPDType = typename phi::dtype::MPTypeTrait::Type; public: void Compute(const framework::ExecutionContext &context) const override { @@ -468,16 +468,18 @@ class LogSoftmaxGradKernel } }; -} // operators -} // paddle +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( log_softmax, ops::LogSoftmaxKernel, ops::LogSoftmaxKernel, - ops::LogSoftmaxKernel); + ops::LogSoftmaxKernel, + ops::LogSoftmaxKernel); REGISTER_OP_CUDA_KERNEL( log_softmax_grad, ops::LogSoftmaxGradKernel, ops::LogSoftmaxGradKernel, - ops::LogSoftmaxGradKernel); + ops::LogSoftmaxGradKernel, + ops::LogSoftmaxGradKernel); diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index fd879e9e6ffe72a2175acc2db98727f5ff39fbbb..83b124902ebb74e65af0a25e432ff6b488e5cee1 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -120,6 +120,10 @@ template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; +#if CUDNN_VERSION_MIN(8, 1, 0) +template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; +#endif // MIOPEN do not support double #ifndef PADDLE_WITH_HIP @@ -131,6 +135,10 @@ template class SoftmaxFunctor; template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; @@ -139,9 +147,13 @@ template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; +template class SoftmaxGradFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; @@ -149,6 +161,7 @@ template class SoftmaxFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; +template class SoftmaxGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index d51d638e0c19f43f9b0a91adbac15dffcdf14588..9833b4447ec45376e04ad520315e88568f7991d8 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -156,6 +156,65 @@ class SoftmaxEigen { } }; +template +class SoftmaxEigen { + public: + void operator()(const DeviceContext& context, const int axis_dim, + const framework::Tensor* X, framework::Tensor* Y) { + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + constexpr int kAxisDim = 1; + + auto logits = EigenMatrix::From(*X); + auto softmax = EigenMatrix::From(*Y); + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes along_axis(kAxisDim); + Eigen::DSizes batch_classes(batch_size, num_classes); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); + Eigen::DSizes one_axis_one(1, axis_dim, 1); + Eigen::DSizes one_axis(1, axis_dim); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + + // For numerical stability, logits should be shifted by maximum number along + // axis, calculate shifted_logits into softmax tensor for memory reuse. + if (num_remain == 1) { + // axis == -1, axis and class in same dimension, calculate along + // class dimension directly for higher performance + softmax.device(*context.eigen_device()) = + (logits - + logits.maximum(along_axis) + .reshape(batch_by_one) + .broadcast(one_by_class)) + .unaryExpr(ValueClip()); + } else { + // axis != -1, class dimension split into (axis, remain), max and sum + // should be calculated along axis dimension + softmax.device(*context.eigen_device()) = + (logits.reshape(batch_axis_remain) - + logits.reshape(batch_axis_remain) + .maximum(along_axis) + .reshape(batch_one_remain) + .broadcast(one_axis_one) + .reshape(batch_classes)) + .unaryExpr(ValueClip()); + } + + softmax.device(*context.eigen_device()) = softmax.exp(); + softmax.device(*context.eigen_device()) = + (softmax * + softmax.reshape(batch_axis_remain) + .sum(along_axis) + .inverse() + .broadcast(one_axis)); + } +}; + template void SoftmaxFunctor::operator()( const DeviceContext& context, const int axis_dim, @@ -289,6 +348,38 @@ class SoftmaxGradEigen { } }; +template +class SoftmaxGradEigen { + public: + void operator()(const DeviceContext& context, const int axis_dim, + const framework::Tensor* y, const framework::Tensor* y_grad, + framework::Tensor* x_grad) { + auto softmax = EigenMatrix::From(*y); + auto softmax_grad = EigenMatrix::From(*y_grad); + auto logits_grad = EigenMatrix::From(*x_grad); + + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + + const int batch_size = softmax.dimension(kBatchDim); + const int num_classes = softmax.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + Eigen::DSizes one_axis(1, axis_dim); + + auto dot = (softmax * softmax_grad) + .reshape(batch_axis_remain) + .sum(along_class) + .broadcast(one_axis); + logits_grad.device(*context.eigen_device()) = + (softmax_grad - dot) * softmax; + } +}; + template void SoftmaxGradFunctor::operator()( const DeviceContext& context, const int axis_dim, diff --git a/paddle/fluid/platform/device/gpu/rocm/miopen_helper.h b/paddle/fluid/platform/device/gpu/rocm/miopen_helper.h index 34b9d57e055d57533a02466d17d83e26ddaa40d9..1a514d2aca2675932396fede6d22f4962e4e0d76 100644 --- a/paddle/fluid/platform/device/gpu/rocm/miopen_helper.h +++ b/paddle/fluid/platform/device/gpu/rocm/miopen_helper.h @@ -140,6 +140,23 @@ class CudnnDataType { } }; +template <> +class CudnnDataType { + public: + static const miopenDataType_t type = miopenBFloat16; + // The scaling param type is float for HALF and FLOAT tensors + using ScalingParamType = const float; + using BatchNormParamType = float; + static ScalingParamType* kOne() { + static ScalingParamType v = 1.0; + return &v; + } + static ScalingParamType* kZero() { + static ScalingParamType v = 0.0; + return &v; + } +}; + template <> class CudnnDataType { public: diff --git a/paddle/phi/common/amp_type_traits.h b/paddle/phi/common/amp_type_traits.h new file mode 100644 index 0000000000000000000000000000000000000000..ce3a469f5aeddc29e67e320141d2ebaab925fabd --- /dev/null +++ b/paddle/phi/common/amp_type_traits.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" + +namespace phi { +namespace dtype { + +template +class MPTypeTrait { + public: + using Type = T; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +} // namespace dtype +} // namespace phi diff --git a/paddle/phi/common/bfloat16.h b/paddle/phi/common/bfloat16.h index 3fd8eb1b2684a0d3b04c88549cb38698975ace2f..cf99bb8f19af0516ed5524c1d2af89777c1e1d0b 100644 --- a/paddle/phi/common/bfloat16.h +++ b/paddle/phi/common/bfloat16.h @@ -377,31 +377,31 @@ struct numeric_limits { static const bool traps = true; static const bool tinyness_before = false; - static phi::dtype::bfloat16(min)() { + HOSTDEVICE static phi::dtype::bfloat16(min)() { return phi::dtype::raw_uint16_to_bfloat16(0x007f); } - static phi::dtype::bfloat16 lowest() { + HOSTDEVICE static phi::dtype::bfloat16 lowest() { return phi::dtype::raw_uint16_to_bfloat16(0xff7f); } - static phi::dtype::bfloat16(max)() { + HOSTDEVICE static phi::dtype::bfloat16(max)() { return phi::dtype::raw_uint16_to_bfloat16(0x7f7f); } - static phi::dtype::bfloat16 epsilon() { + HOSTDEVICE static phi::dtype::bfloat16 epsilon() { return phi::dtype::raw_uint16_to_bfloat16(0x3400); } - static phi::dtype::bfloat16 round_error() { + HOSTDEVICE static phi::dtype::bfloat16 round_error() { return phi::dtype::bfloat16(0.5); } - static phi::dtype::bfloat16 infinity() { + HOSTDEVICE static phi::dtype::bfloat16 infinity() { return phi::dtype::raw_uint16_to_bfloat16(0x7f80); } - static phi::dtype::bfloat16 quiet_NaN() { + HOSTDEVICE static phi::dtype::bfloat16 quiet_NaN() { return phi::dtype::raw_uint16_to_bfloat16(0xffc1); } - static phi::dtype::bfloat16 signaling_NaN() { + HOSTDEVICE static phi::dtype::bfloat16 signaling_NaN() { return phi::dtype::raw_uint16_to_bfloat16(0xff81); } - static phi::dtype::bfloat16 denorm_min() { + HOSTDEVICE static phi::dtype::bfloat16 denorm_min() { return phi::dtype::raw_uint16_to_bfloat16(0x0001); } }; diff --git a/paddle/phi/common/float16.h b/paddle/phi/common/float16.h index 6ed9c88d705106ce3b03732096fa34b23422875f..1cdcdef2c12eec1c59c0fd2dfdf1c4dd6e62bd37 100644 --- a/paddle/phi/common/float16.h +++ b/paddle/phi/common/float16.h @@ -988,18 +988,6 @@ inline std::ostream& operator<<(std::ostream& os, const float16& a) { return os; } -template -class MPTypeTrait { - public: - using Type = T; -}; - -template <> -class MPTypeTrait { - public: - using Type = float; -}; - } // namespace dtype } // namespace phi diff --git a/paddle/phi/kernels/gpu/softmax_grad_kernel.cu b/paddle/phi/kernels/gpu/softmax_grad_kernel.cu index aa496d3cd391b59bef16c57dc8b7f0c39834c107..04052e0dfc39a44a0f485557b9e8dc57b8794c38 100644 --- a/paddle/phi/kernels/gpu/softmax_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/softmax_grad_kernel.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/phi/kernels/softmax_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/softmax_grad_kernel_impl.h" @@ -25,4 +26,5 @@ PD_REGISTER_KERNEL(softmax_grad, phi::SoftmaxGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/softmax_kernel.cu b/paddle/phi/kernels/gpu/softmax_kernel.cu index 32efb9b776419efe5733ab0493c38f9c1a9c237e..03c5714b967841ef1bd124bd9191830a79567514 100644 --- a/paddle/phi/kernels/gpu/softmax_kernel.cu +++ b/paddle/phi/kernels/gpu/softmax_kernel.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/phi/kernels/softmax_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/softmax_kernel_impl.h" @@ -25,4 +26,5 @@ PD_REGISTER_KERNEL(softmax, phi::SoftmaxRawKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index 45798b88bb58a3b088b2545f4a343c18ebec0ec4..c9c549379bbce6ceb4c2314f34f24ad659f5c272 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/axis_utils.h" @@ -47,6 +49,11 @@ class VecT4 { public: using Type = int2; }; +template <> +class VecT4 { + public: + using Type = int2; +}; // Vectorization trait 2 * sizeof(T) template @@ -66,6 +73,11 @@ class VecT2 { public: using Type = int; }; +template <> +class VecT2 { + public: + using Type = int; +}; static inline int log2_ceil(int value) { int log2_value = 0; diff --git a/paddle/phi/kernels/gpudnn/softmax_grad_kernel_gpudnn.cu b/paddle/phi/kernels/gpudnn/softmax_grad_kernel_gpudnn.cu index 56e5fef6e37e41dd6405af25c214013211670246..45ab645d3736734fb9ec4c6a7b949274c1f0a91e 100644 --- a/paddle/phi/kernels/gpudnn/softmax_grad_kernel_gpudnn.cu +++ b/paddle/phi/kernels/gpudnn/softmax_grad_kernel_gpudnn.cu @@ -38,7 +38,18 @@ PD_REGISTER_KERNEL(softmax_grad, ALL_LAYOUT, phi::SoftmaxGradGPUDNNKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} +#else +#if CUDNN_VERSION_MIN(8, 1, 0) +PD_REGISTER_KERNEL(softmax_grad, + GPUDNN, + ALL_LAYOUT, + phi::SoftmaxGradGPUDNNKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} #else PD_REGISTER_KERNEL(softmax_grad, GPUDNN, @@ -48,3 +59,4 @@ PD_REGISTER_KERNEL(softmax_grad, double, phi::dtype::float16) {} #endif +#endif diff --git a/paddle/phi/kernels/gpudnn/softmax_kernel_gpudnn.cu b/paddle/phi/kernels/gpudnn/softmax_kernel_gpudnn.cu index 427d1729a13a8ea8e0caf4aa534b012af76e79f2..7685c7dbb6894b4e640ea4b63010c4d22fc5e18f 100644 --- a/paddle/phi/kernels/gpudnn/softmax_kernel_gpudnn.cu +++ b/paddle/phi/kernels/gpudnn/softmax_kernel_gpudnn.cu @@ -37,7 +37,18 @@ PD_REGISTER_KERNEL(softmax, ALL_LAYOUT, phi::SoftmaxRawGPUDNNKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} +#else +#if CUDNN_VERSION_MIN(8, 1, 0) +PD_REGISTER_KERNEL(softmax, + GPUDNN, + ALL_LAYOUT, + phi::SoftmaxRawGPUDNNKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} #else PD_REGISTER_KERNEL(softmax, GPUDNN, @@ -47,3 +58,4 @@ PD_REGISTER_KERNEL(softmax, double, phi::dtype::float16) {} #endif +#endif diff --git a/python/paddle/fluid/tests/unittests/test_log_softmax.py b/python/paddle/fluid/tests/unittests/test_log_softmax.py index d1437ca9c96f1ba5fd2b9e1e420f91414d4f923a..16f954708d4d4149f46a18cfd48e35dfbe147153 100644 --- a/python/paddle/fluid/tests/unittests/test_log_softmax.py +++ b/python/paddle/fluid/tests/unittests/test_log_softmax.py @@ -14,8 +14,9 @@ import unittest import numpy as np -from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 import paddle +import paddle.fluid.core as core import paddle.nn.functional as F np.random.seed(10) @@ -74,6 +75,33 @@ class TestLogSoftmaxAxis(TestLogSoftmaxOp): self.axis = 1 +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestLogSoftmaxBF16Op(OpTest): + def setUp(self): + self.op_type = 'log_softmax' + self.dtype = np.uint16 + self.shape = [2, 3, 4, 5] + self.axis = -1 + + x = np.random.uniform(0.1, 1., self.shape).astype(np.float32) + out = np.apply_along_axis(ref_log_softmax, self.axis, x) + self.x_grad = ref_log_softmax_grad(x, self.axis) + + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': convert_float_to_uint16(out)} + self.attrs = {'axis': self.axis} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], ['Out'], user_defined_grads=[self.x_grad]) + + class TestNNLogSoftmaxAPI(unittest.TestCase): def setUp(self): self.x_shape = [2, 3, 4, 5] diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index a1cbefa40f307f5cdc1a64feaa51573f68a259f5..4f1c37a242474a63078336cbbecae06d78e5cdbd 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard @@ -296,6 +296,56 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp): return [2, 3, 4, 5] +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxBF16Op(OpTest): + def setUp(self): + self.op_type = "softmax" + self.use_cudnn = self.init_cudnn() + self.use_mkldnn = False + self.dtype = np.uint16 + self.shape = [10, 10] + self.axis = -1 + + np.random.seed(0) + x = np.random.uniform(0.1, 1, self.shape).astype(np.float32) + out = np.apply_along_axis(stable_softmax, self.axis, x) + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(convert_float_to_uint16(x)) + } + self.outputs = {'Out': convert_float_to_uint16(out)} + self.attrs = { + 'axis': self.axis, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn + } + + def init_cudnn(self): + return False + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place( + place, check_dygraph=(self.use_mkldnn == False)) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ["X"], + "Out", + numeric_grad_delta=0.05, + check_dygraph=(self.use_mkldnn == False)) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or core.cudnn_version() < 8100, + "core is not compiled with CUDA and cudnn version need larger than 8.1.0") +class TestSoftmaxBF16CUDNNOp(TestSoftmaxBF16Op): + def init_cudnn(self): + return True + + class TestSoftmaxAPI(unittest.TestCase): def setUp(self): self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(