提交 66b84366 编写于 作者: S sweetsky0901

modify for code review by wangyi

上级 e553d572
......@@ -16,11 +16,9 @@
namespace paddle {
namespace operators {
using framework::Tensor;
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Unpool2dOpMaker(framework::OpProto* proto, \
Unpool2dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
......@@ -38,26 +36,26 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"the number of channels, H and W is the height and "
"width of feature.");
AddAttr<std::vector<int>>("ksize",
"(vector ), the unpooling window size(height, width) "
"(vector), the unpooling window size(height, width) "
"of unpooling operator.");
AddAttr<std::vector<int>>("strides",
"(vector, default:{1, 1}), "
"strides(height, width) of unpooling operator.")
"strides (height, width) of unpooling operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
"(vector defalut:{0,0}), "
"paddings(height, width) of unpooling operator.")
"paddings (height, width) of unpooling operator.")
.SetDefault({0, 0});
AddAttr<std::string>("unpoolingtype",
"(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"});
AddComment(R"DOC(
"input: the input Tensor to invert"
"indices: the indices given out by MaxPool2d"
"ksize – Size of the max pooling window."
"stride – Stride of the max pooling window."
"It is set to kernel_size by default."
"padding – Padding that was added to the input"
"input: the input Tensor to invert
indices: the indices given out by MaxPool2d
ksize – Size of the max pooling window.
stride – Stride of the max pooling window.
"It is set to kernel_size by default.
padding – Padding that was added to the input"
)DOC");
}
};
......@@ -80,14 +78,14 @@ class UnpoolOp : public framework::OperatorWithKernel {
auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Y");
std::string unpoolingtype = \
std::string unpoolingtype =
ctx->Attrs().Get<std::string>("unpoolingtype");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4,
"Unpooling intput should be 4-D.");
"Unpooling intput must be of 4-dimensional.");
for (int i = 0; i < 4; ++i) {
PADDLE_ENFORCE(in_x_dims[i] == in_y_dims[i],
"X size must be eq Y size!");
......
......@@ -21,15 +21,13 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class UnpoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
const Tensor* in_y = context.Input<Tensor>("Y");
auto * out = context.Output<Tensor>("Out");
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
auto * out = context.Output<framework::Tensor>("Out");
std::string unpoolingtype = context.Attr<std::string>("unpoolingtype");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
......@@ -39,15 +37,8 @@ class UnpoolKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), out, static_cast<T>(0));
}
switch (ksize.size()) {
case 2: {
if (unpoolingtype == "max") {
math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D input."); }
}
math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
}
};
......@@ -55,12 +46,13 @@ template <typename Place, typename T>
class UnpoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
const Tensor* in_y = context.Input<Tensor>("Y");
const Tensor* out = context.Input<Tensor>("Out");
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
const framework::Tensor* out = context.Input<framework::Tensor>("Out");
const framework::Tensor* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
std::string unpoolingtype = context.Attr<std::string>("unpoolingtype");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
......@@ -70,18 +62,11 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> zero;
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0.0));
}
switch (ksize.size()) {
case 2: {
if (unpoolingtype == "max") {
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad,
*out, *out_grad);
}
} break;
default: { PADDLE_THROW("Unpool op only supports 2D input."); }
zero(device_ctx, in_x_grad, static_cast<T>(0));
}
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad,
*out, *out_grad);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册