diff --git a/paddle/operators/math/unpooling.cc b/paddle/operators/math/unpooling.cc index d8647c6b236dfc79000f3b8d35efe662b08003f9..ab6212f3872d2b6583b5b1ff5aa9dce50b758994 100644 --- a/paddle/operators/math/unpooling.cc +++ b/paddle/operators/math/unpooling.cc @@ -19,8 +19,8 @@ namespace operators { namespace math { // All tensors are in NCHW format -template -class Unpool2dMaxFunctor { +template +class Unpool2dMaxFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -35,7 +35,7 @@ class Unpool2dMaxFunctor { int input_feasize = input_height * input_width; int output_feasize = output_height * output_width; const T* input_data = input.data(); - const T * indices_data = indices.data(); + const T2 * indices_data = indices.data(); T* output_data = output->mutable_data(context.GetPlace()); for (int b = 0; b < batch_size; ++b) { for (int c = 0; c < output_channels; ++c) { @@ -54,8 +54,8 @@ class Unpool2dMaxFunctor { -template -class Unpool2dMaxGradFunctor { +template +class Unpool2dMaxGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -71,7 +71,7 @@ public: const int output_width = output.dims()[3]; int input_feasize = input_height * input_width; int output_feasize = output_height * output_width; - const T* indices_data = indices.data(); + const T2 * indices_data = indices.data(); const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad->mutable_data(context.GetPlace()); @@ -90,10 +90,10 @@ public: } }; -template class Unpool2dMaxGradFunctor; -template class Unpool2dMaxGradFunctor; -template class Unpool2dMaxFunctor; -template class Unpool2dMaxFunctor; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxFunctor; +template class Unpool2dMaxFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/unpooling.cu b/paddle/operators/math/unpooling.cu index d3eaa48547ee381ed48434058482e255f287ca4c..c8fd58eca55318ae12348c7c7d173cbb51aabb25 100644 --- a/paddle/operators/math/unpooling.cu +++ b/paddle/operators/math/unpooling.cu @@ -19,10 +19,10 @@ namespace paddle { namespace operators { namespace math { -template +template __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, - const T* indices_data, + const T2 * indices_data, const int input_height, const int input_width, const int channels, @@ -45,10 +45,10 @@ __global__ void KernelUnpool2dMax(const int nthreads, output_data[out_offset + out_index] = input_data[i]; } } -template +template __global__ void KernelUnpool2dMaxGrad(const int nthreads, const T* input_data, - const T* indices_data, + const T2* indices_data, const int input_height, const int input_width, const int channels, @@ -76,8 +76,8 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads, /* * All tensors are in NCHW format. */ -template -class Unpool2dMaxFunctor { +template +class Unpool2dMaxFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -90,7 +90,7 @@ class Unpool2dMaxFunctor { const int output_height = output->dims()[2]; const int output_width = output->dims()[3]; const T* input_data = input.data(); - const T* indices_data = indices.data(); + const T2 * indices_data = indices.data(); T* output_data = output->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * input_height * input_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -98,7 +98,7 @@ class Unpool2dMaxFunctor { dim3 grid(blocks, 1); KernelUnpool2dMax< - T><<<<(context) .stream()>>>(nthreads, input_data, indices_data, input_height, input_width, output_channels, @@ -108,8 +108,8 @@ class Unpool2dMaxFunctor { /* * All tensors are in NCHW format. */ -template -class Unpool2dMaxGradFunctor { +template +class Unpool2dMaxGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -124,7 +124,7 @@ class Unpool2dMaxGradFunctor { const int output_height = output.dims()[2]; const int output_width = output.dims()[3]; const T* input_data = input.data(); - const T* indices_data = indices.data(); + const T2 * indices_data = indices.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad->mutable_data(context.GetPlace()); @@ -134,7 +134,7 @@ class Unpool2dMaxGradFunctor { dim3 grid(blocks, 1); KernelUnpool2dMaxGrad< - T><<<<(context) .stream()>>>( nthreads, input_data, indices_data, @@ -145,11 +145,11 @@ class Unpool2dMaxGradFunctor { } }; -template class Unpool2dMaxGradFunctor; -template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxGradFunctor; -template class Unpool2dMaxFunctor; -template class Unpool2dMaxFunctor; +template class Unpool2dMaxFunctor; +template class Unpool2dMaxFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/unpooling.h b/paddle/operators/math/unpooling.h index bf79354ed927017e0ad802cfca832a3295c7309f..e086b891a1653668c507c87ea74b8b56146ed65d 100644 --- a/paddle/operators/math/unpooling.h +++ b/paddle/operators/math/unpooling.h @@ -19,7 +19,7 @@ namespace paddle { namespace operators { namespace math { -template +template class Unpool2dMaxFunctor { public: @@ -29,7 +29,7 @@ class Unpool2dMaxFunctor { framework::Tensor * output); }; -template +template class Unpool2dMaxGradFunctor { public: void operator()(const platform::DeviceContext& context, diff --git a/paddle/operators/unpool_op.cc b/paddle/operators/unpool_op.cc index ada9ce8ce5debb74b5500ee97839fd35707268d7..f00459cd8590886cb8ff13a3bdd3c6298c65fc70 100644 --- a/paddle/operators/unpool_op.cc +++ b/paddle/operators/unpool_op.cc @@ -66,7 +66,15 @@ int OutputSize(int input_size, int ksize, int padding, int stride) { } class UnpoolOp : public framework::OperatorWithKernel { - public: +protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } + +public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of UnpoolOp" @@ -102,6 +110,14 @@ class UnpoolOp : public framework::OperatorWithKernel { }; class UnpoolOpGrad : public framework::OperatorWithKernel { + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } + public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { @@ -118,9 +134,9 @@ namespace ops = paddle::operators; REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad, ops::UnpoolOpGrad); REGISTER_OP_CPU_KERNEL(unpool, - ops::UnpoolKernel, - ops::UnpoolKernel); + ops::UnpoolKernel, + ops::UnpoolKernel); REGISTER_OP_CPU_KERNEL(unpool_grad, - ops::UnpoolGradKernel, - ops::UnpoolGradKernel); + ops::UnpoolGradKernel, + ops::UnpoolGradKernel); diff --git a/paddle/operators/unpool_op.cu.cc b/paddle/operators/unpool_op.cu.cc index 4949fc467e006598956f2c8b4d30fd5d2e8bf68b..0a1d8b5996de47faef50042911dcca72d5d8a337 100644 --- a/paddle/operators/unpool_op.cu.cc +++ b/paddle/operators/unpool_op.cu.cc @@ -16,10 +16,10 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(unpool, - ops::UnpoolKernel, - ops::UnpoolKernel); + ops::UnpoolKernel, + ops::UnpoolKernel); REGISTER_OP_GPU_KERNEL(unpool_grad, ops::UnpoolGradKernel, + float, int>, ops::UnpoolGradKernel); + double, int>); diff --git a/paddle/operators/unpool_op.h b/paddle/operators/unpool_op.h index ae11a9f4f801a0725e033395a2307c95c601af5e..c2942211816f872a0f9456a84f648b788ac6f95d 100644 --- a/paddle/operators/unpool_op.h +++ b/paddle/operators/unpool_op.h @@ -21,7 +21,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class UnpoolKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -37,12 +37,12 @@ class UnpoolKernel : public framework::OpKernel { math::SetConstant set_zero; set_zero(context.device_context(), out, static_cast(0)); } - math::Unpool2dMaxFunctor unpool2d_max_forward; + math::Unpool2dMaxFunctor unpool2d_max_forward; unpool2d_max_forward(context.device_context(), *in_x, *in_y, out); } }; -template +template class UnpoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -64,7 +64,7 @@ class UnpoolGradKernel : public framework::OpKernel { in_x_grad->mutable_data(context.GetPlace()); zero(device_ctx, in_x_grad, static_cast(0)); } - math::Unpool2dMaxGradFunctor unpool2d_max_backward; + math::Unpool2dMaxGradFunctor unpool2d_max_backward; unpool2d_max_backward(context.device_context(), *in_x, *in_y, *out, *out_grad, in_x_grad); } diff --git a/python/paddle/v2/fluid/tests/test_unpool_op.py b/python/paddle/v2/fluid/tests/test_unpool_op.py index 106af9f5d91794c2db66a7e1b777f0afdbb81e41..3fdee9091fc68e7f8953453080c3e503161b3c29 100644 --- a/python/paddle/v2/fluid/tests/test_unpool_op.py +++ b/python/paddle/v2/fluid/tests/test_unpool_op.py @@ -53,7 +53,7 @@ class TestUnpoolOp(OpTest): output = self.Unpool2d_forward_naive(input, indices, self.ksize, \ self.strides, self.paddings).astype("float32") self.inputs = {'X': input.astype('float32'), - 'Y': indices.astype('int16')} + 'Y': indices.astype('int32')} self.attrs = { 'strides': self.strides, 'paddings': self.paddings,