提交 899c7d6b 编写于 作者: Y Yibing Liu

pass unit test

上级 12eaa22a
...@@ -38,6 +38,7 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -38,6 +38,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
size_t in_size = framework::product(in->dims()); size_t in_size = framework::product(in->dims());
PADDLE_ENFORCE_EQ(shape_size, in_size, PADDLE_ENFORCE_EQ(shape_size, in_size,
"The size of Input(X) mismatches with Attr(shape)."); "The size of Input(X) mismatches with Attr(shape).");
ctx.Output<framework::Tensor>("Out")->Resize(in->dims());
} }
}; };
...@@ -51,7 +52,7 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -51,7 +52,7 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>("shape", "Target shape of reshape operator."); AddAttr<std::vector<int>>("shape", "Target shape of reshape operator.");
AddComment(R"DOC(Reshape operator AddComment(R"DOC(Reshape operator
The input tensor will be reshaped with Attr(shape). Reshape Input(X) into the shape specified by Attr(shape).
)DOC"); )DOC");
} }
}; };
......
...@@ -23,13 +23,13 @@ namespace operators { ...@@ -23,13 +23,13 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T>
class ReshapeKernel : public framework::OpKernel { class ReshapeKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
auto* in = ctx.Input<Tensor>("X"); auto* in = ctx.Input<Tensor>("X");
out->mutable_data<T>(in->place()); out->mutable_data<T>(ctx.GetPlace());
auto shape = ctx.Attr<std::vector<int>>("shape"); auto shape = ctx.Attr<std::vector<int>>("shape");
std::vector<int64_t> tmp; std::vector<int64_t> tmp;
...@@ -42,7 +42,7 @@ class ReshapeKernel : public framework::OpKernel { ...@@ -42,7 +42,7 @@ class ReshapeKernel : public framework::OpKernel {
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T>
class ReshapeGradKernel : public framework::OpKernel { class ReshapeGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
...@@ -51,7 +51,6 @@ class ReshapeGradKernel : public framework::OpKernel { ...@@ -51,7 +51,6 @@ class ReshapeGradKernel : public framework::OpKernel {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
auto in_dims = d_x->dims(); auto in_dims = d_x->dims();
d_x->CopyFrom<T>(*d_out, ctx.GetPlace()); d_x->CopyFrom<T>(*d_out, ctx.GetPlace());
d_x->Resize(in_dims); d_x->Resize(in_dims);
} }
......
import unittest import unittest
import numpy as np import numpy as np
from gradient_checker import GradientChecker, create_op from gradient_checker import GradientChecker, Operator
from op_test_util import OpTestMeta from op_test_util import OpTestMeta
...@@ -9,19 +9,16 @@ class TestReshapeOp(unittest.TestCase): ...@@ -9,19 +9,16 @@ class TestReshapeOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "reshape" self.type = "reshape"
self.inputs = {'X': np.random.random((2, 4)).astype("float32"), } self.inputs = {'X': np.random.random((37, 51)).astype("float32"), }
print self.inputs self.attrs = {'shape': [51, 37]}
self.attrs = {'shape': [4, 2]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
print self.outputs
class ReshapeGradOpTest(GradientChecker): class ReshapeGradOpTest(GradientChecker):
def test_normal(self): def test_normal(self):
op = create_op("reshape") op = Operator("reshape", X='X', Out='Out', shape=[5, 40])
inputs = {"X": np.random.random((2, 4)).astype("float32")} inputs = {"X": np.random.random((10, 20)).astype("float32")}
attrs = {'shape': [4, 2]} self.check_grad(op, inputs, set("X"), "Out")
self.check_grad(op, inputs, attrs, set("X"), "Out")
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册