diff --git a/src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp deleted file mode 100644 index 635aac4dffa0fbc7d0bb2ff604973f0325d00566..0000000000000000000000000000000000000000 --- a/src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVADDBNRELU_OP - -#include "operators/kernel/conv_add_bn_relu_kernel.h" -#include -#include "operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h" - -namespace paddle_mobile { -namespace operators { - -template <> -bool ConvAddBNReluKernel::Init( - FusionConvAddBNReluParam *param) { - const Tensor *mean = param->InputMean(); - const Tensor *variance = param->InputVariance(); - const Tensor *scale = param->InputScale(); - const Tensor *bias = param->InputBias(); - const float epsilon = param->Epsilon(); - - auto mean_ptr = mean->data(); - auto variance_ptr = variance->data(); - auto scale_ptr = scale->data(); - auto bias_ptr = bias->data(); - - const int C = mean->numel(); - float inv_std_ptr[C]; - for (int i = 0; i < C; i++) { - inv_std_ptr[i] = - 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); - } - LoDTensor *new_scale = new LoDTensor(); - LoDTensor *new_bias = new LoDTensor(); - auto new_scale_ptr = new_scale->mutable_data({C}); - auto new_bias_ptr = new_bias->mutable_data({C}); - for (int i = 0; i < C; i++) { - new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; - new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; - } - param->SetNewScale(new_scale); - param->SetNewBias(new_bias); - return true; -} - -template <> -void ConvAddBNReluKernel::Compute( - const FusionConvAddBNReluParam ¶m) { - ConvAddBNReluCompute(param); -} -template class ConvAddBNReluKernel; - -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/arm/conv_bn_relu_kernel.cpp b/src/operators/kernel/arm/conv_bn_relu_kernel.cpp deleted file mode 100644 index bac91ff273e5dced71ae0e46852697a8696b3a68..0000000000000000000000000000000000000000 --- a/src/operators/kernel/arm/conv_bn_relu_kernel.cpp +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVBNRELU_OP - -#include "operators/kernel/conv_bn_relu_kernel.h" -#include -#include "operators/kernel/central-arm-func/conv_bn_relu_arm_func.h" - -namespace paddle_mobile { -namespace operators { - -template <> -bool ConvBNReluKernel::Init(FusionConvBNReluParam *param) { - const Tensor *mean = param->InputMean(); - const Tensor *variance = param->InputVariance(); - const Tensor *scale = param->InputScale(); - const Tensor *bias = param->InputBias(); - const float epsilon = param->Epsilon(); - - // DLOG << "variance: " << *variance; - - auto mean_ptr = mean->data(); - auto variance_ptr = variance->data(); - auto scale_ptr = scale->data(); - auto bias_ptr = bias->data(); - - const int C = mean->numel(); - float inv_std_ptr[C]; - for (int i = 0; i < C; i++) { - inv_std_ptr[i] = - 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); - } - LoDTensor *new_scale = new LoDTensor(); - LoDTensor *new_bias = new LoDTensor(); - auto new_scale_ptr = new_scale->mutable_data({C}); - auto new_bias_ptr = new_bias->mutable_data({C}); - for (int i = 0; i < C; i++) { - new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; - new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; - } - - param->SetNewScale(new_scale); - param->SetNewBias(new_bias); - return true; -} - -template <> -void ConvBNReluKernel::Compute( - const FusionConvBNReluParam ¶m) { - ConvBNReluCompute(param); -} -template class ConvBNReluKernel; - -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/arm/conv_add_add_prelu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_add_prelu_kernel.cpp similarity index 100% rename from src/operators/kernel/arm/conv_add_add_prelu_kernel.cpp rename to src/operators/kernel/arm/convolution/conv_add_add_prelu_kernel.cpp diff --git a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b0bfae799c2f6afdef858f71910b9cd6f7e5a276 --- /dev/null +++ b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp @@ -0,0 +1,112 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDBNRELU_OP + +#include "operators/kernel/conv_add_bn_relu_kernel.h" +#include +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool ConvAddBNReluKernel::Init( + FusionConvAddBNReluParam *param) { + const Tensor *mean = param->InputMean(); + const Tensor *variance = param->InputVariance(); + const Tensor *scale = param->InputScale(); + const Tensor *bias = param->InputBias(); + const float epsilon = param->Epsilon(); + + auto mean_ptr = mean->data(); + auto variance_ptr = variance->data(); + auto scale_ptr = scale->data(); + auto bias_ptr = bias->data(); + + const int C = mean->numel(); + float inv_std_ptr[C]; + for (int i = 0; i < C; i++) { + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + } + LoDTensor *new_scale = new LoDTensor(); + LoDTensor *new_bias = new LoDTensor(); + auto new_scale_ptr = new_scale->mutable_data({C}); + auto new_bias_ptr = new_bias->mutable_data({C}); + for (int i = 0; i < C; i++) { + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; + new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; + } + param->SetNewScale(new_scale); + param->SetNewBias(new_bias); + + InitBaseConvKernel(param); + return true; +} + +template <> +void ConvAddBNReluKernel::Compute( + const FusionConvAddBNReluParam ¶m) { + switch (param.ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: + math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: + math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: + math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), + nullptr, false, false); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; + case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: + math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), + param.Filter(), nullptr, param.Output(), false); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; +#ifndef __aarch64__ + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; +#endif // __aarch64__ + case ConvParam::EXEC_GEMM_FLOAT: + ConvBNReluBasic>(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } +} + +template class ConvAddBNReluKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/conv_add_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp similarity index 94% rename from src/operators/kernel/arm/conv_add_kernel.cpp rename to src/operators/kernel/arm/convolution/conv_add_kernel.cpp index e016b8efbd15472ae0d77423d84dc19671bfa316..9e27dc62fb79ef8e764da3f8c82cbbd0cf079815 100644 --- a/src/operators/kernel/arm/conv_add_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp @@ -14,7 +14,7 @@ limitations under the License. */ #ifdef FUSION_CONVADD_OP #include "operators/kernel/conv_add_kernel.h" -#include "../central-arm-func/conv_add_arm_func.h" +#include "operators/kernel/central-arm-func/conv_add_arm_func.h" namespace paddle_mobile { namespace operators { diff --git a/src/operators/kernel/arm/conv_add_prelu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_prelu_kernel.cpp similarity index 100% rename from src/operators/kernel/arm/conv_add_prelu_kernel.cpp rename to src/operators/kernel/arm/convolution/conv_add_prelu_kernel.cpp diff --git a/src/operators/kernel/arm/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp similarity index 100% rename from src/operators/kernel/arm/conv_add_relu_kernel.cpp rename to src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp diff --git a/src/operators/kernel/arm/conv_bn_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp similarity index 100% rename from src/operators/kernel/arm/conv_bn_add_relu_kernel.cpp rename to src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp diff --git a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..738e85b01b714365797378b1fd580223cecef625 --- /dev/null +++ b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp @@ -0,0 +1,110 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVBNRELU_OP + +#include "operators/kernel/conv_bn_relu_kernel.h" +#include +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool ConvBNReluKernel::Init(FusionConvBNReluParam *param) { + const Tensor *mean = param->InputMean(); + const Tensor *variance = param->InputVariance(); + const Tensor *scale = param->InputScale(); + const Tensor *bias = param->InputBias(); + const float epsilon = param->Epsilon(); + + auto mean_ptr = mean->data(); + auto variance_ptr = variance->data(); + auto scale_ptr = scale->data(); + auto bias_ptr = bias->data(); + + const int C = mean->numel(); + float inv_std_ptr[C]; + for (int i = 0; i < C; i++) { + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + } + LoDTensor *new_scale = new LoDTensor(); + LoDTensor *new_bias = new LoDTensor(); + auto new_scale_ptr = new_scale->mutable_data({C}); + auto new_bias_ptr = new_bias->mutable_data({C}); + for (int i = 0; i < C; i++) { + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; + new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; + } + param->SetNewScale(new_scale); + param->SetNewBias(new_bias); + + InitBaseConvKernel(param); + return true; +} + +template <> +void ConvBNReluKernel::Compute( + const FusionConvBNReluParam ¶m) { + switch (param.ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: + math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: + math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: + math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), + nullptr, false, false); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; + case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: + math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), + param.Filter(), nullptr, param.Output(), false); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; +#ifndef __aarch64__ + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; +#endif // __aarch64__ + case ConvParam::EXEC_GEMM_FLOAT: + ConvBNReluBasic>(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } +} +template class ConvBNReluKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp similarity index 62% rename from src/operators/kernel/arm/conv_kernel.cpp rename to src/operators/kernel/arm/convolution/conv_common.cpp index 95a46816b9f35046388e9a14e659e02822871ecd..e070db0eccacb3f80a6c1d12bef4250e7f69d548 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -12,22 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef CONV_OP - -#include "operators/kernel/conv_kernel.h" -#include "operators/kernel/central-arm-func/conv_arm_func.h" +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/math/winograd/winograd_transform.h" namespace paddle_mobile { namespace operators { -template <> -bool ConvKernel::Init(ConvParam *param) { +void InitBaseConvKernel(ConvParam *param) { bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Filter()->dims()[2] == 3; bool conv5x5 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Filter()->dims()[2] == 5; bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] && param->Input()->dims()[1] == param->Output()->dims()[1]; + bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] && param->Input()->dims()[1] == param->Output()->dims()[1]; if (param->Filter()->type() == typeid(int8_t)) { @@ -65,10 +63,10 @@ bool ConvKernel::Init(ConvParam *param) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE5x5_FLOAT; } else if (conv3x3 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && - param->Strides()[0] == 1 && param->Dilations()[0] == 1 && + param->Strides()[0] == 1 && param->Dilations()[0] == 1 /* && param->Output()->dims()[1] >= 16 && param->Input()->dims()[1] >= 16 && - param->Input()->dims()[2] <= 140 /* refered from ncnn */) { + param->Input()->dims()[2] <= 140 */ /* refered from ncnn */) { param->ExecMode() = ConvParam::EXEC_WINOGRAD3X3_FLOAT; // transform weight param->transformed_filter_ = new framework::LoDTensor; @@ -79,59 +77,7 @@ bool ConvKernel::Init(ConvParam *param) { param->ExecMode() = ConvParam::EXEC_GEMM_FLOAT; } } - return true; -} - -template <> -void ConvKernel::Compute(const ConvParam ¶m) { - switch (param.ExecMode()) { - case ConvParam::EXEC_GEMM_INT8: - GemmConv(param); - break; -#ifndef __aarch64__ - case ConvParam::EXEC_DEPTHWISE3x3_INT8: - DepthwiseConv3x3(param); - break; - case ConvParam::EXEC_DEPTHWISE5x5_INT8: - DepthwiseConv5x5(param); - break; -#endif // __aarch64__ - case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: - math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), - nullptr, false, false); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: - math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), nullptr, false, false); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: - math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), - nullptr, false, false); - break; - case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), nullptr, param.Output(), false); - break; -#ifndef __aarch64__ - case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: - DepthwiseConv5x5(param); - break; - case ConvParam::EXEC_WINOGRAD3X3_FLOAT: - WinogradConv3x3<8, 3>(param); - break; -#endif // __aarch64__ - case ConvParam::EXEC_GEMM_FLOAT: - GemmConv(param); - break; - default: - PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", - param.ExecMode()); - } } -template class ConvKernel; - } // namespace operators } // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/arm/convolution/conv_common.h b/src/operators/kernel/arm/convolution/conv_common.h new file mode 100644 index 0000000000000000000000000000000000000000..4db37715c4302439fa0e43446bd62ef68675276e --- /dev/null +++ b/src/operators/kernel/arm/convolution/conv_common.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +void InitBaseConvKernel(ConvParam *param); + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/arm/convolution/conv_kernel.cpp b/src/operators/kernel/arm/convolution/conv_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45776de8b89f8efedc8bbbcf0a581d78a5678b81 --- /dev/null +++ b/src/operators/kernel/arm/convolution/conv_kernel.cpp @@ -0,0 +1,82 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CONV_OP + +#include "operators/kernel/conv_kernel.h" +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool ConvKernel::Init(ConvParam *param) { + InitBaseConvKernel(param); + return true; +} + +template <> +void ConvKernel::Compute(const ConvParam ¶m) { + switch (param.ExecMode()) { + case ConvParam::EXEC_GEMM_INT8: + GemmConv(param); + break; +#ifndef __aarch64__ + case ConvParam::EXEC_DEPTHWISE3x3_INT8: + DepthwiseConv3x3(param); + break; + case ConvParam::EXEC_DEPTHWISE5x5_INT8: + DepthwiseConv5x5(param); + break; +#endif // __aarch64__ + case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: + math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), + nullptr, false, false); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: + math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), + param.Output(), nullptr, false, false); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: + math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), + nullptr, false, false); + break; + case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: + math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), + param.Filter(), nullptr, param.Output(), false); + break; +#ifndef __aarch64__ + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + break; +#endif // __aarch64__ + case ConvParam::EXEC_GEMM_FLOAT: + GemmConv(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } +} + +template class ConvKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/conv_transpose_kernel.cpp b/src/operators/kernel/arm/convolution/conv_transpose_kernel.cpp similarity index 100% rename from src/operators/kernel/arm/conv_transpose_kernel.cpp rename to src/operators/kernel/arm/convolution/conv_transpose_kernel.cpp diff --git a/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f600d8af0c81a40c588fc16586138c71e756f832 --- /dev/null +++ b/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp @@ -0,0 +1,110 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DWCONVBNRELU_OP + +#include "operators/kernel/dwconv_bn_relu_kernel.h" +#include +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool DWConvBNReluKernel::Init(FusionDWConvBNReluParam *param) { + const Tensor *mean = param->InputMean(); + const Tensor *variance = param->InputVariance(); + const Tensor *scale = param->InputScale(); + const Tensor *bias = param->InputBias(); + const float epsilon = param->Epsilon(); + + auto mean_ptr = mean->data(); + auto variance_ptr = variance->data(); + auto scale_ptr = scale->data(); + auto bias_ptr = bias->data(); + + const int C = mean->numel(); + float inv_std_ptr[C]; + for (int i = 0; i < C; i++) { + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + } + LoDTensor *new_scale = new LoDTensor(); + LoDTensor *new_bias = new LoDTensor(); + auto new_scale_ptr = new_scale->mutable_data({C}); + auto new_bias_ptr = new_bias->mutable_data({C}); + for (int i = 0; i < C; i++) { + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; + new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; + } + param->SetNewScale(new_scale); + param->SetNewBias(new_bias); + + InitBaseConvKernel(param); + return true; +} + +template <> +void DWConvBNReluKernel::Compute( + const FusionDWConvBNReluParam ¶m) { + switch (param.ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: + math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: + math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: + math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), + nullptr, false, false); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; + case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: + math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), + param.Filter(), nullptr, param.Output(), false); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; +#ifndef __aarch64__ + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + break; +#endif // __aarch64__ + case ConvParam::EXEC_GEMM_FLOAT: + ConvBNReluBasic>(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } +} +template class DWConvBNReluKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/dwconv_bn_relu_kernel.cpp b/src/operators/kernel/arm/dwconv_bn_relu_kernel.cpp deleted file mode 100644 index 38dd6ae181ccc623b3321bf9358764f7ef2c22bc..0000000000000000000000000000000000000000 --- a/src/operators/kernel/arm/dwconv_bn_relu_kernel.cpp +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_DWCONVBNRELU_OP - -#include "operators/kernel/dwconv_bn_relu_kernel.h" -#include -#include "operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h" - -namespace paddle_mobile { -namespace operators { - -template <> -bool DWConvBNReluKernel::Init(FusionDWConvBNReluParam *param) { - const Tensor *mean = param->InputMean(); - const Tensor *variance = param->InputVariance(); - const Tensor *scale = param->InputScale(); - const Tensor *bias = param->InputBias(); - const float epsilon = param->Epsilon(); - - auto mean_ptr = mean->data(); - auto variance_ptr = variance->data(); - auto scale_ptr = scale->data(); - auto bias_ptr = bias->data(); - - const int C = mean->numel(); - float inv_std_ptr[C]; - for (int i = 0; i < C; i++) { - inv_std_ptr[i] = - 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); - } - LoDTensor *new_scale = new LoDTensor(); - LoDTensor *new_bias = new LoDTensor(); - auto new_scale_ptr = new_scale->mutable_data({C}); - auto new_bias_ptr = new_bias->mutable_data({C}); - for (int i = 0; i < C; i++) { - new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; - new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; - } - param->SetNewScale(new_scale); - param->SetNewBias(new_bias); - return true; -} - -template <> -void DWConvBNReluKernel::Compute( - const FusionDWConvBNReluParam ¶m) { - DWConvBNReluCompute(param); -} -template class DWConvBNReluKernel; - -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h deleted file mode 100644 index b94d6c97f278a71c4fa1519d9ceb76a7f3675852..0000000000000000000000000000000000000000 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVADDBNRELU_OP - -#pragma once - -#include -#include "operators/math/depthwise_conv3x3.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/vol2col.h" -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -void ConvAddBNReluBasic(const FusionConvAddBNReluParam ¶m) { - const Tensor *input = param.Input(); - Tensor filter = *param.Filter(); - Tensor new_bias = *param.NewBias(); - Tensor new_scale = *param.NewScale(); - Tensor *output = param.Output(); - output->mutable_data(); - - int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); - - const int batch_size = static_cast(input->dims()[0]); - - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - - std::vector output_shape_vec(framework::vectorize(output->dims())); - size_t data_dim = filter_shape_vec.size() - 2; - std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - - framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, data_dim + 1); - - bool is_expand = - math::IsExpand(filter_shape_vec, strides, paddings, dilations); - Tensor col; - Tensor col_matrix; - if (is_expand) { - col.mutable_data(col_shape); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } - - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = { - output->dims()[1], - output->numel() / (output->dims()[0] * output->dims()[1])}; - - // convolution operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; - - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; - - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (!is_expand) { - col.ShareDataWith(in_slice); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { - // im2col - im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); - } else if (data_dim == 3U) { - // vol2col - vol2col(in_slice, dilations, strides, paddings, &col); - } - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - - math::MatMulWithBn(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(0), true, &new_scale, &new_bias, g); - } - } -} - -template -void ConvAddBNReluCompute(const FusionConvAddBNReluParam ¶m) { - if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { - math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - } else if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { - math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - } else { - ConvAddBNReluBasic(param); - } -} - -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index b1556a3a771231fd62e8cadda2d9d7d40721856a..6f37a0b711fed43f67d38777b471225768858c6a 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -212,6 +212,95 @@ inline void DepthwiseConv5x5(const ConvParam ¶m) { } #endif // __aarch64__ +template +void ConvBNReluBasic(const ParamType ¶m) { + const Tensor *input = param.Input(); + Tensor filter = *param.Filter(); + Tensor new_bias = *param.NewBias(); + Tensor new_scale = *param.NewScale(); + Tensor *output = param.Output(); + output->mutable_data(); + + int groups = param.Groups(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + + const int batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + + std::vector output_shape_vec(framework::vectorize(output->dims())); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + + bool is_expand = + math::IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col.mutable_data(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { + col_matrix = in_slice; + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(in_slice, dilations, strides, paddings, &col); + } + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + + math::MatMulWithBn(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(0), true, &new_scale, &new_bias, g); + } + } +} + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h deleted file mode 100644 index 7eeb7f76670aa5c5a39544484ac92e611ff9066a..0000000000000000000000000000000000000000 --- a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h +++ /dev/null @@ -1,145 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVBNRELU_OP - -#pragma once -#include -#include "operators/math/depthwise_conv3x3.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/vol2col.h" -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -void ConvBNReluBasic(const FusionConvBNReluParam ¶m) { - const Tensor *input = param.Input(); - Tensor filter = *param.Filter(); - Tensor new_bias = *param.NewBias(); - Tensor new_scale = *param.NewScale(); - - Tensor *output = param.Output(); - output->mutable_data(); - - int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); - - const int batch_size = static_cast(input->dims()[0]); - - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - - std::vector output_shape_vec(framework::vectorize(output->dims())); - size_t data_dim = filter_shape_vec.size() - 2; - std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - - framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, data_dim + 1); - - bool is_expand = - math::IsExpand(filter_shape_vec, strides, paddings, dilations); - Tensor col; - Tensor col_matrix; - if (is_expand) { - col.mutable_data(col_shape); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } - - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = { - output->dims()[1], - output->numel() / (output->dims()[0] * output->dims()[1])}; - - // convolution operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; - - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; - - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (!is_expand) { - col.ShareDataWith(in_slice); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { - // im2col - im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); - } else if (data_dim == 3U) { - // vol2col - vol2col(in_slice, dilations, strides, paddings, &col); - } - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - - math::MatMulWithBn(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(0), true, &new_scale, &new_bias, g); - } - } -} - -template -void ConvBNReluCompute(const FusionConvBNReluParam ¶m) { - if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { - math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - } else if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { - // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), - // param.Output(), param.NewScale(), - // param.NewBias(), 1); - math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - } else { - ConvBNReluBasic(param); - } -} - -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h deleted file mode 100644 index e0299d00ae09de62c133676449f0148a49beae5e..0000000000000000000000000000000000000000 --- a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_DWCONVBNRELU_OP - -#pragma once -#include -#include "operators/math/depthwise_conv3x3.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/vol2col.h" -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -void DWConvBNReluBasic(const FusionDWConvBNReluParam ¶m) { - const Tensor *input = param.Input(); - Tensor filter = *param.Filter(); - Tensor new_bias = *param.NewBias(); - Tensor new_scale = *param.NewScale(); - - Tensor *output = param.Output(); - output->mutable_data(); - - int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); - - const int batch_size = static_cast(input->dims()[0]); - - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - - std::vector output_shape_vec(framework::vectorize(output->dims())); - size_t data_dim = filter_shape_vec.size() - 2; - std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - - framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, data_dim + 1); - - bool is_expand = - math::IsExpand(filter_shape_vec, strides, paddings, dilations); - Tensor col; - Tensor col_matrix; - if (is_expand) { - col.mutable_data(col_shape); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } - - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = { - output->dims()[1], - output->numel() / (output->dims()[0] * output->dims()[1])}; - - // convolution operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; - - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; - - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (!is_expand) { - col.ShareDataWith(in_slice); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { - // im2col - im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); - } else if (data_dim == 3U) { - // vol2col - vol2col(in_slice, dilations, strides, paddings, &col); - } - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::MatMulWithBn(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(0), true, &new_scale, &new_bias, g); - } - } -} -template -void DWConvBNReluCompute(const FusionDWConvBNReluParam ¶m) { - if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { - math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - } else if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { - // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), - // param.Output(), param.NewScale(), - // param.NewBias(), 1); - math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), true); - } else { - DWConvBNReluBasic(param); - } -} - -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/math/conv_func.h b/src/operators/math/conv_func.h index d9e2da0db5c50e0b0f9b11d5584bfce8b75777cd..40320dedac564f4c66ed1773e0c3f050b6e07144 100644 --- a/src/operators/math/conv_func.h +++ b/src/operators/math/conv_func.h @@ -14,12 +14,13 @@ limitations under the License. */ #pragma once +#include #ifdef __ARM_NEON #include #endif - #include "framework/ddim.h" #include "framework/tensor.h" +#include "operators/math/activation.h" namespace paddle_mobile { namespace operators { @@ -35,8 +36,8 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation, return output_size; } -inline void expand_bias(Tensor &bias, int axis, const DDim &dDim) { - auto bias_ptr = bias.data(); +inline void expand_bias(Tensor &bias, int axis, const DDim &dDim) { // NOLINT + const auto bias_ptr = bias.data(); const DDim bias_ddim = bias.dims(); PADDLE_MOBILE_ENFORCE(bias.dims().size() == 1, "the bias tensor's dims size != 1") @@ -98,6 +99,63 @@ inline bool IsExpand(const std::vector &filter_dim, return !(filter_1 && strides_1 && padding_0 && dilation_1); } +template +void ScaleAddChannelWise(const framework::Tensor *input, + const framework::Tensor *scale, + const framework::Tensor *bias, + framework::Tensor *output) { + const float *input_ptr = input->data(); + const float *scale_ptr = scale->data(); + const float *bias_ptr = bias->data(); + float *output_ptr = output->mutable_data(); + // maybe check shape + int batch_size = input->dims()[0]; + int channels = input->dims()[1]; + size_t spatial_size = input->dims()[2] * input->dims()[3]; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int channel = 0; channel < channels; ++channel) { + size_t offset = (batch * channels + channel) * spatial_size; + const float *x = input_ptr + offset; + float *y = output_ptr + offset; + float alpha = scale_ptr[channel]; + float beta = bias_ptr[channel]; + int j = 0; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float32x4_t __scale = vdupq_n_f32(alpha); + float32x4_t __bias = vdupq_n_f32(beta); + for (; j < spatial_size - 15; j += 16, x += 16, y += 16) { + float32x4_t in0 = vld1q_f32(x); + float32x4_t in1 = vld1q_f32(x + 4); + float32x4_t in2 = vld1q_f32(x + 8); + float32x4_t in3 = vld1q_f32(x + 12); + in0 = vmlaq_f32(__bias, __scale, in0); + in1 = vmlaq_f32(__bias, __scale, in1); + in2 = vmlaq_f32(__bias, __scale, in2); + in3 = vmlaq_f32(__bias, __scale, in3); + in0 = math::vActiveq_f32(in0); + in1 = math::vActiveq_f32(in1); + in2 = math::vActiveq_f32(in2); + in3 = math::vActiveq_f32(in3); + vst1q_f32(y, in0); + vst1q_f32(y + 4, in1); + vst1q_f32(y + 8, in2); + vst1q_f32(y + 12, in3); + } + for (; j < spatial_size - 3; j += 4, x += 4, y += 4) { + float32x4_t in0 = vld1q_f32(x); + in0 = vmlaq_f32(__bias, __scale, in0); + in0 = math::vActiveq_f32(in0); + vst1q_f32(y, in0); + } +#endif + for (; j < spatial_size; ++j, ++x, ++y) { + *y = math::Active(alpha * (*x) + beta); + } + } + } +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index f4cf22decb5a722f29560f4b563bc8a81001b922..1fa78d161621b3c7928a0ce6b554c14aac3fd6b6 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -3255,8 +3255,6 @@ void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda, paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); packedC = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); int mc, nc; for (int j = 0; j < n; j += NC) { @@ -3288,7 +3286,6 @@ void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda, paddle_mobile::memory::Free(packedA); paddle_mobile::memory::Free(packedB); paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); } void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A, @@ -3328,8 +3325,6 @@ void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A, paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); packedC = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); int mc, nc; for (int j = 0; j < n; j += NC) { @@ -3362,7 +3357,6 @@ void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A, paddle_mobile::memory::Free(packedA); paddle_mobile::memory::Free(packedB); paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); } void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, @@ -3401,11 +3395,6 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); packedC = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - - for (int l = 0; l < KC; ++l) { - zero[l] = 0; - } int mc, nc; for (int j = 0; j < n; j += NC) { @@ -3437,7 +3426,6 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, paddle_mobile::memory::Free(packedA); paddle_mobile::memory::Free(packedB); paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); } // 32位 float 矩阵乘法 @@ -3459,8 +3447,6 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, int L = (max_threads > 2) ? 64 : 32; int L1 = L / max_threads * 1024; KC = k; - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(float)); @@ -3566,7 +3552,6 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, paddle_mobile::memory::Free(packedA); paddle_mobile::memory::Free(packedB); paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); } void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, @@ -3581,8 +3566,6 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int L1 = 64 / max_threads * 1024; KC = k; - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(float)); @@ -3694,7 +3677,6 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, paddle_mobile::memory::Free(packedA); paddle_mobile::memory::Free(packedB); paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); } void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, @@ -3709,8 +3691,6 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, int L1 = 8 * 1024; KC = k; - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(float)); @@ -3820,7 +3800,6 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, paddle_mobile::memory::Free(packedA); paddle_mobile::memory::Free(packedB); paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); } } // namespace math diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index 3fc418003c7faa804c0f7a146b1f9108e0b01789..113e04fe3c94d4f9edbfab0520ca881b9cbab4e7 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -260,7 +260,6 @@ class Gemm { float *packedA; float *packedB; float *packedC; - float *zero; // 8 bits int int8_t *packedA_int8; diff --git a/src/operators/math/gemm/cblas.cc b/src/operators/math/gemm/cblas.cc new file mode 100644 index 0000000000000000000000000000000000000000..5fb9e290b30eaf99262ed6da0fade7e6a832ec6c --- /dev/null +++ b/src/operators/math/gemm/cblas.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "operators/math/gemm/cblas.h" +#include "operators/math/gemm/cpu_info.h" +#include "operators/math/gemm/executor.h" +#include "operators/math/gemm/strategy.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +void cblas_sgemm(const bool transA, const bool transB, const int M, const int N, + const int K, const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, float *C, + const int ldc) { + if (N == 1) { + return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C); + } + + CPUInfo *info = CPUInfo::Info(); + GemmExecutor exec(info, transA, transB, M, N, K); + exec(alpha, A, lda, B, ldb, beta, C, ldc); +} + +void cblas_sgemv(const bool trans, const int M, const int N, const float alpha, + const float *A, const int lda, const float *B, + const float beta, float *C) { + CPUInfo *info = CPUInfo::Info(); + GemvExecutor exec(info, trans, M, N); + exec(alpha, A, lda, B, beta, C); +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/gemm/cblas.h b/src/operators/math/gemm/cblas.h new file mode 100644 index 0000000000000000000000000000000000000000..c7c9201869f56a7d339cccfcb3d898a4751836a6 --- /dev/null +++ b/src/operators/math/gemm/cblas.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle_mobile { +namespace operators { +namespace math { + +void cblas_sgemm(const bool transA, const bool transB, const int M, const int N, + const int K, const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, float *C, + const int ldc); + +void cblas_sgemv(const bool trans, const int M, const int N, const float alpha, + const float *A, const int lda, const float *B, + const float beta, float *C); + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/gemm/cpu_info.h b/src/operators/math/gemm/cpu_info.h new file mode 100644 index 0000000000000000000000000000000000000000..54975797c782be1964c562f38fd12edbcd6a2f0e --- /dev/null +++ b/src/operators/math/gemm/cpu_info.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#define MOBILE_MAX_CPU_NUM 8 + +namespace paddle_mobile { +namespace operators { +namespace math { + +struct CPUInfo { + private: + CPUInfo() { + // TODO(hjchen2) + num_cpus = 4; + for (int i = 0; i < num_cpus; ++i) { + cpu_frequency[i] = 2400; // 2400 MHz + max_cpu_frequency[i] = 2400; // 2400 MHz + } + // L1_cache = 32000; // 32K + L1_cache = 32 * 1024; + L2_cache = 2000000; // 2M + // L2_cache = 512000; + } + virtual ~CPUInfo() {} + + public: + static CPUInfo* Info() { + static CPUInfo* ctx = new CPUInfo; + return ctx; + } + + int num_cpus; + int cpu_frequency[MOBILE_MAX_CPU_NUM]; + int max_cpu_frequency[MOBILE_MAX_CPU_NUM]; + + int L1_cache; + int L2_cache; +}; + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/gemm/executor.h b/src/operators/math/gemm/executor.h new file mode 100644 index 0000000000000000000000000000000000000000..9dcf8080196373d5cc7a2183f831370835288bfa --- /dev/null +++ b/src/operators/math/gemm/executor.h @@ -0,0 +1,208 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#ifdef _OPENMP +#include +#endif +#include +#include +#include "common/log.h" +#include "memory/t_malloc.h" +#include "operators/math/gemm/cpu_info.h" +#include "operators/math/gemm/gemm_kernel.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +inline int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; } + +class Executor { + public: + Executor() : num_threads_(1) { +#ifdef _OPENMP + num_threads_ = omp_get_max_threads(); +#endif + } + virtual ~Executor() {} + + protected: + int num_threads_; +}; + +template +class GemmExecutor : public Executor { + typedef typename Strategy::Itype Itype; + typedef typename Strategy::Otype Otype; + + public: + GemmExecutor(const CPUInfo *info, const bool transA, const bool transB, + const int M, const int N, const int K) + : Executor(), + info_(info), + transA_(transA), + transB_(transB), + M_(M), + N_(N), + K_(K) { + unsigned int L1_size = info->L1_cache; + unsigned int L2_size = info->L2_cache; + // if (N_ > 10000) L1_size *= 2; + if (num_threads_ >= 2) L1_size /= 2; + + rhs_tile_num_ = L1_size / (K * sizeof(Itype)); + if (rhs_tile_num_ == 0) { + rhs_tile_num_ = Strategy::out_width(); + } else { + int n_block = CeilDiv(N, rhs_tile_num_); + rhs_tile_num_ = CeilDiv(N, n_block); + rhs_tile_num_ = CeilDiv(rhs_tile_num_, Strategy::out_width()); + rhs_tile_num_ *= Strategy::out_width(); + } + + // lhs_tile_num_ = CeilDiv(M, Strategy::out_height()) * + // Strategy::out_height(); + lhs_tile_num_ = L2_size / (K * sizeof(Itype)); + if (lhs_tile_num_ == 0) { + lhs_tile_num_ = Strategy::out_height(); + } else { + int m_block = CeilDiv(M, lhs_tile_num_); + lhs_tile_num_ = CeilDiv(M, m_block); + lhs_tile_num_ = CeilDiv(lhs_tile_num_, Strategy::out_height()); + lhs_tile_num_ *= Strategy::out_height(); + } + } + + void operator()(const float alpha, const Itype *A, const int lda, + const Itype *B, const int ldb, const float beta, Otype *C, + const int ldc) { + // struct timeval tv_begin, tv_end; + // gettimeofday(&tv_begin,NULL); + + int mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height(); + lhs_worksize_ = sizeof(Itype) * mblock * K_; + rhs_worksize_ = sizeof(Itype) * K_ * rhs_tile_num_ * num_threads_; + out_worksize_ = sizeof(Otype) * mblock * rhs_tile_num_ * num_threads_; + + lhs_workspace_ = + static_cast(paddle_mobile::memory::Alloc(lhs_worksize_)); + rhs_workspace_ = + static_cast(paddle_mobile::memory::Alloc(rhs_worksize_)); + out_workspace_ = + static_cast(paddle_mobile::memory::Alloc(out_worksize_)); + + strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true); + + // std::cout << "M: " << M_ << ", N: " << N_ << ", K: " << K_ << + // std::endl; std::cout << "rhs_block: " << CeilDiv(N_, rhs_tile_num_) << + // std::endl; + + #pragma omp parallel for if (N_ > 128) + for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) { + int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_); +#ifdef _OPENMP + int thread_id = omp_get_thread_num(); +#else + int thread_id = 0; +#endif + float *local_B = rhs_workspace_ + K_ * rhs_tile_num_ * thread_id; + float *local_C = + out_workspace_ + lhs_tile_num_ * rhs_tile_num_ * thread_id; + // load rhs into rhs_workspace + strategy_.pack_rhs(K_, rhs_range, B + rhs_block, ldb, local_B, false); + for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) { + int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_); + float *local_A = lhs_workspace_ + lhs_block * lda; + for (int lhs_tile = 0; lhs_tile < lhs_range; + lhs_tile += Strategy::out_height()) { + for (int rhs_tile = 0; rhs_tile < rhs_range; + rhs_tile += Strategy::out_width()) { + int offset = (lhs_block + lhs_tile) * rhs_tile_num_ + rhs_tile; + strategy_.kernel(local_A + lhs_tile * K_, local_B + rhs_tile * K_, + K_, local_C + offset, rhs_tile_num_); + } + } + } + strategy_.write(M_, rhs_range, local_C, rhs_tile_num_, C + rhs_block, + ldc); + } + + paddle_mobile::memory::Free(lhs_workspace_); + paddle_mobile::memory::Free(rhs_workspace_); + paddle_mobile::memory::Free(out_workspace_); + + // gettimeofday(&tv_end,NULL); + // float elapsed = (tv_end.tv_sec - tv_begin.tv_sec) * 1000.f + + // (tv_end.tv_usec - tv_begin.tv_usec) / 1000.f; std::cout << "elapsed: " + // << elapsed << "ms, speed: " << (M_ * N_ * K_ / 1000.f / 1000.f) / + // elapsed << " gflops" << std::endl; + } + + virtual ~GemmExecutor() {} + + private: + const CPUInfo *info_; + + const unsigned int M_; + const unsigned int N_; + const unsigned int K_; + const bool transA_; + const bool transB_; + + unsigned int lhs_tile_num_ = 0; + unsigned int rhs_tile_num_ = 0; + unsigned int out_tile_num_ = 0; + + unsigned int lhs_worksize_ = 0; + unsigned int rhs_worksize_ = 0; + unsigned int out_worksize_ = 0; + + Itype *lhs_workspace_ = nullptr; + Itype *rhs_workspace_ = nullptr; + Otype *out_workspace_ = nullptr; + + Strategy strategy_; +}; + +template +class GemvExecutor : public Executor { + typedef typename Strategy::Itype Itype; + typedef typename Strategy::Otype Otype; + + public: + GemvExecutor(const CPUInfo *info, const bool transA, const int M, const int N) + : Executor(), info_(info), M_(M), N_(N) {} + + void operator()(const float alpha, const Itype *A, const int lda, + const Itype *B, const float beta, Otype *C) { + // strategy_.kernel(); + } + + virtual ~GemvExecutor() {} + + private: + const CPUInfo *const info_; + + const unsigned int M_; + const unsigned int N_; + + Strategy strategy_; +}; + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/gemm/gemm_kernel.h b/src/operators/math/gemm/gemm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..78489ed793664b57af0416910d1d509e0a53553d --- /dev/null +++ b/src/operators/math/gemm/gemm_kernel.h @@ -0,0 +1,247 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef __ARM_NEON__ + +#include + +namespace paddle_mobile { +namespace operators { +namespace math { + +#ifdef __aarch64__ +void sgemm_12x8(const float *lhs, const float *rhs, const int k, float *output, + const int ldc) { + // TODO(hjchen2) +} +#else +void sgemm_6x8(const float *lhs, const float *rhs, const int k, float *output, + const int ldc) { + int kc1 = k >> 3; // k / 8 + int kc2 = k & 0x7; // k % 8 + int step = sizeof(float) * ldc; + asm volatile( + "pld [%[lhs]] \n\t" + "pld [%[lhs], #64] \n\t" + "pld [%[rhs]] \n\t" + "pld [%[rhs], #64] \n\t" + + "vmov.f32 q4, #0.0 \n\t" + "vmov.f32 q5, #0.0 \n\t" + "vmov.f32 q6, #0.0 \n\t" + "vmov.f32 q7, #0.0 \n\t" + "vmov.f32 q8, #0.0 \n\t" + "vmov.f32 q9, #0.0 \n\t" + "vmov.f32 q10, #0.0 \n\t" + "vmov.f32 q11, #0.0 \n\t" + "vmov.f32 q12, #0.0 \n\t" + "vmov.f32 q13, #0.0 \n\t" + "vmov.f32 q14, #0.0 \n\t" + "vmov.f32 q15, #0.0 \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "blt 2f \n\t" + "1: \n\t" + + "pld [%[lhs], #128] \n\t" + "pld [%[rhs], #128] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "pld [%[lhs], #128] \n\t" + "pld [%[rhs], #128] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "pld [%[lhs], #128] \n\t" + "pld [%[rhs], #128] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "pld [%[lhs], #128] \n\t" + "pld [%[rhs], #128] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge 1b \n\t" + "2: \n\t" + + "subs %[kc2], %[kc2], #1 \n\t" + "blt 4f \n\t" + "3: \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "subs %[kc2], %[kc2], #1 \n\t" + "bge 3b \n\t" + "4: \n\t" + + "mov r5, %[c] \n\t" + "mov r6, %[step] \n\t" + "vst1.32 {q4, q5}, [r5], r6 \n\t" + "vst1.32 {q6, q7}, [r5], r6 \n\t" + "vst1.32 {q8, q9}, [r5], r6 \n\t" + "vst1.32 {q10, q11}, [r5], r6 \n\t" + "vst1.32 {q12, q13}, [r5], r6 \n\t" + "vst1.32 {q14, q15}, [r5] \n\t" + : + : [lhs] "r"(lhs), [rhs] "r"(rhs), [c] "r"(output), [kc1] "r"(kc1), + [kc2] "r"(kc2), [step] "r"(step) + : "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +} +#endif // __aarch64__ + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif // __ARM_NEON__ diff --git a/src/operators/math/gemm/pack_kernel.h b/src/operators/math/gemm/pack_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..48d9ed651604afc2e3dce492b026c47424ad0ea5 --- /dev/null +++ b/src/operators/math/gemm/pack_kernel.h @@ -0,0 +1,658 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef __ARM_NEON__ + +#include +#ifdef _OPENMP +#include +#endif + +namespace paddle_mobile { +namespace operators { +namespace math { + +inline float32x4_t vandq_f32_u32(float32x4_t x, uint32x4_t mask) { + return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); +} + +void pack_lhs_12r(const int m, const int k, const float *A, const int lda, + float *output, const bool parallel) { + // TODO(hjchen2) +} + +void pack_lhs_6r(const int m, const int k, const float *A, const int lda, + float *output, const bool parallel) { + uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 4, 5}; + int remain_k = k & 0x3; + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_k)); + + #pragma omp parallel for if (parallel) + for (int i = 0; i < m - 5; i += 6) { + const float *a0 = A + i * lda; + const float *a1 = A + (i + 1) * lda; + const float *a2 = A + (i + 2) * lda; + const float *a3 = A + (i + 3) * lda; + const float *a4 = A + (i + 4) * lda; + const float *a5 = A + (i + 5) * lda; + float *out_ptr = output + i * k; + + int loops = k >> 2; + if (loops > 0) { +#if __aarch64__ + for (int l = 0; l < loops; ++l) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = + vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + _d3 = + vcombine_f32(vget_high_f32(_q0.val[1]), vget_high_f32(_q1.val[1])); + + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0])); + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1])); + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0])); + vst1q_f32(out_ptr + 18, _d3); + vst1_f32(out_ptr + 22, vget_high_f32(_q3.val[1])); + + a0 += 4; + a1 += 4; + a2 += 4; + a3 += 4; + a4 += 4; + a5 += 4; + out_ptr += 24; + } +#else + asm volatile( + "loop_4k_%=: \n" + "vld1.32 {d0-d1}, [%[a0]]! \n" + "vld1.32 {d2-d3}, [%[a1]]! \n" + "vld1.32 {d4-d5}, [%[a2]]! \n" + "vld1.32 {d6-d7}, [%[a3]]! \n" + "vld1.32 {d8-d9}, [%[a4]]! \n" + "vld1.32 {d10-d11}, [%[a5]]! \n" + "vtrn.32 q0, q1 \n" + "vtrn.32 q2, q3 \n" + "vtrn.32 q4, q5 \n" + "vswp.32 d1, d4 \n" + "vswp.32 d3, d6 \n" + + "vst1.32 {q0}, [%[out]]! \n" + "vst1.32 {d8}, [%[out]]! \n" + "vst1.32 {q1}, [%[out]]! \n" + "vst1.32 {d10}, [%[out]]! \n" + "vst1.32 {q2}, [%[out]]! \n" + "vst1.32 {d9}, [%[out]]! \n" + "vst1.32 {q3}, [%[out]]! \n" + "vst1.32 {d11}, [%[out]]! \n" + + "subs %[loops], #1 \n" + "bne loop_4k_%= \n" + : [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2] "+r"(a2), + [a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5"); +#endif + } + + if (remain_k > 0) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + _d0 = vandq_f32_u32(_d0, vmask1); + _d1 = vandq_f32_u32(_d1, vmask1); + _d2 = vandq_f32_u32(_d2, vmask1); + _d3 = vandq_f32_u32(_d3, vmask1); + _d4 = vandq_f32_u32(_d4, vmask1); + _d5 = vandq_f32_u32(_d5, vmask1); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + + switch (remain_k) { + case 3: + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0])); + case 2: + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1])); + case 1: + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0])); + default: + break; + } + } + } + + int remain_m = m % 6; + if (remain_m) { + int remain_m_start = m - remain_m; + const float *a0 = A + remain_m_start * lda; + const float *a1 = a0 + lda; + const float *a2 = a0 + 2 * lda; + const float *a3 = a0 + 3 * lda; + const float *a4 = a0 + 4 * lda; + const float *a5 = a0 + 5 * lda; + float *out_ptr = output + remain_m_start * k; + + uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_m)); + uint32x4_t vmask3 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(remain_m)); + + int loops = k >> 2; + if (loops > 0) { +#if __aarch64__ + for (int l = 0; l < loops; ++l) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = + vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + _d3 = + vcombine_f32(vget_high_f32(_q0.val[1]), vget_high_f32(_q1.val[1])); + + _d0 = vandq_f32_u32(_d0, vmask2); + _d1 = vandq_f32_u32(_d1, vmask2); + _d2 = vandq_f32_u32(_d2, vmask2); + _d3 = vandq_f32_u32(_d3, vmask2); + _d4 = vandq_f32_u32(_q3.val[0], vmask3); + _d5 = vandq_f32_u32(_q3.val[1], vmask3); + + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_d4)); + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_d5)); + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_d4)); + vst1q_f32(out_ptr + 18, _d3); + vst1_f32(out_ptr + 22, vget_high_f32(_d5)); + + a0 += 4; + a1 += 4; + a2 += 4; + a3 += 4; + a4 += 4; + a5 += 4; + out_ptr += 24; + } +#else + asm volatile( + "loop_4k_%=: \n" + "vld1.32 {d0-d1}, [%[a0]]! \n" + "vld1.32 {d2-d3}, [%[a1]]! \n" + "vld1.32 {d4-d5}, [%[a2]]! \n" + "vld1.32 {d6-d7}, [%[a3]]! \n" + "vld1.32 {d8-d9}, [%[a4]]! \n" + "vld1.32 {d10-d11}, [%[a5]]! \n" + "vtrn.32 q0, q1 \n" + "vtrn.32 q2, q3 \n" + "vtrn.32 q4, q5 \n" + "vswp.32 d1, d4 \n" + "vswp.32 d3, d6 \n" + + "vbif q0, %q[vzero], %q[vmask2] \n" + "vbif q1, %q[vzero], %q[vmask2] \n" + "vbif q2, %q[vzero], %q[vmask2] \n" + "vbif q3, %q[vzero], %q[vmask2] \n" + "vbif q4, %q[vzero], %q[vmask3] \n" + "vbif q5, %q[vzero], %q[vmask3] \n" + + "vst1.32 {q0}, [%[out]]! \n" + "vst1.32 {d8}, [%[out]]! \n" + "vst1.32 {q1}, [%[out]]! \n" + "vst1.32 {d10}, [%[out]]! \n" + "vst1.32 {q2}, [%[out]]! \n" + "vst1.32 {d9}, [%[out]]! \n" + "vst1.32 {q3}, [%[out]]! \n" + "vst1.32 {d11}, [%[out]]! \n" + + "subs %[loops], #1 \n" + "bne loop_4k_%= \n" + : [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2] "+r"(a2), + [a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops) + : [vmask2] "w"(vmask2), [vmask3] "w"(vmask3), [vzero] "w"(vzero) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5"); +#endif + } + + if (remain_k > 0) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + _d0 = vandq_f32_u32(_d0, vmask1); + _d1 = vandq_f32_u32(_d1, vmask1); + _d2 = vandq_f32_u32(_d2, vmask1); + _d3 = vandq_f32_u32(_d3, vmask1); + _d4 = vandq_f32_u32(_d4, vmask1); + _d5 = vandq_f32_u32(_d5, vmask1); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + // _d3 = vcombine_f32(vget_high_f32(_q0.val[1]), + // vget_high_f32(_q1.val[1])); + + _d0 = vandq_f32_u32(_d0, vmask2); + _d1 = vandq_f32_u32(_d1, vmask2); + _d2 = vandq_f32_u32(_d2, vmask2); + // _d3 = vandq_f32_u32(_d3, vmask2); + _d4 = vandq_f32_u32(_q3.val[0], vmask3); + _d5 = vandq_f32_u32(_q3.val[1], vmask3); + + switch (remain_k) { + case 3: + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_d4)); + case 2: + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_d5)); + case 1: + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_d4)); + default: + break; + } + } + } +} + +void pack_rhs_8c(const int k, const int n, const float *B, const int ldb, + float *output, const bool parallel) { + #pragma omp parallel for if (parallel) + for (int i = 0; i < k - 3; i += 4) { + int j = 0; + for (; j < n - 15; j += 16) { + float *out_ptr0 = output + j * k + 8 * i; + float *out_ptr1 = out_ptr0 + 8 * k; + const float *b0 = B + i * ldb + j; + const float *b1 = b0 + ldb; + const float *b2 = b1 + ldb; + const float *b3 = b2 + ldb; +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%[b0]] \n" + "prfm pldl1keep, [%[b1]] \n" + "prfm pldl1keep, [%[b2]] \n" + "prfm pldl1keep, [%[b3]] \n" + + "ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b0]], #32 \n" + "ld1 {v4.4s, v5.4s}, [%[b1]], #32 \n" + "ld1 {v6.4s, v7.4s}, [%[b1]], #32 \n" + "ld1 {v8.4s, v9.4s}, [%[b2]], #32 \n" + "ld1 {v10.4s, v11.4s}, [%[b2]], #32 \n" + "ld1 {v12.4s, v13.4s}, [%[b3]], #32 \n" + "ld1 {v14.4s, v15.4s}, [%[b3]], #32 \n" + + "st1 {v0.4s, v1.4s}, [%[out_ptr0]], #32 \n" + "st1 {v4.4s, v5.4s}, [%[out_ptr0]], #32 \n" + "st1 {v8.4s, v9.4s}, [%[out_ptr0]], #32 \n" + "st1 {v12.4s, v13.4s}, [%[out_ptr0]], #32 \n" + "st1 {v2.4s, v3.4s}, [%[out_ptr1]], #32 \n" + "st1 {v6.4s, v7.4s}, [%[out_ptr1]], #32 \n" + "st1 {v10.4s, v11.4s}, [%[out_ptr1]], #32 \n" + "st1 {v14.4s, v15.4s}, [%[out_ptr1]], #32 \n" + : [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [b0] "+r"(b0) + : + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15"); +#else + asm volatile( + // "pld [%[b]] \n" + "vld1.32 {q0, q1}, [%[b0]]! \n" + "vld1.32 {q4, q5}, [%[b1]]! \n" + "vld1.32 {q2, q3}, [%[b0]]! \n" + "vld1.32 {q6, q7}, [%[b1]]! \n" + "vld1.32 {q8, q9}, [%[b2]]! \n" + "vld1.32 {q12, q13}, [%[b3]]! \n" + "vld1.32 {q10, q11}, [%[b2]]! \n" + "vld1.32 {q14, q15}, [%[b3]]! \n" + + "vst1.32 {q0, q1}, [%[out_ptr0]]! \n" + "vst1.32 {q4, q5}, [%[out_ptr0]]! \n" + "vst1.32 {q8, q9}, [%[out_ptr0]]! \n" + "vst1.32 {q12, q13}, [%[out_ptr0]]! \n" + "vst1.32 {q2, q3}, [%[out_ptr1]]! \n" + "vst1.32 {q6, q7}, [%[out_ptr1]]! \n" + "vst1.32 {q10, q11}, [%[out_ptr1]]! \n" + "vst1.32 {q14, q15}, [%[out_ptr1]]! \n" + : [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [b0] "+r"(b0), + [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3) + : + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ + } + + for (; j < n - 7; j += 8) { + float *out_ptr0 = output + j * k + 8 * i; + const float *b0 = B + i * ldb + j; + const float *b1 = b0 + ldb; + const float *b2 = b1 + ldb; + const float *b3 = b2 + ldb; +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%[b0]] \n" + "ld1 {v0.4s, v1.4s}, [%[b0]] \n" + "ld1 {v2.4s, v3.4s}, [%[b1]] \n" + "ld1 {v4.4s, v5.4s}, [%[b2]] \n" + "ld1 {v6.4s, v7.4s}, [%[b3]] \n" + + "st1 {v0.4s, v1.4s}, [%[out_ptr0]], #32 \n" + "st1 {v2.4s, v3.4s}, [%[out_ptr0]], #32 \n" + "st1 {v4.4s, v5.4s}, [%[out_ptr0]], #32 \n" + "st1 {v6.4s, v7.4s}, [%[out_ptr0]], #32 \n" + : [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0), [b1] "+r"(b1), + [b2] "+r"(b2), [b3] "+r"(b3) + : + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "r0"); +#else + asm volatile( + // "pld [%[b]] \n" + "vld1.32 {q0, q1}, [%[b0]] \n" + "vld1.32 {q2, q3}, [%[b1]] \n" + "vld1.32 {q4, q5}, [%[b2]] \n" + "vld1.32 {q6, q7}, [%[b3]] \n" + + "vst1.32 {q0, q1}, [%[out_ptr0]]! \n" + "vst1.32 {q2, q3}, [%[out_ptr0]]! \n" + "vst1.32 {q4, q5}, [%[out_ptr0]]! \n" + "vst1.32 {q6, q7}, [%[out_ptr0]]! \n" + : [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0), [b1] "+r"(b1), + [b2] "+r"(b2), [b3] "+r"(b3) + : + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +#endif // __aarch64__ + } + } + + int remain_k_start = k & 0xFFFC; + if (remain_k_start < k) { + int j = 0; + for (; j < n - 15; j += 16) { + float *out_ptr0 = output + j * k + 8 * remain_k_start; + float *out_ptr1 = out_ptr0 + 8 * k; + const float *b0 = B + remain_k_start * ldb + j; +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%[b0]] \n" + "ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b0]] \n" + "st1 {v0.4s, v1.4s}, [%[out_ptr0]] \n" + "st1 {v2.4s, v3.4s}, [%[out_ptr1]] \n" + : [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [b0] "+r"(b0) + : + : "memory", "v0", "v1", "v2", "v3"); +#else + asm volatile( + // "pld [%[b]] \n" + "vld1.32 {q0, q1}, [%[b0]]! \n" + "vld1.32 {q2, q3}, [%[b0]] \n" + "vst1.32 {q0, q1}, [%[out_ptr0]] \n" + "vst1.32 {q2, q3}, [%[out_ptr1]] \n" + : [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [b0] "+r"(b0) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ + } + for (; j < n - 7; j += 8) { + float *out_ptr0 = output + j * k + 8 * remain_k_start; + const float *b0 = B + remain_k_start * ldb + j; +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%[b0]] \n" + "ld1 {v0.4s, v1.4s}, [%[b0]] \n" + "st1 {v0.4s, v1.4s}, [%[out_ptr0]] \n" + : [out_ptr0] "+r"(out_ptr0) + : [b0] "r"(b0) + : "memory", "v0", "v1"); +#else + asm volatile( + // "pld [%[b]] \n" + "vld1.32 {q0, q1}, [%[b0]] \n" + "vst1.32 {q0, q1}, [%[out_ptr0]] \n" + : [out_ptr0] "+r"(out_ptr0) + : [b0] "r"(b0) + : "memory", "q0", "q1"); +#endif // __aarch64__ + } + } + + int remain_n_start = n & 0xFFF8; + if (remain_n_start < n) { + uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + uint32_t remain_n = n & 0x7; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_n)); + uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(remain_n)); + + float *out_ptr = output + remain_n_start * k; + for (int i = 0; i < k; ++i) { + const float *b0 = B + i * ldb + remain_n_start; +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%[b0]] \n" + "ld1 {v0.4s, v1.4s}, [%[b0]] \n" + "bif v0.8b, %[vzero].8b, %[vmask1].8b \n" + "bif v1.8b, %[vzero].8b, %[vmask2].8b \n" + "st1 {v0.4s, v1.4s}, [%[out_ptr]], #32 \n" + : [out_ptr] "+r"(out_ptr) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero), + [b0] "r"(b0) + : "memory", "v0", "v1"); +#else + asm volatile( + "vld1.32 {q0, q1}, [%[b0]] \n" + "vbif q0, %q[vzero], %q[vmask1] \n" + "vbif q1, %q[vzero], %q[vmask2] \n" + "vst1.32 {q0, q1}, [%[out_ptr]] \n" + : [out_ptr] "+r"(out_ptr) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero), + [b0] "r"(b0) + : "memory", "q0", "q1"); +#endif + } + } +} + +void write_back(const int mc, const int nc, const float *c, const int ldc1, + float *C, const int ldc2) { + /* + int remain_n = nc & 0x3; + //#ifndef __aarch64__ + // register float32x4_t _in00 __asm("q0"); + // register float32x4_t _in01 __asm("q1"); + // register float32x4_t _in10 __asm("q2"); + // register float32x4_t _in11 __asm("q3"); + //#endif + + int m = 0; + for (; m < mc - 1; m += 2) { + const float *in0 = c + m * ldc1; + const float *in1 = in0 + ldc1; + float *out0 = C + m * ldc2; + float *out1 = out0 + ldc2; + int n = 0; + for (; n < nc - 7; n += 8) { + float32x4_t _in00 = vld1q_f32(in0 + n); + float32x4_t _in01 = vld1q_f32(in0 + n + 4); + float32x4_t _in10 = vld1q_f32(in1 + n); + float32x4_t _in11 = vld1q_f32(in1 + n + 4); + vst1q_f32(out0 + n, _in00); + vst1q_f32(out0 + n + 4, _in01); + vst1q_f32(out1 + n, _in10); + vst1q_f32(out1 + n + 4, _in11); + } + for (; n < nc - 3; n += 4) { + float32x4_t _in00 = vld1q_f32(in0 + n); + float32x4_t _in10 = vld1q_f32(in1 + n); + vst1q_f32(out0 + n, _in00); + vst1q_f32(out1 + n, _in10); + } + if (n < nc) { + float32x4_t _in00 = vld1q_f32(in0 + n); + float32x4_t _in10 = vld1q_f32(in1 + n); + switch (remain_n) { + case 3: + vst1_f32(out0 + n, vget_low_f32(_in00)); + vst1q_lane_f32(out0 + n + 2, _in00, 2); + vst1_f32(out1 + n, vget_low_f32(_in10)); + vst1q_lane_f32(out1 + n + 2, _in10, 2); + break; + case 2: + vst1_f32(out0 + n, vget_low_f32(_in00)); + vst1_f32(out1 + n, vget_low_f32(_in10)); + break; + case 1: + vst1q_lane_f32(out0 + n, _in00, 2); + vst1q_lane_f32(out1 + n, _in10, 2); + break; + default: + break; + } + } + } + + for (; m < mc; ++m) { + const float *in0 = c + m * ldc1; + float *out0 = C + m * ldc2; + int n = 0; + for (; n < nc - 7; n += 8) { + float32x4_t _in0 = vld1q_f32(in0 + n); + float32x4_t _in1 = vld1q_f32(in0 + n + 4); + vst1q_f32(out0 + n, _in0); + vst1q_f32(out0 + n + 4, _in1); + } + for (; n < nc - 3; n += 4) { + float32x4_t _in0 = vld1q_f32(in0 + n); + vst1q_f32(out0 + n, _in0); + } + if (n < nc) { + float32x4_t _in0 = vld1q_f32(in0 + n); + switch (remain_n) { + case 3: + vst1_f32(out0 + n, vget_low_f32(_in0)); + vst1q_lane_f32(out0 + n + 2, _in0, 2); + break; + case 2: + vst1_f32(out0 + n, vget_low_f32(_in0)); + break; + case 1: + vst1q_lane_f32(out0 + n, _in0, 2); + break; + default: + break; + } + } + } + */ + + int nc1 = nc / 16; + int nc2 = nc % 16; + int step1 = 4 * (ldc1 - 16 * nc1); + int step2 = 4 * ldc2; + int volatile m = mc; + + const float *volatile c_ptr = c; + float *volatile C_ptr = C; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "loop_mc_%=: \n\t" + + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vst1.32 {q0, q1}, [r6]! \n\t" + + "vld1.32 {q2, q3}, [%[c_ptr]]! \n\t" + "vst1.32 {q2, q3}, [r6]! \n\t" + + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "add %[C_ptr], %[C_ptr], %[step2] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), + [step1] "r"(step1), [step2] "r"(step2) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3"); + } + + if (nc2 != 0) { + for (int i = 0; i < mc; i++) { + const float *c0 = c_ptr + nc1 * 16 + i * ldc1; + float *C0 = C_ptr + nc1 * 16 + i * ldc2; + for (int j = 0; j < nc2; j++) { + *C0++ = *c0++; + } + } + } +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif // __ARM_NEON__ diff --git a/src/operators/math/gemm/strategy.h b/src/operators/math/gemm/strategy.h new file mode 100644 index 0000000000000000000000000000000000000000..8baf4e7e8cfd32d3c5c5bc1f769ceb2ac643dee3 --- /dev/null +++ b/src/operators/math/gemm/strategy.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "operators/math/gemm/gemm_kernel.h" +#include "operators/math/gemm/pack_kernel.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +struct SgemmStrategy { + typedef float Itype; + typedef float Otype; + + typedef void (*packLhsFunc)(const int, const int, const Itype *, const int, + Itype *, const bool); + typedef void (*packRhsFunc)(const int, const int, const Itype *, const int, + Itype *, const bool); + typedef void (*kernelFunc)(const Itype *, const Itype *, const int, Otype *, + const int); + typedef void (*WriteFunc)(const int, const int, const Otype *, const int, + Otype *, const int); + + packLhsFunc pack_lhs; + packRhsFunc pack_rhs; + kernelFunc kernel; + WriteFunc write; + + static int out_width() { return 8; } + + static int out_height() { +#ifdef __aarch64__ + return 12; +#else + return 6; +#endif + } + + SgemmStrategy() { +#ifdef __aarch64__ + pack_lhs = pack_lhs_12r; + pack_rhs = pack_rhs_8c; + kernel = sgemm_12x8; +#else + pack_lhs = pack_lhs_6r; + pack_rhs = pack_rhs_8c; + kernel = sgemm_6x8; +#endif + write = write_back; + } +}; + +struct I8o32gemmStrategy { + typedef int8_t Itype; + typedef int32_t Otype; + + typedef void (*kern_type)(const Itype *, const Itype *, const int, Otype *, + const int); + kern_type kernel; + + static int out_width() { return 8; } + + static int out_height() { +#ifdef __aarch64__ + return 12; +#else + return 6; +#endif + } + + I8o32gemmStrategy() {} +}; + +struct SgemvStrategy { + typedef float Itype; + typedef float Otype; + + typedef void (*kern_type)(const Itype *, const Itype *, const int, Otype *, + const int); + kern_type kernel; + + static int out_width() { return 1; } + + static int out_height() { +#ifdef __aarch64__ + return 12; +#else + return 6; +#endif + } +}; + +struct I8o32gemvStrategy { + typedef int8_t Itype; + typedef int32_t Otype; + + typedef void (*kern_type)(const Itype *, const Itype *, const int, Otype *, + const int); + kern_type kernel; + + static int out_width() { return 1; } + + static int out_height() { +#ifdef __aarch64__ + return 12; +#else + return 6; +#endif + } +}; + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index b1e49e377b661cdfdefe08e8043f11b43ab0f9ee..b576963cc4b9a6754de3f8ebc3f6ffcb40b161f2 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -18,6 +18,7 @@ limitations under the License. */ #include "framework/data_type.h" #include "framework/tensor.h" #include "operators/math/gemm.h" +#include "operators/math/gemm/cblas.h" namespace paddle_mobile { namespace operators { @@ -55,6 +56,7 @@ void MatMul(const framework::Tensor &matrix_a, bool trans_a, int M = dim_out[0]; int N = dim_out[1]; int K = (!trans_a) ? dim_a[1] : dim_a[0]; + Gemm gemm; if (trans_a) { framework::Tensor matrix_trans; @@ -69,24 +71,34 @@ void MatMul(const framework::Tensor &matrix_a, bool trans_a, a[index++] = tmp[i * n + j]; } } - + if (M > N || M == 1) { #ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); + gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); #else - gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); + gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); #endif + } else { + cblas_sgemm(false, false, M, N, K, alpha, a, K, matrix_b.data(), N, + beta, matrix_out->data(), N); + } } else { + if (M > N || M == 1) { #ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, matrix_out->data(), - N, relu, bias); + gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), + N, relu, bias); #else - gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, matrix_out->data(), N, - relu, bias); + gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), N, + relu, bias); #endif + } else { + cblas_sgemm(false, false, M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), + N); + } } } diff --git a/src/operators/math/winograd/winograd_transform_f6k3.cpp b/src/operators/math/winograd/winograd_transform_f6k3.cpp index 937050ebbd8f6bab3b0c9b075e9b4fa54c25b1ba..f95ead62445861566a1591df3c883085fc3eb16e 100644 --- a/src/operators/math/winograd/winograd_transform_f6k3.cpp +++ b/src/operators/math/winograd/winograd_transform_f6k3.cpp @@ -52,9 +52,7 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180}; const float *inptr = weight.data(); int remain_start = out_channel & 0xFFFC; -#if 0 - remain_start = 0; -#else + #pragma omp parallel for for (int oc = 0; oc < out_channel - 3; oc += 4) { float gw[96]; // gw[3][8][4] @@ -258,7 +256,6 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, "q13", "r0"); } } -#endif // remain output channel #pragma omp parallel for @@ -350,311 +347,8 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, size_t image_size = height * width; const float transform_matrix[8] = {5.25f, -5.f, -4.25f, -2.5f, 2.f, -1.25f, 0.5f, 0.25f}; - int remain_c_start = channel & 0xFFFC; -#if 1 - remain_c_start = 0; -#else - #pragma omp parallel for - for (int c = 0; c < channel - 3; c += 4) { - const float *in = inptr + c * image_size; - float d_bt[64 * 4]; // d * B_t - for (int h = 0; h < h_tiles; ++h) { - for (int w = 0; w < w_tiles; ++w) { - const float *in0 = in + (h * width + w) * 6; - const float *in1 = in0 + image_size; - const float *in2 = in1 + image_size; - const float *in3 = in2 + image_size; - int steps = width * sizeof(float); - float *d_bt_ptr = d_bt; - asm volatile( - "mov r0, #8 \n" - "vld1.32 {d0-d3}, [%[tm_ptr]] \n" - // row loop - "loop_r_%=: \n" - "vld1.32 {d4-d7}, [%[in0]], %[steps] \n" - "vld1.32 {d8-d11}, [%[in1]], %[steps] \n" - "vld1.32 {d12-d15}, [%[in2]], %[steps] \n" - "vld1.32 {d16-d19}, [%[in3]], %[steps] \n" - "vtrn.32 q2, q4 \n" // d0: q2 - "vtrn.32 q3, q5 \n" // d1: q4 - "vtrn.32 q6, q8 \n" // d2: q6 - "vtrn.32 q7, q9 \n" // d3: q8 - "vswp.32 d5, d12 \n" // d4: q3 - "vswp.32 d9, d16 \n" // d5: q5 - "vswp.32 d7, d14 \n" // d6: q7 - "vswp.32 d11, d18 \n" // d7: q9 - - "vsub.f32 q10, q2, q7 \n" - "vsub.f32 q11, q3, q6 \n" - "vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 - - // d2) * 5.25 - "vst1.32 {d20-d21}, [%[d_bt]]! \n" - - "vadd.f32 q10, q6, q7 \n" - "vadd.f32 q11, q4, q5 \n" - "vmla.f32 q10, q3, d1[0] \n" // d2 - 4.25 * d4 + - // d6 - "vmla.f32 q11, q8, d1[0] \n" // d1 - 4.25 * d3 + - // d5 - "vadd.f32 q12, q10, q11 \n" - "vsub.f32 q13, q10, q11 \n" - "vst1.32 {d24-d27}, [%[d_bt]]! \n" - - "vmul.f32 q10, q6, d3[1] \n" // 0.25 * d2 - "vmul.f32 q11, q4, d3[0] \n" // 0.5 * d1 - "vadd.f32 q10, q10, q7 \n" // 0.25 * d2 + d6 - "vmla.f32 q11, q5, d2[0] \n" // 0.5 * d1 + 2 * - // d5 - "vmla.f32 q10, q3, d2[1] \n" // 0.25 * d2 + d6 - // - 1.25 * d4 - "vmla.f32 q11, q8, d1[1] \n" // 0.5 * d1 + 2 * - // d5 - 2.5 * d3 - "vadd.f32 q12, q10, q11 \n" - "vsub.f32 q13, q10, q11 \n" - "vst1.32 {d24-d27}, [%[d_bt]]! \n" - - "vmul.f32 q10, q6, d2[0] \n" // 2 * d2 - "vmul.f32 q11, q4, d2[0] \n" // 2 * d1 - "vmla.f32 q10, q3, d1[1] \n" // 2 * d2 - 2.5 * - // d4 - "vmla.f32 q11, q8, d1[1] \n" // 2 * d1 - 2.5 * - // d3 - "vmla.f32 q10, q7, d3[0] \n" // 2 * d1 - 2.5 * - // d3 + 0.5 * d6 - "vmla.f32 q11, q5, d3[0] \n" // 2 * d2 - 2.5 * - // d4 + 0.5 * d5 - "vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3 - // + d6 - "vadd.f32 q12, q10, q11 \n" - "vsub.f32 q13, q10, q11 \n" - "vst1.32 {d24-d27}, [%[d_bt]]! \n" - - "vsub.f32 q10, q9, q4 \n" - "vsub.f32 q11, q8, q5 \n" - "vmla.f32 q10, q11, d0[0] \n" - "vst1.32 {d20-d21}, [%[d_bt]]! \n" - - "subs r0, #1 \n" - "bne loop_r_%= \n" - : [d_bt] "+r"(d_bt_ptr), [in0] "+r"(in0), [in1] "+r"(in1), - [in2] "+r"(in2), [in3] "+r"(in3) - : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q13", "r0"); - - float *ptr0 = d_bt; - float *ptr1 = ptr0 + 32; - float *ptr2 = ptr1 + 32; - float *ptr3 = ptr2 + 32; - float *ptr4 = ptr3 + 32; - float *ptr5 = ptr4 + 32; - float *ptr6 = ptr5 + 32; - float *ptr7 = ptr6 + 32; - int tile_indics = h * w_tiles + w; - int tile_block = tile_indics >> 3; - int block_indics = tile_indics & 0x7; - // (tiles / 8, 64, channel, 8) - float *out0 = - outptr + (tile_block * 64 * channel + c) * 8 + block_indics; - steps = (channel - 3) * 8 * sizeof(float); - asm volatile( - "vld1.32 {d0-d3}, [%[tm_ptr]] \n" - "mov r0, 4 \n" - "mov r1, 32 \n" - "loop_col_%=: \n" - // col 0: - "vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0 - "vld1.32 {d6-d7}, [%[ptr1]]! \n" // q3: d1 - "vld1.32 {d8-d9}, [%[ptr2]]! \n" // q4: d2 - "vld1.32 {d10-d11}, [%[ptr3]]! \n" // q5: d3 - "vld1.32 {d12-d13}, [%[ptr4]]! \n" // q6: d4 - "vld1.32 {d14-d15}, [%[ptr5]]! \n" // q7: d5 - "vld1.32 {d16-d17}, [%[ptr6]]! \n" // q8: d6 - "vld1.32 {d18-d19}, [%[ptr7]]! \n" // q9: d7 - - "vsub.f32 q10, q2, q8 \n" // d0 - d6 - "vsub.f32 q11, q6, q4 \n" // d4 - d2 - "vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 - - // d2) * 5.25 - "vst1.32 {d20[0]}, [%[out0]], r1 \n" - "vst1.32 {d20[1]}, [%[out0]], r1 \n" - "vst1.32 {d21[0]}, [%[out0]], r1 \n" - "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" - - "vadd.f32 q10, q4, q8 \n" - "vadd.f32 q11, q3, q7 \n" - "vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 + - // d6 - "vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 + - // d5 - "vadd.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - "vsub.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2 - "vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1 - "vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6 - "vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 * - // d5 - "vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6 - // - 1.25 * d4 - "vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 * - // d5 - 2.5 * d3 - "vadd.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - "vsub.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vmul.f32 q10, q4, d2[0] \n" // 2 * d2 - "vmul.f32 q11, q3, d2[0] \n" // 2 * d1 - "vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 * - // d4 - "vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 * - // d3 - "vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 * - // d3 + 0.5 * d6 - "vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 * - // d4 + 0.5 * d5 - "vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3 - // + d6 - "vadd.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - "vsub.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vsub.f32 q10, q9, q3 \n" - "vsub.f32 q11, q5, q7 \n" - "vmla.f32 q10, q11, d0[0] \n" - "vst1.32 {d20[0]}, [%[out0]], r1 \n" - "vst1.32 {d20[1]}, [%[out0]], r1 \n" - "vst1.32 {d21[0]}, [%[out0]], r1 \n" - "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" - - // col 1: - "vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0 - "vld1.32 {d6-d7}, [%[ptr1]]! \n" // q3: d1 - "vld1.32 {d8-d9}, [%[ptr2]]! \n" // q4: d2 - "vld1.32 {d10-d11}, [%[ptr3]]! \n" // q5: d3 - "vld1.32 {d12-d13}, [%[ptr4]]! \n" // q6: d4 - "vld1.32 {d14-d15}, [%[ptr5]]! \n" // q7: d5 - "vld1.32 {d16-d17}, [%[ptr6]]! \n" // q8: d6 - "vld1.32 {d18-d19}, [%[ptr7]]! \n" // q9: d7 - - "vsub.f32 q10, q2, q8 \n" // d0 - d6 - "vsub.f32 q11, q6, q4 \n" // d4 - d2 - "vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 - - // d2) * 5.25 - "vst1.32 {d20[0]}, [%[out0]], r1 \n" - "vst1.32 {d20[1]}, [%[out0]], r1 \n" - "vst1.32 {d21[0]}, [%[out0]], r1 \n" - "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" - - "vadd.f32 q10, q4, q8 \n" - "vadd.f32 q11, q3, q7 \n" - "vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 + - // d6 - "vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 + - // d5 - "vadd.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vsub.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2 - "vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1 - "vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6 - "vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 * - // d5 - "vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6 - // - 1.25 * d4 - "vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 * - // d5 - 2.5 * d3 - "vadd.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vsub.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vmul.f32 q10, q4, d2[0] \n" // 2 * d2 - "vmul.f32 q11, q3, d2[0] \n" // 2 * d1 - "vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 * - // d4 - "vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 * - // d3 - "vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 * - // d3 + 0.5 * d6 - "vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 * - // d4 + 0.5 * d5 - "vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3 - // + d6 - "vadd.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vsub.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vsub.f32 q10, q9, q3 \n" - "vsub.f32 q11, q5, q7 \n" - "vmla.f32 q10, q11, d0[0] \n" - "vst1.32 {d20[0]}, [%[out0]], r1 \n" - "vst1.32 {d20[1]}, [%[out0]], r1 \n" - "vst1.32 {d21[0]}, [%[out0]], r1 \n" - "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" - - "subs r0, #1 \n" - "bne loop_col_%= \n" - : [out0] "+r"(out0), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), - [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3), [ptr4] "+r"(ptr4), - [ptr5] "+r"(ptr5), [ptr6] "+r"(ptr6), [ptr7] "+r"(ptr7) - : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q13", "r0", "r1"); - } - } - } -#endif - - // remainer channels #pragma omp parallel for - for (int c = remain_c_start; c < channel; ++c) { + for (int c = 0; c < channel; ++c) { const float *in = inptr + c * image_size; float d_bt[64]; // d * B_t for (int h = 0; h < h_tiles; ++h) { diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 5eaeb784bd81b21d92a57fde282e7d80bb3f553e..a735fbee48164406e38256bab30d7ae97abfa31a 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1753,18 +1753,15 @@ class FusionConvAddParam : public ConvParam { : ConvParam(inputs, outputs, attrs, scope) { bias_ = OpParam::InputYFrom(inputs, scope); axis_ = OpParam::GetAttr("axis", attrs); - output_ = OpParam::OutFrom(outputs, scope); + this->output_ = OpParam::OutFrom(outputs, scope); } GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - GType *Output() const { return output_; } - protected: GType *bias_; int axis_; - GType *output_; }; template @@ -1797,18 +1794,16 @@ class FusionConvAddPReluParam : public ConvParam { framework::DDim dims = alpha_->dims(); bias_ = OpParam::InputYFrom(inputs, scope); axis_ = OpParam::GetAttr("axis", attrs); - output_ = OpParam::OutFrom(outputs, scope); + this->output_ = OpParam::OutFrom(outputs, scope); } const GType *InputAlpha() const { return alpha_; } const std::string &Mode() const { return mode_; } GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - GType *Output() const { return output_; } protected: GType *bias_; int axis_; - GType *output_; GType *alpha_; std::string mode_; }; @@ -1830,7 +1825,6 @@ class FusionConvAddAddPReluParam : public ConvParam { mode_ = OpParam::GetStringAttr("mode", attrs); framework::DDim dims = alpha_->dims(); bias_ = OpParam::InputYFrom(inputs, scope); - output_ = OpParam::OutFrom(outputs, scope); axis_ = OpParam::GetAttr("axis", attrs); keyOutput_ = OpParam::getkey("addOut", inputs, 0); keyX1_ = OpParam::getkey("addX", inputs, 1); @@ -1840,6 +1834,7 @@ class FusionConvAddAddPReluParam : public ConvParam { } else if (keyY1_ == keyOutput_) { bias1_ = OpParam::InputXFrom1(inputs, scope); } + this->output_ = OpParam::OutFrom(outputs, scope); } const GType *InputAlpha() const { return alpha_; } const std::string &Mode() const { return mode_; } @@ -1848,12 +1843,10 @@ class FusionConvAddAddPReluParam : public ConvParam { GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - GType *Output() const { return output_; } protected: GType *bias_; int axis_; - GType *output_; GType *alpha_; std::string mode_; GType *bias1_; @@ -1876,21 +1869,18 @@ class FusionConvAddBNReluParam : public ConvParam { : ConvParam(inputs, outputs, attrs, scope) { bias_ = OpParam::InputYFrom(inputs, scope); axis_ = OpParam::GetAttr("axis", attrs); - output_ = OpParam::OutFrom(outputs, scope); input_bias_ = OpParam::InputBiasFrom(inputs, scope); input_mean_ = OpParam::InputMeanFrom(inputs, scope); input_scale_ = OpParam::InputScaleFrom(inputs, scope); input_variance_ = OpParam::InputVarianceFrom(inputs, scope); epsilon_ = OpParam::GetAttr("epsilon", attrs); momentum_ = OpParam::GetAttr("momentum", attrs); - // is_test_ = OpParam::GetAttr("is_test", attrs); + this->output_ = OpParam::OutFrom(outputs, scope); } GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - GType *Output() const { return output_; } - const GType *InputBias() const { return input_bias_; } const GType *InputMean() const { return input_mean_; } @@ -1903,8 +1893,6 @@ class FusionConvAddBNReluParam : public ConvParam { const float &Momentum() const { return momentum_; } - const bool &IsTest() const { return is_test_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } @@ -1916,14 +1904,12 @@ class FusionConvAddBNReluParam : public ConvParam { protected: GType *bias_; int axis_; - GType *output_; GType *input_bias_; GType *input_mean_; GType *input_scale_; GType *input_variance_; float epsilon_; float momentum_; - bool is_test_; GType *new_bias_; GType *new_scale_; }; @@ -1942,7 +1928,6 @@ class FusionConvBNAddReluParam : public ConvParam { : ConvParam(inputs, outputs, attrs, scope) { bias_ = OpParam::InputYFrom(inputs, scope); axis_ = OpParam::GetAttr("axis", attrs); - output_ = OpParam::OutFrom(outputs, scope); input_bias_ = OpParam::InputBiasFrom(inputs, scope); input_mean_ = OpParam::InputMeanFrom(inputs, scope); input_scale_ = OpParam::InputScaleFrom(inputs, scope); @@ -1957,14 +1942,12 @@ class FusionConvBNAddReluParam : public ConvParam { } else if (keyY_ == keyBNY_) { bias_ = OpParam::InputXFrom(inputs, scope); } - // is_test_ = OpParam::GetAttr("is_test", attrs); + this->output_ = OpParam::OutFrom(outputs, scope); } GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - GType *Output() const { return output_; } - const GType *InputBias() const { return input_bias_; } const GType *InputMean() const { return input_mean_; } @@ -1977,8 +1960,6 @@ class FusionConvBNAddReluParam : public ConvParam { const float &Momentum() const { return momentum_; } - const bool &IsTest() const { return is_test_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } @@ -1990,14 +1971,12 @@ class FusionConvBNAddReluParam : public ConvParam { protected: GType *bias_; int axis_; - GType *output_; GType *input_bias_; GType *input_mean_; GType *input_scale_; GType *input_variance_; float epsilon_; float momentum_; - bool is_test_; GType *new_bias_; GType *new_scale_; std::string keyBNY_; @@ -2017,16 +1996,14 @@ class FusionConvBNParam : public ConvParam { const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) : ConvParam(inputs, outputs, attrs, scope) { - output_y_ = OpParam::OutputYFrom(outputs, scope); input_bias_ = OpParam::InputBiasFrom(inputs, scope); input_mean_ = OpParam::InputMeanFrom(inputs, scope); input_scale_ = OpParam::InputScaleFrom(inputs, scope); input_variance_ = OpParam::InputVarianceFrom(inputs, scope); epsilon_ = OpParam::GetAttr("epsilon", attrs); momentum_ = OpParam::GetAttr("momentum", attrs); - // is_test_ = OpParam::GetAttr("is_test", attrs); + this->output_ = OpParam::OutputYFrom(outputs, scope); } - GType *Output() const { return output_y_; } const GType *InputBias() const { return input_bias_; } @@ -2040,8 +2017,6 @@ class FusionConvBNParam : public ConvParam { const float &Momentum() const { return momentum_; } - const bool &IsTest() const { return is_test_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } @@ -2051,14 +2026,12 @@ class FusionConvBNParam : public ConvParam { const GType *NewBias() const { return new_bias_; } protected: - GType *output_y_; GType *input_bias_; GType *input_mean_; GType *input_scale_; GType *input_variance_; float epsilon_; float momentum_; - bool is_test_; GType *new_bias_; GType *new_scale_; }; @@ -2077,21 +2050,18 @@ class FusionConvAddBNParam : public ConvParam { : ConvParam(inputs, outputs, attrs, scope) { bias_ = OpParam::InputYFrom(inputs, scope); axis_ = OpParam::GetAttr("axis", attrs); - output_y_ = OpParam::OutputYFrom(outputs, scope); input_bias_ = OpParam::InputBiasFrom(inputs, scope); input_mean_ = OpParam::InputMeanFrom(inputs, scope); input_scale_ = OpParam::InputScaleFrom(inputs, scope); input_variance_ = OpParam::InputVarianceFrom(inputs, scope); epsilon_ = OpParam::GetAttr("epsilon", attrs); momentum_ = OpParam::GetAttr("momentum", attrs); - // is_test_ = OpParam::GetAttr("is_test", attrs); + this->output_ = OpParam::OutputYFrom(outputs, scope); } GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - GType *Output() const { return output_y_; } - const GType *InputBias() const { return input_bias_; } const GType *InputMean() const { return input_mean_; } @@ -2104,8 +2074,6 @@ class FusionConvAddBNParam : public ConvParam { const float &Momentum() const { return momentum_; } - const bool &IsTest() const { return is_test_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } @@ -2117,14 +2085,12 @@ class FusionConvAddBNParam : public ConvParam { protected: GType *bias_; int axis_; - GType *output_y_; GType *input_bias_; GType *input_mean_; GType *input_scale_; GType *input_variance_; float epsilon_; float momentum_; - bool is_test_; GType *new_bias_; GType *new_scale_; }; @@ -2141,16 +2107,14 @@ class FusionDWConvBNReluParam : public ConvParam { const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) : ConvParam(inputs, outputs, attrs, scope) { - output_ = OpParam::OutFrom(outputs, scope); input_bias_ = OpParam::InputBiasFrom(inputs, scope); input_mean_ = OpParam::InputMeanFrom(inputs, scope); input_scale_ = OpParam::InputScaleFrom(inputs, scope); input_variance_ = OpParam::InputVarianceFrom(inputs, scope); epsilon_ = OpParam::GetAttr("epsilon", attrs); momentum_ = OpParam::GetAttr("momentum", attrs); - // is_test_ = OpParam::GetAttr("is_test", attrs); + this->output_ = OpParam::OutFrom(outputs, scope); } - GType *Output() const { return output_; } const GType *InputBias() const { return input_bias_; } @@ -2164,8 +2128,6 @@ class FusionDWConvBNReluParam : public ConvParam { const float &Momentum() const { return momentum_; } - const bool &IsTest() const { return is_test_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } @@ -2175,14 +2137,12 @@ class FusionDWConvBNReluParam : public ConvParam { const GType *NewBias() const { return new_bias_; } protected: - GType *output_; GType *input_bias_; GType *input_mean_; GType *input_scale_; GType *input_variance_; float epsilon_; float momentum_; - bool is_test_; GType *new_bias_; GType *new_scale_; }; @@ -2200,16 +2160,14 @@ class FusionConvBNReluParam : public ConvParam { const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) : ConvParam(inputs, outputs, attrs, scope) { - output_ = OpParam::OutFrom(outputs, scope); input_bias_ = OpParam::InputBiasFrom(inputs, scope); input_mean_ = OpParam::InputMeanFrom(inputs, scope); input_scale_ = OpParam::InputScaleFrom(inputs, scope); input_variance_ = OpParam::InputVarianceFrom(inputs, scope); epsilon_ = OpParam::GetAttr("epsilon", attrs); momentum_ = OpParam::GetAttr("momentum", attrs); - // is_test_ = OpParam::GetAttr("is_test", attrs); + this->output_ = OpParam::OutFrom(outputs, scope); } - GType *Output() const { return output_; } const GType *InputBias() const { return input_bias_; } @@ -2223,8 +2181,6 @@ class FusionConvBNReluParam : public ConvParam { const float &Momentum() const { return momentum_; } - const bool &IsTest() const { return is_test_; } - void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } @@ -2234,14 +2190,12 @@ class FusionConvBNReluParam : public ConvParam { const GType *NewBias() const { return new_bias_; } protected: - GType *output_; GType *input_bias_; GType *input_mean_; GType *input_scale_; GType *input_variance_; float epsilon_; float momentum_; - bool is_test_; GType *new_bias_; GType *new_scale_; };