提交 cfd7721b 编写于 作者: S sweetsky0901

add unpool_op.h modify

上级 a38bbc86
...@@ -28,7 +28,7 @@ class UnpoolKernel : public framework::OpKernel<T> { ...@@ -28,7 +28,7 @@ class UnpoolKernel : public framework::OpKernel<T> {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X"); const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Y"); const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
auto * out = context.Output<framework::Tensor>("Out"); auto * out = context.Output<framework::Tensor>("Out");
std::string unpoolingtype = context.Attr<std::string>("unpoolingtype"); std::string unpooling_type = context.Attr<std::string>("unpoolingtype");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
...@@ -53,7 +53,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> { ...@@ -53,7 +53,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
context.Input<framework::Tensor>(framework::GradVarName("Out")); context.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor* in_x_grad = framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X")); context.Output<framework::Tensor>(framework::GradVarName("X"));
std::string unpoolingtype = context.Attr<std::string>("unpoolingtype"); std::string unpooling_type = context.Attr<std::string>("unpoolingtype");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
...@@ -65,8 +65,8 @@ class UnpoolGradKernel : public framework::OpKernel<T> { ...@@ -65,8 +65,8 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
zero(device_ctx, in_x_grad, static_cast<T>(0)); zero(device_ctx, in_x_grad, static_cast<T>(0));
} }
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward; math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad, unpool2d_max_backward(context.device_context(), *in_x, *in_y,
*out, *out_grad); *out, *out_grad, in_x_grad);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册