diff --git a/src/common/types.cpp b/src/common/types.cpp index 372331ad32244ca43ebad929b2918002f7fe42bd..6503f6383d22c7342c7446c44fab436810a7c46f 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -26,6 +26,7 @@ const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu"; const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu"; const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU = "fusion_conv_add_add_prelu"; const char *G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu"; +const char *G_OP_TYPE_FUSION_CONV_BN_ADD_RELU = "fusion_conv_bn_add_relu"; const char *G_OP_TYPE_FUSION_DWCONV_BN_RELU = "fusion_dwconv_bn_relu"; const char *G_OP_TYPE_FUSION_CONV_BN_RELU = "fusion_conv_bn_relu"; const char *G_OP_TYPE_FC = "fusion_fc"; @@ -79,6 +80,7 @@ std::unordered_map< {G_OP_TYPE_BOX_CODER, {{"PriorBox", "PriorBoxVar", "TargetBox"}, {"OutputBox"}}}, {G_OP_TYPE_FUSION_CONV_ADD_BN_RELU, {{"Input"}, {"Out"}}}, + {G_OP_TYPE_FUSION_CONV_BN_ADD_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_PRIOR_BOX, {{"Image", "Input"}, {"Boxes", "Variances"}}}, {G_OP_TYPE_MULTICLASS_NMS, {{"BBoxes", "Scores"}, {"Out"}}}, {G_OP_TYPE_FC, {{"X", "Y", "Z"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index dcbea1132866d7c0dadfc8a5c308bf837f3abbcf..6d38e4178907aa30968a6760a6ae5d69f4b61167 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -90,6 +90,7 @@ extern const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU; extern const char *G_OP_TYPE_FC; extern const char *G_OP_TYPE_FUSION_CONV_ADD; extern const char *G_OP_TYPE_FUSION_CONV_ADD_BN_RELU; +extern const char *G_OP_TYPE_FUSION_CONV_BN_ADD_RELU; extern const char *G_OP_TYPE_FUSION_DWCONV_BN_RELU; extern const char *G_OP_TYPE_FUSION_CONV_BN_RELU; diff --git a/src/io/loader.cpp b/src/io/loader.cpp index 7a0912106db8acaafb751d2c467245ee302c84e6..48a2b5cfdaa5f53cd9611dd0be1ce3df05988311 100644 --- a/src/io/loader.cpp +++ b/src/io/loader.cpp @@ -33,9 +33,7 @@ void InitMemoryFromProgram( for (const auto &var_desc : block->Vars()) { auto var = scope.get()->Var(var_desc->Name()); if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) { - if (var_desc->Persistable() && - var_desc->Type() != framework::VARTYPE_TYPE_FEED_MINIBATCH && - var_desc->Type() != framework::VARTYPE_TYPE_FETCH_LIST) { + if (var_desc->Persistable()) { auto dim = var_desc->Tensor_desc().Dims(); auto tensor = var->GetMutable(); tensor->Resize(framework::make_ddim(dim)); diff --git a/src/operators/fusion_conv_bn_add_relu_op.cpp b/src/operators/fusion_conv_bn_add_relu_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9823a3111e54f5aec90d5518073ca52255706c1a --- /dev/null +++ b/src/operators/fusion_conv_bn_add_relu_op.cpp @@ -0,0 +1,61 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVBNADDRELU_OP + +#include "operators/fusion_conv_bn_add_relu_op.h" +#include "operators/math/conv_func.h" + +namespace paddle_mobile { +namespace operators { + +template +void FusionConvBNAddReluOp::InferShape() const { + auto in_dims = this->param_.Input()->dims(); + auto filter_dims = this->param_.Filter()->dims(); + const std::vector &strides = this->param_.Strides(); + std::vector paddings = this->param_.Paddings(); + int groups = this->param_.Groups(); + std::vector dilations = this->param_.Dilations(); + + PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() && + dilations.size() == paddings.size() && + paddings.size() == strides.size()), + "ConvParam is not suitable"); + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + output_shape.push_back( + math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], + paddings[i], strides[i])); + } + + framework::DDim ddim = framework::make_ddim(output_shape); + this->param_.Output()->Resize(ddim); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_conv_bn_add_relu, ops::FusionConvBNAddReluOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +REGISTER_OPERATOR_FPGA(fusion_conv_bn_add_relu, ops::FusionConvBNAddReluOp); +#endif + +#endif diff --git a/src/operators/fusion_conv_bn_add_relu_op.h b/src/operators/fusion_conv_bn_add_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..62f3ccf37dfbff9720f39fb96b099f6d7eb5ddcc --- /dev/null +++ b/src/operators/fusion_conv_bn_add_relu_op.h @@ -0,0 +1,125 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVBNADDRELU_OP + +#pragma once + +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "op_param.h" +#include "operators/kernel/conv_bn_add_relu_kernel.h" + +namespace paddle_mobile { +namespace operators { +using std::string; +using std::vector; +class FusionConvBNAddReluMatcher : public framework::FusionOpMatcher { + public: + FusionConvBNAddReluMatcher() { + node_ = framework::Node(G_OP_TYPE_CONV); + node_ > std::make_shared(G_OP_TYPE_BATCHNORM) > + std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > + std::make_shared(G_OP_TYPE_RELU); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}, {"X", "X"}}}, + {G_OP_TYPE_BATCHNORM, + {{"Scale", "Scale"}, + {"Mean", "Mean"}, + {"Bias", "Bias"}, + {"Variance", "Variance"}, + {"Y", "BNY"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_CONV_BN_ADD_RELU; } + std::vector> NeedCheck() { + DLOG << " conv bn add relu check add X "; + return {{2, "Y"}, {2, "X"}}; + } +}; + +template +class FusionConvBNAddReluOp + : public framework::OperatorWithKernel< + DeviceType, FusionConvBNAddReluParam, + operators::ConvBNAddReluKernel> { + public: + FusionConvBNAddReluOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionConvBNAddReluParam, + operators::ConvBNAddReluKernel>( + type, inputs, outputs, attrs, scope) {} + + using framework::OperatorWithKernel< + DeviceType, FusionConvBNAddReluParam, + operators::ConvBNAddReluKernel>::OperatorWithKernel; + void InferShape() const override; + + protected: +}; + +#ifdef PADDLE_MOBILE_CPU + +#ifndef FUSION_CONV_BN_ADD_RELU_REGISTER +static framework::FusionOpRegistrar fusion_conv_bn_add_relu_registrar( + new FusionConvBNAddReluMatcher()); +#define FUSION_CONV_BN_ADD_RELU_REGISTER +#endif + +#endif + +#ifdef PADDLE_MOBILE_MALI_GPU + +#ifndef FUSION_CONV_BN_ADD_RELU_REGISTER +static framework::FusionOpRegistrar fusion_conv_bn_add_relu_registrar( + new FusionConvBNAddReluMatcher()); +#define FUSION_CONV_BN_ADD_RELU_REGISTER +#endif + +#endif + +#ifdef PADDLE_MOBILE_FPGA + +#ifndef FUSION_CONV_BN_ADD_RELU_REGISTER +static framework::FusionOpRegistrar fusion_conv_bn_add_relu_registrar( + new FusionConvBNAddReluMatcher()); +#define FUSION_CONV_BN_ADD_RELU_REGISTER +#endif + +#endif + +} // namespace operators +} // namespace paddle_mobile + +#ifdef PADDLE_MOBILE_CPU +USE_OP_CPU(fusion_conv_bn_add_relu); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +USE_OP_FPGA(fusion_conv_bn_add_relu); +#endif + +#endif diff --git a/src/operators/kernel/arm/conv_bn_add_relu_kernel.cpp b/src/operators/kernel/arm/conv_bn_add_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..785b13dde2ec1196792d17b253bb0d904da799f5 --- /dev/null +++ b/src/operators/kernel/arm/conv_bn_add_relu_kernel.cpp @@ -0,0 +1,66 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVBNADDRELU_OP + +#include "operators/kernel/conv_bn_add_relu_kernel.h" +#include "operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool ConvBNAddReluKernel::Init( + FusionConvBNAddReluParam *param) { + const Tensor *mean = param->InputMean(); + const Tensor *variance = param->InputVariance(); + const Tensor *scale = param->InputScale(); + const Tensor *bias = param->InputBias(); + const float epsilon = param->Epsilon(); + + auto mean_ptr = mean->data(); + auto variance_ptr = variance->data(); + auto scale_ptr = scale->data(); + auto bias_ptr = bias->data(); + + const int C = mean->numel(); + float inv_std_ptr[C]; + for (int i = 0; i < C; i++) { + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + } + Tensor *new_scale = new Tensor(); + Tensor *new_bias = new Tensor(); + auto new_scale_ptr = new_scale->mutable_data({C}); + auto new_bias_ptr = new_bias->mutable_data({C}); + for (int i = 0; i < C; i++) { + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; + new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; + } + param->SetNewScale(new_scale); + param->SetNewBias(new_bias); + return true; +} + +template <> +void ConvBNAddReluKernel::Compute( + const FusionConvBNAddReluParam ¶m) const { + ConvBNAddReluCompute(param); +} +template class ConvBNAddReluKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..7c31eed19693d20084e25daa485a0553d5d795f2 --- /dev/null +++ b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h @@ -0,0 +1,147 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVBNADDRELU_OP + +#pragma once + +#include +#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/im2col.h" +#include "operators/math/math_function.h" +#include "operators/math/vol2col.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { +void ConvBNAddReluBasic(const FusionConvBNAddReluParam ¶m) { + const Tensor *input = param.Input(); + Tensor filter = *param.Filter(); + Tensor new_bias = *param.NewBias(); + Tensor new_scale = *param.NewScale(); + Tensor *output = param.Output(); + Tensor *bias1 = param.Bias(); + int groups = param.Groups(); + DLOG << "yangfei2"; + DLOG << bias1->dims(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + + const int batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + + std::vector output_shape_vec(framework::vectorize(output->dims())); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + + bool is_expand = + math::IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col.mutable_data(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor bias_batch = bias1->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(in_slice, dilations, strides, paddings, &col); + } + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + Tensor bias_data = bias_batch.Slice(g * out_step, (g + 1) * out_step); + math::matmulWithBn(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(1), true, &new_scale, + &new_bias, g, bias_data.data()); + } + } +} +template +void ConvBNAddReluCompute(const FusionConvBNAddReluParam ¶m) { + Tensor Bias; + Bias.mutable_data({param.Groups()}); + if (param.Groups() == param.Input()->dims()[1] && + param.Input()->dims()[1] == param.Output()->dims()[1] && + param.Filter()->dims()[2] == param.Filter()->dims()[3] && + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); + } else if (param.Groups() == param.Input()->dims()[1] && + param.Input()->dims()[1] == param.Output()->dims()[1] && + param.Filter()->dims()[2] == param.Filter()->dims()[3] && + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), + // param.Output(), param.NewScale(), + // param.NewBias(), 1); + math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); + } else { + ConvBNAddReluBasic(param); + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h index 425685ba319afee2ce3bba285d69c4edb756a718..0d8d793cccf1b8de596bffa023ba367fb1b46155 100644 --- a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h +++ b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h @@ -30,7 +30,6 @@ void FusionFcCompute(const FusionFcParam ¶m) { int axis = param.Axis(); Tensor *out = param.Out(); auto *out_data = out->mutable_data(); - float *bias_data = out->mutable_data(); const Tensor x_matrix = input_x->dims().size() > 2 ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) @@ -59,7 +58,7 @@ void FusionFcCompute(const FusionFcParam ¶m) { // DLOG << out_data[i]; // } math::matmul(x_matrix, false, y_matrix, false, static_cast(1), - out, static_cast(1), false, bias_data); + out, static_cast(1), false); PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); // if (out_dim.size() != 2) { // out->Resize(out_dim); diff --git a/src/operators/kernel/central-arm-func/pool_arm_func.h b/src/operators/kernel/central-arm-func/pool_arm_func.h index 35fb09f94c49cd91915128260b6426fe0fedf725..37479c22efe95b6506054cf3ded5855aa766c34c 100644 --- a/src/operators/kernel/central-arm-func/pool_arm_func.h +++ b/src/operators/kernel/central-arm-func/pool_arm_func.h @@ -76,15 +76,17 @@ void PoolCompute(const PoolParam ¶m) { } } - } else if (ksize[0] == 2 && ksize[0] == ksize[1]) { + } else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 && + strides[0] == strides[1] && paddings[0] == paddings[1] && + paddings[1] == 0) { #if __ARM_NEON #if __aarch64__ PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); #else if (pooling_type == "max") { - math::Pool2x2Max(strides, paddings, in_x, out); + math::Pool2x2Maxs2p0(strides, paddings, in_x, out); } else if (pooling_type == "avg") { - math::Pool2x2Avg(strides, paddings, in_x, out); + math::Pool2x2Avgs2p0(strides, paddings, in_x, out); } #endif #else diff --git a/src/operators/kernel/conv_bn_add_relu_kernel.h b/src/operators/kernel/conv_bn_add_relu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..820e5f8bcbf58676e8374e575044b10fe4676efa --- /dev/null +++ b/src/operators/kernel/conv_bn_add_relu_kernel.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef FUSION_CONVBNADDRELU_OP + +#include +#include "framework/ddim.h" +#include "framework/operator.h" +#include "operators/math/conv_func.h" +#include "operators/math/im2col.h" +#include "operators/math/math_function.h" +#include "operators/math/vol2col.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using framework::DDim; +using framework::OpKernelBase; + +template +class ConvBNAddReluKernel + : public OpKernelBase> { + public: + void Compute(const FusionConvBNAddReluParam ¶m) const; + bool Init(FusionConvBNAddReluParam *param); +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 63cea79997e47f60877cc8fa9a2d04308032cd05..9f0a18f04f9f247cc06ccf73a36b574cb19d92ad 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -716,6 +716,27 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, } } +// 分块矩阵乘法 +void InnerKernelWithBnAdd(int mc, int nc, float alpha, const float *a, + const float *b, float beta, float *c, float *C, + int ldc, bool relu, float *new_scale, float *new_bias, + float *bias) { +#pragma omp parallel for + for (int j = 0; j < nc; j += NR) { + for (int i = 0; i < mc; i += MR) { +#if __aarch64__ + // AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); +#else + // AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); +#endif + } + } + WriteWithBnAddRelu(mc, nc, c, C, ldc, new_scale, new_bias, bias); +} + void InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b, float *c, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1) { @@ -1183,6 +1204,59 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, } } +// C = A * B, batchnorm(C),C = C + bias; relu(C) +void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, + float *new_scale, float *new_bias, float *bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr, *bias_ptr; + float32x4_t cv; + float32x4_t nbias; + float32x2_t scale; + float32x4_t biasv; + float32x4_t zero = vdupq_n_f32(0.0); + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + bias_ptr = bias + i * ldc; + nbias = vld1q_dup_f32(new_bias); + scale = vld1_dup_f32(new_scale); + new_bias++; + new_scale++; + float scale0 = vget_lane_f32(scale, 0); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + biasv = vld1q_f32(bias_ptr); + cv = vmlaq_n_f32(nbias, cv, scale0); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + bias_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + biasv = vld1q_f32(bias_ptr); + cv = vmlaq_n_f32(nbias, cv, scale0); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + } + } + } +} + #else void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { @@ -2426,6 +2500,59 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale, "q8", "q10", "q11", "q12", "q13", "q14"); } +// C = A * B, batchnorm(C),C = C + bias; relu(C) +void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, + float *new_scale, float *new_bias, float *bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr, *bias_ptr; + float32x4_t cv; + float32x4_t nbias; + float32x2_t scale; + float32x4_t biasv; + float32x4_t zero = vdupq_n_f32(0.0); + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + bias_ptr = bias + i * ldc; + nbias = vld1q_dup_f32(new_bias); + scale = vld1_dup_f32(new_scale); + new_bias++; + new_scale++; + float scale0 = vget_lane_f32(scale, 0); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + biasv = vld1q_f32(bias_ptr); + cv = vmlaq_n_f32(nbias, cv, scale0); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + bias_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + biasv = vld1q_f32(bias_ptr); + cv = vmlaq_n_f32(nbias, cv, scale0); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + } + } + } +} + /* // C = A * B void VecWriteBasic(int n, float *c, float *C, int ldc) { @@ -2835,7 +2962,7 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, - bool relu, float *new_scale, float *new_bias) { + bool relu, float *new_scale, float *new_bias, float *bias) { // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) // L2 cache is 0.5~4 Mib (Contex-A72 cluster) int L1 = 32 * 1024; @@ -2882,8 +3009,14 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, #else PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); #endif - InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, - &C(i, j), ldc, relu, new_scale + i, new_bias + i); + if (bias == nullptr) { + InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, + &C(i, j), ldc, relu, new_scale + i, new_bias + i); + } else { + InnerKernelWithBnAdd(mc, nc, alpha, packedA, packedB, beta, packedC, + &C(i, j), ldc, relu, new_scale + i, new_bias + i, + bias + i * ldc + j); + } } } @@ -3071,7 +3204,8 @@ void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, - bool relu, float *new_scale, float *new_bias) { + bool relu, float *new_scale, float *new_bias, + float *bias) { #ifdef _OPENMP int max_threads = omp_get_max_threads(); #else @@ -3148,8 +3282,14 @@ void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda, float *local_A = packedA + MC * KC * local_threads; float *local_C = packedC + MC * NC * local_threads; procPackA(mc, KC, mc % MR, &A(i, 0), lda, local_A); - InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, &C(i, 0), - ldc, relu, new_scale + i, new_bias + i); + if (bias == nullptr) { + InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, + &C(i, 0), ldc, relu, new_scale + i, new_bias + i); + } else { + InnerKernelWithBnAdd(mc, n, alpha, local_A, packedB, beta, local_C, + &C(i, 0), ldc, relu, new_scale + i, new_bias + i, + bias + i * ldc); + } } } else { #pragma omp parallel for @@ -3165,8 +3305,14 @@ void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda, float *local_B = packedB + KC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads; procPackB(KC, nc, nc % NR, &B(0, j), ldb, local_B); - InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, &C(0, j), - ldc, relu, new_scale, new_bias); + if (bias == nullptr) { + InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, + &C(0, j), ldc, relu, new_scale, new_bias); + } else { + InnerKernelWithBnAdd(m, nc, alpha, packedA, local_B, beta, local_C, + &C(0, j), ldc, relu, new_scale, new_bias, + bias + j); + } } } diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index 6139f1b45f3c6e76d859625ca000ea6d46d3c328..abd209bb45c650363b7d19c495bea4d9848fc834 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -81,6 +81,10 @@ void InnerKernelWithBias(int mc, int nc, float alpha, const float *a, void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, const float *b, float beta, float *c, float *C, int ldc, bool relu, float *new_scale, float *new_bias); +void InnerKernelWithBnAdd(int mc, int nc, float alpha, const float *a, + const float *b, float beta, float *c, float *C, + int ldc, bool relu, float *new_scale, float *new_bias, + float *bias); void InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b, float *c, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); @@ -125,7 +129,8 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale, // C = A * B, batchnorm(C), relu(C) void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *new_scale, float *new_bias); - +void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, + float *new_scale, float *new_bias, float *bias1); /* // 向量矩阵乘法结果回写 // C = A * B @@ -152,8 +157,7 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, // 32位 float 矩阵乘法, 并对结果进行 batchnrom void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, - bool relu, float *new_scale, float *new_bias); - + bool relu, float *new_scale, float *new_bias, float *bias); void SgemmWithPRelu(int m, int n, int k, const float *A, int lda, const float *B, int ldb, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); @@ -166,7 +170,7 @@ void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, // 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本) void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, - bool relu, float *new_scale, float *new_bias); + bool relu, float *new_scale, float *new_bias, float *bias); void SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, const float *B, int ldb, float *C, int ldc, float *p, diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index c5192441b2e89f4a5346f5d580fe87890becc432..576b06422cd0665d9e211633ce2f559e73c11fb5 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -56,7 +56,7 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, framework::Tensor *matrix_out, float beta, bool relu, framework::Tensor *new_scale, - framework::Tensor *new_bias, int group) { + framework::Tensor *new_bias, int group, float *bias) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -79,12 +79,12 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, SgemmWithBn_omp(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, new_scale->data() + group, - new_bias->data() + group); + new_bias->data() + group, bias); #else SgemmWithBn(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, - new_scale->data() + group, - new_bias->data() + group); + new_scale->data() + group, new_bias->data() + group, + bias); #endif } void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a, diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index 26ec50872b1dccf1bc2f24cfea284de02e57fc9c..8d97f8628fb4f71cdd7664161983225136ec7c7f 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -32,7 +32,7 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, framework::Tensor *matrix_out, T beta, bool relu, framework::Tensor *new_scale, framework::Tensor *new_bias, - int group); + int group, float *bias = nullptr); void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, diff --git a/src/operators/math/pool_2x2.cpp b/src/operators/math/pool_2x2.cpp index 0a2d96d4d065d7938e6872b4f073e080d7be8c3a..76af743818edacac6dd9e1878e8d8220ccff6d73 100644 --- a/src/operators/math/pool_2x2.cpp +++ b/src/operators/math/pool_2x2.cpp @@ -20,21 +20,15 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { namespace math { +#define FLT_MAX __FLT_MAX__ -void Pool2x2Max(vector strides, vector paddings, const Tensor *input, - Tensor *output) { -#if __ARM_NEON - -#if __aarch64__ -#else +void Pool2x2Maxs2p0(vector strides, vector paddings, + const Tensor *input, Tensor *output) { const int batch_size = input->dims()[0]; - const int input_height = input->dims()[2]; - const int input_width = input->dims()[3]; const int output_channels = output->dims()[1]; - int output_height = output->dims()[2]; const int output_width = output->dims()[3]; const int ksize_height = 2; @@ -47,72 +41,110 @@ void Pool2x2Max(vector strides, vector paddings, const Tensor *input, const int input_channel_stride = input_height * input_width; const int output_channel_stride = output_height * output_width; + const int input_batch_stride = output_channels * input_channel_stride; + const int output_batch_stride = output_channels * output_channel_stride; + const float *input_data = input->data(); float *output_data = output->mutable_data(); - int out_w_num = output_width >> 2; - const int in_h_num = output_height >> 1; - const int input_batch_stride = output_channels * input_channel_stride; - const int output_batch_stride = output_channels * output_channel_stride; - int remain = output_width - out_w_num << 2; + int w1 = input_width / 16; + int _w1 = input_width % 16; + int w2 = _w1 / 4; + int _w2 = _w1 % 4; + for (int i = 0; i < batch_size; ++i) { for (int c = 0; c < output_channels; ++c) { - const float *input_data_chanel_row_next = input_data + input_width; - for (; output_height > 0; output_height--) { - if (out_w_num > 0) { - asm volatile( - "max_loop: \n\t" - "vld1.f32 {q0,q1}, [%[in_ptr1]]! \n\t" - "vld1.f32 {q2,q3}, [%[in_ptr2]]! \n\t" - "vmax.f32 q0, q0, q2 \n\t" - "vmax.f32 q1, q1, q3 \n\t" - "vpmax.f32 d4, d0, d1 \n\t" - "vpmax.f32 d5, d2, d3 \n\t" - "subs %[out_w_num], #1 \n\t" - "vst1.32 {q2}, [%[out_ptr]]! \n\t" - "bne max_loop \n\t" - : [in_ptr1] "+r"(input_data), - [in_ptr2] "+r"(input_data_chanel_row_next), - [out_ptr] "+r"(output_data), [out_w_num] "+r"(out_w_num) - : - : "memory", "q0", "q1", "q2", "q3"); + for (int ph = 0; ph < input_height; ph += 2) { + const float *in_ptr1 = input_data + i * input_batch_stride + + c * input_channel_stride + ph * input_width; + const float *in_ptr2 = in_ptr1 + input_width; + if (ph + 1 >= input_height) { + in_ptr2 = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * input_width)); + memset(static_cast(const_cast(in_ptr2)), -FLT_MAX, + sizeof(float) * input_width); } + float *out_ptr = output_data + i * output_batch_stride + + c * output_channel_stride + ph / 2 * output_width; + asm volatile( + "subs %[w1], %[w1], #1 \n\t" + "blt end_w1_%= \n\t" + "loop_w1_%=: \n\t" + + "pld [%[in_ptr1], #64] \n\t" + "pld [%[in_ptr2], #64] \n\t" + + "vld1.f32 {q0, q1}, [%[in_ptr1]]! \n\t" + "vld1.f32 {q2, q3}, [%[in_ptr2]]! \n\t" + "vld1.f32 {q6, q7}, [%[in_ptr1]]! \n\t" + "vld1.f32 {q8, q9}, [%[in_ptr2]]! \n\t" - for (; remain > 0; remain--) { - float max_row1 = std::max(input_data[0], input_data[1]); - float max_row2 = std::max(input_data_chanel_row_next[0], - input_data_chanel_row_next[1]); - *output_data = std::max(max_row1, max_row2); - input_data += 2; - input_data_chanel_row_next += 2; - output_data++; + "vmax.f32 q0, q0, q2 \n\t" + "vmax.f32 q1, q1, q3 \n\t" + + "vmax.f32 q6, q6, q8 \n\t" + "vmax.f32 q7, q7, q9 \n\t" + + "vpmax.f32 d8, d0, d1 \n\t" + "vpmax.f32 d9, d2, d3 \n\t" + + "vpmax.f32 d10, d12, d13 \n\t" + "vpmax.f32 d11, d14, d15 \n\t" + + "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" + + "subs %[w1], %[w1], #1 \n\t" + "bge loop_w1_%= \n\t" + "end_w1_%=: \n\t" + + "subs %[w2], %[w2], #1 \n\t" + "blt end_w2_%= \n\t" + "loop_w2_%=: \n\t" + + "vld1.f32 {q0}, [%[in_ptr1]]! \n\t" + "vld1.f32 {q1}, [%[in_ptr2]]! \n\t" + "vmax.f32 q0, q0, q1 \n\t" + "vpmax.f32 d4, d0, d1 \n\t" + "vst1.32 {d4}, [%[out_ptr]]! \n\t" + + "subs %[w2], %[w2], #1 \n\t" + "bge loop_w2_%= \n\t" + "end_w2_%=: \n\t" + : + : [w1] "r"(w1), [w2] "r"(w2), [in_ptr1] "r"(in_ptr1), + [in_ptr2] "r"(in_ptr2), [out_ptr] "r"(out_ptr) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9"); + + if (_w2 != 0) { + in_ptr1 += 16 * w1 + 4 * w2; + in_ptr2 += 16 * w1 + 4 * w2; + out_ptr += 8 * w1 + 2 * w2; + if (_w2 == 1) { + *out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2; + } else if (_w2 == 2) { + float temp = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++; + float temp1 = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2; + *out_ptr = (temp > temp1) ? temp : temp1; + } else if (_w2 == 3) { + float temp = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++; + float temp1 = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++; + *out_ptr++ = (temp > temp1) ? temp : temp1; + *out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2; + } } } - input_data += input_channel_stride; - output_data += output_channel_stride; } - input_data += input_batch_stride; - output_data += output_batch_stride; } -#endif -#else -#endif } -void Pool2x2Avg(vector strides, vector paddings, const Tensor *input, - Tensor *output) { -#if __ARM_NEON - -#if __aarch64__ -#else +void Pool2x2Avgs2p0(vector strides, vector paddings, + const Tensor *input, Tensor *output) { const int batch_size = input->dims()[0]; - const int input_height = input->dims()[2]; - const int input_width = input->dims()[3]; const int output_channels = output->dims()[1]; - int output_height = output->dims()[2]; const int output_width = output->dims()[3]; const int ksize_height = 2; @@ -125,59 +157,114 @@ void Pool2x2Avg(vector strides, vector paddings, const Tensor *input, const int input_channel_stride = input_height * input_width; const int output_channel_stride = output_height * output_width; + const int input_batch_stride = output_channels * input_channel_stride; + const int output_batch_stride = output_channels * output_channel_stride; + const float *input_data = input->data(); float *output_data = output->mutable_data(); - int out_w_num = output_width >> 2; - const int input_batch_stride = output_channels * input_channel_stride; - const int output_batch_stride = output_channels * output_channel_stride; - float vqua[] = {0.25f, 0.25f, 0.25f, 0.25f}; - int remain = output_width - out_w_num << 2; + int w1 = input_width / 16; + int _w1 = input_width % 16; + int w2 = _w1 / 4; + int _w2 = _w1 % 4; + + float quarter = 1 / 4; for (int i = 0; i < batch_size; ++i) { for (int c = 0; c < output_channels; ++c) { - const float *input_data_chanel_row_next = input_data + input_width; - for (; output_height > 0; output_height--) { - if (out_w_num > 0) { - asm volatile( - "avg_loop: \n\t" - "vld1.32 {q0,q1}, [%[in_ptr1]]! \n\t" - "vld1.32 {q2,q3}, [%[in_ptr2]]! \n\t" - "vadd.f32 q0, q0, q2 \n\t" - "vadd.f32 q1, q1, q3 \n\t" - "vpadd.f32 d4, d0, d1 \n\t" - "vpadd.f32 d5, d2, d3 \n\t" - "vld1.32 {q4}, [%[vqua]]! \n\t" - "vmul.f32 q2, q2, q4 \n\t" - "subs %[out_w_num], #1 \n\t" - "vst1.32 {q2}, [%[out_ptr]]! \n\t" - "bne avg_loop \n\t" - : [in_ptr1] "+r"(input_data), - [in_ptr2] "+r"(input_data_chanel_row_next), - [out_ptr] "+r"(output_data), [out_w_num] "+r"(out_w_num) - : [vqua] "r"(vqua) - : "memory", "q0", "q1", "q2", "q3", "q4"); + for (int ph = 0; ph < input_height; ph += 2) { + const float *in_ptr1 = input_data + i * input_batch_stride + + c * input_channel_stride + ph * input_width; + const float *in_ptr2 = in_ptr1 + input_width; + if (ph + 1 >= input_height) { + in_ptr2 = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * input_width)); + memset(static_cast(const_cast(in_ptr2)), 0, + sizeof(float) * input_width); } + float *out_ptr = output_data + i * output_batch_stride + + c * output_channel_stride + ph / 2 * output_width; + asm volatile( + "subs %[w1], %[w1], #1 \n\t" + "blt end_w1_%= \n\t" + "loop_w1_%=: \n\t" + + "pld [%[in_ptr1], #64] \n\t" + "pld [%[in_ptr2], #64] \n\t" + + "vmov.f32 d0[0], %[quarter] \n\t" + "vld1.f32 {q1, q2}, [%[in_ptr1]]! \n\t" + "vld1.f32 {q3, q4}, [%[in_ptr2]]! \n\t" + "vld1.f32 {q7, q8}, [%[in_ptr1]]! \n\t" + "vld1.f32 {q9, q10}, [%[in_ptr2]]! \n\t" + + "vadd.f32 q1, q1, q3 \n\t" + "vadd.f32 q2, q2, q4 \n\t" - for (; remain > 0; remain--) { - float max_row1 = std::max(input_data[0], input_data[1]); - float max_row2 = std::max(input_data_chanel_row_next[0], - input_data_chanel_row_next[1]); - *output_data = std::max(max_row1, max_row2); - input_data += 2; - input_data_chanel_row_next += 2; - output_data++; + "vadd.f32 q7, q7, q9 \n\t" + "vadd.f32 q8, q8, q10 \n\t" + + "vpadd.f32 d10, d2, d3 \n\t" + "vpadd.f32 d11, d4, d5 \n\t" + + "vpadd.f32 d12, d14, d15 \n\t" + "vpadd.f32 d13, d16, d17 \n\t" + + "vmul.f32 q5, q5, d0[0] \n\t" + "vmul.f32 q6, q6, d0[0] \n\t" + + "vst1.32 {q5, q6}, [%[out_ptr]]! \n\t" + + "subs %[w1], %[w1], #1 \n\t" + "bge loop_w1_%= \n\t" + "end_w1_%=: \n\t" + + "subs %[w2], %[w2], #1 \n\t" + "blt end_w2_%= \n\t" + "loop_w2_%=: \n\t" + + "vld1.f32 {q1}, [%[in_ptr1]]! \n\t" + "vld1.f32 {q2}, [%[in_ptr2]]! \n\t" + "vadd.f32 q1, q1, q2 \n\t" + "vpadd.f32 d4, d2, d3 \n\t" + "vmul.f32 d4, d4, d0[0] \n\t" + "vst1.32 {d4}, [%[out_ptr]]! \n\t" + + "subs %[w2], %[w2], #1 \n\t" + "bge loop_w2_%= \n\t" + "end_w2_%=: \n\t" + : + : [w1] "r"(w1), [w2] "r"(w2), [in_ptr1] "r"(in_ptr1), + [in_ptr2] "r"(in_ptr2), [out_ptr] "r"(out_ptr), + [quarter] "r"(quarter) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10"); + + if (_w2 != 0) { + in_ptr1 += 16 * w1 + 4 * w2; + in_ptr2 += 16 * w1 + 4 * w2; + out_ptr += 8 * w1 + 2 * w2; + if (_w2 == 1) { + *out_ptr = 0.5 * (*in_ptr1 + *in_ptr2); + } else if (_w2 == 2) { + float temp = 0; + temp += *in_ptr1++; + temp += *in_ptr2++; + temp += *in_ptr1; + temp += *in_ptr2; + *out_ptr = 0.5 * temp; + } else if (_w2 == 3) { + float temp = 0; + temp += *in_ptr1++; + temp += *in_ptr2++; + temp += *in_ptr1++; + temp += *in_ptr2++; + *out_ptr++ = 0.5 * temp; + *out_ptr = 0.5 * (*in_ptr1 + *in_ptr2); + } } } - input_data += input_channel_stride; - output_data += output_channel_stride; } - input_data += input_batch_stride; - output_data += output_batch_stride; } - -#endif -#else -#endif } //} diff --git a/src/operators/math/pool_2x2.h b/src/operators/math/pool_2x2.h index ae32a3912b677efb50d8558700741a225e3eb3f8..bd5e48482607cc868408b6371f47e0cb55caf499 100644 --- a/src/operators/math/pool_2x2.h +++ b/src/operators/math/pool_2x2.h @@ -26,11 +26,11 @@ namespace math { using framework::Tensor; using std::vector; -void Pool2x2Max(vector strides, vector paddings, const Tensor *input, - Tensor *output); +void Pool2x2Maxs2p0(vector strides, vector paddings, + const Tensor *input, Tensor *output); -void Pool2x2Avg(vector strides, vector paddings, const Tensor *in_x, - Tensor *out); +void Pool2x2Avgs2p0(vector strides, vector paddings, + const Tensor *in_x, Tensor *out); } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/pool_3x3.cpp b/src/operators/math/pool_3x3.cpp index 28547b71fca6caea2ff4341b3f832c0035436a72..05d3017f635a040a52d2cc377c8f384dbbd8086c 100644 --- a/src/operators/math/pool_3x3.cpp +++ b/src/operators/math/pool_3x3.cpp @@ -558,15 +558,13 @@ void Pool3x3Max(vector strides, vector paddings, const Tensor *input, const float *input_seg = input_data + c * input_channel_stride; float *output_seg = output_data + c * output_channel_stride; for (int ph = 0; ph < output_height; ph++) { + int hstart = ph * stride - padding; + int hend = min(hstart + 3, input_height); + hstart = max(hstart, 0); for (int pw = 0; pw < output_width; pw++) { - int hstart = ph * stride - padding; int wstart = pw * stride - padding; - int hend = min(hstart + 3, input_height + padding); - int wend = min(wstart + 3, input_width + padding); - hstart = max(hstart, 0); + int wend = min(wstart + 3, input_width); wstart = max(wstart, 0); - hend = min(hend, input_height); - wend = min(wend, input_width); const float *pos1 = input_seg + hstart * input_width + wstart; const float *pos2 = input_seg + (hstart + 1) * input_width + wstart; const float *pos3 = input_seg + (hstart + 2) * input_width + wstart; diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 8f289b24ffa16e4af92ddff77b722fd458bc7c84..a6077812a0a4f56b58e666617e880b91f7c19b97 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1472,6 +1472,119 @@ class FusionConvAddBNReluParam : public OpParam { }; #endif +#ifdef FUSION_CONVBNADDRELU_OP +template +class FusionConvBNAddReluParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionConvBNAddReluParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + bias_ = InputYFrom(inputs, scope); + axis_ = GetAttr("axis", attrs); + filter_ = FilterFrom(inputs, scope); + input_ = InputFrom(inputs, scope); + output_ = OutFrom(outputs, scope); + strides_ = GetAttr>("strides", attrs); + paddings_ = GetAttr>("paddings", attrs); + dilations_ = GetAttr>("dilations", attrs); + groups = GetAttr("groups", attrs); + input_bias_ = InputBiasFrom(inputs, scope); + input_mean_ = InputMeanFrom(inputs, scope); + input_scale_ = InputScaleFrom(inputs, scope); + input_variance_ = InputVarianceFrom(inputs, scope); + epsilon_ = GetAttr("epsilon", attrs); + momentum_ = GetAttr("momentum", attrs); + keyBNY_ = getkey("BNY", inputs, 0); + keyX_ = getkey("X", inputs, 0); + keyY_ = getkey("Y", inputs, 0); + if (keyX_ == keyBNY_) { + bias_ = InputYFrom(inputs, scope); + } else if (keyY_ == keyBNY_) { + bias_ = InputXFrom(inputs, scope); + } + // is_test_ = GetAttr("is_test", attrs); + } + RType *Bias() const { return bias_; } + + const int &Axis() const { return axis_; } + + const RType *Input() const { return input_; } + +#ifdef PADDLE_MOBILE_FPGA + RType *Filter() const { return filter_; } +#else + const RType *Filter() const { return filter_; } +#endif + + RType *Output() const { return output_; } + + const vector &Strides() const { return strides_; } + + const vector &Paddings() const { return paddings_; } + + const vector &Dilations() const { return dilations_; } + + const int &Groups() const { return groups; } + + const RType *InputBias() const { return input_bias_; } + + const RType *InputMean() const { return input_mean_; } + + const RType *InputScale() const { return input_scale_; } + + const RType *InputVariance() const { return input_variance_; } + + const float &Epsilon() const { return epsilon_; } + + const float &Momentum() const { return momentum_; } + + const bool &IsTest() const { return is_test_; } + + void SetNewScale(RType *new_scale) { new_scale_ = new_scale; } + + void SetNewBias(RType *new_bias) { new_bias_ = new_bias; } + + const RType *NewScale() const { return new_scale_; } + + const RType *NewBias() const { return new_bias_; } + + protected: + RType *bias_; + int axis_; + RType *input_; + RType *output_; + RType *filter_; + vector strides_; + vector paddings_; + vector dilations_; + int groups; + RType *input_bias_; + RType *input_mean_; + RType *input_scale_; + RType *input_variance_; + float epsilon_; + float momentum_; + bool is_test_; + RType *new_bias_; + RType *new_scale_; + std::string keyBNY_; + std::string keyX_; + std::string keyY_; +#ifdef PADDLE_MOBILE_FPGA + + private: + fpga::ConvArgs fpga_conv_args; + + public: + const fpga::ConvArgs &FpgaArgs() const { return fpga_conv_args; } + void SetFpgaArgs(const fpga::ConvArgs &args) { fpga_conv_args = args; } +#endif +}; +#endif + #ifdef FUSION_CONVBN_OP template class FusionConvBNParam : public OpParam { diff --git a/test/common/test_gemm_accuracy.cpp b/test/common/test_gemm_accuracy.cpp index 35241fbd535e062be1c7f1f28eb3860d118a3455..3e31a5f2fe9b41f90f9aebfe44db908682f83ce1 100644 --- a/test/common/test_gemm_accuracy.cpp +++ b/test/common/test_gemm_accuracy.cpp @@ -83,8 +83,8 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { } } - paddle_mobile::operators::math::SgemmWithBn(m, n, k, 0.9, a, lda, b, ldb, 0.3, - c, ldc, relu, scale, bias); + paddle_mobile::operators::math::SgemmWithBn( + m, n, k, 0.9, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias, nullptr); int eq = 0; int neq = 0; for (int i = 0; i < m * n; ++i) { diff --git a/tools/op.cmake b/tools/op.cmake index 38c511400532dd73de03aeff6686c881a3c0ad26..5965cf030fb935c89a5fb42fa72b5e810288552b 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -87,6 +87,7 @@ if ("resnet" IN_LIST NET) set(ELEMENTWISEADD_OP ON) set(POOL_OP ON) set(BATCHNORM_OP ON) + set(FUSION_CONVBNADDRELU_OP ON) set(MUL_OP ON) set(RESHAPE_OP ON) set(SOFTMAX_OP ON) @@ -141,6 +142,7 @@ if(NOT FOUND_MATCH) set(FUSION_CONVADDADDPRELU_OP ON) set(FUSION_DWCONVBNRELU_OP ON) set(FUSION_CONVBNRELU_OP ON) + set(FUSION_CONVBNADDRELU_OP ON) set(PRELU_OP ON) set(RESIZE_OP ON) set(SCALE_OP ON) @@ -244,6 +246,10 @@ if (FUSION_CONVBNRELU_OP) add_definitions(-DFUSION_CONVBNRELU_OP) endif() +if (FUSION_CONVBNADDRELU_OP) + add_definitions(-DFUSION_CONVBNADDRELU_OP) +endif() + if (PRELU_OP) add_definitions(-DPRELU_OP) endif()