未验证 提交 9b6a02d4 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Add shape and strided_slice yaml & Adapt eager mode (#41131)

* add several yaml

* polish strided slice kernel & add yaml

* reorder yaml

* add several yaml

* revert yaml config change

* resolve conflict

* Update test_strided_slice_op.py
上级 98303291
......@@ -228,7 +228,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(strided_slice, StridedSliceInferShape,
PD_INFER_META(phi::StridedSliceInferMeta));
PD_INFER_META(phi::StridedSliceRawInferMeta));
REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker,
ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
......
......@@ -1922,15 +1922,15 @@ void SqueezeInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void StridedSliceInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
MetaTensor* out,
MetaConfig config) {
void StridedSliceRawInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
MetaTensor* out,
MetaConfig config) {
auto in_dims = x.dims();
PADDLE_ENFORCE_LT(
in_dims.size(),
......@@ -2052,6 +2052,19 @@ void StridedSliceInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void StridedSliceInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
MetaTensor* out,
MetaConfig config) {
std::vector<int> infer_flags(axes.size(), 1);
std::vector<int> decrease_axis;
StridedSliceRawInferMeta(
x, axes, starts, ends, strides, infer_flags, decrease_axis, out, config);
}
/* Why not use SumRawInferMeta directly?
Because we need make InferMetaFunction's args follow the design of api.yaml
*/
......
......@@ -284,13 +284,21 @@ void SqueezeInferMeta(const MetaTensor& x,
MetaTensor* xshape,
MetaTensor* out);
void StridedSliceRawInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
MetaTensor* out,
MetaConfig config = MetaConfig());
void StridedSliceInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
MetaTensor* out,
MetaConfig config = MetaConfig());
......
......@@ -19,10 +19,10 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice_grad,
PD_REGISTER_KERNEL(strided_slice_raw_grad,
CPU,
ALL_LAYOUT,
phi::StridedSliceGradKernel,
phi::StridedSliceRawGradKernel,
bool,
int,
int64_t,
......
......@@ -19,10 +19,10 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice,
PD_REGISTER_KERNEL(strided_slice_raw,
CPU,
ALL_LAYOUT,
phi::StridedSliceKernel,
phi::StridedSliceRawKernel,
bool,
int,
int64_t,
......
......@@ -19,10 +19,10 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice_grad,
PD_REGISTER_KERNEL(strided_slice_raw_grad,
GPU,
ALL_LAYOUT,
phi::StridedSliceGradKernel,
phi::StridedSliceRawGradKernel,
bool,
int,
int64_t,
......
......@@ -19,10 +19,10 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice,
PD_REGISTER_KERNEL(strided_slice_raw,
GPU,
ALL_LAYOUT,
phi::StridedSliceKernel,
phi::StridedSliceRawKernel,
bool,
int,
int64_t,
......
......@@ -20,16 +20,16 @@
namespace phi {
template <typename T, typename Context>
void StridedSliceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* x_grad) {
void StridedSliceRawGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* x_grad) {
int rank = x.dims().size();
#define SLICE_CASE(Rank) \
case Rank: \
......
......@@ -20,15 +20,15 @@
namespace phi {
template <typename T, typename Context>
void StridedSliceKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* out) {
void StridedSliceRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* out) {
int rank = x.dims().size();
#define SLICE_CASE(Rank) \
case Rank: \
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
DenseTensor* x_grad) {
std::vector<int> infer_flags(axes.size(), 1);
std::vector<int> decrease_axis;
StridedSliceRawGradKernel<T, Context>(dev_ctx,
x,
out_grad,
axes,
starts,
ends,
strides,
infer_flags,
decrease_axis,
x_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(strided_slice_grad,
CPU,
ALL_LAYOUT,
phi::StridedSliceGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(strided_slice_grad,
GPU,
ALL_LAYOUT,
phi::StridedSliceGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
......@@ -19,6 +19,18 @@
namespace phi {
template <typename T, typename Context>
void StridedSliceRawGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* x_grad);
template <typename T, typename Context>
void StridedSliceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -27,8 +39,6 @@ void StridedSliceGradKernel(const Context& dev_ctx,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* x_grad);
template <typename T, typename Context>
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
DenseTensor* out) {
std::vector<int> infer_flags(axes.size(), 1);
std::vector<int> decrease_axis;
StridedSliceRawKernel<T, Context>(
dev_ctx, x, axes, starts, ends, strides, infer_flags, decrease_axis, out);
}
} // namespace phi
PD_REGISTER_KERNEL(strided_slice,
CPU,
ALL_LAYOUT,
phi::StridedSliceKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(strided_slice,
GPU,
ALL_LAYOUT,
phi::StridedSliceKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
......@@ -19,6 +19,17 @@
namespace phi {
template <typename T, typename Context>
void StridedSliceRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* out);
template <typename T, typename Context>
void StridedSliceKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -26,8 +37,6 @@ void StridedSliceKernel(const Context& dev_ctx,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* out);
template <typename T, typename Context>
......
......@@ -11426,6 +11426,10 @@ def strided_slice(input, axes, starts, ends, strides):
sliced_2 = fluid.layers.strided_slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides_2)
# sliced_2 is input[:, 0:3:1, 0:2:1, 2:4:2].
"""
if in_dygraph_mode():
return _C_ops.final_state_strided_slice(input, axes, starts, ends,
strides)
helper = LayerHelper('strided_slice', **locals())
check_variable_and_dtype(input, 'input',
......@@ -11590,7 +11594,11 @@ def shape(input):
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [array([ 3, 100, 100], dtype=int32)]
"""
if _non_static_mode():
if in_dygraph_mode():
out = _C_ops.final_state_shape(input)
out.stop_gradient = True
return out
if _in_legacy_dygraph():
out = _C_ops.shape(input)
out.stop_gradient = True
return out
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
from paddle.fluid import core
from paddle.fluid.op import Operator
......@@ -24,6 +25,7 @@ from paddle.fluid.op import Operator
class TestShapeOp(OpTest):
def setUp(self):
self.op_type = "shape"
self.python_api = paddle.shape
self.config()
self.shape = [2, 3]
input = np.zeros(self.shape)
......@@ -34,7 +36,7 @@ class TestShapeOp(OpTest):
self.shape = [2, 3]
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
class case1(TestShapeOp):
......
......@@ -58,6 +58,7 @@ class TestStrideSliceOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'strided_slice'
self.python_api = paddle.strided_slice
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
......@@ -72,10 +73,10 @@ class TestStrideSliceOp(OpTest):
}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(set(['Input']), 'Out')
self.check_grad(set(['Input']), 'Out', check_eager=True)
def initTestCase(self):
self.input = np.random.rand(100)
......@@ -704,7 +705,7 @@ class TestStridedSliceTensorArray(unittest.TestCase):
l2.sum().backward()
grads_static = net.get_all_grads()
net.clear_all_grad()
# compare result of dygraph and static
# compare result of dygraph and static
self.is_grads_equal(grads_static, grads_dy)
self.assertTrue(
np.array_equal(s1, s2),
......
......@@ -951,6 +951,14 @@
func : selu
backward : selu_grad
- api : shape
args : (Tensor input)
output : Tensor
infer_meta :
func : ShapeInferMeta
kernel :
func : shape, shape_sr
# shard_index
- api : shard_index
args : (Tensor in, int index_num, int nshards, int shard_id, int ignore_value)
......@@ -1070,6 +1078,15 @@
func : square
backward : square_grad
- api : strided_slice
args : (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides)
output : Tensor
infer_meta :
func : StridedSliceInferMeta
kernel :
func : strided_slice
backward : strided_slice_grad
- api : subtract
args : (Tensor x, Tensor y)
output : Tensor
......
......@@ -660,6 +660,16 @@
kernel :
func : square_grad
- backward_api : strided_slice_grad
forward : strided_slice (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int[] axes, IntArray starts, IntArray ends, IntArray strides)
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param : [x]
kernel :
func : strided_slice_grad
- backward_api : subtract_grad
forward : subtract (Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册