提交 fb110f72 编写于 作者: L lijianshe02 提交者: GitHub

add conv2d, pool2d op and kernels test=develop (#17939)

* add conv2d, pool2d op and kernels test=develop
上级 b1a63197
...@@ -131,6 +131,9 @@ USE_LITE_OP(square) ...@@ -131,6 +131,9 @@ USE_LITE_OP(square)
USE_LITE_OP(softmax) USE_LITE_OP(softmax)
USE_LITE_OP(dropout) USE_LITE_OP(dropout)
USE_LITE_OP(concat) USE_LITE_OP(concat)
USE_LITE_OP(conv2d)
USE_LITE_OP(depthwise_conv2d)
USE_LITE_OP(pool2d)
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
...@@ -145,6 +148,9 @@ USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def); ...@@ -145,6 +148,9 @@ USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def); USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def); USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(concat, kX86, kFloat, kNCHW, def); USE_LITE_KERNEL(concat, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW, def);
#endif #endif
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
......
...@@ -110,6 +110,9 @@ class TensorHvy : public TensorBase<TensorHvy> { ...@@ -110,6 +110,9 @@ class TensorHvy : public TensorBase<TensorHvy> {
void ShareDataWith(const TensorHvy& other) { void ShareDataWith(const TensorHvy& other) {
data_.ShareDataWith(other.data_); data_.ShareDataWith(other.data_);
} }
void ShareDataWith(const framework::Tensor& other) {
data_.ShareDataWith(other);
}
void CopyDataFrom(const TensorHvy& other) { void CopyDataFrom(const TensorHvy& other) {
data_.mutable_data(other.data_.place(), other.data_.type()); data_.mutable_data(other.data_.place(), other.data_.type());
TensorCopySync(other.data_, data_.place(), &data_); TensorCopySync(other.data_, data_.place(), &data_);
......
...@@ -15,6 +15,8 @@ cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kerne ...@@ -15,6 +15,8 @@ cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kerne
cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} ) cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} )
cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} ) cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} )
cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col)
cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling)
set(x86_kernels set(x86_kernels
activation_compute_x86 activation_compute_x86
...@@ -28,6 +30,8 @@ set(x86_kernels ...@@ -28,6 +30,8 @@ set(x86_kernels
softmax_compute_x86 softmax_compute_x86
dropout_compute_x86 dropout_compute_x86
concat_compute_x86 concat_compute_x86
conv_compute_x86
pool_compute_x86
) )
set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels") set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Core>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/operators/conv_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
inline bool IsExpand(const std::vector<int64_t>& filter_dim,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations) {
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
for (size_t j = 0; j < strides.size(); ++j) {
filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
strides_1 = strides_1 && (strides[j] == 1);
padding_0 = padding_0 && (paddings[j] == 0);
dilation_1 = dilation_1 && (dilations[j] == 1);
}
return !(filter_1 && strides_1 && padding_0 && dilation_1);
}
template <typename T>
class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ConvParam;
void Run() override {
auto& param = *param_.get_mutable<operators::ConvParam>();
lite::Tensor filter = *param.filter;
param.output->template mutable_data<T>();
const int batch_size = static_cast<int>(param.x->dims()[0]);
std::vector<int64_t> filter_shape_vec(filter.dims().Vectorize());
std::vector<int64_t> output_shape_vec(param.output->dims().Vectorize());
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = param.x->dims()[1] / param.groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
lite::DDim col_shape(col_shape_vec);
lite::DDim col_matrix_shape = col_shape.Flattern2D(data_dim + 1);
bool is_expand = IsExpand(filter_shape_vec, param.strides, param.paddings,
param.dilations);
lite::Tensor col;
lite::Tensor col_matrix;
if (is_expand) {
col.Resize(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
lite::DDim input_shape = param.x->dims().Slice(1, param.x->dims().size());
lite::DDim filter_matrix_shape(std::vector<int64_t>{
filter.dims()[0], filter.dims().production() / filter.dims()[0]});
filter.Resize(filter_matrix_shape);
lite::DDim output_matrix_shape(std::vector<int64_t>{
param.output->dims()[1],
param.output->dims().production() /
(param.output->dims()[0] * param.output->dims()[1])});
int in_step = static_cast<int>(param.x->dims()[1]) / param.groups;
int out_step = static_cast<int>(param.output->dims()[1]) / param.groups;
paddle::operators::math::Vol2ColFunctor<platform::CPUDeviceContext, T>
vol2col;
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, platform::CPUDeviceContext, T>
im2col;
auto blas = paddle::operators::math::GetBlas<platform::CPUDeviceContext, T>(
platform::CPUDeviceContext());
for (int i = 0; i < batch_size; i++) {
lite::Tensor in_batch;
in_batch.ShareDataWith(
param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data()));
lite::Tensor out_batch;
out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize(
input_shape.data()));
for (int g = 0; g < param.groups; g++) {
lite::Tensor in_slice;
in_slice.ShareDataWith(
in_batch.raw_tensor().Slice(g * in_step, (g + 1) * in_step));
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(platform::CPUDeviceContext(), in_slice.raw_tensor(),
param.dilations, param.strides,
std::vector<int>{param.paddings[0], param.paddings[1],
param.paddings[0], param.paddings[1]},
&(col.raw_tensor()));
} else if (data_dim == 3U) {
// vol2col
vol2col(platform::CPUDeviceContext(), in_slice.raw_tensor(),
param.dilations, param.strides, param.paddings,
&(col.raw_tensor()));
}
// gemm
lite::Tensor out_slice;
out_slice.ShareDataWith(
out_batch.raw_tensor().Slice(g * out_step, (g + 1) * out_step));
lite::Tensor filter_slice;
filter_slice.ShareDataWith(
filter.raw_tensor().Slice(g * out_step, (g + 1) * out_step));
blas.MatMul(filter_slice.raw_tensor(), false, col_matrix.raw_tensor(),
false, T(1.0), &(out_slice.raw_tensor()), T(0.0));
}
}
}
virtual ~Conv2dCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::Conv2dCompute<float>, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::Conv2dCompute<float>, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class PoolCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::PoolParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
if (param.global_pooling) {
for (size_t i = 0; i < param.ksize.size(); ++i) {
param.paddings[i] = 0;
param.ksize[i] = static_cast<int>(param.x->dims()[i + 2]);
}
}
switch (param.ksize.size()) {
case 2: {
if (param.pooling_type == "max") {
paddle::operators::math::Pool2dFunctor<
platform::CPUDeviceContext, paddle::operators::math::MaxPool<T>,
T>
pool2d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(platform::CPUDeviceContext(), param.x->raw_tensor(),
param.ksize, param.strides, param.paddings,
pool_process, true, false,
&(param.output->raw_tensor()));
} else if (param.pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor<
platform::CPUDeviceContext, paddle::operators::math::AvgPool<T>,
T>
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(platform::CPUDeviceContext(), param.x->raw_tensor(),
param.ksize, param.strides, param.paddings,
pool_process, param.exclusive, param.adaptive,
&(param.output->raw_tensor()));
}
} break;
case 3: {
} break;
}
}
virtual ~PoolCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::PoolCompute<float>, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
...@@ -17,6 +17,8 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) ...@@ -17,6 +17,8 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite)
cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS})
cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS})
cc_library(conv_op_lite SRCS conv_op.cc DEPS ${op_DEPS})
cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS})
set(ops_lite set(ops_lite
fc_op_lite fc_op_lite
...@@ -34,6 +36,8 @@ set(ops_lite ...@@ -34,6 +36,8 @@ set(ops_lite
activation_ops_lite activation_ops_lite
dropout_op_lite dropout_op_lite
concat_op_lite concat_op_lite
conv_op_lite
pool_op_lite
PARENT_SCOPE) PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/conv_op.h"
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ConvOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
CHECK_OR_FALSE(param_.filter);
return true;
}
bool ConvOpLite::InferShape() const {
auto in_dims = param_.x->dims();
auto filter_dims = param_.filter->dims();
std::vector<int> strides = param_.strides;
std::vector<int> paddings = param_.paddings;
int groups = param_.groups;
std::vector<int> dilations = param_.dilations;
CHECK_OR_FALSE(in_dims.size() == 4 || in_dims.size() == 5);
CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size());
CHECK_OR_FALSE(in_dims.size() - strides.size() == 2U);
CHECK_EQ_OR_FALSE(paddings.size(), strides.size());
CHECK_EQ_OR_FALSE(in_dims[1], filter_dims[1] * groups);
CHECK_EQ_OR_FALSE(filter_dims[0] % groups, 0);
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i],
strides[i]));
}
param_.output->Resize(lite::DDim(output_shape));
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(conv2d, paddle::lite::operators::ConvOpLite);
REGISTER_LITE_OP(depthwise_conv2d, paddle::lite::operators::ConvOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
inline int ConvOutputSize(int input_size, int filter_size, int dilation,
int padding, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
CHECK_OR_FALSE(output_size > 0);
return output_size;
}
inline bool IsExpand(const std::vector<int64_t>& filter_dim,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations) {
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
for (size_t j = 0; j < strides.size(); ++j) {
filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
strides_1 = strides_1 && (strides[j] == 1);
padding_0 = padding_0 && (paddings[j] == 0);
dilation_1 = dilation_1 && (dilations[j] == 1);
}
return !(filter_1 && strides_1 && padding_0 && dilation_1);
}
class ConvOpLite : public OpLite {
public:
ConvOpLite() {}
explicit ConvOpLite(const std::string& type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
auto X = op_desc.Input("Input").front();
auto Filter = op_desc.Input("Filter").front();
auto Bias = op_desc.Input("Bias").front();
// auto ResidualData = op_desc.Input("ResidualData");
auto Out = op_desc.Output("Output").front();
param_.x = scope->FindVar(X)->GetMutable<lite::Tensor>();
param_.filter = scope->FindVar(Filter)->GetMutable<lite::Tensor>();
param_.bias = scope->FindVar(Bias)->GetMutable<lite::Tensor>();
// param_.residualData =
// scope->FindVar(ResidualData)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.groups = op_desc.GetAttr<int>("groups");
param_.dilations = op_desc.GetAttr<std::vector<int>>("dilations");
return true;
}
std::string DebugString() const override { return "conv2d"; }
private:
mutable ConvParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/pool_op.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
int PoolOutputSize(int input_size, int filter_size, int padding, int stride,
bool ceil_mode) {
int output_size;
if (!ceil_mode) {
output_size = (input_size - filter_size + 2 * padding) / stride + 1;
} else {
output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
}
CHECK_OR_FALSE(output_size > 0);
return output_size;
}
bool PoolOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
return true;
}
bool PoolOpLite::InferShape() const {
const auto input_dims = param_.x->dims();
CHECK_OR_FALSE(input_dims.size() == 4 || input_dims.size() == 5);
if (param_.global_pooling) {
param_.ksize.resize(static_cast<size_t>(input_dims.size()) - 2);
for (size_t i = 0; i < param_.ksize.size(); ++i) {
param_.paddings[i] = 0;
param_.ksize[i] = static_cast<int>(input_dims[i + 2]);
}
}
CHECK_OR_FALSE(input_dims.size() - param_.ksize.size() == 2U);
CHECK_EQ_OR_FALSE(param_.ksize.size(), param_.strides.size());
CHECK_EQ_OR_FALSE(param_.ksize.size(), param_.paddings.size());
std::vector<int64_t> output_shape({input_dims[0], input_dims[1]});
if (param_.adaptive) {
output_shape.insert(output_shape.end(), param_.ksize.begin(),
param_.ksize.end());
} else {
for (size_t i = 0; i < param_.ksize.size(); ++i) {
output_shape.push_back(
PoolOutputSize(input_dims[i + 2], param_.ksize[i], param_.paddings[i],
param_.strides[i], param_.ceil_mode));
}
}
// share LoD
// param_.output->set_lod(param_.input->lod());
param_.output->Resize(lite::DDim(output_shape));
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(pool2d, paddle::lite::operators::PoolOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class PoolOpLite : public OpLite {
public:
PoolOpLite() {}
explicit PoolOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto input = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
param_.x = scope->FindVar(input)->GetMutable<Tensor>();
param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.pooling_type = op_desc.GetAttr<std::string>("pooling_type");
param_.ksize = op_desc.GetAttr<std::vector<int>>("ksize");
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.ceil_mode = op_desc.GetAttr<bool>("ceil_mode");
param_.adaptive = op_desc.GetAttr<bool>("adaptive");
param_.global_pooling = op_desc.GetAttr<bool>("global_pooling");
return true;
}
std::string DebugString() const override { return "pool"; }
private:
mutable PoolParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册