提交 d37ed6cb 编写于 作者: Y Yibing Liu

polish code in reshape_op

上级 5ac8a0be
......@@ -36,10 +36,13 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
std::vector<size_t> neg_dims_idx;
// set some dimension to -1 if it is unknown
const int unknown_size = -1;
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0 || shape[i] == -1,
"Each dimension of Attr(shape) must be positive or -1.");
if (shape[i] == -1) {
PADDLE_ENFORCE(shape[i] > 0 || shape[i] == unknown_size,
"Each dimension of Attr(shape) must be positive or %d.",
unknown_size);
if (shape[i] == unknown_size) {
neg_dims_idx.push_back(i);
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
"Only one dimension of Attr(shape) can be unknown.");
......@@ -53,8 +56,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
// dim infer
shape[neg_dims_idx[0]] = in_size / (-capacity);
// recalculate capacity
capacity = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int>());
capacity = shape[neg_dims_idx[0]] * (-capacity);
}
// capacity check
PADDLE_ENFORCE(capacity == in_size,
......@@ -98,9 +100,9 @@ the tensor X into a 2-D tensor:
[[1, 2, 3, 4]]
One dimension in the target shape can be set -1, and the real dimension
will be infered from the original shape of Input(X) and other
dimensions in the target shape.
One dimension in the target shape can be set -1, representing that its
size is unknown. In this case, the real dimension will be infered from
the original shape of Input(X) and other dimensions in the target shape.
)DOC");
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册