提交 d37ed6cb 编写于 作者: Y Yibing Liu

polish code in reshape_op

上级 5ac8a0be
...@@ -36,10 +36,13 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -36,10 +36,13 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
std::vector<size_t> neg_dims_idx; std::vector<size_t> neg_dims_idx;
// set some dimension to -1 if it is unknown
const int unknown_size = -1;
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0 || shape[i] == -1, PADDLE_ENFORCE(shape[i] > 0 || shape[i] == unknown_size,
"Each dimension of Attr(shape) must be positive or -1."); "Each dimension of Attr(shape) must be positive or %d.",
if (shape[i] == -1) { unknown_size);
if (shape[i] == unknown_size) {
neg_dims_idx.push_back(i); neg_dims_idx.push_back(i);
PADDLE_ENFORCE(neg_dims_idx.size() <= 1, PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
"Only one dimension of Attr(shape) can be unknown."); "Only one dimension of Attr(shape) can be unknown.");
...@@ -53,8 +56,7 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -53,8 +56,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
// dim infer // dim infer
shape[neg_dims_idx[0]] = in_size / (-capacity); shape[neg_dims_idx[0]] = in_size / (-capacity);
// recalculate capacity // recalculate capacity
capacity = std::accumulate(shape.begin(), shape.end(), 1, capacity = shape[neg_dims_idx[0]] * (-capacity);
std::multiplies<int>());
} }
// capacity check // capacity check
PADDLE_ENFORCE(capacity == in_size, PADDLE_ENFORCE(capacity == in_size,
...@@ -98,9 +100,9 @@ the tensor X into a 2-D tensor: ...@@ -98,9 +100,9 @@ the tensor X into a 2-D tensor:
[[1, 2, 3, 4]] [[1, 2, 3, 4]]
One dimension in the target shape can be set -1, and the real dimension One dimension in the target shape can be set -1, representing that its
will be infered from the original shape of Input(X) and other size is unknown. In this case, the real dimension will be infered from
dimensions in the target shape. the original shape of Input(X) and other dimensions in the target shape.
)DOC"); )DOC");
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册