未验证 提交 f6d9ec27 编写于 作者: X xiaohemaikoo 提交者: GitHub

elementwise op support fp16 (#45496)

上级 72b5b5bf
......@@ -237,7 +237,8 @@ PD_REGISTER_KERNEL(remainder,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(
floor_divide, KPS, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {}
PD_REGISTER_KERNEL(elementwise_heaviside,
......@@ -247,7 +248,8 @@ PD_REGISTER_KERNEL(elementwise_heaviside,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(elementwise_pow,
KPS,
ALL_LAYOUT,
......@@ -255,7 +257,8 @@ PD_REGISTER_KERNEL(elementwise_pow,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16) {}
#endif
......
......@@ -524,6 +524,19 @@ struct RemainderFunctor<
}
};
template <>
struct RemainderFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
const dtype::float16 b) const {
float b_float = static_cast<float>(b);
float res = fmod(static_cast<float>(a), b_float);
// Accoding to #PR26732: in dividen % divsor
// remainder shall have the same sign as divsor.
if ((res != 0.0f) && ((res < 0.0f) != (b_float < 0.0f))) res += b_float;
return static_cast<dtype::float16>(res);
}
};
template <typename T, typename Enable = void>
struct InverseRemainderFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
......@@ -547,7 +560,7 @@ struct InverseRemainderFunctor<
template <typename T>
struct ElementwiseHeavisideFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
return a == static_cast<T>(0) ? b : static_cast<T>(a > 0);
return a == static_cast<T>(0) ? b : static_cast<T>(a > static_cast<T>(0));
}
};
......@@ -592,5 +605,16 @@ struct ElementwisePowFunctor {
return std::pow(a, b);
}
};
template <>
struct ElementwisePowFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
const dtype::float16 b) const {
float f_a = static_cast<float>(a);
float f_b = static_cast<float>(b);
return static_cast<dtype::float16>(std::pow(f_a, f_b));
}
};
} // namespace funcs
} // namespace phi
......@@ -35,6 +35,7 @@ void MaximumGradKernel(const Context& dev_ctx,
DenseTensor* dx,
DenseTensor* dy) {
const auto place = dev_ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
......@@ -96,6 +97,7 @@ PD_REGISTER_KERNEL(fmax_grad,
float,
double,
int,
phi::dtype::float16,
int64_t) {}
PD_REGISTER_KERNEL(fmin_grad,
......@@ -105,6 +107,7 @@ PD_REGISTER_KERNEL(fmin_grad,
float,
double,
int,
phi::dtype::float16,
int64_t) {}
PD_REGISTER_KERNEL(maximum_grad,
......@@ -136,6 +139,7 @@ PD_REGISTER_KERNEL(elementwise_heaviside_grad,
float,
double,
int,
phi::dtype::float16,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_pow_grad,
......@@ -145,4 +149,5 @@ PD_REGISTER_KERNEL(elementwise_pow_grad,
float,
double,
int,
phi::dtype::float16,
int64_t) {}
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
......@@ -753,6 +754,20 @@ struct PowGradDX {
}
};
template <>
struct PowGradDX<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(dtype::float16 x,
dtype::float16 y,
dtype::float16 out,
dtype::float16 dout) const {
float tmp_y = static_cast<float>(y);
float tmp_dout = static_cast<float>(dout);
float tmp_x = static_cast<float>(x);
float result = tmp_dout * tmp_y * std::pow(tmp_x, tmp_y - 1.0f);
return static_cast<dtype::float16>(result);
}
};
template <typename T, typename Enable = void>
struct PowGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
......@@ -766,6 +781,21 @@ struct PowGradDY {
}
};
template <>
struct PowGradDY<dtype::float16, void> {
HOSTDEVICE dtype::float16 operator()(dtype::float16 x,
dtype::float16 y,
dtype::float16 out,
dtype::float16 dout) const {
float tmp_y = static_cast<float>(y);
float tmp_dout = static_cast<float>(dout);
float tmp_x = static_cast<float>(x);
float tmp_pow = std::pow(tmp_x, tmp_y);
float result = tmp_pow * tmp_dout * std::log(tmp_x);
return static_cast<dtype::float16>(result);
}
};
template <typename T, typename Context>
void ElementwisePowGradKernel(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -32,6 +32,7 @@ void MaximumKernel(const Context& dev_ctx,
int axis = -1;
MaximumRawKernel<T>(dev_ctx, x, y, axis, out);
}
// Create the definition of Minimum
DEFINE_CUDA_ELEMENTWISE_OP(Minimum)
template <typename T, typename Context>
......@@ -92,11 +93,25 @@ using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(
fmax, KPS, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(fmax,
KPS,
ALL_LAYOUT,
phi::FMaxKernel,
float,
double,
int,
float16,
int64_t) {}
PD_REGISTER_KERNEL(
fmin, KPS, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(fmin,
KPS,
ALL_LAYOUT,
phi::FMinKernel,
float,
double,
int,
float16,
int64_t) {}
PD_REGISTER_KERNEL(maximum_raw,
KPS,
......@@ -125,6 +140,7 @@ PD_REGISTER_KERNEL(remainder_raw,
float,
double,
int,
float16,
int64_t) {}
PD_REGISTER_KERNEL(floor_divide_raw,
KPS,
......@@ -139,6 +155,7 @@ PD_REGISTER_KERNEL(elementwise_heaviside_raw,
float,
double,
int,
float16,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_pow_raw,
KPS,
......@@ -147,5 +164,6 @@ PD_REGISTER_KERNEL(elementwise_pow_raw,
float,
double,
int,
float16,
int64_t) {}
#endif
......@@ -18,6 +18,13 @@ from op_test import OpTest
import paddle
def Heaviside_grad(x, y, dout):
tmp = np.zeros(x.shape).astype("float16")
dx = np.multiply(tmp, dout)
dy = np.multiply(np.equal(x, 0), dout).astype("float16")
return dx, dy
class TestElementwiseOp(OpTest):
def setUp(self):
......@@ -152,6 +159,30 @@ class TestHeavisideAPI_int32(TestHeavisideAPI_float64):
self.dtype = "int32"
class TestHeavisideAPI_float16(OpTest):
def setUp(self):
self.dtype = np.float16
self.op_type = "elementwise_heaviside"
self.python_api = paddle.heaviside
self.inputs = {
'X': np.random.uniform(1, 2, [20, 5]).astype("float16"),
'Y': np.random.uniform(1, 2, [20, 5]).astype("float16")
}
self.outputs = {'Out': np.heaviside(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'Y'],
'Out',
user_defined_grads=Heaviside_grad(
self.inputs['X'], self.inputs['Y'],
1 / self.inputs['X'].size),
check_eager=True)
class TestHeavisideError(unittest.TestCase):
def test_input(self):
......
......@@ -89,6 +89,23 @@ class TestElementwiseModOpFloat(TestElementwiseModOp):
self.check_output(check_eager=False)
class TestElementwiseModOpFp16(TestElementwiseModOp):
def init_dtype(self):
self.dtype = np.float16
def init_input_output(self):
self.x = np.random.uniform(-1000, 1000, [10, 10]).astype(self.dtype)
self.y = np.random.uniform(-100, 100, [10, 10]).astype(self.dtype)
self.out = np.mod(self.x, self.y)
def test_check_output(self):
if self.attrs['axis'] == -1:
self.check_output(check_eager=True)
else:
self.check_output(check_eager=False)
class TestElementwiseModOpDouble(TestElementwiseModOpFloat):
def init_dtype(self):
......
......@@ -20,6 +20,12 @@ import paddle.fluid as fluid
import paddle
def pow_grad(x, y, dout):
dx = dout * y * np.power(x, (y - 1))
dy = dout * np.log(x) * np.power(x, y)
return dx, dy
class TestElementwisePowOp(OpTest):
def setUp(self):
......@@ -194,7 +200,6 @@ class TestElementwisePowGradOpInt(unittest.TestCase):
# dy = dout * log(x) * pow(x, y)
self.grad_y = (self.grad_res * np.log(self.x) *
(self.x**self.y)).astype("int")
print(self.grad_res, self.grad_x, self.grad_y)
def test_grad(self):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
......@@ -205,7 +210,6 @@ class TestElementwisePowGradOpInt(unittest.TestCase):
with fluid.dygraph.guard(place):
x = fluid.dygraph.to_variable(self.x, zero_copy=False)
y = fluid.dygraph.to_variable(self.y, zero_copy=False)
print(x, y)
x.stop_gradient = False
y.stop_gradient = False
res = x**y
......@@ -216,5 +220,31 @@ class TestElementwisePowGradOpInt(unittest.TestCase):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
class TestElementwisePowOpFP16(OpTest):
def setUp(self):
self.op_type = "elementwise_pow"
self.python_api = paddle.pow
self.inputs = {
'X': np.random.uniform(1, 2, [20, 5]).astype("float16"),
'Y': np.random.uniform(1, 2, [20, 5]).astype("float16")
}
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
if hasattr(self, 'attrs'):
self.check_output(check_eager=False)
else:
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X', 'Y'],
'Out',
user_defined_grads=pow_grad(self.inputs['X'],
self.inputs['Y'],
1 / self.inputs['X'].size),
check_eager=True)
if __name__ == '__main__':
unittest.main()
......@@ -209,3 +209,33 @@ class TestElementwiseFmax2Op(OpTest):
max_relative_error=0.005,
no_grad_set=set('Y'),
check_eager=True)
class TestElementwiseFmax3Op(OpTest):
"""TestElementwiseFmax3Op"""
def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmax"
self.python_api = paddle.fmax
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
x = np.random.uniform(0.1, 1, [13, 17]).astype("float16")
sgn = np.random.choice([-1, 1], [13, 17]).astype("float16")
y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float16")
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.fmax(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
"""test_check_output"""
self.check_output(check_eager=True)
def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
if __name__ == "__main__":
unittest.main()
......@@ -213,6 +213,32 @@ class TestElementwiseFmin2Op(OpTest):
check_eager=True)
class TestElementwiseFmin3Op(OpTest):
"""TestElementwiseFmin2Op"""
def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmin"
self.python_api = paddle.fmin
# If x and y have the same value, the min() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
x = np.random.uniform(1, 1, [13, 17]).astype("float16")
sgn = np.random.choice([-1, 1], [13, 17]).astype("float16")
y = x + sgn * np.random.uniform(1, 1, [13, 17]).astype("float16")
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.fmin(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
"""test_check_output"""
self.check_output(check_eager=True)
def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -352,7 +352,7 @@ def pow(x, y, name=None):
Args:
x (Tensor): An N-D Tensor, the data type is float32, float64, int32 or int64.
x (Tensor): An N-D Tensor, the data type is float16, float32, float64, int32 or int64.
y (float|int|Tensor): If it is an N-D Tensor, its data type should be the same as `x`.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
......@@ -762,8 +762,8 @@ def remainder(x, y, name=None):
``paddle.remainder`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Args:
x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
x (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -1003,8 +1003,8 @@ def fmax(x, y, name=None):
``paddle.fmax`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Args:
x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
x (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -1066,8 +1066,8 @@ def fmin(x, y, name=None):
``paddle.fmin`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Args:
x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
x (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -4696,8 +4696,8 @@ def heaviside(x, y, name=None):
``paddle.heaviside`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`.
Args:
x (Tensor): The input tensor of Heaviside step function, it's data type should be float32, float64, int32 or int64.
y (Tensor): The tensor that determines a Heaviside step function, it's data type should be float32, float64, int32 or int64.
x (Tensor): The input tensor of Heaviside step function, it's data type should be float16, float32, float64, int32 or int64.
y (Tensor): The tensor that determines a Heaviside step function, it's data type should be float16, float32, float64, int32 or int64.
name (str, optional): Name for the operation (optional, default is None). Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册