未验证 提交 aaf3a13e 编写于 作者: S sneaxiy 提交者: GitHub

add bfloat16 support for more ops (#48272)

* add bfloat16 support for more ops

* fix ci compile

* fix windows compile error

* fix windows compile error

* fix rocm compile error

* fix ROCM compile error
上级 1b59830b
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/cross_entropy.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
......@@ -153,6 +154,9 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()(
template class CrossEntropyFunctor<phi::GPUContext, float>;
template class CrossEntropyFunctor<phi::GPUContext, double>;
template class CrossEntropyFunctor<phi::GPUContext, phi::dtype::float16>;
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(8, 1, 0)
template class CrossEntropyFunctor<phi::GPUContext, phi::dtype::bfloat16>;
#endif
} // namespace funcs
} // namespace phi
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <limits>
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
......@@ -48,12 +49,27 @@ template <>
struct TolerableValue<phi::dtype::float16> {
HOSTDEVICE phi::dtype::float16 operator()(
const phi::dtype::float16& x) const {
if (phi::dtype::isfinite(x))
if (phi::dtype::isfinite(x)) {
return x;
else if (x > static_cast<phi::dtype::float16>(0))
} else if (x > static_cast<phi::dtype::float16>(0)) {
return std::numeric_limits<phi::dtype::float16>::max();
else
} else {
return std::numeric_limits<phi::dtype::float16>::min();
}
}
};
template <>
struct TolerableValue<phi::dtype::bfloat16> {
HOSTDEVICE phi::dtype::bfloat16 operator()(
const phi::dtype::bfloat16& x) const {
if (phi::dtype::isfinite(x)) {
return x;
} else if (x > static_cast<phi::dtype::bfloat16>(0)) {
return std::numeric_limits<phi::dtype::bfloat16>::max();
} else {
return std::numeric_limits<phi::dtype::bfloat16>::min();
}
}
};
......
......@@ -33,6 +33,10 @@ inline HOSTDEVICE phi::dtype::float16 real_log(phi::dtype::float16 x) {
return static_cast<phi::dtype::float16>(::logf(static_cast<float>(x)));
}
inline HOSTDEVICE phi::dtype::bfloat16 real_log(phi::dtype::bfloat16 x) {
return static_cast<phi::dtype::bfloat16>(::logf(static_cast<float>(x)));
}
inline HOSTDEVICE float real_log(float x) { return ::logf(x); }
inline HOSTDEVICE double real_log(double x) { return ::log(x); }
......
......@@ -253,6 +253,7 @@ PD_REGISTER_KERNEL(arg_min,
ALL_LAYOUT,
phi::ArgMinKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
int32_t,
......@@ -265,6 +266,7 @@ PD_REGISTER_KERNEL(arg_max,
ALL_LAYOUT,
phi::ArgMaxKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
int32_t,
......
......@@ -281,6 +281,7 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU,
ALL_LAYOUT,
......@@ -288,3 +289,23 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
float,
double,
phi::dtype::float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double,
phi::dtype::float16) {}
#endif
#endif
......@@ -1468,6 +1468,16 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax,
float,
phi::dtype::float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(cross_entropy_with_softmax,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(cross_entropy_with_softmax,
GPU,
ALL_LAYOUT,
......@@ -1476,3 +1486,4 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax,
double,
phi::dtype::float16) {}
#endif
#endif
......@@ -130,6 +130,8 @@ PD_REGISTER_KERNEL(index_sample_grad,
GPU,
ALL_LAYOUT,
phi::IndexSampleGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
int,
......
......@@ -103,6 +103,8 @@ PD_REGISTER_KERNEL(index_sample,
GPU,
ALL_LAYOUT,
phi::IndexSampleKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
int,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册