提交 f9c2a5c3 编写于 作者: S sweetsky0901

modify for code review zcd

上级 022b48e1
...@@ -46,7 +46,7 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -46,7 +46,7 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector defalut:{0,0}), " "(vector defalut:{0,0}), "
"paddings (height, width) of unpooling operator.") "paddings (height, width) of unpooling operator.")
.SetDefault({0, 0}); .SetDefault({0, 0});
AddAttr<std::string>("unpoolingtype", AddAttr<std::string>("unpooling_type",
"(string), unpooling type, can be \"max\" for max-unpooling ") "(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"}); .InEnum({"max"});
AddComment(R"DOC( AddComment(R"DOC(
...@@ -87,7 +87,7 @@ public: ...@@ -87,7 +87,7 @@ public:
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Y"); auto in_y_dims = ctx->GetInputDim("Y");
std::string unpooling_type = std::string unpooling_type =
ctx->Attrs().Get<std::string>("unpoolingtype"); ctx->Attrs().Get<std::string>("unpooling_type");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize"); std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
......
...@@ -28,7 +28,7 @@ class UnpoolKernel : public framework::OpKernel<T> { ...@@ -28,7 +28,7 @@ class UnpoolKernel : public framework::OpKernel<T> {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X"); const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Y"); const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
auto * out = context.Output<framework::Tensor>("Out"); auto * out = context.Output<framework::Tensor>("Out");
std::string unpooling_type = context.Attr<std::string>("unpoolingtype"); std::string unpooling_type = context.Attr<std::string>("unpooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
...@@ -53,7 +53,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> { ...@@ -53,7 +53,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
context.Input<framework::Tensor>(framework::GradVarName("Out")); context.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor* in_x_grad = framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X")); context.Output<framework::Tensor>(framework::GradVarName("X"));
std::string unpooling_type = context.Attr<std::string>("unpoolingtype"); std::string unpooling_type = context.Attr<std::string>("unpooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
......
...@@ -58,7 +58,7 @@ class TestUnpoolOp(OpTest): ...@@ -58,7 +58,7 @@ class TestUnpoolOp(OpTest):
'strides': self.strides, 'strides': self.strides,
'paddings': self.paddings, 'paddings': self.paddings,
'ksize': self.ksize, 'ksize': self.ksize,
'unpoolingtype': self.unpoolingtype, 'unpooling_type': self.unpooling_type,
} }
self.outputs = {'Out': output.astype('float32')} self.outputs = {'Out': output.astype('float32')}
...@@ -70,7 +70,7 @@ class TestUnpoolOp(OpTest): ...@@ -70,7 +70,7 @@ class TestUnpoolOp(OpTest):
def init_test_case(self): def init_test_case(self):
self.Unpool2d_forward_naive = unpool2dmax_forward_naive self.Unpool2d_forward_naive = unpool2dmax_forward_naive
self.unpoolingtype = "max" self.unpooling_type = "max"
self.shape = [6, 4, 5, 5] self.shape = [6, 4, 5, 5]
self.ksize = [3, 3] self.ksize = [3, 3]
self.strides = [2, 2] self.strides = [2, 2]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册