From 0bfba16b66ca5c496760b01ffac56c788444decb Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Fri, 4 Mar 2022 14:51:57 +0800 Subject: [PATCH] Add digamma abs trunc yaml (#40024) * add digamma, abs, trunc; test=develop * fix bug and add diagonal; test=develop * add name coverter; test=develop * update tracer.py; test=develop * add test case; test=develop * fix bugs; test=develop --- paddle/fluid/operators/diagonal_op.cc | 77 ++----------------- paddle/phi/infermeta/backward.h | 1 + paddle/phi/infermeta/unary.cc | 75 ++++++++++++++++++ paddle/phi/infermeta/unary.h | 3 + paddle/phi/kernels/cpu/norm_grad_kernel.cc | 2 +- paddle/phi/kernels/digamma_grad_kernel.h | 2 +- paddle/phi/kernels/gpu/norm_grad_kernel.cu | 2 +- .../kernels/impl/digamma_grad_kernel_impl.h | 2 +- paddle/phi/kernels/norm_grad_kernel.h | 2 +- paddle/phi/ops/compat/digamma_sig.cc | 2 +- paddle/phi/ops/compat/norm_sig.cc | 2 +- python/paddle/fluid/dygraph/tracer.py | 23 ++++++ .../fluid/layers/layer_function_generator.py | 2 +- .../tests/unittests/test_activation_op.py | 2 +- .../fluid/tests/unittests/test_diagonal_op.py | 19 +++++ .../fluid/tests/unittests/test_trunc_op.py | 10 +++ python/paddle/tensor/math.py | 8 +- python/paddle/utils/code_gen/api.yaml | 46 +++++++++++ python/paddle/utils/code_gen/backward.yaml | 55 +++++++++++++ 19 files changed, 256 insertions(+), 79 deletions(-) diff --git a/paddle/fluid/operators/diagonal_op.cc b/paddle/fluid/operators/diagonal_op.cc index b419f629a1e..20813f8bb44 100644 --- a/paddle/fluid/operators/diagonal_op.cc +++ b/paddle/fluid/operators/diagonal_op.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -20,74 +23,6 @@ namespace operators { class DiagonalOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "diagonal"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "diagonal"); - - int offset_ = ctx->Attrs().Get("offset"); - int axis1 = ctx->Attrs().Get("axis1"); - int axis2 = ctx->Attrs().Get("axis2"); - - auto x_dims = ctx->GetInputDim("Input"); - int axis1_ = axis1 < 0 ? x_dims.size() + axis1 : axis1; - int axis2_ = axis2 < 0 ? x_dims.size() + axis2 : axis2; - - PADDLE_ENFORCE_GE( - x_dims.size(), 2, - platform::errors::OutOfRange("Input's dim is out of range (expected at " - "least 2 dimensions, but got %ld).", - x_dims.size())); - PADDLE_ENFORCE_LT( - axis1_, x_dims.size(), - platform::errors::OutOfRange( - "Attr(axis1) is out of range (expected to be in range of [%ld, " - "%ld], but got %ld).", - -(x_dims.size()), (x_dims.size() - 1), axis1)); - PADDLE_ENFORCE_LT( - axis2_, x_dims.size(), - platform::errors::OutOfRange( - "Attr(axis2) is out of range (expected to be in range of [%ld, " - "%ld], but got %ld).", - -(x_dims.size()), (x_dims.size() - 1), axis2)); - PADDLE_ENFORCE_NE(axis1_, axis2_, - platform::errors::InvalidArgument( - "The dimensions should not be identical " - "%d vs %d.", - axis1, axis2)); - - auto out_dims = vectorize(x_dims); - // from out_dims get the dim size of axis1_. - auto axis1_size = out_dims[axis1_]; - auto axis2_size = out_dims[axis2_]; - // delete two dims by attr axis1 and axis2 from out_dims. - /* example: - out_dim = [2, 3, 4]; - axis1 = 0; - axis2 = 1; - according to the attr of axis1 and axis2, we get: - out_dim = [4]. - */ - out_dims.erase(out_dims.begin() + std::max(axis1_, axis2_)); - out_dims.erase(out_dims.begin() + std::min(axis1_, axis2_)); - - if (offset_ == 0) { - out_dims.push_back(std::min(axis1_size, axis2_size)); - } else if (offset_ > 0) { - if ((axis2_size - offset_) > 0) { - out_dims.push_back(std::min(axis1_size, axis2_size - offset_)); - } else { - out_dims.push_back(0); - } - } else { - if ((axis1_size + offset_) > 0) { - out_dims.push_back(std::min(axis1_size + offset_, axis2_size)); - } else { - out_dims.push_back(0); - } - } - ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); - } }; class DiagonalOpMaker : public framework::OpProtoAndCheckerMaker { @@ -170,9 +105,13 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(DiagonalGradNoNeedBufferVarsInferer, namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(diagonal, DiagonalInferShapeFunctor, + PT_INFER_META(phi::DiagonalInferMeta)); + REGISTER_OPERATOR(diagonal, ops::DiagonalOp, ops::DiagonalOpMaker, ops::DiagonalGradOpMaker, - ops::DiagonalGradOpMaker); + ops::DiagonalGradOpMaker, + DiagonalInferShapeFunctor); REGISTER_OPERATOR(diagonal_grad, ops::DiagonalGradOp, ops::DiagonalGradNoNeedBufferVarsInferer) diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index c7090ed664b..f2c0cf8a689 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/infermeta/unary.h" namespace phi { diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index ff58c53ad9b..85db1547f16 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -706,6 +706,81 @@ void TraceInferMeta( out->set_dims(phi::make_ddim(sizes)); } +void DiagonalInferMeta(const MetaTensor& input, + int offset, + int axis1, + int axis2, + MetaTensor* out) { + auto x_dims = input.dims(); + int offset_ = offset; + int axis1_ = axis1 < 0 ? x_dims.size() + axis1 : axis1; + int axis2_ = axis2 < 0 ? x_dims.size() + axis2 : axis2; + + PADDLE_ENFORCE_GE( + x_dims.size(), + 2, + phi::errors::OutOfRange("Input's dim is out of range (expected at " + "least 2 dimensions, but got %ld).", + x_dims.size())); + PADDLE_ENFORCE_LT( + axis1_, + x_dims.size(), + phi::errors::OutOfRange( + "Attr(axis1) is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size()), + (x_dims.size() - 1), + axis1)); + PADDLE_ENFORCE_LT( + axis2_, + x_dims.size(), + phi::errors::OutOfRange( + "Attr(axis2) is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size()), + (x_dims.size() - 1), + axis2)); + PADDLE_ENFORCE_NE( + axis1_, + axis2_, + phi::errors::InvalidArgument("The dimensions should not be identical " + "%d vs %d.", + axis1, + axis2)); + + auto out_dims = vectorize(x_dims); + // from out_dims get the dim size of axis1_. + auto axis1_size = out_dims[axis1_]; + auto axis2_size = out_dims[axis2_]; + // delete two dims by attr axis1 and axis2 from out_dims. + /* example: + out_dim = [2, 3, 4]; + axis1 = 0; + axis2 = 1; + according to the attr of axis1 and axis2, we get: + out_dim = [4]. + */ + out_dims.erase(out_dims.begin() + std::max(axis1_, axis2_)); + out_dims.erase(out_dims.begin() + std::min(axis1_, axis2_)); + + if (offset_ == 0) { + out_dims.push_back(std::min(axis1_size, axis2_size)); + } else if (offset_ > 0) { + if ((axis2_size - offset_) > 0) { + out_dims.push_back(std::min(axis1_size, axis2_size - offset_)); + } else { + out_dims.push_back(0); + } + } else { + if ((axis1_size + offset_) > 0) { + out_dims.push_back(std::min(axis1_size + offset_, axis2_size)); + } else { + out_dims.push_back(0); + } + } + out->set_dims(phi::make_ddim(out_dims)); +} + void UnfoldInferMeta(const MetaTensor& x, const std::vector& kernel_sizes, const std::vector& strides, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 97ec6f7fa58..d4e21fbd824 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -140,6 +140,9 @@ void DiagInferMeta(const MetaTensor& x, void SizeInferMeta(const MetaTensor& input, MetaTensor* out); +void DiagonalInferMeta( + const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out); + void PixelShuffleInferMeta(const MetaTensor& x, int upscale_factor, const std::string& data_format, diff --git a/paddle/phi/kernels/cpu/norm_grad_kernel.cc b/paddle/phi/kernels/cpu/norm_grad_kernel.cc index 597207a05a2..bd05e2c4c6e 100644 --- a/paddle/phi/kernels/cpu/norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/norm_grad_kernel.cc @@ -26,9 +26,9 @@ namespace phi { template void NormGradKernel(const Context& ctx, - const DenseTensor& out_grad, const DenseTensor& x, const DenseTensor& norm, + const DenseTensor& out_grad, int axis, float epsilon, bool is_test, diff --git a/paddle/phi/kernels/digamma_grad_kernel.h b/paddle/phi/kernels/digamma_grad_kernel.h index 38912a5ccc4..ae5346080d3 100644 --- a/paddle/phi/kernels/digamma_grad_kernel.h +++ b/paddle/phi/kernels/digamma_grad_kernel.h @@ -20,8 +20,8 @@ namespace phi { template void DigammaGradKernel(const Context& ctx, - const DenseTensor& out_grad, const DenseTensor& x, + const DenseTensor& out_grad, DenseTensor* x_grad); } // namepsace phi diff --git a/paddle/phi/kernels/gpu/norm_grad_kernel.cu b/paddle/phi/kernels/gpu/norm_grad_kernel.cu index ab38a82eceb..43a08b0603e 100644 --- a/paddle/phi/kernels/gpu/norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/norm_grad_kernel.cu @@ -75,9 +75,9 @@ __global__ void NormalizeGradient(const T* x, template void NormGradKernel(const Context& ctx, - const DenseTensor& out_grad, const DenseTensor& x, const DenseTensor& norm, + const DenseTensor& out_grad, int axis, float epsilon, bool is_test, diff --git a/paddle/phi/kernels/impl/digamma_grad_kernel_impl.h b/paddle/phi/kernels/impl/digamma_grad_kernel_impl.h index 74ded1569eb..92550de1800 100644 --- a/paddle/phi/kernels/impl/digamma_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/digamma_grad_kernel_impl.h @@ -38,8 +38,8 @@ struct DigammaGradFunctor { template void DigammaGradKernel(const Context& ctx, - const DenseTensor& out_grad, const DenseTensor& x, + const DenseTensor& out_grad, DenseTensor* x_grad) { x_grad->mutable_data(ctx.GetPlace()); diff --git a/paddle/phi/kernels/norm_grad_kernel.h b/paddle/phi/kernels/norm_grad_kernel.h index 7b09d6463d0..55714b8a4a0 100644 --- a/paddle/phi/kernels/norm_grad_kernel.h +++ b/paddle/phi/kernels/norm_grad_kernel.h @@ -20,9 +20,9 @@ namespace phi { template void NormGradKernel(const Context& ctx, - const DenseTensor& out_grad, const DenseTensor& x, const DenseTensor& out, + const DenseTensor& out_grad, int axis, float epsilon, bool is_test, diff --git a/paddle/phi/ops/compat/digamma_sig.cc b/paddle/phi/ops/compat/digamma_sig.cc index fa693f92c6f..12ef3056f1e 100644 --- a/paddle/phi/ops/compat/digamma_sig.cc +++ b/paddle/phi/ops/compat/digamma_sig.cc @@ -19,7 +19,7 @@ namespace phi { KernelSignature DigammaGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( - "digamma_grad", {GradVarName("Out"), "X"}, {}, {GradVarName("X")}); + "digamma_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); } } // namespace phi diff --git a/paddle/phi/ops/compat/norm_sig.cc b/paddle/phi/ops/compat/norm_sig.cc index 81d294b8424..a74db9b5686 100644 --- a/paddle/phi/ops/compat/norm_sig.cc +++ b/paddle/phi/ops/compat/norm_sig.cc @@ -23,7 +23,7 @@ KernelSignature NormOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature NormGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("norm_grad", - {GradVarName("Out"), "X", "Norm"}, + {"X", "Norm", GradVarName("Out")}, {"axis", "epsilon", "is_test"}, {GradVarName("X")}); } diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index e0c594b07ae..563cd433910 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -29,6 +29,29 @@ final_state_name_mapping = { "x": "X", "y": "Y", "out": "Out", + }, + "trunc": { + "final_op_name": "final_state_trunc", + "x": "X", + "out": "Out", + }, + "abs": { + "final_op_name": "final_state_abs", + "x": "X", + "out": "Out", + }, + "digamma": { + "final_op_name": "final_state_digamma", + "x": "X", + "out": "Out", + }, + "diagonal": { + "final_op_name": "final_state_diagonal", + "x": "Input", + "offset": "offset", + "axis1": "axis1", + "axis2": "axis2", + "out": "Out", } } diff --git a/python/paddle/fluid/layers/layer_function_generator.py b/python/paddle/fluid/layers/layer_function_generator.py index 56af7e341fd..676ee3e3c77 100755 --- a/python/paddle/fluid/layers/layer_function_generator.py +++ b/python/paddle/fluid/layers/layer_function_generator.py @@ -20,7 +20,7 @@ import string from six.moves import cStringIO from ..proto import framework_pb2 -from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_, in_dygraph_mode +from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_, in_dygraph_mode, _in_eager_mode from ..layer_helper import LayerHelper from ..data_feeder import check_variable_and_dtype from paddle import _C_ops diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index d3d8fdd7031..b4b5944e27c 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -983,7 +983,7 @@ class TestAbs(TestActivation): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestCeil(TestActivation): diff --git a/python/paddle/fluid/tests/unittests/test_diagonal_op.py b/python/paddle/fluid/tests/unittests/test_diagonal_op.py index 4dab7c0df40..b4854aea52a 100644 --- a/python/paddle/fluid/tests/unittests/test_diagonal_op.py +++ b/python/paddle/fluid/tests/unittests/test_diagonal_op.py @@ -124,6 +124,25 @@ class TestDiagonalAPI(unittest.TestCase): self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True) paddle.enable_static() + def test_api_eager(self): + paddle.disable_static(self.place) + with _test_eager_guard(): + x_tensor = paddle.to_tensor(self.x) + out = paddle.diagonal(x_tensor) + out2 = paddle.diagonal(x_tensor, offset=0, axis1=2, axis2=1) + out3 = paddle.diagonal(x_tensor, offset=1, axis1=0, axis2=1) + out4 = paddle.diagonal(x_tensor, offset=0, axis1=1, axis2=2) + out_ref = np.diagonal(self.x) + self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True) + out2_ref = np.diagonal(self.x, offset=0, axis1=2, axis2=1) + self.assertEqual(np.allclose(out2.numpy(), out2_ref, rtol=1e-08), True) + out3_ref = np.diagonal(self.x, offset=1, axis1=0, axis2=1) + self.assertEqual(np.allclose(out3.numpy(), out3_ref, rtol=1e-08), True) + out4_ref = np.diagonal(self.x, offset=0, axis1=1, axis2=2) + self.assertEqual(np.allclose(out4.numpy(), out4_ref, rtol=1e-08), True) + + paddle.enable_static() + def test_api_eager_dygraph(self): with _test_eager_guard(): self.test_api_dygraph() diff --git a/python/paddle/fluid/tests/unittests/test_trunc_op.py b/python/paddle/fluid/tests/unittests/test_trunc_op.py index 08a35db3ac4..b70fa04adc1 100644 --- a/python/paddle/fluid/tests/unittests/test_trunc_op.py +++ b/python/paddle/fluid/tests/unittests/test_trunc_op.py @@ -79,6 +79,16 @@ class TestTruncAPI(unittest.TestCase): self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True) paddle.enable_static() + def test_api_eager(self): + paddle.disable_static(self.place) + + with _test_eager_guard(): + x_tensor = paddle.to_tensor(self.x) + out = paddle.trunc(x_tensor) + out_ref = np.trunc(self.x) + self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True) + paddle.enable_static() + def test_api_eager_dygraph(self): with _test_eager_guard(): self.test_api_dygraph() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index ce29e9dce81..9a013910565 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -27,7 +27,7 @@ from paddle.tensor import cast from paddle.tensor.attribute import _complex_to_real_dtype import paddle from paddle.static import Variable -from ..framework import core +from ..framework import core, _in_eager_mode from ..framework import _varbase_creator, convert_np_dtype_to_dtype_ from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype @@ -1083,6 +1083,8 @@ def trunc(input, name=None): # [0., 0.]])) ''' if paddle.in_dynamic_mode(): + if _in_eager_mode(): + return _C_ops.final_state_trunc(input) return _C_ops.trunc(input) else: inputs = {"X": input} @@ -2425,6 +2427,8 @@ def diagonal(x, offset=0, axis1=0, axis2=1, name=None): """ if paddle.in_dynamic_mode(): + if _in_eager_mode(): + return _C_ops.final_state_diagonal(x, offset, axis1, axis2) return _C_ops.diagonal(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2) def __check_input(input, offset, dim1, dim2): @@ -3184,6 +3188,8 @@ def digamma(x, name=None): """ if paddle.in_dynamic_mode(): + if _in_eager_mode(): + return _C_ops.final_state_digamma(x) return _C_ops.digamma(x) check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'digamma') diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 45a6aae5e6d..699e42f2373 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -193,3 +193,49 @@ args : (Tensor x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED) output : Tensor invoke : full_like(x, 0, dtype, place) + +- api : digamma + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : digamma + backward : digamma_grad + +- api : abs + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : abs + backward : abs_grad + +- api : trunc + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : trunc + backward : trunc_grad + +# - api : norm +# args : (Tensor x, int axis, float epsilon, bool is_test) +# output : Tensor(out), Tensor(norm) +# infer_meta : +# func : NormInferMeta +# kernel : +# func : norm +# intermediate : norm +# backward : norm_grad + +- api : diagonal + args : (Tensor x, int offset, int axis1, int axis2) + output : Tensor + infer_meta : + func : DiagonalInferMeta + kernel : + func : diagonal + backward : diagonal_grad diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index cdda5cb1f05..c69bbf35b97 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -25,6 +25,61 @@ output : Tensor(x_grad) invoke : scale(out_grad, scale, bias, bias_after_scale) +- backward_api : digamma_grad + forward : digamma (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : digamma_grad + +- backward_api : abs_grad + forward : abs (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : abs_grad + +- backward_api : trunc_grad + forward : trunc (Tensor x) -> Tensor(out) + args : (Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out_grad] + kernel : + func : trunc_grad + +# - backward_api : norm_grad +# forward : norm (Tensor x, int axis, float epsilon, bool is_test) -> Tensor(out), Tensor(norm) +# args : (Tensor out_grad, Tensor x, Tensor norm, int axis, float epsilon, bool is_test) +# output : Tensor(x_grad) +# infer_meta : +# func : UnchangedInferMeta +# param : [x] +# kernel : +# func : norm_grad + +- backward_api : diagonal_grad + forward : diagonal (Tensor x, int offset, int axis1, int axis2) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int offset = 0, int axis1 = 0, int axis2 = 1) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : diagonal_grad + +# - backward_api : split_grad +# forward : split (Tensor x, ScalarArray num_or_sections, Scalar axis) -> Tensor[](out) +# args : (Tensor[] out_grad, Scalar axis) +# output : Tensor(x_grad) +# invoke : concat( out_grad, axis) # TODO(zhangyunfei) The config of double grad and triple grad will be supported in the future. # - backward_api : matmul_triple_grad -- GitLab