diff --git a/src/common/types.cpp b/src/common/types.cpp index 9bc594c7533b980626d8d07e89fc3ccf649a127f..7b8b5bb559a29bd28c9789f3b52fa1b65cc14dc5 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -25,7 +25,7 @@ const std::string G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu"; const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu"; const std::string G_OP_TYPE_FUSION_DWCONV_BN_RELU = "fusion_dwconv_bn_relu"; - +const std::string G_OP_TYPE_FUSION_CONV_BN_RELU = "fusion_conv_bn_relu"; const std::string G_OP_TYPE_FC = "fusion_fc"; const std::string G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add"; const std::string G_OP_TYPE_LRN = "lrn"; @@ -49,6 +49,8 @@ std::unordered_map< std::string, std::pair, std::vector>> op_input_output_key = { {G_OP_TYPE_CONV, {{"Input"}, {"Output"}}}, + {G_OP_TYPE_FUSION_DWCONV_BN_RELU, {{"Input"}, {"Out"}}}, + {G_OP_TYPE_FUSION_CONV_BN_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD, {{"Input"}, {"Out"}}}, {G_OP_TYPE_RELU, {{"X"}, {"Out"}}}, {G_OP_TYPE_SOFTMAX, {{"X"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index 1daf9c9b7bccfc8bcb584e5a37f920539736a911..627b7efac6ef6bee6dc96295e63ee8f0f96b7932 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include namespace paddle_mobile { @@ -82,6 +83,7 @@ extern const std::string G_OP_TYPE_FC; extern const std::string G_OP_TYPE_FUSION_CONV_ADD; extern const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU; extern const std::string G_OP_TYPE_FUSION_DWCONV_BN_RELU; +extern const std::string G_OP_TYPE_FUSION_CONV_BN_RELU; extern const std::string G_OP_TYPE_LRN; extern const std::string G_OP_TYPE_MUL; diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index 36b4663cb603d29bb60cfc297899d1c300e8ca91..765103c241a82ac224d707340f8b66ace827e335 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -28,6 +28,16 @@ vector OperatorBase::GetOutKeys() const { return it->second.second; } +template +vector OperatorBase::GetInputKeys() const { + auto it = op_input_output_key.find(type_); + if (it == op_input_output_key.end()) { + DLOG << type_ << " has no outputs"; + return {}; + } + return it->second.first; +} + template OperatorBase::OperatorBase(const std::string &type, const VariableNameMap &inputs, @@ -49,6 +59,11 @@ template void OperatorBase::Run() const { RunImpl(); #ifdef PADDLE_MOBILE_DEBUG + vector input_keys = GetInputKeys(); + for (const auto key : input_keys) { + Tensor *input = GetVarValue(key, inputs_, *scope_); + DLOG << type_ << " input- " << key << "=" << *input; + } vector output_keys = GetOutKeys(); for (const auto key : output_keys) { Tensor *out_ = GetVarValue(key, outputs_, *scope_); diff --git a/src/framework/operator.h b/src/framework/operator.h index 793551b0cd3eea290243c156c27616a34c37a3d2..084ac3c81185fe489fe1ca67589c1e8edb1d4fdf 100644 --- a/src/framework/operator.h +++ b/src/framework/operator.h @@ -61,6 +61,7 @@ class OperatorBase { virtual ~OperatorBase() {} void Run() const; std::vector GetOutKeys() const; + std::vector GetInputKeys() const; virtual void RunImpl() const = 0; virtual void Init() = 0; @@ -118,6 +119,10 @@ class OperatorWithKernel : public OperatorBase { virtual void InferShape() const = 0; void Init() { + // for (auto i : this->inputs_) { + // DLOG << i.first; + // DLOG << i.second; + // } PADDLE_MOBILE_ENFORCE(kernel_.Init(¶m_), " %s kernel init failed", this->type_.c_str()); } @@ -146,7 +151,7 @@ class OpKernelBase { } #endif virtual void Compute(const P ¶) const = 0; - virtual bool Init(P *para) { return true; }; + virtual bool Init(P *para) { return true; } virtual ~OpKernelBase() = default; private: diff --git a/src/operators/fusion_conv_add.h b/src/operators/fusion_conv_add.h index fd63efa8f6172ad244c8e61619ef286dc3ffa1de..170df9ce33e4ab90297664fbc81d723e7c246f83 100644 --- a/src/operators/fusion_conv_add.h +++ b/src/operators/fusion_conv_add.h @@ -66,11 +66,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel< #ifdef PADDLE_MOBILE_CPU -//#ifndef CONV_ADD_REGISTER -// static framework::FusionOpRegistrar convadd_registrar( -// new FusionConvAddMatcher()); -//#define CONV_ADD_REGISTER -//#endif +#ifndef CONV_ADD_REGISTER +static framework::FusionOpRegistrar convadd_registrar( + new FusionConvAddMatcher()); +#define CONV_ADD_REGISTER +#endif #endif diff --git a/src/operators/fusion_conv_bn_relu_op.cpp b/src/operators/fusion_conv_bn_relu_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..49fe9c933a5a9695f2c18bd0921c2d36063dc065 --- /dev/null +++ b/src/operators/fusion_conv_bn_relu_op.cpp @@ -0,0 +1,60 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVBNRELU_OP + +#include "operators/fusion_conv_bn_relu_op.h" +#include "operators/math/conv_func.h" + +namespace paddle_mobile { +namespace operators { + +template +void FusionConvBNReluOp::InferShape() const { + auto in_dims = this->param_.Input()->dims(); + auto filter_dims = this->param_.Filter()->dims(); + const std::vector &strides = this->param_.Strides(); + std::vector paddings = this->param_.Paddings(); + int groups = this->param_.Groups(); + std::vector dilations = this->param_.Dilations(); + + PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() && + dilations.size() == paddings.size() && + paddings.size() == strides.size()), + "ConvParam is not suitable"); + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + output_shape.push_back( + math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], + paddings[i], strides[i])); + } + + framework::DDim ddim = framework::make_ddim(output_shape); + this->param_.Output()->Resize(ddim); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_conv_bn_relu, ops::FusionConvBNReluOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif diff --git a/src/operators/fusion_conv_bn_relu_op.h b/src/operators/fusion_conv_bn_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4c2c1033ac0a4d6c8e3bc3f188a66884dd9e0642 --- /dev/null +++ b/src/operators/fusion_conv_bn_relu_op.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVBNRELU_OP + +#pragma once + +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "operators/kernel/conv_bn_relu_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { +using std::string; +using std::vector; +class FusionConvBNReluMatcher : public framework::FusionOpMatcher { + public: + FusionConvBNReluMatcher() { + node_ = framework::Node(G_OP_TYPE_CONV); + node_ > std::make_shared(G_OP_TYPE_BATCHNORM) > + std::make_shared(G_OP_TYPE_RELU); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_BATCHNORM, + {{"Scale", "Scale"}, + {"Mean", "Mean"}, + {"Bias", "Bias"}, + {"Variance", "Variance"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_CONV_BN_RELU; } +}; + +template +class FusionConvBNReluOp : public framework::OperatorWithKernel< + DeviceType, FusionConvBNReluParam, + operators::ConvBNReluKernel> { + public: + FusionConvBNReluOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionConvBNReluParam, + operators::ConvBNReluKernel>(type, inputs, outputs, + attrs, scope) {} + + using framework::OperatorWithKernel< + DeviceType, FusionConvBNReluParam, + operators::ConvBNReluKernel>::OperatorWithKernel; + void InferShape() const override; + + protected: +}; + +#ifdef PADDLE_MOBILE_CPU + +#ifndef FUSION_CONV_BN_RELU_REGISTER +static framework::FusionOpRegistrar fusion_conv_bn_relu_registrar( + new FusionConvBNReluMatcher()); +#define FUSION_CONV_BN_RELU_REGISTER +#endif + +#endif + +#ifdef PADDLE_MOBILE_MALI_GPU + +#endif + +#ifdef PADDLE_MOBILE_FPGA +#endif + +} // namespace operators +} // namespace paddle_mobile + +#ifdef PADDLE_MOBILE_CPU +USE_OP_CPU(fusion_conv_bn_relu); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif diff --git a/src/operators/kernel/arm/conv_bn_relu_kernel.cpp b/src/operators/kernel/arm/conv_bn_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..23f06c1f0b8a0ed3f22ca9d23d24ae44c59f3618 --- /dev/null +++ b/src/operators/kernel/arm/conv_bn_relu_kernel.cpp @@ -0,0 +1,68 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVBNRELU_OP + +#include "operators/kernel/conv_bn_relu_kernel.h" +#include "operators/kernel/central-arm-func/conv_bn_relu_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool ConvBNReluKernel::Init(FusionConvBNReluParam *param) { + const Tensor *mean = param->InputMean(); + const Tensor *variance = param->InputVariance(); + const Tensor *scale = param->InputScale(); + const Tensor *bias = param->InputBias(); + const float epsilon = param->Epsilon(); + + // DLOG << "variance: " << *variance; + + auto mean_ptr = mean->data(); + auto variance_ptr = variance->data(); + auto scale_ptr = scale->data(); + auto bias_ptr = bias->data(); + + const int C = mean->numel(); + float inv_std_ptr[C]; + for (int i = 0; i < C; i++) { + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + } + Tensor *new_scale = new Tensor(); + Tensor *new_bias = new Tensor(); + auto new_scale_ptr = new_scale->mutable_data({C}); + auto new_bias_ptr = new_bias->mutable_data({C}); + for (int i = 0; i < C; i++) { + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; + new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; + } + + param->SetNewScale(new_scale); + param->SetNewBias(new_bias); + return true; +} + +template <> +void ConvBNReluKernel::Compute( + const FusionConvBNReluParam ¶m) const { + ConvBNReluCompute(param); +} +template class ConvBNReluKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/batchnorm_arm_func.h b/src/operators/kernel/central-arm-func/batchnorm_arm_func.h index b2af17eb4aaf0a7ef98442f589162a3b6f371a3b..cc591035065e4cbbe71ff8f6bd6cbab9c6fe9e79 100644 --- a/src/operators/kernel/central-arm-func/batchnorm_arm_func.h +++ b/src/operators/kernel/central-arm-func/batchnorm_arm_func.h @@ -54,7 +54,40 @@ void BatchnormCompute(const BatchNormParam ¶m) { int HXW = H * W; -#ifdef ARMV7 +#if __ARM_NEON +#if __aarch64__ + float *inv_std_ptr = new float[C]; + for (int i = 0; i < C; i++) { + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + } + + Tensor new_scale; + auto new_scale_ptr = new_scale.mutable_data(framework::make_ddim({C})); + Tensor new_bias; + auto new_bias_ptr = new_bias.mutable_data(framework::make_ddim({C})); + + /// ((x - est_mean) * (inv_var) * scale + bias equal to + /// (x * inv_var * scale) + (bias - est_mean * inv_var * scale) + for (int i = 0; i < C; i++) { + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; + new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; + { + for (int n = 0; n < N; n++) { + for (int h = 0; h < H; h++) { + int tmp_index = n * stride0 + i * stride1 + h * stride2; + for (int w = 0; w < W; w++) { + int index = tmp_index + w; + out_ptr[index] = + input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i]; + } + } + } + } + } + delete[] inv_std_ptr; +#else + if (HXW > 32) { int NXC = N * C; float *inv_std_ptr = new float[NXC * 4]; @@ -229,6 +262,7 @@ void BatchnormCompute(const BatchNormParam ¶m) { delete[] inv_std_ptr; } +#endif #else float *inv_std_ptr = new float[C]; for (int i = 0; i < C; i++) { diff --git a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..f18d67749b96cd0ee2d84c2731af8a2c3e136db1 --- /dev/null +++ b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h @@ -0,0 +1,139 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVBNRELU_OP + +#pragma once +#include +#include "operators/math/depthwise_conv_3x3.h" +#include "operators/op_param.h" +namespace paddle_mobile { +namespace operators { +void ConvBNReluBasic(const FusionConvBNReluParam ¶m) { + const Tensor *input = param.Input(); + Tensor filter = *param.Filter(); + Tensor new_bias = *param.NewBias(); + Tensor new_scale = *param.NewScale(); + + Tensor *output = param.Output(); + + int groups = param.Groups(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + + const int batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + + std::vector output_shape_vec(framework::vectorize(output->dims())); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + + bool is_expand = + math::IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col.mutable_data(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(in_slice, dilations, strides, paddings, &col); + } + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + + math::matmulWithBn( + filter_slice, false, col_matrix, false, static_cast(1), + &out_slice, static_cast(0), true, &new_scale, &new_bias, g); + } + } +} + +template +void ConvBNReluCompute(const FusionConvBNReluParam ¶m) { + if (param.Groups() == param.Input()->dims()[1] && + param.Input()->dims()[1] == param.Output()->dims()[1] && + param.Filter()->dims()[2] == param.Filter()->dims()[3] && + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); + } else if (param.Groups() == param.Input()->dims()[1] && + param.Input()->dims()[1] == param.Output()->dims()[1] && + param.Filter()->dims()[2] == param.Filter()->dims()[3] && + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), + // param.Output(), param.NewScale(), + // param.NewBias(), 1); + math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); + } else { + ConvBNReluBasic(param); + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/pool_arm_func.h b/src/operators/kernel/central-arm-func/pool_arm_func.h index 892dca2ea40d40484b4c32a57f8633849cc9d038..12aa01507d83c5051f4b462fc4607d71a707f06d 100644 --- a/src/operators/kernel/central-arm-func/pool_arm_func.h +++ b/src/operators/kernel/central-arm-func/pool_arm_func.h @@ -76,7 +76,7 @@ void PoolCompute(const PoolParam ¶m) { } } else if (ksize[0] == 2 && ksize[0] == ksize[1]) { -#ifndef IOS +#if __ARM_NEON if (pooling_type == "max") { math::Pool2x2Max(strides, paddings, in_x, out); } else if (pooling_type == "avg") { @@ -84,7 +84,8 @@ void PoolCompute(const PoolParam ¶m) { } #else PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); -#endif +#endif // __ARM_NEON + } else { PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); } diff --git a/src/operators/kernel/central-arm-func/sigmoid_arm_func.h b/src/operators/kernel/central-arm-func/sigmoid_arm_func.h index daf6ad0e472515c8034a400dfc73de608f5b12d2..c612c4b092143ef8925f81a6d6fefe9cd9dff25b 100644 --- a/src/operators/kernel/central-arm-func/sigmoid_arm_func.h +++ b/src/operators/kernel/central-arm-func/sigmoid_arm_func.h @@ -68,6 +68,7 @@ void sigmoid(const Tensor *X, Tensor *Y) { input_outer_ptr++; } } +#else #endif } diff --git a/src/operators/kernel/conv_bn_relu_kernel.h b/src/operators/kernel/conv_bn_relu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..c9d4df5d8f597deebaf2b53491851b7ce03fc7aa --- /dev/null +++ b/src/operators/kernel/conv_bn_relu_kernel.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef FUSION_CONVBNRELU_OP + +#include +#include "framework/ddim.h" +#include "framework/operator.h" +#include "operators/math/conv_func.h" +#include "operators/math/im2col.h" +#include "operators/math/math_function.h" +#include "operators/math/vol2col.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using framework::DDim; +using framework::OpKernelBase; + +template +class ConvBNReluKernel + : public OpKernelBase { + public: + void Compute(const FusionConvBNReluParam ¶m) const; + bool Init(FusionConvBNReluParam *param); +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/math/im2col.cpp b/src/operators/math/im2col.cpp index 625d120705aab8fcc3ea8d232b4077e213941ec4..7b0b974b542a83d381727128887bef8a48ce937f 100644 --- a/src/operators/math/im2col.cpp +++ b/src/operators/math/im2col.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #include "operators/math/im2col.h" #include #ifdef __ARM_NEON -#include "arm_neon.h" +#include #endif #include "common/types.h" namespace paddle_mobile { @@ -69,7 +69,7 @@ class Im2ColFunctor { int channels_col = im_channels * filter_height * filter_width; const T *im_data = im.data(); T *col_data = col->data(); -#ifdef __ARM_NEON +#if __ARM_NEON const int osize = col_height; const int isize = im_height; bool pad1 = padding[0] > 0; diff --git a/src/operators/math/pool_2x2.cpp b/src/operators/math/pool_2x2.cpp index c86003f6f96b632efd50bbb156293510e3d8521c..0a2d96d4d065d7938e6872b4f073e080d7be8c3a 100644 --- a/src/operators/math/pool_2x2.cpp +++ b/src/operators/math/pool_2x2.cpp @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef POOL_OP -#include "pool_2x2.h" +#include "operators/math/pool_2x2.h" +#include +#include namespace paddle_mobile { namespace operators { @@ -21,10 +23,10 @@ namespace math { void Pool2x2Max(vector strides, vector paddings, const Tensor *input, Tensor *output) { -#ifdef __ARM_NEON - -#ifdef ARMV7 +#if __ARM_NEON +#if __aarch64__ +#else const int batch_size = input->dims()[0]; const int input_height = input->dims()[2]; @@ -93,15 +95,16 @@ void Pool2x2Max(vector strides, vector paddings, const Tensor *input, output_data += output_batch_stride; } #endif - +#else #endif } void Pool2x2Avg(vector strides, vector paddings, const Tensor *input, Tensor *output) { -#ifdef __ARM_NEON +#if __ARM_NEON -#ifdef ARMV7 +#if __aarch64__ +#else const int batch_size = input->dims()[0]; const int input_height = input->dims()[2]; @@ -171,12 +174,9 @@ void Pool2x2Avg(vector strides, vector paddings, const Tensor *input, input_data += input_batch_stride; output_data += output_batch_stride; } -#else - -// TODO(): to imp other asm #endif - +#else #endif } diff --git a/src/operators/math/pool_3x3.cpp b/src/operators/math/pool_3x3.cpp index 28a8877355b2c2cc1221512884b5be1497bc4243..28547b71fca6caea2ff4341b3f832c0035436a72 100644 --- a/src/operators/math/pool_3x3.cpp +++ b/src/operators/math/pool_3x3.cpp @@ -17,7 +17,7 @@ limitations under the License. */ #include #endif #include "framework/tensor.h" -#include "pool_3x3.h" +#include "operators/math/pool_3x3.h" #if __ARM_NEON #include #endif // __ARM_NEON @@ -518,6 +518,8 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) { input_data += input_batch_stride; out_data += output_batch_stride; } +#else + #endif } @@ -582,7 +584,18 @@ void Pool3x3Max(vector strides, vector paddings, const Tensor *input, } output_seg[ph * output_width + pw] = max_value; } else { -#if defined(ARMV7) +#if __aarch64__ + const float32x4_t data1 = vld1q_f32(pos1); + const float32x4_t data2 = vld1q_f32(pos1 + input_width); + const float32x4_t data3 = vld1q_f32(pos1 + 2 * input_width); + const float32x4_t max_data = + vmaxq_f32(vmaxq_f32(data1, data2), data3); + float32x2_t res = + vpmax_f32(vget_high_f32(vsetq_lane_f32(-INT_MAX, max_data, 3)), + vget_low_f32(max_data)); + res = vpmax_f32(res, res); + output_seg[ph * output_width + pw] = vget_lane_f32(res, 0); +#else asm volatile( "vld1.32 {q1}, [%[pos1]] \n\t" "vld1.32 {q2}, [%[pos2]] \n\t" @@ -598,17 +611,6 @@ void Pool3x3Max(vector strides, vector paddings, const Tensor *input, [pos2] "r"(pos2), [pos3] "r"(pos3), [output_ptr] "r"(output_ptr), [negative_max] "r"(negative_max) : "memory", "q1", "q2", "q3", "q4"); -#else - const float32x4_t data1 = vld1q_f32(pos1); - const float32x4_t data2 = vld1q_f32(pos1 + input_width); - const float32x4_t data3 = vld1q_f32(pos1 + 2 * input_width); - const float32x4_t max_data = - vmaxq_f32(vmaxq_f32(data1, data2), data3); - float32x2_t res = - vpmax_f32(vget_high_f32(vsetq_lane_f32(-INT_MAX, max_data, 3)), - vget_low_f32(max_data)); - res = vpmax_f32(res, res); - output_seg[ph * output_width + pw] = vget_lane_f32(res, 0); #endif } } @@ -676,8 +678,8 @@ void Pool3x3Avg(vector strides, vector paddings, const Tensor *input, } output_seg[ph * output_width + pw] = sum / 9.0; } else { -#if defined(ARMV7) - +#if __aarch64__ +#else asm volatile( "vld1.32 {q1}, [%[pos1]] \n\t" "vld1.32 {q2}, [%[pos2]] \n\t" @@ -696,7 +698,7 @@ void Pool3x3Avg(vector strides, vector paddings, const Tensor *input, [output_ptr] "r"(output_ptr), [zero] "r"(zero), [nine_ptr] "r"(nine_ptr) : "memory", "r6", "q1", "q2", "q3", "q4"); -#else +#endif const float32x4_t data1 = vld1q_f32(pos1); const float32x4_t data2 = vld1q_f32(pos2); const float32x4_t data3 = vld1q_f32(pos3); @@ -707,7 +709,6 @@ void Pool3x3Avg(vector strides, vector paddings, const Tensor *input, vget_low_f32(sum_data)); res = vpadd_f32(res, res); output_seg[ph * output_width + pw] = vget_lane_f32(res, 0) / 9.0; -#endif } } } @@ -715,6 +716,7 @@ void Pool3x3Avg(vector strides, vector paddings, const Tensor *input, input_data += input_batch_stride; output_data += output_batch_stride; } +#else #endif } } // namespace math diff --git a/src/operators/math/softmax.cpp b/src/operators/math/softmax.cpp index 968915f21e08fce9f25ceb63831ee40ecba9cee6..dba88c93969014f2ad0d2636b4141c734dbc2ed5 100644 --- a/src/operators/math/softmax.cpp +++ b/src/operators/math/softmax.cpp @@ -135,6 +135,7 @@ class SoftmaxFuntor { } } } +#else #endif // ARM_NEON public: diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 390de2d5cfc09381112b66d58044e307275ac994..4b95ceb18740531919c4ef00dfdd912b1067e891 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1078,7 +1078,7 @@ class FusionDWConvBNReluParam : public OpParam { input_variance_ = InputVarianceFrom(inputs, scope); epsilon_ = GetAttr("epsilon", attrs); momentum_ = GetAttr("momentum", attrs); - is_test_ = GetAttr("is_test", attrs); + // is_test_ = GetAttr("is_test", attrs); } const Tensor *Input() const { return input_; } @@ -1139,6 +1139,85 @@ class FusionDWConvBNReluParam : public OpParam { Print &operator<<(Print &printer, const FusionConvAddParam &conv_param); #endif +#ifdef FUSION_CONVBNRELU_OP +class FusionConvBNReluParam : public OpParam { + public: + FusionConvBNReluParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + filter_ = FilterFrom(inputs, scope); + input_ = InputFrom(inputs, scope); + output_ = OutFrom(outputs, scope); + + strides_ = GetAttr>("strides", attrs); + paddings_ = GetAttr>("paddings", attrs); + dilations_ = GetAttr>("dilations", attrs); + groups = GetAttr("groups", attrs); + input_bias_ = InputBiasFrom(inputs, scope); + input_mean_ = InputMeanFrom(inputs, scope); + input_scale_ = InputScaleFrom(inputs, scope); + input_variance_ = InputVarianceFrom(inputs, scope); + epsilon_ = GetAttr("epsilon", attrs); + momentum_ = GetAttr("momentum", attrs); + // is_test_ = GetAttr("is_test", attrs); + } + + const Tensor *Input() const { return input_; } + + const Tensor *Filter() const { return filter_; } + + Tensor *Output() const { return output_; } + + const vector &Strides() const { return strides_; } + + const vector &Paddings() const { return paddings_; } + + const vector &Dilations() const { return dilations_; } + + const int &Groups() const { return groups; } + + const Tensor *InputBias() const { return input_bias_; } + + const Tensor *InputMean() const { return input_mean_; } + + const Tensor *InputScale() const { return input_scale_; } + + const Tensor *InputVariance() const { return input_variance_; } + + const float &Epsilon() const { return epsilon_; } + + const float &Momentum() const { return momentum_; } + + const bool &IsTest() const { return is_test_; } + + void SetNewScale(Tensor *new_scale) { new_scale_ = new_scale; } + + void SetNewBias(Tensor *new_bias) { new_bias_ = new_bias; } + + const Tensor *NewScale() const { return new_scale_; } + + const Tensor *NewBias() const { return new_bias_; } + + protected: + Tensor *input_; + Tensor *output_; + Tensor *filter_; + vector strides_; + vector paddings_; + vector dilations_; + int groups; + Tensor *input_bias_; + Tensor *input_mean_; + Tensor *input_scale_; + Tensor *input_variance_; + float epsilon_; + float momentum_; + bool is_test_; + Tensor *new_bias_; + Tensor *new_scale_; +}; +#endif + #ifdef IM2SEQUENCE_OP class Im2SequenceParam : public OpParam { public: diff --git a/test/framework/test_load.cpp b/test/framework/test_load.cpp index f4215de46c2bafd732b0092b58c25bf6fcefdf7a..bea7d4ba7d2df1344f0819222fbdb389106fa77e 100644 --- a/test/framework/test_load.cpp +++ b/test/framework/test_load.cpp @@ -19,7 +19,9 @@ int main() { paddle_mobile::Loader loader; // ../../../test/models/googlenet // ../../../test/models/mobilenet - auto program = loader.Load(g_googlenet, true); + // auto program = loader.Load(g_googlenet, true); + + auto program = loader.Load(g_mobilenet_ssd, true); // auto program = loader.Load(g_googlenet_combine + "/model", // g_googlenet_combine + // "/params", true); diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index 2ab24736397c1e71350335561abbcabcba6e27a4..d230b9469229946fc74f4dc9e1ee6100196ed9aa 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -23,7 +23,7 @@ int main() { auto time1 = time(); if (paddle_mobile.Load(g_googlenet, optimize)) { auto time2 = time(); - DLOG << "load cost :" << time_diff(time1, time1) << "ms"; + DLOG << "load cost: " << time_diff(time1, time1) << "ms"; std::vector input; std::vector dims{1, 3, 224, 224}; GetInput(g_test_image_1x3x224x224, &input, dims); diff --git a/test/net/test_mobilenet+ssd.cpp b/test/net/test_mobilenet+ssd.cpp index 1a7c4cd49cb1707b9c7783cf74e87e74da39732e..a3aac63f5759923df5bc60df556241c6e15c3eb4 100644 --- a/test/net/test_mobilenet+ssd.cpp +++ b/test/net/test_mobilenet+ssd.cpp @@ -12,16 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +#include #include "../test_helper.h" #include "../test_include.h" int main() { paddle_mobile::PaddleMobile paddle_mobile; auto time1 = time(); - if (paddle_mobile.Load(g_mobilenet_ssd, true)) { + // auto isok = paddle_mobile.Load(g_mobilenet_ssd_gesture + "/model", + // g_mobilenet_ssd_gesture + "/params", + // true); + auto isok = paddle_mobile.Load(g_mobilenet_ssd, false); + if (isok) { auto time2 = time(); - DLOG << "load cost :" << time_diff(time1, time1) << "ms"; + std::cout << "load cost :" << time_diff(time1, time2) << "ms" << std::endl; std::vector dims{1, 3, 300, 300}; Tensor input_tensor; @@ -33,7 +37,8 @@ int main() { auto time3 = time(); paddle_mobile.Predict(input, dims); auto time4 = time(); - DLOG << "predict cost :" << time_diff(time3, time4) << "ms"; + std::cout << "predict cost :" << time_diff(time3, time4) << "ms" + << std::endl; } return 0; } diff --git a/test/test_helper.h b/test/test_helper.h index 81ad23ff3b4e53db0225630eebaa34878ad4c139..fb6724f9c5764497ec81de0d73406709f098e0e0 100644 --- a/test/test_helper.h +++ b/test/test_helper.h @@ -16,6 +16,8 @@ limitations under the License. */ #include #include +#include +#include #include "common/common.h" #include "common/log.h" @@ -23,6 +25,8 @@ limitations under the License. */ #include "framework/tensor.h" static const std::string g_mobilenet_ssd = "../models/mobilenet+ssd"; +static const std::string g_mobilenet_ssd_gesture = + "../models/mobilenet+ssd_gesture"; static const std::string g_squeezenet = "../models/squeezenet"; static const std::string g_googlenet = "../models/googlenet"; static const std::string g_mobilenet = "../models/mobilenet"; @@ -62,9 +66,9 @@ void GetInput(const std::string &input_name, std::vector *input, size *= dim; } - T *input_ptr = (T *)malloc(sizeof(T) * size); + T *input_ptr = reinterpret_cast(malloc(sizeof(T) * size)); std::ifstream in(input_name, std::ios::in | std::ios::binary); - in.read((char *)(input_ptr), size * sizeof(T)); + in.read(reinterpret_cast(input_ptr), size * sizeof(T)); in.close(); for (int i = 0; i < size; ++i) { input->push_back(input_ptr[i]); @@ -79,6 +83,6 @@ void GetInput(const std::string &input_name, T *input_ptr = input->mutable_data(dims); std::ifstream in(input_name, std::ios::in | std::ios::binary); - in.read((char *)(input_ptr), input->numel() * sizeof(T)); + in.read(reinterpret_cast(input_ptr), input->numel() * sizeof(T)); in.close(); } diff --git a/tools/op.cmake b/tools/op.cmake index 456d36262e9abf997a7861838c870e698d64f3c1..feccaabbeb122d40cc5e0687a8420e147a98b1cf 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -65,6 +65,7 @@ else () set(FUSION_CONVADD_RELU_OP ON) set(FUSION_CONVADDBNRELU_OP ON) set(FUSION_DWCONVBNRELU_OP ON) + set(FUSION_CONVBNRELU_OP ON) set(PRELU_OP ON) set(RESIZE_OP ON) set(SCALE_OP ON) @@ -159,6 +160,11 @@ endif() if (FUSION_DWCONVBNRELU_OP) add_definitions(-DFUSION_DWCONVBNRELU_OP) endif() + +if (FUSION_CONVBNRELU_OP) + add_definitions(-DFUSION_CONVBNRELU_OP) +endif() + if (PRELU_OP) add_definitions(-DPRELU_OP) endif()