未验证 提交 c9076543 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #6609 from kuke/fix_reshape_op

Enable reshape_op to support dimension inference
...@@ -34,21 +34,33 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -34,21 +34,33 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
// TODO(qiao) change batch_size
for (size_t i = 1; i < shape.size(); ++i) { std::vector<size_t> neg_dims_idx;
PADDLE_ENFORCE(shape[i] > 0, // set some dimension to -1 if it is unknown
"Each dimension of Attr(shape) " const int unknown_size = -1;
"must be positive except the first one."); for (size_t i = 0; i < shape.size(); ++i) {
} PADDLE_ENFORCE(shape[i] > 0 || shape[i] == unknown_size,
if (shape[0] < 0) { "Each dimension of Attr(shape) must be positive or %d.",
shape[0] = x_dims[0]; unknown_size);
if (shape[i] == unknown_size) {
neg_dims_idx.push_back(i);
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
"Only one dimension of Attr(shape) can be unknown.");
}
} }
// capacity check
int64_t capacity = int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int64_t in_size = framework::product(x_dims); int64_t in_size = framework::product(x_dims);
PADDLE_ENFORCE_EQ(capacity, in_size, if (neg_dims_idx.size() == 1) {
"The size of Input(X) mismatches with Attr(shape)."); // dim infer
shape[neg_dims_idx[0]] = in_size / (-capacity);
// recalculate capacity
capacity = shape[neg_dims_idx[0]] * (-capacity);
}
// capacity check
PADDLE_ENFORCE(capacity == in_size,
"The size of Input(X) mismatches with Attr(shape).");
// resize output // resize output
std::vector<int64_t> shape_int64(shape.size(), 0); std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(), std::transform(shape.begin(), shape.end(), shape_int64.begin(),
...@@ -88,6 +100,9 @@ the tensor X into a 2-D tensor: ...@@ -88,6 +100,9 @@ the tensor X into a 2-D tensor:
[[1, 2, 3, 4]] [[1, 2, 3, 4]]
One dimension in the target shape can be set -1, representing that its
size is unknown. In this case, the real dimension will be infered from
the original shape of Input(X) and other dimensions in the target shape.
)DOC"); )DOC");
} }
}; };
......
...@@ -17,5 +17,19 @@ class TestReshapeOp(OpTest): ...@@ -17,5 +17,19 @@ class TestReshapeOp(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer(OpTest):
def setUp(self):
self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [4, -1, 5]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册