未验证 提交 3027c58a 编写于 作者: H houj04 提交者: GitHub

[XPU] add fp16 support for cumsum and log (#50599)

* [XPU] add fp16 support for cumsum and log.

* [XPU] add fp16 support for cumsum and log.
上级 61469eec
...@@ -134,6 +134,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -134,6 +134,7 @@ XPUOpMap& get_kl2_ops() {
{"conv2d_transpose", XPUKernelSet({phi::DataType::FLOAT32})}, {"conv2d_transpose", XPUKernelSet({phi::DataType::FLOAT32})},
{"cumsum", {"cumsum",
XPUKernelSet({phi::DataType::FLOAT32, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT64})}, phi::DataType::INT64})},
{"deformable_conv_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"deformable_conv_grad", XPUKernelSet({phi::DataType::FLOAT32})},
...@@ -379,7 +380,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -379,7 +380,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT8, phi::DataType::INT8,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT64})}, phi::DataType::INT64})},
{"log", XPUKernelSet({phi::DataType::FLOAT32})}, {"log", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"log_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"log_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"log_softmax", XPUKernelSet({phi::DataType::FLOAT32})}, {"log_softmax", XPUKernelSet({phi::DataType::FLOAT32})},
{"log_softmax_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"log_softmax_grad", XPUKernelSet({phi::DataType::FLOAT32})},
......
...@@ -545,9 +545,11 @@ PD_REGISTER_KERNEL( ...@@ -545,9 +545,11 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
square, XPU, ALL_LAYOUT, phi::SquareKernel, float, phi::dtype::float16) {} square, XPU, ALL_LAYOUT, phi::SquareKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {}
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel)
......
...@@ -66,9 +66,38 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -66,9 +66,38 @@ void CumsumKernel(const Context& dev_ctx,
} }
} }
// template<typename T> DLL_EXPORT int cumsum(Context* ctx, const T* x, T* y, // special for fp16
// const std::vector<int>& xshape, bool reverse, bool exclusive, int axis); if (std::is_same<T, dtype::float16>::value) {
int r = cumsum(dev_ctx.x_context(), xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
float* cast_input_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(x.numel());
float* temp_result_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(x.numel());
// cast to fp32
int r =
xpu::cast<XPUType, float>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
cast_input_fp32,
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
// cumsum in fp32
r = xpu::cumsum<float>(dev_ctx.x_context(),
cast_input_fp32,
temp_result_fp32,
x_shape,
reverse,
exclusive,
axis_as_int);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumsum");
// cast back to fp16
r = xpu::cast<float, XPUType>(dev_ctx.x_context(),
temp_result_fp32,
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
} else {
// template<typename T> DLL_EXPORT int cumsum(Context* ctx, const T* x, T*
// y, const std::vector<int>& xshape, bool reverse, bool exclusive, int
// axis);
int r = xpu::cumsum<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()), reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()), reinterpret_cast<XPUType*>(out->data<T>()),
x_shape, x_shape,
...@@ -76,9 +105,16 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -76,9 +105,16 @@ void CumsumKernel(const Context& dev_ctx,
exclusive, exclusive,
axis_as_int); axis_as_int);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumsum"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumsum");
}
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(cumsum,
cumsum, XPU, ALL_LAYOUT, phi::CumsumKernel, float, int, int64_t) {} XPU,
ALL_LAYOUT,
phi::CumsumKernel,
float,
int,
int64_t,
phi::dtype::float16) {}
...@@ -48,9 +48,9 @@ class XPUTestCumsumOP(XPUOpTestWrapper): ...@@ -48,9 +48,9 @@ class XPUTestCumsumOP(XPUOpTestWrapper):
self.op_type = 'cumsum' self.op_type = 'cumsum'
self.init_config() self.init_config()
self.data = np.random.uniform( self.data = np.random.uniform(-1.0, 1.0, self.input_shape).astype(
-100.0, 100.0, self.input_shape self.dtype
).astype(self.dtype) )
reference_out = np.cumsum(self.data, axis=self.axis) reference_out = np.cumsum(self.data, axis=self.axis)
self.inputs = { self.inputs = {
'X': self.data, 'X': self.data,
...@@ -68,6 +68,9 @@ class XPUTestCumsumOP(XPUOpTestWrapper): ...@@ -68,6 +68,9 @@ class XPUTestCumsumOP(XPUOpTestWrapper):
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def init_config(self): def init_config(self):
self.input_shape = (2, 5) self.input_shape = (2, 5)
self.axis = None self.axis = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册