diff --git a/python/paddle/fluid/tests/custom_op/custom_linear_op.cc b/python/paddle/fluid/tests/custom_op/custom_linear_op.cc index a561c845aba2b59a05a50da2a312c744c580f043..ebfaaecd4909348058c7430aa6bf2b08d162eebb 100644 --- a/python/paddle/fluid/tests/custom_op/custom_linear_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_linear_op.cc @@ -23,6 +23,16 @@ std::vector PhiLinearForward(const paddle::Tensor& x, return {paddle::add(paddle::matmul(x, weight), bias)}; } +std::vector PhiLinearBackward(const paddle::Tensor& x, + const paddle::Tensor& weight, + const paddle::Tensor& bias, + const paddle::Tensor& out_grad) { + auto x_grad = paddle::matmul(out_grad, weight, false, true); + auto weight_grad = paddle::matmul(x, out_grad, true, false); + auto bias_grad = paddle::experimental::sum(out_grad, {0}); + return {x_grad, weight_grad, bias_grad}; +} + std::vector> LinearInferShape( const std::vector& x_shape, const std::vector& weight_shape, @@ -86,9 +96,14 @@ std::vector LinearInferDtype( return {x_dtype}; } -PD_BUILD_OP(pten_linear) +PD_BUILD_OP(phi_linear) .Inputs({"X", "Weight", "Bias"}) .Outputs({"Out"}) .SetKernelFn(PD_KERNEL(PhiLinearForward)) .SetInferShapeFn(PD_INFER_SHAPE(LinearInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(LinearInferDtype)); + +PD_BUILD_GRAD_OP(phi_linear) + .Inputs({"X", "Weight", "Bias", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X"), paddle::Grad("Weight"), paddle::Grad("Bias")}) + .SetKernelFn(PD_KERNEL(PhiLinearBackward)); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_linear.py b/python/paddle/fluid/tests/custom_op/test_custom_linear.py index be49513da35dd1e27767989cec3ac3f78dc23417..fba512d511c36976f91ea224be94022d6d6038da 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_linear.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_linear.py @@ -40,43 +40,56 @@ custom_ops = load( verbose=True) -def linear_dynamic(func, dtype, np_x, np_weight, np_bias): - paddle.set_device("cpu") - x = paddle.to_tensor(np_x, dtype=dtype) - weight = paddle.to_tensor(np_weight, dtype=dtype) - bias = paddle.to_tensor(np_bias, dtype=dtype) +def linear_dynamic(func, device, dtype, np_x, np_weight, np_bias): + paddle.set_device(device) + x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) + weight = paddle.to_tensor(np_weight, dtype=dtype, stop_gradient=False) + bias = paddle.to_tensor(np_bias, dtype=dtype, stop_gradient=False) out = func(x, weight, bias) - return out.numpy() + out.backward() + return out.numpy(), x.grad.numpy(), weight.grad.numpy(), bias.grad.numpy() -def linear_static(func, dtype, np_x, np_weight, np_bias): +def linear_static(func, device, dtype, np_x, np_weight, np_bias): paddle.enable_static() - paddle.set_device("cpu") + paddle.set_device(device) with static.scope_guard(static.Scope()): with static.program_guard(static.Program()): - x = static.data(name="x", shape=np_x.shape, dtype=dtype) + x = static.data(name="x", shape=[None, np_x.shape[1]], dtype=dtype) weight = static.data( name="weight", shape=np_weight.shape, dtype=dtype) bias = static.data(name="bias", shape=np_bias.shape, dtype=dtype) + x.stop_gradient = False + weight.stop_gradient = False + bias.stop_gradient = False out = func(x, weight, bias) + mean_out = paddle.mean(out) + static.append_backward(mean_out) exe = static.Executor() exe.run(static.default_startup_program()) - out_v, = exe.run(static.default_main_program(), - feed={ - "x": np_x.astype(dtype), - "weight": np_weight.astype(dtype), - "bias": np_bias.astype(dtype) - }, - fetch_list=[out.name]) + out_v, x_grad_v, weight_grad_v, bias_grad_v = exe.run( + static.default_main_program(), + feed={ + "x": np_x.astype(dtype), + "weight": np_weight.astype(dtype), + "bias": np_bias.astype(dtype) + }, + fetch_list=[ + out.name, x.name + "@GRAD", weight.name + "@GRAD", + bias.name + "@GRAD" + ]) paddle.disable_static() - return out_v + return out_v, x_grad_v, weight_grad_v, bias_grad_v class TestCustomLinearJit(unittest.TestCase): def setUp(self): self.dtypes = ['float32', 'float64'] + self.devices = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.devices.append('gpu') self.np_x = np.random.random((3, 2)).astype("float32") self.np_weight = np.full([2, 4], fill_value=0.5, dtype="float32") self.np_bias = np.ones([4], dtype="float32") @@ -88,20 +101,34 @@ class TestCustomLinearJit(unittest.TestCase): pd_out)) def test_static(self): - for dtype in self.dtypes: - pten_out = linear_static(custom_ops.pten_linear, dtype, self.np_x, - self.np_weight, self.np_bias) - pd_out = linear_static(F.linear, dtype, self.np_x, self.np_weight, - self.np_bias) - self.check_output(pten_out, pd_out, "pten_out") + for device in self.devices: + for dtype in self.dtypes: + phi_out, phi_x_grad, phi_weight_grad, phi_bias_grad = linear_static( + custom_ops.phi_linear, device, dtype, self.np_x, + self.np_weight, self.np_bias) + pd_out, pd_x_grad, pd_weight_grad, pd_bias_grad = linear_static( + F.linear, device, dtype, self.np_x, self.np_weight, + self.np_bias) + self.check_output(phi_out, pd_out, "out") + self.check_output(phi_x_grad, pd_x_grad, "x_grad") + self.check_output(phi_weight_grad, pd_weight_grad, + "weight_grad") + self.check_output(phi_bias_grad, pd_bias_grad, "bias_grad") def func_dynamic(self): - for dtype in self.dtypes: - pten_out = linear_dynamic(custom_ops.pten_linear, dtype, self.np_x, - self.np_weight, self.np_bias) - pd_out = linear_dynamic(F.linear, dtype, self.np_x, self.np_weight, - self.np_bias) - self.check_output(pten_out, pd_out, "pten_out") + for device in self.devices: + for dtype in self.dtypes: + phi_out, phi_x_grad, phi_weight_grad, phi_bias_grad = linear_dynamic( + custom_ops.phi_linear, device, dtype, self.np_x, + self.np_weight, self.np_bias) + pd_out, pd_x_grad, pd_weight_grad, pd_bias_grad = linear_dynamic( + F.linear, device, dtype, self.np_x, self.np_weight, + self.np_bias) + self.check_output(phi_out, pd_out, "phi_out") + self.check_output(phi_x_grad, pd_x_grad, "x_grad") + self.check_output(phi_weight_grad, pd_weight_grad, + "weight_grad") + self.check_output(phi_bias_grad, pd_bias_grad, "bias_grad") def test_dynamic(self): with _test_eager_guard():