未验证 提交 e84fa263 编写于 作者: Z zhangyikun02 提交者: GitHub

pad3d support fp16 for xpu (#50653)

上级 c1b5e7c2
......@@ -452,7 +452,7 @@ XPUOpMap& get_kl2_ops() {
{"p_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"p_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"pad3d_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"pad3d", XPUKernelSet({phi::DataType::FLOAT32})},
{"pad3d", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"pixel_shuffle", XPUKernelSet({phi::DataType::FLOAT32})},
{"pixel_shuffle_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"pool2d_grad",
......
......@@ -28,7 +28,6 @@ void Pad3dKernel(const Context& dev_ctx,
float pad_value,
const std::string& data_format,
DenseTensor* out) {
T value = static_cast<T>(pad_value);
std::vector<int64_t> pads = paddings.GetData();
auto in_dims = x.dims();
......@@ -142,10 +141,12 @@ void Pad3dKernel(const Context& dev_ctx,
pads_xpu[4] = pads[0]; // pl
pads_xpu[5] = pads[1]; // pr
using XPUType = typename XPUTypeTrait<T>::Type;
if (mode == "reflect") {
int r = xpu::reflection_pad3d(dev_ctx.x_context(),
in_data,
out_data,
reinterpret_cast<const XPUType*>(in_data),
reinterpret_cast<XPUType*>(out_data),
num,
channels,
in_depth,
......@@ -156,8 +157,8 @@ void Pad3dKernel(const Context& dev_ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reflection_pad3d");
} else if (mode == "replicate") {
int r = xpu::replication_pad3d(dev_ctx.x_context(),
in_data,
out_data,
reinterpret_cast<const XPUType*>(in_data),
reinterpret_cast<XPUType*>(out_data),
num,
channels,
in_depth,
......@@ -167,9 +168,10 @@ void Pad3dKernel(const Context& dev_ctx,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "replication_pad3d");
} else if (mode == "constant") {
XPUType value = static_cast<XPUType>(pad_value);
int r = xpu::constant_pad3d(dev_ctx.x_context(),
in_data,
out_data,
reinterpret_cast<const XPUType*>(in_data),
reinterpret_cast<XPUType*>(out_data),
num,
channels,
in_depth,
......@@ -184,4 +186,5 @@ void Pad3dKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(pad3d, XPU, ALL_LAYOUT, phi::Pad3dKernel, float) {}
PD_REGISTER_KERNEL(
pad3d, XPU, ALL_LAYOUT, phi::Pad3dKernel, float, phi::dtype::float16) {}
......@@ -188,7 +188,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
mode = "constant"
value = 100
input_data = np.random.rand(*input_shape).astype(self.dtype)
x = paddle.fluid.data(name="x", shape=input_shape)
x = paddle.fluid.data(
name="x", shape=input_shape, dtype=self.dtype
)
result = F.pad(
x=x, pad=pad, value=value, mode=mode, data_format="NCDHW"
)
......@@ -209,7 +211,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
pad = [1, 2, 1, 1, 1, 2]
mode = "reflect"
input_data = np.random.rand(*input_shape).astype(self.dtype)
x = paddle.fluid.data(name="x", shape=input_shape)
x = paddle.fluid.data(
name="x", shape=input_shape, dtype=self.dtype
)
result1 = F.pad(x=x, pad=pad, mode=mode, data_format="NCDHW")
result2 = F.pad(x=x, pad=pad, mode=mode, data_format="NDHWC")
exe = Executor(place)
......@@ -235,7 +239,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
pad = [1, 2, 1, 1, 3, 4]
mode = "replicate"
input_data = np.random.rand(*input_shape).astype(self.dtype)
x = paddle.fluid.data(name="x", shape=input_shape)
x = paddle.fluid.data(
name="x", shape=input_shape, dtype=self.dtype
)
result1 = F.pad(x=x, pad=pad, mode=mode, data_format="NCDHW")
result2 = F.pad(x=x, pad=pad, mode=mode, data_format="NDHWC")
exe = Executor(place)
......@@ -320,6 +326,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
self.check_static_result_3(place=place)
def test_dygraph_1(self):
# TODO: remove fp16 limit after support of pad op
if self.dtype == np.float16:
return
paddle.disable_static()
input_shape = (1, 2, 3, 4, 5)
pad = [1, 2, 1, 1, 3, 4]
......@@ -365,6 +374,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
np.testing.assert_allclose(y3.numpy(), np_out3, rtol=1e-05)
def test_dygraph_2(self):
# TODO: remove fp16 limit after support of pad op
if self.dtype == np.float16:
return
paddle.disable_static()
input_shape = (2, 3, 4, 5)
pad = [1, 1, 3, 4]
......@@ -412,6 +424,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
np.testing.assert_allclose(y3.numpy(), np_out3, rtol=1e-05)
def test_dygraph_3(self):
# TODO: remove fp16 limit after support of pad op
if self.dtype == np.float16:
return
paddle.disable_static()
input_shape = (3, 4, 5)
pad = [3, 4]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册