未验证 提交 8062bd51 编写于 作者: H Hongyu Liu 提交者: GitHub

Reshape support tensor attribute (#17781)

* add reshape support tensor; test=develop

* fix reshape bug; test=develop

* change reshape attribute default value; test=develop

* fix reshape input name; test=develop

* fix reshape unitest; test=develop

* check dim tensor shape; test=develop
上级 972c54cd
...@@ -19,6 +19,29 @@ limitations under the License. */ ...@@ -19,6 +19,29 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
inline std::vector<int> get_new_shape(
const std::vector<const Tensor *> &list_new_shape_tensor) {
// get tensor from
std::vector<int> vec_new_shape;
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
"shape of dim tensor should be [1]");
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
} else {
vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
}
return vec_new_shape;
}
class ReshapeOp : public framework::OperatorWithKernel { class ReshapeOp : public framework::OperatorWithKernel {
public: public:
ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs, ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs,
...@@ -32,17 +55,24 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -32,17 +55,24 @@ class ReshapeOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReshapeOp should not be null."); "Output(Out) of ReshapeOp should not be null.");
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape"); if (ctx->HasInputs("ShapeTensor")) {
PADDLE_ENFORCE(!shape.empty(), // top prority shape
"The shape information must be set by Attr(shape)."); auto inputs_name = ctx->Inputs("ShapeTensor");
PADDLE_ENFORCE(inputs_name.size() > 0, "shape tensor size can't be zero");
auto out_dims = std::vector<int>(inputs_name.size(), -1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
return;
}
if (ctx->HasInput("Shape") && ctx->IsRuntime()) { if (ctx->HasInput("Shape") && ctx->IsRuntime()) {
// If true, set the shape of Output(Out) according to Input(Shape) in // If true, set the shape of Output(Out) according to Input(Shape) in
// ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel. // ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel.
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
return; return;
} }
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(!shape.empty(),
"The shape information must be set by Attr(shape).");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto out_dims = ValidateShape(shape, x_dims); auto out_dims = ValidateShape(shape, x_dims);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
...@@ -114,6 +144,16 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -114,6 +144,16 @@ class ReshapeOp : public framework::OperatorWithKernel {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context()); ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "ShapeTensor") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}; };
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -126,9 +166,18 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -126,9 +166,18 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
"the shape attribute, while the shape attribute still should be " "the shape attribute, while the shape attribute still should be "
"set correctly to gurantee shape inference in compile time.") "set correctly to gurantee shape inference in compile time.")
.AsDispensable(); .AsDispensable();
AddInput(
"ShapeTensor",
"(vector<Tensor<int32>>, optional). If provided, reshape will use this"
"The shape of the tensor in vector MUST BE [1]"
"it has the highest priority compare with Input(Shape) and "
"attr(shape).")
.AsDuplicable()
.AsDispensable();
AddOutput("Out", "(Tensor). The output tensor of reshape operator."); AddOutput("Out", "(Tensor). The output tensor of reshape operator.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"shape", "(std::vector<int>) Target shape of reshape operator."); "shape", "(std::vector<int>) Target shape of reshape operator.")
.SetDefault({});
AddComment(R"DOC( AddComment(R"DOC(
Reshape Operator. Reshape Operator.
...@@ -202,24 +251,35 @@ class ReshapeKernel { ...@@ -202,24 +251,35 @@ class ReshapeKernel {
auto *out = ctx.Output<framework::LoDTensor>("Out"); auto *out = ctx.Output<framework::LoDTensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("X"); auto *in = ctx.Input<framework::LoDTensor>("X");
framework::DDim out_dims = out->dims();
auto list_new_shape_tensor =
ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (list_new_shape_tensor.size() > 0) {
// have shape tensor
auto new_shape = get_new_shape(list_new_shape_tensor);
out_dims = ReshapeOp::ValidateShape(new_shape, in->dims());
} else {
auto *shape_tensor = ctx.HasInput("Shape") auto *shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape") ? ctx.Input<framework::LoDTensor>("Shape")
: nullptr; : nullptr;
framework::DDim out_dims = out->dims();
if (shape_tensor) { if (shape_tensor) {
auto *shape_data = shape_tensor->data<int>(); auto *shape_data = shape_tensor->data<int>();
framework::Tensor cpu_shape_tensor; framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(shape_tensor->place())) { if (platform::is_gpu_place(shape_tensor->place())) {
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); TensorCopySync(*shape_tensor, platform::CPUPlace(),
&cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>(); shape_data = cpu_shape_tensor.data<int>();
} }
auto shape = auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel()); std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ReshapeOp::ValidateShape(shape, in->dims()); out_dims = ReshapeOp::ValidateShape(shape, in->dims());
} }
}
out->Resize(out_dims);
out->mutable_data(ctx.GetPlace(), in->type()); out->mutable_data(ctx.GetPlace(), in->type());
framework::TensorCopy( framework::TensorCopy(
*in, ctx.GetPlace(), *in, ctx.GetPlace(),
...@@ -288,6 +348,7 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker { ...@@ -288,6 +348,7 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker {
auto *grad_op = new framework::OpDesc(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("reshape2_grad"); grad_op->SetType("reshape2_grad");
grad_op->SetInput("XShape", Output("XShape")); grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput("ShapeTensor", Input("ShapeTensor"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(Attrs());
...@@ -320,6 +381,16 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -320,6 +381,16 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.device_context()); ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "ShapeTensor") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}; };
class ReshapeOpInplaceInToOut : public framework::InplaceOpInference { class ReshapeOpInplaceInToOut : public framework::InplaceOpInference {
......
...@@ -6703,6 +6703,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): ...@@ -6703,6 +6703,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
if not (isinstance(shape, list) or isinstance(shape, tuple)): if not (isinstance(shape, list) or isinstance(shape, tuple)):
raise ValueError("Input shape must be a python list or tuple.") raise ValueError("Input shape must be a python list or tuple.")
inputs = {"X": x} inputs = {"X": x}
if isinstance(actual_shape, Variable): if isinstance(actual_shape, Variable):
inputs["Shape"] = actual_shape inputs["Shape"] = actual_shape
...@@ -6711,7 +6712,12 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): ...@@ -6711,7 +6712,12 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
# Validate the shape # Validate the shape
unk_dim_idx = -1 unk_dim_idx = -1
contain_var = False
for dim_idx, dim_size in enumerate(shape): for dim_idx, dim_size in enumerate(shape):
if isinstance(dim_size, Variable):
contain_var = True
continue
if dim_size == -1: if dim_size == -1:
assert unk_dim_idx == -1, ( assert unk_dim_idx == -1, (
"Only one dimension in shape can be unknown.") "Only one dimension in shape can be unknown.")
...@@ -6725,13 +6731,35 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): ...@@ -6725,13 +6731,35 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
"except one unknown dimension.") "except one unknown dimension.")
helper = LayerHelper("reshape2", **locals()) helper = LayerHelper("reshape2", **locals())
if in_dygraph_mode():
inputs = {'X': x}
attrs = {'shape': shape}
else:
if contain_var:
new_shape_tensor = []
for dim in shape:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_shape_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = helper.create_variable_for_type_inference(
'int32')
fill_constant(
[1], 'int32', dim, force_cpu=True, out=temp_out)
new_shape_tensor.append(temp_out)
inputs['ShapeTensor'] = new_shape_tensor
attrs = {}
else:
attrs = {'shape': shape}
out = x if inplace else helper.create_variable_for_type_inference( out = x if inplace else helper.create_variable_for_type_inference(
dtype=x.dtype) dtype=x.dtype)
x_shape = helper.create_variable_for_type_inference(dtype=x.dtype) x_shape = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type="reshape2", type="reshape2",
inputs=inputs, inputs=inputs,
attrs={"shape": shape}, attrs=attrs,
outputs={"Out": out, outputs={"Out": out,
"XShape": x_shape}) "XShape": x_shape})
......
...@@ -37,6 +37,7 @@ class TestReshapeOp(OpTest): ...@@ -37,6 +37,7 @@ class TestReshapeOp(OpTest):
self.infered_shape = (5, 10) self.infered_shape = (5, 10)
def test_check_output(self): def test_check_output(self):
self.check_output(no_check_set=['XShape']) self.check_output(no_check_set=['XShape'])
def test_check_grad(self): def test_check_grad(self):
...@@ -82,5 +83,51 @@ class TestReshapeOpWithInputShape(OpTest): ...@@ -82,5 +83,51 @@ class TestReshapeOpWithInputShape(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
class TestReshapeOp_attr_tensor(OpTest):
def setUp(self):
self.init_data()
self.op_type = "reshape2"
shape_tensor = []
for index, ele in enumerate(self.new_shape):
shape_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {
"X": np.random.random(self.ori_shape).astype("float32"),
'ShapeTensor': shape_tensor
}
self.attrs = {}
self.outputs = {
"Out": self.inputs["X"].reshape(self.infered_shape),
'XShape': np.random.random(self.ori_shape).astype("float32")
}
def init_data(self):
self.ori_shape = (2, 25)
self.new_shape = (5, 10)
self.infered_shape = (5, 10)
def test_check_output(self):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer1_attr_tensor(TestReshapeOp_attr_tensor):
def init_data(self):
self.ori_shape = (5, 10)
self.new_shape = (5, -1, 5)
self.infered_shape = (5, -1, 5)
class TestReshapeOpDimInfer2_attr_tensor(TestReshapeOp_attr_tensor):
def init_data(self):
self.ori_shape = (2, 2, 6)
self.new_shape = (2, 0, 3, -1)
self.infered_shape = (2, 2, 3, -1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册