未验证 提交 e84fa263 编写于 作者: Z zhangyikun02 提交者: GitHub

pad3d support fp16 for xpu (#50653)

上级 c1b5e7c2
...@@ -452,7 +452,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -452,7 +452,7 @@ XPUOpMap& get_kl2_ops() {
{"p_norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"p_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"p_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"p_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"pad3d_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"pad3d_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"pad3d", XPUKernelSet({phi::DataType::FLOAT32})}, {"pad3d", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"pixel_shuffle", XPUKernelSet({phi::DataType::FLOAT32})}, {"pixel_shuffle", XPUKernelSet({phi::DataType::FLOAT32})},
{"pixel_shuffle_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"pixel_shuffle_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"pool2d_grad", {"pool2d_grad",
......
...@@ -28,7 +28,6 @@ void Pad3dKernel(const Context& dev_ctx, ...@@ -28,7 +28,6 @@ void Pad3dKernel(const Context& dev_ctx,
float pad_value, float pad_value,
const std::string& data_format, const std::string& data_format,
DenseTensor* out) { DenseTensor* out) {
T value = static_cast<T>(pad_value);
std::vector<int64_t> pads = paddings.GetData(); std::vector<int64_t> pads = paddings.GetData();
auto in_dims = x.dims(); auto in_dims = x.dims();
...@@ -142,10 +141,12 @@ void Pad3dKernel(const Context& dev_ctx, ...@@ -142,10 +141,12 @@ void Pad3dKernel(const Context& dev_ctx,
pads_xpu[4] = pads[0]; // pl pads_xpu[4] = pads[0]; // pl
pads_xpu[5] = pads[1]; // pr pads_xpu[5] = pads[1]; // pr
using XPUType = typename XPUTypeTrait<T>::Type;
if (mode == "reflect") { if (mode == "reflect") {
int r = xpu::reflection_pad3d(dev_ctx.x_context(), int r = xpu::reflection_pad3d(dev_ctx.x_context(),
in_data, reinterpret_cast<const XPUType*>(in_data),
out_data, reinterpret_cast<XPUType*>(out_data),
num, num,
channels, channels,
in_depth, in_depth,
...@@ -156,8 +157,8 @@ void Pad3dKernel(const Context& dev_ctx, ...@@ -156,8 +157,8 @@ void Pad3dKernel(const Context& dev_ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reflection_pad3d"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "reflection_pad3d");
} else if (mode == "replicate") { } else if (mode == "replicate") {
int r = xpu::replication_pad3d(dev_ctx.x_context(), int r = xpu::replication_pad3d(dev_ctx.x_context(),
in_data, reinterpret_cast<const XPUType*>(in_data),
out_data, reinterpret_cast<XPUType*>(out_data),
num, num,
channels, channels,
in_depth, in_depth,
...@@ -167,9 +168,10 @@ void Pad3dKernel(const Context& dev_ctx, ...@@ -167,9 +168,10 @@ void Pad3dKernel(const Context& dev_ctx,
is_ncdhw); is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "replication_pad3d"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "replication_pad3d");
} else if (mode == "constant") { } else if (mode == "constant") {
XPUType value = static_cast<XPUType>(pad_value);
int r = xpu::constant_pad3d(dev_ctx.x_context(), int r = xpu::constant_pad3d(dev_ctx.x_context(),
in_data, reinterpret_cast<const XPUType*>(in_data),
out_data, reinterpret_cast<XPUType*>(out_data),
num, num,
channels, channels,
in_depth, in_depth,
...@@ -184,4 +186,5 @@ void Pad3dKernel(const Context& dev_ctx, ...@@ -184,4 +186,5 @@ void Pad3dKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(pad3d, XPU, ALL_LAYOUT, phi::Pad3dKernel, float) {} PD_REGISTER_KERNEL(
pad3d, XPU, ALL_LAYOUT, phi::Pad3dKernel, float, phi::dtype::float16) {}
...@@ -188,7 +188,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper): ...@@ -188,7 +188,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
mode = "constant" mode = "constant"
value = 100 value = 100
input_data = np.random.rand(*input_shape).astype(self.dtype) input_data = np.random.rand(*input_shape).astype(self.dtype)
x = paddle.fluid.data(name="x", shape=input_shape) x = paddle.fluid.data(
name="x", shape=input_shape, dtype=self.dtype
)
result = F.pad( result = F.pad(
x=x, pad=pad, value=value, mode=mode, data_format="NCDHW" x=x, pad=pad, value=value, mode=mode, data_format="NCDHW"
) )
...@@ -209,7 +211,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper): ...@@ -209,7 +211,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
pad = [1, 2, 1, 1, 1, 2] pad = [1, 2, 1, 1, 1, 2]
mode = "reflect" mode = "reflect"
input_data = np.random.rand(*input_shape).astype(self.dtype) input_data = np.random.rand(*input_shape).astype(self.dtype)
x = paddle.fluid.data(name="x", shape=input_shape) x = paddle.fluid.data(
name="x", shape=input_shape, dtype=self.dtype
)
result1 = F.pad(x=x, pad=pad, mode=mode, data_format="NCDHW") result1 = F.pad(x=x, pad=pad, mode=mode, data_format="NCDHW")
result2 = F.pad(x=x, pad=pad, mode=mode, data_format="NDHWC") result2 = F.pad(x=x, pad=pad, mode=mode, data_format="NDHWC")
exe = Executor(place) exe = Executor(place)
...@@ -235,7 +239,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper): ...@@ -235,7 +239,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
pad = [1, 2, 1, 1, 3, 4] pad = [1, 2, 1, 1, 3, 4]
mode = "replicate" mode = "replicate"
input_data = np.random.rand(*input_shape).astype(self.dtype) input_data = np.random.rand(*input_shape).astype(self.dtype)
x = paddle.fluid.data(name="x", shape=input_shape) x = paddle.fluid.data(
name="x", shape=input_shape, dtype=self.dtype
)
result1 = F.pad(x=x, pad=pad, mode=mode, data_format="NCDHW") result1 = F.pad(x=x, pad=pad, mode=mode, data_format="NCDHW")
result2 = F.pad(x=x, pad=pad, mode=mode, data_format="NDHWC") result2 = F.pad(x=x, pad=pad, mode=mode, data_format="NDHWC")
exe = Executor(place) exe = Executor(place)
...@@ -320,6 +326,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper): ...@@ -320,6 +326,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
self.check_static_result_3(place=place) self.check_static_result_3(place=place)
def test_dygraph_1(self): def test_dygraph_1(self):
# TODO: remove fp16 limit after support of pad op
if self.dtype == np.float16:
return
paddle.disable_static() paddle.disable_static()
input_shape = (1, 2, 3, 4, 5) input_shape = (1, 2, 3, 4, 5)
pad = [1, 2, 1, 1, 3, 4] pad = [1, 2, 1, 1, 3, 4]
...@@ -365,6 +374,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper): ...@@ -365,6 +374,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
np.testing.assert_allclose(y3.numpy(), np_out3, rtol=1e-05) np.testing.assert_allclose(y3.numpy(), np_out3, rtol=1e-05)
def test_dygraph_2(self): def test_dygraph_2(self):
# TODO: remove fp16 limit after support of pad op
if self.dtype == np.float16:
return
paddle.disable_static() paddle.disable_static()
input_shape = (2, 3, 4, 5) input_shape = (2, 3, 4, 5)
pad = [1, 1, 3, 4] pad = [1, 1, 3, 4]
...@@ -412,6 +424,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper): ...@@ -412,6 +424,9 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
np.testing.assert_allclose(y3.numpy(), np_out3, rtol=1e-05) np.testing.assert_allclose(y3.numpy(), np_out3, rtol=1e-05)
def test_dygraph_3(self): def test_dygraph_3(self):
# TODO: remove fp16 limit after support of pad op
if self.dtype == np.float16:
return
paddle.disable_static() paddle.disable_static()
input_shape = (3, 4, 5) input_shape = (3, 4, 5)
pad = [3, 4] pad = [3, 4]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册