提交 d3d16f76 编写于 作者: Y ying

enhance reshape operator.

上级 ea4e6c7a
...@@ -31,48 +31,69 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -31,48 +31,69 @@ class ReshapeOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReshapeOp should not be null."); "Output(Out) of ReshapeOp should not be null.");
auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
PADDLE_ENFORCE_EQ(shape.empty(), ctx->HasInput("Shape"),
"The shape information can only be set by Attr(shape) or "
"by Input(Shape). Attr(shape) and Input(Shape) cannot be "
"set at the same time.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
std::vector<size_t> neg_dims_idx; if (ctx->HasInput("Shape")) {
// set some dimension to -1 if it is unknown auto shape_dims = ctx->GetInputDim("Shape");
const int unknown_size = -1;
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0 || shape[i] == unknown_size,
"Each dimension of Attr(shape) must be positive or %d.",
unknown_size);
if (shape[i] == unknown_size) {
neg_dims_idx.push_back(i);
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
"Only one dimension of Attr(shape) can be unknown.");
}
}
int64_t capacity = PADDLE_ENFORCE(shape_dims.size() == 2UL && shape_dims[0] == 1UL,
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); "The Input(Label) should be a 2-D tensor with the 1st "
int64_t in_size = framework::product(x_dims); "dimensions fixed to 1 (a row vector).");
if (neg_dims_idx.size() == 1) {
// dim infer // The actual output shape will be set at runtime, here temporially the
shape[neg_dims_idx[0]] = in_size / (-capacity); // the shape of output the same as the shape of input.
// recalculate capacity ctx->SetOutputDim("Out", x_dims);
capacity = shape[neg_dims_idx[0]] * (-capacity); } else {
std::vector<int64_t> output_shape;
ValidateShape(shape, framework::product(x_dims), output_shape);
auto out_dims = framework::make_ddim(output_shape);
ctx->SetOutputDim("Out", out_dims);
} }
// capacity check
PADDLE_ENFORCE(capacity == in_size,
"The size of Input(X) mismatches with Attr(shape).");
// resize output
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto out_dims = framework::make_ddim(shape_int64);
ctx->SetOutputDim("Out", out_dims);
if (shape[0] == x_dims[0]) { if (shape[0] == x_dims[0]) {
// Only pass LoD when the first dimension is equal between // Only pass LoD when the first dimension of output and input are the
// output and input. // same.
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
} }
private:
void ValidateShape(const std::vector<int> &shape, const int64_t in_size,
std::vector<int64_t> &output_shape) const {
std::vector<size_t> neg_dims_idx;
const int unknown_index = -1; // only one dimension canbe set to -1, whose
// size will be automatically infered.
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 1 || shape[i] == unknown_index,
"Each input dimension of Attr(shape) must be positive, or "
"only one input dimension can be -1.");
if (shape[i] == unknown_index) neg_dims_idx.push_back(i);
}
PADDLE_ENFORCE_LE(
neg_dims_idx.size(), 1,
"Only one input dimension of Attr(shape) may be unknown.");
int64_t inferred_dim = 0;
if (neg_dims_idx.size()) {
int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int>());
inferred_dim = in_size / (-capacity);
}
output_shape.resize(shape.size(), 0);
std::transform(shape.begin(), shape.end(), output_shape.begin(),
[](int a) { return static_cast<int64_t>(a); });
if (neg_dims_idx.size()) output_shape[neg_dims_idx[0]] = inferred_dim;
}
}; };
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -80,10 +101,12 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -80,10 +101,12 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker) ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of reshape operator."); AddInput("X", "The input tensor of reshape operator.");
AddInput("Shape", "a 1-D tensor that provides the shape information.")
.AsDispensable();
AddOutput("Out", "The output tensor of reshape operator."); AddOutput("Out", "The output tensor of reshape operator.");
AddAttr<std::vector<int>>("shape", AddAttr<std::vector<int>>("shape",
"(vector<int>) " "(vector<int>) Target shape of reshape operator.")
"Target shape of reshape operator."); .SetDefault(std::vector<int>());
AddComment(R"DOC( AddComment(R"DOC(
Reshape Operator. Reshape Operator.
...@@ -96,7 +119,7 @@ and target shape = [1, 4], the reshape operator will transform ...@@ -96,7 +119,7 @@ and target shape = [1, 4], the reshape operator will transform
the tensor X into a 2-D tensor: [[1, 2, 3, 4]] the tensor X into a 2-D tensor: [[1, 2, 3, 4]]
One dimension in the target shape can be set -1, representing that its One dimension in the target shape can be set -1, representing that its
size is unknown. In this case, the real dimension will be infered from size is unknown. In this case, the real dimension will be infered from
the original shape of Input(X) and other dimensions in the target shape. the original shape of Input(X) and other dimensions in the target shape.
)DOC"); )DOC");
} }
......
...@@ -26,11 +26,57 @@ class ReshapeKernel : public framework::OpKernel<T> { ...@@ -26,11 +26,57 @@ class ReshapeKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto out_dims = out->dims();
auto* shape = ctx.Input<framework::Tensor>("Shape");
framework::DDim out_dims;
if (shape) {
std::vector<int64_t> output_shape;
ValidateShape(*shape, framework::product(in->dims()), output_shape);
for (auto d : output_shape) std::cout << d << " ";
std::cout << std::endl;
out_dims = framework::make_ddim(output_shape);
} else {
out_dims = out->dims();
}
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out); framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
out->Resize(out_dims); out->Resize(out_dims);
} }
private:
void ValidateShape(const framework::Tensor& shape, const int64_t in_size,
std::vector<int64_t>& output_shape) const {
std::vector<size_t> neg_dims_idx;
const int unknown_index = -1; // only one dimension canbe set to -1, whose
// size will be automatically infered.
const int64_t dimension = shape.dims()[1];
std::cout << "dimension =" << dimension << std::endl;
const T* shape_data = shape.data<T>();
for (int64_t i = 0; i < dimension; ++i) {
PADDLE_ENFORCE(shape_data[i] > 1 || shape_data[i] == unknown_index,
"Each input dimension of Attr(shape) must be positive, or "
"only one input dimension can be -1.");
if (shape_data[i] == unknown_index) neg_dims_idx.push_back(i);
}
PADDLE_ENFORCE_LE(
neg_dims_idx.size(), 1,
"Only one input dimension of Attr(shape) can be unknown.");
int64_t capacity = 1;
output_shape.resize(dimension, 0);
for (int64_t i = 0; i < dimension; ++i) {
capacity *= shape_data[i];
output_shape[i] = static_cast<int64_t>(shape_data[i]);
}
if (neg_dims_idx.size())
output_shape[neg_dims_idx[0]] = in_size / (-capacity);
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -334,7 +334,7 @@ class OpTest(unittest.TestCase): ...@@ -334,7 +334,7 @@ class OpTest(unittest.TestCase):
np.allclose( np.allclose(
actual_t, expect_t, atol=atol), actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place) + "Output (" + out_name + ") has diff at " + str(place) +
str(actual_t) + str(expect_t)) str(actual_t) + "\n" + str(expect_t))
if isinstance(expect, tuple): if isinstance(expect, tuple):
self.assertListEqual(actual.lod(), expect[1], self.assertListEqual(actual.lod(), expect[1],
"Output (" + out_name + "Output (" + out_name +
...@@ -546,6 +546,6 @@ class OpTest(unittest.TestCase): ...@@ -546,6 +546,6 @@ class OpTest(unittest.TestCase):
fetch_list = [g for p, g in param_grad_list] fetch_list = [g for p, g in param_grad_list]
executor = Executor(place) executor = Executor(place)
return map( return map(np.array,
np.array, executor.run(prog, feed_dict, fetch_list,
executor.run(prog, feed_dict, fetch_list, return_numpy=False)) return_numpy=False))
...@@ -14,29 +14,51 @@ ...@@ -14,29 +14,51 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest import pdb
class TestReshapeOp(OpTest): from op_test import OpTest
def setUp(self):
self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [10 * 20]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
def test_check_output(self): # class TestReshapeOp1(OpTest):
self.check_output() # def setUp(self):
# ori_shape = (2, 25)
# new_shape = [5, 10]
#
# self.op_type = "reshape"
# self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
# self.attrs = {"shape": new_shape}
# self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
def test_check_grad(self): # class TestReshapeOpDimInfer1(OpTest):
self.check_grad(["X"], "Out") # def setUp(self):
# self.op_type = "reshape"
# self.inputs = {"X": np.random.random((5, 10)).astype("float32")}
# self.attrs = {"shape": [5, -1, 5]}
# self.outputs = {"Out": self.inputs["X"].reshape(self.attrs["shape"])}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer(OpTest): class TestReshapeOp2(OpTest):
def setUp(self): def setUp(self):
ori_shape = (2, 25)
new_shape = ([5, 10], )
self.op_type = "reshape" self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")} self.inputs = {
self.attrs = {'shape': [4, -1, 5]} "X": np.random.random(ori_shape).astype("float32"),
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} "Shape": np.array(new_shape)
}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape[0])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -45,5 +67,5 @@ class TestReshapeOpDimInfer(OpTest): ...@@ -45,5 +67,5 @@ class TestReshapeOpDimInfer(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
文件模式从 100755 更改为 100644
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册