提交 e9e3c087 编写于 作者: L liym27 提交者: Aurelius84

fix expand op: (#19302)

1. add tensor support for argument expand_times in expand op;
2. add support parameter inference when argument expand_times is a list containing integer and tensor variable;

improve expand op according to reviews:
1. add doc of ExpandTimes in expand_op.cc;
2. improve the test of test_api.

add stop_gradient=True when attr(expand_times) is tensor Variable, change code examples.
test=develop,test=document_preview
上级 6bf298bf
...@@ -226,7 +226,7 @@ paddle.fluid.layers.unstack (ArgSpec(args=['x', 'axis', 'num'], varargs=None, ke ...@@ -226,7 +226,7 @@ paddle.fluid.layers.unstack (ArgSpec(args=['x', 'axis', 'num'], varargs=None, ke
paddle.fluid.layers.sequence_enumerate (ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'b870fed41abd2aecf929ece65f555fa1')) paddle.fluid.layers.sequence_enumerate (ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'b870fed41abd2aecf929ece65f555fa1'))
paddle.fluid.layers.unique (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', 'cab0b06e5683875f12f0efc62fa230a9')) paddle.fluid.layers.unique (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', 'cab0b06e5683875f12f0efc62fa230a9'))
paddle.fluid.layers.unique_with_counts (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', '1cb59c65b41766116944b8ed1e6ad345')) paddle.fluid.layers.unique_with_counts (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', '1cb59c65b41766116944b8ed1e6ad345'))
paddle.fluid.layers.expand (ArgSpec(args=['x', 'expand_times', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '33bc4f6010282ffe044d77be7ba7c275')) paddle.fluid.layers.expand (ArgSpec(args=['x', 'expand_times', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7b97042c3ba55fb5fec6a06308523b73'))
paddle.fluid.layers.sequence_concat (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b992616c1afbd6b0c2a897ac23036381')) paddle.fluid.layers.sequence_concat (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b992616c1afbd6b0c2a897ac23036381'))
paddle.fluid.layers.scale (ArgSpec(args=['x', 'scale', 'bias', 'bias_after_scale', 'act', 'name'], varargs=None, keywords=None, defaults=(1.0, 0.0, True, None, None)), ('document', '463e4713806e5adaa4d20a41e2218453')) paddle.fluid.layers.scale (ArgSpec(args=['x', 'scale', 'bias', 'bias_after_scale', 'act', 'name'], varargs=None, keywords=None, defaults=(1.0, 0.0, True, None, None)), ('document', '463e4713806e5adaa4d20a41e2218453'))
paddle.fluid.layers.elementwise_add (ArgSpec(args=['x', 'y', 'axis', 'act', 'name'], varargs=None, keywords=None, defaults=(-1, None, None)), ('document', '5c0fb7298aec32525f96d451ae4c2851')) paddle.fluid.layers.elementwise_add (ArgSpec(args=['x', 'y', 'axis', 'act', 'name'], varargs=None, keywords=None, defaults=(-1, None, None)), ('document', '5c0fb7298aec32525f96d451ae4c2851'))
......
...@@ -28,14 +28,15 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -28,14 +28,15 @@ class ExpandOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
std::vector<int> expand_times(x_dims.size(), -1); auto expand_times = ctx->Attrs().Get<std::vector<int>>("expand_times");
if (!ctx->HasInputs("expand_times_tensor")) { if (expand_times.size() == 0) {
expand_times = ctx->Attrs().Get<std::vector<int>>("expand_times"); expand_times = std::vector<int>(x_dims.size(), -1);
} }
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(), PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(),
...@@ -49,6 +50,9 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -49,6 +50,9 @@ class ExpandOp : public framework::OperatorWithKernel {
if (x_dims[i] == -1 || expand_times[i] == -1) { if (x_dims[i] == -1 || expand_times[i] == -1) {
out_shape[i] = -1; out_shape[i] = -1;
} else { } else {
PADDLE_ENFORCE_GT(
expand_times[i], 0,
"The element of Attr(expand_times) must greater than 0.");
out_shape[i] = x_dims[i] * expand_times[i]; out_shape[i] = x_dims[i] * expand_times[i];
} }
} }
...@@ -69,7 +73,7 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -69,7 +73,7 @@ class ExpandOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor, const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override { const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "expand_times_tensor") { if (var_name == "expand_times_tensor" || var_name == "ExpandTimes") {
return expected_kernel_type; return expected_kernel_type;
} }
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
...@@ -83,7 +87,15 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -83,7 +87,15 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]." "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded."); "X is the input to be expanded.");
AddInput("expand_times_tensor", "(Tensor Tensor<int>), epxand times for X") AddInput("ExpandTimes",
"(Tensor<int>), optional). If provided, expand according to "
"this given expand times. It has a higher priority than "
"expand_times_tensor and expand_times.")
.AsDispensable();
AddInput("expand_times_tensor",
"(Tensor Tensor<int>), epxand times for X."
"It has a higher priority than expand_times, but a lower priority "
"than ExpandTimes")
.AsDuplicable() .AsDuplicable()
.AsDispensable(); .AsDispensable();
AddOutput("Out", AddOutput("Out",
...@@ -127,9 +139,9 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -127,9 +139,9 @@ class ExpandGradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) should not be null."); "Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
std::vector<int> expand_times = std::vector<int> expand_times =
...@@ -147,12 +159,15 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -147,12 +159,15 @@ class ExpandGradOp : public framework::OperatorWithKernel {
} }
for (size_t i = start_pos; i < expand_times.size(); ++i) { for (size_t i = start_pos; i < expand_times.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i], if (expand_times[i] == -1) {
"Each dimension size of Input(Out@GRAD) should be " continue;
"equal to multiplication of crroresponding dimension " } else {
"size of Input(X) and Attr(expand_times) value."); PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
"Each dimension size of Input(Out@GRAD) should be "
"equal to multiplication of crroresponding dimension "
"size of Input(X) and Attr(expand_times) value.");
}
} }
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
...@@ -191,6 +206,7 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -191,6 +206,7 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetInput("expand_times_tensor", Input("expand_times_tensor")); op->SetInput("expand_times_tensor", Input("expand_times_tensor"));
op->SetInput("ExpandTimes", Input("ExpandTimes"));
op->SetAttrMap(Attrs()); op->SetAttrMap(Attrs());
return op; return op;
} }
......
...@@ -50,6 +50,19 @@ namespace paddle { ...@@ -50,6 +50,19 @@ namespace paddle {
namespace operators { namespace operators {
inline std::vector<int> get_expand_times( inline std::vector<int> get_expand_times(
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
if (ctx.HasInput("ExpandTimes")) {
auto* expand_tensor = ctx.Input<framework::LoDTensor>("ExpandTimes");
auto* expand_data = expand_tensor->data<int>();
framework::Tensor cpu_expand_tensor;
if (platform::is_gpu_place(expand_tensor->place())) {
TensorCopySync(*expand_tensor, platform::CPUPlace(), &cpu_expand_tensor);
expand_data = cpu_expand_tensor.data<int>();
}
auto vec_epxand_times =
std::vector<int>(expand_data, expand_data + expand_tensor->numel());
return vec_epxand_times;
}
auto list_expand_times_tensor = auto list_expand_times_tensor =
ctx.MultiInput<framework::Tensor>("expand_times_tensor"); ctx.MultiInput<framework::Tensor>("expand_times_tensor");
if (list_expand_times_tensor.size() > 0) { if (list_expand_times_tensor.size() > 0) {
...@@ -100,6 +113,9 @@ class ExpandKernel : public framework::OpKernel<T> { ...@@ -100,6 +113,9 @@ class ExpandKernel : public framework::OpKernel<T> {
auto in_dims = in0->dims(); auto in_dims = in0->dims();
auto expand_times = get_expand_times(context); auto expand_times = get_expand_times(context);
PADDLE_ENFORCE_EQ(static_cast<size_t>(in_dims.size()), expand_times.size(),
"The number of Attr(expand_times)'s value must be equal "
"to the rank of Input(X).");
auto* out0 = context.Output<Tensor>("Out"); auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims; Eigen::DSizes<int, Rank> bcast_dims;
for (size_t i = 0; i < expand_times.size(); ++i) { for (size_t i = 0; i < expand_times.size(); ++i) {
......
...@@ -10290,7 +10290,7 @@ def expand(x, expand_times, name=None): ...@@ -10290,7 +10290,7 @@ def expand(x, expand_times, name=None):
Args: Args:
x (Variable): A tensor with rank in [1, 6]. x (Variable): A tensor with rank in [1, 6].
expand_times (list|tuple): Expand times number for each dimension. expand_times (list|tuple|Variable): Expand times number for each dimension.
Returns: Returns:
Variable: The expanded variable which is a LoDTensor. After expanding, size of each dimension of Output(Out) is equal to ithe size of the corresponding dimension of Input(X) multiplying the corresponding value given by expand_times. Variable: The expanded variable which is a LoDTensor. After expanding, size of each dimension of Output(Out) is equal to ithe size of the corresponding dimension of Input(X) multiplying the corresponding value given by expand_times.
...@@ -10298,46 +10298,72 @@ def expand(x, expand_times, name=None): ...@@ -10298,46 +10298,72 @@ def expand(x, expand_times, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
x = fluid.layers.fill_constant(shape=[2, 3, 1], dtype='int32', value=0)
out = fluid.layers.expand(x=x, expand_times=[1, 2, 2]) # example 1:
data_1 = fluid.layers.fill_constant(shape=[2, 3, 1], dtype='int32', value=0)
expanded_1 = fluid.layers.expand(data_1, expand_times=[1, 2, 2])
# example 2:
data_2 = fluid.layers.fill_constant(shape=[12, 14], dtype="int32", value=3)
expand_times = fluid.layers.fill_constant(shape=[2], dtype="int32", value=4)
expanded_2 = fluid.layers.expand(data_2, expand_times=expand_times)
""" """
if not isinstance(expand_times, (list, tuple, Variable)):
raise ValueError(
"Input expand_times must be an Variable, python list or tuple.")
helper = LayerHelper('expand', input=x, **locals()) helper = LayerHelper('expand', input=x, **locals())
dtype = helper.input_dtype(input_param_name='x') inputs = {"X": x}
out = helper.create_variable_for_type_inference(dtype) attrs = {}
# check expand_times have tensor
def contain_var(expand_times):
for ele in expand_times:
if isinstance(ele, Variable):
return True
return False
def get_attr_expand_times(list_expand_times):
attrs_expand_times = []
for idx, times in enumerate(list_expand_times):
if isinstance(times, Variable):
attrs_expand_times.append(-1)
else:
attrs_expand_times.append(times)
assert times > 0, (
"Each element given in expand_times must not be negtive.")
return attrs_expand_times
def get_new_expand_times_tensor(list_expand_times):
new_expand_times_tensor = []
for ele in list_expand_times:
if isinstance(ele, Variable):
ele.stop_gradient = True
new_expand_times_tensor.append(ele)
else:
assert (isinstance(ele, int))
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', ele, force_cpu=True, out=temp_out)
new_expand_times_tensor.append(temp_out)
return new_expand_times_tensor
if in_dygraph_mode(): if in_dygraph_mode():
inputs = {'X': x} inputs = {'X': x}
attrs = {'expand_times': expand_times} attrs = {'expand_times': expand_times}
else: else:
if isinstance(expand_times, Variable):
expand_times.stop_gradient = True
inputs['ExpandTimes'] = expand_times
elif isinstance(expand_times, (list, tuple)):
attrs['expand_times'] = get_attr_expand_times(expand_times)
if contain_var(expand_times):
inputs['expand_times_tensor'] = get_new_expand_times_tensor(
expand_times)
def contain_tensor(expand_times): dtype = helper.input_dtype(input_param_name='x')
for ele in expand_times: out = helper.create_variable_for_type_inference(dtype)
if isinstance(ele, Variable):
return True
return False
if contain_tensor(expand_times):
new_expand_times = []
for ele in expand_times:
if isinstance(ele, Variable):
ele.stop_gradient = True
new_expand_times.append(ele)
else:
assert (isinstance(ele, int))
temp_out = helper.create_variable_for_type_inference(
"int32")
fill_constant(
[1], 'int32', ele, force_cpu=True, out=temp_out)
new_expand_times.append(temp_out)
inputs = {'X': x, 'expand_times_tensor': new_expand_times}
attrs = {}
else:
inputs = {'X': x}
attrs = {'expand_times': expand_times}
helper.append_op( helper.append_op(
type='expand', inputs=inputs, outputs={'Out': out}, attrs=attrs) type='expand', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return out return out
......
...@@ -17,16 +17,24 @@ from __future__ import print_function ...@@ -17,16 +17,24 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid
# Situation 1: expand_times is a list(without tensor)
class TestExpandOpRank1(OpTest): class TestExpandOpRank1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
self.inputs = {'X': np.random.random(12).astype("float32")} self.init_data()
self.attrs = {'expand_times': [2]}
output = np.tile(self.inputs['X'], 2) self.inputs = {'X': np.random.random(self.ori_shape).astype("float32")}
self.attrs = {'expand_times': self.expand_times}
output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output} self.outputs = {'Out': output}
def init_data(self):
self.ori_shape = [12]
self.expand_times = [2]
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -34,51 +42,59 @@ class TestExpandOpRank1(OpTest): ...@@ -34,51 +42,59 @@ class TestExpandOpRank1(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestExpandOpRank1_tensor_attr(OpTest): class TestExpandOpRank2_Corner(TestExpandOpRank1):
def setUp(self): def init_data(self):
self.op_type = "expand" self.ori_shape = [12]
self.inputs = { self.expand_times = [2]
'X': np.random.random(12).astype("float32"),
'expand_times_tensor': [('x1', np.ones((1)).astype('int32') * 2)]
}
self.attrs = {}
output = np.tile(self.inputs['X'], 2)
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self): class TestExpandOpRank2(TestExpandOpRank1):
self.check_grad(['X'], 'Out', no_grad_set=set('x1')) def init_data(self):
self.ori_shape = [12, 14]
self.expand_times = [2, 3]
class TestExpandOpRank2_Corner(OpTest): class TestExpandOpRank3_Corner(TestExpandOpRank1):
def setUp(self): def init_data(self):
self.op_type = "expand" self.ori_shape = (2, 4, 5)
self.inputs = {'X': np.random.random((12, 14)).astype("float32")} self.expand_times = (1, 1, 1)
self.attrs = {'expand_times': [1, 1]}
output = np.tile(self.inputs['X'], (1, 1))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self): class TestExpandOpRank3(TestExpandOpRank1):
self.check_grad(['X'], 'Out') def init_data(self):
self.ori_shape = (2, 4, 5)
self.expand_times = (2, 1, 4)
class TestExpandOpRank4(TestExpandOpRank1):
def init_data(self):
self.ori_shape = (2, 4, 5, 7)
self.expand_times = (3, 2, 1, 2)
class TestExpandOpRank2_Corner_tensor_attr(OpTest):
# Situation 2: expand_times is a list(with tensor)
class TestExpandOpRank1_tensor_attr(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
self.init_data()
expand_times_tensor = []
for index, ele in enumerate(self.expand_times):
expand_times_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = { self.inputs = {
'X': np.random.random((12, 14)).astype("float32"), 'X': np.random.random(self.ori_shape).astype("float32"),
'expand_times_tensor': [('x1', np.ones((1)).astype('int32')), 'expand_times_tensor': expand_times_tensor,
('x2', np.ones((1)).astype('int32'))]
} }
self.attrs = {} self.attrs = {"expand_times": self.infer_expand_times}
output = np.tile(self.inputs['X'], (1, 1)) output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output} self.outputs = {'Out': output}
def init_data(self):
self.ori_shape = [12]
self.expand_times = [2]
self.infer_expand_times = [-1]
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -86,47 +102,37 @@ class TestExpandOpRank2_Corner_tensor_attr(OpTest): ...@@ -86,47 +102,37 @@ class TestExpandOpRank2_Corner_tensor_attr(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestExpandOpRank2(OpTest): class TestExpandOpRank2_Corner_tensor_attr(TestExpandOpRank1_tensor_attr):
def setUp(self): def init_data(self):
self.op_type = "expand" self.ori_shape = [12, 14]
self.inputs = {'X': np.random.random((12, 14)).astype("float32")} self.expand_times = [1, 1]
self.attrs = {'expand_times': [2, 3]} self.infer_expand_times = [1, -1]
output = np.tile(self.inputs['X'], (2, 3))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self): class TestExpandOpRank2_attr_tensor(TestExpandOpRank1_tensor_attr):
self.check_grad(['X'], 'Out') def init_data(self):
self.ori_shape = [12, 14]
self.expand_times = [2, 3]
self.infer_expand_times = [-1, 3]
class TestExpandOpRank2_attr_tensor(OpTest): # Situation 3: expand_times is a tensor
class TestExpandOpRank1_tensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
self.init_data()
self.inputs = { self.inputs = {
'X': np.random.random((12, 14)).astype("float32"), 'X': np.random.random(self.ori_shape).astype("float32"),
'expand_times_tensor': [('x1', np.ones((1)).astype('int32') * 2), 'ExpandTimes': np.array(self.expand_times).astype("int32"),
('x2', np.ones((1)).astype('int32') * 3)]
} }
self.attrs = {} self.attrs = {}
output = np.tile(self.inputs['X'], (2, 3)) output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output} self.outputs = {'Out': output}
def test_check_output(self): def init_data(self):
self.check_output() self.ori_shape = [12]
self.expand_times = [2]
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandOpRank3_Corner(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")}
self.attrs = {'expand_times': [1, 1, 1]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -135,36 +141,13 @@ class TestExpandOpRank3_Corner(OpTest): ...@@ -135,36 +141,13 @@ class TestExpandOpRank3_Corner(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestExpandOpRank3(OpTest): class TestExpandOpRank2_tensor(TestExpandOpRank1_tensor):
def setUp(self): def init_data(self):
self.op_type = "expand" self.ori_shape = [12, 14]
self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")} self.expand_times = [2, 3]
self.attrs = {'expand_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandOpRank4(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {'X': np.random.random((2, 4, 5, 7)).astype("float32")}
self.attrs = {'expand_times': [3, 2, 1, 2]}
output = np.tile(self.inputs['X'], (3, 2, 1, 2))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
# Situation 4: input x is Integer
class TestExpandOpInteger(OpTest): class TestExpandOpInteger(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
...@@ -180,6 +163,7 @@ class TestExpandOpInteger(OpTest): ...@@ -180,6 +163,7 @@ class TestExpandOpInteger(OpTest):
self.check_output() self.check_output()
# Situation 5: input x is Bool
class TestExpandOpBoolean(OpTest): class TestExpandOpBoolean(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
...@@ -192,5 +176,33 @@ class TestExpandOpBoolean(OpTest): ...@@ -192,5 +176,33 @@ class TestExpandOpBoolean(OpTest):
self.check_output() self.check_output()
# Test python API
class TestExpandAPI(OpTest):
def test_api(self):
input = np.random.random([12, 14]).astype("float32")
x = fluid.layers.data(
name='x', shape=[12, 14], append_batch_size=False, dtype="float32")
positive_2 = fluid.layers.fill_constant([1], "int32", 2)
expand_times = fluid.layers.data(
name="expand_times", shape=[2], append_batch_size=False)
out_1 = fluid.layers.expand(x, expand_times=[2, 3])
out_2 = fluid.layers.expand(x, expand_times=[positive_2, 3])
out_3 = fluid.layers.expand(x, expand_times=expand_times)
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3 = exe.run(fluid.default_main_program(),
feed={
"x": input,
"expand_times":
np.array([1, 3]).astype("int32")
},
fetch_list=[out_1, out_2, out_3])
assert np.array_equal(res_1, np.tile(input, (2, 3)))
assert np.array_equal(res_2, np.tile(input, (2, 3)))
assert np.array_equal(res_3, np.tile(input, (1, 3)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册