diff --git a/test/custom_op/custom_linear_op.cc b/test/custom_op/custom_linear_op.cc index ebfaaecd4909348058c7430aa6bf2b08d162eebb..fb7aac40b4f41070c5100378cf1da16b0215a533 100644 --- a/test/custom_op/custom_linear_op.cc +++ b/test/custom_op/custom_linear_op.cc @@ -17,16 +17,17 @@ limitations under the License. */ #include "paddle/extension.h" // The linear implemented here must be passed in bias -std::vector PhiLinearForward(const paddle::Tensor& x, - const paddle::Tensor& weight, - const paddle::Tensor& bias) { +std::vector CustomLinearForward(const paddle::Tensor& x, + const paddle::Tensor& weight, + const paddle::Tensor& bias) { return {paddle::add(paddle::matmul(x, weight), bias)}; } -std::vector PhiLinearBackward(const paddle::Tensor& x, - const paddle::Tensor& weight, - const paddle::Tensor& bias, - const paddle::Tensor& out_grad) { +std::vector CustomLinearBackward( + const paddle::Tensor& x, + const paddle::Tensor& weight, + const paddle::Tensor& bias, + const paddle::Tensor& out_grad) { auto x_grad = paddle::matmul(out_grad, weight, false, true); auto weight_grad = paddle::matmul(x, out_grad, true, false); auto bias_grad = paddle::experimental::sum(out_grad, {0}); @@ -96,14 +97,14 @@ std::vector LinearInferDtype( return {x_dtype}; } -PD_BUILD_OP(phi_linear) +PD_BUILD_OP(custom_linear) .Inputs({"X", "Weight", "Bias"}) .Outputs({"Out"}) - .SetKernelFn(PD_KERNEL(PhiLinearForward)) + .SetKernelFn(PD_KERNEL(CustomLinearForward)) .SetInferShapeFn(PD_INFER_SHAPE(LinearInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(LinearInferDtype)); -PD_BUILD_GRAD_OP(phi_linear) +PD_BUILD_GRAD_OP(custom_linear) .Inputs({"X", "Weight", "Bias", paddle::Grad("Out")}) .Outputs({paddle::Grad("X"), paddle::Grad("Weight"), paddle::Grad("Bias")}) - .SetKernelFn(PD_KERNEL(PhiLinearBackward)); + .SetKernelFn(PD_KERNEL(CustomLinearBackward)); diff --git a/test/custom_op/test_custom_inplace.py b/test/custom_op/test_custom_inplace.py index bdfe018c40f672b8ed61714d80efdb6a72e85175..2c0a5d4c513c18476e4c3c8c7823d7cac54ca70d 100644 --- a/test/custom_op/test_custom_inplace.py +++ b/test/custom_op/test_custom_inplace.py @@ -41,11 +41,11 @@ custom_inplace = load( ) -def inplace_dynamic_add(phi_func, device, dtype, np_x, np_y): +def inplace_dynamic_add(custom_func, device, dtype, np_x, np_y): paddle.set_device(device) x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=True) y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) - if phi_func: + if custom_func: out = custom_inplace.custom_add(x, y) else: out = x.add_(y) @@ -88,14 +88,14 @@ def inplace_static_add(func, device, dtype, np_x, np_y): return x_v, out_v, x_grad_v, y_grad_v, out_grad_v -def inplace_dynamic_add_vector(phi_func, device, dtype, np_inputs, np_y): +def inplace_dynamic_add_vector(custom_func, device, dtype, np_inputs, np_y): paddle.set_device(device) inputs = [ paddle.to_tensor(np_input, dtype=dtype, stop_gradient=True) for np_input in np_inputs ] y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) - if phi_func: + if custom_func: out = custom_inplace.custom_add_vec(inputs, y) else: out = [x.add_(y) for x in inputs] @@ -111,7 +111,7 @@ def inplace_dynamic_add_vector(phi_func, device, dtype, np_inputs, np_y): ) -def inplace_static_add_vector(phi_func, device, dtype, np_inputs, np_y): +def inplace_static_add_vector(custom_func, device, dtype, np_inputs, np_y): paddle.enable_static() paddle.set_device(device) with static.scope_guard(static.Scope()): @@ -126,7 +126,7 @@ def inplace_static_add_vector(phi_func, device, dtype, np_inputs, np_y): x1.stop_gradient = False x2.stop_gradient = False y.stop_gradient = False - if phi_func: + if custom_func: out = custom_inplace.custom_add_vec([x1, x2], y) else: out = [paddle.add(x1, y), paddle.add(x2, y)] @@ -170,13 +170,13 @@ def inplace_static_add_vector(phi_func, device, dtype, np_inputs, np_y): ) -def inplace_dynamic_relu_net(phi_func, device, dtype, np_x, np_y, np_z): +def inplace_dynamic_relu_net(custom_func, device, dtype, np_x, np_y, np_z): paddle.set_device(device) x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False) out_xy = x + y - if phi_func: + if custom_func: out_xy = custom_inplace.custom_relu_inplace(out_xy) out_xyz = out_xy + z out = custom_inplace.custom_relu_inplace(out_xyz) @@ -229,13 +229,13 @@ def inplace_static_relu_net(func, device, dtype, np_x, np_y, np_z): return x_v, y_v, out_v, x_grad_v, y_grad_v -def dynamic_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b): +def dynamic_multi_inplace(custom_func, device, dtype, np_x, np_y, np_a, np_b): paddle.set_device(device) x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=True) y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) a = paddle.to_tensor(np_a, dtype=dtype, stop_gradient=True) b = paddle.to_tensor(np_b, dtype=dtype, stop_gradient=False) - if phi_func: + if custom_func: out_xy, out_ab = custom_inplace.custom_multi_inplace(x, y, a, b) else: out_xy = x.add_(y) @@ -257,7 +257,7 @@ def dynamic_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b): ) -def static_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b): +def static_multi_inplace(custom_func, device, dtype, np_x, np_y, np_a, np_b): paddle.enable_static() paddle.set_device(device) with static.scope_guard(static.Scope()): @@ -270,7 +270,7 @@ def static_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b): y.stop_gradient = False a.stop_gradient = False b.stop_gradient = False - if phi_func: + if custom_func: out_xy, out_ab = custom_inplace.custom_multi_inplace(x, y, a, b) else: out_xy = paddle.add(x, y) @@ -379,11 +379,11 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_y, ) ( - phi_x, - phi_out, - phi_x_grad, - phi_y_grad, - phi_out_grad, + custom_x, + custom_out, + custom_x_grad, + custom_y_grad, + custom_out_grad, ) = inplace_static_add( custom_inplace.custom_add, device, @@ -391,15 +391,15 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_x, self.np_y, ) - self.check_output(phi_x, phi_out, "inplace_phi_x") + self.check_output(custom_x, custom_out, "inplace_custom_x") self.check_output( - phi_x_grad, phi_out_grad, "inplace_phi_x_grad" + custom_x_grad, custom_out_grad, "inplace_custom_x_grad" ) - self.check_output(phi_out, pd_out, "out") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") - self.check_output(phi_y_grad, pd_y_grad, "y_grad") - self.check_output(phi_out_grad, pd_out_grad, "out_grad") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_y_grad, pd_y_grad, "y_grad") + self.check_output(custom_out_grad, pd_out_grad, "out_grad") def test_dynamic_add(self): for device in self.devices: @@ -418,11 +418,11 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_y, ) ( - phi_x, - phi_y, - phi_out, - phi_x_grad, - phi_y_grad, + custom_x, + custom_y, + custom_out, + custom_x_grad, + custom_y_grad, ) = inplace_dynamic_add( True, device, @@ -431,14 +431,14 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_y, ) - self.check_output(phi_x, phi_out, "inplace_phi_x") + self.check_output(custom_x, custom_out, "inplace_custom_x") self.check_output(pd_x, pd_out, "inplace_pd_x") - self.check_output(phi_x, pd_x, "x") - self.check_output(phi_y, pd_y, "y") - self.check_output(phi_out, pd_out, "out") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") - self.check_output(phi_y_grad, pd_y_grad, "y_grad") + self.check_output(custom_x, pd_x, "x") + self.check_output(custom_y, pd_y, "y") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_y_grad, pd_y_grad, "y_grad") def test_static_add_vector(self): for device in self.devices: @@ -456,10 +456,10 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_y, ) ( - phi_out, - phi_x_grad, - phi_y_grad, - phi_out_grad, + custom_out, + custom_x_grad, + custom_y_grad, + custom_out_grad, ) = inplace_static_add_vector( False, device, @@ -468,10 +468,10 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_y, ) - self.check_output(phi_out, pd_out, "out") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") - self.check_output(phi_y_grad, pd_y_grad, "y_grad") - self.check_output(phi_out_grad, pd_out_grad, "out_grad") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_y_grad, pd_y_grad, "y_grad") + self.check_output(custom_out_grad, pd_out_grad, "out_grad") def test_dynamic_add_vector(self): for device in self.devices: @@ -490,11 +490,11 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_y, ) ( - phi_x, - phi_y, - phi_out, - phi_x_grad, - phi_y_grad, + custom_x, + custom_y, + custom_out, + custom_x_grad, + custom_y_grad, ) = inplace_dynamic_add_vector( False, device, @@ -503,14 +503,14 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_y, ) - self.check_output(phi_x, phi_out, "inplace_phi_x") + self.check_output(custom_x, custom_out, "inplace_custom_x") self.check_output(pd_x, pd_out, "inplace_pd_x") - self.check_output(phi_x, pd_x, "x") - self.check_output(phi_y, pd_y, "y") - self.check_output(phi_out, pd_out, "out") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") - self.check_output(phi_y_grad, pd_y_grad, "y_grad") + self.check_output(custom_x, pd_x, "x") + self.check_output(custom_y, pd_y, "y") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_y_grad, pd_y_grad, "y_grad") def test_static_relu_net(self): for device in self.devices: @@ -530,11 +530,11 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_z, ) ( - phi_x, - phi_y, - phi_out, - phi_x_grad, - phi_y_grad, + custom_x, + custom_y, + custom_out, + custom_x_grad, + custom_y_grad, ) = inplace_static_relu_net( custom_inplace.custom_relu_inplace, device, @@ -543,11 +543,11 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_y, self.np_z, ) - self.check_output_allclose(phi_x, pd_x, "x") - self.check_output_allclose(phi_y, pd_y, "y") - self.check_output_allclose(phi_out, pd_out, "out") - self.check_output_allclose(phi_x_grad, pd_x_grad, "x_grad") - self.check_output_allclose(phi_y_grad, pd_y_grad, "y_grad") + self.check_output_allclose(custom_x, pd_x, "x") + self.check_output_allclose(custom_y, pd_y, "y") + self.check_output_allclose(custom_out, pd_out, "out") + self.check_output_allclose(custom_x_grad, pd_x_grad, "x_grad") + self.check_output_allclose(custom_y_grad, pd_y_grad, "y_grad") def test_dynamic_relu_net(self): for device in self.devices: @@ -567,11 +567,11 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_z, ) ( - phi_x, - phi_y, - phi_out, - phi_x_grad, - phi_y_grad, + custom_x, + custom_y, + custom_out, + custom_x_grad, + custom_y_grad, ) = inplace_dynamic_relu_net( True, device, @@ -581,11 +581,11 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_z, ) - self.check_output(phi_x, pd_x, "x") - self.check_output(phi_y, pd_y, "y") - self.check_output(phi_out, pd_out, "out") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") - self.check_output(phi_y_grad, pd_y_grad, "y_grad") + self.check_output(custom_x, pd_x, "x") + self.check_output(custom_y, pd_y, "y") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_y_grad, pd_y_grad, "y_grad") def test_static_multi_inplace(self): for device in self.devices: @@ -611,16 +611,16 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_b, ) ( - phi_x, - phi_out_xy, - phi_x_grad, - phi_y_grad, - phi_out_xy_grad, - phi_a, - phi_out_ab, - phi_a_grad, - phi_b_grad, - phi_out_ab_grad, + custom_x, + custom_out_xy, + custom_x_grad, + custom_y_grad, + custom_out_xy_grad, + custom_a, + custom_out_ab, + custom_a_grad, + custom_b_grad, + custom_out_ab_grad, ) = static_multi_inplace( True, device, @@ -630,23 +630,27 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_a, self.np_b, ) - self.check_output(phi_x, pd_out_xy, "inplace_phi_x") + self.check_output(custom_x, pd_out_xy, "inplace_custom_x") self.check_output( - phi_x_grad, phi_out_xy_grad, "inplace_phi_x_grad" + custom_x_grad, custom_out_xy_grad, "inplace_custom_x_grad" ) - self.check_output(phi_a, pd_out_ab, "inplace_phi_a") + self.check_output(custom_a, pd_out_ab, "inplace_custom_a") self.check_output( - phi_a_grad, phi_out_ab_grad, "inplace_phi_a_grad" + custom_a_grad, custom_out_ab_grad, "inplace_custom_a_grad" ) - self.check_output(phi_out_xy, pd_out_xy, "outxy") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") - self.check_output(phi_y_grad, pd_y_grad, "y_grad") - self.check_output(phi_out_xy_grad, pd_out_xy_grad, "outxy_grad") - self.check_output(phi_out_ab, pd_out_ab, "outab") - self.check_output(phi_a_grad, pd_a_grad, "a_grad") - self.check_output(phi_b_grad, pd_b_grad, "b_grad") - self.check_output(phi_out_ab_grad, pd_out_ab_grad, "outab_grad") + self.check_output(custom_out_xy, pd_out_xy, "outxy") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_y_grad, pd_y_grad, "y_grad") + self.check_output( + custom_out_xy_grad, pd_out_xy_grad, "outxy_grad" + ) + self.check_output(custom_out_ab, pd_out_ab, "outab") + self.check_output(custom_a_grad, pd_a_grad, "a_grad") + self.check_output(custom_b_grad, pd_b_grad, "b_grad") + self.check_output( + custom_out_ab_grad, pd_out_ab_grad, "outab_grad" + ) def test_dynamic_multi_inplace(self): for device in self.devices: @@ -672,16 +676,16 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_b, ) ( - phi_x, - phi_y, - phi_out_xy, - phi_x_grad, - phi_y_grad, - phi_a, - phi_b, - phi_out_ab, - phi_a_grad, - phi_b_grad, + custom_x, + custom_y, + custom_out_xy, + custom_x_grad, + custom_y_grad, + custom_a, + custom_b, + custom_out_ab, + custom_a_grad, + custom_b_grad, ) = dynamic_multi_inplace( True, device, @@ -692,21 +696,21 @@ class TestCustomInplaceJit(unittest.TestCase): self.np_b, ) - self.check_output(phi_x, phi_out_xy, "inplace_phi_x") + self.check_output(custom_x, custom_out_xy, "inplace_custom_x") self.check_output(pd_x, pd_out_xy, "inplace_pd_x") - self.check_output(phi_a, phi_out_ab, "inplace_phi_a") + self.check_output(custom_a, custom_out_ab, "inplace_custom_a") self.check_output(pd_a, pd_out_ab, "inplace_pd_a") - self.check_output(phi_x, pd_x, "x") - self.check_output(phi_y, pd_y, "y") - self.check_output(phi_out_xy, pd_out_xy, "outxy") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") - self.check_output(phi_y_grad, pd_y_grad, "y_grad") - self.check_output(phi_a, pd_a, "a") - self.check_output(phi_b, pd_b, "b") - self.check_output(phi_out_ab, pd_out_ab, "outab") - self.check_output(phi_a_grad, pd_a_grad, "a_grad") - self.check_output(phi_b_grad, pd_b_grad, "b_grad") + self.check_output(custom_x, pd_x, "x") + self.check_output(custom_y, pd_y, "y") + self.check_output(custom_out_xy, pd_out_xy, "outxy") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_y_grad, pd_y_grad, "y_grad") + self.check_output(custom_a, pd_a, "a") + self.check_output(custom_b, pd_b, "b") + self.check_output(custom_out_ab, pd_out_ab, "outab") + self.check_output(custom_a_grad, pd_a_grad, "a_grad") + self.check_output(custom_b_grad, pd_b_grad, "b_grad") if __name__ == "__main__": diff --git a/test/custom_op/test_custom_linear.py b/test/custom_op/test_custom_linear.py index 5d2a55456d7d23c60391f9f184abd0fed69bfe5d..5cd4b5e14f7dd5aad7667a6775ef12c5f3a018c0 100644 --- a/test/custom_op/test_custom_linear.py +++ b/test/custom_op/test_custom_linear.py @@ -112,12 +112,12 @@ class TestCustomLinearJit(unittest.TestCase): for device in self.devices: for dtype in self.dtypes: ( - phi_out, - phi_x_grad, - phi_weight_grad, - phi_bias_grad, + custom_out, + custom_x_grad, + custom_weight_grad, + custom_bias_grad, ) = linear_static( - custom_ops.phi_linear, + custom_ops.custom_linear, device, dtype, self.np_x, @@ -132,23 +132,23 @@ class TestCustomLinearJit(unittest.TestCase): self.np_weight, self.np_bias, ) - self.check_output(phi_out, pd_out, "out") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") self.check_output( - phi_weight_grad, pd_weight_grad, "weight_grad" + custom_weight_grad, pd_weight_grad, "weight_grad" ) - self.check_output(phi_bias_grad, pd_bias_grad, "bias_grad") + self.check_output(custom_bias_grad, pd_bias_grad, "bias_grad") def test_dynamic(self): for device in self.devices: for dtype in self.dtypes: ( - phi_out, - phi_x_grad, - phi_weight_grad, - phi_bias_grad, + custom_out, + custom_x_grad, + custom_weight_grad, + custom_bias_grad, ) = linear_dynamic( - custom_ops.phi_linear, + custom_ops.custom_linear, device, dtype, self.np_x, @@ -168,12 +168,12 @@ class TestCustomLinearJit(unittest.TestCase): self.np_weight, self.np_bias, ) - self.check_output(phi_out, pd_out, "phi_out") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_out, pd_out, "custom_out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") self.check_output( - phi_weight_grad, pd_weight_grad, "weight_grad" + custom_weight_grad, pd_weight_grad, "weight_grad" ) - self.check_output(phi_bias_grad, pd_bias_grad, "bias_grad") + self.check_output(custom_bias_grad, pd_bias_grad, "bias_grad") if __name__ == "__main__": diff --git a/test/custom_op/test_multi_out_jit.py b/test/custom_op/test_multi_out_jit.py index 9b652a0efccae187de8840ecd59a61441a1700c7..f3e3a6ec8abc138580d86f3b42361416b171b510 100644 --- a/test/custom_op/test_multi_out_jit.py +++ b/test/custom_op/test_multi_out_jit.py @@ -40,13 +40,13 @@ multi_out_module = load( ) -def discrete_out_dynamic(use_phi, device, dtype, np_w, np_x, np_y, np_z): +def discrete_out_dynamic(use_custom, device, dtype, np_w, np_x, np_y, np_z): paddle.set_device(device) w = paddle.to_tensor(np_w, dtype=dtype, stop_gradient=False) x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False) - if use_phi: + if use_custom: out = multi_out_module.discrete_out(w, x, y, z) else: out = w * 1 + x * 2 + y * 3 + z * 4 @@ -55,7 +55,7 @@ def discrete_out_dynamic(use_phi, device, dtype, np_w, np_x, np_y, np_z): return out.numpy(), w.grad.numpy(), y.grad.numpy() -def discrete_out_static(use_phi, device, dtype, np_w, np_x, np_y, np_z): +def discrete_out_static(use_custom, device, dtype, np_w, np_x, np_y, np_z): paddle.enable_static() paddle.set_device(device) with static.scope_guard(static.Scope()): @@ -68,7 +68,7 @@ def discrete_out_static(use_phi, device, dtype, np_w, np_x, np_y, np_z): x.stop_gradient = False y.stop_gradient = False z.stop_gradient = False - if use_phi: + if use_custom: out = multi_out_module.discrete_out(w, x, y, z) else: out = w * 1 + x * 2 + y * 3 + z * 4 @@ -180,7 +180,11 @@ class TestMultiOutputDtypes(unittest.TestCase): self.np_y, self.np_z, ) - (phi_out, phi_w_grad, phi_y_grad,) = discrete_out_static( + ( + custom_out, + custom_w_grad, + custom_y_grad, + ) = discrete_out_static( True, device, dtype, @@ -189,10 +193,10 @@ class TestMultiOutputDtypes(unittest.TestCase): self.np_y, self.np_z, ) - self.check_output(phi_out, pd_out, "out") + self.check_output(custom_out, pd_out, "out") # NOTE: In static mode, the output gradient of custom operator has been optimized to shape=[1]. However, native paddle op's output shape = [4, 8], hence we need to fetch pd_w_grad[0][0] (By the way, something wrong with native paddle's gradient, the outputs with other indexes instead of pd_w_grad[0][0] is undefined in this unittest.) - self.check_output(phi_w_grad, pd_w_grad[0][0], "w_grad") - self.check_output(phi_y_grad, pd_y_grad[0][0], "y_grad") + self.check_output(custom_w_grad, pd_w_grad[0][0], "w_grad") + self.check_output(custom_y_grad, pd_y_grad[0][0], "y_grad") def test_discrete_out_dynamic(self): for device in self.devices: @@ -206,7 +210,11 @@ class TestMultiOutputDtypes(unittest.TestCase): self.np_y, self.np_z, ) - (phi_out, phi_w_grad, phi_y_grad,) = discrete_out_dynamic( + ( + custom_out, + custom_w_grad, + custom_y_grad, + ) = discrete_out_dynamic( True, device, dtype, @@ -215,9 +223,9 @@ class TestMultiOutputDtypes(unittest.TestCase): self.np_y, self.np_z, ) - self.check_output(phi_out, pd_out, "out") - self.check_output(phi_w_grad, pd_w_grad, "w_grad") - self.check_output(phi_y_grad, pd_y_grad, "y_grad") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_w_grad, pd_w_grad, "w_grad") + self.check_output(custom_y_grad, pd_y_grad, "y_grad") if __name__ == '__main__':