提交 899c7d6b 编写于 作者: Y Yibing Liu

pass unit test

上级 12eaa22a
......@@ -38,6 +38,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
size_t in_size = framework::product(in->dims());
PADDLE_ENFORCE_EQ(shape_size, in_size,
"The size of Input(X) mismatches with Attr(shape).");
ctx.Output<framework::Tensor>("Out")->Resize(in->dims());
}
};
......@@ -51,7 +52,7 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>("shape", "Target shape of reshape operator.");
AddComment(R"DOC(Reshape operator
The input tensor will be reshaped with Attr(shape).
Reshape Input(X) into the shape specified by Attr(shape).
)DOC");
}
};
......
......@@ -23,13 +23,13 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T, typename AttrType = T>
template <typename Place, typename T>
class ReshapeKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<Tensor>("Out");
auto* in = ctx.Input<Tensor>("X");
out->mutable_data<T>(in->place());
out->mutable_data<T>(ctx.GetPlace());
auto shape = ctx.Attr<std::vector<int>>("shape");
std::vector<int64_t> tmp;
......@@ -42,7 +42,7 @@ class ReshapeKernel : public framework::OpKernel {
}
};
template <typename Place, typename T, typename AttrType = T>
template <typename Place, typename T>
class ReshapeGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
......@@ -51,7 +51,6 @@ class ReshapeGradKernel : public framework::OpKernel {
d_x->mutable_data<T>(ctx.GetPlace());
auto in_dims = d_x->dims();
d_x->CopyFrom<T>(*d_out, ctx.GetPlace());
d_x->Resize(in_dims);
}
......
import unittest
import numpy as np
from gradient_checker import GradientChecker, create_op
from gradient_checker import GradientChecker, Operator
from op_test_util import OpTestMeta
......@@ -9,19 +9,16 @@ class TestReshapeOp(unittest.TestCase):
def setUp(self):
self.type = "reshape"
self.inputs = {'X': np.random.random((2, 4)).astype("float32"), }
print self.inputs
self.attrs = {'shape': [4, 2]}
self.inputs = {'X': np.random.random((37, 51)).astype("float32"), }
self.attrs = {'shape': [51, 37]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
print self.outputs
class ReshapeGradOpTest(GradientChecker):
def test_normal(self):
op = create_op("reshape")
inputs = {"X": np.random.random((2, 4)).astype("float32")}
attrs = {'shape': [4, 2]}
self.check_grad(op, inputs, attrs, set("X"), "Out")
op = Operator("reshape", X='X', Out='Out', shape=[5, 40])
inputs = {"X": np.random.random((10, 20)).astype("float32")}
self.check_grad(op, inputs, set("X"), "Out")
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册