From c30e7b81a87b7acb2cf999381e36d0e778ece63d Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Fri, 31 Mar 2023 14:18:01 +0800 Subject: [PATCH] [CustomOP Unittest] Add unit test for outputs with discrete order (#52348) --- test/custom_op/multi_out_test_op.cc | 49 +++++++++++ test/custom_op/test_multi_out_jit.py | 124 ++++++++++++++++++++++++++- 2 files changed, 171 insertions(+), 2 deletions(-) diff --git a/test/custom_op/multi_out_test_op.cc b/test/custom_op/multi_out_test_op.cc index 18940f20e76..d9e0526e420 100644 --- a/test/custom_op/multi_out_test_op.cc +++ b/test/custom_op/multi_out_test_op.cc @@ -65,9 +65,58 @@ std::vector InferDtype(paddle::DataType x_dtype) { return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32}; } +// out = w * 1 + x * 2 + y * 3 + z * 4 +std::vector DiscreteOutForward(const paddle::Tensor& w, + const paddle::Tensor& x, + const paddle::Tensor& y, + const paddle::Tensor& z) { + paddle::Tensor out = w * 1 + x * 2 + y * 3 + z * 4; + return {out}; +} + +std::vector> DiscreteOutInferShape( + const std::vector& w_shape, + const std::vector& x_shape, + const std::vector& y_shape, + const std::vector& z_shape) { + return {w_shape}; +} + +std::vector DiscreteOutInferDtype( + const paddle::DataType& w_dtype, + const paddle::DataType& x_dtype, + const paddle::DataType& y_dtype, + const paddle::DataType& z_dtype) { + return {w_dtype}; +} + +// w_grad = out_grad +// y_grad = out_grad * 3 +std::vector DiscreteOutBackward( + const paddle::Tensor& w, + const paddle::Tensor& x, + const paddle::Tensor& y, + const paddle::Tensor& z, + const paddle::Tensor& out_grad) { + return {out_grad, out_grad * 3}; +} + PD_BUILD_OP(multi_out) .Inputs({"X"}) .Outputs({"Out", "Fake_float64", "ZFake_int32"}) .SetKernelFn(PD_KERNEL(MultiOutCPU)) .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(InferDtype)); + +PD_BUILD_OP(discrete_out) + .Inputs({"w", "x", "y", "z"}) + .Outputs({"output"}) + .SetKernelFn(PD_KERNEL(DiscreteOutForward)) + .SetInferShapeFn(PD_INFER_SHAPE(DiscreteOutInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(DiscreteOutInferDtype)); + +// Test gradient operator whose output order is discrete. +PD_BUILD_GRAD_OP(discrete_out) + .Inputs({"w", "x", "y", "z", paddle::Grad("output")}) + .Outputs({paddle::Grad("w"), paddle::Grad("y")}) + .SetKernelFn(PD_KERNEL(DiscreteOutBackward)); diff --git a/test/custom_op/test_multi_out_jit.py b/test/custom_op/test_multi_out_jit.py index 8582650a986..9b652a0efcc 100644 --- a/test/custom_op/test_multi_out_jit.py +++ b/test/custom_op/test_multi_out_jit.py @@ -19,6 +19,7 @@ import numpy as np from utils import extra_cc_args, paddle_includes import paddle +from paddle import static from paddle.utils.cpp_extension import get_build_directory, load from paddle.utils.cpp_extension.extension_utils import run_cmd @@ -39,11 +40,79 @@ multi_out_module = load( ) +def discrete_out_dynamic(use_phi, device, dtype, np_w, np_x, np_y, np_z): + paddle.set_device(device) + w = paddle.to_tensor(np_w, dtype=dtype, stop_gradient=False) + x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) + y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) + z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False) + if use_phi: + out = multi_out_module.discrete_out(w, x, y, z) + else: + out = w * 1 + x * 2 + y * 3 + z * 4 + + out.backward() + return out.numpy(), w.grad.numpy(), y.grad.numpy() + + +def discrete_out_static(use_phi, device, dtype, np_w, np_x, np_y, np_z): + paddle.enable_static() + paddle.set_device(device) + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + w = static.data(name="w", shape=[None, np_x.shape[1]], dtype=dtype) + x = static.data(name="x", shape=[None, np_x.shape[1]], dtype=dtype) + y = static.data(name="y", shape=[None, np_y.shape[1]], dtype=dtype) + z = static.data(name="z", shape=[None, np_z.shape[1]], dtype=dtype) + w.stop_gradient = False + x.stop_gradient = False + y.stop_gradient = False + z.stop_gradient = False + if use_phi: + out = multi_out_module.discrete_out(w, x, y, z) + else: + out = w * 1 + x * 2 + y * 3 + z * 4 + static.append_backward(out) + + exe = static.Executor() + exe.run(static.default_startup_program()) + + out_v, w_grad_v, y_grad_v = exe.run( + static.default_main_program(), + feed={ + "w": np_w.astype(dtype), + "x": np_x.astype(dtype), + "y": np_y.astype(dtype), + "z": np_z.astype(dtype), + }, + fetch_list=[ + out.name, + w.name + "@GRAD", + y.name + "@GRAD", + ], + ) + paddle.disable_static() + return out_v, w_grad_v, y_grad_v + + class TestMultiOutputDtypes(unittest.TestCase): def setUp(self): self.custom_op = multi_out_module.multi_out self.dtypes = ['float32', 'float64'] self.devices = ['cpu'] + self.np_w = np.random.uniform(-1, 1, [4, 8]).astype("float32") + self.np_x = np.random.uniform(-1, 1, [4, 8]).astype("float32") + self.np_y = np.random.uniform(-1, 1, [4, 8]).astype("float32") + self.np_z = np.random.uniform(-1, 1, [4, 8]).astype("float32") + + def check_output(self, out, pd_out, name): + np.testing.assert_array_equal( + out, + pd_out, + err_msg='custom op {}: {},\n paddle api {}: {}'.format( + name, out, name, pd_out + ), + ) def run_static(self, device, dtype): paddle.set_device(device) @@ -80,7 +149,7 @@ class TestMultiOutputDtypes(unittest.TestCase): one_int32, np.ones([4, 8]).astype('int32') ) - def test_static(self): + def test_multi_out_static(self): paddle.enable_static() for device in self.devices: for dtype in self.dtypes: @@ -88,7 +157,7 @@ class TestMultiOutputDtypes(unittest.TestCase): self.check_multi_outputs(res) paddle.disable_static() - def test_dynamic(self): + def test_multi_out_dynamic(self): for device in self.devices: for dtype in self.dtypes: paddle.set_device(device) @@ -99,6 +168,57 @@ class TestMultiOutputDtypes(unittest.TestCase): self.assertTrue(len(outs) == 3) self.check_multi_outputs(outs, True) + def test_discrete_out_static(self): + for device in self.devices: + for dtype in self.dtypes: + (pd_out, pd_w_grad, pd_y_grad,) = discrete_out_static( + False, + device, + dtype, + self.np_w, + self.np_x, + self.np_y, + self.np_z, + ) + (phi_out, phi_w_grad, phi_y_grad,) = discrete_out_static( + True, + device, + dtype, + self.np_w, + self.np_x, + self.np_y, + self.np_z, + ) + self.check_output(phi_out, pd_out, "out") + # NOTE: In static mode, the output gradient of custom operator has been optimized to shape=[1]. However, native paddle op's output shape = [4, 8], hence we need to fetch pd_w_grad[0][0] (By the way, something wrong with native paddle's gradient, the outputs with other indexes instead of pd_w_grad[0][0] is undefined in this unittest.) + self.check_output(phi_w_grad, pd_w_grad[0][0], "w_grad") + self.check_output(phi_y_grad, pd_y_grad[0][0], "y_grad") + + def test_discrete_out_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + (pd_out, pd_w_grad, pd_y_grad,) = discrete_out_dynamic( + False, + device, + dtype, + self.np_w, + self.np_x, + self.np_y, + self.np_z, + ) + (phi_out, phi_w_grad, phi_y_grad,) = discrete_out_dynamic( + True, + device, + dtype, + self.np_w, + self.np_x, + self.np_y, + self.np_z, + ) + self.check_output(phi_out, pd_out, "out") + self.check_output(phi_w_grad, pd_w_grad, "w_grad") + self.check_output(phi_y_grad, pd_y_grad, "y_grad") + if __name__ == '__main__': unittest.main() -- GitLab