未验证 提交 288ad844 编写于 作者: L Lin Manhui 提交者: GitHub

[AMP] Add bfloat16 Support for `elementwise_pow` Op (#51888)

* Add bf16 support for elementwise_pow

* Update ut
上级 4bf1c163
...@@ -105,4 +105,5 @@ PD_REGISTER_KERNEL(elementwise_pow_grad, ...@@ -105,4 +105,5 @@ PD_REGISTER_KERNEL(elementwise_pow_grad,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::bfloat16) {}
...@@ -166,7 +166,8 @@ PD_REGISTER_KERNEL(elementwise_pow_raw, ...@@ -166,7 +166,8 @@ PD_REGISTER_KERNEL(elementwise_pow_raw,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(heaviside, PD_REGISTER_KERNEL(heaviside,
CPU, CPU,
......
...@@ -140,7 +140,8 @@ PD_REGISTER_KERNEL(elementwise_pow, ...@@ -140,7 +140,8 @@ PD_REGISTER_KERNEL(elementwise_pow,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(subtract, PD_REGISTER_KERNEL(subtract,
CPU, CPU,
...@@ -232,7 +233,8 @@ PD_REGISTER_KERNEL(elementwise_pow, ...@@ -232,7 +233,8 @@ PD_REGISTER_KERNEL(elementwise_pow,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif #endif
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "xpu/kernel/math_xpu2.h" // pow() #include "xpu/kernel/math_xpu2.h" // pow()
#endif #endif
#include "paddle/phi/common/amp_type_traits.h"
namespace phi { namespace phi {
namespace funcs { namespace funcs {
...@@ -585,68 +586,55 @@ struct InverseFloorDivideFunctor { ...@@ -585,68 +586,55 @@ struct InverseFloorDivideFunctor {
} }
}; };
template <typename T>
struct ElementwisePowFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
// TODO(wujionghao): A potential speed improvement is supporting different
// types in C++.
#if defined(__CUDA_ARCH__) || defined(__HIPCC__) #if defined(__CUDA_ARCH__) || defined(__HIPCC__)
template <typename T, typename MPType>
inline HOSTDEVICE typename std::enable_if<std::is_integral<T>::value, T>::type
compute_pow(const T a, const T b) {
// TODO(wujionghao): A potential speed improvement is supporting different
// types in C++.
// On CUDAPlace, std::pow(3, 1) calls pow(float, float), and // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// it will return a float number like 2.99... , which floor to 2 // it will return a float number like 2.99... , which floor to 2
// when cast to int by default and it is wrong. // when cast to int by default and it is wrong.
// Use llrint to cast it to the nearest integer, which is 3. // Use llrint to cast it to the nearest integer, which is 3.
if (std::is_integral<T>::value) { return std::llrint(std::pow(static_cast<double>(a), static_cast<double>(b)));
return std::llrint( }
std::pow(static_cast<double>(a), static_cast<double>(b))); template <typename T, typename MPType>
} inline HOSTDEVICE typename std::enable_if<!std::is_integral<T>::value, T>::type
compute_pow(const T a, const T b) {
MPType a_val = static_cast<MPType>(a);
MPType b_val = static_cast<MPType>(b);
#ifdef PADDLE_WITH_XPU_KP
return static_cast<T>(pow(a_val, b_val));
#endif #endif
return static_cast<T>(std::pow(a_val, b_val));
}
#else
template <typename T, typename MPType>
inline HOSTDEVICE T compute_pow(const T a, const T b) {
MPType a_val = static_cast<MPType>(a);
MPType b_val = static_cast<MPType>(b);
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
return pow(a, b); return static_cast<T>(pow(a_val, b_val));
#endif
return static_cast<T>(std::pow(a_val, b_val));
}
#endif #endif
return std::pow(a, b);
}
};
template <typename T> template <typename T>
struct ElementwiseInversePowFunctor { struct ElementwisePowFunctor {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(const T a, const T b) const { inline HOSTDEVICE T operator()(const T a, const T b) const {
// TODO(wujionghao): A potential speed improvement is supporting different return compute_pow<T, MPType>(a, b);
// types in C++.
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
// On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// it will return a float number like 2.99... , which floor to 2
// when cast to int by default and it is wrong.
// Use llrint to cast it to the nearest integer, which is 3.
if (std::is_integral<T>::value) {
return std::llrint(
std::pow(static_cast<double>(b), static_cast<double>(a)));
}
#endif
#ifdef PADDLE_WITH_XPU_KP
return pow(b, a);
#endif
return std::pow(b, a);
} }
}; };
template <> template <typename T>
struct ElementwisePowFunctor<dtype::float16> { struct ElementwiseInversePowFunctor {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a, using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
const dtype::float16 b) const { inline HOSTDEVICE T operator()(const T a, const T b) const {
float f_a = static_cast<float>(a); return compute_pow<T, MPType>(b, a);
float f_b = static_cast<float>(b);
return static_cast<dtype::float16>(std::pow(f_a, f_b));
} }
}; };
template <>
struct ElementwiseInversePowFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
const dtype::float16 b) const {
float f_a = static_cast<float>(a);
float f_b = static_cast<float>(b);
return static_cast<dtype::float16>(std::pow(f_b, f_a));
}
};
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -150,4 +150,5 @@ PD_REGISTER_KERNEL(elementwise_pow_grad, ...@@ -150,4 +150,5 @@ PD_REGISTER_KERNEL(elementwise_pow_grad,
double, double,
int, int,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
int64_t) {} int64_t) {}
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/complex.h" #include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
...@@ -851,58 +852,65 @@ void HeavisideGradKernel(const Context& dev_ctx, ...@@ -851,58 +852,65 @@ void HeavisideGradKernel(const Context& dev_ctx,
HeavisideGradDy<T>()); HeavisideGradDy<T>());
} }
template <typename T>
struct PowGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
#if defined(__CUDA_ARCH__) || defined(__HIPCC__) #if defined(__CUDA_ARCH__) || defined(__HIPCC__)
if (std::is_integral<T>::value) { template <typename T, typename MPType>
HOSTDEVICE typename std::enable_if<std::is_integral<T>::value, T>::type
compute_pow_grad_dx(T x, T y, T out, T dout) {
return dout * y * return dout * y *
std::pow(static_cast<double>(x), static_cast<double>(y - 1)); std::pow(static_cast<double>(x), static_cast<double>(y - 1));
} }
template <typename T, typename MPType>
HOSTDEVICE typename std::enable_if<!std::is_integral<T>::value, T>::type
compute_pow_grad_dx(T x, T y, T out, T dout) {
MPType x_val = static_cast<MPType>(x);
MPType y_val = static_cast<MPType>(y);
return static_cast<T>(static_cast<MPType>(dout) * y_val *
std::pow(x_val, y_val - 1));
}
template <typename T, typename MPType>
HOSTDEVICE typename std::enable_if<std::is_integral<T>::value, T>::type
compute_pow_grad_dy(T x, T y, T out, T dout) {
return dout * std::log(static_cast<double>(x)) *
std::pow(static_cast<double>(x), static_cast<double>(y));
}
template <typename T, typename MPType>
HOSTDEVICE typename std::enable_if<!std::is_integral<T>::value, T>::type
compute_pow_grad_dy(T x, T y, T out, T dout) {
MPType x_val = static_cast<MPType>(x);
MPType y_val = static_cast<MPType>(y);
return static_cast<T>(static_cast<MPType>(dout) * std::log(x_val) *
std::pow(x_val, y_val));
}
#else
template <typename T, typename MPType>
HOSTDEVICE T compute_pow_grad_dx(T x, T y, T out, T dout) {
MPType x_val = static_cast<MPType>(x);
MPType y_val = static_cast<MPType>(y);
return static_cast<T>(static_cast<MPType>(dout) * y_val *
std::pow(x_val, y_val - 1));
}
template <typename T, typename MPType>
HOSTDEVICE T compute_pow_grad_dy(T x, T y, T out, T dout) {
MPType x_val = static_cast<MPType>(x);
MPType y_val = static_cast<MPType>(y);
return static_cast<T>(static_cast<MPType>(dout) * std::log(x_val) *
std::pow(x_val, y_val));
}
#endif #endif
return dout * y * std::pow(x, y - 1);
}
};
template <> template <typename T>
struct PowGradDX<dtype::float16> { struct PowGradDX {
HOSTDEVICE dtype::float16 operator()(dtype::float16 x, using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
dtype::float16 y, HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
dtype::float16 out, return compute_pow_grad_dx<T, MPType>(x, y, out, dout);
dtype::float16 dout) const {
float tmp_y = static_cast<float>(y);
float tmp_dout = static_cast<float>(dout);
float tmp_x = static_cast<float>(x);
float result = tmp_dout * tmp_y * std::pow(tmp_x, tmp_y - 1.0f);
return static_cast<dtype::float16>(result);
} }
}; };
template <typename T, typename Enable = void> template <typename T, typename Enable = void>
struct PowGradDY { struct PowGradDY {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
#if defined(__CUDA_ARCH__) || defined(__HIPCC__) return compute_pow_grad_dy<T, MPType>(x, y, out, dout);
if (std::is_integral<T>::value) {
return dout * std::log(static_cast<double>(x)) *
std::pow(static_cast<double>(x), static_cast<double>(y));
}
#endif
return dout * std::log(x) * std::pow(x, y);
}
};
template <>
struct PowGradDY<dtype::float16, void> {
HOSTDEVICE dtype::float16 operator()(dtype::float16 x,
dtype::float16 y,
dtype::float16 out,
dtype::float16 dout) const {
float tmp_y = static_cast<float>(y);
float tmp_dout = static_cast<float>(dout);
float tmp_x = static_cast<float>(x);
float tmp_pow = std::pow(tmp_x, tmp_y);
float result = tmp_pow * tmp_dout * std::log(tmp_x);
return static_cast<dtype::float16>(result);
} }
}; };
......
...@@ -181,5 +181,6 @@ PD_REGISTER_KERNEL(elementwise_pow_raw, ...@@ -181,5 +181,6 @@ PD_REGISTER_KERNEL(elementwise_pow_raw,
double, double,
int, int,
float16, float16,
bfloat16,
int64_t) {} int64_t) {}
#endif #endif
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest, skip_check_grad_ci from op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -308,6 +308,7 @@ class TestElementwisePowGradOpInt(unittest.TestCase): ...@@ -308,6 +308,7 @@ class TestElementwisePowGradOpInt(unittest.TestCase):
class TestElementwisePowOpFP16(OpTest): class TestElementwisePowOpFP16(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.dtype = np.float16
self.python_api = paddle.pow self.python_api = paddle.pow
self.public_python_api = paddle.pow self.public_python_api = paddle.pow
self.prim_op_type = "prim" self.prim_op_type = "prim"
...@@ -336,5 +337,30 @@ class TestElementwisePowOpFP16(OpTest): ...@@ -336,5 +337,30 @@ class TestElementwisePowOpFP16(OpTest):
) )
class TestElementwisePowBF16Op(OpTest):
def setUp(self):
self.op_type = "elementwise_pow"
self.dtype = np.uint16
self.python_api = paddle.pow
x = np.random.uniform(0, 1, [20, 5]).astype(np.float32)
y = np.random.uniform(0, 1, [20, 5]).astype(np.float32)
out = np.power(x, y)
self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y),
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
if hasattr(self, 'attrs'):
self.check_output(check_eager=False)
else:
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册