提交 685d1e3b 编写于 作者: Y Yibing Liu

Enable reshape_op to support dimension inference

上级 a91efdde
...@@ -34,21 +34,27 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -34,21 +34,27 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
// TODO(qiao) change batch_size
for (size_t i = 1; i < shape.size(); ++i) { std::vector<size_t> neg_dims_idx;
PADDLE_ENFORCE(shape[i] > 0, for (size_t i = 0; i < shape.size(); ++i) {
"Each dimension of Attr(shape) " PADDLE_ENFORCE(shape[i] > 0 || shape[i] == -1,
"must be positive except the first one."); "Each dimension of Attr(shape) must be positive or -1.");
if (shape[i] == -1) {
neg_dims_idx.push_back(i);
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
"Only one dimension of Attr(shape) can be -1.");
} }
if (shape[0] < 0) {
shape[0] = x_dims[0];
} }
// capacity check // capacity check
int64_t capacity = int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int64_t in_size = framework::product(x_dims); int64_t in_size = framework::product(x_dims);
PADDLE_ENFORCE_EQ(capacity, in_size, if (neg_dims_idx.size() == 1) {
shape[neg_dims_idx[0]] = in_size / (-capacity);
PADDLE_ENFORCE(shape[neg_dims_idx[0]] > 0,
"The size of Input(X) mismatches with Attr(shape)."); "The size of Input(X) mismatches with Attr(shape).");
}
// resize output // resize output
std::vector<int64_t> shape_int64(shape.size(), 0); std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(), std::transform(shape.begin(), shape.end(), shape_int64.begin(),
...@@ -88,6 +94,9 @@ the tensor X into a 1-D tensor: ...@@ -88,6 +94,9 @@ the tensor X into a 1-D tensor:
[1, 2, 3, 4] [1, 2, 3, 4]
One dimension in the target shape can be set -1, and the real dimension
will be infered from the original shape of Input(X) and other
dimensions in the target shape.
)DOC"); )DOC");
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册