提交 685d1e3b 编写于 作者: Y Yibing Liu

Enable reshape_op to support dimension inference

上级 a91efdde
......@@ -34,21 +34,27 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
auto x_dims = ctx->GetInputDim("X");
// TODO(qiao) change batch_size
for (size_t i = 1; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0,
"Each dimension of Attr(shape) "
"must be positive except the first one.");
}
if (shape[0] < 0) {
shape[0] = x_dims[0];
std::vector<size_t> neg_dims_idx;
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0 || shape[i] == -1,
"Each dimension of Attr(shape) must be positive or -1.");
if (shape[i] == -1) {
neg_dims_idx.push_back(i);
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
"Only one dimension of Attr(shape) can be -1.");
}
}
// capacity check
int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int64_t in_size = framework::product(x_dims);
PADDLE_ENFORCE_EQ(capacity, in_size,
if (neg_dims_idx.size() == 1) {
shape[neg_dims_idx[0]] = in_size / (-capacity);
PADDLE_ENFORCE(shape[neg_dims_idx[0]] > 0,
"The size of Input(X) mismatches with Attr(shape).");
}
// resize output
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
......@@ -88,6 +94,9 @@ the tensor X into a 1-D tensor:
[1, 2, 3, 4]
One dimension in the target shape can be set -1, and the real dimension
will be infered from the original shape of Input(X) and other
dimensions in the target shape.
)DOC");
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册