提交 e9e3c087 编写于 作者: L liym27 提交者: Aurelius84

fix expand op: (#19302)

1. add tensor support for argument expand_times in expand op;
2. add support parameter inference when argument expand_times is a list containing integer and tensor variable;

improve expand op according to reviews:
1. add doc of ExpandTimes in expand_op.cc;
2. improve the test of test_api.

add stop_gradient=True when attr(expand_times) is tensor Variable, change code examples.
test=develop,test=document_preview
上级 6bf298bf
......@@ -226,7 +226,7 @@ paddle.fluid.layers.unstack (ArgSpec(args=['x', 'axis', 'num'], varargs=None, ke
paddle.fluid.layers.sequence_enumerate (ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'b870fed41abd2aecf929ece65f555fa1'))
paddle.fluid.layers.unique (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', 'cab0b06e5683875f12f0efc62fa230a9'))
paddle.fluid.layers.unique_with_counts (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', '1cb59c65b41766116944b8ed1e6ad345'))
paddle.fluid.layers.expand (ArgSpec(args=['x', 'expand_times', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '33bc4f6010282ffe044d77be7ba7c275'))
paddle.fluid.layers.expand (ArgSpec(args=['x', 'expand_times', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7b97042c3ba55fb5fec6a06308523b73'))
paddle.fluid.layers.sequence_concat (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b992616c1afbd6b0c2a897ac23036381'))
paddle.fluid.layers.scale (ArgSpec(args=['x', 'scale', 'bias', 'bias_after_scale', 'act', 'name'], varargs=None, keywords=None, defaults=(1.0, 0.0, True, None, None)), ('document', '463e4713806e5adaa4d20a41e2218453'))
paddle.fluid.layers.elementwise_add (ArgSpec(args=['x', 'y', 'axis', 'act', 'name'], varargs=None, keywords=None, defaults=(-1, None, None)), ('document', '5c0fb7298aec32525f96d451ae4c2851'))
......
......@@ -28,14 +28,15 @@ class ExpandOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) should not be null.");
auto x_dims = ctx->GetInputDim("X");
std::vector<int> expand_times(x_dims.size(), -1);
auto expand_times = ctx->Attrs().Get<std::vector<int>>("expand_times");
if (!ctx->HasInputs("expand_times_tensor")) {
expand_times = ctx->Attrs().Get<std::vector<int>>("expand_times");
if (expand_times.size() == 0) {
expand_times = std::vector<int>(x_dims.size(), -1);
}
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(),
......@@ -49,6 +50,9 @@ class ExpandOp : public framework::OperatorWithKernel {
if (x_dims[i] == -1 || expand_times[i] == -1) {
out_shape[i] = -1;
} else {
PADDLE_ENFORCE_GT(
expand_times[i], 0,
"The element of Attr(expand_times) must greater than 0.");
out_shape[i] = x_dims[i] * expand_times[i];
}
}
......@@ -69,7 +73,7 @@ class ExpandOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "expand_times_tensor") {
if (var_name == "expand_times_tensor" || var_name == "ExpandTimes") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
......@@ -83,7 +87,15 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded.");
AddInput("expand_times_tensor", "(Tensor Tensor<int>), epxand times for X")
AddInput("ExpandTimes",
"(Tensor<int>), optional). If provided, expand according to "
"this given expand times. It has a higher priority than "
"expand_times_tensor and expand_times.")
.AsDispensable();
AddInput("expand_times_tensor",
"(Tensor Tensor<int>), epxand times for X."
"It has a higher priority than expand_times, but a lower priority "
"than ExpandTimes")
.AsDuplicable()
.AsDispensable();
AddOutput("Out",
......@@ -127,9 +139,9 @@ class ExpandGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
std::vector<int> expand_times =
......@@ -147,12 +159,15 @@ class ExpandGradOp : public framework::OperatorWithKernel {
}
for (size_t i = start_pos; i < expand_times.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
"Each dimension size of Input(Out@GRAD) should be "
"equal to multiplication of crroresponding dimension "
"size of Input(X) and Attr(expand_times) value.");
if (expand_times[i] == -1) {
continue;
} else {
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
"Each dimension size of Input(Out@GRAD) should be "
"equal to multiplication of crroresponding dimension "
"size of Input(X) and Attr(expand_times) value.");
}
}
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
......@@ -191,6 +206,7 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetInput("expand_times_tensor", Input("expand_times_tensor"));
op->SetInput("ExpandTimes", Input("ExpandTimes"));
op->SetAttrMap(Attrs());
return op;
}
......
......@@ -50,6 +50,19 @@ namespace paddle {
namespace operators {
inline std::vector<int> get_expand_times(
const framework::ExecutionContext& ctx) {
if (ctx.HasInput("ExpandTimes")) {
auto* expand_tensor = ctx.Input<framework::LoDTensor>("ExpandTimes");
auto* expand_data = expand_tensor->data<int>();
framework::Tensor cpu_expand_tensor;
if (platform::is_gpu_place(expand_tensor->place())) {
TensorCopySync(*expand_tensor, platform::CPUPlace(), &cpu_expand_tensor);
expand_data = cpu_expand_tensor.data<int>();
}
auto vec_epxand_times =
std::vector<int>(expand_data, expand_data + expand_tensor->numel());
return vec_epxand_times;
}
auto list_expand_times_tensor =
ctx.MultiInput<framework::Tensor>("expand_times_tensor");
if (list_expand_times_tensor.size() > 0) {
......@@ -100,6 +113,9 @@ class ExpandKernel : public framework::OpKernel<T> {
auto in_dims = in0->dims();
auto expand_times = get_expand_times(context);
PADDLE_ENFORCE_EQ(static_cast<size_t>(in_dims.size()), expand_times.size(),
"The number of Attr(expand_times)'s value must be equal "
"to the rank of Input(X).");
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims;
for (size_t i = 0; i < expand_times.size(); ++i) {
......
......@@ -10290,7 +10290,7 @@ def expand(x, expand_times, name=None):
Args:
x (Variable): A tensor with rank in [1, 6].
expand_times (list|tuple): Expand times number for each dimension.
expand_times (list|tuple|Variable): Expand times number for each dimension.
Returns:
Variable: The expanded variable which is a LoDTensor. After expanding, size of each dimension of Output(Out) is equal to ithe size of the corresponding dimension of Input(X) multiplying the corresponding value given by expand_times.
......@@ -10298,46 +10298,72 @@ def expand(x, expand_times, name=None):
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.fill_constant(shape=[2, 3, 1], dtype='int32', value=0)
out = fluid.layers.expand(x=x, expand_times=[1, 2, 2])
# example 1:
data_1 = fluid.layers.fill_constant(shape=[2, 3, 1], dtype='int32', value=0)
expanded_1 = fluid.layers.expand(data_1, expand_times=[1, 2, 2])
# example 2:
data_2 = fluid.layers.fill_constant(shape=[12, 14], dtype="int32", value=3)
expand_times = fluid.layers.fill_constant(shape=[2], dtype="int32", value=4)
expanded_2 = fluid.layers.expand(data_2, expand_times=expand_times)
"""
if not isinstance(expand_times, (list, tuple, Variable)):
raise ValueError(
"Input expand_times must be an Variable, python list or tuple.")
helper = LayerHelper('expand', input=x, **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
# check expand_times have tensor
inputs = {"X": x}
attrs = {}
def contain_var(expand_times):
for ele in expand_times:
if isinstance(ele, Variable):
return True
return False
def get_attr_expand_times(list_expand_times):
attrs_expand_times = []
for idx, times in enumerate(list_expand_times):
if isinstance(times, Variable):
attrs_expand_times.append(-1)
else:
attrs_expand_times.append(times)
assert times > 0, (
"Each element given in expand_times must not be negtive.")
return attrs_expand_times
def get_new_expand_times_tensor(list_expand_times):
new_expand_times_tensor = []
for ele in list_expand_times:
if isinstance(ele, Variable):
ele.stop_gradient = True
new_expand_times_tensor.append(ele)
else:
assert (isinstance(ele, int))
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', ele, force_cpu=True, out=temp_out)
new_expand_times_tensor.append(temp_out)
return new_expand_times_tensor
if in_dygraph_mode():
inputs = {'X': x}
attrs = {'expand_times': expand_times}
else:
if isinstance(expand_times, Variable):
expand_times.stop_gradient = True
inputs['ExpandTimes'] = expand_times
elif isinstance(expand_times, (list, tuple)):
attrs['expand_times'] = get_attr_expand_times(expand_times)
if contain_var(expand_times):
inputs['expand_times_tensor'] = get_new_expand_times_tensor(
expand_times)
def contain_tensor(expand_times):
for ele in expand_times:
if isinstance(ele, Variable):
return True
return False
if contain_tensor(expand_times):
new_expand_times = []
for ele in expand_times:
if isinstance(ele, Variable):
ele.stop_gradient = True
new_expand_times.append(ele)
else:
assert (isinstance(ele, int))
temp_out = helper.create_variable_for_type_inference(
"int32")
fill_constant(
[1], 'int32', ele, force_cpu=True, out=temp_out)
new_expand_times.append(temp_out)
inputs = {'X': x, 'expand_times_tensor': new_expand_times}
attrs = {}
else:
inputs = {'X': x}
attrs = {'expand_times': expand_times}
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='expand', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return out
......
......@@ -17,16 +17,24 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
# Situation 1: expand_times is a list(without tensor)
class TestExpandOpRank1(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {'X': np.random.random(12).astype("float32")}
self.attrs = {'expand_times': [2]}
output = np.tile(self.inputs['X'], 2)
self.init_data()
self.inputs = {'X': np.random.random(self.ori_shape).astype("float32")}
self.attrs = {'expand_times': self.expand_times}
output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output}
def init_data(self):
self.ori_shape = [12]
self.expand_times = [2]
def test_check_output(self):
self.check_output()
......@@ -34,51 +42,59 @@ class TestExpandOpRank1(OpTest):
self.check_grad(['X'], 'Out')
class TestExpandOpRank1_tensor_attr(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {
'X': np.random.random(12).astype("float32"),
'expand_times_tensor': [('x1', np.ones((1)).astype('int32') * 2)]
}
self.attrs = {}
output = np.tile(self.inputs['X'], 2)
self.outputs = {'Out': output}
class TestExpandOpRank2_Corner(TestExpandOpRank1):
def init_data(self):
self.ori_shape = [12]
self.expand_times = [2]
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', no_grad_set=set('x1'))
class TestExpandOpRank2(TestExpandOpRank1):
def init_data(self):
self.ori_shape = [12, 14]
self.expand_times = [2, 3]
class TestExpandOpRank2_Corner(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {'X': np.random.random((12, 14)).astype("float32")}
self.attrs = {'expand_times': [1, 1]}
output = np.tile(self.inputs['X'], (1, 1))
self.outputs = {'Out': output}
class TestExpandOpRank3_Corner(TestExpandOpRank1):
def init_data(self):
self.ori_shape = (2, 4, 5)
self.expand_times = (1, 1, 1)
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandOpRank3(TestExpandOpRank1):
def init_data(self):
self.ori_shape = (2, 4, 5)
self.expand_times = (2, 1, 4)
class TestExpandOpRank4(TestExpandOpRank1):
def init_data(self):
self.ori_shape = (2, 4, 5, 7)
self.expand_times = (3, 2, 1, 2)
class TestExpandOpRank2_Corner_tensor_attr(OpTest):
# Situation 2: expand_times is a list(with tensor)
class TestExpandOpRank1_tensor_attr(OpTest):
def setUp(self):
self.op_type = "expand"
self.init_data()
expand_times_tensor = []
for index, ele in enumerate(self.expand_times):
expand_times_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {
'X': np.random.random((12, 14)).astype("float32"),
'expand_times_tensor': [('x1', np.ones((1)).astype('int32')),
('x2', np.ones((1)).astype('int32'))]
'X': np.random.random(self.ori_shape).astype("float32"),
'expand_times_tensor': expand_times_tensor,
}
self.attrs = {}
output = np.tile(self.inputs['X'], (1, 1))
self.attrs = {"expand_times": self.infer_expand_times}
output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output}
def init_data(self):
self.ori_shape = [12]
self.expand_times = [2]
self.infer_expand_times = [-1]
def test_check_output(self):
self.check_output()
......@@ -86,47 +102,37 @@ class TestExpandOpRank2_Corner_tensor_attr(OpTest):
self.check_grad(['X'], 'Out')
class TestExpandOpRank2(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {'X': np.random.random((12, 14)).astype("float32")}
self.attrs = {'expand_times': [2, 3]}
output = np.tile(self.inputs['X'], (2, 3))
self.outputs = {'Out': output}
class TestExpandOpRank2_Corner_tensor_attr(TestExpandOpRank1_tensor_attr):
def init_data(self):
self.ori_shape = [12, 14]
self.expand_times = [1, 1]
self.infer_expand_times = [1, -1]
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandOpRank2_attr_tensor(TestExpandOpRank1_tensor_attr):
def init_data(self):
self.ori_shape = [12, 14]
self.expand_times = [2, 3]
self.infer_expand_times = [-1, 3]
class TestExpandOpRank2_attr_tensor(OpTest):
# Situation 3: expand_times is a tensor
class TestExpandOpRank1_tensor(OpTest):
def setUp(self):
self.op_type = "expand"
self.init_data()
self.inputs = {
'X': np.random.random((12, 14)).astype("float32"),
'expand_times_tensor': [('x1', np.ones((1)).astype('int32') * 2),
('x2', np.ones((1)).astype('int32') * 3)]
'X': np.random.random(self.ori_shape).astype("float32"),
'ExpandTimes': np.array(self.expand_times).astype("int32"),
}
self.attrs = {}
output = np.tile(self.inputs['X'], (2, 3))
output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandOpRank3_Corner(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")}
self.attrs = {'expand_times': [1, 1, 1]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}
def init_data(self):
self.ori_shape = [12]
self.expand_times = [2]
def test_check_output(self):
self.check_output()
......@@ -135,36 +141,13 @@ class TestExpandOpRank3_Corner(OpTest):
self.check_grad(['X'], 'Out')
class TestExpandOpRank3(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")}
self.attrs = {'expand_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandOpRank4(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {'X': np.random.random((2, 4, 5, 7)).astype("float32")}
self.attrs = {'expand_times': [3, 2, 1, 2]}
output = np.tile(self.inputs['X'], (3, 2, 1, 2))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandOpRank2_tensor(TestExpandOpRank1_tensor):
def init_data(self):
self.ori_shape = [12, 14]
self.expand_times = [2, 3]
# Situation 4: input x is Integer
class TestExpandOpInteger(OpTest):
def setUp(self):
self.op_type = "expand"
......@@ -180,6 +163,7 @@ class TestExpandOpInteger(OpTest):
self.check_output()
# Situation 5: input x is Bool
class TestExpandOpBoolean(OpTest):
def setUp(self):
self.op_type = "expand"
......@@ -192,5 +176,33 @@ class TestExpandOpBoolean(OpTest):
self.check_output()
# Test python API
class TestExpandAPI(OpTest):
def test_api(self):
input = np.random.random([12, 14]).astype("float32")
x = fluid.layers.data(
name='x', shape=[12, 14], append_batch_size=False, dtype="float32")
positive_2 = fluid.layers.fill_constant([1], "int32", 2)
expand_times = fluid.layers.data(
name="expand_times", shape=[2], append_batch_size=False)
out_1 = fluid.layers.expand(x, expand_times=[2, 3])
out_2 = fluid.layers.expand(x, expand_times=[positive_2, 3])
out_3 = fluid.layers.expand(x, expand_times=expand_times)
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3 = exe.run(fluid.default_main_program(),
feed={
"x": input,
"expand_times":
np.array([1, 3]).astype("int32")
},
fetch_list=[out_1, out_2, out_3])
assert np.array_equal(res_1, np.tile(input, (2, 3)))
assert np.array_equal(res_2, np.tile(input, (2, 3)))
assert np.array_equal(res_3, np.tile(input, (1, 3)))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册