From 71b046cda4d2c1751cfbc280e3695261f12fe8b4 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 25 May 2022 10:58:15 +0800 Subject: [PATCH] [EinsumOp] Optimize the backward speed of EinsumOp (#42663) * change logic for optimize * modifty * optimize the backward speed of EinsumOp * add cache optimizer for einsum op * EinsumOp: fix new dygraph mode error * fix bug * change Cache->InnerCache * fix code * fix * add nan inf utils for einsum op * add as_extra * Compatible with v2.3 EinsumOp * remove dispensable --- paddle/fluid/eager/nan_inf_utils.cc | 6 ++ paddle/fluid/eager/nan_inf_utils.h | 4 ++ paddle/fluid/operators/einsum_op.cc | 8 +++ paddle/phi/infermeta/unary.cc | 3 +- paddle/phi/infermeta/unary.h | 3 +- paddle/phi/kernels/cpu/einsum_kernel.cc | 3 +- paddle/phi/kernels/einsum_grad_kernel.h | 1 + paddle/phi/kernels/einsum_kernel.h | 7 +++ paddle/phi/kernels/gpu/einsum_grad_kernel.cu | 9 ++- paddle/phi/kernels/gpu/einsum_kernel.cu | 9 ++- paddle/phi/kernels/impl/einsum_grad_impl.h | 41 ++++++++++-- paddle/phi/kernels/impl/einsum_impl.h | 63 +++++++++++++++---- paddle/phi/ops/compat/einsum_sig.cc | 7 ++- .../fluid/tests/unittests/test_einsum_op.py | 8 ++- .../white_list/no_check_set_white_list.py | 1 + python/paddle/tensor/einsum.py | 14 +++-- python/paddle/utils/code_gen/api.yaml | 2 +- python/paddle/utils/code_gen/api_base.py | 6 +- python/paddle/utils/code_gen/api_gen.py | 4 +- python/paddle/utils/code_gen/backward.yaml | 4 +- 20 files changed, 165 insertions(+), 38 deletions(-) diff --git a/paddle/fluid/eager/nan_inf_utils.cc b/paddle/fluid/eager/nan_inf_utils.cc index d6769550166..d1c5983a370 100644 --- a/paddle/fluid/eager/nan_inf_utils.cc +++ b/paddle/fluid/eager/nan_inf_utils.cc @@ -110,4 +110,10 @@ void CheckTensorHasNanOrInf( } } +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfTensorAndVector& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); +} + } // namespace egr diff --git a/paddle/fluid/eager/nan_inf_utils.h b/paddle/fluid/eager/nan_inf_utils.h index 5309eeb2959..a411504fa49 100644 --- a/paddle/fluid/eager/nan_inf_utils.h +++ b/paddle/fluid/eager/nan_inf_utils.h @@ -31,6 +31,7 @@ using TupleOfFourTensors = std::tuple; using TupleOfFiveTensors = std::tuple; using TupleOfSixTensors = std::tuple; +using TupleOfTensorAndVector = std::tuple>; void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor); @@ -52,6 +53,9 @@ void CheckTensorHasNanOrInf(const std::string& api_name, void CheckTensorHasNanOrInf(const std::string& api_name, const std::vector& tensors); +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfTensorAndVector& tensors); + void CheckTensorHasNanOrInf( const std::string& api_name, const paddle::small_vector, diff --git a/paddle/fluid/operators/einsum_op.cc b/paddle/fluid/operators/einsum_op.cc index 8fdde1ccdc0..6da0045443c 100644 --- a/paddle/fluid/operators/einsum_op.cc +++ b/paddle/fluid/operators/einsum_op.cc @@ -33,6 +33,13 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Operands", "(TensorList), The input tensor of einsum op.") .AsDuplicable(); AddOutput("Out", "(Tensor), The output tensor of einsum op."); + AddOutput( + "InnerCache", + "(Tensor), The cache of the forward transpose tensors: tA and tB.") + .AsDuplicable() + .AsExtra() + .AsIntermediate(); + AddAttr("equation", "(string) A einsum equation. such as `ij,jk->ik`" "There must have `->` and the number of operands in " @@ -72,6 +79,7 @@ class EinsumGradMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr retv) const override { retv->SetType("einsum_grad"); retv->SetInput("Operands", this->Input("Operands")); + retv->SetInput("InnerCache", this->Output("InnerCache")); retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); retv->SetAttrMap(this->Attrs()); retv->SetOutput(framework::GradVarName("Operands"), diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c88c2d6f60f..1ec804d1bf8 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -401,7 +401,8 @@ void EighInferMeta(const MetaTensor& x, void EinsumInferMeta(const std::vector& inputs, const std::string& equation, - MetaTensor* out) { + MetaTensor* out, + std::vector inner_cache) { // collect the following informations to prepare einsum. LabelMap labelshape(0); LabelMap labeltype(LabelType::Reduction); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 58b256dc66e..25ea003f58f 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -82,7 +82,8 @@ void EighInferMeta(const MetaTensor& x, void EinsumInferMeta(const std::vector& inputs, const std::string& equation, - MetaTensor* out); + MetaTensor* out, + std::vector inner_cache); void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, diff --git a/paddle/phi/kernels/cpu/einsum_kernel.cc b/paddle/phi/kernels/cpu/einsum_kernel.cc index 3e25a65526d..8968542b3e0 100644 --- a/paddle/phi/kernels/cpu/einsum_kernel.cc +++ b/paddle/phi/kernels/cpu/einsum_kernel.cc @@ -17,4 +17,5 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_impl.h" -PD_REGISTER_KERNEL(einsum, CPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {} +PD_REGISTER_KERNEL( + einsum, CPU, ALL_LAYOUT, phi::EinsumKernelRaw, float, double) {} diff --git a/paddle/phi/kernels/einsum_grad_kernel.h b/paddle/phi/kernels/einsum_grad_kernel.h index 5c1970e7758..06785c8532e 100644 --- a/paddle/phi/kernels/einsum_grad_kernel.h +++ b/paddle/phi/kernels/einsum_grad_kernel.h @@ -21,6 +21,7 @@ namespace phi { template void EinsumGradKernel(const Context& dev_ctx, const std::vector& x, + const std::vector& inner_cache, const DenseTensor& out_grad, const std::string& equation, std::vector x_grad); diff --git a/paddle/phi/kernels/einsum_kernel.h b/paddle/phi/kernels/einsum_kernel.h index 3d9e8feda74..87df2b1c64a 100644 --- a/paddle/phi/kernels/einsum_kernel.h +++ b/paddle/phi/kernels/einsum_kernel.h @@ -24,4 +24,11 @@ void EinsumKernel(const Context& dev_ctx, const std::string& equation, DenseTensor* out); +template +void EinsumKernelRaw(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out, + std::vector cache); + } // namespace phi diff --git a/paddle/phi/kernels/gpu/einsum_grad_kernel.cu b/paddle/phi/kernels/gpu/einsum_grad_kernel.cu index c8a8745f345..6ca8dbd9205 100644 --- a/paddle/phi/kernels/gpu/einsum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/einsum_grad_kernel.cu @@ -18,5 +18,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_grad_impl.h" -PD_REGISTER_KERNEL( - einsum_grad, GPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {} +PD_REGISTER_KERNEL(einsum_grad, + GPU, + ALL_LAYOUT, + phi::EinsumGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/einsum_kernel.cu b/paddle/phi/kernels/gpu/einsum_kernel.cu index d73e154eb40..d1f4c659038 100644 --- a/paddle/phi/kernels/gpu/einsum_kernel.cu +++ b/paddle/phi/kernels/gpu/einsum_kernel.cu @@ -18,4 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_impl.h" -PD_REGISTER_KERNEL(einsum, GPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {} +PD_REGISTER_KERNEL(einsum, + GPU, + ALL_LAYOUT, + phi::EinsumKernelRaw, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/einsum_grad_impl.h b/paddle/phi/kernels/impl/einsum_grad_impl.h index 2b087f8dcae..aceb97a49b1 100644 --- a/paddle/phi/kernels/impl/einsum_grad_impl.h +++ b/paddle/phi/kernels/impl/einsum_grad_impl.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "paddle/fluid/platform/profiler.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/tile_kernel.h" @@ -55,7 +56,13 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, } t.Resize(make_ddim(resize_dims)); DenseTensor after_tile; - TileKernel(dev_ctx, t, repeat_times, &after_tile); + if (std::all_of(repeat_times.begin(), repeat_times.end(), [](int x) { + return x == 1; + })) { + after_tile = t; + } else { + TileKernel(dev_ctx, t, repeat_times, &after_tile); + } size_t n_ellipsis_idx = op_label.find(".", 0); if (n_ellipsis_idx != std::string::npos) { // may be we need reduce. broadcast_dims is not equal to ellipsis dims. @@ -91,10 +98,11 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, template void EinsumGradKernel(const Context& dev_ctx, const std::vector& x, + const std::vector& inner_cache, const DenseTensor& out_grad, const std::string& equation, std::vector x_grad) { - VLOG(5) << "Start EisumGradKernel:"; + VLOG(5) << "Start EinsumGradKernel:"; LabelMap labelshape(0); LabelMap labeltype(LabelType::Reduction); std::vector label2perms(x.size(), LabelMap(-1)); @@ -162,8 +170,33 @@ void EinsumGradKernel(const Context& dev_ctx, operands_for_B.push_back(x[0]); DenseTensor before_tile; - EinsumKernel(dev_ctx, operands_for_A, equation_for_A, &dA); - EinsumKernel(dev_ctx, operands_for_B, equation_for_B, &dB); + + std::vector cache(3); // set empty; TA, TB, TdC + if (inner_cache.size() > + 0) { // for compatibility, we can load and run v2.3 EinsumOp. + cache[0].ShareBufferWith(*(inner_cache[0])); + cache[1].ShareBufferWith(*(inner_cache[1])); + } + + EinsumKernelImpl(dev_ctx, + all_labels, + operands_for_A, + equation_for_A, + &dA, + {&cache[1], &cache[2]}, + false); + + EinsumKernelImpl(dev_ctx, + all_labels, + operands_for_B, + equation_for_B, + &dB, + {&cache[2], &cache[0]}, + false); + + // release the cache tensor dTC to save memory right now. they are useless + // now. + cache.clear(); *(x_grad[0]) = PerformTileAndReduction(dev_ctx, labeltype, labelshape, diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h index 901147734b2..5e4480426c0 100644 --- a/paddle/phi/kernels/impl/einsum_impl.h +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -137,7 +137,6 @@ inline std::vector TransformLabelsOrder( std::vector tmp; for (int c : all_labels) { if (type[c] == cnt_type) tmp.push_back(c); - std::sort(tmp.begin(), tmp.end()); } ret.insert(ret.end(), tmp.begin(), tmp.end()); } @@ -176,6 +175,15 @@ inline static void GlobalInfo(const std::vector& op_labels, (*label2type)['.'] = LabelType::Batch; + if (sorted_labels->size()) { + std::set exist(all.begin(), all.end()); + all.clear(); + std::for_each( + sorted_labels->begin(), sorted_labels->end(), [&exist, &all](char c) { + if (exist.count(c)) all.push_back(c); + }); + } + *sorted_labels = TransformLabelsOrder(all, *label2type, {LabelType::Batch, @@ -409,7 +417,8 @@ DenseTensor PerformContraction( const LabelMap& label2shape, const std::vector>& ellipsis_dims, const std::vector& broadcast_dims, - std::vector cache) { + std::vector cache, + bool use_cache) { // Get All the Batches, so perm is auto all_valid = LabelMap(1); auto recover_dim = GetShapeByType(all_labels, @@ -447,14 +456,17 @@ DenseTensor PerformContraction( } // reduction DenseTensor trans_t; - if (cache[operand_idx]->IsInitialized()) { + if (use_cache && cache[operand_idx] != nullptr && + cache[operand_idx]->IsInitialized()) { trans_t.ShareBufferWith(*(cache[operand_idx])); + VLOG(5) << "Cache Used!"; } else { auto reduct_t = PerformReduction( dev_ctx, t, perm, all_labels, ellipsis, label2type); trans_t = PerformTranspose( dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type); - cache[operand_idx]->ShareBufferWith(trans_t); + if (cache[operand_idx] != nullptr) + cache[operand_idx]->ShareBufferWith(trans_t); } auto mul_dims = GetShapeByType(all_labels, label2type, @@ -515,18 +527,23 @@ void TransposeToOutput(const Context& dev_ctx, axis.push_back(it - all_labels.begin() + offset); } } - if (is_no_need_transpose(axis)) return output->ShareBufferWith(to_trans); + if (is_no_need_transpose(axis)) { + output->ShareBufferWith(to_trans); + return; + } VLOG(5) << "call TransposeToOutput: with axis: " << paddle::string::join_strings(axis, ","); - return TransposeKernel(dev_ctx, to_trans, axis, output); + TransposeKernel(dev_ctx, to_trans, axis, output); } template void EinsumKernelImpl(const Context& dev_ctx, + const std::vector& forward_all_labels, const std::vector& inputs, const std::string& equation, DenseTensor* out, - std::vector cache) { + std::vector cache, + bool is_forward = true) { ValidationCheck(equation); // collect the following informations to prepare einsum. LabelMap labelshape(0); @@ -542,6 +559,9 @@ void EinsumKernelImpl(const Context& dev_ctx, input_dims.push_back(i->dims()); } std::string right; + if (!is_forward) { + all_labels = forward_all_labels; + } ParseEinsumEquation(equation, input_dims, &labelshape, @@ -557,7 +577,6 @@ void EinsumKernelImpl(const Context& dev_ctx, auto& A = inputs[0]; auto& B = inputs[1]; // Reduction and Contract Procedure - dev_ctx.template Alloc(out); auto after_contraction = PerformContraction(dev_ctx, *A, *B, @@ -567,7 +586,8 @@ void EinsumKernelImpl(const Context& dev_ctx, labelshape, ellipsis_dims, broadcast_dims, - cache); + cache, + !is_forward); TransposeToOutput(dev_ctx, after_contraction, right, @@ -599,18 +619,37 @@ void EinsumKernelImpl(const Context& dev_ctx, } } +template +void EinsumKernelRaw(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out, + std::vector cache) { + std::vector tmp; + // for the sake of compatibility, we may load and run v2.3 EinsumOp. Output + // may have nullptr and the cache.size() is not equal to inputs.size(). refer + // to BuildPhiKernelContext for details. + int diff = inputs.size() - cache.size(); + for (int i = 0; i < diff; ++i) { + cache.push_back(nullptr); + } + EinsumKernelImpl( + dev_ctx, tmp, inputs, equation, out, cache, /*forward=*/true); +} + template void EinsumKernel(const Context& dev_ctx, const std::vector& inputs, const std::string& equation, DenseTensor* out) { - std::vector cache(inputs.size()); // set empty; TA, TB, TdC + std::vector place_holder; std::vector cache_tensor( inputs.size()); // set empty; TA, TB, TdC for (size_t i = 0; i < inputs.size(); ++i) { - cache_tensor[i] = &cache[i]; + cache_tensor[i] = nullptr; } - EinsumKernelImpl(dev_ctx, inputs, equation, out, cache_tensor); + EinsumKernelImpl( + dev_ctx, place_holder, inputs, equation, out, cache_tensor, true); } } // namespace phi diff --git a/paddle/phi/ops/compat/einsum_sig.cc b/paddle/phi/ops/compat/einsum_sig.cc index 0b3cc3425df..5e45bcf97ce 100644 --- a/paddle/phi/ops/compat/einsum_sig.cc +++ b/paddle/phi/ops/compat/einsum_sig.cc @@ -17,14 +17,15 @@ limitations under the License. */ namespace phi { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"}); + return KernelSignature( + "einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"}); } KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("einsum_grad", - {"Operands", {"Out@GRAD"}}, + {"Operands", "InnerCache", "Out@GRAD"}, {"equation"}, - {{"Operands@GRAD"}}); + {"Operands@GRAD"}); } } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_einsum_op.py b/python/paddle/fluid/tests/unittests/test_einsum_op.py index 565e43214ea..1a4ae54afef 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_op.py @@ -34,7 +34,11 @@ class TestEinsumBinary(OpTest): self.operands.append(("x" + str(idx), inp)) self.inputs = {"Operands": self.operands} self.attrs = {"equation": self.equation} - self.outputs = {'Out': out} + self.outputs = { + 'Out': out, + "InnerCache": [('cache_' + str(i), np.array([1.0])) + for i in range(len(self.operands))] + } def init_input(self): self.inputs = [] @@ -49,7 +53,7 @@ class TestEinsumBinary(OpTest): def test_check_output(self): if not self.disable: - self.check_output() + self.check_output(no_check_set=["InnerCache"]) def test_grad(self): if not self.disable: diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index 23bbc377cae..ea3264ba0db 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -35,4 +35,5 @@ no_check_set_white_list = [ 'eigh', 'eigvalsh', 'class_center_sample', + 'einsum', ] diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index 713a611f9f3..4cdbebb0552 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -798,11 +798,12 @@ def gen_einsum_op(equation, *operands): """ assert len(operands) <= 2, "Only support two operands in EinsumOp." if in_dygraph_mode(): - return _C_ops.final_state_einsum(operands, equation) + return _C_ops.final_state_einsum(operands, equation)[0] if _in_legacy_dygraph(): # dygraph - return _C_ops.einsum(operands, 'equation', equation) + return _C_ops.einsum(operands, len(operands), 'equation', equation)[0] + # static graph for inp in operands: check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum') @@ -811,11 +812,16 @@ def gen_einsum_op(equation, *operands): out = helper.create_variable_for_type_inference(dtype=operands[0].dtype) attrs = dict() attrs['equation'] = equation + caches = [ + helper.create_variable_for_type_inference(dtype=operands[0].dtype) + for i in range(len(operands)) + ] helper.append_op( type='einsum', inputs={'Operands': operands}, - outputs={'Out': out}, - attrs=attrs, ) + outputs={'Out': out, + "InnerCache": caches}, + attrs=attrs) return out diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index f9e0efa5950..c5418916628 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -585,7 +585,7 @@ - api : einsum args : (Tensor[] x, str equation) - output : Tensor + output : Tensor, Tensor[]{x.size()} infer_meta : func : EinsumInferMeta param : [x, equation] diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index ac9a4315937..146925ccef6 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -224,16 +224,18 @@ class BaseAPI(object): if len(temp_list) == 1: out_type, out_name, size_expr = parse_output_item(temp_list[0]) - return [out_type], [out_name], size_expr + return [out_type], [out_name], [size_expr] else: out_type_list = [] out_name_list = [] + out_size_expr_list = [] for output_item in temp_list: out_type, out_name, size_expr = parse_output_item(output_item) out_type_list.append(out_type) out_name_list.append(out_name) + out_size_expr_list.append(size_expr) - return out_type_list, out_name_list, size_expr + return out_type_list, out_name_list, out_size_expr_list def parse_infer_meta(self, infer_meta_config): infer_meta = infer_meta_config diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index 4e98985c9b1..c0923adf39c 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -111,10 +111,10 @@ class ForwardAPI(BaseAPI): {code_indent} {return_type} api_output{inplace_assign};""" if return_type == 'std::vector': - assert self.outputs['out_size_expr'] is not None, \ + assert self.outputs['out_size_expr'][0] is not None, \ f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." output_create = output_create + f""" -{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr']}, kernel_backend, &api_output);""" +{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, kernel_backend, &api_output);""" else: output_create = output_create + f""" diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index eb00e2e615f..81c211e6407 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -552,8 +552,8 @@ skip_transform : out_w, out_w_grad - backward_api : einsum_grad - forward : einsum (Tensor[] x, str equation) -> Tensor(out) - args : (Tensor[] x, Tensor out_grad, str equation) + forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache) + args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation) output : Tensor[](x_grad){x.size()} infer_meta : func : UnchangedMultiInferMeta -- GitLab