提交 66b84366 编写于 作者: S sweetsky0901

modify for code review by wangyi

上级 e553d572
...@@ -16,11 +16,9 @@ ...@@ -16,11 +16,9 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor;
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Unpool2dOpMaker(framework::OpProto* proto, \ Unpool2dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
...@@ -38,26 +36,26 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -38,26 +36,26 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"the number of channels, H and W is the height and " "the number of channels, H and W is the height and "
"width of feature."); "width of feature.");
AddAttr<std::vector<int>>("ksize", AddAttr<std::vector<int>>("ksize",
"(vector ), the unpooling window size(height, width) " "(vector), the unpooling window size(height, width) "
"of unpooling operator."); "of unpooling operator.");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector, default:{1, 1}), " "(vector, default:{1, 1}), "
"strides(height, width) of unpooling operator.") "strides (height, width) of unpooling operator.")
.SetDefault({1, 1}); .SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings", AddAttr<std::vector<int>>("paddings",
"(vector defalut:{0,0}), " "(vector defalut:{0,0}), "
"paddings(height, width) of unpooling operator.") "paddings (height, width) of unpooling operator.")
.SetDefault({0, 0}); .SetDefault({0, 0});
AddAttr<std::string>("unpoolingtype", AddAttr<std::string>("unpoolingtype",
"(string), unpooling type, can be \"max\" for max-unpooling ") "(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"}); .InEnum({"max"});
AddComment(R"DOC( AddComment(R"DOC(
"input: the input Tensor to invert" "input: the input Tensor to invert
"indices: the indices given out by MaxPool2d" indices: the indices given out by MaxPool2d
"ksize – Size of the max pooling window." ksize – Size of the max pooling window.
"stride – Stride of the max pooling window." stride – Stride of the max pooling window.
"It is set to kernel_size by default." "It is set to kernel_size by default.
"padding – Padding that was added to the input" padding – Padding that was added to the input"
)DOC"); )DOC");
} }
}; };
...@@ -80,14 +78,14 @@ class UnpoolOp : public framework::OperatorWithKernel { ...@@ -80,14 +78,14 @@ class UnpoolOp : public framework::OperatorWithKernel {
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Y"); auto in_y_dims = ctx->GetInputDim("Y");
std::string unpoolingtype = \ std::string unpoolingtype =
ctx->Attrs().Get<std::string>("unpoolingtype"); ctx->Attrs().Get<std::string>("unpoolingtype");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize"); std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4, PADDLE_ENFORCE(in_x_dims.size() == 4,
"Unpooling intput should be 4-D."); "Unpooling intput must be of 4-dimensional.");
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
PADDLE_ENFORCE(in_x_dims[i] == in_y_dims[i], PADDLE_ENFORCE(in_x_dims[i] == in_y_dims[i],
"X size must be eq Y size!"); "X size must be eq Y size!");
......
...@@ -21,15 +21,13 @@ limitations under the License. */ ...@@ -21,15 +21,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class UnpoolKernel : public framework::OpKernel<T> { class UnpoolKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X"); const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const Tensor* in_y = context.Input<Tensor>("Y"); const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
auto * out = context.Output<Tensor>("Out"); auto * out = context.Output<framework::Tensor>("Out");
std::string unpoolingtype = context.Attr<std::string>("unpoolingtype"); std::string unpoolingtype = context.Attr<std::string>("unpoolingtype");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
...@@ -39,28 +37,22 @@ class UnpoolKernel : public framework::OpKernel<T> { ...@@ -39,28 +37,22 @@ class UnpoolKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> set_zero; math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), out, static_cast<T>(0)); set_zero(context.device_context(), out, static_cast<T>(0));
} }
switch (ksize.size()) {
case 2: {
if (unpoolingtype == "max") {
math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward; math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
unpool2d_max_forward(context.device_context(), *in_x, *in_y, out); unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
} }
} break;
default: { PADDLE_THROW("Pool op only supports 2D input."); }
}
}
}; };
template <typename Place, typename T> template <typename Place, typename T>
class UnpoolGradKernel : public framework::OpKernel<T> { class UnpoolGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X"); const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const Tensor* in_y = context.Input<Tensor>("Y"); const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
const Tensor* out = context.Input<Tensor>("Out"); const framework::Tensor* out = context.Input<framework::Tensor>("Out");
const Tensor* out_grad = const framework::Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out")); context.Input<framework::Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X")); framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
std::string unpoolingtype = context.Attr<std::string>("unpoolingtype"); std::string unpoolingtype = context.Attr<std::string>("unpoolingtype");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
...@@ -70,19 +62,12 @@ class UnpoolGradKernel : public framework::OpKernel<T> { ...@@ -70,19 +62,12 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> zero; math::SetConstant<Place, T> zero;
if (in_x_grad) { if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0.0)); zero(device_ctx, in_x_grad, static_cast<T>(0));
} }
switch (ksize.size()) {
case 2: {
if (unpoolingtype == "max") {
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward; math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad, unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad,
*out, *out_grad); *out, *out_grad);
} }
} break;
default: { PADDLE_THROW("Unpool op only supports 2D input."); }
}
}
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册