未验证 提交 4ef5d5ec 编写于 作者: L Leonardo-Ding 提交者: GitHub

[arm]add benchmark ops for arm,test=develop (#4148)

上级 5dd5ed67
......@@ -26,6 +26,88 @@ namespace lite {
namespace kernels {
namespace arm {
template <typename Dtype>
void naive_transpose(const Dtype* din, Dtype* dout, int m, int n) {
int k = 0;
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
dout[k++] = din[j * n + i];
}
}
}
template <PrecisionType PType>
void fc_trans_weights(const Tensor& tin, Tensor* tout);
template <>
void fc_trans_weights<PRECISION(kFloat)>(const Tensor& tin, Tensor* tout) {
CHECK_EQ(tin.dims().size(), 2) << "fc weights size must = 2";
int m = tin.dims()[0];
int n = tin.dims()[1];
tout->Resize({n, m});
auto* ptr_in = tin.data<float>();
auto* ptr_out = tout->mutable_data<float>();
naive_transpose(ptr_in, ptr_out, m, n);
}
template <>
void fc_trans_weights<PRECISION(kInt8)>(const Tensor& tin, Tensor* tout) {
CHECK_EQ(tin.dims().size(), 2) << "fc weights size must = 2";
int m = tin.dims()[0];
int n = tin.dims()[1];
tout->Resize({n, m});
auto* ptr_in = tin.data<int8_t>();
auto* ptr_out = tout->mutable_data<int8_t>();
naive_transpose(ptr_in, ptr_out, m, n);
}
template <PrecisionType PType, PrecisionType OutType>
bool check_fc_use_gemm(int m, const std::vector<float>& scale, bool has_bias) {
return m > 1;
}
template <>
bool check_fc_use_gemm<PRECISION(kInt8), PRECISION(kFloat)>(
int m, const std::vector<float>& scale, bool has_bias) {
CHECK_GT(scale.size(), 0) << "Int8 FC param must has weight_scale";
return m > 1 && scale.size() == 1;
}
template <>
bool check_fc_use_gemm<PRECISION(kInt8), PRECISION(kInt8)>(
int m, const std::vector<float>& scale, bool has_bias) {
CHECK_GT(scale.size(), 0) << "Int8 FC param must has weight_scale";
return m > 1 && scale.size() == 1 && !has_bias;
}
template <PrecisionType PType, PrecisionType OutType>
void FcCompute<PType, OutType>::ReInitWhenNeeded() {
auto& param = this->template Param<operators::FcParam>();
auto x_dims = param.input->dims();
if (last_shape_ == x_dims) {
return;
}
last_shape_ = x_dims;
auto w_dims = param.w->dims();
auto& ctx = this->ctx_->template As<ARMContext>();
CHECK_GE(x_dims.size(), 2UL);
CHECK_EQ(w_dims.size(), 2UL);
CHECK_GE(param.output->dims().size(), 2UL);
m_ = x_dims.Slice(0, param.in_num_col_dims).production();
k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production();
CHECK_EQ(k_, w_dims[0]);
n_ = w_dims[1];
CHECK_EQ(k_, static_cast<int>(w_dims[0]));
flag_gemm_ = check_fc_use_gemm<PType, OutType>(
m_, param.weight_scale, param.bias != nullptr);
if (!flag_trans_weights_ && !flag_gemm_) {
flag_trans_weights_ = true;
fc_trans_weights<PType>(*param.w, &weights_);
}
}
/// for fp32 kernel
template <>
void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
......@@ -71,8 +153,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
/// update bias
if (param.bias) {
bias_.Resize(param.bias->dims());
auto ptr = bias_.mutable_data<float>();
auto ptr_in = bias_.data<float>();
auto* ptr = bias_.mutable_data<float>();
auto* ptr_in = bias_.data<float>();
float out_scale = param.output_scale;
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] = ptr_in[i] / out_scale;
......@@ -86,9 +168,9 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<operators::FcParam>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto i_data = param.input->data<float>();
auto o_data = param.output->mutable_data<float>();
auto w_data = param.w->data<float>();
auto* i_data = param.input->data<float>();
auto* o_data = param.output->mutable_data<float>();
auto* w_data = flag_gemm_ ? param.w->data<float>() : weights_.data<float>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
......@@ -125,8 +207,8 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
}
} else {
for (int i = 0; i < m_; ++i) {
auto i_data_batch = i_data + i * k_;
auto o_data_batch = o_data + i * n_;
auto* i_data_batch = i_data + i * k_;
auto* o_data_batch = o_data + i * n_;
lite::arm::math::sgemv(w_data,
i_data_batch,
o_data_batch,
......@@ -147,9 +229,10 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = this->Param<operators::FcParam>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto i_data = param.input->data<int8_t>();
auto o_data = param.output->mutable_data<float>();
auto w_data = param.w->data<int8_t>();
auto* i_data = param.input->data<int8_t>();
auto* o_data = param.output->mutable_data<float>();
auto* w_data =
flag_trans_weights_ ? weights_.data<int8_t>() : param.w->data<int8_t>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
......@@ -182,8 +265,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
}
} else {
for (int i = 0; i < m_; ++i) {
auto i_data_batch = i_data + i * k_;
auto o_data_batch = o_data + i * n_;
auto* i_data_batch = i_data + i * k_;
auto* o_data_batch = o_data + i * n_;
lite::arm::math::gemv_int8(w_data,
i_data_batch,
o_data_batch,
......@@ -205,9 +288,10 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto& param = this->Param<operators::FcParam>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto i_data = param.input->data<int8_t>();
auto o_data = param.output->mutable_data<int8_t>();
auto w_data = param.w->data<int8_t>();
auto* i_data = param.input->data<int8_t>();
auto* o_data = param.output->mutable_data<int8_t>();
auto* w_data =
flag_trans_weights_ ? weights_.data<int8_t>() : param.w->data<int8_t>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
......@@ -240,8 +324,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
&ctx);
} else {
for (int i = 0; i < m_; ++i) {
auto i_data_batch = i_data + i * k_;
auto o_data_batch = o_data + i * n_;
auto* i_data_batch = i_data + i * k_;
auto* o_data_batch = o_data + i * n_;
lite::arm::math::gemv_int8(w_data,
i_data_batch,
o_data_batch,
......
......@@ -24,92 +24,12 @@ namespace lite {
namespace kernels {
namespace arm {
template <typename Dtype>
void naive_transpose(const Dtype* din, Dtype* dout, int m, int n) {
int k = 0;
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
dout[k++] = din[j * n + i];
}
}
}
template <PrecisionType PType>
void fc_trans_weights(const Tensor& tin, Tensor* tout);
template <>
void fc_trans_weights<PRECISION(kFloat)>(const Tensor& tin, Tensor* tout) {
CHECK_EQ(tin.dims().size(), 2) << "fc weights size must = 2";
int m = tin.dims()[0];
int n = tin.dims()[1];
tout->Resize({n, m});
auto ptr_in = tin.data<float>();
auto ptr_out = tout->mutable_data<float>();
naive_transpose(ptr_in, ptr_out, m, n);
}
template <>
void fc_trans_weights<PRECISION(kInt8)>(const Tensor& tin, Tensor* tout) {
CHECK_EQ(tin.dims().size(), 2) << "fc weights size must = 2";
int m = tin.dims()[0];
int n = tin.dims()[1];
tout->Resize({n, m});
auto ptr_in = tin.data<int8_t>();
auto ptr_out = tout->mutable_data<int8_t>();
naive_transpose(ptr_in, ptr_out, m, n);
}
template <PrecisionType PType, PrecisionType OutType>
bool check_fc_use_gemm(int m, const std::vector<float>& scale, bool has_bias) {
return m > 1;
}
template <>
bool check_fc_use_gemm<PRECISION(kInt8), PRECISION(kFloat)>(
int m, const std::vector<float>& scale, bool has_bias) {
CHECK(scale.size() > 0) << "Int8 FC param must has weight_scale";
return m > 1 && scale.size() == 1;
}
template <>
bool check_fc_use_gemm<PRECISION(kInt8), PRECISION(kInt8)>(
int m, const std::vector<float>& scale, bool has_bias) {
CHECK(scale.size() > 0) << "Int8 FC param must has weight_scale";
return m > 1 && scale.size() == 1 && !has_bias;
}
template <PrecisionType PType, PrecisionType OutType>
class FcCompute : public KernelLite<TARGET(kARM), PType> {
public:
using param_t = operators::FcParam;
virtual void ReInitWhenNeeded() {
auto& param = this->template Param<operators::FcParam>();
auto x_dims = param.input->dims();
if (last_shape_ == x_dims) {
return;
}
last_shape_ = x_dims;
auto w_dims = param.w_dims;
auto& ctx = this->ctx_->template As<ARMContext>();
CHECK_GE(x_dims.size(), 2UL);
CHECK_EQ(w_dims.size(), 2UL);
CHECK_GE(param.output->dims().size(), 2UL);
m_ = x_dims.Slice(0, param.in_num_col_dims).production();
k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production();
n_ = w_dims[1];
flag_gemm_ = check_fc_use_gemm<PType, OutType>(
m_, param.weight_scale, param.bias != nullptr);
if (flag_trans_weights_ == flag_gemm_) {
flag_trans_weights_ = !flag_trans_weights_;
Tensor tmp_tensor;
fc_trans_weights<PType>(*param.w, &tmp_tensor);
param.w->CopyDataFrom(tmp_tensor);
}
}
virtual void ReInitWhenNeeded();
virtual void PrepareForRun();
virtual void Run();
......@@ -117,6 +37,7 @@ class FcCompute : public KernelLite<TARGET(kARM), PType> {
private:
DDim last_shape_;
Tensor weights_;
Tensor bias_;
bool flag_trans_weights_{false};
bool flag_trans_bias_{false};
......
......@@ -3,3 +3,4 @@ add_subdirectory(math)
add_subdirectory(cv)
add_subdirectory(cv/anakin)
add_subdirectory(api)
add_subdirectory(benchmark)
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_MLU AND NOT LITE_WITH_XPU) AND (LITE_WITH_ARM))
lite_cc_test(get_conv_latency SRCS src/get_conv_latency.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(get_batchnorm_latency SRCS src/get_batchnorm_latency.cc DEPS ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(get_pooling_latency SRCS src/get_pooling_latency.cc DEPS ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(get_fc_latency SRCS src/get_fc_latency.cc DEPS ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(get_activation_latency SRCS src/get_activation_latency.cc DEPS ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
# 运行方式
```shell
-- cd Paddle-Lite/lite/tests/benchmark
-- ./build_benchmark_ops.sh #把build目录下的所有单测可执行文件push到手机上
在build_benchmark_ops.sh中运行python get_latency_lookup_table.py --ops_path ops.txt --latency_lookup_table_path latency_lookup_table.txt
其中ops.txt是输入的网络模型文件, latency_lookup_table.txt是执行lite单测后输出的网络op耗时信息文件。
```
# 输入ops.txt格式说明
-- op_name [dim0 dim1 dim2 dim3] (op_param0, op_param1, ..., dtype=xxx)
ops.txt每一行有三个字段,第一个字段是op_name, 第二个字段是输入Tensor的input_dims,
第三个字段用()括起来,描述该op的parameter.
# 注意: 每一个字段之间是以tab来分割的,parameter内的子字段是以逗号来分割的,
# 描述tensor维度的[]内的数据之间以空格来分割,不能加逗号和tab.
op_name现支持取值为conv/activation/batchnorm/pooling/fc;
input_dims描述的是输入tensor格式,支持NCHW 4D等Tensor格式;
op_param0,op_param1等字段描述该op的param属性,比如conv op包含ch_out/stride/group/kernel/pad/dilation/flag_bias/flag_act等属性;
dtype描述该层op使用的数据类型,支持的合法输入为float/int8_float/int8_int8, 现在conv支持三种数据类型,其他op只支持float一种数据类型.
# conv op格式
conv [1 96 112 112] (ch_out=48, stride=1, group=1, kernel=1x1, pad=0, dilation=1, flag_bias=0, flag_act=0, dtype=float)
ch_out表示输出channel值, kernel表示卷积核size, 支持的合法取值为1x1/3x3/5x5等, pad表示边界padding的取值, flag_bias表示是否有bias, flag_act表示是否融合激活函数,支持的合法取值为0/1/2/4.
# activitation op格式
activation [1 8 64 64] (act_type=relu)
act_type表示激活函数类型,合法取值为relu/relu6/leaky_relu/tanh/swish/exp/abs/hard_swish/reciprocal/threshold_relu.
# batchnorm op格式
batchnorm [1 8 64 64] (epsilon=1e-4f, momentum=0.9f)
epsilon表示batchnorm的epsilon参数取值, 默认值为1e-4f;
momentum表示batchnorm的momentum参数取值, 默认值为0.9f.
# pooling op格式
pooling [1 8 64 64] (stride=2, pad=0, kernel=2x2, ceil_mode=0, flag_global=0, exclusive=1, pooling_type=max)
stride表示pooling操作的跨度,默认值取2;pad表示边界padding的取值,默认值取0;
kernel表示pooling卷积核size, 常见取值为2x2(默认值);
ceil_mode表示pooling是否进行ceil操作,=0表示false(默认值),否则表示为true;
flag_global表示pooling是否在WxH维度进行全局操作,=0表示false(默认值),否则表示为true;
exclusive表示pooling操作时的exclusive取值,=1表示true(默认值),否则表示为false;
pooling_type表示pooling类型,合法取值为max(默认值)/avg.
# fc op格式
fc [1 64] (flag_bias=1, param_dim=64x1000)
flag_bias表示fc op是否有bias,=1(默认值)表示为true, 否则为false;
param_dim表示fc op `k x n`的操作维度信息,其中k应与input_dims=[m k]中的k取值保持一致.
# 输出latency_lookup_table.txt格式说明
dev_info core_num thread_num power_mode core0 arch core1 arch core2 arch core3 arch core4 arch core5 arch core6 arch core7 arch
Hisilicon Kirin980 8 1 0 ARM_A55 ARM_A55 ARM_A55 ARM_A55 ARM_A76 ARM_A76 ARM_A76 ARM_A76
op_name input_dims output_dims param_info min_latency(ms) max_latency(ms) avg_latency(ms)
conv [1 96 112 112] [1 48 114 114] (ch_out=48, stride=1, pad=0, kernel=1x1, group=1, dilation=1, flag_bias=0, flag_act=0, dtype=float) 3.469 4.111 3.52088
fc [1 64] [64 1000] (param_dim=64x1000, flag_bias=1, dtype=float) 0.135 0.176 0.13779
batchnorm [1 8 64 64] [1 8 64 64] (epsilon=1e-4f, momentum=0.9f, dtype=float) 0.014 0.178 0.01679
pooling [1 8 64 64] [1 8 32 32] (stride=2, pad=0, kernel=2x2, ceil_mode=0, flag_global=0, exclusive=0, pooling_type=avg, dtype=float) 0.009 0.011 0.00983
activation [1 8 64 64] [1 8 64 64] (act_type=relu, dtype=float) 0.01 0.036 0.01103
-- 第一栏为header信息栏, 包含`dev_info` `arm_v7/v8` `core_num` `thread_num` `power_mode` `core0 arch` ... `core7 arch`字段:
`dev_info`表示手机hardware厂家型号信息, `arm_v7/v8`表示armv7还是armv8架构, `core_num`表示cpu核心数, `thread_num`表示设置的运行多线程数,
`power_mode`表示cpu绑核方式,
`core0 arch`...`core7 arch`表示arm cpu架构信息
第二栏为op信息栏, 包含`op_name` `input_dims` `output_dims` `param_info` `min_latency` `max_latency` `avg_latency`字段:
其中`output_dims`为该层op根据`input_dims``param_info`计算得到的输出tensor维度信息;
`min_latency(ms)` `max_latency(ms)` `avg_latency(ms)`为该层op运行得到的min/max/avg耗时信息.
#!/usr/bin/env bash
exe_dir="/data/local/tmp/bin"
work_dir=$(pwd)
os=android
abi=armv8
lang=gcc
function print_usage {
echo "----------------------------------------"
echo -e " ./push2device.sh --arm_os=<os> --arm_abi=<abi> --arm_lang=<lang>"
echo -e "--arm_os:\t android, only support android now"
echo -e "--arm_abi:\t armv8|armv7"
echo -e "--arm_lang:\t gcc|clang"
echo -e "make sure directory: PaddleLite/build.lite.${arm_os}.${arm_abi}.${arm_lang} exsits!"
echo "----------------------------------------"
}
function main {
for i in "$@"; do
case $i in
--arm_os=*)
os="${i#*=}"
shift
;;
--arm_abi=*)
abi="${i#*=}"
shift
;;
--arm_lang=*)
lang="${i#*=}"
shift
;;
*)
print_usage
exit 1
;;
esac
done
build_dir=$work_dir/../../../build.lite.${os}.${abi}.${lang}
lib_path=$build_dir/lite/tests/benchmark
lib_files=$lib_path/get*latency
adb shell mkdir ${exe_dir}
for file in ${lib_files}
do
adb push ${file} ${exe_dir}
done
}
main $@
python get_latency_lookup_table.py --arm_v7_v8 ${abi}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
import re
import argparse
import subprocess
def get_args():
"""Get arguments.
Returns:
Namespace, arguments.
"""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--ops_path', default='ops.txt', help='Input ops path.')
parser.add_argument(
'--latency_lookup_table_path',
default='latency_lookup_table.txt',
help='Output ops latency path.')
parser.add_argument(
'--platform', default='android', help='Platform: android/ios/custom.')
parser.add_argument('--threads', type=int, default=1, help='Threads.')
parser.add_argument('--power_mode', type=int, default=0, help='PowerMode.')
parser.add_argument('--warmup_times', type=int, default=5,
help='Warm up times of op when estimating latency.')
parser.add_argument('--repeats_times', type=int, default=100,
help='Running times of op when estimating latency.')
parser.add_argument('--arm_v7_v8', type=str, default='armv8',
help='Indicate arm architecture v7 or v8.')
args = parser.parse_args()
return args
def check_dev_connect():
cmd = 'adb devices | grep device'
dev_info = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
out = dev_info.communicate()[0]
res = out.decode().find("\tdevice")
if res == -1:
print("No android device is attached")
sys.exit()
def get_dev_info():
cmd = 'adb shell "cat /proc/cpuinfo | grep Hardware"'
dev_info = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
out = dev_info.communicate()[0]
out = out.decode().strip('\n')
dev_info = out.strip('Hardware\t:').strip()
cmd = 'adb shell "cat /proc/cpuinfo | grep part"'
cpu_info = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
out = cpu_info.communicate()[0]
out = (out.decode().strip('\n').split('\n'))
core_num = len(out)
arch_type = ['UNKNOWN CPU ARCH']*core_num
for i, v in enumerate(out):
out = v.strip('CPU part').strip().strip(':').strip()
if out == '0xd03':
arch_type[i] = 'ARM_A53'
elif out == '0xd05':
arch_type[i] = 'ARM_A55'
elif out == '0xd07':
arch_type[i] = 'ARM_A57'
elif out == '0xd08':
arch_type[i] = 'ARM_A72'
elif out == '0xd09':
arch_type[i] = 'ARM_A73'
elif out == '0xd0a':
arch_type[i] = 'ARM_A75'
elif out == '0xd40':
arch_type[i] = 'ARM_A76'
elif out == '0x804':
# 855
arch_type[i] = 'ARM_A76'
elif out == '0x805':
# 855
arch_type[i] = 'ARM_A55'
elif out == '0x802':
# 845
arch_type[i] = 'ARM_A75'
elif out == '0x803':
# 845
arch_type[i] = 'ARM_A55'
elif out == '0x801':
# 835
arch_type[i] = 'ARM_A73'
elif out == '0x800':
# 835
arch_type[i] = 'ARM_A73'
elif out == '0x205':
# 820
arch_type[i] = 'ARM_A72'
else:
arch_type[i] = 'UNKNOWN CPU ARCH'
return dev_info, core_num, arch_type
def get_op_latency(op, platform):
"""Get model latency.
Args:
op: list, a list of str represents the op and its parameters.
platform: str, platform name.
Returns:
float, op latency.
"""
if platform == 'android':
commands = 'adb shell "cd /data/local/tmp/bin && ./get_{}_latency {}"'.format(
op[0], ' '.join(op[1:]))
proc = subprocess.Popen(
commands,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True)
out = proc.communicate()[0]
avg_out = [_ for _ in out.decode().split('\n') if 'Avg Latency' in _][-1]
avg_out = re.findall(r'\d+\.?\d*', avg_out)[0]
avg_out = float(avg_out)
min_out = [_ for _ in out.decode().split('\n') if 'Min Latency' in _][-1]
min_out = re.findall(r'\d+\.?\d*', min_out)[0]
min_out = float(min_out)
max_out = [_ for _ in out.decode().split('\n') if 'Max Latency' in _][-1]
max_out = re.findall(r'\d+\.?\d*', max_out)[0]
max_out = float(max_out)
elif platform == 'ios':
print('ios platform is not supported now')
sys.exit()
else:
print('Please define `get_op_latency` for {} platform'.format(platform))
sys.exit()
return avg_out, min_out, max_out
def main():
args = get_args()
check_dev_connect()
conv_param_dict = {'ch_out': '1', 'stride':'[1 1]', 'pad':'[0 0 0 0]', 'kernel':'3x3',
'group':'1', 'dilation':'[1 1]', 'flag_bias':'1',
'flag_act':'0', 'dtype':'float'}
batchnorm_param_dict = {'epsilon':'1e-4f', 'momentum':'0.9f',
'dtype':'float'}
pooling_param_dict = {'stride':'2', 'pad':'0', 'kernel':'2x2', 'ceil_mode':'0',
'flag_global':'0', 'exclusive':'1', 'pooling_type': 'max',
'dtype':'float'}
activation_param_dict = {'act_type':'relu', 'dtype':'float'}
fc_param_dict = {'param_dim':'1x1','flag_bias':'1', 'dtype':'float'}
op_info = {}
cur_op_name = ''
cur_param_dict = {}
input_dims = ''
output_dims = ''
runtime_cmd = []
fid = open(args.ops_path, 'r')
handle = open(args.latency_lookup_table_path, 'w')
handle.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format('dev_info'.ljust(30), 'armv7/v8'.ljust(10), 'core_num'.ljust(10), 'thread_num'.ljust(10), 'power_mode'.ljust(10), 'core0 arch'.ljust(10), 'core1 arch'.ljust(10),
'core2 arch'.ljust(10), 'core3 arch'.ljust(10), 'core4 arch'.ljust(10), 'core5 arch'.ljust(10),
'core6 arch'.ljust(10), 'core7 arch'.ljust(10)))
dev_info, core_num, arch_type = get_dev_info()
handle.write('{}\t{}\t{}\t{}'.format(dev_info.ljust(30), str(args.arm_v7_v8).ljust(10), str(core_num).ljust(10), str(args.threads).ljust(10), str(args.power_mode).ljust(10)))
for i in arch_type:
handle.write('\t{}'.format(i).ljust(10))
handle.write('\n')
handle.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format('op_name'.ljust(10), 'input_dims'.ljust(10), 'output_dims'.ljust(10), 'param_info'.ljust(80), 'min_latency(ms)'.ljust(10), 'max_latency(ms)'.ljust(10), 'avg_latency(ms)'.ljust(10)))
for line in fid.readlines():
line = [line.strip('\n')]
for data_item in line:
data_item = data_item.strip().split('\t')
cur_op_name = data_item[0]
input_dims = data_item[1]
parameters = data_item[2].strip('( )').split(',')
for item_ in parameters:
item_ = item_.strip().split('=')
# conv op dict
if cur_op_name == 'conv':
cur_param_dict = conv_param_dict
if item_[0] == 'ch_out':
cur_param_dict['ch_out'] = item_[1]
elif item_[0] == 'stride':
cur_param_dict['stride'] = item_[1]
elif item_[0] == 'pad':
cur_param_dict['pad'] = item_[1]
elif item_[0] == 'kernel':
cur_param_dict['kernel'] = item_[1]
elif item_[0] == 'group':
cur_param_dict['group'] = item_[1]
elif item_[0] == 'dilation':
cur_param_dict['dilation'] = item_[1]
elif item_[0] == 'flag_bias':
cur_param_dict['flag_bias'] = item_[1]
elif item_[0] == 'flag_act':
cur_param_dict['flag_act'] = item_[1]
elif item_[0] == 'dtype':
cur_param_dict['dtype'] = item_[1]
#batchnorm op dict
elif cur_op_name == 'batchnorm':
cur_param_dict = batchnorm_param_dict
if item_[0] == 'epsilon':
cur_param_dict['epsilon'] = item_[1]
elif item_[0] == 'momentum':
cur_param_dict['momentum'] = item_[1]
#pooling op dict
elif cur_op_name == 'pooling':
cur_param_dict = pooling_param_dict
if item_[0] == 'stride':
cur_param_dict['stride'] = item_[1]
elif item_[0] == 'pad':
cur_param_dict['pad'] = item_[1]
elif item_[0] == 'kernel':
cur_param_dict['kernel'] = item_[1]
elif item_[0] == 'ceil_mode':
cur_param_dict['ceil_mode'] = item_[1]
elif item_[0] == 'flag_global':
cur_param_dict['flag_global'] = item_[1]
elif item_[0] == 'exclusive':
cur_param_dict['exclusive'] = item_[1]
elif item_[0] == 'pooling_type':
cur_param_dict['pooling_type'] = item_[1]
#activation op dict
elif cur_op_name == 'activation':
cur_param_dict = activation_param_dict
if item_[0] == 'act_type':
cur_param_dict['act_type'] = item_[1]
# fc op dict
elif cur_op_name == 'fc':
cur_param_dict = fc_param_dict
if item_[0] == 'param_dim':
cur_param_dict['param_dim'] = item_[1]
elif item_[0] == 'flag_bias':
cur_param_dict['flag_bias'] = item_[1]
elif item_[0] == 'dtype':
cur_param_dict['dtype'] = 'float'
op_info[cur_op_name] = cur_param_dict
if cur_op_name == 'conv':
batch = input_dims.strip('[' ']').split()[0]
in_ch = input_dims.strip('[' ']').split()[1]
height = input_dims.strip('[' ']').split()[2]
width = input_dims.strip('[' ']').split()[3]
out_ch = cur_param_dict['ch_out']
pad_top = cur_param_dict['pad'].strip('[' ']').split()[0]
pad_bottom = cur_param_dict['pad'].strip('[' ']').split()[1]
pad_left = cur_param_dict['pad'].strip('[' ']').split()[2]
pad_right = cur_param_dict['pad'].strip('[' ']').split()[0]
dila_h = cur_param_dict['dilation'].strip('[' ']').split()[0]
dila_w = cur_param_dict['dilation'].strip('[' ']').split()[1]
kernel_h = cur_param_dict['kernel'][0]
kernel_w = cur_param_dict['kernel'][2]
stride_h = cur_param_dict['stride'].strip('[' ']').split()[0]
stride_w = cur_param_dict['stride'].strip('[' ']').split()[1]
hout = (int(height) + int(pad_top) + int(pad_bottom) - int(dila_h) *
(int(kernel_h) - 1) + 1) / int(stride_h) + 1
wout = (int(width) + int(pad_left) + int(pad_right) - int(dila_w) *
(int(kernel_w) - 1) + 1) / int(stride_w) + 1
output_dims = '[' + str(batch) + ' ' + str(out_ch) + ' ' + str(int(hout)) + ' ' + str(int(wout)) + ']'
dtype = 0
if cur_param_dict['dtype'] == 'float':
dtype = 0
elif cur_param_dict['dtype'] == 'int8_float':
dtype = 1
elif cur_param_dict['dtype'] == 'int8_int8':
dtype = 2
runtime_cmd = [str(batch), str(in_ch), str(height), str(width), str(out_ch),
str(cur_param_dict['group']), str(cur_param_dict['kernel'])[0],
str(pad_top), str(pad_bottom),
str(pad_left), str(pad_right),
str(stride_h), str(stride_w),
str(dila_h), str(dila_w),
str(cur_param_dict['flag_bias']), str(cur_param_dict['flag_act']),
str(dtype)]
elif cur_op_name == 'batchnorm':
batch = input_dims.strip('[' ']').split()[0]
in_ch = input_dims.strip('[' ']').split()[1]
height = input_dims.strip('[' ']').split()[2]
width = input_dims.strip('[' ']').split()[3]
output_dims = input_dims
runtime_cmd = [str(batch), str(in_ch), str(height), str(width),
str(cur_param_dict['epsilon']), str(cur_param_dict['momentum'])]
elif cur_op_name == 'pooling':
batch = input_dims.strip('[' ']').split()[0]
in_ch = input_dims.strip('[' ']').split()[1]
height = input_dims.strip('[' ']').split()[2]
width = input_dims.strip('[' ']').split()[3]
hout = 1
wout = 1
pad_top = cur_param_dict['pad'].strip('[' ']').split()[0]
pad_bottom = cur_param_dict['pad'].strip('[' ']').split()[1]
pad_left = cur_param_dict['pad'].strip('[' ']').split()[2]
pad_right = cur_param_dict['pad'].strip('[' ']').split()[3]
kernel_h = cur_param_dict['kernel'][0]
kernel_w = cur_param_dict['kernel'][2]
stride_h = cur_param_dict['stride'].strip('[' ']').split()[0]
stride_w = cur_param_dict['stride'].strip('[' ']').split()[1]
if cur_param_dict['flag_global'] == '0':
if cur_param_dict['ceil_mode'] == '0':
hout = (int(height) - int(kernel_h) + int(pad_top) + int(pad_bottom)) / int(stride_h) + 1
wout = (int(width) - int(kernel_w) + int(pad_left) + int(pad_right)) / int(stride_w) + 1
else:
hout = (int(height) - int(kernel_h) + int(pad_top) + int(pad_bottom) + int(stride_h) - 1) / int(stride_h) + 1
wout = (int(width) - int(kernel_w) + int(pad_left) + int(pad_right) + int(stride_w) - 1) / int(stride_w) + 1
output_dims = '[' + batch + ' ' + str(in_ch) + ' ' + str(int(hout)) + ' ' + str(int(wout)) + ']'
pooling_type = 0
if cur_param_dict['pooling_type'] == 'max':
pooling_type = 0
else:
pooling_type = 1
runtime_cmd = [str(batch), str(in_ch), str(height), str(width),
str(stride_h), str(stride_w),
str(pad_top), str(pad_bottom),
str(pad_left), str(pad_right),
str(cur_param_dict['kernel'])[0], str(cur_param_dict['ceil_mode']),
str(cur_param_dict['flag_global']), str(cur_param_dict['exclusive']),
str(pooling_type)]
elif cur_op_name == 'activation':
batch = input_dims.strip('[' ']').split()[0]
in_ch = input_dims.strip('[' ']').split()[1]
height = input_dims.strip('[' ']').split()[2]
width = input_dims.strip('[' ']').split()[3]
act_type = 1
if cur_param_dict['act_type'] == 'relu':
act_type = 1
elif cur_param_dict['act_type'] == 'relu6':
act_type = 2
elif cur_param_dict['act_type'] == 'leaky_relu':
act_type = 4
elif cur_param_dict['act_type'] == 'sigmoid':
act_type = 5
elif cur_param_dict['act_type'] == 'tanh':
act_type = 6
elif cur_param_dict['act_type'] == 'swish':
act_type = 7
elif cur_param_dict['act_type'] == 'exp':
act_type = 8
elif cur_param_dict['act_type'] == 'abs':
act_type = 9
elif cur_param_dict['act_type'] == 'hard_swish':
act_type = 10
elif cur_param_dict['act_type'] == 'reciprocal':
act_type = 11
elif cur_param_dict['act_type'] == 'threshold_relu':
act_type = 12
output_dims = input_dims
runtime_cmd = [str(batch), str(in_ch), str(height), str(width),
str(act_type)]
elif cur_op_name == 'fc':
m = input_dims.strip('[' ']').split()[0]
k = input_dims.strip('[' ']').split()[1]
n = cur_param_dict['param_dim'].split('x')[1]
output_dims = '[' + m + ' ' + n + ']'
runtime_cmd = [str(m), str(n), str(k), str(cur_param_dict['flag_bias']),
str(cur_param_dict['dtype'])]
avg_latency, min_latency, max_latency = get_op_latency([cur_op_name] +
runtime_cmd + [str(args.threads), str(args.power_mode),
str(args.warmup_times), str(args.repeats_times)],
args.platform)
param_dict = ''
for k in cur_param_dict:
param_dict += str(k) + '=' + str(cur_param_dict[k]) + ','
param_dict = '(' + param_dict[:-1] + ')'
handle.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(cur_op_name.ljust(10), input_dims.ljust(10), output_dims.ljust(10), param_dict.ljust(80), str(min_latency).ljust(10), str(max_latency).ljust(10), str(avg_latency).ljust(10)))
fid.close()
handle.close()
print('Congratulations! Get Latency LookUp Table is Completed.')
if __name__ == '__main__':
main()
dev_info armv7/v8 core_num thread_num power_mode core0 arch core1 arch core2 arch core3 arch core4 arch core5 arch core6 arch core7 arch
Hisilicon Kirin980 armv8 8 1 ARM_A55 ARM_A55 ARM_A55 ARM_A55 ARM_A76 ARM_A76 ARM_A76 ARM_A76
op_name input_dims output_dims param_info min_latency(ms) max_latency(ms) avg_latency(ms)
conv [1 96 112 112] [1 48 114 114] (ch_out=48,stride=[1 1],pad=[0 0 0 0],kernel=1x1,group=1,dilation=[1 1],flag_bias=0,flag_act=0,dtype=float) 3.472 5.384 3.97393
fc [4 8] [4 1000] (param_dim=8x1000,flag_bias=1,dtype=float) 0.009 0.023 0.00951
batchnorm [1 8 64 64] [1 8 64 64] (epsilon=1e-4f,momentum=0.9f,dtype=float) 0.01 0.012 0.0114
pooling [1 8 64 64] [1 8 32 32] (stride=[2 2],pad=[0 0 0 0],kernel=2x2,ceil_mode=0,flag_global=0,exclusive=0,pooling_type=avg,dtype=float) 0.009 0.01 0.00969
activation [1 8 64 64] [1 8 64 64] (act_type=relu,dtype=float) 0.01 0.028 0.01098
conv [1 96 112 112] (ch_out=48, stride=[1 1], group=1, kernel=1x1, pad=[0 0 0 0], dilation=[1 1], flag_bias=0, flag_act=0, dtype=float)
fc [4 8] (flag_bias=1, param_dim=8x1000)
batchnorm [1 8 64 64] (epsilon=1e-4f, momentum=0.9f)
pooling [1 8 64 64] (stride=[2 2], kernel=2x2, pad=[0 0 0 0], exclusive=0, pooling_type=avg)
activation [1 8 64 64] (act_type=relu)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdlib.h>
#include <iostream>
#include <memory>
#include "lite/core/context.h"
#include "lite/core/profile/timer.h"
#include "lite/core/tensor.h"
#include "lite/kernels/arm/activation_compute.h"
#include "lite/operators/op_params.h"
#include "lite/tests/utils/tensor_utils.h"
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::DDim DDim;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer;
int main(int argc, char** argv) {
if (argc != 10) {
std::cerr << "usage: " << argv[0] << "\n"
<< " <batch_size>\n"
<< " <input_channel>\n"
<< " <input_height>\n"
<< " <input_width>\n"
<< " <act_type>\n"
<< " <thread_num>\n"
<< " <power_mode>\n"
<< " <warmup_times>\n"
<< " <repeats_times>" << std::endl;
return 0;
}
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
int batch_size = atoi(argv[1]);
int input_channel = atoi(argv[2]);
int input_height = atoi(argv[3]);
int input_width = atoi(argv[4]);
int thread_num = atoi(argv[6]);
int power_mode = atoi(argv[7]);
int warmup = atoi(argv[8]);
int repeats = atoi(argv[9]);
int act_type = atoi(argv[5]);
const float six = 6.f;
const float leakey_relu_scale = 8.88f;
#ifdef LITE_WITH_ARM
ActivationParam act_param;
Tensor x, y;
DDim dim_in = DDim({batch_size, input_channel, input_height, input_width});
x.set_precision(PRECISION(kFloat));
x.Resize(dim_in);
paddle::lite::fill_tensor_rand(x, -1.f, 1.f);
act_param.X = &x;
act_param.active_type = (paddle::lite_api::ActivationType)act_type;
act_param.has_active = true;
if (act_type == 2) {
act_param.Relu_clipped_coef = six;
} else if (act_type == 4) {
act_param.Leaky_relu_alpha = leakey_relu_scale;
}
act_param.Out = &y;
act_param.Out->set_precision(PRECISION(kFloat));
act_param.Out->Resize(dim_in);
Timer t0;
if (act_type == 1) {
paddle::lite::kernels::arm::ReluCompute act_compute;
act_compute.SetParam(act_param);
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
act_compute.SetContext(std::move(ctx1));
act_compute.PrepareForRun();
// warm up
for (int i = 0; i < warmup; ++i) {
act_compute.Launch();
}
// compute
for (int i = 0; i < repeats; ++i) {
t0.Start();
act_compute.Launch();
t0.Stop();
}
} else if (act_type == 2) {
paddle::lite::kernels::arm::Relu6Compute act_compute;
act_compute.SetParam(act_param);
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
act_compute.SetContext(std::move(ctx1));
act_compute.PrepareForRun();
// warm up
for (int i = 0; i < warmup; ++i) {
act_compute.Launch();
}
// compute
for (int i = 0; i < repeats; ++i) {
t0.Start();
act_compute.Launch();
t0.Stop();
}
} else if (act_type == 4) {
paddle::lite::kernels::arm::LeakyReluCompute act_compute;
act_compute.SetParam(act_param);
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
act_compute.SetContext(std::move(ctx1));
act_compute.PrepareForRun();
// warm up
for (int i = 0; i < warmup; ++i) {
act_compute.Launch();
}
// compute
for (int i = 0; i < repeats; ++i) {
t0.Start();
act_compute.Launch();
t0.Stop();
}
} else if (act_type == 5) {
paddle::lite::kernels::arm::SigmoidCompute act_compute;
act_compute.SetParam(act_param);
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
act_compute.SetContext(std::move(ctx1));
act_compute.PrepareForRun();
// warm up
for (int i = 0; i < warmup; ++i) {
act_compute.Launch();
}
// compute
for (int i = 0; i < repeats; ++i) {
t0.Start();
act_compute.Launch();
t0.Stop();
}
} else if (act_type == 6) {
paddle::lite::kernels::arm::TanhCompute act_compute;
act_compute.SetParam(act_param);
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
act_compute.SetContext(std::move(ctx1));
act_compute.PrepareForRun();
// warm up
for (int i = 0; i < warmup; ++i) {
act_compute.Launch();
}
// compute
for (int i = 0; i < repeats; ++i) {
t0.Start();
act_compute.Launch();
t0.Stop();
}
} else if (act_type == 7) {
paddle::lite::kernels::arm::SwishCompute act_compute;
act_compute.SetParam(act_param);
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
act_compute.SetContext(std::move(ctx1));
act_compute.PrepareForRun();
// warm up
for (int i = 0; i < warmup; ++i) {
act_compute.Launch();
}
// compute
for (int i = 0; i < repeats; ++i) {
t0.Start();
act_compute.Launch();
t0.Stop();
}
} else if (act_type == 8) {
paddle::lite::kernels::arm::ExpCompute act_compute;
act_compute.SetParam(act_param);
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
act_compute.SetContext(std::move(ctx1));
act_compute.PrepareForRun();
// warm up
for (int i = 0; i < warmup; ++i) {
act_compute.Launch();
}
// compute
for (int i = 0; i < repeats; ++i) {
t0.Start();
act_compute.Launch();
t0.Stop();
}
} else if (act_type == 9) {
paddle::lite::kernels::arm::AbsCompute act_compute;
act_compute.SetParam(act_param);
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
act_compute.SetContext(std::move(ctx1));
act_compute.PrepareForRun();
// warm up
for (int i = 0; i < warmup; ++i) {
act_compute.Launch();
}
// compute
for (int i = 0; i < repeats; ++i) {
t0.Start();
act_compute.Launch();
t0.Stop();
}
} else if (act_type == 10) {
paddle::lite::kernels::arm::HardSwishCompute act_compute;
act_compute.SetParam(act_param);
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
act_compute.SetContext(std::move(ctx1));
act_compute.PrepareForRun();
// warm up
for (int i = 0; i < warmup; ++i) {
act_compute.Launch();
}
// compute
for (int i = 0; i < repeats; ++i) {
t0.Start();
act_compute.Launch();
t0.Stop();
}
} else if (act_type == 11) {
paddle::lite::kernels::arm::ReciprocalCompute act_compute;
act_compute.SetParam(act_param);
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
act_compute.SetContext(std::move(ctx1));
act_compute.PrepareForRun();
// warm up
for (int i = 0; i < warmup; ++i) {
act_compute.Launch();
}
// compute
for (int i = 0; i < repeats; ++i) {
t0.Start();
act_compute.Launch();
t0.Stop();
}
} else if (act_type == 12) {
paddle::lite::kernels::arm::ThresholdedReluCompute act_compute;
act_compute.SetParam(act_param);
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
act_compute.SetContext(std::move(ctx1));
act_compute.PrepareForRun();
// warm up
for (int i = 0; i < warmup; ++i) {
act_compute.Launch();
}
// compute
for (int i = 0; i < repeats; ++i) {
t0.Start();
act_compute.Launch();
t0.Stop();
}
}
printf("Avg Latency is %f\n", t0.LapTimes().Avg());
printf("Min Latency is %f\n", t0.LapTimes().Min());
printf("Max Latency is %f\n", t0.LapTimes().Max());
#endif
return 0;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdlib.h>
#include <iostream>
#include "lite/core/context.h"
#include "lite/core/profile/timer.h"
#include "lite/core/tensor.h"
#include "lite/kernels/arm/batch_norm_compute.h"
#include "lite/operators/op_params.h"
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::kernels::arm::BatchNormCompute BatchNormCompute;
using paddle::lite::profile::Timer;
int main(int argc, char** argv) {
if (argc != 11) {
std::cerr << "usage: " << argv[0] << "\n"
<< " <batch_size>\n"
<< " <input_channel>\n"
<< " <input_height>\n"
<< " <input_width>\n"
<< " <epsilon>\n"
<< " <momentum>\n"
<< " <thread_num>\n"
<< " <power_mode>\n"
<< " <warmup_times>\n"
<< " <repeats_times>\n"
<< std::endl;
return 0;
}
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
int batch_size = atoi(argv[1]);
int input_channel = atoi(argv[2]);
int input_height = atoi(argv[3]);
int input_width = atoi(argv[4]);
float epsilon = atof(argv[5]);
float momentum = atof(argv[6]);
int thread_num = atoi(argv[7]);
int power_mode = atoi(argv[8]);
int warmup = atoi(argv[9]);
int repeats = atoi(argv[10]);
#ifdef LITE_WITH_ARM
Tensor x;
Tensor scale;
Tensor bias;
Tensor mean;
Tensor variance;
Tensor y;
Tensor mean_out;
Tensor variance_out;
Tensor saved_mean;
Tensor saved_variance;
std::vector<int64_t> in_out_shape = {
batch_size, input_channel, input_height, input_width};
x.Resize(in_out_shape);
scale.Resize({input_channel});
bias.Resize({input_channel});
mean.Resize({input_channel});
variance.Resize({input_channel});
y.Resize(in_out_shape);
mean_out.Resize({input_channel});
variance_out.Resize({input_channel});
saved_mean.Resize({input_channel});
saved_variance.Resize({input_channel});
// initialize the data of input tensors
auto* x_data = x.mutable_data<float>();
auto* scale_data = scale.mutable_data<float>();
auto* bias_data = bias.mutable_data<float>();
auto* mean_data = mean.mutable_data<float>();
auto* variance_data = variance.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(i % 64);
}
for (int i = 0; i < scale.dims().production(); i++) {
scale_data[i] = static_cast<float>(i) * 0.01f + 0.03f;
}
for (int i = 0; i < bias.dims().production(); i++) {
bias_data[i] = static_cast<float>(i) * 0.065f + 0.1f;
}
for (int i = 0; i < mean.dims().production(); i++) {
mean_data[i] = static_cast<float>(i) * 0.0565f;
}
for (int i = 0; i < variance.dims().production(); i++) {
variance_data[i] = static_cast<float>(i) * 2.08f + 1.5f;
}
// prepare kernel params and run
BatchNormCompute batch_norm;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
batch_norm.SetContext(std::move(ctx1));
paddle::lite::operators::BatchNormParam param;
param.x = &x;
param.scale = &scale;
param.bias = &bias;
param.mean = &mean;
param.variance = &variance;
param.is_test = false;
param.use_global_stats = true;
param.epsilon = epsilon;
param.momentum = momentum;
param.data_layout = DATALAYOUT(kNCHW);
param.y = &y;
param.mean_out = &mean_out;
param.variance_out = &variance_out;
param.saved_mean = &saved_mean;
param.saved_variance = &saved_variance;
batch_norm.SetParam(param);
// warm up
for (int i = 0; i < warmup; ++i) {
batch_norm.Launch();
}
// compute
Timer t0;
for (int i = 0; i < repeats; ++i) {
t0.Start();
batch_norm.Launch();
t0.Stop();
}
printf("Avg Latency is %f\n", t0.LapTimes().Avg());
printf("Min Latency is %f\n", t0.LapTimes().Min());
printf("Max Latency is %f\n", t0.LapTimes().Max());
#endif
return 0;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdlib.h>
#include <iostream>
#include "lite/core/context.h"
#include "lite/core/profile/timer.h"
#include "lite/core/tensor.h"
#include "lite/kernels/arm/conv_compute.h"
#include "lite/operators/op_params.h"
#include "lite/tests/utils/tensor_utils.h"
typedef paddle::lite::operators::ConvParam ConvParam;
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::DDim DDim;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer;
using paddle::lite_api::PrecisionType;
DDim compute_out_dim(const DDim& dim_in,
const paddle::lite::operators::ConvParam& param) {
DDim dim_out = dim_in;
auto paddings = *param.paddings;
auto dilations = *param.dilations;
dim_out[1] = param.filter->dims()[0];
auto kernel_h = param.filter->dims()[2];
auto kernel_w = param.filter->dims()[3];
auto h = dim_in[2];
auto w = dim_in[3];
int dila_h = dilations[0];
int dila_w = dilations[1];
int pad_top = paddings[0];
int pad_bottom = paddings[1];
int pad_left = paddings[2];
int pad_right = paddings[3];
int stride_h = param.strides[0];
int stride_w = param.strides[1];
auto kernel_exten = dila_h * (kernel_h - 1) + 1;
auto hout = (h + pad_top + pad_bottom - kernel_exten) / stride_h + 1;
kernel_exten = dila_w * (kernel_w - 1) + 1;
auto wout = (w + pad_left + pad_right - kernel_exten) / stride_w + 1;
dim_out[2] = hout;
dim_out[3] = wout;
return dim_out;
}
template <PrecisionType Ptype, PrecisionType OutType>
void test_conv(const DDim& input_dims,
const DDim& weight_dims,
const int group,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dilas,
const bool flag_bias,
const int flag_act,
const int thread_num,
const int power_mode,
const int warmup,
const int repeats,
const float leakey_relu_scale = 8.88f) {
ConvParam param;
Tensor x, f, y;
Tensor bias;
param.x = &x;
param.x->set_precision(Ptype);
param.filter = &f;
param.filter->Resize(weight_dims);
param.filter->set_precision(Ptype);
if (flag_bias) {
param.bias = &bias;
param.bias->Resize({weight_dims[0]});
param.bias->set_precision(PRECISION(kFloat));
}
param.strides = strides;
param.paddings = std::make_shared<std::vector<int>>(pads);
param.dilations = std::make_shared<std::vector<int>>(dilas);
param.groups = group;
const float six = 6.f;
if (Ptype == PRECISION(kInt8)) {
std::vector<float> scale_in{1.f / 127};
std::vector<float> scale_out(1, weight_dims.count(1, 4) / 127.f);
if (flag_act == 2) {
scale_out[0] = six / 127.f;
} else if (flag_act == 4) {
if (std::abs(leakey_relu_scale) > 1) {
scale_out[0] *= std::abs(leakey_relu_scale);
}
}
std::vector<float> scale_w(weight_dims[0], 1.f / 127);
param.input_scale = scale_in[0];
param.output_scale = scale_out[0];
param.weight_scale = scale_w;
}
if (flag_act > 0) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type = (paddle::lite_api::ActivationType)
flag_act; // 1-relu, 2-relu6, 4-leakyrelu
if (flag_act == 1) {
param.fuse_relu = true;
} else if (flag_act == 2) {
act_param.Relu_clipped_coef = six;
} else if (flag_act == 4) {
act_param.Leaky_relu_alpha = leakey_relu_scale;
}
param.activation_param = act_param;
}
param.output = &y;
param.output->set_precision(OutType);
paddle::lite::fill_tensor_rand(*param.filter, -1.f, 1.f);
if (flag_bias) {
paddle::lite::fill_tensor_rand(*param.bias, -1.f, 1.f);
}
paddle::lite::kernels::arm::ConvCompute<Ptype, OutType> conv;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
param.x->Resize(input_dims);
DDim dim_out = compute_out_dim(input_dims, param);
param.output->Resize(dim_out);
conv.SetParam(param);
conv.SetContext(std::move(ctx1));
conv.PrepareForRun();
paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f);
// warm up
for (int i = 0; i < warmup; ++i) {
conv.Launch();
}
// compute
Timer t0;
for (int i = 0; i < repeats; ++i) {
t0.Start();
conv.Launch();
t0.Stop();
}
printf("Avg Latency is %f\n", t0.LapTimes().Avg());
printf("Min Latency is %f\n", t0.LapTimes().Min());
printf("Max Latency is %f\n", t0.LapTimes().Max());
}
int main(int argc, char** argv) {
if (argc != 23) {
std::cerr << "usage: " << argv[0] << "\n"
<< " <batch_size>\n"
<< " <input_channel>\n"
<< " <input_height>\n"
<< " <input_width>\n"
<< " <output_channel>\n"
<< " <group_size>\n"
<< " <kernel_size>\n"
<< " <pad_top>\n"
<< " <pad_bottom>\n"
<< " <pad_left>\n"
<< " <pad_right>\n"
<< " <stride_h>\n"
<< " <stride_w>\n"
<< " <dilation_h>\n"
<< " <dilation_w>\n"
<< " <flag_bias>\n"
<< " <flag_act>\n"
<< " <dtype>\n"
<< " <thread_num>\n"
<< " <power_mode>\n"
<< " <warmup_times>\n"
<< " <repeats_times>\n"
<< std::endl;
return 0;
}
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
int batch_size = atoi(argv[1]);
int input_channel = atoi(argv[2]);
int input_height = atoi(argv[3]);
int input_width = atoi(argv[4]);
int output_channel = atoi(argv[5]);
int group_size = atoi(argv[6]);
int kernel_size = atoi(argv[7]);
int pad_top = atoi(argv[8]);
int pad_bottom = atoi(argv[9]);
int pad_left = atoi(argv[10]);
int pad_right = atoi(argv[11]);
int stride_h = atoi(argv[12]);
int stride_w = atoi(argv[13]);
int dilation_h = atoi(argv[14]);
int dilation_w = atoi(argv[15]);
int flag_bias = atoi(argv[16]);
int flag_act = atoi(argv[17]);
int dtype = atoi(argv[18]);
int thread_num = atoi(argv[19]);
int power_mode = atoi(argv[20]);
int warmup = atoi(argv[21]);
int repeats = atoi(argv[22]);
DDim weight_dims(
{output_channel, input_channel / group_size, kernel_size, kernel_size});
DDim input_dims({batch_size, input_channel, input_height, input_width});
switch (dtype) {
case 0:
test_conv<PRECISION(kFloat), PRECISION(kFloat)>(
input_dims,
weight_dims,
group_size,
{stride_h, stride_w},
{pad_top, pad_bottom, pad_left, pad_right},
{dilation_h, dilation_w},
flag_bias,
flag_act,
thread_num,
power_mode,
warmup,
repeats);
break;
case 1:
test_conv<PRECISION(kInt8), PRECISION(kFloat)>(
input_dims,
weight_dims,
group_size,
{stride_h, stride_w},
{pad_top, pad_bottom, pad_left, pad_right},
{dilation_h, dilation_w},
flag_bias,
flag_act,
thread_num,
power_mode,
warmup,
repeats);
break;
case 2:
test_conv<PRECISION(kInt8), PRECISION(kInt8)>(
input_dims,
weight_dims,
group_size,
{stride_h, stride_w},
{pad_top, pad_bottom, pad_left, pad_right},
{dilation_h, dilation_w},
flag_bias,
flag_act,
thread_num,
power_mode,
warmup,
repeats);
break;
default:
test_conv<PRECISION(kFloat), PRECISION(kFloat)>(
input_dims,
weight_dims,
group_size,
{stride_h, stride_w},
{pad_top, pad_bottom, pad_left, pad_right},
{dilation_h, dilation_w},
flag_bias,
flag_act,
thread_num,
power_mode,
warmup,
repeats);
}
return 0;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdlib.h>
#include <iostream>
#include "lite/core/context.h"
#include "lite/core/profile/timer.h"
#include "lite/core/tensor.h"
#include "lite/kernels/arm/fc_compute.h"
#include "lite/operators/op_params.h"
#include "lite/tests/utils/tensor_utils.h"
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::DDim DDim;
typedef paddle::lite::operators::FcParam FcParam;
using paddle::lite::profile::Timer;
using paddle::lite_api::PrecisionType;
template <PrecisionType Ptype, PrecisionType OutType>
void test_fc(const int m,
const int n,
const int k,
const bool has_bias,
const int thread_num,
const int power_mode,
const int warmup,
const int repeats) {
FcParam param;
Tensor x, y, bias, w;
param.input = &x;
param.input->set_precision(Ptype);
param.input->Resize({m, k});
param.w = &w;
param.w->set_precision(Ptype);
param.w->Resize({k, n});
if (has_bias) {
param.bias = &bias;
param.bias->set_precision(Ptype);
param.bias->Resize({1, n});
} else {
param.bias = nullptr;
}
param.output = &y;
param.output->set_precision(OutType);
param.output->Resize({m, n});
param.in_num_col_dims = 1;
param.in_mat_dims = param.input->dims();
paddle::lite::kernels::arm::FcCompute<Ptype, OutType> fc_compute;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
// set param and context
fc_compute.SetParam(param);
fc_compute.SetContext(std::move(ctx1));
// prepare for run
fc_compute.PrepareForRun();
paddle::lite::fill_tensor_rand(*param.input, -1.f, 1.f);
paddle::lite::fill_tensor_rand(*param.w, -1.f, 1.f);
if (has_bias) {
paddle::lite::fill_tensor_rand(*param.bias, -1.f, 1.f);
}
// warm up
for (int i = 0; i < warmup; ++i) {
fc_compute.Launch();
}
// compute
Timer t0;
for (int i = 0; i < repeats; ++i) {
t0.Start();
fc_compute.Launch();
t0.Stop();
}
printf("Avg Latency is %f\n", t0.LapTimes().Avg());
printf("Min Latency is %f\n", t0.LapTimes().Min());
printf("Max Latency is %f\n", t0.LapTimes().Max());
}
int main(int argc, char** argv) {
if (argc != 10) {
std::cerr << "usage: " << argv[0] << "\n"
<< " <m>\n"
<< " <n>\n"
<< " <k>\n"
<< " <has_bias>\n"
<< " <dtype>\n"
<< " <thread_num>\n"
<< " <power_mode>\n"
<< " <warmup_times>\n"
<< " <repeats_times>\n"
<< std::endl;
return 0;
}
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
int m = atoi(argv[1]);
int n = atoi(argv[2]);
int k = atoi(argv[3]);
bool has_bias = atoi(argv[4]) == 0 ? false : true;
int dtype = argv[5] == "int8_int8" ? 2 : argv[5] == "float_int8"
? 1
: argv[5] == "float" ? 0 : 0;
int thread_num = atoi(argv[6]);
int power_mode = atoi(argv[7]);
int warmup = atoi(argv[8]);
int repeats = atoi(argv[9]);
switch (dtype) {
case 0:
test_fc<PRECISION(kFloat), PRECISION(kFloat)>(
m, n, k, has_bias, thread_num, power_mode, warmup, repeats);
break;
case 1:
test_fc<PRECISION(kInt8), PRECISION(kFloat)>(
m, n, k, has_bias, thread_num, power_mode, warmup, repeats);
break;
case 2:
test_fc<PRECISION(kInt8), PRECISION(kInt8)>(
m, n, k, has_bias, thread_num, power_mode, warmup, repeats);
break;
default:
test_fc<PRECISION(kFloat), PRECISION(kFloat)>(
m, n, k, has_bias, thread_num, power_mode, warmup, repeats);
break;
}
return 0;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdlib.h>
#include <iostream>
#include "lite/core/context.h"
#include "lite/core/profile/timer.h"
#include "lite/core/tensor.h"
#include "lite/kernels/arm/pool_compute.h"
#include "lite/operators/op_params.h"
#include "lite/tests/utils/tensor_utils.h"
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::DDim DDim;
typedef paddle::lite::operators::PoolParam PoolParam;
using paddle::lite::profile::Timer;
DDim compute_out_dim(const DDim& dim_in,
const paddle::lite::operators::PoolParam& param) {
DDim dim_out = dim_in;
auto kernel_h = param.ksize[0];
auto kernel_w = param.ksize[1];
auto h = dim_in[2];
auto w = dim_in[3];
auto paddings = *param.paddings;
int stride_h = param.strides[0];
int stride_w = param.strides[1];
bool ceil_mode = param.ceil_mode;
bool flag_global = param.global_pooling;
int hout = 1;
int wout = 1;
if (!flag_global) {
if (!ceil_mode) {
hout = (h - kernel_h + paddings[0] + paddings[1]) / stride_h + 1;
wout = (w - kernel_w + paddings[2] + paddings[3]) / stride_w + 1;
} else {
hout =
(h - kernel_h + paddings[0] + paddings[1] + stride_h - 1) / stride_h +
1;
wout =
(w - kernel_w + paddings[2] + paddings[3] + stride_w - 1) / stride_w +
1;
}
}
dim_out[2] = hout;
dim_out[3] = wout;
return dim_out;
}
int main(int argc, char** argv) {
if (argc != 20) {
std::cerr << "usage: " << argv[0] << "\n"
<< " <batch_size>\n"
<< " <input_channel>\n"
<< " <input_height>\n"
<< " <input_width>\n"
<< " <kernel_size>\n"
<< " <stride_size>\n"
<< " <pad_size>\n"
<< " <exclusive>\n"
<< " <pooling_type>\n"
<< " <ceil_mode>\n"
<< " <flag_global>\n"
<< " <thread_num>\n"
<< " <power_mode>\n"
<< " <warmup_times>\n"
<< " <repeats_times>\n"
<< std::endl;
return 0;
}
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
int batch_size = atoi(argv[1]);
int input_channel = atoi(argv[2]);
int input_height = atoi(argv[3]);
int input_width = atoi(argv[4]);
int stride_h = atoi(argv[5]);
int stride_w = atoi(argv[6]);
int pad_top = atoi(argv[7]);
int pad_bottom = atoi(argv[8]);
int pad_left = atoi(argv[9]);
int pad_right = atoi(argv[10]);
int kernel_size = atoi(argv[11]);
bool ceil_mode = argv[12] == 0 ? false : true;
bool flag_global = argv[13] == 0 ? false : true;
bool exclusive = atoi(argv[14]) == 0 ? false : true;
std::string pooling_type = atoi(argv[15]) == 0 ? "max" : "avg";
int thread_num = atoi(argv[16]);
int power_mode = atoi(argv[17]);
int warmup = atoi(argv[18]);
int repeats = atoi(argv[19]);
#ifdef LITE_WITH_ARM
PoolParam param;
Tensor x, y;
param.x = &x;
param.x->set_precision(PRECISION(kFloat));
param.ksize = {kernel_size, kernel_size};
param.strides = {stride_h, stride_w};
param.paddings = std::make_shared<std::vector<int>>(
std::vector<int>{pad_top, pad_bottom, pad_left, pad_right});
param.ceil_mode = ceil_mode;
param.global_pooling = flag_global;
param.pooling_type = pooling_type;
param.exclusive = exclusive;
param.adaptive = false;
param.use_quantizer = false;
param.output = &y;
param.output->set_precision(PRECISION(kFloat));
paddle::lite::kernels::arm::PoolCompute pool;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(power_mode),
thread_num);
// set param and context
pool.SetParam(param);
pool.SetContext(std::move(ctx1));
// prepare for run
pool.PrepareForRun();
DDim dim_in = DDim({batch_size, input_channel, input_height, input_width});
DDim dim_out = compute_out_dim(dim_in, param);
param.x->Resize(dim_in);
param.output->Resize(dim_out);
paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f);
// warm up
for (int i = 0; i < warmup; ++i) {
pool.Launch();
}
// compute
Timer t0;
for (int i = 0; i < repeats; ++i) {
t0.Start();
pool.Launch();
t0.Stop();
}
printf("Avg Latency is %f\n", t0.LapTimes().Avg());
printf("Min Latency is %f\n", t0.LapTimes().Min());
printf("Max Latency is %f\n", t0.LapTimes().Max());
#endif
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册