未验证 提交 499b389e 编写于 作者: S Santa An 提交者: GitHub

[LITE][BM] adaptive pool,test=develop (#3425)

* [LITE][BM] fix reshape infer shape issue, optimize global pool,adaptive pool,multi thread 
上级 43777438
...@@ -48,6 +48,7 @@ USE_LITE_OP(concat) ...@@ -48,6 +48,7 @@ USE_LITE_OP(concat)
USE_LITE_OP(conv2d) USE_LITE_OP(conv2d)
USE_LITE_OP(depthwise_conv2d) USE_LITE_OP(depthwise_conv2d)
USE_LITE_OP(pool2d) USE_LITE_OP(pool2d)
USE_LITE_OP(max_pool2d_with_index)
USE_LITE_OP(batch_norm) USE_LITE_OP(batch_norm)
USE_LITE_OP(fusion_elementwise_sub_activation) USE_LITE_OP(fusion_elementwise_sub_activation)
USE_LITE_OP(transpose) USE_LITE_OP(transpose)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <fstream> #include <fstream>
#include <thread> //NOLINT
#include <vector> #include <vector>
#include "lite/api/cxx_api.h" #include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_kernels.h"
...@@ -30,14 +31,18 @@ DEFINE_string(input_img_txt_path, ...@@ -30,14 +31,18 @@ DEFINE_string(input_img_txt_path,
namespace paddle { namespace paddle {
namespace lite { namespace lite {
void TestModel(const std::vector<Place>& valid_places) { const int g_batch_size = 1;
const int g_thread_num = 1;
void instance_run() {
lite::Predictor predictor; lite::Predictor predictor;
std::vector<std::string> passes; std::vector<std::string> passes;
std::vector<Place> valid_places({Place{TARGET(kBM), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}});
predictor.Build(FLAGS_model_dir, "", "", valid_places, passes); predictor.Build(FLAGS_model_dir, "", "", valid_places, passes);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim( input_tensor->Resize(DDim(std::vector<DDim::value_type>(
std::vector<DDim::value_type>({1, 3, FLAGS_im_height, FLAGS_im_width}))); {g_batch_size, 3, FLAGS_im_height, FLAGS_im_width})));
auto* data = input_tensor->mutable_data<float>(); auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production(); auto item_size = input_tensor->dims().production();
if (FLAGS_input_img_txt_path.empty()) { if (FLAGS_input_img_txt_path.empty()) {
...@@ -45,12 +50,15 @@ void TestModel(const std::vector<Place>& valid_places) { ...@@ -45,12 +50,15 @@ void TestModel(const std::vector<Place>& valid_places) {
data[i] = 1; data[i] = 1;
} }
} else { } else {
std::fstream fs(FLAGS_input_img_txt_path, std::ios::in); for (int j = 0; j < g_batch_size; j++) {
if (!fs.is_open()) { std::fstream fs(FLAGS_input_img_txt_path, std::ios::in);
LOG(FATAL) << "open input_img_txt error."; if (!fs.is_open()) {
} LOG(FATAL) << "open input_img_txt error.";
for (int i = 0; i < item_size; i++) { }
fs >> data[i]; for (int i = 0; i < item_size / g_batch_size; i++) {
fs >> data[i];
}
data += j * item_size / g_batch_size;
} }
} }
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
...@@ -72,6 +80,7 @@ void TestModel(const std::vector<Place>& valid_places) { ...@@ -72,6 +80,7 @@ void TestModel(const std::vector<Place>& valid_places) {
FILE* fp = fopen("result.txt", "wb"); FILE* fp = fopen("result.txt", "wb");
for (int i = 0; i < out.size(); i++) { for (int i = 0; i < out.size(); i++) {
auto* out_data = out[i]->data<float>(); auto* out_data = out[i]->data<float>();
LOG(INFO) << out[i]->numel();
for (int j = 0; j < out[i]->numel(); j++) { for (int j = 0; j < out[i]->numel(); j++) {
fprintf(fp, "%f\n", out_data[j]); fprintf(fp, "%f\n", out_data[j]);
} }
...@@ -79,6 +88,16 @@ void TestModel(const std::vector<Place>& valid_places) { ...@@ -79,6 +88,16 @@ void TestModel(const std::vector<Place>& valid_places) {
fclose(fp); fclose(fp);
} }
void TestModel(const std::vector<Place>& valid_places) {
std::vector<std::unique_ptr<std::thread>> instances_vec;
for (int i = 0; i < g_thread_num; ++i) {
instances_vec.emplace_back(new std::thread(&instance_run));
}
for (int i = 0; i < g_thread_num; ++i) {
instances_vec[i]->join();
}
}
TEST(Classify, test_bm) { TEST(Classify, test_bm) {
std::vector<Place> valid_places({Place{TARGET(kBM), PRECISION(kFloat)}, std::vector<Place> valid_places({Place{TARGET(kBM), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}}); Place{TARGET(kX86), PRECISION(kFloat)}});
......
...@@ -36,6 +36,7 @@ lite_cc_library(subgraph_bridge_shape_op_bm SRCS shape_op.cc DEPS ${bm_subgraph_ ...@@ -36,6 +36,7 @@ lite_cc_library(subgraph_bridge_shape_op_bm SRCS shape_op.cc DEPS ${bm_subgraph_
lite_cc_library(subgraph_bridge_split_op_bm SRCS split_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_split_op_bm SRCS split_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_matmul_op_bm SRCS matmul_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_matmul_op_bm SRCS matmul_op.cc DEPS ${bm_subgraph_bridge_deps})
set(bm_subgraph_bridges set(bm_subgraph_bridges
subgraph_bridge_registry subgraph_bridge_registry
subgraph_bridge_engine subgraph_bridge_engine
......
...@@ -54,6 +54,8 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -54,6 +54,8 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
active_type_id = ACTIVE_SQRT; active_type_id = ACTIVE_SQRT;
} else if (op_type == "square") { } else if (op_type == "square") {
active_type_id = ACTIVE_SQUARE; active_type_id = ACTIVE_SQUARE;
} else if (op_type == "sigmoid") {
active_type_id = ACTIVE_SIGMOID;
} else { } else {
LOG(FATAL) << "[BM] unsupport act type"; LOG(FATAL) << "[BM] unsupport act type";
return FAILED; return FAILED;
...@@ -102,3 +104,6 @@ REGISTER_SUBGRAPH_BRIDGE(leaky_relu, ...@@ -102,3 +104,6 @@ REGISTER_SUBGRAPH_BRIDGE(leaky_relu,
paddle::lite::subgraph::bm::ActConverter); paddle::lite::subgraph::bm::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(sqrt, kBM, paddle::lite::subgraph::bm::ActConverter); REGISTER_SUBGRAPH_BRIDGE(sqrt, kBM, paddle::lite::subgraph::bm::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(square, kBM, paddle::lite::subgraph::bm::ActConverter); REGISTER_SUBGRAPH_BRIDGE(square, kBM, paddle::lite::subgraph::bm::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(sigmoid,
kBM,
paddle::lite::subgraph::bm::ActConverter);
...@@ -20,11 +20,14 @@ namespace lite { ...@@ -20,11 +20,14 @@ namespace lite {
namespace subgraph { namespace subgraph {
namespace bm { namespace bm {
pthread_mutex_t Graph::mutex_compiler_ = PTHREAD_MUTEX_INITIALIZER;
void Graph::AddNode(const std::string& name) { void Graph::AddNode(const std::string& name) {
nodes_.insert(std::make_pair(name, name)); nodes_.insert(std::make_pair(name, name));
} }
void Graph::CreateCompilerHandle() { void Graph::CreateCompilerHandle() {
pthread_mutex_lock(&mutex_compiler_);
#ifdef BM1682 #ifdef BM1682
compiler_handle_ = create_bmcompiler("BM1682"); compiler_handle_ = create_bmcompiler("BM1682");
#else #else
...@@ -33,6 +36,8 @@ void Graph::CreateCompilerHandle() { ...@@ -33,6 +36,8 @@ void Graph::CreateCompilerHandle() {
CHECK(compiler_handle_ != nullptr); CHECK(compiler_handle_ != nullptr);
} }
void Graph::UnlockCompilerMutex() { pthread_mutex_unlock(&mutex_compiler_); }
} // namespace bm } // namespace bm
} // namespace subgraph } // namespace subgraph
} // namespace lite } // namespace lite
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <pthread.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -36,10 +37,12 @@ class Graph { ...@@ -36,10 +37,12 @@ class Graph {
} }
void CreateCompilerHandle(); void CreateCompilerHandle();
void* GetCompilerHandle() { return compiler_handle_; } void* GetCompilerHandle() { return compiler_handle_; }
void UnlockCompilerMutex();
private: private:
std::unordered_map<std::string, std::string> nodes_; std::unordered_map<std::string, std::string> nodes_;
void* compiler_handle_; void* compiler_handle_;
static pthread_mutex_t mutex_compiler_;
}; };
} // namespace bm } // namespace bm
......
...@@ -58,3 +58,5 @@ USE_SUBGRAPH_BRIDGE(depthwise_conv2d_transpose, kBM); ...@@ -58,3 +58,5 @@ USE_SUBGRAPH_BRIDGE(depthwise_conv2d_transpose, kBM);
USE_SUBGRAPH_BRIDGE(shape, kBM); USE_SUBGRAPH_BRIDGE(shape, kBM);
USE_SUBGRAPH_BRIDGE(split, kBM); USE_SUBGRAPH_BRIDGE(split, kBM);
USE_SUBGRAPH_BRIDGE(matmul, kBM); USE_SUBGRAPH_BRIDGE(matmul, kBM);
USE_SUBGRAPH_BRIDGE(max_pool2d_with_index, kBM);
USE_SUBGRAPH_BRIDGE(sigmoid, kBM);
...@@ -11,7 +11,10 @@ ...@@ -11,7 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <bmcompiler_defs.h>
#include <bmcompiler_if.h> #include <bmcompiler_if.h>
#include <bmcompiler_if_lite.h>
#include <user_bmcpu_common.h>
#include "lite/kernels/bm/bridges/graph.h" #include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h" #include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h" #include "lite/kernels/npu/bridges/registry.h"
...@@ -54,46 +57,84 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -54,46 +57,84 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
shape[0] = &i_output_shape_data[0]; shape[0] = &i_output_shape_data[0];
name[0] = static_cast<const char*>(output_var_name.c_str()); name[0] = static_cast<const char*>(output_var_name.c_str());
dim[0] = output_dims.size(); dim[0] = output_dims.size();
auto pooling_type = op_info->GetAttr<std::string>("pooling_type"); std::string pooling_type;
if (op_info->HasAttr("pooling_type")) {
pooling_type = op_info->GetAttr<std::string>("pooling_type");
} else if (op_type == "max_pool2d_with_index") {
pooling_type = "max";
}
CHECK(pooling_type == "max" || pooling_type == "avg"); CHECK(pooling_type == "max" || pooling_type == "avg");
auto ksize = op_info->GetAttr<std::vector<int>>("ksize"); auto ksize = op_info->GetAttr<std::vector<int>>("ksize");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings"); auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto strides = op_info->GetAttr<std::vector<int>>("strides"); auto strides = op_info->GetAttr<std::vector<int>>("strides");
auto global_pooling = op_info->GetAttr<bool>("global_pooling"); auto global_pooling = op_info->GetAttr<bool>("global_pooling");
auto ceil_mode = op_info->GetAttr<bool>("ceil_mode"); bool ceil_mode = false;
if (op_info->HasAttr("ceil_mode")) {
ceil_mode = op_info->GetAttr<bool>("ceil_mode");
}
bool adaptive = false;
if (op_info->HasAttr("adaptive")) {
adaptive = op_info->GetAttr<bool>("adaptive");
}
bool average_exclusive = false; bool average_exclusive = false;
if (pooling_type == "avg") { if (pooling_type == "avg") {
average_exclusive = op_info->GetAttr<bool>("exclusive"); average_exclusive = op_info->GetAttr<bool>("exclusive");
} }
if (output_dims[2] == 1 && output_dims[3] == 1) {
global_pooling = true;
}
if (global_pooling) { if (global_pooling) {
paddings[0] = 0; paddings[0] = 0;
paddings[1] = 0; paddings[1] = 0;
ksize[0] = i_x_shape_data[2]; ksize[0] = i_x_shape_data[2];
ksize[1] = i_x_shape_data[3]; ksize[1] = i_x_shape_data[3];
} }
add_pooling_layer( bool is_max = (pooling_type == "max");
graph->GetCompilerHandle(), if (adaptive && !global_pooling) {
const_cast<const int*>(&i_x_shape_data[0]), user_cpu_param_t bm_param;
x_dims.size(), bm_param.op_type = USER_PADDLE_ADAPTIVE_POOL;
static_cast<const char*>(x_var_name.c_str()), bm_param.u.adaptive_pool_parm.is_avg = !is_max;
1, int32_t* in_shape[1];
shape, int32_t in_dim[1];
dim, const char* in_name[1];
name, in_shape[0] = &i_x_shape_data[0];
ksize[0], in_name[0] = static_cast<const char*>(x_var_name.c_str());
ksize[1], in_dim[0] = x_dims.size();
paddings[0], add_user_cpu_layer(graph->GetCompilerHandle(),
paddings[0], 1,
paddings[1], in_shape,
paddings[1], in_dim,
strides[0], in_name,
strides[1], 1,
(ksize[0] > 1 && ksize[1] > 1) && pooling_type == "max" ? 0 : 1, shape,
static_cast<int>(average_exclusive), dim,
static_cast<int>(global_pooling), name,
static_cast<int>(ceil_mode), &bm_param,
static_cast<const char*>(unique_op_name.c_str()), static_cast<int>(sizeof(bm_param)));
nullptr); } else {
add_pooling_layer(graph->GetCompilerHandle(),
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
1,
shape,
dim,
name,
ksize[0],
ksize[1],
paddings[0],
paddings[0],
paddings[1],
paddings[1],
strides[0],
strides[1],
is_max ? 0 : 1,
static_cast<int>(average_exclusive),
static_cast<int>(global_pooling),
static_cast<int>(ceil_mode),
static_cast<const char*>(unique_op_name.c_str()),
nullptr);
}
graph->AddNode(output_var_name); graph->AddNode(output_var_name);
return SUCCESS; return SUCCESS;
} }
...@@ -105,3 +146,6 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -105,3 +146,6 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
REGISTER_SUBGRAPH_BRIDGE(pool2d, REGISTER_SUBGRAPH_BRIDGE(pool2d,
kBM, kBM,
paddle::lite::subgraph::bm::PoolConverter); paddle::lite::subgraph::bm::PoolConverter);
REGISTER_SUBGRAPH_BRIDGE(max_pool2d_with_index,
kBM,
paddle::lite::subgraph::bm::PoolConverter);
...@@ -40,6 +40,7 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -40,6 +40,7 @@ int SubgraphEngine::BuildDeviceProgram() {
op->CheckShape(); op->CheckShape();
op->InferShape(); op->InferShape();
std::string op_type = op->op_info()->Type(); std::string op_type = op->op_info()->Type();
LOG(INFO) << op_type;
if (!bridges.Exists(op_type, TARGET(kBM))) { if (!bridges.Exists(op_type, TARGET(kBM))) {
return subgraph::FAILED; return subgraph::FAILED;
} }
...@@ -59,6 +60,7 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -59,6 +60,7 @@ int SubgraphEngine::BuildDeviceProgram() {
unsigned int data_size = 0; unsigned int data_size = 0;
bm_hd_ = static_cast<bm_handle_t>(ctx.GetHandle()); bm_hd_ = static_cast<bm_handle_t>(ctx.GetHandle());
finish_bmcompiler_data(graph.GetCompilerHandle(), &bmodel_data, &data_size); finish_bmcompiler_data(graph.GetCompilerHandle(), &bmodel_data, &data_size);
graph.UnlockCompilerMutex();
bmrt_hd_ = bmrt_create(bm_hd_); bmrt_hd_ = bmrt_create(bm_hd_);
if (false == bmrt_load_bmodel_data(bmrt_hd_, bmodel_data, data_size)) { if (false == bmrt_load_bmodel_data(bmrt_hd_, bmodel_data, data_size)) {
return subgraph::FAILED; return subgraph::FAILED;
......
...@@ -108,6 +108,7 @@ add_operator(collect_fpn_proposals_op_lite extra SRCS collect_fpn_proposals_op.c ...@@ -108,6 +108,7 @@ add_operator(collect_fpn_proposals_op_lite extra SRCS collect_fpn_proposals_op.c
add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposals_op.cc DEPS ${op_DEPS}) add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposals_op.cc DEPS ${op_DEPS})
add_operator(crf_decoding_op_lite extra SRCS crf_decoding_op.cc DEPS ${op_DEPS}) add_operator(crf_decoding_op_lite extra SRCS crf_decoding_op.cc DEPS ${op_DEPS})
add_operator(ctc_align_op_lite extra SRCS ctc_align_op.cc DEPS ${op_DEPS}) add_operator(ctc_align_op_lite extra SRCS ctc_align_op.cc DEPS ${op_DEPS})
add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${op_DEPS})
# for OCR specific # for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/max_pool_with_index_op.h"
#include <algorithm>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool MaxPoolWithIndexOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
const auto& x_dims = param_.x->dims();
const auto& strides = param_.strides;
const auto& ksize = param_.ksize;
const auto& paddings = *param_.paddings;
// "Pooling intput should be 4-D or 5-D tensor."
CHECK_OR_FALSE(x_dims.size() == 4 || x_dims.size() == 5);
// Input size and pooling size should be consistent.
CHECK_OR_FALSE(x_dims.size() - ksize.size() == 2U);
// Strides size and pooling size should be the same.
CHECK_OR_FALSE(ksize.size() == strides.size());
// Paddings size must be 4.
CHECK_OR_FALSE(paddings.size() == 4L);
return true;
}
inline int MaxPoolOutputSize(int input_size,
int filter_size,
int padding,
int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
}
bool MaxPoolWithIndexOpLite::InferShapeImpl() const {
const auto x_dims = param_.x->dims();
const auto ksize = param_.ksize;
std::vector<int64_t> output_shape({x_dims[0], x_dims[1]});
const auto& strides = param_.strides;
const auto& paddings = *param_.paddings;
const auto adaptive = param_.adaptive;
if (adaptive) {
output_shape.insert(output_shape.end(), ksize.begin(), ksize.end());
} else {
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(
MaxPoolOutputSize(x_dims[i + 2], ksize[i], paddings[i], strides[i]));
}
}
param_.output->Resize(lite::DDim(output_shape));
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(max_pool2d_with_index,
paddle::lite::operators::MaxPoolWithIndexOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class MaxPoolWithIndexOpLite : public OpLite {
public:
MaxPoolWithIndexOpLite() {}
explicit MaxPoolWithIndexOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
auto mask = op_desc.Output("Mask").front();
CHECK(scope->FindVar(x));
CHECK(scope->FindVar(out));
CHECK(scope->FindVar(mask));
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.ksize = op_desc.GetAttr<std::vector<int>>("ksize");
param_.global_pooling = op_desc.GetAttr<bool>("global_pooling");
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings");
if (op_desc.HasAttr("adaptive")) {
param_.adaptive = op_desc.GetAttr<bool>("adaptive");
}
// 2-pad to 4-pad
if (paddings.size() == 2L) {
for (size_t i = 0; i < 2L; ++i) {
int copy_pad = *(paddings.begin() + 2 * i);
paddings.insert(paddings.begin() + 2 * i + 1, copy_pad);
}
} else {
if (paddings.size() != 4L) {
LOG(FATAL)
<< "Paddings size should be the same or twice as the inputs size.";
}
}
param_.paddings = std::make_shared<std::vector<int>>(paddings);
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "max_pool2d_with_index"; }
private:
mutable PoolParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册