From 9b6a02d4563cef827ebf03a3f010f214dcb0931d Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 1 Apr 2022 10:04:24 +0800 Subject: [PATCH] [Phi] Add shape and strided_slice yaml & Adapt eager mode (#41131) * add several yaml * polish strided slice kernel & add yaml * reorder yaml * add several yaml * revert yaml config change * resolve conflict * Update test_strided_slice_op.py --- paddle/fluid/operators/strided_slice_op.cc | 2 +- paddle/phi/infermeta/unary.cc | 31 +- paddle/phi/infermeta/unary.h | 12 +- .../kernels/cpu/strided_slice_grad_kernel.cc | 4 +- .../phi/kernels/cpu/strided_slice_kernel.cc | 4 +- .../kernels/gpu/strided_slice_grad_kernel.cu | 4 +- .../phi/kernels/gpu/strided_slice_kernel.cu | 4 +- .../impl/strided_slice_grad_kernel_impl.h | 20 +- .../kernels/impl/strided_slice_kernel_impl.h | 18 +- .../phi/kernels/strided_slice_grad_kernel.cc | 69 +++ .../phi/kernels/strided_slice_grad_kernel.h | 14 +- paddle/phi/kernels/strided_slice_kernel.cc | 60 +++ paddle/phi/kernels/strided_slice_kernel.h | 13 +- paddle/phi/ops/compat/strided_slice_sig.cc | 424 +++--------------- python/paddle/fluid/layers/nn.py | 10 +- .../fluid/tests/unittests/test_shape_op.py | 4 +- .../tests/unittests/test_strided_slice_op.py | 7 +- python/paddle/utils/code_gen/api.yaml | 17 + python/paddle/utils/code_gen/backward.yaml | 10 + 19 files changed, 317 insertions(+), 410 deletions(-) create mode 100644 paddle/phi/kernels/strided_slice_grad_kernel.cc create mode 100644 paddle/phi/kernels/strided_slice_kernel.cc diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index 0ff7d654fc..6f092bbef0 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -228,7 +228,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer, namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(strided_slice, StridedSliceInferShape, - PD_INFER_META(phi::StridedSliceInferMeta)); + PD_INFER_META(phi::StridedSliceRawInferMeta)); REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker, ops::StridedSliceOpGradMaker, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index d763b23ef5..6bf7a36b06 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1922,15 +1922,15 @@ void SqueezeInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } -void StridedSliceInferMeta(const MetaTensor& x, - const std::vector& axes, - const IntArray& starts, - const IntArray& ends, - const IntArray& strides, - const std::vector& infer_flags, - const std::vector& decrease_axis, - MetaTensor* out, - MetaConfig config) { +void StridedSliceRawInferMeta(const MetaTensor& x, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + MetaTensor* out, + MetaConfig config) { auto in_dims = x.dims(); PADDLE_ENFORCE_LT( in_dims.size(), @@ -2052,6 +2052,19 @@ void StridedSliceInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void StridedSliceInferMeta(const MetaTensor& x, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + MetaTensor* out, + MetaConfig config) { + std::vector infer_flags(axes.size(), 1); + std::vector decrease_axis; + StridedSliceRawInferMeta( + x, axes, starts, ends, strides, infer_flags, decrease_axis, out, config); +} + /* Why not use SumRawInferMeta directly? Because we need make InferMetaFunction's args follow the design of api.yaml */ diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 7ab0f3df2a..54f70d8d55 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -284,13 +284,21 @@ void SqueezeInferMeta(const MetaTensor& x, MetaTensor* xshape, MetaTensor* out); +void StridedSliceRawInferMeta(const MetaTensor& x, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void StridedSliceInferMeta(const MetaTensor& x, const std::vector& axes, const IntArray& starts, const IntArray& ends, const IntArray& strides, - const std::vector& infer_flags, - const std::vector& decrease_axis, MetaTensor* out, MetaConfig config = MetaConfig()); diff --git a/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc b/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc index cdc5534d63..e6c812cf6b 100644 --- a/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc @@ -19,10 +19,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h" -PD_REGISTER_KERNEL(strided_slice_grad, +PD_REGISTER_KERNEL(strided_slice_raw_grad, CPU, ALL_LAYOUT, - phi::StridedSliceGradKernel, + phi::StridedSliceRawGradKernel, bool, int, int64_t, diff --git a/paddle/phi/kernels/cpu/strided_slice_kernel.cc b/paddle/phi/kernels/cpu/strided_slice_kernel.cc index f34a3301fc..d0aa7b2f4c 100644 --- a/paddle/phi/kernels/cpu/strided_slice_kernel.cc +++ b/paddle/phi/kernels/cpu/strided_slice_kernel.cc @@ -19,10 +19,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h" -PD_REGISTER_KERNEL(strided_slice, +PD_REGISTER_KERNEL(strided_slice_raw, CPU, ALL_LAYOUT, - phi::StridedSliceKernel, + phi::StridedSliceRawKernel, bool, int, int64_t, diff --git a/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu b/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu index 5f31d48853..90d9f1d986 100644 --- a/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu @@ -19,10 +19,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h" -PD_REGISTER_KERNEL(strided_slice_grad, +PD_REGISTER_KERNEL(strided_slice_raw_grad, GPU, ALL_LAYOUT, - phi::StridedSliceGradKernel, + phi::StridedSliceRawGradKernel, bool, int, int64_t, diff --git a/paddle/phi/kernels/gpu/strided_slice_kernel.cu b/paddle/phi/kernels/gpu/strided_slice_kernel.cu index ff10718edb..716150ff47 100644 --- a/paddle/phi/kernels/gpu/strided_slice_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_slice_kernel.cu @@ -19,10 +19,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h" -PD_REGISTER_KERNEL(strided_slice, +PD_REGISTER_KERNEL(strided_slice_raw, GPU, ALL_LAYOUT, - phi::StridedSliceKernel, + phi::StridedSliceRawKernel, bool, int, int64_t, diff --git a/paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h b/paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h index f0fddce6b5..95780682c9 100644 --- a/paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h @@ -20,16 +20,16 @@ namespace phi { template -void StridedSliceGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& out_grad, - const std::vector& axes, - const IntArray& starts, - const IntArray& ends, - const IntArray& strides, - const std::vector& infer_flags, - const std::vector& decrease_axis, - DenseTensor* x_grad) { +void StridedSliceRawGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* x_grad) { int rank = x.dims().size(); #define SLICE_CASE(Rank) \ case Rank: \ diff --git a/paddle/phi/kernels/impl/strided_slice_kernel_impl.h b/paddle/phi/kernels/impl/strided_slice_kernel_impl.h index 2df937524e..81e6d50562 100644 --- a/paddle/phi/kernels/impl/strided_slice_kernel_impl.h +++ b/paddle/phi/kernels/impl/strided_slice_kernel_impl.h @@ -20,15 +20,15 @@ namespace phi { template -void StridedSliceKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& axes, - const IntArray& starts, - const IntArray& ends, - const IntArray& strides, - const std::vector& infer_flags, - const std::vector& decrease_axis, - DenseTensor* out) { +void StridedSliceRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* out) { int rank = x.dims().size(); #define SLICE_CASE(Rank) \ case Rank: \ diff --git a/paddle/phi/kernels/strided_slice_grad_kernel.cc b/paddle/phi/kernels/strided_slice_grad_kernel.cc new file mode 100644 index 0000000000..38dd360ea6 --- /dev/null +++ b/paddle/phi/kernels/strided_slice_grad_kernel.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/strided_slice_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void StridedSliceGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + DenseTensor* x_grad) { + std::vector infer_flags(axes.size(), 1); + std::vector decrease_axis; + StridedSliceRawGradKernel(dev_ctx, + x, + out_grad, + axes, + starts, + ends, + strides, + infer_flags, + decrease_axis, + x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(strided_slice_grad, + CPU, + ALL_LAYOUT, + phi::StridedSliceGradKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(strided_slice_grad, + GPU, + ALL_LAYOUT, + phi::StridedSliceGradKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} +#endif diff --git a/paddle/phi/kernels/strided_slice_grad_kernel.h b/paddle/phi/kernels/strided_slice_grad_kernel.h index 07fba9d27b..21d01310b6 100644 --- a/paddle/phi/kernels/strided_slice_grad_kernel.h +++ b/paddle/phi/kernels/strided_slice_grad_kernel.h @@ -19,6 +19,18 @@ namespace phi { +template +void StridedSliceRawGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* x_grad); + template void StridedSliceGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -27,8 +39,6 @@ void StridedSliceGradKernel(const Context& dev_ctx, const IntArray& starts, const IntArray& ends, const IntArray& strides, - const std::vector& infer_flags, - const std::vector& decrease_axis, DenseTensor* x_grad); template diff --git a/paddle/phi/kernels/strided_slice_kernel.cc b/paddle/phi/kernels/strided_slice_kernel.cc new file mode 100644 index 0000000000..547d574cd7 --- /dev/null +++ b/paddle/phi/kernels/strided_slice_kernel.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/strided_slice_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void StridedSliceKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + DenseTensor* out) { + std::vector infer_flags(axes.size(), 1); + std::vector decrease_axis; + StridedSliceRawKernel( + dev_ctx, x, axes, starts, ends, strides, infer_flags, decrease_axis, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(strided_slice, + CPU, + ALL_LAYOUT, + phi::StridedSliceKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(strided_slice, + GPU, + ALL_LAYOUT, + phi::StridedSliceKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} +#endif diff --git a/paddle/phi/kernels/strided_slice_kernel.h b/paddle/phi/kernels/strided_slice_kernel.h index fd90d81b85..2c8b373bf0 100644 --- a/paddle/phi/kernels/strided_slice_kernel.h +++ b/paddle/phi/kernels/strided_slice_kernel.h @@ -19,6 +19,17 @@ namespace phi { +template +void StridedSliceRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* out); + template void StridedSliceKernel(const Context& dev_ctx, const DenseTensor& x, @@ -26,8 +37,6 @@ void StridedSliceKernel(const Context& dev_ctx, const IntArray& starts, const IntArray& ends, const IntArray& strides, - const std::vector& infer_flags, - const std::vector& decrease_axis, DenseTensor* out); template diff --git a/paddle/phi/ops/compat/strided_slice_sig.cc b/paddle/phi/ops/compat/strided_slice_sig.cc index 70ce2e3e07..9fb70af0de 100644 --- a/paddle/phi/ops/compat/strided_slice_sig.cc +++ b/paddle/phi/ops/compat/strided_slice_sig.cc @@ -57,14 +57,14 @@ KernelSignature StridedSliceOpArgumentMapping( "decrease_axis"}; paddle::SmallVector outputs = {"Out"}; - std::string op_type; + std::string kernel_name; if (ctx.IsDenseTensorVectorInput("Input")) { - op_type = "strided_slice_array"; + kernel_name = "strided_slice_array"; } else { - op_type = "strided_slice"; + kernel_name = "strided_slice_raw"; } // NOTE(dev): Use this to avoid regularization. - KernelSignature sig(op_type, inputs, attrs, outputs); + KernelSignature sig(kernel_name, inputs, attrs, outputs); return sig; } @@ -106,15 +106,15 @@ KernelSignature StridedSliceGradOpArgumentMapping( "decrease_axis"}; paddle::SmallVector outputs = {GradVarName("Input")}; - std::string op_type; + std::string kernel_name; if (ctx.IsDenseTensorVectorInput("Input")) { - op_type = "strided_slice_array_grad"; + kernel_name = "strided_slice_array_grad"; } else { - op_type = "strided_slice_grad"; + kernel_name = "strided_slice_raw_grad"; } // NOTE(dev): Use this to avoid regularization. - KernelSignature sig(op_type, inputs, attrs, outputs); + KernelSignature sig(kernel_name, inputs, attrs, outputs); return sig; } @@ -132,573 +132,273 @@ NOTE: The following codes are for 'get_compat_kernel_signature.py' ############################ Forward ############################ -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensor", "EndsTensor", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensor", "EndsTensor", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensor", "EndsTensor", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensor", "EndsTensorList", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensor", "EndsTensorList", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensor", "ends", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensor", "ends", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensorList", "EndsTensor", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensorList", "EndsTensor", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensorList", "EndsTensorList", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensorList", "EndsTensorList", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensorList", "EndsTensorList", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensorList", "ends", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "StartsTensorList", "ends", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "starts", "EndsTensor", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "starts", "EndsTensor", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "starts", "EndsTensorList", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "starts", "EndsTensorList", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "starts", "ends", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "starts", "ends", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice}", {"Input"}, +return KernelSignature("strided_slice_raw", {"Input"}, {"axes", "starts", "ends", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensor", "EndsTensor", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensor", "EndsTensor", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensor", "EndsTensor", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensor", "EndsTensorList", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensor", "EndsTensorList", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensor", "ends", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensor", "ends", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensorList", "EndsTensor", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensorList", "EndsTensor", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensorList", "EndsTensorList", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensorList", "EndsTensorList", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensorList", "EndsTensorList", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensorList", "ends", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "StartsTensorList", "ends", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "starts", "EndsTensor", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "starts", "EndsTensor", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "starts", "EndsTensorList", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "starts", "EndsTensorList", "starts","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "starts", "ends", "StartsTensor","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "starts", "ends", "StartsTensorList","infer_flags", "decrease_axis"}, {"Out"}); -return KernelSignature("{strided_slice_array}", {"Input"}, +return KernelSignature("strided_slice_array", {"Input"}, {"axes", "starts", "ends", "starts","infer_flags", "decrease_axis"}, {"Out"}); - -############################ Backward ############################ - - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensor", "EndsTensor", -"StartsTensor","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensor", "EndsTensor", -"StartsTensorList","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensor", "EndsTensor", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensor", "EndsTensorList", -"StartsTensor","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensor", "EndsTensorList", -"StartsTensorList","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensor", "ends", "StartsTensor","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensor", "ends", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensorList", "EndsTensor", -"StartsTensor","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensorList", "EndsTensor", -"StartsTensorList","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensorList", "EndsTensorList", -"StartsTensor","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensorList", "EndsTensorList", -"StartsTensorList","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensorList", "EndsTensorList", -"starts","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensorList", "ends", -"StartsTensorList","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "StartsTensorList", "ends", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "starts", "EndsTensor", "StartsTensor","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "starts", "EndsTensor", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "starts", "EndsTensorList", -"StartsTensorList","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "starts", "EndsTensorList", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "starts", "ends", "StartsTensor","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "starts", "ends", "StartsTensorList","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, - {"axes", "starts", "ends", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensor", "EndsTensor", -"StartsTensor","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensor", "EndsTensor", -"StartsTensorList","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensor", "EndsTensor", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensor", "EndsTensorList", -"StartsTensor","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensor", "EndsTensorList", -"StartsTensorList","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensor", "ends", "StartsTensor","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensor", "ends", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensorList", "EndsTensor", -"StartsTensor","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensorList", "EndsTensor", -"StartsTensorList","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensorList", "EndsTensorList", -"StartsTensor","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensorList", "EndsTensorList", -"StartsTensorList","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensorList", "EndsTensorList", -"starts","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensorList", "ends", -"StartsTensorList","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "StartsTensorList", "ends", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "starts", "EndsTensor", "StartsTensor","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "starts", "EndsTensor", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "starts", "EndsTensorList", -"StartsTensorList","infer_flags", "decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "starts", "EndsTensorList", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "starts", "ends", "StartsTensor","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "starts", "ends", "StartsTensorList","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); - -return KernelSignature("{strided_slice_array_grad}", {"Input", -GradVarName("Out")}, - {"axes", "starts", "ends", "starts","infer_flags", -"decrease_axis"}, - {GradVarName("Input")}); */ diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index cb3781d5c2..0be014394f 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11426,6 +11426,10 @@ def strided_slice(input, axes, starts, ends, strides): sliced_2 = fluid.layers.strided_slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides_2) # sliced_2 is input[:, 0:3:1, 0:2:1, 2:4:2]. """ + if in_dygraph_mode(): + return _C_ops.final_state_strided_slice(input, axes, starts, ends, + strides) + helper = LayerHelper('strided_slice', **locals()) check_variable_and_dtype(input, 'input', @@ -11590,7 +11594,11 @@ def shape(input): res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output]) print(res) # [array([ 3, 100, 100], dtype=int32)] """ - if _non_static_mode(): + if in_dygraph_mode(): + out = _C_ops.final_state_shape(input) + out.stop_gradient = True + return out + if _in_legacy_dygraph(): out = _C_ops.shape(input) out.stop_gradient = True return out diff --git a/python/paddle/fluid/tests/unittests/test_shape_op.py b/python/paddle/fluid/tests/unittests/test_shape_op.py index bada62e323..3d961a7413 100644 --- a/python/paddle/fluid/tests/unittests/test_shape_op.py +++ b/python/paddle/fluid/tests/unittests/test_shape_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle from paddle.fluid import core from paddle.fluid.op import Operator @@ -24,6 +25,7 @@ from paddle.fluid.op import Operator class TestShapeOp(OpTest): def setUp(self): self.op_type = "shape" + self.python_api = paddle.shape self.config() self.shape = [2, 3] input = np.zeros(self.shape) @@ -34,7 +36,7 @@ class TestShapeOp(OpTest): self.shape = [2, 3] def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class case1(TestShapeOp): diff --git a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py index e9be6b338f..ae17cb9b1b 100644 --- a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py @@ -58,6 +58,7 @@ class TestStrideSliceOp(OpTest): def setUp(self): self.initTestCase() self.op_type = 'strided_slice' + self.python_api = paddle.strided_slice self.output = strided_slice_native_forward( self.input, self.axes, self.starts, self.ends, self.strides) @@ -72,10 +73,10 @@ class TestStrideSliceOp(OpTest): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(set(['Input']), 'Out') + self.check_grad(set(['Input']), 'Out', check_eager=True) def initTestCase(self): self.input = np.random.rand(100) @@ -704,7 +705,7 @@ class TestStridedSliceTensorArray(unittest.TestCase): l2.sum().backward() grads_static = net.get_all_grads() net.clear_all_grad() - # compare result of dygraph and static + # compare result of dygraph and static self.is_grads_equal(grads_static, grads_dy) self.assertTrue( np.array_equal(s1, s2), diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 5499c81c7e..c89e519f80 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -951,6 +951,14 @@ func : selu backward : selu_grad +- api : shape + args : (Tensor input) + output : Tensor + infer_meta : + func : ShapeInferMeta + kernel : + func : shape, shape_sr + # shard_index - api : shard_index args : (Tensor in, int index_num, int nshards, int shard_id, int ignore_value) @@ -1070,6 +1078,15 @@ func : square backward : square_grad +- api : strided_slice + args : (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides) + output : Tensor + infer_meta : + func : StridedSliceInferMeta + kernel : + func : strided_slice + backward : strided_slice_grad + - api : subtract args : (Tensor x, Tensor y) output : Tensor diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 5efe6e7451..3830d7f926 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -660,6 +660,16 @@ kernel : func : square_grad +- backward_api : strided_slice_grad + forward : strided_slice (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int[] axes, IntArray starts, IntArray ends, IntArray strides) + output : Tensor(x_grad) + infer_meta : + func : GeneralUnaryGradInferMeta + param : [x] + kernel : + func : strided_slice_grad + - backward_api : subtract_grad forward : subtract (Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1) -- GitLab