提交 1d4dfc09 编写于 作者: C caoying03

fix bugs.

上级 d3d16f76
......@@ -32,7 +32,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
"Output(Out) of ReshapeOp should not be null.");
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ(shape.empty(), ctx->HasInput("Shape"),
"The shape information can only be set by Attr(shape) or "
"by Input(Shape). Attr(shape) and Input(Shape) cannot be "
......@@ -41,29 +40,31 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
if (ctx->HasInput("Shape")) {
// The shape information in given by Input(Shape).
auto shape_dims = ctx->GetInputDim("Shape");
PADDLE_ENFORCE(shape_dims.size() == 2UL && shape_dims[0] == 1UL,
"The Input(Label) should be a 2-D tensor with the 1st "
"dimensions fixed to 1 (a row vector).");
// The actual output shape will be set at runtime, here temporially the
// The actual output shape will be set at runtime, here temporially set
// the shape of output the same as the shape of input.
ctx->SetOutputDim("Out", x_dims);
} else {
// The shape information in given by Attr(shape).
std::vector<int64_t> output_shape;
ValidateShape(shape, framework::product(x_dims), output_shape);
auto out_dims = framework::make_ddim(output_shape);
ctx->SetOutputDim("Out", out_dims);
}
if (shape[0] == x_dims[0]) {
// Only pass LoD when the first dimension of output and input are the
// same.
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
}
private:
void ValidateShape(const std::vector<int> &shape, const int64_t in_size,
......@@ -94,6 +95,14 @@ class ReshapeOp : public framework::OperatorWithKernel {
[](int a) { return static_cast<int64_t>(a); });
if (neg_dims_idx.size()) output_shape[neg_dims_idx[0]] = inferred_dim;
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
};
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -101,11 +110,13 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of reshape operator.");
AddInput("Shape", "a 1-D tensor that provides the shape information.")
AddInput(
"Shape",
"Tensor<int64_t>, a 1-D tensor that provides the shape information.")
.AsDispensable();
AddOutput("Out", "The output tensor of reshape operator.");
AddAttr<std::vector<int>>("shape",
"(vector<int>) Target shape of reshape operator.")
AddAttr<std::vector<int>>(
"shape", "(std::vector<int>) Target shape of reshape operator.")
.SetDefault(std::vector<int>());
AddComment(R"DOC(
Reshape Operator.
......@@ -139,6 +150,14 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
};
} // namespace operators
......
......@@ -33,9 +33,6 @@ class ReshapeKernel : public framework::OpKernel<T> {
std::vector<int64_t> output_shape;
ValidateShape(*shape, framework::product(in->dims()), output_shape);
for (auto d : output_shape) std::cout << d << " ";
std::cout << std::endl;
out_dims = framework::make_ddim(output_shape);
} else {
out_dims = out->dims();
......@@ -85,11 +82,18 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
bool inplace = ctx.Attr<bool>("inplace");
auto in_dims = d_x->dims();
if (!inplace) {
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
d_x->Resize(in_dims);
} else {
d_x->ShareDataWith(*d_out);
d_x->Resize(in_dims);
}
}
};
} // namespace operators
......
......@@ -33,7 +33,8 @@ from op_test import OpTest
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
#
#
# class TestReshapeOpDimInfer1(OpTest):
# def setUp(self):
# self.op_type = "reshape"
......@@ -56,7 +57,8 @@ class TestReshapeOp2(OpTest):
self.op_type = "reshape"
self.inputs = {
"X": np.random.random(ori_shape).astype("float32"),
"Shape": np.array(new_shape)
"Shape": np.array(
new_shape, dtype="int64")
}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape[0])}
......@@ -67,5 +69,32 @@ class TestReshapeOp2(OpTest):
self.check_grad(["X"], "Out")
# class TestReshapeOpInplace(OpTest):
# def setUp(self):
# self.op_type = "reshape"
# self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
# self.attrs = {'shape': [10 * 20], 'inplace': True}
# self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
#
#
# class TestReshapeOpDimInferInplace(OpTest):
# def setUp(self):
# self.op_type = "reshape"
# self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
# self.attrs = {'shape': [4, -1, 5], 'inplace': True}
# self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册