提交 1fc2ed55 编写于 作者: Y Yanzhan Yang 提交者: GitHub

add faster depthwise implementations (#1747)

* add faster depthwise implementations

* fix style
上级 b84df55a
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include <vector>
#include "operators/math/depthwise/faster_depthwise_conv3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/depthwise_conv5x5.h"
#include "operators/math/im2col.h"
......@@ -211,6 +212,65 @@ void DepthwiseConv3x3(const ConvParam<CPU> &param) {
}
}
template <>
void DepthwiseConv3x3<float, float>(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = input->dims()[0];
Tensor *output = param.Output();
output->mutable_data<float>();
if (paddings.size() == 2 && paddings[0] == paddings[1] &&
strides.size() == 2 && strides[0] == strides[1]) {
int pad = paddings[0];
int stride = strides[0];
const float *din = input->data<float>();
float *dout = output->mutable_data<float>();
const float *weights = filter->data<float>();
const float *bias = nullptr;
const int num = input->dims()[0];
const int chin = input->dims()[1];
const int hin = input->dims()[2];
const int win = input->dims()[3];
const int chout = output->dims()[1];
const int hout = output->dims()[2];
const int wout = output->dims()[3];
bool flag_relu = false;
bool flag_bias = bias != nullptr;
if (pad == 0 && hin > 2) {
math::depthwise::conv_depthwise_3x3p0(din, dout, num, chout, hout, wout,
chin, hin, win, weights, bias,
stride, flag_bias, flag_relu);
} else if (pad == 1) {
math::depthwise::conv_depthwise_3x3p1(din, dout, num, chout, hout, wout,
chin, hin, win, weights, bias,
stride, flag_bias, flag_relu);
} else {
GemmConv<float, float>(param);
}
} else {
if (strides[0] == 1) {
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
math::DepthwiseConv3x3S1<float, float>(in_batch, *filter, paddings,
&out_batch);
}
} else if (strides[0] == 2) {
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
math::DepthwiseConv3x3S2<float, float>(in_batch, *filter, paddings,
&out_batch);
}
} else {
GemmConv<float, float>(param);
}
}
}
template <typename Itype, typename Otype>
void DepthwiseConv5x5(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
namespace paddle_mobile {
namespace operators {
namespace math {
namespace depthwise {
void conv_depthwise_3x3p0(const float* din, float* dout, int num, int ch_out,
int h_out, int w_out, int ch_in, int h_in, int w_in,
const float* weights, const float* bias, int stride,
bool flag_bias, bool flag_relu);
void conv_depthwise_3x3p1(const float* din, float* dout, int num, int ch_out,
int h_out, int w_out, int ch_in, int h_in, int w_in,
const float* weights, const float* bias, int stride,
bool flag_bias, bool flag_relu);
} // namespace depthwise
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -59,12 +59,13 @@ int main(int argc, char* argv[]) {
paddle_mobile.Predict(input);
}
auto time3 = time();
for (int i = 0; i < 10; ++i) {
int test_count = 100;
for (int i = 0; i < test_count; ++i) {
paddle_mobile.Predict(input);
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms\n";
std::cout << "predict cost :" << time_diff(time3, time4) / test_count
<< "ms\n";
std::ostringstream os("output tensor size: ");
output = paddle_mobile.Fetch();
os << output->numel() << "\n" << output->data<float>()[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册