未验证 提交 288ad844 编写于 作者: L Lin Manhui 提交者: GitHub

[AMP] Add bfloat16 Support for `elementwise_pow` Op (#51888)

* Add bf16 support for elementwise_pow

* Update ut
上级 4bf1c163
......@@ -105,4 +105,5 @@ PD_REGISTER_KERNEL(elementwise_pow_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::bfloat16) {}
......@@ -166,7 +166,8 @@ PD_REGISTER_KERNEL(elementwise_pow_raw,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(heaviside,
CPU,
......
......@@ -140,7 +140,8 @@ PD_REGISTER_KERNEL(elementwise_pow,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(subtract,
CPU,
......@@ -232,7 +233,8 @@ PD_REGISTER_KERNEL(elementwise_pow,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "xpu/kernel/math_xpu2.h" // pow()
#endif
#include "paddle/phi/common/amp_type_traits.h"
namespace phi {
namespace funcs {
......@@ -585,68 +586,55 @@ struct InverseFloorDivideFunctor {
}
};
template <typename T>
struct ElementwisePowFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
// TODO(wujionghao): A potential speed improvement is supporting different
// types in C++.
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
// On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// it will return a float number like 2.99... , which floor to 2
// when cast to int by default and it is wrong.
// Use llrint to cast it to the nearest integer, which is 3.
if (std::is_integral<T>::value) {
return std::llrint(
std::pow(static_cast<double>(a), static_cast<double>(b)));
}
template <typename T, typename MPType>
inline HOSTDEVICE typename std::enable_if<std::is_integral<T>::value, T>::type
compute_pow(const T a, const T b) {
// TODO(wujionghao): A potential speed improvement is supporting different
// types in C++.
// On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// it will return a float number like 2.99... , which floor to 2
// when cast to int by default and it is wrong.
// Use llrint to cast it to the nearest integer, which is 3.
return std::llrint(std::pow(static_cast<double>(a), static_cast<double>(b)));
}
template <typename T, typename MPType>
inline HOSTDEVICE typename std::enable_if<!std::is_integral<T>::value, T>::type
compute_pow(const T a, const T b) {
MPType a_val = static_cast<MPType>(a);
MPType b_val = static_cast<MPType>(b);
#ifdef PADDLE_WITH_XPU_KP
return static_cast<T>(pow(a_val, b_val));
#endif
return static_cast<T>(std::pow(a_val, b_val));
}
#else
template <typename T, typename MPType>
inline HOSTDEVICE T compute_pow(const T a, const T b) {
MPType a_val = static_cast<MPType>(a);
MPType b_val = static_cast<MPType>(b);
#ifdef PADDLE_WITH_XPU_KP
return pow(a, b);
return static_cast<T>(pow(a_val, b_val));
#endif
return static_cast<T>(std::pow(a_val, b_val));
}
#endif
return std::pow(a, b);
}
};
template <typename T>
struct ElementwiseInversePowFunctor {
struct ElementwisePowFunctor {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(const T a, const T b) const {
// TODO(wujionghao): A potential speed improvement is supporting different
// types in C++.
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
// On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// it will return a float number like 2.99... , which floor to 2
// when cast to int by default and it is wrong.
// Use llrint to cast it to the nearest integer, which is 3.
if (std::is_integral<T>::value) {
return std::llrint(
std::pow(static_cast<double>(b), static_cast<double>(a)));
}
#endif
#ifdef PADDLE_WITH_XPU_KP
return pow(b, a);
#endif
return std::pow(b, a);
return compute_pow<T, MPType>(a, b);
}
};
template <>
struct ElementwisePowFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
const dtype::float16 b) const {
float f_a = static_cast<float>(a);
float f_b = static_cast<float>(b);
return static_cast<dtype::float16>(std::pow(f_a, f_b));
template <typename T>
struct ElementwiseInversePowFunctor {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(const T a, const T b) const {
return compute_pow<T, MPType>(b, a);
}
};
template <>
struct ElementwiseInversePowFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
const dtype::float16 b) const {
float f_a = static_cast<float>(a);
float f_b = static_cast<float>(b);
return static_cast<dtype::float16>(std::pow(f_b, f_a));
}
};
} // namespace funcs
} // namespace phi
......@@ -150,4 +150,5 @@ PD_REGISTER_KERNEL(elementwise_pow_grad,
double,
int,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t) {}
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
......@@ -851,58 +852,65 @@ void HeavisideGradKernel(const Context& dev_ctx,
HeavisideGradDy<T>());
}
template <typename T>
struct PowGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
if (std::is_integral<T>::value) {
return dout * y *
std::pow(static_cast<double>(x), static_cast<double>(y - 1));
}
template <typename T, typename MPType>
HOSTDEVICE typename std::enable_if<std::is_integral<T>::value, T>::type
compute_pow_grad_dx(T x, T y, T out, T dout) {
return dout * y *
std::pow(static_cast<double>(x), static_cast<double>(y - 1));
}
template <typename T, typename MPType>
HOSTDEVICE typename std::enable_if<!std::is_integral<T>::value, T>::type
compute_pow_grad_dx(T x, T y, T out, T dout) {
MPType x_val = static_cast<MPType>(x);
MPType y_val = static_cast<MPType>(y);
return static_cast<T>(static_cast<MPType>(dout) * y_val *
std::pow(x_val, y_val - 1));
}
template <typename T, typename MPType>
HOSTDEVICE typename std::enable_if<std::is_integral<T>::value, T>::type
compute_pow_grad_dy(T x, T y, T out, T dout) {
return dout * std::log(static_cast<double>(x)) *
std::pow(static_cast<double>(x), static_cast<double>(y));
}
template <typename T, typename MPType>
HOSTDEVICE typename std::enable_if<!std::is_integral<T>::value, T>::type
compute_pow_grad_dy(T x, T y, T out, T dout) {
MPType x_val = static_cast<MPType>(x);
MPType y_val = static_cast<MPType>(y);
return static_cast<T>(static_cast<MPType>(dout) * std::log(x_val) *
std::pow(x_val, y_val));
}
#else
template <typename T, typename MPType>
HOSTDEVICE T compute_pow_grad_dx(T x, T y, T out, T dout) {
MPType x_val = static_cast<MPType>(x);
MPType y_val = static_cast<MPType>(y);
return static_cast<T>(static_cast<MPType>(dout) * y_val *
std::pow(x_val, y_val - 1));
}
template <typename T, typename MPType>
HOSTDEVICE T compute_pow_grad_dy(T x, T y, T out, T dout) {
MPType x_val = static_cast<MPType>(x);
MPType y_val = static_cast<MPType>(y);
return static_cast<T>(static_cast<MPType>(dout) * std::log(x_val) *
std::pow(x_val, y_val));
}
#endif
return dout * y * std::pow(x, y - 1);
}
};
template <>
struct PowGradDX<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(dtype::float16 x,
dtype::float16 y,
dtype::float16 out,
dtype::float16 dout) const {
float tmp_y = static_cast<float>(y);
float tmp_dout = static_cast<float>(dout);
float tmp_x = static_cast<float>(x);
float result = tmp_dout * tmp_y * std::pow(tmp_x, tmp_y - 1.0f);
return static_cast<dtype::float16>(result);
template <typename T>
struct PowGradDX {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return compute_pow_grad_dx<T, MPType>(x, y, out, dout);
}
};
template <typename T, typename Enable = void>
struct PowGradDY {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
if (std::is_integral<T>::value) {
return dout * std::log(static_cast<double>(x)) *
std::pow(static_cast<double>(x), static_cast<double>(y));
}
#endif
return dout * std::log(x) * std::pow(x, y);
}
};
template <>
struct PowGradDY<dtype::float16, void> {
HOSTDEVICE dtype::float16 operator()(dtype::float16 x,
dtype::float16 y,
dtype::float16 out,
dtype::float16 dout) const {
float tmp_y = static_cast<float>(y);
float tmp_dout = static_cast<float>(dout);
float tmp_x = static_cast<float>(x);
float tmp_pow = std::pow(tmp_x, tmp_y);
float result = tmp_pow * tmp_dout * std::log(tmp_x);
return static_cast<dtype::float16>(result);
return compute_pow_grad_dy<T, MPType>(x, y, out, dout);
}
};
......
......@@ -181,5 +181,6 @@ PD_REGISTER_KERNEL(elementwise_pow_raw,
double,
int,
float16,
bfloat16,
int64_t) {}
#endif
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
......@@ -308,6 +308,7 @@ class TestElementwisePowGradOpInt(unittest.TestCase):
class TestElementwisePowOpFP16(OpTest):
def setUp(self):
self.op_type = "elementwise_pow"
self.dtype = np.float16
self.python_api = paddle.pow
self.public_python_api = paddle.pow
self.prim_op_type = "prim"
......@@ -336,5 +337,30 @@ class TestElementwisePowOpFP16(OpTest):
)
class TestElementwisePowBF16Op(OpTest):
def setUp(self):
self.op_type = "elementwise_pow"
self.dtype = np.uint16
self.python_api = paddle.pow
x = np.random.uniform(0, 1, [20, 5]).astype(np.float32)
y = np.random.uniform(0, 1, [20, 5]).astype(np.float32)
out = np.power(x, y)
self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y),
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
if hasattr(self, 'attrs'):
self.check_output(check_eager=False)
else:
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册