未验证 提交 aaf3a13e 编写于 作者: S sneaxiy 提交者: GitHub

add bfloat16 support for more ops (#48272)

* add bfloat16 support for more ops

* fix ci compile

* fix windows compile error

* fix windows compile error

* fix rocm compile error

* fix ROCM compile error
上级 1b59830b
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/cross_entropy.h" #include "paddle/phi/kernels/funcs/cross_entropy.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_primitives.h"
...@@ -153,6 +154,9 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()( ...@@ -153,6 +154,9 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()(
template class CrossEntropyFunctor<phi::GPUContext, float>; template class CrossEntropyFunctor<phi::GPUContext, float>;
template class CrossEntropyFunctor<phi::GPUContext, double>; template class CrossEntropyFunctor<phi::GPUContext, double>;
template class CrossEntropyFunctor<phi::GPUContext, phi::dtype::float16>; template class CrossEntropyFunctor<phi::GPUContext, phi::dtype::float16>;
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(8, 1, 0)
template class CrossEntropyFunctor<phi::GPUContext, phi::dtype::bfloat16>;
#endif
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <limits> #include <limits>
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
...@@ -48,12 +49,27 @@ template <> ...@@ -48,12 +49,27 @@ template <>
struct TolerableValue<phi::dtype::float16> { struct TolerableValue<phi::dtype::float16> {
HOSTDEVICE phi::dtype::float16 operator()( HOSTDEVICE phi::dtype::float16 operator()(
const phi::dtype::float16& x) const { const phi::dtype::float16& x) const {
if (phi::dtype::isfinite(x)) if (phi::dtype::isfinite(x)) {
return x; return x;
else if (x > static_cast<phi::dtype::float16>(0)) } else if (x > static_cast<phi::dtype::float16>(0)) {
return std::numeric_limits<phi::dtype::float16>::max(); return std::numeric_limits<phi::dtype::float16>::max();
else } else {
return std::numeric_limits<phi::dtype::float16>::min(); return std::numeric_limits<phi::dtype::float16>::min();
}
}
};
template <>
struct TolerableValue<phi::dtype::bfloat16> {
HOSTDEVICE phi::dtype::bfloat16 operator()(
const phi::dtype::bfloat16& x) const {
if (phi::dtype::isfinite(x)) {
return x;
} else if (x > static_cast<phi::dtype::bfloat16>(0)) {
return std::numeric_limits<phi::dtype::bfloat16>::max();
} else {
return std::numeric_limits<phi::dtype::bfloat16>::min();
}
} }
}; };
......
...@@ -33,6 +33,10 @@ inline HOSTDEVICE phi::dtype::float16 real_log(phi::dtype::float16 x) { ...@@ -33,6 +33,10 @@ inline HOSTDEVICE phi::dtype::float16 real_log(phi::dtype::float16 x) {
return static_cast<phi::dtype::float16>(::logf(static_cast<float>(x))); return static_cast<phi::dtype::float16>(::logf(static_cast<float>(x)));
} }
inline HOSTDEVICE phi::dtype::bfloat16 real_log(phi::dtype::bfloat16 x) {
return static_cast<phi::dtype::bfloat16>(::logf(static_cast<float>(x)));
}
inline HOSTDEVICE float real_log(float x) { return ::logf(x); } inline HOSTDEVICE float real_log(float x) { return ::logf(x); }
inline HOSTDEVICE double real_log(double x) { return ::log(x); } inline HOSTDEVICE double real_log(double x) { return ::log(x); }
......
...@@ -253,6 +253,7 @@ PD_REGISTER_KERNEL(arg_min, ...@@ -253,6 +253,7 @@ PD_REGISTER_KERNEL(arg_min,
ALL_LAYOUT, ALL_LAYOUT,
phi::ArgMinKernel, phi::ArgMinKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
float, float,
double, double,
int32_t, int32_t,
...@@ -265,6 +266,7 @@ PD_REGISTER_KERNEL(arg_max, ...@@ -265,6 +266,7 @@ PD_REGISTER_KERNEL(arg_max,
ALL_LAYOUT, ALL_LAYOUT,
phi::ArgMaxKernel, phi::ArgMaxKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
float, float,
double, double,
int32_t, int32_t,
......
...@@ -281,6 +281,7 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx, ...@@ -281,6 +281,7 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad, PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -288,3 +289,23 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad, ...@@ -288,3 +289,23 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double,
phi::dtype::float16) {}
#endif
#endif
...@@ -1468,6 +1468,16 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax, ...@@ -1468,6 +1468,16 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {}
#else #else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(cross_entropy_with_softmax,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(cross_entropy_with_softmax, PD_REGISTER_KERNEL(cross_entropy_with_softmax,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -1476,3 +1486,4 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax, ...@@ -1476,3 +1486,4 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {}
#endif #endif
#endif
...@@ -130,6 +130,8 @@ PD_REGISTER_KERNEL(index_sample_grad, ...@@ -130,6 +130,8 @@ PD_REGISTER_KERNEL(index_sample_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::IndexSampleGradKernel, phi::IndexSampleGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float, float,
double, double,
int, int,
......
...@@ -103,6 +103,8 @@ PD_REGISTER_KERNEL(index_sample, ...@@ -103,6 +103,8 @@ PD_REGISTER_KERNEL(index_sample,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::IndexSampleKernel, phi::IndexSampleKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float, float,
double, double,
int, int,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册