未验证 提交 49f03648 编写于 作者: Y yiicy 提交者: GitHub

[ARM] change global pooling choose kernel policy, test=develop (#2602)

* [ARM] change global pooling choose kernel policy, test=develop
上级 25159de9
...@@ -262,14 +262,10 @@ void Instruction::Run() { ...@@ -262,14 +262,10 @@ void Instruction::Run() {
if (op_->run_once() && has_run_) { if (op_->run_once() && has_run_) {
return; return;
} }
#ifndef LITE_SHUTDOWN_LOG // VLOG(4) << "kernel launch";
VLOG(4) << "kernel launch";
#endif
op_->InferShape(); op_->InferShape();
#ifndef LITE_SHUTDOWN_LOG // VLOG(4) << ">> Running kernel: " << op_->op_info()->Repr() << " on Target "
VLOG(4) << ">> Running kernel: " << op_->op_info()->Repr() << " on Target " // << TargetToStr(kernel_->target());
<< TargetToStr(kernel_->target());
#endif
kernel_->Launch(); kernel_->Launch();
has_run_ = true; has_run_ = true;
} }
......
...@@ -65,20 +65,20 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -65,20 +65,20 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
no_dilation && flag_dw) { no_dilation && flag_dw) {
/// dw conv impl /// dw conv impl
impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking dw conv"; // VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation && pads_all_equal) { no_dilation && pads_all_equal) {
/// winograd conv impl /// winograd conv impl
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv"; // VLOG(3) << "invoking winograd conv";
} else if (param.groups == 1 && kw == 3 && stride == 2 && } else if (param.groups == 1 && kw == 3 && stride == 2 &&
chin * chout < 4 * hin * win && kps_equal && no_dilation) { chin * chout < 4 * hin * win && kps_equal && no_dilation) {
/// direct conv impl /// direct conv impl
impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking direct conv"; // VLOG(3) << "invoking direct conv";
} else { } else {
impl_ = new GemmLikeConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new GemmLikeConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking gemm like conv"; // VLOG(3) << "invoking gemm like conv";
} }
impl_->SetContext(std::move(this->ctx_)); impl_->SetContext(std::move(this->ctx_));
impl_->SetParam(param); impl_->SetParam(param);
...@@ -117,14 +117,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -117,14 +117,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
if (param.groups == ic && ic == oc && kps_equal && pads_equal && if (param.groups == ic && ic == oc && kps_equal && pads_equal &&
no_dilation && flag_dw) { no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>; impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>;
VLOG(3) << "Run DepthwiseConv Int8"; // VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
kps_equal && no_dilation) { kps_equal && no_dilation) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>; impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>;
VLOG(3) << "Run DirectConv Int8"; // VLOG(3) << "Run DirectConv Int8";
} else { } else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>; impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>;
VLOG(3) << "Run GemmLikeConvInt8"; // VLOG(3) << "Run GemmLikeConvInt8";
} }
impl_->SetContext(std::move(this->ctx_)); impl_->SetContext(std::move(this->ctx_));
impl_->SetParam(param); impl_->SetParam(param);
...@@ -163,14 +163,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -163,14 +163,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
if (param.groups == ic && ic == oc && kps_equal && pads_equal && if (param.groups == ic && ic == oc && kps_equal && pads_equal &&
no_dilation && flag_dw) { no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>; impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>;
VLOG(3) << "Run DepthwiseConv Int8"; // VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
kps_equal && no_dilation) { kps_equal && no_dilation) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>; impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>;
VLOG(3) << "Run DirectConv Int8"; // VLOG(3) << "Run DirectConv Int8";
} else { } else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>; impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>;
VLOG(3) << "Run GemmLikeConvInt8"; // VLOG(3) << "Run GemmLikeConvInt8";
} }
impl_->SetContext(std::move(this->ctx_)); impl_->SetContext(std::move(this->ctx_));
impl_->SetParam(param); impl_->SetParam(param);
......
...@@ -30,7 +30,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -30,7 +30,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto kw = w_dims[3]; auto kw = w_dims[3];
// select dw conv kernel // select dw conv kernel
if (kw == 3) { if (kw == 3) {
VLOG(5) << "invoke 3x3 dw conv fp32"; // VLOG(5) << "invoke 3x3 dw conv fp32";
auto paddings = *param.paddings; auto paddings = *param.paddings;
bool pads_equal = bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
...@@ -54,7 +54,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -54,7 +54,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
flag_trans_weights_ = true; flag_trans_weights_ = true;
} }
} else if (kw == 5) { } else if (kw == 5) {
VLOG(5) << "invoke 5x5 dw conv fp32"; // VLOG(5) << "invoke 5x5 dw conv fp32";
impl_ = lite::arm::math::conv_depthwise_5x5_fp32; impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
} else { } else {
LOG(FATAL) << "this type dw conv not impl"; LOG(FATAL) << "this type dw conv not impl";
...@@ -86,7 +86,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -86,7 +86,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
/// select dw conv kernel /// select dw conv kernel
if (kw == 3) { if (kw == 3) {
// trans weights // trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out"; // VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32; impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8); int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8}); weights_.Resize({cround / 8, 1, kh * kw, 8});
...@@ -96,7 +96,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -96,7 +96,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
flag_trans_weights_ = true; flag_trans_weights_ = true;
} else if (kw == 5) { } else if (kw == 5) {
// trans weights // trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out"; // VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32; impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8); int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8}); weights_.Resize({cround / 8, 1, kh * kw, 8});
...@@ -145,7 +145,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -145,7 +145,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
/// select dw conv kernel /// select dw conv kernel
if (kw == 3) { if (kw == 3) {
// trans weights // trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out"; // VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8; impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8;
int cround = ROUNDUP(w_dims[0], 8); int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8}); weights_.Resize({cround / 8, 1, kh * kw, 8});
...@@ -155,7 +155,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -155,7 +155,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
flag_trans_weights_ = true; flag_trans_weights_ = true;
} else if (kw == 5) { } else if (kw == 5) {
// trans weights // trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out"; // VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8; impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8;
int cround = ROUNDUP(w_dims[0], 8); int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8}); weights_.Resize({cround / 8, 1, kh * kw, 8});
......
...@@ -41,18 +41,20 @@ void PoolCompute::Run() { ...@@ -41,18 +41,20 @@ void PoolCompute::Run() {
std::vector<int>& paddings = *param.paddings; std::vector<int>& paddings = *param.paddings;
std::string& pooling_type = param.pooling_type; std::string& pooling_type = param.pooling_type;
bool global_pooling = param.global_pooling;
bool exclusive = param.exclusive; bool exclusive = param.exclusive;
bool adaptive = param.adaptive; bool adaptive = param.adaptive;
bool ceil_mode = param.ceil_mode; bool ceil_mode = param.ceil_mode;
bool use_quantizer = param.use_quantizer; bool use_quantizer = param.use_quantizer;
std::string& data_format = param.data_format; std::string& data_format = param.data_format;
bool pads_equal = bool pads_equal = (paddings[0] == paddings[1]) &&
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]); (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]);
bool kps_equal = (ksize[0] == ksize[1]) && (strides[0] == strides[1]) && bool kps_equal =
(paddings[0] == paddings[2]); (ksize[0] == ksize[1]) && (strides[0] == strides[1]) && pads_equal;
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && pads_equal;
global_pooling = param.global_pooling || global_pooling;
if (global_pooling) { if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[2 * i] = 0; paddings[2 * i] = 0;
...@@ -83,8 +85,7 @@ void PoolCompute::Run() { ...@@ -83,8 +85,7 @@ void PoolCompute::Run() {
return; return;
} }
} else { } else {
if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && pads_equal && if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && kps_equal) {
kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din, lite::arm::math::pooling2x2s2_max(din,
dout, dout,
...@@ -110,7 +111,7 @@ void PoolCompute::Run() { ...@@ -110,7 +111,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 && } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
pads_equal && kps_equal) { kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(din, lite::arm::math::pooling3x3s1p1_max(din,
dout, dout,
...@@ -136,7 +137,7 @@ void PoolCompute::Run() { ...@@ -136,7 +137,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 &&
pads_equal && kps_equal) { kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p0_max(din, lite::arm::math::pooling3x3s1p0_max(din,
dout, dout,
...@@ -162,7 +163,7 @@ void PoolCompute::Run() { ...@@ -162,7 +163,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
pads_equal && kps_equal) { kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(din, lite::arm::math::pooling3x3s2p0_max(din,
dout, dout,
...@@ -188,7 +189,7 @@ void PoolCompute::Run() { ...@@ -188,7 +189,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
pads_equal && kps_equal) { kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(din, lite::arm::math::pooling3x3s2p1_max(din,
dout, dout,
......
...@@ -54,7 +54,7 @@ void SplitLodTensorCompute::Run() { ...@@ -54,7 +54,7 @@ void SplitLodTensorCompute::Run() {
} }
lod->clear(); lod->clear();
for (size_t i = 0; i < static_cast<size_t>(mask_dim[0]); i++) { for (size_t i = 0; i < static_cast<size_t>(mask_dim[0]); i++) {
VLOG(4) << "mask: " << mask_data[i]; // VLOG(4) << "mask: " << mask_data[i];
if (static_cast<size_t>(mask_data[i]) == t) { if (static_cast<size_t>(mask_data[i]) == t) {
size_t start_idx = i; size_t start_idx = i;
auto lod_and_offset = lite::arm::math::GetSubLoDAndAbsoluteOffset( auto lod_and_offset = lite::arm::math::GetSubLoDAndAbsoluteOffset(
......
...@@ -36,7 +36,7 @@ class StepExecutor { ...@@ -36,7 +36,7 @@ class StepExecutor {
auto &op_desc = *block->template GetOp<cpp::OpDesc>(i); auto &op_desc = *block->template GetOp<cpp::OpDesc>(i);
auto op_type = op_desc.Type(); auto op_type = op_desc.Type();
auto op_handler = lite::LiteOpRegistry::Global().Create(op_desc.Type()); auto op_handler = lite::LiteOpRegistry::Global().Create(op_desc.Type());
VLOG(4) << "while: creating Op [" << op_type << "]"; // VLOG(4) << "while: creating Op [" << op_type << "]";
op_handler->Attach(op_desc, scope); op_handler->Attach(op_desc, scope);
auto hostplace = place_; auto hostplace = place_;
...@@ -51,9 +51,9 @@ class StepExecutor { ...@@ -51,9 +51,9 @@ class StepExecutor {
void Run() { void Run() {
for (auto &op_handler : ops_of_block_) { for (auto &op_handler : ops_of_block_) {
VLOG(4) << op_handler->op_info()->Repr(); // VLOG(4) << op_handler->op_info()->Repr();
op_handler->InferShape(); op_handler->InferShape();
VLOG(4) << "while: infered shape"; // VLOG(4) << "while: infered shape";
op_handler->Run(); op_handler->Run();
} }
} }
......
...@@ -355,7 +355,8 @@ void test_pool_fp32(const std::vector<DDim>& input_dims, ...@@ -355,7 +355,8 @@ void test_pool_fp32(const std::vector<DDim>& input_dims,
LOG(FATAL) << "test fp32 pool: input: " << dim_in LOG(FATAL) << "test fp32 pool: input: " << dim_in
<< ", output: " << dim_out << ", output: " << dim_out
<< ", kernel dim: " << ksize[0] << ", " << ksize[1] << ", kernel dim: " << ksize[0] << ", " << ksize[1]
<< ", pad: " << pads[0] << ", " << pads[1] << ", pad: " << pads[0] << ", " << pads[1] << ", "
<< pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1] << ", stride: " << strides[0] << ", " << strides[1]
<< ", global_pooling: " << ", global_pooling: "
<< (flag_global ? "global" : "false") << (flag_global ? "global" : "false")
...@@ -370,6 +371,7 @@ void test_pool_fp32(const std::vector<DDim>& input_dims, ...@@ -370,6 +371,7 @@ void test_pool_fp32(const std::vector<DDim>& input_dims,
LOG(INFO) << "test fp32 pool: input: " << dim_in LOG(INFO) << "test fp32 pool: input: " << dim_in
<< ", output: " << dim_out << ", kernel dim: " << ksize[0] << ", output: " << dim_out << ", kernel dim: " << ksize[0]
<< ", " << ksize[1] << ", pad: " << pads[0] << ", " << pads[1] << ", " << ksize[1] << ", pad: " << pads[0] << ", " << pads[1]
<< ", " << pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1] << ", stride: " << strides[0] << ", " << strides[1]
<< ", global_pooling: " << (flag_global ? "global" : "false") << ", global_pooling: " << (flag_global ? "global" : "false")
<< ", pooling_type: " << pooling_type << ", pooling_type: " << pooling_type
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册