未验证 提交 3027c58a 编写于 作者: H houj04 提交者: GitHub

[XPU] add fp16 support for cumsum and log (#50599)

* [XPU] add fp16 support for cumsum and log.

* [XPU] add fp16 support for cumsum and log.
上级 61469eec
......@@ -134,6 +134,7 @@ XPUOpMap& get_kl2_ops() {
{"conv2d_transpose", XPUKernelSet({phi::DataType::FLOAT32})},
{"cumsum",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"deformable_conv_grad", XPUKernelSet({phi::DataType::FLOAT32})},
......@@ -379,7 +380,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT8,
phi::DataType::INT32,
phi::DataType::INT64})},
{"log", XPUKernelSet({phi::DataType::FLOAT32})},
{"log", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"log_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"log_softmax", XPUKernelSet({phi::DataType::FLOAT32})},
{"log_softmax_grad", XPUKernelSet({phi::DataType::FLOAT32})},
......
......@@ -545,9 +545,11 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL(
square, XPU, ALL_LAYOUT, phi::SquareKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {}
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel)
......
......@@ -66,19 +66,55 @@ void CumsumKernel(const Context& dev_ctx,
}
}
// template<typename T> DLL_EXPORT int cumsum(Context* ctx, const T* x, T* y,
// const std::vector<int>& xshape, bool reverse, bool exclusive, int axis);
int r = cumsum(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x_shape,
reverse,
exclusive,
axis_as_int);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumsum");
// special for fp16
if (std::is_same<T, dtype::float16>::value) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
float* cast_input_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(x.numel());
float* temp_result_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(x.numel());
// cast to fp32
int r =
xpu::cast<XPUType, float>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
cast_input_fp32,
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
// cumsum in fp32
r = xpu::cumsum<float>(dev_ctx.x_context(),
cast_input_fp32,
temp_result_fp32,
x_shape,
reverse,
exclusive,
axis_as_int);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumsum");
// cast back to fp16
r = xpu::cast<float, XPUType>(dev_ctx.x_context(),
temp_result_fp32,
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
} else {
// template<typename T> DLL_EXPORT int cumsum(Context* ctx, const T* x, T*
// y, const std::vector<int>& xshape, bool reverse, bool exclusive, int
// axis);
int r = xpu::cumsum<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x_shape,
reverse,
exclusive,
axis_as_int);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumsum");
}
}
} // namespace phi
PD_REGISTER_KERNEL(
cumsum, XPU, ALL_LAYOUT, phi::CumsumKernel, float, int, int64_t) {}
PD_REGISTER_KERNEL(cumsum,
XPU,
ALL_LAYOUT,
phi::CumsumKernel,
float,
int,
int64_t,
phi::dtype::float16) {}
......@@ -48,9 +48,9 @@ class XPUTestCumsumOP(XPUOpTestWrapper):
self.op_type = 'cumsum'
self.init_config()
self.data = np.random.uniform(
-100.0, 100.0, self.input_shape
).astype(self.dtype)
self.data = np.random.uniform(-1.0, 1.0, self.input_shape).astype(
self.dtype
)
reference_out = np.cumsum(self.data, axis=self.axis)
self.inputs = {
'X': self.data,
......@@ -68,6 +68,9 @@ class XPUTestCumsumOP(XPUOpTestWrapper):
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def init_config(self):
self.input_shape = (2, 5)
self.axis = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册