未验证 提交 4a4215ff 编写于 作者: Z zhangbo9674 提交者: GitHub

[bf16] add bf16 kernel: softmax & log_softmax (#39999)

* add softmax log_softmax

* refine rocm

* refine unittest
上级 c9cd47d9
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
// limitations under the License. // limitations under the License.
#include <limits> #include <limits>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/log_softmax_op.h" #include "paddle/fluid/operators/log_softmax_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/funcs/functors.h"
...@@ -311,7 +311,7 @@ void LaunchLogSoftmaxForwardCUDAKernelNotLastAxis(T *output_data, ...@@ -311,7 +311,7 @@ void LaunchLogSoftmaxForwardCUDAKernelNotLastAxis(T *output_data,
template <typename T> template <typename T>
class LogSoftmaxKernel<platform::CUDADeviceContext, T> class LogSoftmaxKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type; using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -433,7 +433,7 @@ void LaunchSoftmaxBackwardForLastAxis(T *grad_input, const T *grad_output, ...@@ -433,7 +433,7 @@ void LaunchSoftmaxBackwardForLastAxis(T *grad_input, const T *grad_output,
template <typename T> template <typename T>
class LogSoftmaxGradKernel<platform::CUDADeviceContext, T> class LogSoftmaxGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type; using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -468,16 +468,18 @@ class LogSoftmaxGradKernel<platform::CUDADeviceContext, T> ...@@ -468,16 +468,18 @@ class LogSoftmaxGradKernel<platform::CUDADeviceContext, T>
} }
}; };
} // operators } // namespace operators
} // paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
log_softmax, ops::LogSoftmaxKernel<plat::CUDADeviceContext, float>, log_softmax, ops::LogSoftmaxKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, double>, ops::LogSoftmaxKernel<plat::CUDADeviceContext, double>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, plat::float16>); ops::LogSoftmaxKernel<plat::CUDADeviceContext, plat::float16>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
log_softmax_grad, ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, float>, log_softmax_grad, ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, double>, ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, double>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::float16>); ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::bfloat16>);
...@@ -120,6 +120,10 @@ template class SoftmaxCUDNNFunctor<float>; ...@@ -120,6 +120,10 @@ template class SoftmaxCUDNNFunctor<float>;
template class SoftmaxCUDNNFunctor<platform::float16>; template class SoftmaxCUDNNFunctor<platform::float16>;
template class SoftmaxGradCUDNNFunctor<float>; template class SoftmaxGradCUDNNFunctor<float>;
template class SoftmaxGradCUDNNFunctor<platform::float16>; template class SoftmaxGradCUDNNFunctor<platform::float16>;
#if CUDNN_VERSION_MIN(8, 1, 0)
template class SoftmaxCUDNNFunctor<platform::bfloat16>;
template class SoftmaxGradCUDNNFunctor<platform::bfloat16>;
#endif
// MIOPEN do not support double // MIOPEN do not support double
#ifndef PADDLE_WITH_HIP #ifndef PADDLE_WITH_HIP
...@@ -131,6 +135,10 @@ template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16, ...@@ -131,6 +135,10 @@ template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,
false>; false>;
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16, template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,
true>; true>;
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::bfloat16,
false>;
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::bfloat16,
true>;
template class SoftmaxFunctor<platform::CUDADeviceContext, float, false>; template class SoftmaxFunctor<platform::CUDADeviceContext, float, false>;
template class SoftmaxFunctor<platform::CUDADeviceContext, double, false>; template class SoftmaxFunctor<platform::CUDADeviceContext, double, false>;
template class SoftmaxFunctor<platform::CUDADeviceContext, float, true>; template class SoftmaxFunctor<platform::CUDADeviceContext, float, true>;
...@@ -139,9 +147,13 @@ template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>; ...@@ -139,9 +147,13 @@ template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext, double>; template class SoftmaxGradFunctor<platform::CUDADeviceContext, double>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext, template class SoftmaxGradFunctor<platform::CUDADeviceContext,
platform::float16>; platform::float16>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext,
platform::bfloat16>;
template class SoftmaxFunctor<phi::GPUContext, platform::float16, false>; template class SoftmaxFunctor<phi::GPUContext, platform::float16, false>;
template class SoftmaxFunctor<phi::GPUContext, platform::float16, true>; template class SoftmaxFunctor<phi::GPUContext, platform::float16, true>;
template class SoftmaxFunctor<phi::GPUContext, platform::bfloat16, false>;
template class SoftmaxFunctor<phi::GPUContext, platform::bfloat16, true>;
template class SoftmaxFunctor<phi::GPUContext, float, false>; template class SoftmaxFunctor<phi::GPUContext, float, false>;
template class SoftmaxFunctor<phi::GPUContext, double, false>; template class SoftmaxFunctor<phi::GPUContext, double, false>;
template class SoftmaxFunctor<phi::GPUContext, float, true>; template class SoftmaxFunctor<phi::GPUContext, float, true>;
...@@ -149,6 +161,7 @@ template class SoftmaxFunctor<phi::GPUContext, double, true>; ...@@ -149,6 +161,7 @@ template class SoftmaxFunctor<phi::GPUContext, double, true>;
template class SoftmaxGradFunctor<phi::GPUContext, float>; template class SoftmaxGradFunctor<phi::GPUContext, float>;
template class SoftmaxGradFunctor<phi::GPUContext, double>; template class SoftmaxGradFunctor<phi::GPUContext, double>;
template class SoftmaxGradFunctor<phi::GPUContext, platform::float16>; template class SoftmaxGradFunctor<phi::GPUContext, platform::float16>;
template class SoftmaxGradFunctor<phi::GPUContext, platform::bfloat16>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -156,6 +156,65 @@ class SoftmaxEigen<DeviceContext, platform::float16, is_test> { ...@@ -156,6 +156,65 @@ class SoftmaxEigen<DeviceContext, platform::float16, is_test> {
} }
}; };
template <typename DeviceContext, bool is_test>
class SoftmaxEigen<DeviceContext, platform::bfloat16, is_test> {
public:
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int kAxisDim = 1;
auto logits = EigenMatrix<platform::bfloat16>::From(*X);
auto softmax = EigenMatrix<platform::bfloat16>::From(*Y);
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_axis(kAxisDim);
Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
// For numerical stability, logits should be shifted by maximum number along
// axis, calculate shifted_logits into softmax tensor for memory reuse.
if (num_remain == 1) {
// axis == -1, axis and class in same dimension, calculate along
// class dimension directly for higher performance
softmax.device(*context.eigen_device()) =
(logits -
logits.maximum(along_axis)
.reshape(batch_by_one)
.broadcast(one_by_class))
.unaryExpr(ValueClip<platform::bfloat16>());
} else {
// axis != -1, class dimension split into (axis, remain), max and sum
// should be calculated along axis dimension
softmax.device(*context.eigen_device()) =
(logits.reshape(batch_axis_remain) -
logits.reshape(batch_axis_remain)
.maximum(along_axis)
.reshape(batch_one_remain)
.broadcast(one_axis_one)
.reshape(batch_classes))
.unaryExpr(ValueClip<platform::bfloat16>());
}
softmax.device(*context.eigen_device()) = softmax.exp();
softmax.device(*context.eigen_device()) =
(softmax *
softmax.reshape(batch_axis_remain)
.sum(along_axis)
.inverse()
.broadcast(one_axis));
}
};
template <typename DeviceContext, typename T, bool is_test, typename Enable> template <typename DeviceContext, typename T, bool is_test, typename Enable>
void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()( void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
const DeviceContext& context, const int axis_dim, const DeviceContext& context, const int axis_dim,
...@@ -289,6 +348,38 @@ class SoftmaxGradEigen<DeviceContext, platform::float16> { ...@@ -289,6 +348,38 @@ class SoftmaxGradEigen<DeviceContext, platform::float16> {
} }
}; };
template <typename DeviceContext>
class SoftmaxGradEigen<DeviceContext, platform::bfloat16> {
public:
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* y, const framework::Tensor* y_grad,
framework::Tensor* x_grad) {
auto softmax = EigenMatrix<platform::bfloat16>::From(*y);
auto softmax_grad = EigenMatrix<platform::bfloat16>::From(*y_grad);
auto logits_grad = EigenMatrix<platform::bfloat16>::From(*x_grad);
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
auto dot = (softmax * softmax_grad)
.reshape(batch_axis_remain)
.sum(along_class)
.broadcast(one_axis);
logits_grad.device(*context.eigen_device()) =
(softmax_grad - dot) * softmax;
}
};
template <typename DeviceContext, typename T, typename Enable> template <typename DeviceContext, typename T, typename Enable>
void SoftmaxGradFunctor<DeviceContext, T, Enable>::operator()( void SoftmaxGradFunctor<DeviceContext, T, Enable>::operator()(
const DeviceContext& context, const int axis_dim, const DeviceContext& context, const int axis_dim,
......
...@@ -140,6 +140,23 @@ class CudnnDataType<float16> { ...@@ -140,6 +140,23 @@ class CudnnDataType<float16> {
} }
}; };
template <>
class CudnnDataType<bfloat16> {
public:
static const miopenDataType_t type = miopenBFloat16;
// The scaling param type is float for HALF and FLOAT tensors
using ScalingParamType = const float;
using BatchNormParamType = float;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
};
template <> template <>
class CudnnDataType<float> { class CudnnDataType<float> {
public: public:
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
namespace phi {
namespace dtype {
template <typename T>
class MPTypeTrait {
public:
using Type = T;
};
template <>
class MPTypeTrait<phi::dtype::float16> {
public:
using Type = float;
};
template <>
class MPTypeTrait<phi::dtype::bfloat16> {
public:
using Type = float;
};
} // namespace dtype
} // namespace phi
...@@ -377,31 +377,31 @@ struct numeric_limits<phi::dtype::bfloat16> { ...@@ -377,31 +377,31 @@ struct numeric_limits<phi::dtype::bfloat16> {
static const bool traps = true; static const bool traps = true;
static const bool tinyness_before = false; static const bool tinyness_before = false;
static phi::dtype::bfloat16(min)() { HOSTDEVICE static phi::dtype::bfloat16(min)() {
return phi::dtype::raw_uint16_to_bfloat16(0x007f); return phi::dtype::raw_uint16_to_bfloat16(0x007f);
} }
static phi::dtype::bfloat16 lowest() { HOSTDEVICE static phi::dtype::bfloat16 lowest() {
return phi::dtype::raw_uint16_to_bfloat16(0xff7f); return phi::dtype::raw_uint16_to_bfloat16(0xff7f);
} }
static phi::dtype::bfloat16(max)() { HOSTDEVICE static phi::dtype::bfloat16(max)() {
return phi::dtype::raw_uint16_to_bfloat16(0x7f7f); return phi::dtype::raw_uint16_to_bfloat16(0x7f7f);
} }
static phi::dtype::bfloat16 epsilon() { HOSTDEVICE static phi::dtype::bfloat16 epsilon() {
return phi::dtype::raw_uint16_to_bfloat16(0x3400); return phi::dtype::raw_uint16_to_bfloat16(0x3400);
} }
static phi::dtype::bfloat16 round_error() { HOSTDEVICE static phi::dtype::bfloat16 round_error() {
return phi::dtype::bfloat16(0.5); return phi::dtype::bfloat16(0.5);
} }
static phi::dtype::bfloat16 infinity() { HOSTDEVICE static phi::dtype::bfloat16 infinity() {
return phi::dtype::raw_uint16_to_bfloat16(0x7f80); return phi::dtype::raw_uint16_to_bfloat16(0x7f80);
} }
static phi::dtype::bfloat16 quiet_NaN() { HOSTDEVICE static phi::dtype::bfloat16 quiet_NaN() {
return phi::dtype::raw_uint16_to_bfloat16(0xffc1); return phi::dtype::raw_uint16_to_bfloat16(0xffc1);
} }
static phi::dtype::bfloat16 signaling_NaN() { HOSTDEVICE static phi::dtype::bfloat16 signaling_NaN() {
return phi::dtype::raw_uint16_to_bfloat16(0xff81); return phi::dtype::raw_uint16_to_bfloat16(0xff81);
} }
static phi::dtype::bfloat16 denorm_min() { HOSTDEVICE static phi::dtype::bfloat16 denorm_min() {
return phi::dtype::raw_uint16_to_bfloat16(0x0001); return phi::dtype::raw_uint16_to_bfloat16(0x0001);
} }
}; };
......
...@@ -988,18 +988,6 @@ inline std::ostream& operator<<(std::ostream& os, const float16& a) { ...@@ -988,18 +988,6 @@ inline std::ostream& operator<<(std::ostream& os, const float16& a) {
return os; return os;
} }
template <typename T>
class MPTypeTrait {
public:
using Type = T;
};
template <>
class MPTypeTrait<float16> {
public:
using Type = float;
};
} // namespace dtype } // namespace dtype
} // namespace phi } // namespace phi
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/phi/kernels/softmax_grad_kernel.h" #include "paddle/phi/kernels/softmax_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/softmax_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/softmax_grad_kernel_impl.h"
...@@ -25,4 +26,5 @@ PD_REGISTER_KERNEL(softmax_grad, ...@@ -25,4 +26,5 @@ PD_REGISTER_KERNEL(softmax_grad,
phi::SoftmaxGradKernel, phi::SoftmaxGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/phi/kernels/softmax_kernel.h" #include "paddle/phi/kernels/softmax_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/softmax_kernel_impl.h" #include "paddle/phi/kernels/impl/softmax_kernel_impl.h"
...@@ -25,4 +26,5 @@ PD_REGISTER_KERNEL(softmax, ...@@ -25,4 +26,5 @@ PD_REGISTER_KERNEL(softmax,
phi::SoftmaxRawKernel, phi::SoftmaxRawKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/axis_utils.h"
...@@ -47,6 +49,11 @@ class VecT4<phi::dtype::float16> { ...@@ -47,6 +49,11 @@ class VecT4<phi::dtype::float16> {
public: public:
using Type = int2; using Type = int2;
}; };
template <>
class VecT4<phi::dtype::bfloat16> {
public:
using Type = int2;
};
// Vectorization trait 2 * sizeof(T) // Vectorization trait 2 * sizeof(T)
template <typename T> template <typename T>
...@@ -66,6 +73,11 @@ class VecT2<phi::dtype::float16> { ...@@ -66,6 +73,11 @@ class VecT2<phi::dtype::float16> {
public: public:
using Type = int; using Type = int;
}; };
template <>
class VecT2<phi::dtype::bfloat16> {
public:
using Type = int;
};
static inline int log2_ceil(int value) { static inline int log2_ceil(int value) {
int log2_value = 0; int log2_value = 0;
......
...@@ -38,7 +38,18 @@ PD_REGISTER_KERNEL(softmax_grad, ...@@ -38,7 +38,18 @@ PD_REGISTER_KERNEL(softmax_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::SoftmaxGradGPUDNNKernel, phi::SoftmaxGradGPUDNNKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(softmax_grad,
GPUDNN,
ALL_LAYOUT,
phi::SoftmaxGradGPUDNNKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else #else
PD_REGISTER_KERNEL(softmax_grad, PD_REGISTER_KERNEL(softmax_grad,
GPUDNN, GPUDNN,
...@@ -48,3 +59,4 @@ PD_REGISTER_KERNEL(softmax_grad, ...@@ -48,3 +59,4 @@ PD_REGISTER_KERNEL(softmax_grad,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {}
#endif #endif
#endif
...@@ -37,7 +37,18 @@ PD_REGISTER_KERNEL(softmax, ...@@ -37,7 +37,18 @@ PD_REGISTER_KERNEL(softmax,
ALL_LAYOUT, ALL_LAYOUT,
phi::SoftmaxRawGPUDNNKernel, phi::SoftmaxRawGPUDNNKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(softmax,
GPUDNN,
ALL_LAYOUT,
phi::SoftmaxRawGPUDNNKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else #else
PD_REGISTER_KERNEL(softmax, PD_REGISTER_KERNEL(softmax,
GPUDNN, GPUDNN,
...@@ -47,3 +58,4 @@ PD_REGISTER_KERNEL(softmax, ...@@ -47,3 +58,4 @@ PD_REGISTER_KERNEL(softmax,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {}
#endif #endif
#endif
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
import unittest import unittest
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
import paddle import paddle
import paddle.fluid.core as core
import paddle.nn.functional as F import paddle.nn.functional as F
np.random.seed(10) np.random.seed(10)
...@@ -74,6 +75,33 @@ class TestLogSoftmaxAxis(TestLogSoftmaxOp): ...@@ -74,6 +75,33 @@ class TestLogSoftmaxAxis(TestLogSoftmaxOp):
self.axis = 1 self.axis = 1
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestLogSoftmaxBF16Op(OpTest):
def setUp(self):
self.op_type = 'log_softmax'
self.dtype = np.uint16
self.shape = [2, 3, 4, 5]
self.axis = -1
x = np.random.uniform(0.1, 1., self.shape).astype(np.float32)
out = np.apply_along_axis(ref_log_softmax, self.axis, x)
self.x_grad = ref_log_softmax_grad(x, self.axis)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
self.attrs = {'axis': self.axis}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], ['Out'], user_defined_grads=[self.x_grad])
class TestNNLogSoftmaxAPI(unittest.TestCase): class TestNNLogSoftmaxAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.x_shape = [2, 3, 4, 5] self.x_shape = [2, 3, 4, 5]
......
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest, convert_float_to_uint16
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard from paddle.fluid import compiler, Program, program_guard
...@@ -296,6 +296,56 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp): ...@@ -296,6 +296,56 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp):
return [2, 3, 4, 5] return [2, 3, 4, 5]
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxBF16Op(OpTest):
def setUp(self):
self.op_type = "softmax"
self.use_cudnn = self.init_cudnn()
self.use_mkldnn = False
self.dtype = np.uint16
self.shape = [10, 10]
self.axis = -1
np.random.seed(0)
x = np.random.uniform(0.1, 1, self.shape).astype(np.float32)
out = np.apply_along_axis(stable_softmax, self.axis, x)
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(convert_float_to_uint16(x))
}
self.outputs = {'Out': convert_float_to_uint16(out)}
self.attrs = {
'axis': self.axis,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn
}
def init_cudnn(self):
return False
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, check_dygraph=(self.use_mkldnn == False))
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ["X"],
"Out",
numeric_grad_delta=0.05,
check_dygraph=(self.use_mkldnn == False))
@unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100,
"core is not compiled with CUDA and cudnn version need larger than 8.1.0")
class TestSoftmaxBF16CUDNNOp(TestSoftmaxBF16Op):
def init_cudnn(self):
return True
class TestSoftmaxAPI(unittest.TestCase): class TestSoftmaxAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册