未验证 提交 91110661 编写于 作者: A Aurelius84 提交者: GitHub

[OpAttr]axis of Reverse Support Tensor type (#45391)

* [OpAttr]axis of Reverse Support Tensor type

* fix coverage

* fix unittest
上级 9b5b005e
......@@ -26,6 +26,15 @@ namespace operators {
class ReverseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class ReverseOpVarTypeInference : public framework::VarTypeInference {
......@@ -42,7 +51,8 @@ class ReverseOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The LoDTensor to be flipped.");
AddOutput("Out", "The LoDTensor after flipping.");
AddAttr<std::vector<int>>(
"axis", "The axises that along which order of elements is reversed.");
"axis", "The axises that along which order of elements is reversed.")
.SupportTensor();
AddComment(R"DOC(
Reverse Operator.
......
......@@ -2161,7 +2161,7 @@
backward: reshape_grad
- api : reverse
args : (Tensor x, int[] axis)
args : (Tensor x, IntArray axis)
output : Tensor
infer_meta :
func : ReverseInferMeta
......@@ -2170,7 +2170,7 @@
backward : reverse_grad
- api : reverse_array
args : (Tensor[] x, int[] axis)
args : (Tensor[] x, IntArray axis)
output : Tensor[]{x.size()}
infer_meta :
func : ReverseArrayInferMeta
......
......@@ -1963,8 +1963,8 @@
inplace : (out_grad -> x_grad)
- backward_api : reverse_array_grad
forward : reverse_array (Tensor[] x, int[] axis) -> Tensor[](out)
args : (Tensor[] out_grad, int[] axis)
forward : reverse_array (Tensor[] x, IntArray axis) -> Tensor[](out)
args : (Tensor[] out_grad, IntArray axis)
output : Tensor[](x_grad){out_grad.size()}
infer_meta :
func : ReverseArrayInferMeta
......@@ -1972,8 +1972,8 @@
func : reverse
- backward_api : reverse_grad
forward : reverse (Tensor x, int[] axis) -> Tensor(out)
args : (Tensor out_grad, int[] axis)
forward : reverse (Tensor x, IntArray axis) -> Tensor(out)
args : (Tensor out_grad, IntArray axis)
output : Tensor(x_grad)
infer_meta :
func : ReverseInferMeta
......
......@@ -2744,13 +2744,22 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
}
void ReverseInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out) {
PADDLE_ENFORCE_NE(axis.empty(),
const IntArray& axis,
MetaTensor* out,
MetaConfig config) {
// NOTE(Aurelius84): In Reverse Op, output TensorMeta is always same
// as input, so we only verify axis when it is not from Tensor or in
// runtime.
if (!config.is_runtime && axis.FromTensor()) {
out->share_meta(x);
return;
}
auto& axis_data = axis.GetData();
PADDLE_ENFORCE_NE(axis_data.empty(),
true,
phi::errors::InvalidArgument("'axis' can not be empty."));
const auto& x_dims = x.dims();
for (int a : axis) {
for (int a : axis_data) {
PADDLE_ENFORCE_LT(a,
x_dims.size(),
phi::errors::OutOfRange(
......@@ -2771,22 +2780,27 @@ void ReverseInferMeta(const MetaTensor& x,
}
void ReverseArrayInferMeta(const std::vector<const phi::MetaTensor*>& x,
const std::vector<int>& axis,
std::vector<phi::MetaTensor*> out) {
const IntArray& axis,
std::vector<phi::MetaTensor*> out,
MetaConfig config) {
if (!config.is_runtime && axis.FromTensor()) {
return;
}
auto& axis_data = axis.GetData();
PADDLE_ENFORCE_EQ(
axis.size(),
axis_data.size(),
1,
phi::errors::InvalidArgument(
"The size of axis must be 1 when the Input(X) is LoDTensorArray, "
"but received %d.",
axis.size()));
axis_data.size()));
PADDLE_ENFORCE_EQ(
axis[0],
axis_data[0],
0,
phi::errors::InvalidArgument("The value of axis should be 1 when "
"the Input(X) is LoDTensorArray, "
"but received %d.",
axis[0]));
axis_data[0]));
}
void RollInferMeta(const MetaTensor& x,
......
......@@ -388,12 +388,14 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
MetaConfig config = MetaConfig());
void ReverseInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out);
const IntArray& axis,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ReverseArrayInferMeta(const std::vector<const phi::MetaTensor*>& x,
const std::vector<int>& axis,
std::vector<phi::MetaTensor*> out);
const IntArray& axis,
std::vector<phi::MetaTensor*> out,
MetaConfig config = MetaConfig());
void RollInferMeta(const MetaTensor& x,
const IntArray& shifts,
......
......@@ -25,12 +25,13 @@ struct ReverseFunctor {
void operator()(const Context& dev_ctx,
const DenseTensor& in,
DenseTensor* out,
const std::vector<int>& axis) {
const IntArray& axis) {
auto& axis_data = axis.GetData();
Eigen::DSizes<bool, Rank> reverse_axis;
for (int i = 0; i < Rank; ++i) {
reverse_axis[i] = false;
}
for (int a : axis) {
for (int a : axis_data) {
if (a >= 0) {
reverse_axis[a] = true;
} else {
......@@ -50,7 +51,7 @@ struct ReverseFunctor {
template <typename T, typename Context>
void ReverseKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
const IntArray& axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
int rank = x.dims().size();
......
......@@ -23,7 +23,7 @@ namespace phi {
template <typename T, typename Context>
void ReverseArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<int>& axis,
const IntArray& axis,
std::vector<DenseTensor*> out) {
PADDLE_ENFORCE_EQ(
x.size(),
......
......@@ -16,6 +16,7 @@
#include <vector>
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
......@@ -23,13 +24,13 @@ namespace phi {
template <typename T, typename Context>
void ReverseKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
const IntArray& axis,
DenseTensor* out);
template <typename T, typename Context>
void ReverseArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<int>& axis,
const IntArray& axis,
std::vector<DenseTensor*> out);
} // namespace phi
......@@ -1279,7 +1279,7 @@ def reverse(x, axis):
check_variable_and_dtype(x, 'x',
('float32', 'float64', 'int32', 'int64', 'uint8'),
'reverse')
check_type(axis, 'axis', (int, tuple, list), 'reverse')
check_type(axis, 'axis', (int, tuple, list, Variable), 'reverse')
if isinstance(axis, int):
axis = [axis]
if in_dygraph_mode():
......
......@@ -37,6 +37,9 @@ class UnittestBase(unittest.TestCase):
self.shapes = None
self.save_path = None
def path_prefix(self):
return type(self).__name__
def infer_prog(self):
config = paddle_infer.Config(self.save_path + '.pdmodel',
self.save_path + '.pdiparams')
......@@ -44,15 +47,21 @@ class UnittestBase(unittest.TestCase):
input_names = predictor.get_input_names()
for i, shape in enumerate(self.shapes):
input_handle = predictor.get_input_handle(input_names[i])
fake_input = np.random.randn(*shape).astype("float32")
self.fake_input = np.random.randn(*shape).astype("float32")
input_handle.reshape(shape)
input_handle.copy_from_cpu(fake_input)
input_handle.copy_from_cpu(self.fake_input)
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
res = []
for out_name in output_names:
output_handle = predictor.get_output_handle(out_name)
output_data = output_handle.copy_to_cpu()
res.append(output_data)
if len(output_names) == 1:
res = res[0]
return output_data
return res
class TestDropout(UnittestBase):
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import os
import unittest
import numpy as np
from op_test import OpTest
......@@ -21,6 +22,9 @@ import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.framework import program_guard, Program
from test_attribute_var import UnittestBase
class TestReverseOp(OpTest):
......@@ -195,6 +199,130 @@ class TestReverseLoDTensorArray(unittest.TestCase):
self.run_program(arr_len=3, axis=1)
class TestReverseAxisTensor(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 4]]
self.save_path = os.path.join(self.temp_dir.name, self.path_prefix())
def test_static(self):
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x) # [2,3,10]
out = self.call_func(feat)
sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(out))
self.assertTrue(self.var_prefix() in str(main_prog))
exe = paddle.static.Executor()
exe.run(starup_prog)
res = exe.run(fetch_list=[feat, out])
gt = res[0][::-1, :, ::-1]
np.testing.assert_allclose(res[1], gt)
paddle.static.save_inference_model(self.save_path, [x], [feat, out],
exe)
# Test for Inference Predictor
infer_outs = self.infer_prog()
gt = infer_outs[0][::-1, :, ::-1]
np.testing.assert_allclose(infer_outs[1], gt)
def path_prefix(self):
return 'reverse_tensor'
def var_prefix(self):
return "Var["
def call_func(self, x):
# axes is a Variable
axes = paddle.assign([0, 2])
out = paddle.fluid.layers.reverse(x, axes)
return out
class TestReverseAxisListTensor(TestReverseAxisTensor):
def path_prefix(self):
return 'reverse_tensors'
def var_prefix(self):
return "Vars["
def call_func(self, x):
# axes is a List[Variable]
axes = [paddle.assign([0]), paddle.assign([2])]
out = paddle.fluid.layers.reverse(x, axes)
return out
class TestAReverseEagerAPI(UnittestBase):
def test_api(self):
paddle.disable_static()
x = paddle.randn([4, 10])
y = paddle.randn([4, 10])
out = paddle._C_ops.final_state_reverse_array([x, y], [0])
np.testing.assert_allclose(x.numpy(), out[1].numpy())
np.testing.assert_allclose(y.numpy(), out[0].numpy())
paddle.enable_static()
class TestReverseTensorArrayAxisTensor(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 4]]
self.save_path = os.path.join(self.temp_dir.name,
'reverse_tensor_array')
def test_static(self):
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
fc = paddle.nn.Linear(4, 2)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x) # [2,3,10]
# tensor_array.shape: [[2,3,10], [2,3,10]]
tensor_array = paddle.fluid.layers.create_array(dtype='float32')
idx0 = paddle.full(shape=[1], fill_value=0, dtype="int64")
val0 = paddle.randn([2, 3, 2])
paddle.fluid.layers.array_write(val0, idx0, tensor_array)
idx1 = paddle.full(shape=[1], fill_value=1, dtype="int64")
paddle.fluid.layers.array_write(feat, idx1, tensor_array)
# axes is a Variable
axes = paddle.assign([0])
# tensor_array.shape: [[2,3,10], [2,3,10]]
reverse_array = paddle.fluid.layers.reverse(tensor_array, axes)
out, _ = paddle.fluid.layers.tensor_array_to_tensor(reverse_array,
axis=0)
sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(out))
self.assertTrue("Var[" in str(main_prog))
exe = paddle.static.Executor()
exe.run(starup_prog)
res = exe.run(fetch_list=[val0, feat, out])
np.testing.assert_allclose(res[1], res[-1][0:2])
np.testing.assert_allclose(res[0], res[-1][2:4])
paddle.static.save_inference_model(self.save_path, [x],
[val0, feat, out], exe)
# Test for Inference Predictor
infer_outs = self.infer_prog()
np.testing.assert_allclose(infer_outs[1], infer_outs[-1][0:2])
np.testing.assert_allclose(infer_outs[0], infer_outs[-1][2:4])
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册