提交 1d4dfc09 编写于 作者: C caoying03

fix bugs.

上级 d3d16f76
...@@ -32,7 +32,6 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -32,7 +32,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
"Output(Out) of ReshapeOp should not be null."); "Output(Out) of ReshapeOp should not be null.");
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape"); const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ(shape.empty(), ctx->HasInput("Shape"), PADDLE_ENFORCE_EQ(shape.empty(), ctx->HasInput("Shape"),
"The shape information can only be set by Attr(shape) or " "The shape information can only be set by Attr(shape) or "
"by Input(Shape). Attr(shape) and Input(Shape) cannot be " "by Input(Shape). Attr(shape) and Input(Shape) cannot be "
...@@ -41,29 +40,31 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -41,29 +40,31 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
if (ctx->HasInput("Shape")) { if (ctx->HasInput("Shape")) {
// The shape information in given by Input(Shape).
auto shape_dims = ctx->GetInputDim("Shape"); auto shape_dims = ctx->GetInputDim("Shape");
PADDLE_ENFORCE(shape_dims.size() == 2UL && shape_dims[0] == 1UL, PADDLE_ENFORCE(shape_dims.size() == 2UL && shape_dims[0] == 1UL,
"The Input(Label) should be a 2-D tensor with the 1st " "The Input(Label) should be a 2-D tensor with the 1st "
"dimensions fixed to 1 (a row vector)."); "dimensions fixed to 1 (a row vector).");
// The actual output shape will be set at runtime, here temporially the // The actual output shape will be set at runtime, here temporially set
// the shape of output the same as the shape of input. // the shape of output the same as the shape of input.
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
} else { } else {
// The shape information in given by Attr(shape).
std::vector<int64_t> output_shape; std::vector<int64_t> output_shape;
ValidateShape(shape, framework::product(x_dims), output_shape); ValidateShape(shape, framework::product(x_dims), output_shape);
auto out_dims = framework::make_ddim(output_shape); auto out_dims = framework::make_ddim(output_shape);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
}
if (shape[0] == x_dims[0]) { if (shape[0] == x_dims[0]) {
// Only pass LoD when the first dimension of output and input are the // Only pass LoD when the first dimension of output and Input(X)
// same. // are the same.
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
} }
}
private: private:
void ValidateShape(const std::vector<int> &shape, const int64_t in_size, void ValidateShape(const std::vector<int> &shape, const int64_t in_size,
...@@ -94,6 +95,14 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -94,6 +95,14 @@ class ReshapeOp : public framework::OperatorWithKernel {
[](int a) { return static_cast<int64_t>(a); }); [](int a) { return static_cast<int64_t>(a); });
if (neg_dims_idx.size()) output_shape[neg_dims_idx[0]] = inferred_dim; if (neg_dims_idx.size()) output_shape[neg_dims_idx[0]] = inferred_dim;
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
}; };
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -101,11 +110,13 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -101,11 +110,13 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker) ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of reshape operator."); AddInput("X", "The input tensor of reshape operator.");
AddInput("Shape", "a 1-D tensor that provides the shape information.") AddInput(
"Shape",
"Tensor<int64_t>, a 1-D tensor that provides the shape information.")
.AsDispensable(); .AsDispensable();
AddOutput("Out", "The output tensor of reshape operator."); AddOutput("Out", "The output tensor of reshape operator.");
AddAttr<std::vector<int>>("shape", AddAttr<std::vector<int>>(
"(vector<int>) Target shape of reshape operator.") "shape", "(std::vector<int>) Target shape of reshape operator.")
.SetDefault(std::vector<int>()); .SetDefault(std::vector<int>());
AddComment(R"DOC( AddComment(R"DOC(
Reshape Operator. Reshape Operator.
...@@ -139,6 +150,14 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -139,6 +150,14 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) shouldn't be null."); "Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -33,9 +33,6 @@ class ReshapeKernel : public framework::OpKernel<T> { ...@@ -33,9 +33,6 @@ class ReshapeKernel : public framework::OpKernel<T> {
std::vector<int64_t> output_shape; std::vector<int64_t> output_shape;
ValidateShape(*shape, framework::product(in->dims()), output_shape); ValidateShape(*shape, framework::product(in->dims()), output_shape);
for (auto d : output_shape) std::cout << d << " ";
std::cout << std::endl;
out_dims = framework::make_ddim(output_shape); out_dims = framework::make_ddim(output_shape);
} else { } else {
out_dims = out->dims(); out_dims = out->dims();
...@@ -85,11 +82,18 @@ class ReshapeGradKernel : public framework::OpKernel<T> { ...@@ -85,11 +82,18 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
bool inplace = ctx.Attr<bool>("inplace");
auto in_dims = d_x->dims(); auto in_dims = d_x->dims();
if (!inplace) {
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
d_x->Resize(in_dims); d_x->Resize(in_dims);
} else {
d_x->ShareDataWith(*d_out);
d_x->Resize(in_dims);
}
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -33,7 +33,8 @@ from op_test import OpTest ...@@ -33,7 +33,8 @@ from op_test import OpTest
# #
# def test_check_grad(self): # def test_check_grad(self):
# self.check_grad(["X"], "Out") # self.check_grad(["X"], "Out")
#
#
# class TestReshapeOpDimInfer1(OpTest): # class TestReshapeOpDimInfer1(OpTest):
# def setUp(self): # def setUp(self):
# self.op_type = "reshape" # self.op_type = "reshape"
...@@ -56,7 +57,8 @@ class TestReshapeOp2(OpTest): ...@@ -56,7 +57,8 @@ class TestReshapeOp2(OpTest):
self.op_type = "reshape" self.op_type = "reshape"
self.inputs = { self.inputs = {
"X": np.random.random(ori_shape).astype("float32"), "X": np.random.random(ori_shape).astype("float32"),
"Shape": np.array(new_shape) "Shape": np.array(
new_shape, dtype="int64")
} }
self.outputs = {"Out": self.inputs["X"].reshape(new_shape[0])} self.outputs = {"Out": self.inputs["X"].reshape(new_shape[0])}
...@@ -67,5 +69,32 @@ class TestReshapeOp2(OpTest): ...@@ -67,5 +69,32 @@ class TestReshapeOp2(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
# class TestReshapeOpInplace(OpTest):
# def setUp(self):
# self.op_type = "reshape"
# self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
# self.attrs = {'shape': [10 * 20], 'inplace': True}
# self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
#
#
# class TestReshapeOpDimInferInplace(OpTest):
# def setUp(self):
# self.op_type = "reshape"
# self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
# self.attrs = {'shape': [4, -1, 5], 'inplace': True}
# self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册