From 3bc305b6c17e5ea82cc4cad5fa7cfd9fb1a18cf5 Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Tue, 11 Jun 2019 10:51:07 +0800 Subject: [PATCH] add conv2d, pool2d op and kernels test=develop (#17939) * add conv2d, pool2d op and kernels test=develop --- paddle/fluid/lite/api/cxx_api_test.cc | 6 + paddle/fluid/lite/core/hvy_tensor.h | 3 + paddle/fluid/lite/kernels/x86/CMakeLists.txt | 4 + paddle/fluid/lite/kernels/x86/conv_compute.cc | 169 ++++++++++++++++++ paddle/fluid/lite/kernels/x86/pool_compute.cc | 80 +++++++++ paddle/fluid/lite/operators/CMakeLists.txt | 4 + paddle/fluid/lite/operators/conv_op.cc | 60 +++++++ paddle/fluid/lite/operators/conv_op.h | 94 ++++++++++ paddle/fluid/lite/operators/pool_op.cc | 78 ++++++++ paddle/fluid/lite/operators/pool_op.h | 64 +++++++ 10 files changed, 562 insertions(+) create mode 100644 paddle/fluid/lite/kernels/x86/conv_compute.cc create mode 100644 paddle/fluid/lite/kernels/x86/pool_compute.cc create mode 100644 paddle/fluid/lite/operators/conv_op.cc create mode 100644 paddle/fluid/lite/operators/conv_op.h create mode 100644 paddle/fluid/lite/operators/pool_op.cc create mode 100644 paddle/fluid/lite/operators/pool_op.h diff --git a/paddle/fluid/lite/api/cxx_api_test.cc b/paddle/fluid/lite/api/cxx_api_test.cc index c818f33e029..430bd9b58f8 100644 --- a/paddle/fluid/lite/api/cxx_api_test.cc +++ b/paddle/fluid/lite/api/cxx_api_test.cc @@ -131,6 +131,9 @@ USE_LITE_OP(square) USE_LITE_OP(softmax) USE_LITE_OP(dropout) USE_LITE_OP(concat) +USE_LITE_OP(conv2d) +USE_LITE_OP(depthwise_conv2d) +USE_LITE_OP(pool2d) USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); @@ -145,6 +148,9 @@ USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def); USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def); USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def); USE_LITE_KERNEL(concat, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW, def); #endif #ifdef LITE_WITH_CUDA diff --git a/paddle/fluid/lite/core/hvy_tensor.h b/paddle/fluid/lite/core/hvy_tensor.h index 21bfa2b48e2..16172a80035 100644 --- a/paddle/fluid/lite/core/hvy_tensor.h +++ b/paddle/fluid/lite/core/hvy_tensor.h @@ -110,6 +110,9 @@ class TensorHvy : public TensorBase { void ShareDataWith(const TensorHvy& other) { data_.ShareDataWith(other.data_); } + void ShareDataWith(const framework::Tensor& other) { + data_.ShareDataWith(other); + } void CopyDataFrom(const TensorHvy& other) { data_.mutable_data(other.data_.place(), other.data_.type()); TensorCopySync(other.data_, data_.place(), &data_); diff --git a/paddle/fluid/lite/kernels/x86/CMakeLists.txt b/paddle/fluid/lite/kernels/x86/CMakeLists.txt index 62db7a0a226..3747351d562 100644 --- a/paddle/fluid/lite/kernels/x86/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/x86/CMakeLists.txt @@ -15,6 +15,8 @@ cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kerne cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} ) cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} ) +cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col) +cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling) set(x86_kernels activation_compute_x86 @@ -28,6 +30,8 @@ set(x86_kernels softmax_compute_x86 dropout_compute_x86 concat_compute_x86 + conv_compute_x86 + pool_compute_x86 ) set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels") diff --git a/paddle/fluid/lite/kernels/x86/conv_compute.cc b/paddle/fluid/lite/kernels/x86/conv_compute.cc new file mode 100644 index 00000000000..9d2de5be452 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/conv_compute.cc @@ -0,0 +1,169 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/lite/operators/conv_op.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/depthwise_conv.h" +#include "paddle/fluid/operators/math/im2col.h" +#include "paddle/fluid/operators/math/vol2col.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +inline bool IsExpand(const std::vector& filter_dim, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations) { + bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; + for (size_t j = 0; j < strides.size(); ++j) { + filter_1 = filter_1 && (static_cast(filter_dim[j + 2]) == 1); + strides_1 = strides_1 && (strides[j] == 1); + padding_0 = padding_0 && (paddings[j] == 0); + dilation_1 = dilation_1 && (dilations[j] == 1); + } + return !(filter_1 && strides_1 && padding_0 && dilation_1); +} + +template +class Conv2dCompute : public KernelLite { + public: + using param_t = operators::ConvParam; + void Run() override { + auto& param = *param_.get_mutable(); + lite::Tensor filter = *param.filter; + param.output->template mutable_data(); + + const int batch_size = static_cast(param.x->dims()[0]); + + std::vector filter_shape_vec(filter.dims().Vectorize()); + std::vector output_shape_vec(param.output->dims().Vectorize()); + + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = param.x->dims()[1] / param.groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + lite::DDim col_shape(col_shape_vec); + lite::DDim col_matrix_shape = col_shape.Flattern2D(data_dim + 1); + bool is_expand = IsExpand(filter_shape_vec, param.strides, param.paddings, + param.dilations); + + lite::Tensor col; + lite::Tensor col_matrix; + if (is_expand) { + col.Resize(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + lite::DDim input_shape = param.x->dims().Slice(1, param.x->dims().size()); + + lite::DDim filter_matrix_shape(std::vector{ + filter.dims()[0], filter.dims().production() / filter.dims()[0]}); + filter.Resize(filter_matrix_shape); + + lite::DDim output_matrix_shape(std::vector{ + param.output->dims()[1], + param.output->dims().production() / + (param.output->dims()[0] * param.output->dims()[1])}); + + int in_step = static_cast(param.x->dims()[1]) / param.groups; + int out_step = static_cast(param.output->dims()[1]) / param.groups; + + paddle::operators::math::Vol2ColFunctor + vol2col; + paddle::operators::math::Im2ColFunctor< + paddle::operators::math::ColFormat::kCFO, platform::CPUDeviceContext, T> + im2col; + auto blas = paddle::operators::math::GetBlas( + platform::CPUDeviceContext()); + for (int i = 0; i < batch_size; i++) { + lite::Tensor in_batch; + in_batch.ShareDataWith( + param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data())); + lite::Tensor out_batch; + out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize( + input_shape.data())); + + for (int g = 0; g < param.groups; g++) { + lite::Tensor in_slice; + in_slice.ShareDataWith( + in_batch.raw_tensor().Slice(g * in_step, (g + 1) * in_step)); + + if (!is_expand) { + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(platform::CPUDeviceContext(), in_slice.raw_tensor(), + param.dilations, param.strides, + std::vector{param.paddings[0], param.paddings[1], + param.paddings[0], param.paddings[1]}, + &(col.raw_tensor())); + } else if (data_dim == 3U) { + // vol2col + vol2col(platform::CPUDeviceContext(), in_slice.raw_tensor(), + param.dilations, param.strides, param.paddings, + &(col.raw_tensor())); + } + + // gemm + lite::Tensor out_slice; + out_slice.ShareDataWith( + out_batch.raw_tensor().Slice(g * out_step, (g + 1) * out_step)); + lite::Tensor filter_slice; + filter_slice.ShareDataWith( + filter.raw_tensor().Slice(g * out_step, (g + 1) * out_step)); + blas.MatMul(filter_slice.raw_tensor(), false, col_matrix.raw_tensor(), + false, T(1.0), &(out_slice.raw_tensor()), T(0.0)); + } + } + } + + virtual ~Conv2dCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::Conv2dCompute, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::Conv2dCompute, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/pool_compute.cc b/paddle/fluid/lite/kernels/x86/pool_compute.cc new file mode 100644 index 00000000000..745c2a78789 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/pool_compute.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/pooling.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class PoolCompute : public KernelLite { + public: + using param_t = operators::PoolParam; + void Run() override { + auto& param = *param_.get_mutable(); + if (param.global_pooling) { + for (size_t i = 0; i < param.ksize.size(); ++i) { + param.paddings[i] = 0; + param.ksize[i] = static_cast(param.x->dims()[i + 2]); + } + } + switch (param.ksize.size()) { + case 2: { + if (param.pooling_type == "max") { + paddle::operators::math::Pool2dFunctor< + platform::CPUDeviceContext, paddle::operators::math::MaxPool, + T> + pool2d_forward; + paddle::operators::math::MaxPool pool_process; + pool2d_forward(platform::CPUDeviceContext(), param.x->raw_tensor(), + param.ksize, param.strides, param.paddings, + pool_process, true, false, + &(param.output->raw_tensor())); + } else if (param.pooling_type == "avg") { + paddle::operators::math::Pool2dFunctor< + platform::CPUDeviceContext, paddle::operators::math::AvgPool, + T> + pool2d_forward; + paddle::operators::math::AvgPool pool_process; + pool2d_forward(platform::CPUDeviceContext(), param.x->raw_tensor(), + param.ksize, param.strides, param.paddings, + pool_process, param.exclusive, param.adaptive, + &(param.output->raw_tensor())); + } + } break; + case 3: { + } break; + } + } + virtual ~PoolCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::PoolCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 8c855ed4610..ed26f5fdb1f 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -17,6 +17,8 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite) cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) +cc_library(conv_op_lite SRCS conv_op.cc DEPS ${op_DEPS}) +cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS}) set(ops_lite fc_op_lite @@ -34,6 +36,8 @@ set(ops_lite activation_ops_lite dropout_op_lite concat_op_lite + conv_op_lite + pool_op_lite PARENT_SCOPE) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc diff --git a/paddle/fluid/lite/operators/conv_op.cc b/paddle/fluid/lite/operators/conv_op.cc new file mode 100644 index 00000000000..63838efd6fe --- /dev/null +++ b/paddle/fluid/lite/operators/conv_op.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/conv_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool ConvOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + CHECK_OR_FALSE(param_.filter); + return true; +} + +bool ConvOpLite::InferShape() const { + auto in_dims = param_.x->dims(); + auto filter_dims = param_.filter->dims(); + std::vector strides = param_.strides; + std::vector paddings = param_.paddings; + int groups = param_.groups; + std::vector dilations = param_.dilations; + + CHECK_OR_FALSE(in_dims.size() == 4 || in_dims.size() == 5); + CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size()); + CHECK_OR_FALSE(in_dims.size() - strides.size() == 2U); + CHECK_EQ_OR_FALSE(paddings.size(), strides.size()); + CHECK_EQ_OR_FALSE(in_dims[1], filter_dims[1] * groups); + CHECK_EQ_OR_FALSE(filter_dims[0] % groups, 0); + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); + } + param_.output->Resize(lite::DDim(output_shape)); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(conv2d, paddle::lite::operators::ConvOpLite); +REGISTER_LITE_OP(depthwise_conv2d, paddle::lite::operators::ConvOpLite); diff --git a/paddle/fluid/lite/operators/conv_op.h b/paddle/fluid/lite/operators/conv_op.h new file mode 100644 index 00000000000..3ab30eb787b --- /dev/null +++ b/paddle/fluid/lite/operators/conv_op.h @@ -0,0 +1,94 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +inline int ConvOutputSize(int input_size, int filter_size, int dilation, + int padding, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = (input_size + 2 * padding - dkernel) / stride + 1; + CHECK_OR_FALSE(output_size > 0); + + return output_size; +} + +inline bool IsExpand(const std::vector& filter_dim, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations) { + bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; + for (size_t j = 0; j < strides.size(); ++j) { + filter_1 = filter_1 && (static_cast(filter_dim[j + 2]) == 1); + strides_1 = strides_1 && (strides[j] == 1); + padding_0 = padding_0 && (paddings[j] == 0); + dilation_1 = dilation_1 && (dilations[j] == 1); + } + return !(filter_1 && strides_1 && padding_0 && dilation_1); +} + +class ConvOpLite : public OpLite { + public: + ConvOpLite() {} + + explicit ConvOpLite(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + // TODO(Superjomn) replace framework::OpDesc with a lite one. + bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { + auto X = op_desc.Input("Input").front(); + auto Filter = op_desc.Input("Filter").front(); + auto Bias = op_desc.Input("Bias").front(); + // auto ResidualData = op_desc.Input("ResidualData"); + auto Out = op_desc.Output("Output").front(); + + param_.x = scope->FindVar(X)->GetMutable(); + param_.filter = scope->FindVar(Filter)->GetMutable(); + param_.bias = scope->FindVar(Bias)->GetMutable(); + // param_.residualData = + // scope->FindVar(ResidualData)->GetMutable(); + param_.output = scope->FindVar(Out)->GetMutable(); + + param_.strides = op_desc.GetAttr>("strides"); + param_.paddings = op_desc.GetAttr>("paddings"); + param_.groups = op_desc.GetAttr("groups"); + param_.dilations = op_desc.GetAttr>("dilations"); + + return true; + } + + std::string DebugString() const override { return "conv2d"; } + + private: + mutable ConvParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/pool_op.cc b/paddle/fluid/lite/operators/pool_op.cc new file mode 100644 index 00000000000..055f00f90a4 --- /dev/null +++ b/paddle/fluid/lite/operators/pool_op.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/pool_op.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +int PoolOutputSize(int input_size, int filter_size, int padding, int stride, + bool ceil_mode) { + int output_size; + if (!ceil_mode) { + output_size = (input_size - filter_size + 2 * padding) / stride + 1; + } else { + output_size = + (input_size - filter_size + 2 * padding + stride - 1) / stride + 1; + } + CHECK_OR_FALSE(output_size > 0); + return output_size; +} + +bool PoolOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + return true; +} + +bool PoolOpLite::InferShape() const { + const auto input_dims = param_.x->dims(); + CHECK_OR_FALSE(input_dims.size() == 4 || input_dims.size() == 5); + + if (param_.global_pooling) { + param_.ksize.resize(static_cast(input_dims.size()) - 2); + for (size_t i = 0; i < param_.ksize.size(); ++i) { + param_.paddings[i] = 0; + param_.ksize[i] = static_cast(input_dims[i + 2]); + } + } + + CHECK_OR_FALSE(input_dims.size() - param_.ksize.size() == 2U); + CHECK_EQ_OR_FALSE(param_.ksize.size(), param_.strides.size()); + CHECK_EQ_OR_FALSE(param_.ksize.size(), param_.paddings.size()); + + std::vector output_shape({input_dims[0], input_dims[1]}); + if (param_.adaptive) { + output_shape.insert(output_shape.end(), param_.ksize.begin(), + param_.ksize.end()); + } else { + for (size_t i = 0; i < param_.ksize.size(); ++i) { + output_shape.push_back( + PoolOutputSize(input_dims[i + 2], param_.ksize[i], param_.paddings[i], + param_.strides[i], param_.ceil_mode)); + } + } + // share LoD + // param_.output->set_lod(param_.input->lod()); + param_.output->Resize(lite::DDim(output_shape)); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(pool2d, paddle::lite::operators::PoolOpLite); diff --git a/paddle/fluid/lite/operators/pool_op.h b/paddle/fluid/lite/operators/pool_op.h new file mode 100644 index 00000000000..64c15ccf1db --- /dev/null +++ b/paddle/fluid/lite/operators/pool_op.h @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class PoolOpLite : public OpLite { + public: + PoolOpLite() {} + + explicit PoolOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + // TODO(Superjomn) replace framework::OpDesc with a lite one. + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto input = op_desc.Input("X").front(); + auto out = op_desc.Output("Out").front(); + + param_.x = scope->FindVar(input)->GetMutable(); + param_.output = scope->FindVar(out)->GetMutable(); + param_.pooling_type = op_desc.GetAttr("pooling_type"); + param_.ksize = op_desc.GetAttr>("ksize"); + param_.strides = op_desc.GetAttr>("strides"); + param_.paddings = op_desc.GetAttr>("paddings"); + param_.ceil_mode = op_desc.GetAttr("ceil_mode"); + param_.adaptive = op_desc.GetAttr("adaptive"); + param_.global_pooling = op_desc.GetAttr("global_pooling"); + return true; + } + + std::string DebugString() const override { return "pool"; } + + private: + mutable PoolParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle -- GitLab