提交 d3d16f76 编写于 作者: Y ying

enhance reshape operator.

上级 ea4e6c7a
......@@ -31,47 +31,68 @@ class ReshapeOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReshapeOp should not be null.");
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ(shape.empty(), ctx->HasInput("Shape"),
"The shape information can only be set by Attr(shape) or "
"by Input(Shape). Attr(shape) and Input(Shape) cannot be "
"set at the same time.");
auto x_dims = ctx->GetInputDim("X");
if (ctx->HasInput("Shape")) {
auto shape_dims = ctx->GetInputDim("Shape");
PADDLE_ENFORCE(shape_dims.size() == 2UL && shape_dims[0] == 1UL,
"The Input(Label) should be a 2-D tensor with the 1st "
"dimensions fixed to 1 (a row vector).");
// The actual output shape will be set at runtime, here temporially the
// the shape of output the same as the shape of input.
ctx->SetOutputDim("Out", x_dims);
} else {
std::vector<int64_t> output_shape;
ValidateShape(shape, framework::product(x_dims), output_shape);
auto out_dims = framework::make_ddim(output_shape);
ctx->SetOutputDim("Out", out_dims);
}
if (shape[0] == x_dims[0]) {
// Only pass LoD when the first dimension of output and input are the
// same.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
private:
void ValidateShape(const std::vector<int> &shape, const int64_t in_size,
std::vector<int64_t> &output_shape) const {
std::vector<size_t> neg_dims_idx;
// set some dimension to -1 if it is unknown
const int unknown_size = -1;
const int unknown_index = -1; // only one dimension canbe set to -1, whose
// size will be automatically infered.
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0 || shape[i] == unknown_size,
"Each dimension of Attr(shape) must be positive or %d.",
unknown_size);
if (shape[i] == unknown_size) {
neg_dims_idx.push_back(i);
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
"Only one dimension of Attr(shape) can be unknown.");
PADDLE_ENFORCE(shape[i] > 1 || shape[i] == unknown_index,
"Each input dimension of Attr(shape) must be positive, or "
"only one input dimension can be -1.");
if (shape[i] == unknown_index) neg_dims_idx.push_back(i);
}
PADDLE_ENFORCE_LE(
neg_dims_idx.size(), 1,
"Only one input dimension of Attr(shape) may be unknown.");
int64_t inferred_dim = 0;
if (neg_dims_idx.size()) {
int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int>());
inferred_dim = in_size / (-capacity);
}
int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int64_t in_size = framework::product(x_dims);
if (neg_dims_idx.size() == 1) {
// dim infer
shape[neg_dims_idx[0]] = in_size / (-capacity);
// recalculate capacity
capacity = shape[neg_dims_idx[0]] * (-capacity);
}
// capacity check
PADDLE_ENFORCE(capacity == in_size,
"The size of Input(X) mismatches with Attr(shape).");
// resize output
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
output_shape.resize(shape.size(), 0);
std::transform(shape.begin(), shape.end(), output_shape.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto out_dims = framework::make_ddim(shape_int64);
ctx->SetOutputDim("Out", out_dims);
if (shape[0] == x_dims[0]) {
// Only pass LoD when the first dimension is equal between
// output and input.
ctx->ShareLoD("X", /*->*/ "Out");
}
if (neg_dims_idx.size()) output_shape[neg_dims_idx[0]] = inferred_dim;
}
};
......@@ -80,10 +101,12 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of reshape operator.");
AddInput("Shape", "a 1-D tensor that provides the shape information.")
.AsDispensable();
AddOutput("Out", "The output tensor of reshape operator.");
AddAttr<std::vector<int>>("shape",
"(vector<int>) "
"Target shape of reshape operator.");
"(vector<int>) Target shape of reshape operator.")
.SetDefault(std::vector<int>());
AddComment(R"DOC(
Reshape Operator.
......
......@@ -26,11 +26,57 @@ class ReshapeKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::Tensor>("X");
auto out_dims = out->dims();
auto* shape = ctx.Input<framework::Tensor>("Shape");
framework::DDim out_dims;
if (shape) {
std::vector<int64_t> output_shape;
ValidateShape(*shape, framework::product(in->dims()), output_shape);
for (auto d : output_shape) std::cout << d << " ";
std::cout << std::endl;
out_dims = framework::make_ddim(output_shape);
} else {
out_dims = out->dims();
}
out->mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
out->Resize(out_dims);
}
private:
void ValidateShape(const framework::Tensor& shape, const int64_t in_size,
std::vector<int64_t>& output_shape) const {
std::vector<size_t> neg_dims_idx;
const int unknown_index = -1; // only one dimension canbe set to -1, whose
// size will be automatically infered.
const int64_t dimension = shape.dims()[1];
std::cout << "dimension =" << dimension << std::endl;
const T* shape_data = shape.data<T>();
for (int64_t i = 0; i < dimension; ++i) {
PADDLE_ENFORCE(shape_data[i] > 1 || shape_data[i] == unknown_index,
"Each input dimension of Attr(shape) must be positive, or "
"only one input dimension can be -1.");
if (shape_data[i] == unknown_index) neg_dims_idx.push_back(i);
}
PADDLE_ENFORCE_LE(
neg_dims_idx.size(), 1,
"Only one input dimension of Attr(shape) can be unknown.");
int64_t capacity = 1;
output_shape.resize(dimension, 0);
for (int64_t i = 0; i < dimension; ++i) {
capacity *= shape_data[i];
output_shape[i] = static_cast<int64_t>(shape_data[i]);
}
if (neg_dims_idx.size())
output_shape[neg_dims_idx[0]] = in_size / (-capacity);
}
};
template <typename DeviceContext, typename T>
......
......@@ -334,7 +334,7 @@ class OpTest(unittest.TestCase):
np.allclose(
actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place) +
str(actual_t) + str(expect_t))
str(actual_t) + "\n" + str(expect_t))
if isinstance(expect, tuple):
self.assertListEqual(actual.lod(), expect[1],
"Output (" + out_name +
......@@ -546,6 +546,6 @@ class OpTest(unittest.TestCase):
fetch_list = [g for p, g in param_grad_list]
executor = Executor(place)
return map(
np.array,
executor.run(prog, feed_dict, fetch_list, return_numpy=False))
return map(np.array,
executor.run(prog, feed_dict, fetch_list,
return_numpy=False))
......@@ -14,29 +14,51 @@
import unittest
import numpy as np
from op_test import OpTest
import pdb
class TestReshapeOp(OpTest):
def setUp(self):
self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [10 * 20]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
from op_test import OpTest
def test_check_output(self):
self.check_output()
# class TestReshapeOp1(OpTest):
# def setUp(self):
# ori_shape = (2, 25)
# new_shape = [5, 10]
#
# self.op_type = "reshape"
# self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
# self.attrs = {"shape": new_shape}
# self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
def test_check_grad(self):
self.check_grad(["X"], "Out")
# class TestReshapeOpDimInfer1(OpTest):
# def setUp(self):
# self.op_type = "reshape"
# self.inputs = {"X": np.random.random((5, 10)).astype("float32")}
# self.attrs = {"shape": [5, -1, 5]}
# self.outputs = {"Out": self.inputs["X"].reshape(self.attrs["shape"])}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer(OpTest):
class TestReshapeOp2(OpTest):
def setUp(self):
ori_shape = (2, 25)
new_shape = ([5, 10], )
self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [4, -1, 5]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
self.inputs = {
"X": np.random.random(ori_shape).astype("float32"),
"Shape": np.array(new_shape)
}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape[0])}
def test_check_output(self):
self.check_output()
......@@ -45,5 +67,5 @@ class TestReshapeOpDimInfer(OpTest):
self.check_grad(["X"], "Out")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
文件模式从 100755 更改为 100644
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册