提交 c078ed46 编写于 作者: G guosheng

Enhance reshape_op by adding Input(Shape)

上级 b7e83d24
...@@ -17,88 +17,18 @@ limitations under the License. */ ...@@ -17,88 +17,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class ReshapeOp : public framework::OperatorWithKernel {
public:
ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReshapeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReshapeOp should not be null.");
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(!shape.empty(),
"The shape information must be set by Attr(shape).");
std::vector<int64_t> output_shape;
auto x_dims = ctx->GetInputDim("X");
auto out_dims = ValidateShape(shape, x_dims);
ctx->SetOutputDim("Out", out_dims);
// NOTE: Reshape op cannot reshape an input sequence batch into an
// output sequence batch that has a different number of time steps. Here
// output always shares the LoD information with input. But if
// Attr(shape) contains 0 or -1, the actual output shape can only be
// determined during runtime. The check for wheather it is a valid
// output sequence batch is performed in runtime.
ctx->ShareLoD("X", /*->*/ "Out");
}
private:
framework::DDim ValidateShape(const std::vector<int> shape,
const framework::DDim &in_dims) const {
const int64_t in_size = framework::product(in_dims);
// only one dimension canbe set to -1, whose size will be automatically
// infered.
const int64_t unk_dim_val = -1;
const int64_t copy_dim_val = 0;
std::vector<int64_t> output_shape(shape.size(), 0);
int64_t capacity = 1;
int unk_dim_idx = -1;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
PADDLE_ENFORCE(
unk_dim_idx == -1,
"Only one input dimension of Attr(shape) can be unknown.");
unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) {
PADDLE_ENFORCE(
static_cast<int>(i) < in_dims.size(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape.");
} else {
PADDLE_ENFORCE(
shape[i] > 0,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension.");
}
capacity *= (shape[i] ? shape[i] : in_dims[i]);
output_shape[i] =
(shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
}
if (unk_dim_idx != -1) {
output_shape[unk_dim_idx] = -in_size / capacity;
PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size,
"Invalid shape is given.");
} else {
PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given.");
}
return framework::make_ddim(output_shape);
}
};
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker) ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of reshape operator."); AddInput("X", "(Tensor). The input tensor of reshape operator.");
AddOutput("Out", "The output tensor of reshape operator."); AddInput("Shape",
"(Tensor<int32>, optional). If provided, reshape according to "
"this given shape. That is to say it has a higher priority than "
"the shape attribute, while the shape attribute still should be "
"set correctly to gurantee shape inference in compile time.")
.AsDispensable();
AddOutput("Out", "(Tensor). The output tensor of reshape operator.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"shape", "(std::vector<int>) Target shape of reshape operator."); "shape", "(std::vector<int>) Target shape of reshape operator.");
AddAttr<bool>("inplace", AddAttr<bool>("inplace",
...@@ -110,8 +40,8 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -110,8 +40,8 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Reshape Operator. Reshape Operator.
Reshape Input(X) into the shape specified by Attr(shape). The data in Input(X) Reshape Input(X) into the shape specified by Attr(shape) or Input(Shape). The
are unchanged. data in Input(X) are unchanged.
Examples: Examples:
...@@ -141,6 +71,9 @@ Input(X) and remaining dimensions. ...@@ -141,6 +71,9 @@ Input(X) and remaining dimensions.
dimension value will be copied from Input(X) at runtime. Note that the index of dimension value will be copied from Input(X) at runtime. Note that the index of
0 can not exceed Rank(X). For example, Input(X) is a 3-D tensor with shape 0 can not exceed Rank(X). For example, Input(X) is a 3-D tensor with shape
[2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input. [2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input.
1. Input(Shape) has a higher priority than Attr(shape) if it is provided, while
Attr(shape) still should be set correctly to gurantee shape inference in
compile-time.
)DOC"); )DOC");
} }
...@@ -160,6 +93,14 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -160,6 +93,14 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) shouldn't be null."); "Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -20,15 +20,115 @@ limitations under the License. */ ...@@ -20,15 +20,115 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class ReshapeOp : public framework::OperatorWithKernel {
public:
ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReshapeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReshapeOp should not be null.");
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(!shape.empty(),
"The shape information must be set by Attr(shape).");
if (ctx->HasInput("Shape") && ctx->IsRuntime()) {
// If true, set the shape of Output(Out) according to Input(Shape) in
// ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel.
ctx->ShareLoD("X", /*->*/ "Out");
return;
}
auto x_dims = ctx->GetInputDim("X");
auto out_dims = ValidateShape(shape, x_dims);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
static framework::DDim ValidateShape(const std::vector<int> shape,
const framework::DDim &in_dims) {
const int64_t in_size = framework::product(in_dims);
// only one dimension canbe set to -1, whose size will be automatically
// infered.
const int64_t unk_dim_val = -1;
const int64_t copy_dim_val = 0;
std::vector<int64_t> output_shape(shape.size(), 0);
int64_t capacity = 1;
int unk_dim_idx = -1;
for (size_t i = 0; i < shape.size(); ++i) {
// std::cout<< shape[i] << "haha";
if (shape[i] == unk_dim_val) {
PADDLE_ENFORCE(
unk_dim_idx == -1,
"Only one input dimension of Attr(shape) can be unknown.");
unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) {
PADDLE_ENFORCE(
static_cast<int>(i) < in_dims.size(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape.");
} else {
PADDLE_ENFORCE(
shape[i] > 0,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension.");
}
capacity *= (shape[i] ? shape[i] : in_dims[i]);
output_shape[i] =
(shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
}
if (unk_dim_idx != -1) {
output_shape[unk_dim_idx] = -in_size / capacity;
PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size,
"Invalid shape is given.");
} else {
PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given.");
}
return framework::make_ddim(output_shape);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ReshapeKernel : public framework::OpKernel<T> { class ReshapeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const { void Compute(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output<framework::LoDTensor>("Out"); auto *out = ctx.Output<framework::LoDTensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("X"); auto *in = ctx.Input<framework::LoDTensor>("X");
auto *shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
auto out_dims = out->dims(); framework::DDim out_dims = out->dims();
if (shape_tensor) {
auto *shape_data = shape_tensor->data<int>();
if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor cpu_shape_tensor;
TensorCopy(*shape_tensor, platform::CPUPlace(), ctx.device_context(),
&cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>();
}
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ReshapeOp::ValidateShape(shape, in->dims());
}
if (!in->lod().empty()) { if (!in->lod().empty()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
out_dims[0], in->dims()[0], out_dims[0], in->dims()[0],
...@@ -39,9 +139,11 @@ class ReshapeKernel : public framework::OpKernel<T> { ...@@ -39,9 +139,11 @@ class ReshapeKernel : public framework::OpKernel<T> {
} }
bool inplace = ctx.Attr<bool>("inplace"); bool inplace = ctx.Attr<bool>("inplace");
out->Resize(out_dims);
if (!inplace) { if (!inplace) {
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out); framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
// TensorCopy will resize to in_dims.
out->Resize(out_dims); out->Resize(out_dims);
} else { } else {
out->ShareDataWith(*in); out->ShareDataWith(*in);
......
...@@ -3320,42 +3320,54 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): ...@@ -3320,42 +3320,54 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
return counter return counter
def reshape(x, shape, act=None, inplace=True, name=None): def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
""" """
Gives a new shape to the input Tensor without changing its data. Gives a new shape to the input Tensor without changing its data.
This layer takes a tensor and the attribute shape which specifies the The target shape can be given by :attr:`shape` or :attr:`actual_shape`.
new shape as its inputs. The shape attribute must be given. It cannot be :attr:`shape` is a list of integer while :attr:`actual_shape` is a tensor
empty. One and only one dimension of shape can be -1. More than one variable. :attr:`actual_shape` has a higher priority than :attr:`shape`
dimension of shape can be 0. if it is provided, while :attr:`shape` still should be set correctly to
gurantee shape inference in compile-time.
-1 means the value of this dimension is inferred from the total element Some tricks exist when specifying the target shape.
number of x and remaining dimensions.
0 means the actual dimension value is going to be copied from the 1. -1 means the value of this dimension is inferred from the total element
corresponding dimension of x. number of x and remaining dimensions. Thus one and only one dimension can
be set -1.
1. 0 means the actual dimension value is going to be copied from the
corresponding dimension of x. The indice of 0s in shape can not exceed
Rank(X).
Here are some examples to explain it.
1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape 1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
specified by Attr(shape) is [6, 8], the reshape operator will transform x is [6, 8], the reshape operator will transform x into a 2-D tensor with
into a 2-D tensor with shape [6, 8] and leaving x's data unchanged. shape [6, 8] and leaving x's data unchanged.
1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape 1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
specified by Attr(shape) is [2, 3, -1, 2], the reshape operator will specified is [2, 3, -1, 2], the reshape operator will transform x into a
transform x into a 4-D tensor with shape [2, 3, 4, 2] and leaving x's data 4-D tensor with shape [2, 3, 4, 2] and leaving x's data unchanged. In this
unchanged. In this case, one and only dimension of Attr(shape) can be set case, one dimension of the target shape is set to -1, the value of this
to -1, the value of this dimension is inferred from the total element number dimension is inferred from the total element number of x and remaining
of x and remaining dimensions. dimensions.
1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape 1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
specified by Attr(shape) is [-1, 0, 3, 2], the reshape operator will is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor
transform x into a 4-D tensor with shape [2, 4, 3, 2] and leaving x's data with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case,
unchanged. In this case, besides -1, 0 means the actual dimension value is besides -1, 0 means the actual dimension value is going to be copied from
going to be copied from the corresponding dimension of x during runtime. the corresponding dimension of x.
Args: Args:
input(variable): The input tensor. input(variable): The input tensor.
shape(list): The new shape. At most one dimension of the new shape can shape(list): The new shape. At most one dimension of the new shape can
be -1. be -1.
actual_shape(variable): An optional input. If provided, reshape
according to this given shape rather than
:attr:`shape` specifying shape. That is to
say :attr:`actual_shape` has a higher priority
than :attr:`shape`.
act (str): The non-linear activation to be applied to output variable. act (str): The non-linear activation to be applied to output variable.
inplace(bool): If this flag is set true, a new output tensor is created inplace(bool): If this flag is set true, a new output tensor is created
whose data is copied from input x, otherwise the output whose data is copied from input x, otherwise the output
...@@ -3366,12 +3378,9 @@ def reshape(x, shape, act=None, inplace=True, name=None): ...@@ -3366,12 +3378,9 @@ def reshape(x, shape, act=None, inplace=True, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
data = fluid.layers.data( data = fluid.layers.data(
name='data', shape=[2, 4, 6], dtype='float32' name='data', shape=[2, 4, 6], dtype='float32')
)
reshaped = fluid.layers.reshape( reshaped = fluid.layers.reshape(
x=data, shape=[-1, 0, 3, 2], act='tanh', inplace=True x=data, shape=[-1, 0, 3, 2], act='tanh', inplace=True)
)
""" """
if not (isinstance(shape, list) or isinstance(shape, tuple)): if not (isinstance(shape, list) or isinstance(shape, tuple)):
...@@ -3396,7 +3405,9 @@ def reshape(x, shape, act=None, inplace=True, name=None): ...@@ -3396,7 +3405,9 @@ def reshape(x, shape, act=None, inplace=True, name=None):
reshaped = helper.create_tmp_variable(dtype=x.dtype) reshaped = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op( helper.append_op(
type="reshape", type="reshape",
inputs={"X": x}, inputs={"X": x,
"Shape": actual_shape}
if isinstance(actual_shape, Variable) else {"X": x},
attrs={"shape": shape, attrs={"shape": shape,
"inplace": inplace}, "inplace": inplace},
outputs={"Out": reshaped}) outputs={"Out": reshaped})
......
...@@ -122,5 +122,27 @@ class TestReshapeOpDimInferInplace2(OpTest): ...@@ -122,5 +122,27 @@ class TestReshapeOpDimInferInplace2(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
class TestReshapeOpWithInputShape(OpTest):
def setUp(self):
ori_shape = (6, 5)
new_shape = (0, -1, 5)
actual_shape = (2, 3, 5)
self.op_type = "reshape"
self.inputs = {
"X": np.random.random(ori_shape).astype("float32"),
"Shape": np.array(
actual_shape, dtype="int32")
}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(actual_shape)}
def test_check_output(self):
self.check_output()
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册