提交 d3d16f76 编写于 作者: Y ying

enhance reshape operator.

上级 ea4e6c7a
......@@ -31,48 +31,69 @@ class ReshapeOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReshapeOp should not be null.");
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ(shape.empty(), ctx->HasInput("Shape"),
"The shape information can only be set by Attr(shape) or "
"by Input(Shape). Attr(shape) and Input(Shape) cannot be "
"set at the same time.");
auto x_dims = ctx->GetInputDim("X");
std::vector<size_t> neg_dims_idx;
// set some dimension to -1 if it is unknown
const int unknown_size = -1;
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0 || shape[i] == unknown_size,
"Each dimension of Attr(shape) must be positive or %d.",
unknown_size);
if (shape[i] == unknown_size) {
neg_dims_idx.push_back(i);
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
"Only one dimension of Attr(shape) can be unknown.");
}
}
if (ctx->HasInput("Shape")) {
auto shape_dims = ctx->GetInputDim("Shape");
int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int64_t in_size = framework::product(x_dims);
if (neg_dims_idx.size() == 1) {
// dim infer
shape[neg_dims_idx[0]] = in_size / (-capacity);
// recalculate capacity
capacity = shape[neg_dims_idx[0]] * (-capacity);
PADDLE_ENFORCE(shape_dims.size() == 2UL && shape_dims[0] == 1UL,
"The Input(Label) should be a 2-D tensor with the 1st "
"dimensions fixed to 1 (a row vector).");
// The actual output shape will be set at runtime, here temporially the
// the shape of output the same as the shape of input.
ctx->SetOutputDim("Out", x_dims);
} else {
std::vector<int64_t> output_shape;
ValidateShape(shape, framework::product(x_dims), output_shape);
auto out_dims = framework::make_ddim(output_shape);
ctx->SetOutputDim("Out", out_dims);
}
// capacity check
PADDLE_ENFORCE(capacity == in_size,
"The size of Input(X) mismatches with Attr(shape).");
// resize output
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto out_dims = framework::make_ddim(shape_int64);
ctx->SetOutputDim("Out", out_dims);
if (shape[0] == x_dims[0]) {
// Only pass LoD when the first dimension is equal between
// output and input.
// Only pass LoD when the first dimension of output and input are the
// same.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
private:
void ValidateShape(const std::vector<int> &shape, const int64_t in_size,
std::vector<int64_t> &output_shape) const {
std::vector<size_t> neg_dims_idx;
const int unknown_index = -1; // only one dimension canbe set to -1, whose
// size will be automatically infered.
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 1 || shape[i] == unknown_index,
"Each input dimension of Attr(shape) must be positive, or "
"only one input dimension can be -1.");
if (shape[i] == unknown_index) neg_dims_idx.push_back(i);
}
PADDLE_ENFORCE_LE(
neg_dims_idx.size(), 1,
"Only one input dimension of Attr(shape) may be unknown.");
int64_t inferred_dim = 0;
if (neg_dims_idx.size()) {
int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int>());
inferred_dim = in_size / (-capacity);
}
output_shape.resize(shape.size(), 0);
std::transform(shape.begin(), shape.end(), output_shape.begin(),
[](int a) { return static_cast<int64_t>(a); });
if (neg_dims_idx.size()) output_shape[neg_dims_idx[0]] = inferred_dim;
}
};
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -80,10 +101,12 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of reshape operator.");
AddInput("Shape", "a 1-D tensor that provides the shape information.")
.AsDispensable();
AddOutput("Out", "The output tensor of reshape operator.");
AddAttr<std::vector<int>>("shape",
"(vector<int>) "
"Target shape of reshape operator.");
"(vector<int>) Target shape of reshape operator.")
.SetDefault(std::vector<int>());
AddComment(R"DOC(
Reshape Operator.
......@@ -96,7 +119,7 @@ and target shape = [1, 4], the reshape operator will transform
the tensor X into a 2-D tensor: [[1, 2, 3, 4]]
One dimension in the target shape can be set -1, representing that its
size is unknown. In this case, the real dimension will be infered from
size is unknown. In this case, the real dimension will be infered from
the original shape of Input(X) and other dimensions in the target shape.
)DOC");
}
......
......@@ -26,11 +26,57 @@ class ReshapeKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::Tensor>("X");
auto out_dims = out->dims();
auto* shape = ctx.Input<framework::Tensor>("Shape");
framework::DDim out_dims;
if (shape) {
std::vector<int64_t> output_shape;
ValidateShape(*shape, framework::product(in->dims()), output_shape);
for (auto d : output_shape) std::cout << d << " ";
std::cout << std::endl;
out_dims = framework::make_ddim(output_shape);
} else {
out_dims = out->dims();
}
out->mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
out->Resize(out_dims);
}
private:
void ValidateShape(const framework::Tensor& shape, const int64_t in_size,
std::vector<int64_t>& output_shape) const {
std::vector<size_t> neg_dims_idx;
const int unknown_index = -1; // only one dimension canbe set to -1, whose
// size will be automatically infered.
const int64_t dimension = shape.dims()[1];
std::cout << "dimension =" << dimension << std::endl;
const T* shape_data = shape.data<T>();
for (int64_t i = 0; i < dimension; ++i) {
PADDLE_ENFORCE(shape_data[i] > 1 || shape_data[i] == unknown_index,
"Each input dimension of Attr(shape) must be positive, or "
"only one input dimension can be -1.");
if (shape_data[i] == unknown_index) neg_dims_idx.push_back(i);
}
PADDLE_ENFORCE_LE(
neg_dims_idx.size(), 1,
"Only one input dimension of Attr(shape) can be unknown.");
int64_t capacity = 1;
output_shape.resize(dimension, 0);
for (int64_t i = 0; i < dimension; ++i) {
capacity *= shape_data[i];
output_shape[i] = static_cast<int64_t>(shape_data[i]);
}
if (neg_dims_idx.size())
output_shape[neg_dims_idx[0]] = in_size / (-capacity);
}
};
template <typename DeviceContext, typename T>
......
......@@ -334,7 +334,7 @@ class OpTest(unittest.TestCase):
np.allclose(
actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place) +
str(actual_t) + str(expect_t))
str(actual_t) + "\n" + str(expect_t))
if isinstance(expect, tuple):
self.assertListEqual(actual.lod(), expect[1],
"Output (" + out_name +
......@@ -546,6 +546,6 @@ class OpTest(unittest.TestCase):
fetch_list = [g for p, g in param_grad_list]
executor = Executor(place)
return map(
np.array,
executor.run(prog, feed_dict, fetch_list, return_numpy=False))
return map(np.array,
executor.run(prog, feed_dict, fetch_list,
return_numpy=False))
......@@ -14,29 +14,51 @@
import unittest
import numpy as np
from op_test import OpTest
import pdb
class TestReshapeOp(OpTest):
def setUp(self):
self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [10 * 20]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
from op_test import OpTest
def test_check_output(self):
self.check_output()
# class TestReshapeOp1(OpTest):
# def setUp(self):
# ori_shape = (2, 25)
# new_shape = [5, 10]
#
# self.op_type = "reshape"
# self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
# self.attrs = {"shape": new_shape}
# self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
def test_check_grad(self):
self.check_grad(["X"], "Out")
# class TestReshapeOpDimInfer1(OpTest):
# def setUp(self):
# self.op_type = "reshape"
# self.inputs = {"X": np.random.random((5, 10)).astype("float32")}
# self.attrs = {"shape": [5, -1, 5]}
# self.outputs = {"Out": self.inputs["X"].reshape(self.attrs["shape"])}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer(OpTest):
class TestReshapeOp2(OpTest):
def setUp(self):
ori_shape = (2, 25)
new_shape = ([5, 10], )
self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [4, -1, 5]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
self.inputs = {
"X": np.random.random(ori_shape).astype("float32"),
"Shape": np.array(new_shape)
}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape[0])}
def test_check_output(self):
self.check_output()
......@@ -45,5 +67,5 @@ class TestReshapeOpDimInfer(OpTest):
self.check_grad(["X"], "Out")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
文件模式从 100755 更改为 100644
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册