From 59c9d75e5f9db33e277aacff8fe8db375833d49d Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Mon, 3 Apr 2023 18:45:05 +0800 Subject: [PATCH] [CustomOP Optional Inplace] Custom operator supports inplace optional vector Tensor input (#52421) * [CustomOP Optional Inplace] Custom operator supports inplace optional vector Tensor input * uncomment unittest codes --- paddle/fluid/framework/custom_operator.cc | 6 +- paddle/phi/api/ext/op_meta_info.h | 44 +- paddle/phi/api/lib/op_meta_info.cc | 12 + .../utils/cpp_extension/extension_utils.py | 16 + test/custom_op/custom_optional.cc | 113 ++++++ test/custom_op/test_custom_optional.py | 377 +++++++++++++++--- 6 files changed, 493 insertions(+), 75 deletions(-) diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index ceeb2a224d7..641674695ca 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -833,7 +833,11 @@ static void RunInferDtypeFunc( auto in_name = inplace_reverse_map.at(out_name); // make sure ctx has valid inplace optional outputs if (ctx->HasOutput(out_name)) { - ctx->SetOutputDataTypes(out_name, ctx->GetInputDataTypes(in_name)); + size_t size = ctx->InputSize(in_name); + for (size_t i = 0; i < size; ++i) { + auto dtype = ctx->GetInputDataType(in_name, i); + ctx->SetOutputDataType(out_name, dtype, i); + } } else { PADDLE_ENFORCE( detail::IsOptionalVar(out_name), diff --git a/paddle/phi/api/ext/op_meta_info.h b/paddle/phi/api/ext/op_meta_info.h index 479b701e3fb..4a9a10a53aa 100644 --- a/paddle/phi/api/ext/op_meta_info.h +++ b/paddle/phi/api/ext/op_meta_info.h @@ -120,6 +120,8 @@ class PADDLE_API CustomOpKernelContext { std::vector InputsBetween(size_t start, size_t end) const; Tensor& MutableInputAt(size_t idx); paddle::optional OptionalInputAt(size_t idx); + paddle::optional> OptionalInputsBetween(size_t start, + size_t end); const std::vector& Attrs() const { return attrs_; } const std::vector>& InputRange() { return input_range_; @@ -294,6 +296,34 @@ struct KernelFuncImpl { } }; + // Handle args for inplace vector case + template + struct ComputeCallHelper&, Tail...> { + template + static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) { + auto& range = ctx->InputRangeAt(in_idx); + auto arg = ctx->InputsBetween(range.first, range.second); + ComputeCallHelper< + Tail...>::template Compute(ctx, + pargs..., + arg); + } + }; + + // Handle args for optional inplace vector case + template + struct ComputeCallHelper>&, Tail...> { + template + static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) { + auto& range = ctx->InputRangeAt(in_idx); + auto arg = ctx->OptionalInputsBetween(range.first, range.second); + ComputeCallHelper< + Tail...>::template Compute(ctx, + pargs..., + arg); + } + }; + PD_SPECIALIZE_ComputeCallHelper(bool); PD_SPECIALIZE_ComputeCallHelper(int); PD_SPECIALIZE_ComputeCallHelper(float); @@ -358,20 +388,6 @@ struct KernelFuncImpl { } }; - // Handle args for inplace vector case - template - struct ComputeCallHelper&, Tail...> { - template - static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) { - auto& range = ctx->InputRangeAt(in_idx); - auto arg = ctx->InputsBetween(range.first, range.second); - ComputeCallHelper< - Tail...>::template Compute(ctx, - pargs..., - arg); - } - }; - template struct ComputeReturnHelper; diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index 7b11bd02084..bdc46a4e0e7 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -110,6 +110,18 @@ paddle::optional CustomOpKernelContext::OptionalInputAt(size_t idx) { return paddle::make_optional(inputs_.at(idx)); } +paddle::optional> +CustomOpKernelContext::OptionalInputsBetween(size_t start, size_t end) { + std::vector rlt; + for (size_t i = start; i < end; ++i) { + if (!inputs_.at(i).is_initialized()) { + return paddle::none; + } + rlt.emplace_back(inputs_.at(i)); + } + return paddle::make_optional>(rlt); +} + Tensor* CustomOpKernelContext::MutableOutputAt(size_t idx) { return &(outputs_.at(idx)); } diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index cf19ec558ea..8958c6bc7ac 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -1057,6 +1057,22 @@ def _gen_output_content( if out_idx in inplace_reverse_idx: in_idx = inplace_reverse_idx[out_idx] if ( + in_idx != -1 + and "@VECTOR" in in_names[in_idx] + and "@OPTIONAL" in in_names[in_idx] + ): + # inplace optional vector output case + lower_in_names = in_names[in_idx].split("@")[0].lower() + dynamic_content += f""" +{indent}if {lower_in_names} is not None: +{indent} outs['{out_name}'] = [core.eager.Tensor() for _ in range(len({lower_in_names}))] +{indent}else: +{indent} outs['{out_name}'] = core.eager.Tensor() +{indent}ctx.add_outputs(outs['{out_name}'])""" + static_content += f""" +{indent}if {lower_in_names} is not None: +{indent} outs['{out_name}'] = [helper.create_variable(dtype='float32') for _ in range(len({lower_in_names}))]""" + elif ( in_idx != -1 and "@VECTOR" in in_names[in_idx] ): # inplace vector output case lower_in_names = in_names[in_idx].split("@")[0].lower() diff --git a/test/custom_op/custom_optional.cc b/test/custom_op/custom_optional.cc index 52c8e989d0e..0e28ce84d5a 100644 --- a/test/custom_op/custom_optional.cc +++ b/test/custom_op/custom_optional.cc @@ -300,3 +300,116 @@ PD_BUILD_GRAD_OP(custom_optional_inplace_add) paddle::Grad(paddle::Optional("Y"))}}) .SetKernelFn(PD_KERNEL(AddOptionalInplaceBackward)) .SetInferShapeFn(PD_INFER_SHAPE(AddOptionalInplaceBackwardInferShape)); + +/* +if (y) { + outX = 2 * x + y[1...n]; + outY[i] = x + y[i]; +} else { + outX = 2 * x; + outY = None; +} +*/ +std::vector AddOptionalInplaceVectorForward( + const paddle::Tensor& x, + paddle::optional>& y) { // NOLINT + PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); + paddle::Tensor outX = paddle::zeros(x.shape(), x.dtype(), x.place()); + + PD_DISPATCH_FLOATING_TYPES( + x.type(), "AddOptionalInplaceVectorForward", ([&] { + add_two_pointers( + x.data(), x.data(), outX.data(), x.size()); + if (y) { + for (size_t i = 0; i < y->size(); ++i) { + add_one_pointer( + y->at(i).data(), outX.data(), outX.size()); + add_one_pointer( + x.data(), y->at(i).data(), x.size()); + } + } + })); + // No need to return y, because we set it as inplace input. + return {outX}; +} + +std::vector AddOptionalInplaceVectorInferDtype( + const paddle::DataType& x_dtype, + const paddle::optional>& y_dtype) { + return {x_dtype}; +} + +std::vector> AddOptionalInplaceVectorInferShape( + const std::vector& x_shape, + const paddle::optional>>& y_shape) { + return {x_shape}; +} + +/* +if (outy_grad) { + x_grad = outX_grad * 2 + outY_grad[1...n]; + y_grad[i] = outX_grad + outY_grad[i]; +} else { + x_grad = outX_grad * 2; + y_grad = None; +} +*/ +std::vector AddOptionalInplaceVectorBackward( + const paddle::Tensor& x, + const paddle::optional>& y, + const paddle::Tensor& outx_grad, + paddle::optional>& outy_grad) { // NOLINT + PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); + + paddle::Tensor x_grad = paddle::zeros(x.shape(), x.dtype(), x.place()); + + PD_DISPATCH_FLOATING_TYPES( + outx_grad.type(), "AddOptionalInplaceVectorBackward", ([&] { + add_two_pointers(outx_grad.data(), + outx_grad.data(), + x_grad.data(), + x_grad.size()); + if (outy_grad) { + for (size_t i = 0; i < outy_grad->size(); ++i) { + add_one_pointer(outy_grad->at(i).data(), + x_grad.data(), + x_grad.size()); + add_one_pointer(outx_grad.data(), + outy_grad->at(i).data(), + outx_grad.size()); + } + } + })); + + return {x_grad}; +} + +std::vector> AddOptionalInplaceVectorBackwardInferShape( + const std::vector& x_shape, + const paddle::optional>>& y_shape, + const std::vector& x_grad_shape, + const paddle::optional>>& y_grad_shape) { + return {x_shape}; +} + +PD_BUILD_OP(custom_optional_inplace_add_vec) + .Inputs({"X", paddle::Optional(paddle::Vec("Y"))}) + .Outputs({"OutX", paddle::Optional(paddle::Vec("OutY"))}) + .SetInplaceMap({{paddle::Optional(paddle::Vec("Y")), + paddle::Optional(paddle::Vec("OutY"))}}) + .SetKernelFn(PD_KERNEL(AddOptionalInplaceVectorForward)) + .SetInferShapeFn(PD_INFER_SHAPE(AddOptionalInplaceVectorInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(AddOptionalInplaceVectorInferDtype)); + +PD_BUILD_GRAD_OP(custom_optional_inplace_add_vec) + .Inputs({"X", + paddle::Optional(paddle::Vec("Y")), + paddle::Grad("OutX"), + paddle::Grad(paddle::Optional(paddle::Vec("OutY")))}) + .Outputs({paddle::Grad("X"), + paddle::Grad(paddle::Optional(paddle::Vec("Y")))}) + .SetInplaceMap({{paddle::Grad(paddle::Optional(paddle::Vec("OutY"))), + paddle::Grad(paddle::Optional(paddle::Vec("Y")))}}) + .SetKernelFn(PD_KERNEL(AddOptionalInplaceVectorBackward)) + .SetInferShapeFn( + PD_INFER_SHAPE(AddOptionalInplaceVectorBackwardInferShape)); diff --git a/test/custom_op/test_custom_optional.py b/test/custom_op/test_custom_optional.py index 8cc92dd7d8a..53d4f159527 100644 --- a/test/custom_op/test_custom_optional.py +++ b/test/custom_op/test_custom_optional.py @@ -41,7 +41,7 @@ custom_optional = load( ) -def optional_dynamic_add(phi_func, device, dtype, np_x, np_y): +def optional_dynamic_add(custom_func, device, dtype, np_x, np_y): paddle.set_device(device) x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) @@ -49,7 +49,7 @@ def optional_dynamic_add(phi_func, device, dtype, np_x, np_y): y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) else: y = x - if phi_func: + if custom_func: out = custom_optional.custom_add(x, y if np_y is not None else None) else: out = paddle.add(x, y) @@ -58,7 +58,7 @@ def optional_dynamic_add(phi_func, device, dtype, np_x, np_y): return x.numpy(), out.numpy(), x.grad.numpy() -def optional_static_add(phi_func, device, dtype, np_x, np_y): +def optional_static_add(custom_func, device, dtype, np_x, np_y): paddle.enable_static() paddle.set_device(device) with static.scope_guard(static.Scope()): @@ -79,7 +79,7 @@ def optional_static_add(phi_func, device, dtype, np_x, np_y): feed_dict = { "x": np_x.astype(dtype), } - if phi_func: + if custom_func: out = custom_optional.custom_add( x, y if np_y is not None else None ) @@ -116,13 +116,13 @@ if (y) { ''' -def optional_inplace_dynamic_add(phi_func, device, dtype, np_x, np_y): +def optional_inplace_dynamic_add(custom_func, device, dtype, np_x, np_y): paddle.set_device(device) x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) if np_y is not None: y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=True) - if phi_func: + if custom_func: outx, outy = custom_optional.custom_optional_inplace_add(x, y) else: # We need to accumulate y's grad here. @@ -133,7 +133,7 @@ def optional_inplace_dynamic_add(phi_func, device, dtype, np_x, np_y): outy = y.add_(x) else: y = None - if phi_func: + if custom_func: outx, outy = custom_optional.custom_optional_inplace_add(x, y) else: outx = 2 * x @@ -155,7 +155,7 @@ def optional_inplace_dynamic_add(phi_func, device, dtype, np_x, np_y): ) -def optional_inplace_static_add(phi_func, device, dtype, np_x, np_y): +def optional_inplace_static_add(custom_func, device, dtype, np_x, np_y): paddle.enable_static() paddle.set_device(device) with static.scope_guard(static.Scope()): @@ -171,7 +171,7 @@ def optional_inplace_static_add(phi_func, device, dtype, np_x, np_y): "x": np_x.astype(dtype), "y": np_y.astype(dtype), } - if phi_func: + if custom_func: outx, outy = custom_optional.custom_optional_inplace_add( x, y ) @@ -182,7 +182,7 @@ def optional_inplace_static_add(phi_func, device, dtype, np_x, np_y): feed_dict = { "x": np_x.astype(dtype), } - if phi_func: + if custom_func: outx, outy = custom_optional.custom_optional_inplace_add( x, None ) @@ -223,7 +223,7 @@ def optional_inplace_static_add(phi_func, device, dtype, np_x, np_y): return [x_v, out_v, x_grad_v] -def optional_vector_dynamic_add(phi_func, device, dtype, np_x, np_inputs): +def optional_vector_dynamic_add(custom_func, device, dtype, np_x, np_inputs): paddle.set_device(device) x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) @@ -232,14 +232,14 @@ def optional_vector_dynamic_add(phi_func, device, dtype, np_x, np_inputs): paddle.to_tensor(np_input, dtype=dtype, stop_gradient=False) for np_input in np_inputs ] - if phi_func: + if custom_func: out = custom_optional.custom_add_vec(x, inputs) else: out = paddle.add(x, inputs[0]) for input in inputs[1:]: out = paddle.add(out, input) else: - if phi_func: + if custom_func: out = custom_optional.custom_add_vec(x, None) else: out = paddle.add(x, x) @@ -248,7 +248,7 @@ def optional_vector_dynamic_add(phi_func, device, dtype, np_x, np_inputs): return x.numpy(), out.numpy(), x.grad.numpy() -def optional_vector_static_add(phi_func, device, dtype, np_x, np_inputs): +def optional_vector_static_add(custom_func, device, dtype, np_x, np_inputs): paddle.enable_static() paddle.set_device(device) with static.scope_guard(static.Scope()): @@ -271,13 +271,13 @@ def optional_vector_static_add(phi_func, device, dtype, np_x, np_inputs): "y2": np_inputs[1].astype(dtype), } ) - if phi_func: + if custom_func: out = custom_optional.custom_add_vec(x, [y1, y2]) else: out = paddle.add(x, y1) out = paddle.add(out, y2) else: - if phi_func: + if custom_func: out = custom_optional.custom_add_vec(x, None) else: out = paddle.add(x, x) @@ -301,6 +301,159 @@ def optional_vector_static_add(phi_func, device, dtype, np_x, np_inputs): return x_v, out_v, x_grad_v +''' +if (y) { + outX = 2 * x + y[1...n]; + outY[i] = x + y[i]; +} else { + outX = 2 * x; + outY = None; +} +''' + + +def optional_inplace_vector_dynamic_add( + custom_func, device, dtype, np_x, np_inputs +): + paddle.set_device(device) + x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) + + if np_inputs is not None: + inputs = [ + paddle.to_tensor(np_input, dtype=dtype, stop_gradient=True) + for np_input in np_inputs + ] + if custom_func: + outx, outy = custom_optional.custom_optional_inplace_add_vec( + x, inputs + ) + else: + outx = 2 * x + outy = [] + for input in inputs: + # We need to accumulate y's grad here. + input.stop_gradient = False + outx = outx + input + # Inplace leaf Tensor's stop_gradient should be True + input.stop_gradient = True + outy.append(input.add_(x)) + else: + if custom_func: + outx, outy = custom_optional.custom_optional_inplace_add_vec( + x, None + ) + else: + outx = 2 * x + outy = None + assert ( + outy is None + ), "The output `outy` of optional_inplace_dynamic_add should be None" + + if outy is not None: + out = outx + for tensor in outy: + out = out + tensor + else: + out = outx + out.backward() + return ( + x.numpy(), + outx.numpy(), + [y.numpy() for y in inputs] if np_inputs is not None else None, + [t.numpy() for t in outy] if outy is not None else None, + out.numpy(), + x.grad.numpy(), + [y.grad.numpy() for y in inputs] + if np_inputs is not None and inputs[0].grad is not None + else None, + ) + + +def optional_inplace_vector_static_add( + custom_func, device, dtype, np_x, np_inputs +): + paddle.enable_static() + paddle.set_device(device) + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + x = static.data(name="x", shape=[None, np_x.shape[1]], dtype=dtype) + x.stop_gradient = False + feed_dict = { + "x": np_x.astype(dtype), + } + if np_inputs is not None: + y1 = static.data( + name="y1", shape=[None, np_x.shape[1]], dtype=dtype + ) + y1.stop_gradient = False + y2 = static.data( + name="y2", shape=[None, np_x.shape[1]], dtype=dtype + ) + y2.stop_gradient = False + feed_dict.update( + { + "y1": np_inputs[0].astype(dtype), + "y2": np_inputs[1].astype(dtype), + } + ) + if custom_func: + ( + outx, + outy, + ) = custom_optional.custom_optional_inplace_add_vec( + x, [y1, y2] + ) + else: + outx = paddle.add(paddle.add(paddle.add(x, x), y1), y2) + # outx = 2 * x + y1 + y2 + outy = [x + y1, x + y2] + else: + if custom_func: + ( + outx, + outy, + ) = custom_optional.custom_optional_inplace_add_vec(x, None) + else: + outx = 2 * x + outy = None + if np_inputs is not None: + out = outx + outy[0] + outy[1] + else: + out = outx + mean_out = paddle.mean(out) + static.append_backward(mean_out) + + exe = static.Executor() + exe.run(static.default_startup_program()) + + if np_inputs is not None: + x_v, out_v, x_grad_v, y1_grad_v, y2_grad_v = exe.run( + static.default_main_program(), + feed=feed_dict, + fetch_list=[ + x.name, + out.name, + x.name + "@GRAD", + y1.name + "@GRAD", + y2.name + "@GRAD", + ], + ) + paddle.disable_static() + return [x_v, out_v, x_grad_v, y1_grad_v, y2_grad_v] + else: + x_v, out_v, x_grad_v = exe.run( + static.default_main_program(), + feed=feed_dict, + fetch_list=[ + x.name, + out.name, + x.name + "@GRAD", + ], + ) + paddle.disable_static() + return [x_v, out_v, x_grad_v] + + class TestCustomOptionalJit(unittest.TestCase): def setUp(self): self.dtypes = ['float32', 'float64'] @@ -317,13 +470,23 @@ class TestCustomOptionalJit(unittest.TestCase): return assert out is not None, "out value of " + name + " is None" assert pd_out is not None, "pd_out value of " + name + " is None" - np.testing.assert_array_equal( - out, - pd_out, - err_msg='custom op {}: {},\n paddle api {}: {}'.format( - name, out, name, pd_out - ), - ) + if isinstance(out, list) and isinstance(pd_out, list): + for idx in range(len(out)): + np.testing.assert_array_equal( + out[idx], + pd_out[idx], + err_msg='custom op {}: {},\n paddle api {}: {}'.format( + name, out[idx], name, pd_out[idx] + ), + ) + else: + np.testing.assert_array_equal( + out, + pd_out, + err_msg='custom op {}: {},\n paddle api {}: {}'.format( + name, out, name, pd_out + ), + ) def check_output_allclose(self, out, pd_out, name): if out is None and pd_out is None: @@ -351,7 +514,11 @@ class TestCustomOptionalJit(unittest.TestCase): self.np_x, np_y, ) - (phi_x, phi_out, phi_x_grad,) = optional_static_add( + ( + custom_x, + custom_out, + custom_x_grad, + ) = optional_static_add( True, device, dtype, @@ -359,9 +526,9 @@ class TestCustomOptionalJit(unittest.TestCase): np_y, ) - self.check_output(phi_x, pd_x, "x") - self.check_output(phi_out, pd_out, "out") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_x, pd_x, "x") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") def test_optional_dynamic_add(self): for device in self.devices: @@ -374,7 +541,11 @@ class TestCustomOptionalJit(unittest.TestCase): self.np_x, np_y, ) - (phi_x, phi_out, phi_x_grad,) = optional_dynamic_add( + ( + custom_x, + custom_out, + custom_x_grad, + ) = optional_dynamic_add( True, device, dtype, @@ -382,9 +553,9 @@ class TestCustomOptionalJit(unittest.TestCase): np_y, ) - self.check_output(phi_x, pd_x, "x") - self.check_output(phi_out, pd_out, "out") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_x, pd_x, "x") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") def test_optional_inplace_static_add(self): for device in self.devices: @@ -397,7 +568,7 @@ class TestCustomOptionalJit(unittest.TestCase): self.np_x, np_y, ) - phi_tuple = optional_inplace_static_add( + custom_tuple = optional_inplace_static_add( True, device, dtype, @@ -405,11 +576,13 @@ class TestCustomOptionalJit(unittest.TestCase): np_y, ) - self.check_output(phi_tuple[0], pd_tuple[0], "x") - self.check_output(phi_tuple[1], pd_tuple[1], "out") - self.check_output(phi_tuple[2], pd_tuple[2], "x_grad") - if len(phi_tuple) > 3: - self.check_output(phi_tuple[3], pd_tuple[3], "y_grad") + self.check_output(custom_tuple[0], pd_tuple[0], "x") + self.check_output(custom_tuple[1], pd_tuple[1], "out") + self.check_output(custom_tuple[2], pd_tuple[2], "x_grad") + if len(custom_tuple) > 3: + self.check_output( + custom_tuple[3], pd_tuple[3], "y_grad" + ) def test_optional_inplace_dynamic_add(self): for device in self.devices: @@ -431,13 +604,13 @@ class TestCustomOptionalJit(unittest.TestCase): np_y, ) ( - phi_x, - phi_outx, - phi_y, - phi_outy, - phi_out, - phi_x_grad, - phi_y_grad, + custom_x, + custom_outx, + custom_y, + custom_outy, + custom_out, + custom_x_grad, + custom_y_grad, ) = optional_inplace_dynamic_add( True, device, @@ -447,21 +620,25 @@ class TestCustomOptionalJit(unittest.TestCase): ) self.check_output(pd_y, pd_outy, "inplace_pd_y") - self.check_output(phi_y, phi_outy, "inplace_phi_y") + self.check_output(custom_y, custom_outy, "inplace_custom_y") - self.check_output(phi_x, pd_x, "x") - self.check_output(phi_outx, pd_outx, "outx") - self.check_output(phi_y, pd_y, "y") - self.check_output(phi_outy, pd_outy, "outy") - self.check_output(phi_out, pd_out, "out") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") - self.check_output(phi_y_grad, pd_y_grad, "y_grad") + self.check_output(custom_x, pd_x, "x") + self.check_output(custom_outx, pd_outx, "outx") + self.check_output(custom_y, pd_y, "y") + self.check_output(custom_outy, pd_outy, "outy") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_y_grad, pd_y_grad, "y_grad") def test_optional_vector_static_add(self): for device in self.devices: for dtype in self.dtypes: for np_y in [None, self.np_inputs]: - (phi_x, phi_out, phi_x_grad,) = optional_vector_static_add( + ( + custom_x, + custom_out, + custom_x_grad, + ) = optional_vector_static_add( True, device, dtype, @@ -476,15 +653,19 @@ class TestCustomOptionalJit(unittest.TestCase): np_y, ) - self.check_output(phi_x, pd_x, "x") - self.check_output(phi_out, pd_out, "out") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_x, pd_x, "x") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") def test_optional_vector_dynamic_add(self): for device in self.devices: for dtype in self.dtypes: for np_y in [None, self.np_inputs]: - (phi_x, phi_out, phi_x_grad,) = optional_vector_dynamic_add( + ( + custom_x, + custom_out, + custom_x_grad, + ) = optional_vector_dynamic_add( True, device, dtype, @@ -499,9 +680,85 @@ class TestCustomOptionalJit(unittest.TestCase): np_y, ) - self.check_output(phi_x, pd_x, "x") - self.check_output(phi_out, pd_out, "out") - self.check_output(phi_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_x, pd_x, "x") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") + + def test_optional_inplace_vector_static_add(self): + for device in self.devices: + for dtype in self.dtypes: + for np_y in [None, self.np_inputs]: + pd_tuple = optional_inplace_vector_static_add( + False, + device, + dtype, + self.np_x, + np_y, + ) + custom_tuple = optional_inplace_vector_static_add( + True, + device, + dtype, + self.np_x, + np_y, + ) + + self.check_output(custom_tuple[0], pd_tuple[0], "x") + self.check_output(custom_tuple[1], pd_tuple[1], "out") + self.check_output(custom_tuple[2], pd_tuple[2], "x_grad") + if len(custom_tuple) > 3: + self.check_output( + custom_tuple[3], pd_tuple[3], "y1_grad" + ) + self.check_output( + custom_tuple[4], pd_tuple[4], "y2_grad" + ) + + def test_optional_inplace_vector_dynamic_add(self): + for device in self.devices: + for dtype in self.dtypes: + for np_y in [None, self.np_inputs]: + ( + custom_x, + custom_outx, + custom_y, + custom_outy, + custom_out, + custom_x_grad, + custom_y_grad, + ) = optional_inplace_vector_dynamic_add( + True, + device, + dtype, + self.np_x, + np_y, + ) + ( + pd_x, + pd_outx, + pd_y, + pd_outy, + pd_out, + pd_x_grad, + pd_y_grad, + ) = optional_inplace_vector_dynamic_add( + False, + device, + dtype, + self.np_x, + np_y, + ) + + self.check_output(pd_y, pd_outy, "inplace_pd_y") + self.check_output(custom_y, custom_outy, "inplace_custom_y") + + self.check_output(custom_x, pd_x, "x") + self.check_output(custom_outx, pd_outx, "outx") + self.check_output(custom_y, pd_y, "y") + self.check_output(custom_outy, pd_outy, "outy") + self.check_output(custom_out, pd_out, "out") + self.check_output(custom_x_grad, pd_x_grad, "x_grad") + self.check_output(custom_y_grad, pd_y_grad, "y_grad") if __name__ == "__main__": -- GitLab