diff --git a/CMakeLists.txt b/CMakeLists.txt index a98d815943cf4d4bb3d632ccfcb83fc7818e047d..bdbf5a6ea604400fb5087976df0e1e9c279fd78d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,9 +2,9 @@ cmake_minimum_required(VERSION 3.0) project(paddle-mobile) # select the platform to build -option(CPU "armv7 with neon support" OFF) +option(CPU "armv7 with neon support" ON) option(MALI_GPU "mali gpu support" OFF) -option(FPGA "fpga support" ON) +option(FPGA "fpga support" OFF) option(USE_OPENMP "openmp support" OFF) option(DEBUGING "enable debug mode" ON) diff --git a/doc/development_fpga.md b/doc/development_fpga.md new file mode 100644 index 0000000000000000000000000000000000000000..14cc57c6b4055e8c4e45d8b673eb1e3be22ae256 --- /dev/null +++ b/doc/development_fpga.md @@ -0,0 +1,37 @@ +# FPGA开发文档 + +FPGA平台的代码在Xilinx ZCU102 revision 1.0开发板测试Resnet50成功,预测结果正确。 + +## 准备硬件 +___ + +1. 购买Xilinx ZCU102 revision1.0 开发板 +2. 另外下载Xilinx ZCU102 Ubuntu[镜像文件](https://www.xilinx.com/member/forms/download/xef.html?filename=Ubuntu_Desktop_Release_2018_1.zip),并烧录进SD卡。 + * Windowns系统可使用Win32DiskImager + * Linux系统使用dd命令:dd if=name.img of=/dev/sdb +2. 将SD卡插入电脑,替换分区1中已有的BOOT.BIN、image.ub为[BOOT.BIN、image.ub](http://mms-graph.bj.bcebos.com/paddle-mobile/fpga/files.tar.gz) +3. 将SD卡插入ZCU102开发板,设置板拨码开关为SD卡启动,上电启动Linux系统. +3. 装载驱动:sudo insmod [fpgadrv.ko](http://mms-graph.bj.bcebos.com/paddle-mobile/fpga/files.tar.gz) + + +## 编译工程 +___ +1. 将最新的paddle mobile 代码复制到ZCU102开发板中。 +2. 进入paddle-mobile根目录, CMakeLists.txt 设置平台为 option(FPGA "fpga support" ON)。CPU和MALI\_GPU选项设置为OFF。 +2. 执行以下命令,可在./test/build下生成test-resnet50可执行程序。 + * mkdir build + * cd build + * cmake .. + * make + +## 准备模型和数据 +___ +1. 模型文件放在./test/models/resnet50中。将[\_\_model\_\_](http://mms-graph.bj.bcebos.com/paddle-mobile/fpga/files.tar.gz)文件复制到此文件夹下。 +2. 另外下载模型[权重文件](http://paddle-imagenet-models.bj.bcebos.com/resnet_50_model.tar),解压后也放在./test/models/resnet50 中。 +3. 将数据文件[image_src_float](http://mms-graph.bj.bcebos.com/paddle-mobile/fpga/files.tar.gz)复制到/test/images下。此数据文件对应着标准数据集中的ILSVRC2012_val_00000885.JPEG,分类标签为80, 对应着"black grouse". + +## 运行程序 +___ +1. 进入./test/build目录。 +2. sudo ./test-resnet50 +3. 如果于DEBUG选项是否打开,屏幕会输出很多中间打印信息。最终打印出预测分类结果为80。 diff --git a/src/common/types.cpp b/src/common/types.cpp index 2a8b8c8a151e58d13093c99c753cc47e8eef64a3..8c8de7765161e61dc75036a87a34fc6abd2df43e 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -22,6 +22,7 @@ const char *G_OP_TYPE_BATCHNORM = "batch_norm"; const char *G_OP_TYPE_BOX_CODER = "box_coder"; const char *G_OP_TYPE_CONCAT = "concat"; const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; +const char *G_OP_TYPE_FILL_CONSTANT = "fill_constant"; const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu"; const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu"; const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU = "fusion_conv_add_add_prelu"; @@ -99,6 +100,7 @@ std::unordered_map< {G_OP_TYPE_FC, {{"X", "Y", "Z"}, {"Out"}}}, {G_OP_TYPE_RESHAPE, {{"X"}, {"Out"}}}, {G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}}, + {G_OP_TYPE_FILL_CONSTANT, {{}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_PRELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU, {{"Input"}, {"Out"}}}, diff --git a/src/common/variant.h b/src/common/variant.h index 2d81160a840668e26ab052afbdd05367cde5189a..4aa4f47c628caec438ecd00522d90ebf299da6a0 100644 --- a/src/common/variant.h +++ b/src/common/variant.h @@ -12,14 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma once + #include #include +#include #include "common/enforce.h" #include "common/log.h" -#pragma once - namespace paddle_mobile { + template struct IDToType { typedef Type type_t; diff --git a/src/fpga/bias_scale.cpp b/src/fpga/bias_scale.cpp index 23889d5b1fee3d8cb9e4673f42b18574366411eb..50f1ed03f0121b5afdc41d427e5b52675994bd1e 100644 --- a/src/fpga/bias_scale.cpp +++ b/src/fpga/bias_scale.cpp @@ -27,9 +27,6 @@ void align_element(float **data_in, int num_per_div_before_alignment, int num) { (num + num_per_div_before_alignment - 1) / num_per_div_before_alignment; int num_per_div_after_alignment = align_to_x(num_per_div_before_alignment, BS_NUM_ALIGNMENT); - if (num_per_div_before_alignment == num_per_div_after_alignment) { - return; - } int num_element = 2 * div_num * num_per_div_after_alignment; // including bias & scale float *ptr_aligned = diff --git a/src/framework/attribute.h b/src/framework/attribute.h index ff9e1204a1e32f3ffe6271d4d2d76b8e3cf24d63..a94346bc7ab321b0f5710a98fb3cc60198f148b0 100644 --- a/src/framework/attribute.h +++ b/src/framework/attribute.h @@ -156,7 +156,7 @@ class AttrReader { template inline T Get(const string &name) const { PADDLE_MOBILE_ENFORCE(attrs_.count(name) != 0, - "%s should be in AttributeMap", name); + "%s should be in AttributeMap", name.c_str()); return ((Attribute)attrs_.at(name)).Get(); } diff --git a/src/framework/data_type.cpp b/src/framework/data_type.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0bcf7d9f67dae28db5a316476778b4132b39b274 --- /dev/null +++ b/src/framework/data_type.cpp @@ -0,0 +1,107 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "framework/data_type.h" +#include +#include +#include + +namespace paddle_mobile { +namespace framework { + +struct DataTypeMap { + std::unordered_map + cpp_to_proto_; + std::unordered_map proto_to_cpp_; + std::unordered_map proto_to_str_; + std::unordered_map cpp_to_size_; +}; + +static DataTypeMap* InitDataTypeMap(); +// C++11 removes the need for manual locking. Concurrent execution shall wait if +// a static local variable is already being initialized. +// https://stackoverflow.com/questions/11711920/how-to-implement-multithread-safe-singleton-in-c11-without-using-mutex +static DataTypeMap& gDataTypeMap() { + static DataTypeMap* g_data_type_map_ = InitDataTypeMap(); + return *g_data_type_map_; +} + +template +static inline void RegisterType( + DataTypeMap* map, _PaddleMobile__Framework__Proto__VarType__Type proto_type, + const std::string& name) { + map->proto_to_cpp_.emplace(static_cast(proto_type), typeid(T)); + map->cpp_to_proto_.emplace(typeid(T), proto_type); + map->proto_to_str_.emplace(static_cast(proto_type), name); + map->cpp_to_size_.emplace(typeid(T), sizeof(T)); +} + +static DataTypeMap* InitDataTypeMap() { + auto retv = new DataTypeMap(); + +#define RegType(cc_type, proto_type) \ + RegisterType(retv, proto_type, #cc_type) + + // NOTE: Add your customize type here. + // RegType(float16, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP16); + RegType(float, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP32); + RegType(double, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP64); + RegType(int, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT32); + RegType(int64_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT64); + RegType(bool, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__BOOL); + RegType(size_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__SIZE_T); + RegType(int16_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT16); + RegType(uint8_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__UINT8); + RegType(int8_t, PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT8); + +#undef RegType + return retv; +} + +_PaddleMobile__Framework__Proto__VarType__Type ToDataType( + std::type_index type) { + auto it = gDataTypeMap().cpp_to_proto_.find(type); + if (it != gDataTypeMap().cpp_to_proto_.end()) { + return it->second; + } + PADDLE_MOBILE_THROW_EXCEPTION("Not support %s as tensor type", type.name()); +} + +std::type_index ToTypeIndex( + _PaddleMobile__Framework__Proto__VarType__Type type) { + auto it = gDataTypeMap().proto_to_cpp_.find(static_cast(type)); + if (it != gDataTypeMap().proto_to_cpp_.end()) { + return it->second; + } + PADDLE_MOBILE_THROW_EXCEPTION( + "Not support _PaddleMobile__Framework__Proto__VarType__Type(%d) as " + "tensor type", + static_cast(type)); +} + +std::string DataTypeToString( + const _PaddleMobile__Framework__Proto__VarType__Type type) { + auto it = gDataTypeMap().proto_to_str_.find(static_cast(type)); + if (it != gDataTypeMap().proto_to_str_.end()) { + return it->second; + } + PADDLE_MOBILE_THROW_EXCEPTION( + "Not support _PaddleMobile__Framework__Proto__VarType__Type(%d) as " + "tensor type", + static_cast(type)); +} + +} // namespace framework +} // namespace paddle_mobile diff --git a/src/framework/data_type.h b/src/framework/data_type.h new file mode 100644 index 0000000000000000000000000000000000000000..2e3623fdedcb527cb0c85bbb7a2eaf04d91a2193 --- /dev/null +++ b/src/framework/data_type.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "common/enforce.h" +#include "framework/framework.pb-c.h" + +namespace paddle_mobile { + +namespace framework { + +extern _PaddleMobile__Framework__Proto__VarType__Type ToDataType( + std::type_index type); +extern std::type_index ToTypeIndex( + _PaddleMobile__Framework__Proto__VarType__Type type); + +template +inline void VisitDataType(_PaddleMobile__Framework__Proto__VarType__Type type, + Visitor visitor) { + switch (type) { + // case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP16: + // visitor.template apply(); + // break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP32: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP64: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT32: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT64: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__BOOL: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__UINT8: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT16: + visitor.template apply(); + break; + case PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT8: + visitor.template apply(); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Not supported %d", type); + } +} + +extern std::string DataTypeToString( + const _PaddleMobile__Framework__Proto__VarType__Type type); +inline std::ostream& operator<<( + std::ostream& out, + const _PaddleMobile__Framework__Proto__VarType__Type& type) { + out << DataTypeToString(type); + return out; +} + +} // namespace framework +} // namespace paddle_mobile diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 7fd6290704a573aac535685b4bdf48092b35e98b..2b76b0158fe06e8678208f6f98fcdb71f8d91e51 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -64,6 +64,9 @@ limitations under the License. */ // load requared ops LOAD_OP(feed) LOAD_OP(fetch) +#ifdef FILL_CONSTANT_OP +LOAD_OP(fill_constant) +#endif #ifdef BATCHNORM_OP LOAD_OP2(batch_norm, CPU, MALI_GPU); #endif diff --git a/src/framework/selected_rows.h b/src/framework/selected_rows.h index 9c8176285278afa69679ac3471f7a4adb0aeea3f..db49bd91159116883e5fcb148ef3ed012ec42e71 100644 --- a/src/framework/selected_rows.h +++ b/src/framework/selected_rows.h @@ -18,9 +18,9 @@ limitations under the License. */ #include #include "framework/lod_tensor.h" +#include "framework/mixed_vector.h" #include "framework/tensor.h" #include "memory/t_malloc.h" -#include "mixed_vector.h" namespace paddle_mobile { namespace framework { diff --git a/src/framework/tensor.h b/src/framework/tensor.h index 66ad328fa98aa7d36ba33dc4929567b2ff79884e..496cde98e57561ca048f356fa397f5447b9050f5 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -343,7 +343,9 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) { } else if (tensor.type() == typeid(int64_t)) { printer << tensor.data()[i] << " "; } else if (tensor.type() == typeid(int8_t)) { - printer << static_cast(tensor.data()[i]) << " "; + printer << static_cast(tensor.data()[i]) << " "; + } else if (tensor.type() == typeid(int32_t)) { + printer << tensor.data()[i] << " "; } } #endif diff --git a/src/io/executor.cpp b/src/io/executor.cpp index 100a774054035285d0e8b14ca195ad9c627a7ff7..9efec27c9df3d51a3411db87faee924b374d2ac7 100644 --- a/src/io/executor.cpp +++ b/src/io/executor.cpp @@ -80,12 +80,13 @@ Executor::Executor(const framework::Program p, int batch_size, } template -void LoadMemInternal(void **data, framework::LoDTensor *tensor) { +static void LoadMemInternal(void **data, framework::LoDTensor *tensor, + bool quant_uint8 = false) { char **data_buf = reinterpret_cast(data); int64_t size = tensor->numel(); Dtype *tensor_data = tensor->mutable_data(); - if (0) { - // TODO(hjchen2) should be moved into operator init function + if (quant_uint8) { + // should be moved into operator init function float min_value; float max_value; memcpy(&min_value, data_buf, sizeof(float)); @@ -141,7 +142,8 @@ void Executor::LoadMemory( // parse tensor from stream switch (tensor_desc.DataType()) { case framework::VARTYPE_TYPE_FP32: - LoadMemInternal(reinterpret_cast(data_buf), tensor); + LoadMemInternal(reinterpret_cast(data_buf), tensor, + program_.quantification); break; case framework::VARTYPE_TYPE_INT8: LoadMemInternal(reinterpret_cast(data_buf), tensor); diff --git a/src/operators/dequantize_op.cpp b/src/operators/dequantize_op.cpp index df835e3007fe90a5540d420077099a60023c913a..21cd96368c4938d309f08d036b172607a5afee8c 100644 --- a/src/operators/dequantize_op.cpp +++ b/src/operators/dequantize_op.cpp @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef DEQUANT_OP + #include "operators/dequantize_op.h" namespace paddle_mobile { @@ -30,3 +32,5 @@ namespace ops = paddle_mobile::operators; #ifdef PADDLE_MOBILE_CPU REGISTER_OPERATOR_CPU(dequantize, ops::DequantizeOp); #endif + +#endif diff --git a/src/operators/dequantize_op.h b/src/operators/dequantize_op.h index 4855f27fc84cc4ef5acd7a4f9cbe7ad8a70b9c75..906167a9a2f3d0e4dfa4ccf02c0d819108cd3493 100644 --- a/src/operators/dequantize_op.h +++ b/src/operators/dequantize_op.h @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef DEQUANT_OP + #pragma once #include @@ -41,3 +43,5 @@ class DequantizeOp } // namespace operators } // namespace paddle_mobile + +#endif diff --git a/src/operators/elementwise_mul_op.cpp b/src/operators/elementwise_mul_op.cpp index 920a9a546f5ea6d5ef4f41de361ba43cb9c1a7b1..335a908ace54664f0bcbca37bdcde30047edee5d 100644 --- a/src/operators/elementwise_mul_op.cpp +++ b/src/operators/elementwise_mul_op.cpp @@ -14,7 +14,7 @@ limitations under the License. */ #ifdef ELEMENTWISEMUL_OP -#include "elementwise_mul_op.h" +#include "operators/elementwise_mul_op.h" namespace paddle_mobile { namespace operators { diff --git a/src/operators/fill_constant_op.cpp b/src/operators/fill_constant_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6d7c4f44f1b769c47d6f741d139118158292a40f --- /dev/null +++ b/src/operators/fill_constant_op.cpp @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FILL_CONSTANT_OP + +#include "operators/fill_constant_op.h" + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fill_constant, ops::FillConstantOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +REGISTER_OPERATOR_MALI_GPU(fill_constant, ops::FillConstantOp); +#endif +#ifdef PADDLE_MOBILE_FPGA +REGISTER_OPERATOR_FPGA(fill_constant, ops::FillConstantOp); +#endif + +#endif diff --git a/src/operators/fill_constant_op.h b/src/operators/fill_constant_op.h new file mode 100644 index 0000000000000000000000000000000000000000..78eb162efc8ccd42b9fba363d49d1dbc4052f6b2 --- /dev/null +++ b/src/operators/fill_constant_op.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FILL_CONSTANT_OP + +#pragma once + +#include +#include "framework/data_type.h" +#include "framework/operator.h" +#include "framework/selected_rows.h" +#include "operators/math/math_function.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { +using std::string; + +template +class FillConstantOp : public framework::OperatorBase { + public: + FillConstantOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap attrs, + std::shared_ptr scope) + : framework::OperatorBase(type, inputs, outputs, attrs, + scope), + param_(inputs, outputs, attrs, *scope) {} + void RunImpl() const { + auto data_type = + static_cast<_PaddleMobile__Framework__Proto__VarType__Type>( + param_.DataDtype()); + framework::Tensor *tensor = nullptr; + auto value = param_.Value(); + auto *outvar = param_.OutVar(); + + if (outvar->template IsType()) { + tensor = outvar->template GetMutable(); + } else if (outvar->template IsType()) { + tensor = outvar->template GetMutable() + ->mutable_value(); + } else { + PADDLE_MOBILE_THROW_EXCEPTION( + "fill constant op's output only" + "supports SelectedRows and LoDTensor"); + } + tensor->Resize(framework::make_ddim(param_.Shape())); + tensor->mutable_data(framework::ToTypeIndex(data_type)); + + math::set_constant(tensor, value); + } + + void Init() {} + + void InferShape() const { + PADDLE_MOBILE_ENFORCE( + param_.Out() != nullptr, + "Output (Out) of fill_constant op should not be null."); + framework::DDim ddim = framework::make_ddim(param_.Shape()); + param_.Out()->Resize(ddim); + } + + protected: + FillConstantParam param_; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/dequantize_kernel.cpp b/src/operators/kernel/arm/dequantize_kernel.cpp index 3033c16c747855455e43454b204fef8e4a345818..cd6c8d17f1ea05e3df6f8f364c2d3d5c9976e46b 100644 --- a/src/operators/kernel/arm/dequantize_kernel.cpp +++ b/src/operators/kernel/arm/dequantize_kernel.cpp @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_MOBILE_CPU +#ifdef DEQUANT_OP #include "operators/kernel/dequantize_kernel.h" @@ -38,7 +38,8 @@ void DequantizeKernel::Compute( const int32_t *x = input->data(); float *y = output->mutable_data(); size_t size = output->numel(); - float scale = 1.f / (activation_scale * weight_scale); + // float scale = 1.f / (activation_scale * weight_scale); + float scale = activation_scale / weight_scale; #if defined(__ARM_NEON__) || defined(__ARM_NEON) size_t loop = size >> 4; size_t remain = size & 0xF; diff --git a/src/operators/kernel/arm/quantize_kernel.cpp b/src/operators/kernel/arm/quantize_kernel.cpp index e2c8efc299c858a3cbb907ce0e98b1c2f96d2bc1..e7552d2602b31f9a5c10e3d81122babae8fcf1a8 100644 --- a/src/operators/kernel/arm/quantize_kernel.cpp +++ b/src/operators/kernel/arm/quantize_kernel.cpp @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_MOBILE_CPU +#ifdef QUANT_OP #include "operators/kernel/quantize_kernel.h" #include @@ -225,7 +225,7 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale, const float *x = input->data(); int8_t *y = output->mutable_data(); size_t size = input->numel(); -#ifdef defined(__ARM_NEON__) || defined(__ARM_NEON) +#if defined(__ARM_NEON__) || defined(__ARM_NEON) size_t loop = size >> 4; size_t remain = size & 0xF; for (size_t i = 0; i < loop; ++i) { @@ -280,17 +280,18 @@ void QuantizeKernel::Compute( } max_abs = std::max(max_abs, 1e-6f); // only support int8 currently - float online_scale = 127 / max_abs; - param.online_scale_->mutable_data()[0] = online_scale; + float scale = 127 / max_abs; + param.online_scale_->mutable_data()[0] = max_abs; switch (param.round_type_) { case ROUND_NEAREST_TO_EVEN: - quantize_round_to_even(input, online_scale, output); + quantize_round_to_even(input, scale, output); break; case ROUND_NEAREST_TOWARDS_ZERO: - quantize_round_to_zero(input, online_scale, output); + quantize_round_to_zero(input, scale, output); break; case ROUND_NEAREST_AWAY_ZERO: - quantize_round_to_nearest(input, online_scale, output); + quantize_round_to_nearest(input, scale, output); + break; default: LOG(kLOG_ERROR) << "round type is not supported."; break; diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index a3e21e4b4b702630f7942f2a5171a3401f29a431..f80a8f944139566483c47daf10f9decac49650dc 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -16,24 +16,27 @@ limitations under the License. */ #pragma once #include +#include "operators/math/conv_arm_int8.h" #include "operators/math/conv_func.h" #include "operators/math/depthwise_conv_3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" +#include "operators/math/pad.h" #include "operators/math/vol2col.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { + +template inline void ConvBasic(const ConvParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); Tensor *output = param.Output(); - output->mutable_data(); int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); + const std::vector strides = param.Strides(); + const std::vector paddings = param.Paddings(); + const std::vector dilations = param.Dilations(); const int batch_size = static_cast(input->dims()[0]); @@ -57,7 +60,7 @@ inline void ConvBasic(const ConvParam ¶m) { Tensor col; Tensor col_matrix; if (is_expand) { - col.mutable_data(col_shape); + col.mutable_data(col_shape); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } @@ -76,8 +79,8 @@ inline void ConvBasic(const ConvParam ¶m) { int in_step = static_cast(input->dims()[1]) / groups; int out_step = static_cast(output->dims()[1]) / groups; - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; for (int i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); @@ -96,6 +99,7 @@ inline void ConvBasic(const ConvParam ¶m) { std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &col); + } else if (data_dim == 3U) { // vol2col vol2col(in_slice, dilations, strides, paddings, &col); @@ -104,29 +108,85 @@ inline void ConvBasic(const ConvParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, + + math::matmul(filter_slice, false, col_matrix, false, static_cast(1), &out_slice, static_cast(0)); } } } +inline void ConvCompute_int8(const ConvParam ¶m) { + typedef void (*ConvFunc)(const Tensor &input, const Tensor &kernel, + Tensor *output); + static ConvFunc conv_funcs_table[7][5] = { + {0, 0, 0, 0, 0}, // k = 1 + {0, 0, 0, 0, 0}, {conv3x3s1_int8, 0, 0, 0, 0}, // k = 3 + {0, 0, 0, 0, 0}, {conv5x5s1_int8, 0, 0, 0, 0}, // k = 5 + {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, // k = 7 + }; + const Tensor *input = param.Input(); + Tensor *filter = param.Filter(); + Tensor *output = param.Output(); + int groups = param.Groups(); + const std::vector &strides = param.Strides(); + const std::vector &paddings = param.Paddings(); + const std::vector &dilations = param.Dilations(); + int kernel_h = filter->dims()[2]; + int kernel_w = filter->dims()[3]; + output->mutable_data(); + + ConvFunc conv_func = 0; + if (strides[1] == strides[0] && strides[1] < 6 && kernel_h == kernel_w && + kernel_h < 8 && groups == 1 && dilations[0] == dilations[1] && + dilations[1] == 1) { + conv_func = conv_funcs_table[kernel_h - 1][strides[0] - 1]; + } + if (conv_func) { + int batch_size = input->dims()[0]; + math::PadFunctor pad; + + Tensor input_pad; + for (int i = 0; i < batch_size; ++i) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); + if (paddings[0] == 0 && paddings[1] == 0) { + input_pad = in_batch; + } else { + framework::DDim pad_shape = in_batch.dims(); + pad_shape[2] += 2 * paddings[0]; + pad_shape[3] += 2 * paddings[1]; + input_pad.mutable_data(pad_shape); + pad(in_batch, paddings[0], paddings[1], &input_pad); + } + conv_func(input_pad, *filter, &out_batch); + } + } else { + ConvBasic(param); + } +} + template void ConvCompute(const ConvParam ¶m) { - if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { - math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), - nullptr, false); - } else if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3) { - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), nullptr, param.Output(), false); + if (param.Input()->type() == typeid(int8_t)) { + ConvCompute_int8(param); } else { - ConvBasic(param); + param.Output()->mutable_data(); + if (param.Groups() == param.Input()->dims()[1] && + param.Input()->dims()[1] == param.Output()->dims()[1] && + param.Filter()->dims()[2] == param.Filter()->dims()[3] && + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), + nullptr, false); + } else if (param.Groups() == param.Input()->dims()[1] && + param.Input()->dims()[1] == param.Output()->dims()[1] && + param.Filter()->dims()[2] == param.Filter()->dims()[3] && + param.Filter()->dims()[2] == 3) { + math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), + param.Filter(), nullptr, param.Output(), false); + } else { + ConvBasic(param); + } } } diff --git a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h index 2a1afb3cf6fdbdc0a80cec5558c2b42fec6699f3..ff5d5d4b2a351d075fcecce209063aa66e026754 100644 --- a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h @@ -44,7 +44,7 @@ void DepthwiseConvCompute(const ConvParam ¶m) { Bias, false); } else { - ConvBasic(param); + ConvBasic(param); } } diff --git a/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h b/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h index ace72b6faddb04ee3547f1b2bc01461d8c9f2e98..0c01ef0072444479d2d2e2f7676b842d89e432ec 100644 --- a/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h @@ -15,8 +15,12 @@ limitations under the License. */ #ifdef ELEMENTWISEADD_OP #pragma once + #include "operators/math/elementwise_op_function.h" #include "operators/op_param.h" +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#include +#endif namespace paddle_mobile { namespace operators { @@ -33,8 +37,61 @@ void ElementwiseAddCompute(const ElementwiseAddParam ¶m) { Tensor *Out = param.Out(); Out->mutable_data(); int axis = param.Axis(); +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + const auto &x_dims = input_x->dims(); + const auto &y_dims = input_y->dims(); + /// axis = -1 represent the last dimensions. + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + size_t batch = 1; + size_t channels = 1; + size_t elementwise_num = 1; + for (int i = 0; i < axis; ++i) { + batch *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + channels *= y_dims[i]; + } + for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) { + elementwise_num *= x_dims[i]; + } + const float *bias_data = input_y->data(); + const float *input_data = input_x->data(); + float *output_data = Out->mutable_data(); + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + size_t offset = (i * channels + j) * elementwise_num; + const float *input = input_data + offset; + const float *bias = bias_data + j; + float *output = output_data + offset; + + int loop = elementwise_num >> 0x4; + int remain = elementwise_num & 0xF; + for (int k = 0; k < loop; ++k) { + float32x4_t rb = vdupq_n_f32(*bias); + float32x4_t r0 = vld1q_f32(input); + float32x4_t r1 = vld1q_f32(input + 4); + float32x4_t r2 = vld1q_f32(input + 8); + float32x4_t r3 = vld1q_f32(input + 12); + r0 = vaddq_f32(r0, rb); + r1 = vaddq_f32(r1, rb); + r2 = vaddq_f32(r2, rb); + r3 = vaddq_f32(r3, rb); + vst1q_f32(output, r0); + vst1q_f32(output + 4, r1); + vst1q_f32(output + 8, r2); + vst1q_f32(output + 12, r3); + input += 16; + output += 16; + } + for (int k = 0; k < remain; ++k) { + output[k] = input[k] + *bias; + } + } + } +#else ElementwiseComputeEx, float>(input_x, input_y, axis, AddFunctor(), Out); +#endif } template class ElementwiseAddKernel; diff --git a/src/operators/kernel/central-arm-func/relu_arm_func.h b/src/operators/kernel/central-arm-func/relu_arm_func.h index d68569c0a5c0730d96a89cd534b2a89c0d3a9bff..38b2e6f334b4b24460f72450b01e4bdc2a6ff616 100644 --- a/src/operators/kernel/central-arm-func/relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/relu_arm_func.h @@ -17,6 +17,9 @@ limitations under the License. */ #include #include "operators/op_param.h" +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#include +#endif namespace paddle_mobile { namespace operators { @@ -37,71 +40,100 @@ void ReluCompute(const ReluParam ¶m) { auto *out_ptr = out->mutable_data(); int numel = input_x->numel(); - // if (numel > 64) { - // asm volatile( - // "pld [%[input_x_ptr], #0] \n\t" - // "vmov.f32 q8, #0.0 \n\t" - // "subs %[num], %[num], #32 \n\t" - // "blt end_num_%= \n\t" - // "loop_num_%=: \n\t" - // "pld [%[input_x_ptr], #1024] \n\t" - // - // "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t" - // "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t" - // "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t" - // "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t" - // - // "vmax.f32 q0, q0, q8 \n\t" - // "vmax.f32 q1, q1, q8 \n\t" - // "vmax.f32 q2, q2, q8 \n\t" - // "vmax.f32 q3, q3, q8 \n\t" - // "vmax.f32 q4, q4, q8 \n\t" - // "vmax.f32 q5, q5, q8 \n\t" - // "vmax.f32 q6, q6, q8 \n\t" - // "vmax.f32 q7, q7, q8 \n\t" - // - // "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t" - // "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t" - // "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" - // "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t" - // - // "subs %[num], %[num], #32 \n\t" - // "bge loop_num_%= \n\t" - // "end_num_%=: \n\t" - // "cmp %[num], #0 \n\t" - // "bge end_%= \n\t" - // "mov r6, #4 \n\t" - // "mul r5, %[num], r6 \n\t" - // "add %[input_x_ptr], %[input_x_ptr], r5 \n\t" - // "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t" - // "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t" - // "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t" - // "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t" - // "vmax.f32 q0, q0, q8 \n\t" - // "vmax.f32 q1, q1, q8 \n\t" - // "vmax.f32 q2, q2, q8 \n\t" - // "vmax.f32 q3, q3, q8 \n\t" - // "vmax.f32 q4, q4, q8 \n\t" - // "vmax.f32 q5, q5, q8 \n\t" - // "vmax.f32 q6, q6, q8 \n\t" - // "vmax.f32 q7, q7, q8 \n\t" - // "add %[out_ptr], %[out_ptr], r5 \n\t" - // "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t" - // "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t" - // "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" - // "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t" - // "end_%=: \n\t" - // : - // : - // [out_ptr] "r"(out_ptr), [input_x_ptr] "r"(input_x_ptr), [num] - // "r"(numel) : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", - // "q7", "q8", "r5", - // "r6"); - // } else { - ReluFunctor func_; - math::Transform trans; - trans(input_x_ptr, input_x_ptr + numel, out_ptr, func_); - // } +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#if __aarch64__ + if (numel > 0) { + int loop = numel >> 0x4; + int remain = numel & 0xF; + float32x4_t zero = vdupq_n_f32(0.f); + for (int i = 0; i < loop; ++i) { + float32x4_t r0 = vld1q_f32(input_x_ptr); + float32x4_t r1 = vld1q_f32(input_x_ptr + 4); + float32x4_t r2 = vld1q_f32(input_x_ptr + 8); + float32x4_t r3 = vld1q_f32(input_x_ptr + 12); + r0 = vmaxq_f32(r0, zero); + r1 = vmaxq_f32(r1, zero); + r2 = vmaxq_f32(r2, zero); + r3 = vmaxq_f32(r3, zero); + vst1q_f32(out_ptr, r0); + vst1q_f32(out_ptr + 4, r1); + vst1q_f32(out_ptr + 8, r2); + vst1q_f32(out_ptr + 12, r3); + input_x_ptr += 16; + out_ptr += 16; + } + for (int i = 0; i < remain; ++i) { + out_ptr[i] = (input_x_ptr[i] > 0) * input_x_ptr[i]; + } +#else + if (numel > 64) { + asm volatile( + "pld [%[input_x_ptr], #0] \n\t" + "vmov.f32 q8, #0.0 \n\t" + "subs %[num], %[num], #32 \n\t" + "blt end_num_%= \n\t" + "loop_num_%=: \n\t" + "pld [%[input_x_ptr], #1024] \n\t" + + "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t" + "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t" + "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t" + + "vmax.f32 q0, q0, q8 \n\t" + "vmax.f32 q1, q1, q8 \n\t" + "vmax.f32 q2, q2, q8 \n\t" + "vmax.f32 q3, q3, q8 \n\t" + "vmax.f32 q4, q4, q8 \n\t" + "vmax.f32 q5, q5, q8 \n\t" + "vmax.f32 q6, q6, q8 \n\t" + "vmax.f32 q7, q7, q8 \n\t" + + "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t" + "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t" + "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" + "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t" + + "subs %[num], %[num], #32 \n\t" + "bge loop_num_%= \n\t" + "end_num_%=: \n\t" + "cmp %[num], #0 \n\t" + "bge end_%= \n\t" + "mov r6, #4 \n\t" + "mul r5, %[num], r6 \n\t" + "add %[input_x_ptr], %[input_x_ptr], r5 \n\t" + "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t" + "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t" + "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t" + "vmax.f32 q0, q0, q8 \n\t" + "vmax.f32 q1, q1, q8 \n\t" + "vmax.f32 q2, q2, q8 \n\t" + "vmax.f32 q3, q3, q8 \n\t" + "vmax.f32 q4, q4, q8 \n\t" + "vmax.f32 q5, q5, q8 \n\t" + "vmax.f32 q6, q6, q8 \n\t" + "vmax.f32 q7, q7, q8 \n\t" + "add %[out_ptr], %[out_ptr], r5 \n\t" + "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t" + "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t" + "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" + "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t" + "end_%=: \n\t" + : + : + [out_ptr] "r"(out_ptr), [input_x_ptr] "r"(input_x_ptr), [num] "r"(numel) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "r5", + "r6"); +#endif + } else { +#endif + ReluFunctor func_; + math::Transform trans; + trans(input_x_ptr, input_x_ptr + numel, out_ptr, func_); +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + } +#endif } } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/sum_arm_func.h b/src/operators/kernel/central-arm-func/sum_arm_func.h index 25c1c51c7abd62a900665197ab4e221b76a3fa04..36c7ac9694bde85fbf702ad8adf5ffda8744da1d 100644 --- a/src/operators/kernel/central-arm-func/sum_arm_func.h +++ b/src/operators/kernel/central-arm-func/sum_arm_func.h @@ -15,11 +15,14 @@ limitations under the License. */ #ifdef SUM_OP #pragma once +#include #include "operators/math/selected_rows_functor.h" namespace paddle_mobile { namespace operators { + using LoDTensorArray = std::vector; + template void SumCompute(const SumParam ¶m) { auto inputsvars = param.InputsVars(); @@ -63,31 +66,21 @@ void SumCompute(const SumParam ¶m) { std::unique_ptr in0; if (in_place) { // If is in_place, we store the input[0] to in0 - auto *in_sel0 = inputsvars[0]->Get(); + auto *in_sel0 = inputsvars[0]->Get(); auto &rows = in_sel0->rows(); - //#ifdef PADDLE_WITH_CUDA - // std::vector rows_in_cpu; - // rows_in_cpu.reserve(rows.size()); - // for (auto item : rows) { - // rows_in_cpu.push_back(item); - // } - // in0.reset(new framework::SelectedRows(rows_in_cpu, - // in_sel0.height())); - //#else in0.reset(new framework::SelectedRows(rows, in_sel0->height())); - //#endif in0->mutable_value()->ShareDataWith(in_sel0->value()); } - auto get_selected_row = [&](size_t i) -> const SelectedRows & { + auto get_selected_row = [&](size_t i) -> const framework::SelectedRows & { if (i == 0 && in0) { return *in0.get(); } else { - return *(inputsvars[i]->Get()); + return *(inputsvars[i]->Get()); } }; - auto *out = outvar->GetMutable(); + auto *out = outvar->GetMutable(); out->mutable_rows()->clear(); auto *out_value = out->mutable_value(); @@ -150,8 +143,6 @@ void SumCompute(const SumParam ¶m) { } } } else { - if (outvar->IsType()) { - } PADDLE_MOBILE_THROW_EXCEPTION( "Unexpected branch, output variable type is %s", outvar->Type().name()); } diff --git a/src/operators/kernel/dequantize_kernel.h b/src/operators/kernel/dequantize_kernel.h index 3d0437875bb64a0d32948a05725214d666ebfa01..d147e3f94ab87165cceac886289e74747906e047 100644 --- a/src/operators/kernel/dequantize_kernel.h +++ b/src/operators/kernel/dequantize_kernel.h @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef DEQUANT_OP + #pragma once #include "framework/operator.h" @@ -30,3 +32,5 @@ class DequantizeKernel } // namespace operators } // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/elementwise_mul_kernel.h b/src/operators/kernel/elementwise_mul_kernel.h index d1e326c6c4e7830c11c387dca03da9858c9a37dd..63f0df4815dc143e482140a855eb254bd016d50c 100644 --- a/src/operators/kernel/elementwise_mul_kernel.h +++ b/src/operators/kernel/elementwise_mul_kernel.h @@ -23,8 +23,6 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -using namespace framework; - template class ElementwiseMulKernel : public framework::OpKernelBase class SumKernel : public framework::OpKernelBase> { diff --git a/src/operators/math/conv3x3_arm_int8.cpp b/src/operators/math/conv3x3_arm_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..283dcb2255b43052dcaf2d622ad629e923810a82 --- /dev/null +++ b/src/operators/math/conv3x3_arm_int8.cpp @@ -0,0 +1,761 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CONV_OP + +#include "operators/math/conv_arm_int8.h" + +namespace paddle_mobile { +namespace operators { + +void conv3x3s1_int8(const framework::Tensor& input, + const framework::Tensor& weight, + framework::Tensor* output) { +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + const int8_t* in_data = input.data(); + const int8_t* w_data = weight.data(); + int32_t* out_data = output->mutable_data(); + // make sure that batch size is 1 + int input_c = input.dims()[1]; + int input_h = input.dims()[2]; + int input_w = input.dims()[3]; + int output_c = output->dims()[1]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int image_size = input_h * input_w; + int out_image_size = output_h * output_w; + memset(out_data, 0, output_c * out_image_size * sizeof(int32_t)); +#if __aarch64__ + // TODO(hjchen2) +#else + int oc = 0; + #pragma omp parallel for + for (; oc < output_c - 1; oc += 2) { + for (int ic = 0; ic < input_c; ++ic) { + const int8_t* kernel0 = w_data + (oc * input_c + ic) * 9; + const int8_t* kernel1 = w_data + ((oc + 1) * input_c + ic) * 9; + int32_t* output0 = out_data + oc * out_image_size; + int32_t* output0n = output0 + output_w; + int32_t* output1 = out_data + (oc + 1) * out_image_size; + int32_t* output1n = output1 + output_w; + + int oh = 0; + for (; oh < output_h - 1; oh += 2) { + const int8_t* r0 = in_data + ic * image_size + oh * input_w; + const int8_t* r1 = r0 + input_w; + const int8_t* r2 = r1 + input_w; + const int8_t* r3 = r2 + input_w; + + int ow = output_w >> 3; + int remain = output_w & 0x7; + if (ow > 0) { + asm volatile( + "vld1.8 {d0}, [%[kernel0]] \n" + "ldr r5, [%[kernel0], #8] \n" + "vld1.8 {d1}, [%[kernel1]] \n" + "ldr r6, [%[kernel1], #8] \n" + + "0: \n" + "vld1.8 {d2-d3}, [%[r0]] \n" // r0 + "add %[r0], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + "vdup.s8 d6, d0[0] \n" + "vdup.s8 d7, d0[1] \n" + "vdup.s8 d8, d0[2] \n" + "vdup.s8 d9, d1[0] \n" + "vdup.s8 d10, d1[1] \n" + "vdup.s8 d11, d1[2] \n" + "vmull.s8 q6, d2, d6 \n" + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddl.s16 q12, d12, d14 \n" + "vaddl.s16 q13, d13, d15 \n" + "vmull.s8 q6, d2, d9 \n" + "vmull.s8 q7, d4, d10 \n" + "vmlal.s8 q6, d5, d11 \n" + "vaddl.s16 q14, d12, d14 \n" + "vaddl.s16 q15, d13, d15 \n" + + "vld1.8 {d2-d3}, [%[r1]] \n" // r1 + "add %[r1], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + + "vmull.s8 q6, d2, d6 \n" // next row + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddl.s16 q8, d12, d14 \n" + "vaddl.s16 q9, d13, d15 \n" + "vmull.s8 q6, d2, d9 \n" + "vmull.s8 q7, d4, d10 \n" + "vmlal.s8 q6, d5, d11 \n" + "vaddl.s16 q10, d12, d14 \n" + "vaddl.s16 q11, d13, d15 \n" + + "vdup.s8 d6, d0[3] \n" + "vdup.s8 d7, d0[4] \n" + "vdup.s8 d8, d0[5] \n" + "vdup.s8 d9, d1[3] \n" + "vdup.s8 d10, d1[4] \n" + "vdup.s8 d11, d1[5] \n" + "vmull.s8 q6, d2, d6 \n" + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddw.s16 q12, q12, d12 \n" + "vaddw.s16 q13, q13, d13 \n" + "vaddw.s16 q12, q12, d14 \n" + "vaddw.s16 q13, q13, d15 \n" + "vmull.s8 q6, d2, d9 \n" + "vmull.s8 q7, d4, d10 \n" + "vmlal.s8 q6, d5, d11 \n" + "vaddw.s16 q14, q14, d12 \n" + "vaddw.s16 q15, q15, d13 \n" + "vaddw.s16 q14, q14, d14 \n" + "vaddw.s16 q15, q15, d15 \n" + + "vld1.8 {d2-d3}, [%[r2]] \n" // r2 + "add %[r2], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + + "vmull.s8 q6, d2, d6 \n" // next row + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddw.s16 q8, q8, d12 \n" + "vaddw.s16 q8, q8, d14 \n" + "vaddw.s16 q9, q9, d13 \n" + "vaddw.s16 q9, q9, d15 \n" + "vmull.s8 q6, d2, d9 \n" + "vmull.s8 q7, d4, d10 \n" + "vmlal.s8 q6, d5, d11 \n" + "vaddw.s16 q10, q10, d12 \n" + "vaddw.s16 q11, q11, d13 \n" + "vaddw.s16 q10, q10, d14 \n" + "vaddw.s16 q11, q11, d15 \n" + + "vdup.s8 d6, d0[6] \n" + "vdup.s8 d7, d0[7] \n" + "vdup.s8 d8, r5 \n" + "vdup.s8 d9, d1[6] \n" + "vdup.s8 d10, d1[7] \n" + "vdup.s8 d11, r6 \n" + "vmull.s8 q6, d2, d6 \n" + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddw.s16 q12, q12, d12 \n" + "vaddw.s16 q13, q13, d13 \n" + "vaddw.s16 q12, q12, d14 \n" + "vaddw.s16 q13, q13, d15 \n" + + "vld1.32 {d12-d15}, [%[output0]] \n" + "vadd.s32 q6, q6, q12 \n" + "vadd.s32 q7, q7, q13 \n" + "vst1.32 {d12-d15}, [%[output0]]! \n" + + "vmull.s8 q6, d2, d9 \n" + "vmull.s8 q7, d4, d10 \n" + "vmlal.s8 q6, d5, d11 \n" + "vaddw.s16 q14, q14, d12 \n" + "vaddw.s16 q15, q15, d13 \n" + "vaddw.s16 q14, q14, d14 \n" + "vaddw.s16 q15, q15, d15 \n" + + "vld1.32 {d12-d15}, [%[output1]] \n" + "vadd.s32 q6, q6, q14 \n" + "vadd.s32 q7, q7, q15 \n" + "vst1.32 {d12-d15}, [%[output1]]! \n" + + "vld1.8 {d2-d3}, [%[r3]] \n" // r3 + "add %[r3], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + + "vmull.s8 q6, d2, d6 \n" // next row + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddw.s16 q8, q8, d12 \n" + "vaddw.s16 q9, q9, d15 \n" + "vaddw.s16 q8, q8, d14 \n" + "vaddw.s16 q9, q9, d13 \n" + + "vld1.32 {d12-d15}, [%[output0n]] \n" + "vadd.s32 q6, q6, q8 \n" + "vadd.s32 q7, q7, q9 \n" + "vst1.32 {d12-d15}, [%[output0n]]! \n" + + "vmull.s8 q6, d2, d9 \n" + "vmull.s8 q7, d4, d10 \n" + "vmlal.s8 q6, d5, d11 \n" + "vaddw.s16 q10, q10, d12 \n" + "vaddw.s16 q11, q11, d15 \n" + "vaddw.s16 q10, q10, d14 \n" + "vaddw.s16 q11, q11, d13 \n" + + "vld1.32 {d12-d15}, [%[output1n]] \n" + "vadd.s32 q6, q6, q10 \n" + "vadd.s32 q7, q7, q11 \n" + "vst1.32 {d12-d15}, [%[output1n]]! \n" + + "subs %[ow], #1 \n" + "bne 0b \n" + : [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3), + [ow] "+r"(ow), [output0] "+r"(output0), [output1] "+r"(output1), + [output0n] "+r"(output0n), [output1n] "+r"(output1n) + : [kernel0] "r"(kernel0), [kernel1] "r"(kernel1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r5", + "r6"); + } + if (remain > 0) { + asm volatile( + "vld1.8 {d0}, [%[kernel0]] \n" + "ldr r5, [%[kernel0], #8] \n" + "vld1.8 {d1}, [%[kernel1]] \n" + "ldr r6, [%[kernel1], #8] \n" + + "0: \n" + "vld1.8 d4, [%[r0]] \n" + "vld1.8 d5, [%[r1]] \n" + "vld1.8 d6, [%[r2]] \n" + "vld1.8 d7, [%[r3]] \n" + "add %[r0], #1 \n" + "add %[r1], #1 \n" + "add %[r2], #1 \n" + "add %[r3], #1 \n" + "vdup.s8 d2, r5 \n" + "vdup.s8 d3, r6 \n" + "vext.8 d8, d0, d2, #3 \n" + "vext.8 d9, d0, d2, #6 \n" + "vext.8 d10, d1, d3, #3 \n" + "vext.8 d11, d1, d3, #6 \n" + + "vmull.s8 q6, d4, d0 \n" + "vmull.s8 q7, d5, d8 \n" + "vmlal.s8 q6, d6, d9 \n" + "vaddl.s16 q12, d12, d14 \n" + "vdup.s32 d2, d24[1] \n" + "vadd.s32 d24, d24, d2 \n" + "vadd.s32 d24, d24, d25 \n" + "vmull.s8 q6, d4, d1 \n" + "vmull.s8 q7, d5, d10 \n" + "vmlal.s8 q6, d6, d11 \n" + "vaddl.s16 q13, d12, d14 \n" + "vdup.s32 d2, d26[1] \n" + "vadd.s32 d26, d26, d2 \n" + "vadd.s32 d26, d26, d27 \n" + + "ldr r7, [%[output0]] \n" + "vdup.s32 d14, r7 \n" + "vadd.s32 d14, d14, d24 \n" + "vst1.32 d14[0], [%[output0]]! \n" + "ldr r7, [%[output1]] \n" + "vdup.s32 d14, r7 \n" + "vadd.s32 d14, d14, d26 \n" + "vst1.32 d14[0], [%[output1]]! \n" + + "vmull.s8 q6, d5, d0 \n" + "vmull.s8 q7, d6, d8 \n" + "vmlal.s8 q6, d7, d9 \n" + "vaddl.s16 q12, d12, d14 \n" + "vdup.s32 d2, d24[1] \n" + "vadd.s32 d24, d24, d2 \n" + "vadd.s32 d24, d24, d25 \n" + "vmull.s8 q6, d5, d1 \n" + "vmull.s8 q7, d6, d10 \n" + "vmlal.s8 q6, d7, d11 \n" + "vaddl.s16 q13, d12, d14 \n" + "vdup.s32 d2, d26[1] \n" + "vadd.s32 d26, d26, d2 \n" + "vadd.s32 d26, d26, d27 \n" + + "ldr r7, [%[output0n]] \n" + "vdup.s32 d14, r7 \n" + "vadd.s32 d14, d14, d24 \n" + "vst1.32 d14[0], [%[output0n]]! \n" + "ldr r7, [%[output1n]] \n" + "vdup.s32 d14, r7 \n" + "vadd.s32 d14, d14, d26 \n" + "vst1.32 d14[0], [%[output1n]]! \n" + + "subs %[remain], #1 \n" + "bne 0b \n" + : [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3), + [remain] "+r"(remain), [output0] "+r"(output0), + [output1] "+r"(output1), [output0n] "+r"(output0n), + [output1n] "+r"(output1n) + : [kernel0] "r"(kernel0), [kernel1] "r"(kernel1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "r5", "r6", "r7"); + } + output0 += output_w; + output1 += output_w; + output0n += output_w; + output1n += output_w; + } + // remain output height + for (; oh < output_h; ++oh) { + const int8_t* r0 = in_data + ic * image_size + oh * input_w; + const int8_t* r1 = r0 + input_w; + const int8_t* r2 = r1 + input_w; + const int8_t* r3 = r2 + input_w; + const int8_t* r4 = r3 + input_w; + + int ow = output_w >> 3; + int remain = output_w & 0x7; + if (ow > 0) { + asm volatile( + "vld1.8 {d0}, [%[kernel0]] \n" + "ldr r5, [%[kernel0], #8] \n" + "vld1.8 {d1}, [%[kernel1]] \n" + "ldr r6, [%[kernel1], #8] \n" + + "0: \n" + "vld1.8 {d2-d3}, [%[r0]] \n" // r0 + "add %[r0], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + "vdup.s8 d6, d0[0] \n" + "vdup.s8 d7, d0[1] \n" + "vdup.s8 d8, d0[2] \n" + "vdup.s8 d9, d1[0] \n" + "vdup.s8 d10, d1[1] \n" + "vdup.s8 d11, d1[2] \n" + "vmull.s8 q6, d2, d6 \n" + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddl.s16 q12, d12, d14 \n" + "vaddl.s16 q13, d13, d15 \n" + "vmull.s8 q6, d2, d9 \n" + "vmull.s8 q7, d4, d10 \n" + "vmlal.s8 q6, d5, d11 \n" + "vaddl.s16 q14, d12, d14 \n" + "vaddl.s16 q15, d13, d15 \n" + + "vld1.8 {d2-d3}, [%[r1]] \n" // r1 + "add %[r1], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + "vdup.s8 d6, d0[3] \n" + "vdup.s8 d7, d0[4] \n" + "vdup.s8 d8, d0[5] \n" + "vdup.s8 d9, d1[3] \n" + "vdup.s8 d10, d1[4] \n" + "vdup.s8 d11, d1[5] \n" + "vmull.s8 q6, d2, d6 \n" + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddw.s16 q12, q12, d12 \n" + "vaddw.s16 q12, q12, d14 \n" + "vaddw.s16 q13, q13, d13 \n" + "vaddw.s16 q13, q13, d15 \n" + "vmull.s8 q6, d2, d9 \n" + "vmull.s8 q7, d4, d10 \n" + "vmlal.s8 q6, d5, d11 \n" + "vaddw.s16 q14, q14, d12 \n" + "vaddw.s16 q14, q14, d14 \n" + "vaddw.s16 q15, q15, d13 \n" + "vaddw.s16 q15, q15, d15 \n" + + "vld1.8 {d2-d3}, [%[r2]] \n" // r2 + "add %[r2], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + "vdup.s8 d6, d0[6] \n" + "vdup.s8 d7, d0[7] \n" + "vdup.s8 d8, r5 \n" + "vdup.s8 d9, d1[6] \n" + "vdup.s8 d10, d1[7] \n" + "vdup.s8 d11, r6 \n" + "vmull.s8 q6, d2, d6 \n" + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddw.s16 q12, q12, d12 \n" + "vaddw.s16 q12, q12, d14 \n" + "vaddw.s16 q13, q13, d13 \n" + "vaddw.s16 q13, q13, d15 \n" + "vmull.s8 q6, d2, d9 \n" + "vmull.s8 q7, d4, d10 \n" + "vmlal.s8 q6, d5, d11 \n" + "vaddw.s16 q14, q14, d12 \n" + "vaddw.s16 q14, q14, d14 \n" + "vaddw.s16 q15, q15, d13 \n" + "vaddw.s16 q15, q15, d15 \n" + + "vld1.32 {d12-d15}, [%[output0]] \n" + "vadd.s32 q6, q6, q12 \n" + "vadd.s32 q7, q7, q13 \n" + "vst1.32 {d12-d15}, [%[output0]]! \n" + "vld1.32 {d12-d15}, [%[output1]] \n" + "vadd.s32 q6, q6, q14 \n" + "vadd.s32 q7, q7, q15 \n" + "vst1.32 {d12-d15}, [%[output1]]! \n" + + "subs %[ow], #1 \n" + "bne 0b \n" + : [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [ow] "+r"(ow), + [output0] "+r"(output0), [output1] "+r"(output1) + : [kernel0] "r"(kernel0), [kernel1] "r"(kernel1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r5", + "r6"); + } + + if (remain > 0) { + asm volatile( + "vld1.8 {d0}, [%[kernel0]] \n" + "ldr r5, [%[kernel0], #8] \n" + "vld1.8 {d1}, [%[kernel1]] \n" + "ldr r6, [%[kernel1], #8] \n" + + "0: \n" + "vld1.8 d4, [%[r0]] \n" + "vld1.8 d5, [%[r1]] \n" + "vld1.8 d6, [%[r2]] \n" + "add %[r0], #1 \n" + "add %[r1], #1 \n" + "add %[r2], #1 \n" + "vdup.s8 d2, r5 \n" + "vdup.s8 d3, r6 \n" + "vext.8 d8, d0, d2, #3 \n" + "vext.8 d9, d0, d2, #6 \n" + "vext.8 d10, d1, d3, #3 \n" + "vext.8 d11, d1, d3, #6 \n" + + "vmull.s8 q6, d4, d0 \n" + "vmull.s8 q7, d5, d8 \n" + "vmlal.s8 q6, d6, d9 \n" + "vaddl.s16 q12, d12, d14 \n" + "vdup.s32 d2, d24[1] \n" + "vadd.s32 d24, d24, d2 \n" + "vadd.s32 d24, d24, d25 \n" + "vmull.s8 q6, d4, d1 \n" + "vmull.s8 q7, d5, d10 \n" + "vmlal.s8 q6, d6, d11 \n" + "vaddl.s16 q13, d12, d14 \n" + "vdup.s32 d2, d26[1] \n" + "vadd.s32 d26, d26, d2 \n" + "vadd.s32 d26, d26, d27 \n" + + "ldr r7, [%[output0]] \n" + "vdup.s32 d14, r7 \n" + "vadd.s32 d14, d14, d24 \n" + "vst1.32 d14[0], [%[output0]]! \n" + "ldr r7, [%[output1]] \n" + "vdup.s32 d14, r7 \n" + "vadd.s32 d14, d14, d26 \n" + "vst1.32 d14[0], [%[output1]]! \n" + + "subs %[remain], #1 \n" + "bne 0b \n" + : [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), + [remain] "+r"(remain), [output0] "+r"(output0), + [output1] "+r"(output1) + : [kernel0] "r"(kernel0), [kernel1] "r"(kernel1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "r5", "r6", "r7"); + } + } + } + } + + for (; oc < output_c; ++oc) { + for (int ic = 0; ic < input_c; ++ic) { + const int8_t* kernel0 = w_data + (oc * input_c + ic) * 9; + int32_t* output0 = out_data + oc * out_image_size; + int32_t* output0n = output0 + output_w; + + int oh = 0; + for (; oh < output_h - 1; oh += 2) { + const int8_t* r0 = in_data + ic * image_size + oh * input_w; + const int8_t* r1 = r0 + input_w; + const int8_t* r2 = r1 + input_w; + const int8_t* r3 = r2 + input_w; + + int ow = output_w >> 3; + int remain = output_w & 0x7; + if (ow > 0) { + asm volatile( + "vld1.8 {d0}, [%[kernel0]] \n" + "ldr r5, [%[kernel0], #8] \n" + + "0: \n" + "vld1.8 {d2-d3}, [%[r0]] \n" // r0 + "add %[r0], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + "vdup.s8 d6, d0[0] \n" + "vdup.s8 d7, d0[1] \n" + "vdup.s8 d8, d0[2] \n" + "vmull.s8 q6, d2, d6 \n" + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddl.s16 q12, d12, d14 \n" + "vaddl.s16 q13, d13, d15 \n" + + "vld1.8 {d2-d3}, [%[r1]] \n" // r1 + "add %[r1], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + + "vmull.s8 q6, d2, d6 \n" // next row + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddl.s16 q8, d12, d14 \n" + "vaddl.s16 q9, d13, d15 \n" + + "vdup.s8 d6, d0[3] \n" + "vdup.s8 d7, d0[4] \n" + "vdup.s8 d8, d0[5] \n" + "vmull.s8 q6, d2, d6 \n" + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddw.s16 q12, q12, d12 \n" + "vaddw.s16 q12, q12, d14 \n" + "vaddw.s16 q13, q13, d13 \n" + "vaddw.s16 q13, q13, d15 \n" + + "vld1.8 {d2-d3}, [%[r2]] \n" // r2 + "add %[r2], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + + "vmull.s8 q6, d2, d6 \n" // next row + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddw.s16 q8, q8, d12 \n" + "vaddw.s16 q8, q8, d14 \n" + "vaddw.s16 q9, q9, d13 \n" + "vaddw.s16 q9, q9, d15 \n" + + "vdup.s8 d6, d0[6] \n" + "vdup.s8 d7, d0[7] \n" + "vdup.s8 d8, r5 \n" + "vmull.s8 q6, d2, d6 \n" + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddw.s16 q12, q12, d12 \n" + "vaddw.s16 q12, q12, d14 \n" + "vaddw.s16 q13, q13, d13 \n" + "vaddw.s16 q13, q13, d15 \n" + + "vld1.32 {d12-d15}, [%[output0]] \n" + "vadd.s32 q6, q6, q12 \n" + "vadd.s32 q7, q7, q13 \n" + "vst1.32 {d12-d15}, [%[output0]]! \n" + + "vld1.8 {d2-d3}, [%[r3]] \n" // r3 + "add %[r3], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + + "vmull.s8 q6, d2, d6 \n" // next row + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddw.s16 q8, q8, d12 \n" + "vaddw.s16 q8, q8, d14 \n" + "vaddw.s16 q9, q9, d13 \n" + "vaddw.s16 q9, q9, d15 \n" + + "vld1.32 {d12-d15}, [%[output0n]] \n" + "vadd.s32 q6, q6, q8 \n" + "vadd.s32 q7, q7, q9 \n" + "vst1.32 {d12-d15}, [%[output0n]]! \n" + + "subs %[ow], #1 \n" + "bne 0b \n" + : [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3), + [ow] "+r"(ow), [output0] "+r"(output0), + [output0n] "+r"(output0n) + : [kernel0] "r"(kernel0) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r5"); + } + if (remain > 0) { + asm volatile( + "vld1.8 {d0}, [%[kernel0]] \n" + "ldr r5, [%[kernel0], #8] \n" + + "0: \n" + "vld1.8 d4, [%[r0]] \n" + "vld1.8 d5, [%[r1]] \n" + "vld1.8 d6, [%[r2]] \n" + "vld1.8 d7, [%[r3]] \n" + "add %[r0], #1 \n" + "add %[r1], #1 \n" + "add %[r2], #1 \n" + "add %[r3], #1 \n" + "vdup.s8 d2, r5 \n" + "vext.8 d8, d0, d2, #3 \n" + "vext.8 d9, d0, d2, #6 \n" + + "vmull.s8 q6, d4, d0 \n" + "vmull.s8 q7, d5, d8 \n" + "vmlal.s8 q6, d6, d9 \n" + "vaddl.s16 q12, d12, d14 \n" + "vdup.s32 d2, d24[1] \n" + "vadd.s32 d24, d24, d2 \n" + "vadd.s32 d24, d24, d25 \n" + + "ldr r7, [%[output0]] \n" + "vdup.s32 d14, r7 \n" + "vadd.s32 d14, d14, d24 \n" + "vst1.32 d14[0], [%[output0]]! \n" + + "vmull.s8 q6, d5, d0 \n" + "vmull.s8 q7, d6, d8 \n" + "vmlal.s8 q6, d7, d9 \n" + "vaddl.s16 q12, d12, d14 \n" + "vdup.s32 d2, d24[1] \n" + "vadd.s32 d24, d24, d2 \n" + "vadd.s32 d24, d24, d25 \n" + + "ldr r7, [%[output0n]] \n" + "vdup.s32 d14, r7 \n" + "vadd.s32 d14, d14, d24 \n" + "vst1.32 d14[0], [%[output0n]]! \n" + + "subs %[remain], #1 \n" + "bne 0b \n" + : [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3), + [remain] "+r"(remain), [output0] "+r"(output0), + [output0n] "+r"(output0n) + : [kernel0] "r"(kernel0) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "r5", "r7"); + } + output0 += output_w; + output0n += output_w; + } + // remain output height + for (; oh < output_h; ++oh) { + const int8_t* r0 = in_data + ic * image_size + oh * input_w; + const int8_t* r1 = r0 + input_w; + const int8_t* r2 = r1 + input_w; + + int ow = output_w >> 3; + int remain = output_w & 0x7; + if (ow > 0) { + asm volatile( + "vld1.8 {d0}, [%[kernel0]] \n" + "ldr r5, [%[kernel0], #8] \n" + + "0: \n" + "vld1.8 {d2-d3}, [%[r0]] \n" // r0 + "add %[r0], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + "vdup.s8 d6, d0[0] \n" + "vdup.s8 d7, d0[1] \n" + "vdup.s8 d8, d0[2] \n" + "vmull.s8 q6, d2, d6 \n" + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddl.s16 q12, d12, d14 \n" + "vaddl.s16 q13, d13, d15 \n" + + "vld1.8 {d2-d3}, [%[r1]] \n" // r1 + "add %[r1], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + "vdup.s8 d6, d0[3] \n" + "vdup.s8 d7, d0[4] \n" + "vdup.s8 d8, d0[5] \n" + "vmull.s8 q6, d2, d6 \n" + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddw.s16 q12, q12, d12 \n" + "vaddw.s16 q12, q12, d14 \n" + "vaddw.s16 q13, q13, d13 \n" + "vaddw.s16 q13, q13, d15 \n" + + "vld1.8 {d2-d3}, [%[r2]] \n" // r2 + "add %[r2], #8 \n" + "vext.8 d4, d2, d3, #1 \n" + "vext.8 d5, d2, d3, #2 \n" + "vdup.s8 d6, d0[6] \n" + "vdup.s8 d7, d0[7] \n" + "vdup.s8 d8, r5 \n" + "vmull.s8 q6, d2, d6 \n" + "vmull.s8 q7, d4, d7 \n" + "vmlal.s8 q6, d5, d8 \n" + "vaddw.s16 q12, q12, d12 \n" + "vaddw.s16 q12, q12, d14 \n" + "vaddw.s16 q13, q13, d13 \n" + "vaddw.s16 q13, q13, d15 \n" + + "vld1.32 {d12-d15}, [%[output0]] \n" + "vadd.s32 q6, q6, q12 \n" + "vadd.s32 q7, q7, q13 \n" + "vst1.32 {d12-d15}, [%[output0]]! \n" + + "subs %[ow], #1 \n" + "bne 0b \n" + : [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [ow] "+r"(ow), + [output0] "+r"(output0) + : [kernel0] "r"(kernel0) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r5"); + } + + if (remain > 0) { + asm volatile( + "vld1.8 {d0}, [%[kernel0]] \n" + "ldr r5, [%[kernel0], #8] \n" + + "0: \n" + "vld1.8 d4, [%[r0]] \n" + "vld1.8 d5, [%[r1]] \n" + "vld1.8 d6, [%[r2]] \n" + "add %[r0], #1 \n" + "add %[r1], #1 \n" + "add %[r2], #1 \n" + "vdup.s8 d2, r5 \n" + "vext.8 d8, d0, d2, #3 \n" + "vext.8 d9, d0, d2, #6 \n" + + "vmull.s8 q6, d4, d0 \n" + "vmull.s8 q7, d5, d8 \n" + "vmlal.s8 q6, d6, d9 \n" + "vaddl.s16 q12, d12, d14 \n" + "vdup.s32 d2, d24[1] \n" + "vadd.s32 d24, d24, d2 \n" + "vadd.s32 d24, d24, d25 \n" + + "ldr r7, [%[output0]] \n" + "vdup.s32 d14, r7 \n" + "vadd.s32 d14, d14, d24 \n" + "vst1.32 d14[0], [%[output0]]! \n" + + "subs %[remain], #1 \n" + "bne 0b \n" + : [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), + [remain] "+r"(remain), [output0] "+r"(output0) + : [kernel0] "r"(kernel0) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "r5", "r7"); + } + } + } + } +#endif +#else +// TODO(hjchen2) +#endif +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/math/conv5x5_arm_int8.cpp b/src/operators/math/conv5x5_arm_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c861c22d184d5428f3ab9c8f3a69b9aca5b697bd --- /dev/null +++ b/src/operators/math/conv5x5_arm_int8.cpp @@ -0,0 +1,551 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CONV_OP + +#include "operators/math/conv_arm_int8.h" + +namespace paddle_mobile { +namespace operators { + +void conv5x5s1_int8(const framework::Tensor& input, + const framework::Tensor& weight, + framework::Tensor* output) { +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + const int8_t* in_data = input.data(); + const int8_t* w_data = weight.data(); + int32_t* out_data = output->mutable_data(); + // make sure that batch size is 1 + int input_c = input.dims()[1]; + int input_h = input.dims()[2]; + int input_w = input.dims()[3]; + int output_c = output->dims()[1]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int image_size = input_h * input_w; + int out_image_size = output_h * output_w; + memset(out_data, 0, output_c * out_image_size * sizeof(int32_t)); +#if __aarch64__ + // TODO(hjchen2) +#else + #pragma omp parallel for + for (int oc = 0; oc < output_c; ++oc) { + for (int ic = 0; ic < input_c; ++ic) { + const int8_t* kernel = w_data + (oc * input_c + ic) * 25; + int32_t* output0 = out_data + oc * out_image_size; + int32_t* output1 = output0 + output_w; + int oh = 0; + for (; oh < output_h - 1; oh += 2) { + const int8_t* r0 = in_data + ic * image_size + oh * input_w; + const int8_t* r1 = r0 + input_w; + const int8_t* r2 = r1 + input_w; + const int8_t* r3 = r2 + input_w; + const int8_t* r4 = r3 + input_w; + const int8_t* r5 = r4 + input_w; + + int ow = output_w >> 3; + int remain = output_w & 0x7; + if (ow > 0) { + asm volatile("vld1.8 {d0-d3}, [%[kernel]] \n" + : [kernel] "+r"(kernel) + : + : "cc", "memory", "q0", "q1"); + asm volatile( + "0: \n" + "vld1.8 {d4-d5}, [%[r0]] \n" // r0 + "add %[r0], #8 \n" + "vext.8 d6, d4, d5, #1 \n" + "vext.8 d7, d4, d5, #2 \n" + "vext.8 d8, d4, d5, #3 \n" + "vext.8 d9, d4, d5, #4 \n" + "vdup.s8 d10, d0[0] \n" + "vdup.s8 d11, d0[1] \n" + "vdup.s8 d12, d0[2] \n" + "vdup.s8 d13, d0[3] \n" + "vdup.s8 d14, d0[4] \n" + "vmull.s8 q8, d4, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q14, d16, d18 \n" + "vaddl.s16 q15, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q14, q14, d16 \n" + "vaddw.s16 q15, q15, d17 \n" + + "vld1.8 {d4-d5}, [%[r1]] \n" // r1 + "add %[r1], #8 \n" + "vext.8 d6, d4, d5, #1 \n" + "vext.8 d7, d4, d5, #2 \n" + "vext.8 d8, d4, d5, #3 \n" + "vext.8 d9, d4, d5, #4 \n" + + "vmull.s8 q8, d4, d10 \n" // next row + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q10, d16, d18 \n" + "vaddl.s16 q11, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q10, q10, d16 \n" + "vaddw.s16 q11, q11, d17 \n" + + "vdup.s8 d10, d0[5] \n" + "vdup.s8 d11, d0[6] \n" + "vdup.s8 d12, d0[7] \n" + "vdup.s8 d13, d1[0] \n" + "vdup.s8 d14, d1[1] \n" + "vmull.s8 q8, d4, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q12, d16, d18 \n" + "vaddl.s16 q13, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q12, q12, d16 \n" + "vaddw.s16 q13, q13, d17 \n" + "vadd.s32 q14, q14, q12 \n" + "vadd.s32 q15, q15, q13 \n" + + "vld1.8 {d4-d5}, [%[r2]] \n" // r2 + "add %[r2], #8 \n" + "vext.8 d6, d4, d5, #1 \n" + "vext.8 d7, d4, d5, #2 \n" + "vext.8 d8, d4, d5, #3 \n" + "vext.8 d9, d4, d5, #4 \n" + + "vmull.s8 q8, d4, d10 \n" // next row + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q12, d16, d18 \n" + "vaddl.s16 q13, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q12, q12, d16 \n" + "vaddw.s16 q13, q13, d17 \n" + "vadd.s32 q10, q10, q12 \n" + "vadd.s32 q11, q11, q13 \n" + + "vdup.s8 d10, d1[2] \n" + "vdup.s8 d11, d1[3] \n" + "vdup.s8 d12, d1[4] \n" + "vdup.s8 d13, d1[5] \n" + "vdup.s8 d14, d1[6] \n" + "vmull.s8 q8, d4, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q12, d16, d18 \n" + "vaddl.s16 q13, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q12, q12, d16 \n" + "vaddw.s16 q13, q13, d17 \n" + "vadd.s32 q14, q14, q12 \n" + "vadd.s32 q15, q15, q13 \n" + + "vld1.8 {d4-d5}, [%[r3]] \n" // r3 + "add %[r3], #8 \n" + "vext.8 d6, d4, d5, #1 \n" + "vext.8 d7, d4, d5, #2 \n" + "vext.8 d8, d4, d5, #3 \n" + "vext.8 d9, d4, d5, #4 \n" + + "vmull.s8 q8, d4, d10 \n" // next row + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q12, d16, d18 \n" + "vaddl.s16 q13, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q12, q12, d16 \n" + "vaddw.s16 q13, q13, d17 \n" + "vadd.s32 q10, q10, q12 \n" + "vadd.s32 q11, q11, q13 \n" + + "vdup.s8 d10, d1[7] \n" + "vdup.s8 d11, d2[0] \n" + "vdup.s8 d12, d2[1] \n" + "vdup.s8 d13, d2[2] \n" + "vdup.s8 d14, d2[3] \n" + "vmull.s8 q8, d4, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q12, d16, d18 \n" + "vaddl.s16 q13, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q12, q12, d16 \n" + "vaddw.s16 q13, q13, d17 \n" + "vadd.s32 q14, q14, q12 \n" + "vadd.s32 q15, q15, q13 \n" + + "vld1.8 {d4-d5}, [%[r4]] \n" // r4 + "add %[r4], #8 \n" + "vext.8 d6, d4, d5, #1 \n" + "vext.8 d7, d4, d5, #2 \n" + "vext.8 d8, d4, d5, #3 \n" + "vext.8 d9, d4, d5, #4 \n" + + "vmull.s8 q8, d4, d10 \n" // next row + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q12, d16, d18 \n" + "vaddl.s16 q13, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q12, q12, d16 \n" + "vaddw.s16 q13, q13, d17 \n" + "vadd.s32 q10, q10, q12 \n" + "vadd.s32 q11, q11, q13 \n" + + "vdup.s8 d10, d2[4] \n" + "vdup.s8 d11, d2[5] \n" + "vdup.s8 d12, d2[6] \n" + "vdup.s8 d13, d2[7] \n" + "vdup.s8 d14, d3[0] \n" + "vmull.s8 q8, d4, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q12, d16, d18 \n" + "vaddl.s16 q13, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q12, q12, d16 \n" + "vaddw.s16 q13, q13, d17 \n" + "vadd.s32 q14, q14, q12 \n" + "vadd.s32 q15, q15, q13 \n" + + "vld1.32 {d24-d27}, [%[output0]] \n" + "vadd.s32 q12, q12, q14 \n" + "vadd.s32 q13, q13, q15 \n" + "vst1.32 {d24-d27}, [%[output0]]! \n" + + "vld1.8 {d4-d5}, [%[r5]] \n" // row 5 + "add %[r5], #8 \n" + "vext.8 d6, d4, d5, #1 \n" + "vext.8 d7, d4, d5, #2 \n" + "vext.8 d8, d4, d5, #3 \n" + "vext.8 d9, d4, d5, #4 \n" + "vmull.s8 q8, d4, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q12, d16, d18 \n" + "vaddl.s16 q13, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q12, q12, d16 \n" + "vaddw.s16 q13, q13, d17 \n" + "vadd.s32 q10, q10, q12 \n" + "vadd.s32 q11, q11, q13 \n" + + "vld1.32 {d24-d27}, [%[output1]] \n" + "vadd.s32 q12, q12, q10 \n" + "vadd.s32 q13, q13, q11 \n" + "vst1.32 {d24-d27}, [%[output1]]! \n" + + "subs %[ow], #1 \n" + "bne 0b \n" + : [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3), + [r4] "+r"(r4), [r5] "+r"(r5), [ow] "+r"(ow), + [output0] "+r"(output0), [output1] "+r"(output1) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + } + if (remain > 0) { + asm volatile("vld1.8 {d0-d3}, [%[kernel]] \n" + : [kernel] "+r"(kernel) + : + : "cc", "memory", "q0", "q1"); + asm volatile( + "0: \n" + "vld1.8 d4, [%[r0]] \n" + "vld1.8 d5, [%[r1]] \n" + "vld1.8 d6, [%[r2]] \n" + "vld1.8 d7, [%[r3]] \n" + "vld1.8 d8, [%[r4]] \n" + "vld1.8 d9, [%[r5]] \n" + "add %[r0], #1 \n" + "add %[r1], #1 \n" + "add %[r2], #1 \n" + "add %[r3], #1 \n" + "add %[r4], #1 \n" + "add %[r5], #1 \n" + "vext.8 d10, d0, d1, #5 \n" + "vext.8 d11, d1, d2, #2 \n" + "vext.8 d12, d1, d2, #7 \n" + "vext.8 d13, d2, d3, #4 \n" + + "vmull.s8 q7, d4, d0 \n" + "vmull.s8 q8, d5, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q10, d14, d16 \n" + "vaddw.s16 q10, q10, d18 \n" + "vadd.s32 d4, d20, d21 \n" + "vaddl.s16 q10, d15, d17 \n" + "vaddw.s16 q10, q10, d19 \n" + "vdup.s32 d14, d4[0] \n" + "vdup.s32 d15, d4[1] \n" + "vadd.s32 d15, d15, d14 \n" + "vdup.s32 d14, d20[0] \n" + "vadd.s32 d15, d15, d14 \n" + + "ldr r6, [%[output0]] \n" + "vdup.s32 d14, r6 \n" + "vadd.s32 d15, d15, d14 \n" + "vst1.32 d15[0], [%[output0]]! \n" + + "vmull.s8 q7, d5, d0 \n" + "vmull.s8 q8, d6, d10 \n" + "vmull.s8 q9, d7, d11 \n" + "vmlal.s8 q8, d8, d12 \n" + "vmlal.s8 q9, d9, d13 \n" + "vaddl.s16 q10, d14, d16 \n" + "vaddw.s16 q10, q10, d18 \n" + "vadd.s32 d4, d20, d21 \n" + "vaddl.s16 q10, d15, d17 \n" + "vaddw.s16 q10, q10, d19 \n" + "vdup.s32 d14, d4[0] \n" + "vdup.s32 d15, d4[1] \n" + "vadd.s32 d15, d15, d14 \n" + "vdup.s32 d14, d20[0] \n" + "vadd.s32 d15, d15, d14 \n" + + "ldr r6, [%[output1]] \n" + "vdup.s32 d14, r6 \n" + "vadd.s32 d15, d15, d14 \n" + "vst1.32 d15[0], [%[output1]]! \n" + + "subs %[remain], #1 \n" + "bne 0b \n" + : [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3), + [r4] "+r"(r4), [r5] "+r"(r5), [remain] "+r"(remain), + [output0] "+r"(output0), [output1] "+r"(output1) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "r6"); + } + output0 += output_w; + output1 += output_w; + } + // remain output height + for (; oh < output_h; ++oh) { + const int8_t* r0 = in_data + ic * image_size + oh * input_w; + const int8_t* r1 = r0 + input_w; + const int8_t* r2 = r1 + input_w; + const int8_t* r3 = r2 + input_w; + const int8_t* r4 = r3 + input_w; + + int ow = output_w >> 3; + int remain = output_w & 0x7; + if (ow > 0) { + asm volatile("vld1.8 {d0-d3}, [%[kernel]] \n" + : [kernel] "+r"(kernel) + : + : "cc", "memory", "q0", "q1"); + asm volatile( + "0: \n" + "vld1.8 {d4-d5}, [%[r0]] \n" // r0 + "add %[r0], #8 \n" + "vext.8 d6, d4, d5, #1 \n" + "vext.8 d7, d4, d5, #2 \n" + "vext.8 d8, d4, d5, #3 \n" + "vext.8 d9, d4, d5, #4 \n" + "vdup.s8 d10, d0[0] \n" + "vdup.s8 d11, d0[1] \n" + "vdup.s8 d12, d0[2] \n" + "vdup.s8 d13, d0[3] \n" + "vdup.s8 d14, d0[4] \n" + "vmull.s8 q8, d4, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q14, d16, d18 \n" + "vaddl.s16 q15, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q14, q14, d16 \n" + "vaddw.s16 q15, q15, d17 \n" + + "vld1.8 {d4-d5}, [%[r1]] \n" // r1 + "add %[r1], #8 \n" + "vext.8 d6, d4, d5, #1 \n" + "vext.8 d7, d4, d5, #2 \n" + "vext.8 d8, d4, d5, #3 \n" + "vext.8 d9, d4, d5, #4 \n" + "vdup.s8 d10, d0[5] \n" + "vdup.s8 d11, d0[6] \n" + "vdup.s8 d12, d0[7] \n" + "vdup.s8 d13, d1[0] \n" + "vdup.s8 d14, d1[1] \n" + "vmull.s8 q8, d4, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q12, d16, d18 \n" + "vaddl.s16 q13, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q12, q12, d16 \n" + "vaddw.s16 q13, q13, d17 \n" + "vadd.s32 q14, q14, q12 \n" + "vadd.s32 q15, q15, q13 \n" + + "vld1.8 {d4-d5}, [%[r2]] \n" // r2 + "add %[r2], #8 \n" + "vext.8 d6, d4, d5, #1 \n" + "vext.8 d7, d4, d5, #2 \n" + "vext.8 d8, d4, d5, #3 \n" + "vext.8 d9, d4, d5, #4 \n" + "vdup.s8 d10, d1[2] \n" + "vdup.s8 d11, d1[3] \n" + "vdup.s8 d12, d1[4] \n" + "vdup.s8 d13, d1[5] \n" + "vdup.s8 d14, d1[6] \n" + "vmull.s8 q8, d4, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q12, d16, d18 \n" + "vaddl.s16 q13, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q12, q12, d16 \n" + "vaddw.s16 q13, q13, d17 \n" + "vadd.s32 q14, q14, q12 \n" + "vadd.s32 q15, q15, q13 \n" + + "vld1.8 {d4-d5}, [%[r3]] \n" // r3 + "add %[r3], #8 \n" + "vext.8 d6, d4, d5, #1 \n" + "vext.8 d7, d4, d5, #2 \n" + "vext.8 d8, d4, d5, #3 \n" + "vext.8 d9, d4, d5, #4 \n" + "vdup.s8 d10, d1[7] \n" + "vdup.s8 d11, d2[0] \n" + "vdup.s8 d12, d2[1] \n" + "vdup.s8 d13, d2[2] \n" + "vdup.s8 d14, d2[3] \n" + "vmull.s8 q8, d4, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q12, d16, d18 \n" + "vaddl.s16 q13, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q12, q12, d16 \n" + "vaddw.s16 q13, q13, d17 \n" + "vadd.s32 q14, q14, q12 \n" + "vadd.s32 q15, q15, q13 \n" + + "vld1.8 {d4-d5}, [%[r4]] \n" // r4 + "add %[r4], #8 \n" + "vext.8 d6, d4, d5, #1 \n" + "vext.8 d7, d4, d5, #2 \n" + "vext.8 d8, d4, d5, #3 \n" + "vext.8 d9, d4, d5, #4 \n" + "vdup.s8 d10, d2[4] \n" + "vdup.s8 d11, d2[5] \n" + "vdup.s8 d12, d2[6] \n" + "vdup.s8 d13, d2[7] \n" + "vdup.s8 d14, d3[0] \n" + "vmull.s8 q8, d4, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q12, d16, d18 \n" + "vaddl.s16 q13, d17, d19 \n" + "vmull.s8 q8, d9, d14 \n" + "vaddw.s16 q12, q12, d16 \n" + "vaddw.s16 q13, q13, d17 \n" + "vadd.s32 q14, q14, q12 \n" + "vadd.s32 q15, q15, q13 \n" + + "vld1.32 {d24-d27}, [%[output0]] \n" + "vadd.s32 q12, q12, q14 \n" + "vadd.s32 q13, q13, q15 \n" + "vst1.32 {d24-d27}, [%[output0]]! \n" + + "subs %[ow], #1 \n" + "bne 0b \n" + : [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3), + [r4] "+r"(r4), [ow] "+r"(ow), [output0] "+r"(output0) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + } + + if (remain > 0) { + asm volatile("vld1.8 {d0-d3}, [%[kernel]] \n" + : [kernel] "+r"(kernel) + : + : "cc", "memory", "q0", "q1"); + asm volatile( + "0: \n" + "vld1.8 d4, [%[r0]] \n" + "vld1.8 d5, [%[r1]] \n" + "vld1.8 d6, [%[r2]] \n" + "vld1.8 d7, [%[r3]] \n" + "vld1.8 d8, [%[r4]] \n" + "add %[r0], #1 \n" + "add %[r1], #1 \n" + "add %[r2], #1 \n" + "add %[r3], #1 \n" + "add %[r4], #1 \n" + "vext.8 d10, d0, d1, #5 \n" + "vext.8 d11, d1, d2, #2 \n" + "vext.8 d12, d1, d2, #7 \n" + "vext.8 d13, d2, d3, #4 \n" + + "vmull.s8 q7, d4, d0 \n" + "vmull.s8 q8, d5, d10 \n" + "vmull.s8 q9, d6, d11 \n" + "vmlal.s8 q8, d7, d12 \n" + "vmlal.s8 q9, d8, d13 \n" + "vaddl.s16 q10, d14, d16 \n" + "vaddw.s16 q10, q10, d18 \n" + "vadd.s32 d4, d20, d21 \n" + "vaddl.s16 q10, d15, d17 \n" + "vaddw.s16 q10, q10, d19 \n" + "vdup.s32 d14, d4[0] \n" + "vdup.s32 d15, d4[1] \n" + "vadd.s32 d15, d15, d14 \n" + "vdup.s32 d14, d20[0] \n" + "vadd.s32 d15, d15, d14 \n" + + "ldr r6, [%[output0]] \n" + "vdup.s32 d14, r6 \n" + "vadd.s32 d15, d15, d14 \n" + "vst1.32 d15[0], [%[output0]]! \n" + + "subs %[remain], #1 \n" + "bne 0b \n" + : [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3), + [r4] "+r"(r4), [remain] "+r"(remain), [output0] "+r"(output0) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "r6"); + } + } + } + } +#endif +#else +// TODO(hjchen2) +#endif +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/math/conv_arm_int8.h b/src/operators/math/conv_arm_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..98843e6158bb0f9816bf49a1cbced5a2ea731446 --- /dev/null +++ b/src/operators/math/conv_arm_int8.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CONV_OP + +#pragma once + +#include "framework/tensor.h" + +namespace paddle_mobile { +namespace operators { + +void conv3x3s1_int8(const framework::Tensor& input, + const framework::Tensor& weight, framework::Tensor* output); + +void conv3x3s1_int8_4c(const framework::Tensor& input, + const framework::Tensor& weight, + framework::Tensor* output); + +void conv5x5s1_int8(const framework::Tensor& input, + const framework::Tensor& weight, framework::Tensor* output); + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 2990f7a0f8d4712a3dc3c429d9b57e5aa3809325..44621ba99a92a3ed456b8d7d0959e3580662d910 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -3379,7 +3379,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, // 对 B 分块 NC = L1 / (KC * sizeof(float)); if (NC == 0) { - NC == NR; + NC = NR; } else { int nblock_num = (n + NC - 1) / NC; NC = (n + nblock_num - 1) / nblock_num; diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index adc6924d8ad273012a9b44677f8ad1a29bc37787..ea023bc134033aee6577ebf06c95f2a762d08bca 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -22,9 +22,11 @@ limitations under the License. */ #define C(i, j) C[(i)*ldc + (j)] #if __aarch64__ +#define MR_INT8 4 #define MR 6 #define NR 16 #else +#define MR_INT8 4 #define MR 6 #define NR 8 #endif @@ -189,6 +191,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, // 8 bits function cluster begins // 8 bits int small block inner product + void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc); void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc); @@ -199,6 +203,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, int8_t *bias); // 8 bits int pack function + void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer); void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp index bd5286dbcb5c871d5d327875b836ad9777c270bf..5dd8a7c3131543f426f32e258efb3181be9b2f61 100644 --- a/src/operators/math/gemm_int8.cpp +++ b/src/operators/math/gemm_int8.cpp @@ -26,11 +26,228 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { namespace math { +void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc) { +#if __ARM_NEON +#if __aarch64__ +// TODO +#else + const int8_t *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int32_t kc1 = k >> 3; + int32_t kc2 = k & 7; + int32_t kc3 = kc2 >> 2; + int32_t kc4 = kc2 & 3; + int32_t kc5 = kc4 >> 1; + int32_t kc6 = kc4 & 1; + int32_t step = sizeof(int32_t) * ldc; + asm volatile( + // q8-q15: save 32 results + "pld [%[a_ptr]] \n\t" + "pld [%[b_ptr]] \n\t" + "pld [%[b_ptr], #64] \n\t" + "vmov.s32 q8, #0 \n\t" + "vmov.s32 q9, q8 \n\t" + "vmov.s32 q10, q8 \n\t" + "vmov.s32 q11, q8 \n\t" + "vmov.s32 q12, q8 \n\t" + "vmov.s32 q13, q8 \n\t" + "vmov.s32 q14, q8 \n\t" + "vmov.s32 q15, q8 \n\t" + "subs %[kc1], %[kc1], #1 \n\t" + "blt 1f \n\t" + "0: \n\t" + "pld [%[a_ptr], #64] \n\t" + "pld [%[b_ptr], #128] \n\t" + "vld1.s8 {d0-d3}, [%[a_ptr]]! \n\t" // load A 8 cols + "vld1.s8 {d8-d11}, [%[b_ptr]]! \n\t" // load B first 4 rows + "vmovl.s8 q2, d0 \n\t" // process B first 4 + // rows + "vmovl.s8 q3, d8 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d9 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "vld1.s8 {d12-d15}, [%[b_ptr]]! \n\t" // load B second 4 + // rows + "vmovl.s8 q2, d1 \n\t" + "vmovl.s8 q3, d10 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d11 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "vmovl.s8 q2, d2 \n\t" // process B second 4 + // rows + "vmovl.s8 q3, d12 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d13 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "vmovl.s8 q2, d3 \n\t" + "vmovl.s8 q3, d14 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d15 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge 0b \n\t" + "1: \n\t" // last 4 rows + "subs %[kc3], %[kc3], #1 \n\t" + "blt 2f \n\t" + "vld1.s8 {d0-d1}, [%[a_ptr]]! \n\t" // load A 4 cols + "vld1.s8 {d8-d11}, [%[b_ptr]]! \n\t" // load B 4 rows + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d8 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d9 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "vmovl.s8 q2, d1 \n\t" + "vmovl.s8 q3, d10 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d11 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "2: \n\t" // last 2 rows + "subs %[kc5], %[kc5], #1 \n\t" + "blt 3f \n\t" + "vld1.s8 {d0}, [%[a_ptr]]! \n\t" // load A 2 cols + "vld1.s8 {d8-d9}, [%[b_ptr]]! \n\t" // load B 2 rows + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d8 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d9 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "3: \n\t" // last 1 row + "subs %[kc6], %[kc6], #1 \n\t" + "blt 4f \n\t" + "vld1.s8 {d0}, [%[a_ptr]] \n\t" // load A 1 col + "vld1.s8 {d8}, [%[b_ptr]] \n\t" // load B 1 row + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d8 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "4: \n\t" + "vst1.32 {q8, q9}, [%[c]], %[step] \n\t" + "vst1.32 {q10, q11}, [%[c]], %[step] \n\t" + "vst1.32 {q12, q13}, [%[c]], %[step] \n\t" + "vst1.32 {q14, q15}, [%[c]] \n\t" + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [kc3] "r"(kc3), [kc5] "r"(kc5), [kc6] "r"(kc6), [step] "r"(step) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#endif // __ARM_NEON +} // 8 bits int small block inner product void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { #if __ARM_NEON +#if __aarch64__ +// TODO +#else const int8_t *a_ptr, *b_ptr; a_ptr = a; b_ptr = b; @@ -46,383 +263,265 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, "pld [%[a_ptr]] \n\t" "pld [%[b_ptr]] \n\t" "pld [%[b_ptr], #64] \n\t" - "vmov.s8 q4, #0 \n\t" - "vmov.s8 q5, #0 \n\t" - "vmov.s8 q6, #0 \n\t" - "vmov.s8 q7, #0 \n\t" - "vmov.s8 q8, #0 \n\t" - "vmov.s8 q9, #0 \n\t" - "vmov.s8 q10, #0 \n\t" - "vmov.s8 q11, #0 \n\t" - "vmov.s8 q12, #0 \n\t" - "vmov.s8 q13, #0 \n\t" - "vmov.s8 q14, #0 \n\t" - "vmov.s8 q15, #0 \n\t" + "vmov.s32 q4, #0 \n\t" + "vmov.s32 q5, q4 \n\t" + "vmov.s32 q6, q4 \n\t" + "vmov.s32 q7, q4 \n\t" + "vmov.s32 q8, q4 \n\t" + "vmov.s32 q9, q4 \n\t" + "vmov.s32 q10, q4 \n\t" + "vmov.s32 q11, q4 \n\t" + "vmov.s32 q12, q4 \n\t" + "vmov.s32 q13, q4 \n\t" + "vmov.s32 q14, q4 \n\t" + "vmov.s32 q15, q4 \n\t" "mov r0, #12 \n\t" - "subs %[kc1], %[kc1], #1 \n\t" + "subs %[kc1], %[kc1], #1 \n\t" "blt 1f \n\t" "0: \n\t" "pld [%[a_ptr], #64] \n\t" "pld [%[b_ptr], #128] \n\t" - "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols, q0 used, - // 1/2 q3 used - "vmov.s8 q2, #0 \n\t" // q2 used - "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, - // q1 - "vdup.s8 d3, d0[0] \n\t" // q3 used // used - "vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 - "vdup.s8 d3, d0[6] \n\t" // q3 used - "vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, - // q3 free - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[1] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d0[7] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[2] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[0] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[3] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[1] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[4] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[2] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[5] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[3] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 - - "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, - // q1 - "vmov.s8 q2, #0 \n\t" // q2 used - "vdup.s8 d3, d1[4] \n\t" // q3 used // used - "vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 - "vdup.s8 d3, d2[2] \n\t" // q3 used - "vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, - // q3 free - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[5] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[3] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[6] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[4] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[7] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[5] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d2[0] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[6] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d2[1] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[7] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 - - "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols, q0 used, - // 1/2 q3 used - "vmov.s8 q2, #0 \n\t" // q2 used - "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, - // q1 - "vdup.s8 d3, d0[0] \n\t" // q3 used // used - "vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 - "vdup.s8 d3, d0[6] \n\t" // q3 used - "vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, - // q3 free - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[1] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d0[7] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[2] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[0] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[3] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[1] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[4] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[2] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[5] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[3] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 - - "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, - // q1 - "vmov.s8 q2, #0 \n\t" // q2 used - "vdup.s8 d3, d1[4] \n\t" // q3 used // used - "vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 - "vdup.s8 d3, d2[2] \n\t" // q3 used - "vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, - // q3 free - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[5] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[3] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[6] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[4] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[7] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[5] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d2[0] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[6] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d2[1] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[7] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 1st row + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[0]\n\t" + "vmlal.s16 q5, d7, d4[0]\n\t" + "vmlal.s16 q6, d6, d4[1]\n\t" + "vmlal.s16 q7, d7, d4[1]\n\t" + "vmlal.s16 q8, d6, d4[2]\n\t" + "vmlal.s16 q9, d7, d4[2]\n\t" + "vmlal.s16 q10, d6, d4[3]\n\t" + "vmlal.s16 q11, d7, d4[3]\n\t" + "vmlal.s16 q12, d6, d5[0]\n\t" + "vmlal.s16 q13, d7, d5[0]\n\t" + "vmlal.s16 q14, d6, d5[1]\n\t" + "vmlal.s16 q15, d7, d5[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 2nd row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[2]\n\t" + "vmlal.s16 q5, d7, d5[2]\n\t" + "vmlal.s16 q6, d6, d5[3]\n\t" + "vmlal.s16 q7, d7, d5[3]\n\t" + "vmovl.s8 q2, d1 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 3th row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[0]\n\t" + "vmlal.s16 q5, d7, d5[0]\n\t" + "vmlal.s16 q6, d6, d5[1]\n\t" + "vmlal.s16 q7, d7, d5[1]\n\t" + "vmlal.s16 q8, d6, d5[2]\n\t" + "vmlal.s16 q9, d7, d5[2]\n\t" + "vmlal.s16 q10, d6, d5[3]\n\t" + "vmlal.s16 q11, d7, d5[3]\n\t" + "vmovl.s8 q2, d2 \n\t" + "vmlal.s16 q12, d6, d4[0]\n\t" + "vmlal.s16 q13, d7, d4[0]\n\t" + "vmlal.s16 q14, d6, d4[1]\n\t" + "vmlal.s16 q15, d7, d4[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 4th row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[2]\n\t" + "vmlal.s16 q5, d7, d4[2]\n\t" + "vmlal.s16 q6, d6, d4[3]\n\t" + "vmlal.s16 q7, d7, d4[3]\n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + + "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 1st row + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[0]\n\t" + "vmlal.s16 q5, d7, d4[0]\n\t" + "vmlal.s16 q6, d6, d4[1]\n\t" + "vmlal.s16 q7, d7, d4[1]\n\t" + "vmlal.s16 q8, d6, d4[2]\n\t" + "vmlal.s16 q9, d7, d4[2]\n\t" + "vmlal.s16 q10, d6, d4[3]\n\t" + "vmlal.s16 q11, d7, d4[3]\n\t" + "vmlal.s16 q12, d6, d5[0]\n\t" + "vmlal.s16 q13, d7, d5[0]\n\t" + "vmlal.s16 q14, d6, d5[1]\n\t" + "vmlal.s16 q15, d7, d5[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 2nd row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[2]\n\t" + "vmlal.s16 q5, d7, d5[2]\n\t" + "vmlal.s16 q6, d6, d5[3]\n\t" + "vmlal.s16 q7, d7, d5[3]\n\t" + "vmovl.s8 q2, d1 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 3th row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[0]\n\t" + "vmlal.s16 q5, d7, d5[0]\n\t" + "vmlal.s16 q6, d6, d5[1]\n\t" + "vmlal.s16 q7, d7, d5[1]\n\t" + "vmlal.s16 q8, d6, d5[2]\n\t" + "vmlal.s16 q9, d7, d5[2]\n\t" + "vmlal.s16 q10, d6, d5[3]\n\t" + "vmlal.s16 q11, d7, d5[3]\n\t" + "vmovl.s8 q2, d2 \n\t" + "vmlal.s16 q12, d6, d4[0]\n\t" + "vmlal.s16 q13, d7, d4[0]\n\t" + "vmlal.s16 q14, d6, d4[1]\n\t" + "vmlal.s16 q15, d7, d4[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 4th row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[2]\n\t" + "vmlal.s16 q5, d7, d4[2]\n\t" + "vmlal.s16 q6, d6, d4[3]\n\t" + "vmlal.s16 q7, d7, d4[3]\n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" "subs %[kc1], %[kc1], #1 \n\t" "bge 0b \n\t" "1: \n\t" // last <8 rows "subs %[kc3], %[kc3], #1 \n\t" "blt 2f \n\t" - "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" - "vmov.s8 q2, #0 \n\t" - "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" - "vdup.s8 d3, d0[0] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d0[6] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[1] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d0[7] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[2] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[0] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[3] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[1] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[4] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[2] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[5] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[3] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 - - "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[4] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[2] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[5] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[3] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[6] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[4] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[7] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[5] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d2[0] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[6] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d2[1] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[7] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 1st row + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[0]\n\t" + "vmlal.s16 q5, d7, d4[0]\n\t" + "vmlal.s16 q6, d6, d4[1]\n\t" + "vmlal.s16 q7, d7, d4[1]\n\t" + "vmlal.s16 q8, d6, d4[2]\n\t" + "vmlal.s16 q9, d7, d4[2]\n\t" + "vmlal.s16 q10, d6, d4[3]\n\t" + "vmlal.s16 q11, d7, d4[3]\n\t" + "vmlal.s16 q12, d6, d5[0]\n\t" + "vmlal.s16 q13, d7, d5[0]\n\t" + "vmlal.s16 q14, d6, d5[1]\n\t" + "vmlal.s16 q15, d7, d5[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 2nd row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[2]\n\t" + "vmlal.s16 q5, d7, d5[2]\n\t" + "vmlal.s16 q6, d6, d5[3]\n\t" + "vmlal.s16 q7, d7, d5[3]\n\t" + "vmovl.s8 q2, d1 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 3th row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[0]\n\t" + "vmlal.s16 q5, d7, d5[0]\n\t" + "vmlal.s16 q6, d6, d5[1]\n\t" + "vmlal.s16 q7, d7, d5[1]\n\t" + "vmlal.s16 q8, d6, d5[2]\n\t" + "vmlal.s16 q9, d7, d5[2]\n\t" + "vmlal.s16 q10, d6, d5[3]\n\t" + "vmlal.s16 q11, d7, d5[3]\n\t" + "vmovl.s8 q2, d2 \n\t" + "vmlal.s16 q12, d6, d4[0]\n\t" + "vmlal.s16 q13, d7, d4[0]\n\t" + "vmlal.s16 q14, d6, d4[1]\n\t" + "vmlal.s16 q15, d7, d4[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 4th row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[2]\n\t" + "vmlal.s16 q5, d7, d4[2]\n\t" + "vmlal.s16 q6, d6, d4[3]\n\t" + "vmlal.s16 q7, d7, d4[3]\n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" "2: \n\t" // last <4 rows "subs %[kc5], %[kc5], #1 \n\t" "blt 3f \n\t" "vld1.s8 {d0, d1}, [%[a_ptr]], r0 \n\t" - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[0] \n\t" - "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" - "vdup.s8 d7, d0[6] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[1] \n\t" - "vdup.s8 d7, d0[7] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[2] \n\t" - "vdup.s8 d7, d1[0] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[3] \n\t" - "vdup.s8 d7, d1[1] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0. \n\t" - "vdup.s8 d6, d0[4] \n\t" - "vdup.s8 d7, d1[2] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[5] \n\t" - "vdup.s8 d7, d1[3] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 - + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 1st row + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[0]\n\t" + "vmlal.s16 q5, d7, d4[0]\n\t" + "vmlal.s16 q6, d6, d4[1]\n\t" + "vmlal.s16 q7, d7, d4[1]\n\t" + "vmlal.s16 q8, d6, d4[2]\n\t" + "vmlal.s16 q9, d7, d4[2]\n\t" + "vmlal.s16 q10, d6, d4[3]\n\t" + "vmlal.s16 q11, d7, d4[3]\n\t" + "vmlal.s16 q12, d6, d5[0]\n\t" + "vmlal.s16 q13, d7, d5[0]\n\t" + "vmlal.s16 q14, d6, d5[1]\n\t" + "vmlal.s16 q15, d7, d5[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 2nd row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[2]\n\t" + "vmlal.s16 q5, d7, d5[2]\n\t" + "vmlal.s16 q6, d6, d5[3]\n\t" + "vmlal.s16 q7, d7, d5[3]\n\t" + "vmovl.s8 q2, d1 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" "3: \n\t" // last <2 rows "subs %[kc6], %[kc6], #1 \n\t" "blt 4f \n\t" "vld1.s8 {d0}, [%[a_ptr]] \n\t" - "vld1.s8 {d1}, [%[b_ptr]] \n\t" - "vdup.s8 d2, d0[0] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vdup.s8 d2, d0[1] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vdup.s8 d2, d0[2] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vdup.s8 d2, d0[3] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vdup.s8 d2, d0[4] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vdup.s8 d2, d0[5] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 4 + "vld1.s8 {d3}, [%[b_ptr]] \n\t" + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[0]\n\t" + "vmlal.s16 q5, d7, d4[0]\n\t" + "vmlal.s16 q6, d6, d4[1]\n\t" + "vmlal.s16 q7, d7, d4[1]\n\t" + "vmlal.s16 q8, d6, d4[2]\n\t" + "vmlal.s16 q9, d7, d4[2]\n\t" + "vmlal.s16 q10, d6, d4[3]\n\t" + "vmlal.s16 q11, d7, d4[3]\n\t" + "vmlal.s16 q12, d6, d5[0]\n\t" + "vmlal.s16 q13, d7, d5[0]\n\t" + "vmlal.s16 q14, d6, d5[1]\n\t" + "vmlal.s16 q15, d7, d5[1]\n\t" "4: \n\t" "vst1.32 {q4, q5}, [%[c]], %[step] \n\t" "vst1.32 {q6, q7}, [%[c]], %[step] \n\t" @@ -435,7 +534,8 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, [kc3] "r"(kc3), [kc5] "r"(kc5), [kc6] "r"(kc6), [step] "r"(step) : "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); -#endif +#endif // __aarch64__ +#endif // __ARM_NEON } // 8 bits int inner product @@ -445,8 +545,9 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, int8_t *bias) { #pragma omp parallel for for (int32_t j = 0; j < nc; j += NR) { - for (int32_t i = 0; i < mc; i += MR) { - AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + for (int32_t i = 0; i < mc; i += MR_INT8) { + // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); } } if (alpha != 1) { @@ -474,12 +575,53 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, return; } } +// 8 bits int PackMatrixA_4r +void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer) { + const int8_t *a0, *a1, *a2, *a3; + for (int32_t i = 0; i < m - m_tail; i += MR_INT8) { + a0 = A + i * lda; + a1 = A + (i + 1) * lda; + a2 = A + (i + 2) * lda; + a3 = A + (i + 3) * lda; + for (int32_t j = 0; j < k; ++j) { + *buffer++ = *a0++; + *buffer++ = *a1++; + *buffer++ = *a2++; + *buffer++ = *a3++; + } + } + + if (m_tail != 0) { + a0 = &A(m - m_tail, 0); + a1 = a0 + lda; + a2 = a0 + 2 * lda; + a3 = a0 + 3 * lda; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + break; + default: + break; + } + for (int j = 0; j < k; ++j) { + *buffer++ = *a0++; + *buffer++ = *a1++; + *buffer++ = *a2++; + *buffer++ = *a3++; + } + } +} -// 8 bits int PackMatrixA +// 8 bits int PackMatrixA_6r void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { const int32_t i_length = m - m_tail; - for (int32_t i = 0; i < i_length; i += MR) { + for (int32_t i = 0; i < i_length; i += MR_INT8) { const int8_t *a0 = A + i * lda; const int8_t *a1 = A + (i + 1) * lda; const int8_t *a2 = A + (i + 2) * lda; @@ -539,6 +681,9 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, for (int32_t i = 0; i < k; ++i) { const int8_t *b0 = &B(i, j); #if __ARM_NEON +#if __aarch64__ + // TODO +#else asm volatile( // "pld [%[b0]] \n\t" "vld1.s8 {d0}, [%[b0]] \n\t" @@ -546,6 +691,7 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "q0"); +#endif // __aarch64__ #else *local_buffer++ = *b0++; *local_buffer++ = *b0++; @@ -585,13 +731,13 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, MC = L1 / (KC * sizeof(int8_t)); NC = L2 / (KC * sizeof(int8_t)); - // make sure MC is multiple of MR, and NC is multiple of NR + // make sure MC is multiple of MR_INT8, and NC is multiple of NR if (MC == 0) { - MC = MR; + MC = MR_INT8; } else { int32_t mblock_num = (m + MC - 1) / MC; MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR - 1) / MR * MR; + MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; } // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; if (NC == 0) { @@ -618,7 +764,8 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8); for (int32_t i = 0; i < m; i += MC) { mc = s_min(m - i, MC); - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA_int8); + // PackMatrixA_6r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8); + PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8); if (bias == nullptr) { InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, packedC_int8, &C(i, j), ldc, relu, nullptr); @@ -642,6 +789,10 @@ void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, // C = A * B, 8位 int32_t void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc) { +#if __ARM_NEON +#if __aarch64__ +// TODO +#else int32_t nc1 = nc >> 4; int32_t _nc1 = nc & 15; int32_t step = sizeof(int32_t) * ldc; @@ -695,6 +846,8 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, } } } +#endif // __aarch64__ +#endif // __ARM_NEON } // C = A * B + C diff --git a/src/operators/math/im2col.cpp b/src/operators/math/im2col.cpp index 090ccdf24e214fc86b8a4032df228d50caa65ef9..4c81e7fa3bd4e5ea36f04b453d4f84468745f919 100644 --- a/src/operators/math/im2col.cpp +++ b/src/operators/math/im2col.cpp @@ -28,91 +28,240 @@ namespace math { * [input_channels, filter_height, filter_width, output_height, * output_width] */ -template -class Im2ColFunctor { - public: - void operator()(const framework::Tensor &im, const std::vector &dilation, - const std::vector &stride, - const std::vector &padding, framework::Tensor *col) { - // PADDLE_ENFORCE(im.dims().size() == 3); - // PADDLE_ENFORCE(col->dims().size() == 5); +template <> +void Im2ColFunctor::operator()( + const framework::Tensor &im, const std::vector &dilation, + const std::vector &stride, const std::vector &padding, + framework::Tensor *col) { + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int col_height = col->dims()[3]; + int col_width = col->dims()[4]; + + int channels_col = im_channels * filter_height * filter_width; + const float *im_data = im.data(); + float *col_data = col->data(); +#if __ARM_NEON + const int osize = col_height; + const int isize = im_height; + bool pad1 = padding[0] > 0; + bool pad2 = + (pad1 && padding[1] && + (((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0)); + int fill = isize % 2; + if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 && + dilation[0] == 1 && im_height > 2) { + for (int c = 0; c < im_channels; ++c) { + int oosize = osize * osize; + int nk4 = osize / 4; + int mk4 = osize % 4; + + float *col0 = col_data + 0 * oosize + 2 * osize + 2; + float *col1 = col_data + 1 * oosize + 2 * osize + 1; + float *col2 = col_data + 2 * oosize + 2 * osize; + + float *col3 = col_data + 3 * oosize + osize + 2; + float *col4 = col_data + 4 * oosize + osize + 1; + float *col5 = col_data + 5 * oosize + osize; + + float *col6 = col_data + 6 * oosize + 2; + float *col7 = col_data + 7 * oosize + 1; + float *col8 = col_data + 8 * oosize; + + float32x4_t im1; + const float *im_tmp_data = im_data + osize + 1; + + int rrsize = oosize - osize - 1; + int nr4 = rrsize / 4; + int mr4 = rrsize % 4; + for (int i = 0; i < nr4; ++i) { + im1 = vld1q_f32(im_tmp_data); + vst1q_f32(col0, im1); + vst1q_f32(col1, im1); + vst1q_f32(col2, im1); + vst1q_f32(col3, im1); + vst1q_f32(col4, im1); + vst1q_f32(col5, im1); + vst1q_f32(col6, im1); + vst1q_f32(col7, im1); + vst1q_f32(col8, im1); + + col0 += 4; + col1 += 4; + col2 += 4; + col3 += 4; + col4 += 4; + col5 += 4; + col6 += 4; + col7 += 4; + col8 += 4; + + im_tmp_data += 4; + } + for (int i = 0; i < mr4; ++i) { + *col0 = *im_tmp_data; + *col1 = *im_tmp_data; + *col2 = *im_tmp_data; + *col3 = *im_tmp_data; + *col4 = *im_tmp_data; + *col5 = *im_tmp_data; + *col6 = *im_tmp_data; + *col7 = *im_tmp_data; + *col8 = *im_tmp_data; + + col0++; + col1++; + col2++; + col3++; + col4++; + col5++; + col6++; + col7++; + col8++; + + im_tmp_data++; + } - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; - int filter_height = col->dims()[1]; - int filter_width = col->dims()[2]; - int col_height = col->dims()[3]; - int col_width = col->dims()[4]; - - // PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - // - - // ((dilation[0] * (filter_height - 1) - // + 1))) / - // stride[0] + - // 1, - // col_height, - // "Output_height and - // padding(padding_up, padding_down) - // are " "inconsistent."); - // PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - // - - // ((dilation[1] * (filter_width - 1) - // + 1))) / - // stride[1] + - // 1, - // col_width, - // "Output_height and - // padding(padding_up, padding_down) - // are " "inconsistent."); + im_tmp_data = im_data + 1; + col0 = col_data + 0 * oosize + osize + 2; + col1 = col_data + 1 * oosize + osize + 1; + col2 = col_data + 2 * oosize + osize; + + col3 = col_data + 3 * oosize + 2; + col4 = col_data + 4 * oosize + 1; + col5 = col_data + 5 * oosize; + + for (int i = 0; i < nk4; i++) { + im1 = vld1q_f32(im_tmp_data); + vst1q_f32(col0, im1); + vst1q_f32(col1, im1); + vst1q_f32(col2, im1); + vst1q_f32(col3, im1); + vst1q_f32(col4, im1); + vst1q_f32(col5, im1); + + col0 += 4; + col1 += 4; + col2 += 4; + col3 += 4; + col4 += 4; + col5 += 4; + im_tmp_data += 4; + } - int channels_col = im_channels * filter_height * filter_width; - const T *im_data = im.data(); - T *col_data = col->data(); -#if __ARM_NEON - const int osize = col_height; - const int isize = im_height; - bool pad1 = padding[0] > 0; - bool pad2 = - (pad1 && padding[1] && - (((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0)); - int fill = isize % 2; - if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 && - dilation[0] == 1 && im_height > 2) { - for (int c = 0; c < im_channels; ++c) { - int oosize = osize * osize; - int nk4 = osize / 4; - int mk4 = osize % 4; - - float *col0 = col_data + 0 * oosize + 2 * osize + 2; - float *col1 = col_data + 1 * oosize + 2 * osize + 1; - float *col2 = col_data + 2 * oosize + 2 * osize; - - float *col3 = col_data + 3 * oosize + osize + 2; - float *col4 = col_data + 4 * oosize + osize + 1; - float *col5 = col_data + 5 * oosize + osize; - - float *col6 = col_data + 6 * oosize + 2; - float *col7 = col_data + 7 * oosize + 1; - float *col8 = col_data + 8 * oosize; - - float32x4_t im1; - const float *im_tmp_data = im_data + osize + 1; - - int rrsize = oosize - osize - 1; - int nr4 = rrsize / 4; - int mr4 = rrsize % 4; - for (int i = 0; i < nr4; ++i) { - im1 = vld1q_f32(im_tmp_data); - vst1q_f32(col0, im1); - vst1q_f32(col1, im1); - vst1q_f32(col2, im1); - vst1q_f32(col3, im1); - vst1q_f32(col4, im1); - vst1q_f32(col5, im1); - vst1q_f32(col6, im1); - vst1q_f32(col7, im1); - vst1q_f32(col8, im1); + for (int i = 0; i < mk4; i++) { + *col0 = *im_tmp_data; + *col1 = *im_tmp_data; + *col2 = *im_tmp_data; + *col3 = *im_tmp_data; + *col4 = *im_tmp_data; + *col5 = *im_tmp_data; + col0++; + col1++; + col2++; + col3++; + col4++; + col5++; + + im_tmp_data++; + } + + // fill 0 1 11; + for (int i = 0; i < osize; ++i) { + col_data[0 * oosize + i * osize] = 0.0; + col_data[3 * oosize + i * osize] = 0.0; + col_data[6 * oosize + i * osize] = 0.0; + + col_data[2 * oosize + osize - 1 + i * osize] = 0.0; + col_data[5 * oosize + osize - 1 + i * osize] = 0.0; + col_data[8 * oosize + osize - 1 + i * osize] = 0.0; + } + + col_data[0 * oosize + osize + 1] = im_data[0]; + col_data[3 * oosize + 1] = im_data[0]; + col_data[6 * oosize + 1] = im_data[osize]; + + col_data[1 * oosize + osize] = im_data[0]; + col_data[4 * oosize] = im_data[0]; + col_data[7 * oosize] = im_data[osize]; + + float32x4_t zero4; + zero4 = vdupq_n_f32(0.0); + auto col_z0 = col_data; + auto col_z1 = col_data + oosize; + auto col_z2 = col_data + 2 * oosize; + auto col_z6 = col_data + 6 * oosize + osize * (osize - 1); + auto col_z7 = col_data + 7 * oosize + osize * (osize - 1); + auto col_z8 = col_data + 8 * oosize + osize * (osize - 1); + + for (int i = 0; i < nk4; ++i) { + vst1q_f32(col_z0, zero4); + vst1q_f32(col_z1, zero4); + vst1q_f32(col_z2, zero4); + vst1q_f32(col_z6, zero4); + vst1q_f32(col_z7, zero4); + vst1q_f32(col_z8, zero4); + + col_z0 += 4; + col_z1 += 4; + col_z2 += 4; + col_z6 += 4; + col_z7 += 4; + col_z8 += 4; + } + + for (int i = 0; i < mk4; ++i) { + col_z0[i] = 0.0; + col_z1[i] = 0.0; + col_z2[i] = 0.0; + col_z6[i] = 0.0; + col_z7[i] = 0.0; + col_z8[i] = 0.0; + } + col_data += 9 * oosize; + im_data += isize * isize; + } + } else if (stride[0] == 2 && filter_height == 3 && pad1 && dilation[0] == 1 && + im_height > 2) { + for (int c = 0; c < im_channels; ++c) { + int oosize = osize * osize; + int nk4 = osize / 4; + int mk4 = osize % 4; + + // 3 2 3 1 0 1 3 2 3 + float *col0 = col_data + 0 * oosize + osize + 1; + float *col1 = col_data + 1 * oosize + osize; + float *col2 = col_data + 2 * oosize + osize; + + float *col3 = col_data + 3 * oosize + 1; + float *col4 = col_data + 4 * oosize; + float *col5 = col_data + 5 * oosize; + + float *col6 = col_data + 6 * oosize + 1; + float *col7 = col_data + 7 * oosize; + float *col8 = col_data + 8 * oosize; + + float32x4x2_t im01; + float32x4x2_t im23; + const float *im_tmp_data0 = im_data; + const float *im_tmp_data2 = im_data + isize; + + for (int j = 0; j < osize; ++j) { + for (int i = 0; i < nk4; ++i) { + im01 = vld2q_f32(im_tmp_data0); + im23 = vld2q_f32(im_tmp_data2); + vst1q_f32(col0, im23.val[1]); + vst1q_f32(col1, im23.val[0]); + vst1q_f32(col2, im23.val[1]); + vst1q_f32(col3, im01.val[1]); + vst1q_f32(col4, im01.val[0]); + vst1q_f32(col5, im01.val[1]); + vst1q_f32(col6, im23.val[1]); + vst1q_f32(col7, im23.val[0]); + vst1q_f32(col8, im23.val[1]); col0 += 4; col1 += 4; @@ -124,18 +273,21 @@ class Im2ColFunctor { col7 += 4; col8 += 4; - im_tmp_data += 4; + im_tmp_data0 += 8; + im_tmp_data2 += 8; } - for (int i = 0; i < mr4; ++i) { - *col0 = *im_tmp_data; - *col1 = *im_tmp_data; - *col2 = *im_tmp_data; - *col3 = *im_tmp_data; - *col4 = *im_tmp_data; - *col5 = *im_tmp_data; - *col6 = *im_tmp_data; - *col7 = *im_tmp_data; - *col8 = *im_tmp_data; + const float *im_tmp_data1 = im_tmp_data0 + 1; + const float *im_tmp_data3 = im_tmp_data2 + 1; + for (int i = 0; i < mk4; ++i) { + *col0 = *im_tmp_data3; + *col1 = *im_tmp_data2; + *col2 = *im_tmp_data3; + *col3 = *im_tmp_data1; + *col4 = *im_tmp_data0; + *col5 = *im_tmp_data1; + *col6 = *im_tmp_data3; + *col7 = *im_tmp_data2; + *col8 = *im_tmp_data3; col0++; col1++; @@ -146,271 +298,215 @@ class Im2ColFunctor { col6++; col7++; col8++; - - im_tmp_data++; + im_tmp_data0 += 2; + im_tmp_data1 += 2; + im_tmp_data2 += 2; + im_tmp_data3 += 2; } - - im_tmp_data = im_data + 1; - col0 = col_data + 0 * oosize + osize + 2; - col1 = col_data + 1 * oosize + osize + 1; - col2 = col_data + 2 * oosize + osize; - - col3 = col_data + 3 * oosize + 2; - col4 = col_data + 4 * oosize + 1; - col5 = col_data + 5 * oosize; - - for (int i = 0; i < nk4; i++) { - im1 = vld1q_f32(im_tmp_data); - vst1q_f32(col0, im1); - vst1q_f32(col1, im1); - vst1q_f32(col2, im1); - vst1q_f32(col3, im1); - vst1q_f32(col4, im1); - vst1q_f32(col5, im1); - - col0 += 4; - col1 += 4; - col2 += 4; - col3 += 4; - col4 += 4; - col5 += 4; - im_tmp_data += 4; - } - - for (int i = 0; i < mk4; i++) { - *col0 = *im_tmp_data; - *col1 = *im_tmp_data; - *col2 = *im_tmp_data; - *col3 = *im_tmp_data; - *col4 = *im_tmp_data; - *col5 = *im_tmp_data; - col0++; - col1++; - col2++; - col3++; - col4++; - col5++; - - im_tmp_data++; - } - - // fill 0 1 11; - for (int i = 0; i < osize; ++i) { - col_data[0 * oosize + i * osize] = 0.0; - col_data[3 * oosize + i * osize] = 0.0; - col_data[6 * oosize + i * osize] = 0.0; - + im_tmp_data0 += (isize - fill); + im_tmp_data2 += (isize - fill); + } + for (int i = 0; i < osize; ++i) { + col_data[0 * oosize + i * osize] = 0.0; + col_data[3 * oosize + i * osize] = 0.0; + col_data[6 * oosize + i * osize] = 0.0; + if (pad2) { col_data[2 * oosize + osize - 1 + i * osize] = 0.0; col_data[5 * oosize + osize - 1 + i * osize] = 0.0; col_data[8 * oosize + osize - 1 + i * osize] = 0.0; } - - col_data[0 * oosize + osize + 1] = im_data[0]; - col_data[3 * oosize + 1] = im_data[0]; - col_data[6 * oosize + 1] = im_data[osize]; - - col_data[1 * oosize + osize] = im_data[0]; - col_data[4 * oosize] = im_data[0]; - col_data[7 * oosize] = im_data[osize]; - - float32x4_t zero4; - zero4 = vdupq_n_f32(0.0); - auto col_z0 = col_data; - auto col_z1 = col_data + oosize; - auto col_z2 = col_data + 2 * oosize; - auto col_z6 = col_data + 6 * oosize + osize * (osize - 1); - auto col_z7 = col_data + 7 * oosize + osize * (osize - 1); - auto col_z8 = col_data + 8 * oosize + osize * (osize - 1); - - for (int i = 0; i < nk4; ++i) { - vst1q_f32(col_z0, zero4); - vst1q_f32(col_z1, zero4); - vst1q_f32(col_z2, zero4); + } + float32x4_t zero4; + zero4 = vdupq_n_f32(0.0); + auto col_z0 = col_data; + auto col_z1 = col_data + oosize; + auto col_z2 = col_data + 2 * oosize; + auto col_z6 = col_data + 6 * oosize + osize * (osize - 1); + auto col_z7 = col_data + 7 * oosize + osize * (osize - 1); + auto col_z8 = col_data + 8 * oosize + osize * (osize - 1); + + for (int i = 0; i < nk4; ++i) { + vst1q_f32(col_z0, zero4); + vst1q_f32(col_z1, zero4); + vst1q_f32(col_z2, zero4); + if (pad2) { vst1q_f32(col_z6, zero4); vst1q_f32(col_z7, zero4); vst1q_f32(col_z8, zero4); - - col_z0 += 4; - col_z1 += 4; - col_z2 += 4; - col_z6 += 4; - col_z7 += 4; - col_z8 += 4; } + col_z0 += 4; + col_z1 += 4; + col_z2 += 4; + col_z6 += 4; + col_z7 += 4; + col_z8 += 4; + } - for (int i = 0; i < mk4; ++i) { - col_z0[i] = 0.0; - col_z1[i] = 0.0; - col_z2[i] = 0.0; + for (int i = 0; i < mk4; ++i) { + col_z0[i] = 0.0; + col_z1[i] = 0.0; + col_z2[i] = 0.0; + if (pad2) { col_z6[i] = 0.0; col_z7[i] = 0.0; col_z8[i] = 0.0; } - col_data += 9 * oosize; - im_data += isize * isize; } - } else if (stride[0] == 2 && filter_height == 3 && pad1 && - dilation[0] == 1 && im_height > 2) { - for (int c = 0; c < im_channels; ++c) { - int oosize = osize * osize; - int nk4 = osize / 4; - int mk4 = osize % 4; - - // 3 2 3 1 0 1 3 2 3 - float *col0 = col_data + 0 * oosize + osize + 1; - float *col1 = col_data + 1 * oosize + osize; - float *col2 = col_data + 2 * oosize + osize; - - float *col3 = col_data + 3 * oosize + 1; - float *col4 = col_data + 4 * oosize; - float *col5 = col_data + 5 * oosize; - - float *col6 = col_data + 6 * oosize + 1; - float *col7 = col_data + 7 * oosize; - float *col8 = col_data + 8 * oosize; - - float32x4x2_t im01; - float32x4x2_t im23; - const float *im_tmp_data0 = im_data; - const float *im_tmp_data2 = im_data + isize; - - for (int j = 0; j < osize; ++j) { - for (int i = 0; i < nk4; ++i) { - im01 = vld2q_f32(im_tmp_data0); - im23 = vld2q_f32(im_tmp_data2); - vst1q_f32(col0, im23.val[1]); - vst1q_f32(col1, im23.val[0]); - vst1q_f32(col2, im23.val[1]); - vst1q_f32(col3, im01.val[1]); - vst1q_f32(col4, im01.val[0]); - vst1q_f32(col5, im01.val[1]); - vst1q_f32(col6, im23.val[1]); - vst1q_f32(col7, im23.val[0]); - vst1q_f32(col8, im23.val[1]); - - col0 += 4; - col1 += 4; - col2 += 4; - col3 += 4; - col4 += 4; - col5 += 4; - col6 += 4; - col7 += 4; - col8 += 4; - - im_tmp_data0 += 8; - im_tmp_data2 += 8; - } - const float *im_tmp_data1 = im_tmp_data0 + 1; - const float *im_tmp_data3 = im_tmp_data2 + 1; - for (int i = 0; i < mk4; ++i) { - *col0 = *im_tmp_data3; - *col1 = *im_tmp_data2; - *col2 = *im_tmp_data3; - *col3 = *im_tmp_data1; - *col4 = *im_tmp_data0; - *col5 = *im_tmp_data1; - *col6 = *im_tmp_data3; - *col7 = *im_tmp_data2; - *col8 = *im_tmp_data3; - - col0++; - col1++; - col2++; - col3++; - col4++; - col5++; - col6++; - col7++; - col8++; - im_tmp_data0 += 2; - im_tmp_data1 += 2; - im_tmp_data2 += 2; - im_tmp_data3 += 2; - } - im_tmp_data0 += (isize - fill); - im_tmp_data2 += (isize - fill); - } - for (int i = 0; i < osize; ++i) { - col_data[0 * oosize + i * osize] = 0.0; - col_data[3 * oosize + i * osize] = 0.0; - col_data[6 * oosize + i * osize] = 0.0; - if (pad2) { - col_data[2 * oosize + osize - 1 + i * osize] = 0.0; - col_data[5 * oosize + osize - 1 + i * osize] = 0.0; - col_data[8 * oosize + osize - 1 + i * osize] = 0.0; - } - } - float32x4_t zero4; - zero4 = vdupq_n_f32(0.0); - auto col_z0 = col_data; - auto col_z1 = col_data + oosize; - auto col_z2 = col_data + 2 * oosize; - auto col_z6 = col_data + 6 * oosize + osize * (osize - 1); - auto col_z7 = col_data + 7 * oosize + osize * (osize - 1); - auto col_z8 = col_data + 8 * oosize + osize * (osize - 1); - for (int i = 0; i < nk4; ++i) { - vst1q_f32(col_z0, zero4); - vst1q_f32(col_z1, zero4); - vst1q_f32(col_z2, zero4); - if (pad2) { - vst1q_f32(col_z6, zero4); - vst1q_f32(col_z7, zero4); - vst1q_f32(col_z8, zero4); - } - col_z0 += 4; - col_z1 += 4; - col_z2 += 4; - col_z6 += 4; - col_z7 += 4; - col_z8 += 4; - } + col_data[1 * oosize + osize] = im_data[isize]; + for (int i = 1; i < osize; ++i) { + col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1]; + } + col_data[4 * oosize] = im_data[0]; + col_data[7 * oosize] = im_data[isize]; - for (int i = 0; i < mk4; ++i) { - col_z0[i] = 0.0; - col_z1[i] = 0.0; - col_z2[i] = 0.0; - if (pad2) { - col_z6[i] = 0.0; - col_z7[i] = 0.0; - col_z8[i] = 0.0; - } - } + col_data += 9 * oosize; + im_data += isize * isize; + } + } else { + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int c_im = c / (filter_width * filter_height); + for (int h = 0; h < col_height; ++h) { + int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; + for (int w = 0; w < col_width; ++w) { + int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; + int col_idx = (c * col_height + h) * col_width + w; + int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; - col_data[1 * oosize + osize] = im_data[isize]; - for (int i = 1; i < osize; ++i) { - col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1]; + col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || + im_col_idx < 0 || im_col_idx >= im_width) + ? static_cast(0) + : im_data[im_idx]; } - col_data[4 * oosize] = im_data[0]; - col_data[7 * oosize] = im_data[isize]; - - col_data += 9 * oosize; - im_data += isize * isize; + } + } + } +#else + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int c_im = c / (filter_width * filter_height); + for (int h = 0; h < col_height; ++h) { + int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; + for (int w = 0; w < col_width; ++w) { + int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; + int col_idx = (c * col_height + h) * col_width + w; + int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; + + col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || + im_col_idx < 0 || im_col_idx >= im_width) + ? static_cast(0) + : im_data[im_idx]; + } + } + } +#endif +} + +void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height, + const int im_width, const int col_height, const int col_width, + const int padding_h, const int padding_w, const int stride_h, + const int stride_w, const int kh, const int kw) { + int h = padding_h - kh; + int w = padding_w - kw; + int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0; + int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0; + int start_height = kh + col_start_height * stride_h - padding_h; + int start_width = kw + col_start_width * stride_w - padding_w; + + int end_height = (col_height - col_start_height) * stride_h + start_height; + end_height = end_height > im_height ? im_height : end_height; + int end_width = (col_width - col_start_width) * stride_w + start_width; + end_width = end_width > im_width ? im_width : end_width; + int extract = (end_width - start_width + stride_w - 1) / stride_w; + + im_data += start_height * im_width + start_width; + col_data += col_start_height * col_width + col_start_width; + for (int i = start_height; i < end_height; i += stride_h) { + if (stride_w == 1) { + memcpy(col_data, im_data, extract * sizeof(int8_t)); + } else if (stride_w == 2) { + int s = 0; +#if __ARM_NEON + for (; s < extract - 15; s += 16) { + int8x16x2_t img = vld2q_s8(im_data + s * 2); + vst1q_s8(col_data + s, img.val[0]); + } +#endif + for (; s < extract; ++s) { + col_data[s] = im_data[s * 2]; + } + } else if (stride_w == 3) { + int s = 0; +#if __ARM_NEON + for (; s < extract - 15; s += 16) { + int8x16x3_t img = vld3q_s8(im_data + s * 3); + vst1q_s8(col_data + s, img.val[0]); + } +#endif + for (; s < extract; ++s) { + col_data[s] = im_data[s * 3]; + } + } else if (stride_w == 4) { + int s = 0; +#if __ARM_NEON + for (; s < extract - 15; s += 16) { + int8x16x4_t img = vld4q_s8(im_data + s * 4); + vst1q_s8(col_data + s, img.val[0]); + } +#endif + for (; s < extract; ++s) { + col_data[s] = im_data[s * 4]; } } else { - for (int c = 0; c < channels_col; ++c) { - int w_offset = c % filter_width; - int h_offset = (c / filter_width) % filter_height; - int c_im = c / (filter_width * filter_height); - for (int h = 0; h < col_height; ++h) { - int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; - for (int w = 0; w < col_width; ++w) { - int im_col_idx = - w * stride[1] - padding[1] + w_offset * dilation[1]; - int col_idx = (c * col_height + h) * col_width + w; - int im_idx = - (im_row_idx + c_im * im_height) * im_width + im_col_idx; - - col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || - im_col_idx < 0 || im_col_idx >= im_width) - ? static_cast(0) - : im_data[im_idx]; - } + PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1, 2, 3 and 4."); + } + im_data += im_width * stride_h; + col_data += col_width; + } +} + +/* + * im = [input_channels, input_height, input_width] + * col = + * [input_channels, filter_height, filter_width, output_height, + * output_width] + */ +template <> +void Im2ColFunctor::operator()( + const framework::Tensor &im, const std::vector &dilation, + const std::vector &stride, const std::vector &padding, + framework::Tensor *col) { + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int col_height = col->dims()[3]; + int col_width = col->dims()[4]; + + int channels_col = im_channels * filter_height * filter_width; + const int8_t *im_data = im.data(); + int8_t *col_data = col->data(); +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { + // pad 0 + memset(col_data, 0, col->numel() * sizeof(int8_t)); + for (int ic = 0; ic < im_channels; ++ic) { + for (int kh = 0; kh < filter_height; ++kh) { + for (int kw = 0; kw < filter_width; ++kw) { + ExtractToImg(im_data, col_data, im_height, im_width, col_height, + col_width, padding[0], padding[1], stride[0], stride[1], + kh, kw); + col_data += col_height * col_width; } } + im_data += im_height * im_width; } -#else + } else { +#endif for (int c = 0; c < channels_col; ++c) { int w_offset = c % filter_width; int h_offset = (c / filter_width) % filter_height; @@ -424,14 +520,15 @@ class Im2ColFunctor { col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || im_col_idx < 0 || im_col_idx >= im_width) - ? static_cast(0) + ? static_cast(0) : im_data[im_idx]; } } } -#endif +#if defined(__ARM_NEON__) || defined(__ARM_NEON) } -}; +#endif +} /* * im = [input_channels, input_height, input_width] @@ -456,27 +553,6 @@ class Col2ImFunctor { int col_height = col.dims()[3]; int col_width = col.dims()[4]; - // PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - // - - // ((dilation[0] * (filter_height - 1) - // + 1))) / - // stride[0] + - // 1, - // col_height, - // "Output_height and - // padding(padding_up, padding_down) - // are " "inconsistent."); - // PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - // - - // ((dilation[1] * (filter_width - 1) - // + 1))) / - // stride[1] + - // 1, - // col_width, - // "Output_height and - // padding(padding_up, padding_down) - // are " "inconsistent."); - int channels_col = im_channels * filter_height * filter_width; T *im_data = im->data(); @@ -503,9 +579,9 @@ class Col2ImFunctor { }; template class Im2ColFunctor; -// template class Im2ColFunctor; +template class Im2ColFunctor; template class Col2ImFunctor; -template class Col2ImFunctor; +template class Col2ImFunctor; /* * im = [input_channels, input_height, input_width] @@ -519,8 +595,6 @@ class Im2ColFunctor { void operator()(const framework::Tensor &im, const std::vector &dilation, const std::vector &stride, const std::vector &padding, framework::Tensor *col) { - // PADDLE_ENFORCE(im.dims().size() == 3); - // PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; @@ -529,19 +603,6 @@ class Im2ColFunctor { int col_height = col->dims()[0]; int col_width = col->dims()[1]; - // PADDLE_ENFORCE_EQ( - // (im_height + padding[0] + padding[2] - - // filter_height) / stride[0] - // + 1, col_height, "Output_height and - // padding(padding_up, - // padding_down) are " "inconsistent."); - // PADDLE_ENFORCE_EQ( - // (im_width + padding[1] + padding[3] - - // filter_width) / stride[1] + - // 1, col_width, "col_width and padding(padding_left, - // padding_right) - // are " "inconsistent."); - const T *im_data = im.data(); T *col_data = col->data(); @@ -593,8 +654,6 @@ class Col2ImFunctor { const std::vector &dilation, const std::vector &stride, const std::vector &padding, framework::Tensor *im) { - // PADDLE_ENFORCE(im->dims().size() == 3); - // PADDLE_ENFORCE(col.dims().size() == 5); int im_channels = im->dims()[0]; int im_height = im->dims()[1]; int im_width = im->dims()[2]; @@ -603,19 +662,6 @@ class Col2ImFunctor { int col_height = col.dims()[0]; int col_width = col.dims()[1]; - // PADDLE_ENFORCE_EQ( - // (im_height + padding[0] + padding[2] - - // filter_height) / stride[0] - // + 1, col_height, "Output_height and - // padding(padding_up, - // padding_down) are " "inconsistent."); - // PADDLE_ENFORCE_EQ( - // (im_width + padding[1] + padding[3] - - // filter_width) / stride[1] + - // 1, col_width, "col_width and padding(padding_left, - // padding_right) - // are " "inconsistent."); - T *im_data = im->data(); const T *col_data = col.data(); @@ -655,9 +701,7 @@ class Col2ImFunctor { }; template class Im2ColFunctor; -template class Im2ColFunctor; template class Col2ImFunctor; -template class Col2ImFunctor; } // namespace math } // namespace operators diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index fc4c385add5ccf30ebe42695fb616e41deb1a827..4365bf5716b8b5811f6ac66217b2fe74ae116f52 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -15,12 +15,31 @@ limitations under the License. */ #include "operators/math/math_function.h" #include #include +#include "framework/data_type.h" +#include "framework/tensor.h" #include "operators/math/gemm.h" namespace paddle_mobile { namespace operators { namespace math { +struct TensorSetConstant { + TensorSetConstant(framework::Tensor *tensor, float value) + : tensor_(tensor), value_(value) {} + template + void apply() const { + auto *begin = tensor_->mutable_data(); + std::fill(begin, begin + tensor_->numel(), static_cast(value_)); + } + framework::Tensor *tensor_; + float value_; +}; + +void set_constant(framework::Tensor *tensor, float value) { + framework::VisitDataType(framework::ToDataType(tensor->type()), + TensorSetConstant(tensor, value)); +} + template <> void matmul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index b70dfb43ba11400e555365485f2a632c854279ac..b91242c1868398e4541c3727567a905e5b0c8714 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -15,12 +15,15 @@ limitations under the License. */ #pragma once #include +#include #include "framework/tensor.h" namespace paddle_mobile { namespace operators { namespace math { +void set_constant(framework::Tensor *tensor, float value); + template void matmul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, diff --git a/src/operators/math/pad.cpp b/src/operators/math/pad.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d8153c445b007e8c5a902301e2724f22c8f6add1 --- /dev/null +++ b/src/operators/math/pad.cpp @@ -0,0 +1,52 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/math/pad.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +template +class PadFunctor { + public: + void operator()(const framework::Tensor &input, const int pad_h, + const int pad_w, framework::Tensor *output) { + const T *in_data = input.data(); + T *out_data = output->mutable_data(); + const framework::DDim &input_shape = input.dims(); + const framework::DDim &output_shape = output->dims(); + // fill output with 0 + memset(out_data, 0, sizeof(T) * output->numel()); + // should make sure the shape of output is match with input + for (int i = 0; i < input_shape[0]; ++i) { + for (int c = 0; c < input_shape[1]; ++c) { + out_data += pad_h * output_shape[3]; + for (int h = 0; h < input_shape[2]; ++h) { + memcpy(out_data + pad_w, in_data, sizeof(T) * input_shape[3]); + out_data += output_shape[3]; + in_data += input_shape[3]; + } + out_data += pad_h * output_shape[3]; + } + } + } +}; + +template class PadFunctor; +template class PadFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/pad.h b/src/operators/math/pad.h new file mode 100644 index 0000000000000000000000000000000000000000..0f5a4b89674f92746f75bb1e4f9364d5a16fdba2 --- /dev/null +++ b/src/operators/math/pad.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "framework/tensor.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +template +class PadFunctor { + public: + void operator()(const framework::Tensor &input, const int pad_h, + const int pad_w, framework::Tensor *output); +}; + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/vol2col.cpp b/src/operators/math/vol2col.cpp index afee3f7f85a6b2b3f84e9c3430211c4d97656d1c..9311e9e2291709631bc8ee07d2cc94f9ca99f4c2 100644 --- a/src/operators/math/vol2col.cpp +++ b/src/operators/math/vol2col.cpp @@ -32,9 +32,6 @@ class Vol2ColFunctor { void operator()(const Tensor &vol, const std::vector &dilations, const std::vector &strides, const std::vector &paddings, Tensor *col) const { - // PADDLE_ENFORCE(vol.dims().size() == 4); - // PADDLE_ENFORCE(col->dims().size() == 7); - int input_channels = vol.dims()[0]; int input_depth = vol.dims()[1]; int input_height = vol.dims()[2]; @@ -48,32 +45,6 @@ class Vol2ColFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; - // PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - - // ((dilations[0] * (filter_depth - 1) - // + 1))) / - // strides[0] + - // 1, - // output_depth, - // "input_depth and output_depth are " - // "mismatching."); - // PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - - // ((dilations[1] * (filter_height - - // 1) + 1))) / - // strides[1] + - // 1, - // output_height, - // "input_height and output_height are - // " - // "mismatching."); - // PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - - // ((dilations[2] * (filter_width - 1) - // + 1))) / - // strides[2] + - // 1, - // output_width, - // "input_width and output_width are " - // "mismatching."); - const T *vol_data = vol.data(); T *col_data = col->data(); @@ -119,9 +90,6 @@ class Col2VolFunctor { void operator()(const Tensor &col, const std::vector &dilations, const std::vector &strides, const std::vector &paddings, Tensor *vol) const { - // PADDLE_ENFORCE(vol->dims().size() == 4); - // PADDLE_ENFORCE(col.dims().size() == 7); - int input_channels = vol->dims()[0]; int input_depth = vol->dims()[1]; int input_height = vol->dims()[2]; @@ -135,31 +103,6 @@ class Col2VolFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; - // PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - - // ((dilations[0] * (filter_depth - 1) - // + 1))) / - // strides[0] + - // 1, - // output_depth, - // "input_depth and output_depth are " - // "mismatching."); - // PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - - // ((dilations[1] * (filter_height - - // 1) + 1))) / - // strides[1] + - // 1, - // output_height, - // "input_height and output_height are - // " - // "mismatching."); - // PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - - // ((dilations[2] * (filter_width - 1) - // + 1))) / - // strides[2] + - // 1, - // output_width, - // "input_width and output_width are " - // "mismatching."); T *vol_data = vol->data(); const T *col_data = col.data(); @@ -195,9 +138,9 @@ class Col2VolFunctor { }; template class Vol2ColFunctor; -template class Vol2ColFunctor; +template class Vol2ColFunctor; template class Col2VolFunctor; -template class Col2VolFunctor; +template class Col2VolFunctor; } // namespace math } // namespace operators diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 2e049226282ced4bad60b48eee68e2e4deae6706..568cf77b8e4e81732cd9a783c1a9ea64d347102b 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1063,6 +1063,42 @@ class FetchParam : public OpParam { RType *out_; }; +#ifdef FILL_CONSTANT_OP +template +class FillConstantParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FillConstantParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs, + const Scope &scope) { + out_var_ = OutVarFrom(outputs, scope); + out_ = OutFrom(outputs, scope); + dtype_ = GetAttr("dtype", attrs); + shape_ = GetAttr>("shape", attrs); + value_ = GetAttr("value", attrs); + } + + Variable *OutVar() const { return out_var_; } + + RType *Out() const { return out_; } + + const int &DataDtype() const { return dtype_; } + + const vector &Shape() const { return shape_; } + + const float &Value() const { return value_; } + + private: + Variable *out_var_; + RType *out_; + int dtype_; + vector shape_; + float value_; +}; +#endif + #ifdef TRANSPOSE_OP template class TransposeParam : public OpParam { @@ -2294,6 +2330,7 @@ class ShapeParam : public OpParam { }; #endif +#ifdef QUANT_OP template class QuantizeParam : public OpParam { typedef typename DtypeTensorTrait::gtype GType; @@ -2304,14 +2341,12 @@ class QuantizeParam : public OpParam { const AttributeMap &attrs, const Scope &scope) { input_ = InputXFrom(inputs, scope); out_ = OutFrom(outputs, scope); - if (HasAttr("is_static", attrs)) { - is_static_ = GetAttr("is_static", attrs); - } // online // scale = max(abs(x)) online_scale_ = GetVarValue("OutScale", outputs, scope); // offline if (HasAttr("static_scale", attrs)) { + is_static_ = true; static_scale_ = GetAttr("static_scale", attrs); } // x = round(scale * x) @@ -2333,9 +2368,11 @@ class QuantizeParam : public OpParam { float static_scale_ = 1.0f; // round method type // nearest_zero and nearest_even is valid currently - RoundType round_type_ = ROUND_NEAREST_TO_EVEN; + RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; }; +#endif +#ifdef DEQUANT_OP template class DequantizeParam : public OpParam { typedef typename DtypeTensorTrait::gtype GType; @@ -2363,6 +2400,7 @@ class DequantizeParam : public OpParam { RType *activation_scale_; float weight_scale_; }; +#endif } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/quantize_op.cpp b/src/operators/quantize_op.cpp index 7958b054de3665132b52582b8bd4126413c0597a..865539d7d26de41b319b4d82ed168b2ec74d722d 100644 --- a/src/operators/quantize_op.cpp +++ b/src/operators/quantize_op.cpp @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef QUANT_OP + #include "operators/quantize_op.h" #include @@ -33,3 +35,5 @@ namespace ops = paddle_mobile::operators; #ifdef PADDLE_MOBILE_CPU REGISTER_OPERATOR_CPU(quantize, ops::QuantizeOp); #endif + +#endif diff --git a/src/operators/quantize_op.h b/src/operators/quantize_op.h index 2b0d2f8e321b9e15324e5aa2b38ba50fb4f7aebf..ca04c1213a5cdcb44082848fb45b1ade3f19086f 100644 --- a/src/operators/quantize_op.h +++ b/src/operators/quantize_op.h @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef QUANT_OP + #pragma once #include @@ -40,3 +42,5 @@ class QuantizeOp : public framework::OperatorWithKernel< } // namespace operators } // namespace paddle_mobile + +#endif diff --git a/src/operators/sum_op.cpp b/src/operators/sum_op.cpp index f821364b92f74534b76ea6069e94a8233ee0a769..2e10363b07498128b5573e27a3d63b59c454d8b6 100644 --- a/src/operators/sum_op.cpp +++ b/src/operators/sum_op.cpp @@ -26,7 +26,7 @@ void SumOp::InferShape() const { auto inputs = this->param_.Inputs(); const size_t n = inputs.size(); - std::vector inputs_dims; + std::vector inputs_dims; inputs_dims.reserve(n); for (int i = 0; i < n; i++) { inputs_dims.push_back(inputs[i]->dims()); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index bbb2a81613deb0ddd4e1c769805a757f8a4665c3..64324b08a572e9b37ba8814ae5eb0dbfaa88088c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -192,6 +192,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-polygon-box-transform-op operators/test_polygon_box_transform_op.cpp test_helper.h test_include.h) target_link_libraries(test-polygon-box-transform-op paddle-mobile) + # gen test + ADD_EXECUTABLE(test-fill-constant-op operators/test_fill_constant_op.cpp test_helper.h test_include.h) + target_link_libraries(test-fill-constant-op paddle-mobile) + # gen test ADD_EXECUTABLE(test-reshape-op operators/test_reshape_op.cpp test_helper.h test_include.h) target_link_libraries(test-reshape-op paddle-mobile) @@ -216,6 +220,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-dequantize-op operators/test_dequantize_op.cpp test_helper.h test_include.h) target_link_libraries(test-dequantize-op paddle-mobile) + # test int8 conv op + ADD_EXECUTABLE(test-int8-conv-op operators/test_int8_conv_op.cpp test_helper.h test_include.h) + target_link_libraries(test-int8-conv-op paddle-mobile) + # gen test log ADD_EXECUTABLE(test-log common/test_log.cpp) target_link_libraries(test-log paddle-mobile) diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index a2f030eeac5c2584b33fad2b082b9d5513707260..c88a78974c330ec270fbcb3f5c28e368ef16440e 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -25,27 +25,31 @@ int main() { paddle_mobile::PaddleMobile paddle_mobile; #endif - paddle_mobile.SetThreadNum(4); - bool optimize = true; + paddle_mobile.SetThreadNum(1); + bool optimize = false; auto time1 = time(); if (paddle_mobile.Load(g_googlenet, optimize)) { auto time2 = time(); std::cout << "load cost :" << time_diff(time1, time2) << "ms" << std::endl; std::vector input; + std::vector output; std::vector dims{1, 3, 224, 224}; GetInput(g_test_image_1x3x224x224, &input, dims); - // 预热十次 - for (int i = 0; i < 10; ++i) { - auto vec_result = paddle_mobile.Predict(input, dims); - } + // // 预热十次 + // for (int i = 0; i < 10; ++i) { + // output = paddle_mobile.Predict(input, dims); + // } auto time3 = time(); for (int i = 0; i < 10; ++i) { - auto vec_result = paddle_mobile.Predict(input, dims); + output = paddle_mobile.Predict(input, dims); } auto time4 = time(); std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms" << std::endl; + for (int i = 0; i < output.size(); ++i) { + DLOG << "result[" << i << "] = " << output[i]; + } } return 0; } diff --git a/test/operators/test_dequantize_op.cpp b/test/operators/test_dequantize_op.cpp index 8c61ae32d90169c5f8c6fdced94ce70f29d93b96..8e89d8f7af3694bcc4701c268451f28675db7fc9 100644 --- a/test/operators/test_dequantize_op.cpp +++ b/test/operators/test_dequantize_op.cpp @@ -59,7 +59,7 @@ int TestDequqntizeOp() { framework::Tensor output_cmp; output_cmp.Resize(dim); - float dequant_scale = 1.f / (1.27 * 1.74); + float dequant_scale = 1.27 / 1.74; dequantize(input, dequant_scale, &output_cmp); const float* output_cmp_data = output_cmp.data(); for (int i = 0; i < output->numel(); ++i) { diff --git a/test/operators/test_fill_constant_op.cpp b/test/operators/test_fill_constant_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b099217d1641eb221b3d0d86d780fb6ecfa929bd --- /dev/null +++ b/test/operators/test_fill_constant_op.cpp @@ -0,0 +1,113 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "../test_include.h" +#include "operators/fill_constant_op.h" + +namespace paddle_mobile { +namespace framework { + +template +class TestFillConstantOp { + public: + explicit TestFillConstantOp(const Program p) : program_(p) { + if (use_optimize_) { + to_predict_program_ = program_.optimizeProgram; + } else { + to_predict_program_ = program_.originProgram; + } + const std::vector> blocks = + to_predict_program_->Blocks(); + for (auto block_desc : blocks) { + std::vector> ops = block_desc->Ops(); + for (auto op : ops) { + if (op->Type() == "fill_constant") { + DLOG << " attr size: " << op->GetAttrMap().size(); + std::unordered_map attrs = op->GetAttrMap(); + for (std::unordered_map::iterator it = + attrs.begin(); + it != attrs.end(); ++it) { + DLOG << " " << it->first << " " << it->second; + } + DLOG << " inputs size: " << op->GetInputs().size(); + DLOG << " outputs size: " << op->GetOutputs().size(); + DLOG << " output is : " << op->Output("Out")[0]; + output_var_name = op->Output("Out")[0]; + std::shared_ptr> op_ptr = + std::make_shared>( + op->Type(), op->GetInputs(), op->GetOutputs(), + op->GetAttrMap(), program_.scope); + ops_of_block_[*block_desc.get()].push_back(op_ptr); + } + } + } + } + + std::shared_ptr predict() { + auto scope = program_.scope; + + Variable *output = scope->Var(output_var_name); + auto *output_tensor = output->GetMutable(); + + std::shared_ptr out_tensor = std::make_shared(); + out_tensor.reset(output_tensor); + + predict(0); + + return out_tensor; + } + + private: + const framework::Program program_; + std::shared_ptr to_predict_program_; + std::map>>> + ops_of_block_; + bool use_optimize_ = false; + string output_var_name; + + void predict(int block_id) { + std::shared_ptr to_predict_block = + to_predict_program_->Block(block_id); + for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) { + auto op = ops_of_block_[*to_predict_block.get()][j]; + op->Run(); + } + } +}; + +template class TestFillConstantOp; +} // namespace framework +} // namespace paddle_mobile + +int main() { + DLOG << "----------**********----------"; + DLOG << "begin to run FillConstant Test"; + paddle_mobile::Loader loader; + auto program = loader.Load(std::string(g_ocr) + "/model", + std::string(g_ocr) + "/params"); + + paddle_mobile::framework::TestFillConstantOp + testFillConstantOp(program); + + auto output = testFillConstantOp.predict(); + auto *output_ptr = output->data(); + + DLOG << "output : "; + for (int i = 0; i < output->numel(); ++i) { + DLOG << " index " << i << " : " << output_ptr[i]; + } + return 0; +} diff --git a/test/operators/test_int8_conv_op.cpp b/test/operators/test_int8_conv_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2ab40ba5833939e4456bb13bf4d5f9819a332693 --- /dev/null +++ b/test/operators/test_int8_conv_op.cpp @@ -0,0 +1,279 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "../test_helper.h" +#include "../test_include.h" +#include "operators/conv_op.h" + +namespace paddle_mobile { + +// Reference convolution for checking results: +// accumulate through explicit loops over input, output, and filters. +template +void conv2d(const framework::Tensor *input, const framework::Tensor *filter, + const framework::AttributeMap &attrs, framework::Tensor *output) { + framework::AttrReader attr_reader(attrs); + std::vector paddings = attr_reader.Get>("paddings"); + std::vector strides = attr_reader.Get>("strides"); + std::vector dilations = attr_reader.Get>("dilations"); + int groups = attr_reader.Get("groups"); + int kernel_h = filter->dims()[2]; + int kernel_w = filter->dims()[3]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int dilation_h = dilations[0]; + int dilation_w = dilations[1]; + auto in_shape = input->dims(); + auto out_shape = output->dims(); + + const bool has_depth = 0; + int kernel_d, pad_d, stride_d, dilation_d; + if (has_depth) { + kernel_d = kernel_h; + stride_d = stride_h; + pad_d = pad_h; + dilation_d = dilation_h; + } else { + kernel_d = stride_d = dilation_d = 1; + pad_d = 0; + } + // Groups + int o_g = out_shape[1] / groups; + int k_g = in_shape[1] / groups; + int o_head, k_head; + // Convolution + vector weight_offset(4 + has_depth); + vector in_offset(4 + has_depth); + vector out_offset(4 + has_depth); + auto offset = [](const framework::Tensor *input, const vector &indics) { + framework::DDim shape = input->dims(); + size_t count = 0; + for (int i = 0; i < indics.size(); ++i) { + count *= shape[i]; + count += indics[i]; + } + return count; + }; + + const Itype *in_data = input->data(); + const Itype *w_data = filter->data(); + Otype *out_data = output->mutable_data(); + memset(out_data, 0, output->numel() * sizeof(Otype)); + for (int n = 0; n < out_shape[0]; n++) { + for (int g = 0; g < groups; g++) { + o_head = o_g * g; + k_head = k_g * g; + for (int o = 0; o < o_g; o++) { + for (int k = 0; k < k_g; k++) { + for (int z = 0; z < (has_depth ? out_shape[2] : 1); z++) { + for (int y = 0; y < out_shape[2 + has_depth]; y++) { + for (int x = 0; x < out_shape[3 + has_depth]; x++) { + for (int r = 0; r < kernel_d; r++) { + for (int p = 0; p < kernel_h; p++) { + for (int q = 0; q < kernel_w; q++) { + int in_z = z * stride_d - pad_d + r * dilation_d; + int in_y = y * stride_h - pad_h + p * dilation_h; + int in_x = x * stride_w - pad_w + q * dilation_w; + if (in_z >= 0 && in_z < (has_depth ? in_shape[2] : 1) && + in_y >= 0 && in_y < in_shape[2 + has_depth] && + in_x >= 0 && in_x < in_shape[3 + has_depth]) { + weight_offset[0] = o + o_head; + weight_offset[1] = k; + if (has_depth) { + weight_offset[2] = r; + } + weight_offset[2 + has_depth] = p; + weight_offset[3 + has_depth] = q; + in_offset[0] = n; + in_offset[1] = k + k_head; + if (has_depth) { + in_offset[2] = in_z; + } + in_offset[2 + has_depth] = in_y; + in_offset[3 + has_depth] = in_x; + out_offset[0] = n; + out_offset[1] = o + o_head; + if (has_depth) { + out_offset[2] = z; + } + out_offset[2 + has_depth] = y; + out_offset[3 + has_depth] = x; + + out_data[offset(output, out_offset)] += + in_data[offset(input, in_offset)] * + w_data[offset(filter, weight_offset)]; + } + } + } + } + } + } + } + } + } + } + } +} + +template +int TestConvOp() { + int kernel_h = Kernel; + int kernel_w = Kernel; + int pad_h = Pad; + int pad_w = Pad; + int stride_h = Stride; + int stride_w = Stride; + int dilation_h = 1; + int dilation_w = 1; + + int batch_size = 1; + int input_c = 3; + int input_h = 100; + int input_w = 100; + int output_c = 10; + framework::DDim input_shape = + framework::make_ddim({batch_size, input_c, input_h, input_w}); + framework::DDim filter_shape = + framework::make_ddim({output_c, input_c, kernel_h, kernel_w}); + + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["Input"] = std::vector({"input"}); + inputs["Filter"] = std::vector({"filter"}); + outputs["Output"] = std::vector({"output"}); + + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, input_shape, -20, 20); + + auto filter_var = scope.get()->Var("filter"); + auto filter = filter_var->template GetMutable(); + SetupTensor(filter, filter_shape, -20, 20); + + auto output_var = scope.get()->Var("output"); + framework::AttributeMap attrs; + attrs["strides"].Set>(std::vector({stride_h, stride_w})); + attrs["paddings"].Set>(std::vector({pad_h, pad_w})); + attrs["dilations"].Set>( + std::vector({dilation_h, dilation_w})); + attrs["groups"].Set(1); + + auto *op = new operators::ConvOp("conv2d", inputs, outputs, attrs, + scope); + // struct timespec ts_begin, ts_end; + op->InferShape(); + // warmup + // op->Run(); + // clock_gettime(CLOCK_MONOTONIC, &ts_begin); + // for (int i = 0; i < 10; ++i) { + op->Run(); + // } + // clock_gettime(CLOCK_MONOTONIC, &ts_end); + // uint64_t elapsed = (ts_end.tv_sec - ts_begin.tv_sec) * 1e3 + + // (ts_end.tv_nsec - ts_begin.tv_nsec) / 1e6; + // LOG(kLOG_INFO) << "elapsed: " << elapsed / 10.0 << " ms"; + + int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; + int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1; + int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1; + auto output_shape = framework::make_ddim( + std::vector({batch_size, output_c, output_h, output_w})); + framework::Tensor output_cmp; + output_cmp.mutable_data(output_shape); + conv2d(input, filter, attrs, &output_cmp); + + // compare results + auto output = output_var->template Get(); + const Otype *output_data = output->data(); + Otype *output_cmp_data = output_cmp.data(); + for (int i = 0; i < output->numel(); ++i) { + PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], + "output[%d] = %d, output_cmp[%d] = %d", i, + output_data[i], i, output_cmp_data[i]); + } + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main() { + // kernel = 7, pad = 0, stride = 2 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2"; + paddle_mobile::TestConvOp(); + + // kernel = 7, pad = 1, stride = 2 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2"; + paddle_mobile::TestConvOp(); + + // kernel = 7, pad = 3, stride = 2 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2"; + paddle_mobile::TestConvOp(); + + // kernel = 7, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1"; + paddle_mobile::TestConvOp(); + + // kernel = 7, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1"; + paddle_mobile::TestConvOp(); + + // kernel = 7, pad = 3, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1"; + paddle_mobile::TestConvOp(); + + // kernel = 7, pad = 5, stride = 3 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3"; + paddle_mobile::TestConvOp(); + + // kernel = 7, pad = 3, stride = 4 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4"; + paddle_mobile::TestConvOp(); + LOG(paddle_mobile::kLOG_INFO) << "\n"; + + // kernel = 3, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1"; + paddle_mobile::TestConvOp(); + // kernel = 3, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1"; + paddle_mobile::TestConvOp(); + LOG(paddle_mobile::kLOG_INFO) << "\n"; + + // kernel = 3, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1"; + paddle_mobile::TestConvOp(); + // kernel = 3, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1"; + paddle_mobile::TestConvOp(); + LOG(paddle_mobile::kLOG_INFO) << "\n"; + + // kernel = 5, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1"; + paddle_mobile::TestConvOp(); + // kernel = 5, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1"; + paddle_mobile::TestConvOp(); + LOG(paddle_mobile::kLOG_INFO) << "\n"; + + // kernel = 5, pad = 2, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1"; + paddle_mobile::TestConvOp(); + // kernel = 5, pad = 2, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1"; + paddle_mobile::TestConvOp(); +} diff --git a/test/operators/test_mul_op.cpp b/test/operators/test_mul_op.cpp index 678add6dcedd22e788e0bd2df64a8eba59ad8514..10dab2cda1b3c692f42cf8760eb2b48ae6451f39 100644 --- a/test/operators/test_mul_op.cpp +++ b/test/operators/test_mul_op.cpp @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include "../test_helper.h" #include "../test_include.h" #include "operators/mul_op.h" diff --git a/test/operators/test_quantize_op.cpp b/test/operators/test_quantize_op.cpp index c988862f6d91c87f47525fa36b7ee61f253682ab..5b1f276bebb0b956a7907a500645612c5aeaf8f9 100644 --- a/test/operators/test_quantize_op.cpp +++ b/test/operators/test_quantize_op.cpp @@ -18,14 +18,6 @@ limitations under the License. */ namespace paddle_mobile { -// static float g_test_data[50] = { -// -5.55, -5.5, -5.45, -5.0, -4.55, -4.5, -4.45, -4.0, -3.55, -3.5, -// -3.45, -3.01, -2.75, -2.5, -2.501, -2.49, -2.01, -1.75, -1.5, -1.25, -// -1.0, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 1.0, 1.25, -// 1.5, 1.75, 2.01, 2.49, 2.501, 2.5, 2.75, 3.01, 3.45, 3.5, -// 3.55, 4.0, 4.45, 4.5, 4.55, 5.0, 5.45, 5.5, 5.55, 6.0, -// }; - static float find_abs_max(const Tensor *input) { float max_abs = 0.f; const float *x = input->data(); @@ -60,6 +52,16 @@ static void quantize_round_to_even(const Tensor *input, const float scale, } } +static void quantize_round_to_nearest(const Tensor *input, const float scale, + Tensor *output) { + const float *x = input->data(); + int8_t *y = output->mutable_data(); + size_t size = input->numel(); + for (size_t i = 0; i < size; ++i) { + y[i] = round(x[i] * scale); + } +} + int TestQuqntizeOp() { framework::DDim dim = framework::make_ddim({1, 3, 224, 224}); @@ -88,15 +90,16 @@ int TestQuqntizeOp() { auto output_scale = output_scale_var->template Get(); const float *output_scale_data = output_scale->data(); - float max_abs = find_abs_max(input); - float output_scale_cmp = 127 / max_abs; + float output_scale_cmp = find_abs_max(input); PADDLE_MOBILE_ENFORCE(output_scale_cmp == output_scale_data[0], "output_scale = %.6f, output_scale_cmp = %.6f", output_scale_cmp, output_scale_data[0]); framework::Tensor output_cmp; output_cmp.Resize(dim); - quantize_round_to_even(input, output_scale_cmp, &output_cmp); + float scale = 127 / output_scale_cmp; + // quantize_round_to_even(input, scale, &output_cmp); + quantize_round_to_nearest(input, scale, &output_cmp); int8_t *output_cmp_data = output_cmp.data(); for (int i = 0; i < output->numel(); ++i) { PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], diff --git a/tools/op.cmake b/tools/op.cmake index 8b7378e1b2d0a0b13f9bec862fd155ec86845ad5..f7a6ed4b134f78ddb23487cd3a861f244e6a86db 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -188,6 +188,7 @@ if(NOT FOUND_MATCH) set(ELEMENTWISEADD_OP ON) set(ELEMENTWISESUB_OP ON) set(IM2SEQUENCE_OP ON) + set(FILL_CONSTANT_OP ON) set(FUSION_CONVADD_OP ON) set(FUSION_CONVADDPRELU_OP ON) set(FUSION_CONVADDRELU_OP ON) @@ -223,6 +224,8 @@ if(NOT FOUND_MATCH) set(SHAPE_OP ON) set(ELEMENTWISEMUL_OP ON) set(SUM_OP ON) + set(QUANT_OP ON) + set(DEQUANT_OP ON) endif() # option(BATCHNORM_OP "" ON) @@ -231,6 +234,7 @@ endif() # option(CONV_OP "" ON) # option(DEPTHWISECONV_OP "" ON) # option(ELEMENTWISEADD_OP "" ON) + # option(FILL_CONSTANT_OP "" ON) # option(FUSION_CONVADD_OP "" ON) # option(FUSION_CONVADDRELU_OP "" ON) # option(FUSION_FC_OP "" ON) @@ -268,6 +272,9 @@ endif() if (ELEMENTWISESUB_OP) add_definitions(-DELEMENTWISESUB_OP) endif() +if (FILL_CONSTANT_OP) + add_definitions(-DFILL_CONSTANT_OP) +endif() if (FUSION_CONVADD_OP) add_definitions(-DFUSION_CONVADD_OP) endif() @@ -406,3 +413,10 @@ if (SUM_OP) add_definitions(-DSUM_OP) endif() +if (QUANT_OP) + add_definitions(-DQUANT_OP) +endif() +if (DEQUANT_OP) + add_definitions(-DDEQUANT_OP) +endif() +