未验证 提交 9b6a02d4 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Add shape and strided_slice yaml & Adapt eager mode (#41131)

* add several yaml

* polish strided slice kernel & add yaml

* reorder yaml

* add several yaml

* revert yaml config change

* resolve conflict

* Update test_strided_slice_op.py
上级 98303291
...@@ -228,7 +228,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer, ...@@ -228,7 +228,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(strided_slice, StridedSliceInferShape, DECLARE_INFER_SHAPE_FUNCTOR(strided_slice, StridedSliceInferShape,
PD_INFER_META(phi::StridedSliceInferMeta)); PD_INFER_META(phi::StridedSliceRawInferMeta));
REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker, REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker,
ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>, ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
......
...@@ -1922,15 +1922,15 @@ void SqueezeInferMeta(const MetaTensor& x, ...@@ -1922,15 +1922,15 @@ void SqueezeInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void StridedSliceInferMeta(const MetaTensor& x, void StridedSliceRawInferMeta(const MetaTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
const IntArray& starts, const IntArray& starts,
const IntArray& ends, const IntArray& ends,
const IntArray& strides, const IntArray& strides,
const std::vector<int>& infer_flags, const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis, const std::vector<int>& decrease_axis,
MetaTensor* out, MetaTensor* out,
MetaConfig config) { MetaConfig config) {
auto in_dims = x.dims(); auto in_dims = x.dims();
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
in_dims.size(), in_dims.size(),
...@@ -2052,6 +2052,19 @@ void StridedSliceInferMeta(const MetaTensor& x, ...@@ -2052,6 +2052,19 @@ void StridedSliceInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void StridedSliceInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
MetaTensor* out,
MetaConfig config) {
std::vector<int> infer_flags(axes.size(), 1);
std::vector<int> decrease_axis;
StridedSliceRawInferMeta(
x, axes, starts, ends, strides, infer_flags, decrease_axis, out, config);
}
/* Why not use SumRawInferMeta directly? /* Why not use SumRawInferMeta directly?
Because we need make InferMetaFunction's args follow the design of api.yaml Because we need make InferMetaFunction's args follow the design of api.yaml
*/ */
......
...@@ -284,13 +284,21 @@ void SqueezeInferMeta(const MetaTensor& x, ...@@ -284,13 +284,21 @@ void SqueezeInferMeta(const MetaTensor& x,
MetaTensor* xshape, MetaTensor* xshape,
MetaTensor* out); MetaTensor* out);
void StridedSliceRawInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
MetaTensor* out,
MetaConfig config = MetaConfig());
void StridedSliceInferMeta(const MetaTensor& x, void StridedSliceInferMeta(const MetaTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
const IntArray& starts, const IntArray& starts,
const IntArray& ends, const IntArray& ends,
const IntArray& strides, const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice_grad, PD_REGISTER_KERNEL(strided_slice_raw_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::StridedSliceGradKernel, phi::StridedSliceRawGradKernel,
bool, bool,
int, int,
int64_t, int64_t,
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h" #include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice, PD_REGISTER_KERNEL(strided_slice_raw,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::StridedSliceKernel, phi::StridedSliceRawKernel,
bool, bool,
int, int,
int64_t, int64_t,
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice_grad, PD_REGISTER_KERNEL(strided_slice_raw_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::StridedSliceGradKernel, phi::StridedSliceRawGradKernel,
bool, bool,
int, int,
int64_t, int64_t,
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h" #include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice, PD_REGISTER_KERNEL(strided_slice_raw,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::StridedSliceKernel, phi::StridedSliceRawKernel,
bool, bool,
int, int,
int64_t, int64_t,
......
...@@ -20,16 +20,16 @@ ...@@ -20,16 +20,16 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void StridedSliceGradKernel(const Context& dev_ctx, void StridedSliceRawGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const std::vector<int>& axes, const std::vector<int>& axes,
const IntArray& starts, const IntArray& starts,
const IntArray& ends, const IntArray& ends,
const IntArray& strides, const IntArray& strides,
const std::vector<int>& infer_flags, const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis, const std::vector<int>& decrease_axis,
DenseTensor* x_grad) { DenseTensor* x_grad) {
int rank = x.dims().size(); int rank = x.dims().size();
#define SLICE_CASE(Rank) \ #define SLICE_CASE(Rank) \
case Rank: \ case Rank: \
......
...@@ -20,15 +20,15 @@ ...@@ -20,15 +20,15 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void StridedSliceKernel(const Context& dev_ctx, void StridedSliceRawKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
const IntArray& starts, const IntArray& starts,
const IntArray& ends, const IntArray& ends,
const IntArray& strides, const IntArray& strides,
const std::vector<int>& infer_flags, const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis, const std::vector<int>& decrease_axis,
DenseTensor* out) { DenseTensor* out) {
int rank = x.dims().size(); int rank = x.dims().size();
#define SLICE_CASE(Rank) \ #define SLICE_CASE(Rank) \
case Rank: \ case Rank: \
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
DenseTensor* x_grad) {
std::vector<int> infer_flags(axes.size(), 1);
std::vector<int> decrease_axis;
StridedSliceRawGradKernel<T, Context>(dev_ctx,
x,
out_grad,
axes,
starts,
ends,
strides,
infer_flags,
decrease_axis,
x_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(strided_slice_grad,
CPU,
ALL_LAYOUT,
phi::StridedSliceGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(strided_slice_grad,
GPU,
ALL_LAYOUT,
phi::StridedSliceGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
...@@ -19,6 +19,18 @@ ...@@ -19,6 +19,18 @@
namespace phi { namespace phi {
template <typename T, typename Context>
void StridedSliceRawGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* x_grad);
template <typename T, typename Context> template <typename T, typename Context>
void StridedSliceGradKernel(const Context& dev_ctx, void StridedSliceGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -27,8 +39,6 @@ void StridedSliceGradKernel(const Context& dev_ctx, ...@@ -27,8 +39,6 @@ void StridedSliceGradKernel(const Context& dev_ctx,
const IntArray& starts, const IntArray& starts,
const IntArray& ends, const IntArray& ends,
const IntArray& strides, const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* x_grad); DenseTensor* x_grad);
template <typename T, typename Context> template <typename T, typename Context>
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
DenseTensor* out) {
std::vector<int> infer_flags(axes.size(), 1);
std::vector<int> decrease_axis;
StridedSliceRawKernel<T, Context>(
dev_ctx, x, axes, starts, ends, strides, infer_flags, decrease_axis, out);
}
} // namespace phi
PD_REGISTER_KERNEL(strided_slice,
CPU,
ALL_LAYOUT,
phi::StridedSliceKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(strided_slice,
GPU,
ALL_LAYOUT,
phi::StridedSliceKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
...@@ -19,6 +19,17 @@ ...@@ -19,6 +19,17 @@
namespace phi { namespace phi {
template <typename T, typename Context>
void StridedSliceRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void StridedSliceKernel(const Context& dev_ctx, void StridedSliceKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -26,8 +37,6 @@ void StridedSliceKernel(const Context& dev_ctx, ...@@ -26,8 +37,6 @@ void StridedSliceKernel(const Context& dev_ctx,
const IntArray& starts, const IntArray& starts,
const IntArray& ends, const IntArray& ends,
const IntArray& strides, const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -57,14 +57,14 @@ KernelSignature StridedSliceOpArgumentMapping( ...@@ -57,14 +57,14 @@ KernelSignature StridedSliceOpArgumentMapping(
"decrease_axis"}; "decrease_axis"};
paddle::SmallVector<std::string> outputs = {"Out"}; paddle::SmallVector<std::string> outputs = {"Out"};
std::string op_type; std::string kernel_name;
if (ctx.IsDenseTensorVectorInput("Input")) { if (ctx.IsDenseTensorVectorInput("Input")) {
op_type = "strided_slice_array"; kernel_name = "strided_slice_array";
} else { } else {
op_type = "strided_slice"; kernel_name = "strided_slice_raw";
} }
// NOTE(dev): Use this to avoid regularization. // NOTE(dev): Use this to avoid regularization.
KernelSignature sig(op_type, inputs, attrs, outputs); KernelSignature sig(kernel_name, inputs, attrs, outputs);
return sig; return sig;
} }
...@@ -106,15 +106,15 @@ KernelSignature StridedSliceGradOpArgumentMapping( ...@@ -106,15 +106,15 @@ KernelSignature StridedSliceGradOpArgumentMapping(
"decrease_axis"}; "decrease_axis"};
paddle::SmallVector<std::string> outputs = {GradVarName("Input")}; paddle::SmallVector<std::string> outputs = {GradVarName("Input")};
std::string op_type; std::string kernel_name;
if (ctx.IsDenseTensorVectorInput("Input")) { if (ctx.IsDenseTensorVectorInput("Input")) {
op_type = "strided_slice_array_grad"; kernel_name = "strided_slice_array_grad";
} else { } else {
op_type = "strided_slice_grad"; kernel_name = "strided_slice_raw_grad";
} }
// NOTE(dev): Use this to avoid regularization. // NOTE(dev): Use this to avoid regularization.
KernelSignature sig(op_type, inputs, attrs, outputs); KernelSignature sig(kernel_name, inputs, attrs, outputs);
return sig; return sig;
} }
...@@ -132,573 +132,273 @@ NOTE: The following codes are for 'get_compat_kernel_signature.py' ...@@ -132,573 +132,273 @@ NOTE: The following codes are for 'get_compat_kernel_signature.py'
############################ Forward ############################ ############################ Forward ############################
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensor", "EndsTensor", {"axes", "StartsTensor", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"}, "StartsTensor","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensor", "EndsTensor", {"axes", "StartsTensor", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"}, "StartsTensorList","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensor", "EndsTensor", "starts","infer_flags", {"axes", "StartsTensor", "EndsTensor", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensor", "EndsTensorList", {"axes", "StartsTensor", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"}, "StartsTensor","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensor", "EndsTensorList", {"axes", "StartsTensor", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"}, "StartsTensorList","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags", {"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensor", "ends", "StartsTensor","infer_flags", {"axes", "StartsTensor", "ends", "StartsTensor","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags", {"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensor", "ends", "starts","infer_flags", {"axes", "StartsTensor", "ends", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensorList", "EndsTensor", {"axes", "StartsTensorList", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"}, "StartsTensor","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensorList", "EndsTensor", {"axes", "StartsTensorList", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"}, "StartsTensorList","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags", {"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensorList", "EndsTensorList", {"axes", "StartsTensorList", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"}, "StartsTensor","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensorList", "EndsTensorList", {"axes", "StartsTensorList", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"}, "StartsTensorList","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensorList", "EndsTensorList", {"axes", "StartsTensorList", "EndsTensorList",
"starts","infer_flags", "decrease_axis"}, "starts","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags", {"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensorList", "ends", {"axes", "StartsTensorList", "ends",
"StartsTensorList","infer_flags", "decrease_axis"}, "StartsTensorList","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "StartsTensorList", "ends", "starts","infer_flags", {"axes", "StartsTensorList", "ends", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "starts", "EndsTensor", "StartsTensor","infer_flags", {"axes", "starts", "EndsTensor", "StartsTensor","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags", {"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "starts", "EndsTensor", "starts","infer_flags", {"axes", "starts", "EndsTensor", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags", {"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "starts", "EndsTensorList", {"axes", "starts", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"}, "StartsTensorList","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "starts", "EndsTensorList", "starts","infer_flags", {"axes", "starts", "EndsTensorList", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "starts", "ends", "StartsTensor","infer_flags", {"axes", "starts", "ends", "StartsTensor","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "starts", "ends", "StartsTensorList","infer_flags", {"axes", "starts", "ends", "StartsTensorList","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice}", {"Input"}, return KernelSignature("strided_slice_raw", {"Input"},
{"axes", "starts", "ends", "starts","infer_flags", {"axes", "starts", "ends", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensor", "EndsTensor", {"axes", "StartsTensor", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"}, "StartsTensor","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensor", "EndsTensor", {"axes", "StartsTensor", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"}, "StartsTensorList","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensor", "EndsTensor", "starts","infer_flags", {"axes", "StartsTensor", "EndsTensor", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensor", "EndsTensorList", {"axes", "StartsTensor", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"}, "StartsTensor","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensor", "EndsTensorList", {"axes", "StartsTensor", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"}, "StartsTensorList","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags", {"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensor", "ends", "StartsTensor","infer_flags", {"axes", "StartsTensor", "ends", "StartsTensor","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags", {"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensor", "ends", "starts","infer_flags", {"axes", "StartsTensor", "ends", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensorList", "EndsTensor", {"axes", "StartsTensorList", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"}, "StartsTensor","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensorList", "EndsTensor", {"axes", "StartsTensorList", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"}, "StartsTensorList","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags", {"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensorList", "EndsTensorList", {"axes", "StartsTensorList", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"}, "StartsTensor","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensorList", "EndsTensorList", {"axes", "StartsTensorList", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"}, "StartsTensorList","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensorList", "EndsTensorList", {"axes", "StartsTensorList", "EndsTensorList",
"starts","infer_flags", "decrease_axis"}, "starts","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags", {"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensorList", "ends", {"axes", "StartsTensorList", "ends",
"StartsTensorList","infer_flags", "decrease_axis"}, "StartsTensorList","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "StartsTensorList", "ends", "starts","infer_flags", {"axes", "StartsTensorList", "ends", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "starts", "EndsTensor", "StartsTensor","infer_flags", {"axes", "starts", "EndsTensor", "StartsTensor","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags", {"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "starts", "EndsTensor", "starts","infer_flags", {"axes", "starts", "EndsTensor", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags", {"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "starts", "EndsTensorList", {"axes", "starts", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"}, "StartsTensorList","infer_flags", "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "starts", "EndsTensorList", "starts","infer_flags", {"axes", "starts", "EndsTensorList", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "starts", "ends", "StartsTensor","infer_flags", {"axes", "starts", "ends", "StartsTensor","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "starts", "ends", "StartsTensorList","infer_flags", {"axes", "starts", "ends", "StartsTensorList","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
return KernelSignature("{strided_slice_array}", {"Input"}, return KernelSignature("strided_slice_array", {"Input"},
{"axes", "starts", "ends", "starts","infer_flags", {"axes", "starts", "ends", "starts","infer_flags",
"decrease_axis"}, "decrease_axis"},
{"Out"}); {"Out"});
############################ Backward ############################
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "ends", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensorList",
"starts","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "ends",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "ends", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensor", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "ends", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "ends", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensorList",
"starts","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "ends",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "ends", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "EndsTensor", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "ends", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
*/ */
...@@ -11426,6 +11426,10 @@ def strided_slice(input, axes, starts, ends, strides): ...@@ -11426,6 +11426,10 @@ def strided_slice(input, axes, starts, ends, strides):
sliced_2 = fluid.layers.strided_slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides_2) sliced_2 = fluid.layers.strided_slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides_2)
# sliced_2 is input[:, 0:3:1, 0:2:1, 2:4:2]. # sliced_2 is input[:, 0:3:1, 0:2:1, 2:4:2].
""" """
if in_dygraph_mode():
return _C_ops.final_state_strided_slice(input, axes, starts, ends,
strides)
helper = LayerHelper('strided_slice', **locals()) helper = LayerHelper('strided_slice', **locals())
check_variable_and_dtype(input, 'input', check_variable_and_dtype(input, 'input',
...@@ -11590,7 +11594,11 @@ def shape(input): ...@@ -11590,7 +11594,11 @@ def shape(input):
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output]) res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [array([ 3, 100, 100], dtype=int32)] print(res) # [array([ 3, 100, 100], dtype=int32)]
""" """
if _non_static_mode(): if in_dygraph_mode():
out = _C_ops.final_state_shape(input)
out.stop_gradient = True
return out
if _in_legacy_dygraph():
out = _C_ops.shape(input) out = _C_ops.shape(input)
out.stop_gradient = True out.stop_gradient = True
return out return out
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
...@@ -24,6 +25,7 @@ from paddle.fluid.op import Operator ...@@ -24,6 +25,7 @@ from paddle.fluid.op import Operator
class TestShapeOp(OpTest): class TestShapeOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "shape" self.op_type = "shape"
self.python_api = paddle.shape
self.config() self.config()
self.shape = [2, 3] self.shape = [2, 3]
input = np.zeros(self.shape) input = np.zeros(self.shape)
...@@ -34,7 +36,7 @@ class TestShapeOp(OpTest): ...@@ -34,7 +36,7 @@ class TestShapeOp(OpTest):
self.shape = [2, 3] self.shape = [2, 3]
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
class case1(TestShapeOp): class case1(TestShapeOp):
......
...@@ -58,6 +58,7 @@ class TestStrideSliceOp(OpTest): ...@@ -58,6 +58,7 @@ class TestStrideSliceOp(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.initTestCase()
self.op_type = 'strided_slice' self.op_type = 'strided_slice'
self.python_api = paddle.strided_slice
self.output = strided_slice_native_forward( self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides) self.input, self.axes, self.starts, self.ends, self.strides)
...@@ -72,10 +73,10 @@ class TestStrideSliceOp(OpTest): ...@@ -72,10 +73,10 @@ class TestStrideSliceOp(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(set(['Input']), 'Out') self.check_grad(set(['Input']), 'Out', check_eager=True)
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(100) self.input = np.random.rand(100)
...@@ -704,7 +705,7 @@ class TestStridedSliceTensorArray(unittest.TestCase): ...@@ -704,7 +705,7 @@ class TestStridedSliceTensorArray(unittest.TestCase):
l2.sum().backward() l2.sum().backward()
grads_static = net.get_all_grads() grads_static = net.get_all_grads()
net.clear_all_grad() net.clear_all_grad()
# compare result of dygraph and static # compare result of dygraph and static
self.is_grads_equal(grads_static, grads_dy) self.is_grads_equal(grads_static, grads_dy)
self.assertTrue( self.assertTrue(
np.array_equal(s1, s2), np.array_equal(s1, s2),
......
...@@ -951,6 +951,14 @@ ...@@ -951,6 +951,14 @@
func : selu func : selu
backward : selu_grad backward : selu_grad
- api : shape
args : (Tensor input)
output : Tensor
infer_meta :
func : ShapeInferMeta
kernel :
func : shape, shape_sr
# shard_index # shard_index
- api : shard_index - api : shard_index
args : (Tensor in, int index_num, int nshards, int shard_id, int ignore_value) args : (Tensor in, int index_num, int nshards, int shard_id, int ignore_value)
...@@ -1070,6 +1078,15 @@ ...@@ -1070,6 +1078,15 @@
func : square func : square
backward : square_grad backward : square_grad
- api : strided_slice
args : (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides)
output : Tensor
infer_meta :
func : StridedSliceInferMeta
kernel :
func : strided_slice
backward : strided_slice_grad
- api : subtract - api : subtract
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor
......
...@@ -660,6 +660,16 @@ ...@@ -660,6 +660,16 @@
kernel : kernel :
func : square_grad func : square_grad
- backward_api : strided_slice_grad
forward : strided_slice (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int[] axes, IntArray starts, IntArray ends, IntArray strides)
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param : [x]
kernel :
func : strided_slice_grad
- backward_api : subtract_grad - backward_api : subtract_grad
forward : subtract (Tensor x, Tensor y) -> Tensor(out) forward : subtract (Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1) args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册