提交 67059524 编写于 作者: H hjchen2

backup

上级 e4615bde
......@@ -16,7 +16,8 @@ limitations under the License. */
#include "operators/kernel/conv_add_bn_relu_kernel.h"
#include <cmath>
#include "operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile {
namespace operators {
......@@ -51,14 +52,58 @@ bool ConvAddBNReluKernel<CPU, float>::Init(
}
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
InitBaseConvKernel(param);
return true;
}
template <>
void ConvAddBNReluKernel<CPU, float>::Compute(
const FusionConvAddBNReluParam<CPU> &param) {
ConvAddBNReluCompute<float>(param);
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvBNReluBasic<FusionConvAddBNReluParam<CPU>>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
}
template class ConvAddBNReluKernel<CPU, float>;
} // namespace operators
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#ifdef FUSION_CONVADD_OP
#include "operators/kernel/conv_add_kernel.h"
#include "../central-arm-func/conv_add_arm_func.h"
#include "operators/kernel/central-arm-func/conv_add_arm_func.h"
namespace paddle_mobile {
namespace operators {
......
......@@ -16,7 +16,8 @@ limitations under the License. */
#include "operators/kernel/conv_bn_relu_kernel.h"
#include <cmath>
#include "operators/kernel/central-arm-func/conv_bn_relu_arm_func.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile {
namespace operators {
......@@ -29,8 +30,6 @@ bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam<CPU> *param) {
const Tensor *bias = param->InputBias();
const float epsilon = param->Epsilon();
// DLOG << "variance: " << *variance;
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
......@@ -50,16 +49,58 @@ bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam<CPU> *param) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
InitBaseConvKernel(param);
return true;
}
template <>
void ConvBNReluKernel<CPU, float>::Compute(
const FusionConvBNReluParam<CPU> &param) {
ConvBNReluCompute<float>(param);
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvBNReluBasic<FusionConvBNReluParam<CPU>>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
}
template class ConvBNReluKernel<CPU, float>;
......
......@@ -12,22 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONV_OP
#include "operators/kernel/conv_kernel.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/math/winograd/winograd_transform.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
void InitBaseConvKernel(ConvParam<CPU> *param) {
bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3;
bool conv5x5 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 5;
bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == typeid(int8_t)) {
......@@ -65,10 +63,10 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT;
} else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1 &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1 /* &&
param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
param->Input()->dims()[2] <= 140 */ /* refered from ncnn */) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight
param->transformed_filter_ = new framework::LoDTensor;
......@@ -79,59 +77,7 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
}
}
return true;
}
template <>
void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_GEMM_INT8:
GemmConv<int8_t, int32_t>(param);
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8:
DepthwiseConv3x3<int8_t, int32_t>(param);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_INT8:
DepthwiseConv5x5<int8_t, int32_t>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
}
template class ConvKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void InitBaseConvKernel(ConvParam<CPU> *param);
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONV_OP
#include "operators/kernel/conv_kernel.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
InitBaseConvKernel(param);
return true;
}
template <>
void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_GEMM_INT8:
GemmConv<int8_t, int32_t>(param);
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8:
DepthwiseConv3x3<int8_t, int32_t>(param);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_INT8:
DepthwiseConv5x5<int8_t, int32_t>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
}
template class ConvKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -16,7 +16,8 @@ limitations under the License. */
#include "operators/kernel/dwconv_bn_relu_kernel.h"
#include <cmath>
#include "operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile {
namespace operators {
......@@ -50,13 +51,56 @@ bool DWConvBNReluKernel<CPU, float>::Init(FusionDWConvBNReluParam<CPU> *param) {
}
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
InitBaseConvKernel(param);
return true;
}
template <>
void DWConvBNReluKernel<CPU, float>::Compute(
const FusionDWConvBNReluParam<CPU> &param) {
DWConvBNReluCompute<float>(param);
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
ConvBNReluBasic<FusionDWConvBNReluParam<CPU>>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
}
template class DWConvBNReluKernel<CPU, float>;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDBNRELU_OP
#pragma once
#include <vector>
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void ConvAddBNReluBasic(const FusionConvAddBNReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
Tensor *output = param.Output();
output->mutable_data<float>();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::MatMulWithBn(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), true, &new_scale, &new_bias, g);
}
}
}
template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else {
ConvAddBNReluBasic(param);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -212,6 +212,95 @@ inline void DepthwiseConv5x5(const ConvParam<CPU> &param) {
}
#endif // __aarch64__
template <typename ParamType>
void ConvBNReluBasic(const ParamType &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
Tensor *output = param.Output();
output->mutable_data<float>();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col_matrix = in_slice;
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::MatMulWithBn(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), true, &new_scale, &new_bias, g);
}
}
}
} // namespace operators
} // namespace paddle_mobile
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNRELU_OP
#pragma once
#include <vector>
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void ConvBNReluBasic(const FusionConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
Tensor *output = param.Output();
output->mutable_data<float>();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::MatMulWithBn(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), true, &new_scale, &new_bias, g);
}
}
}
template <typename P>
void ConvBNReluCompute(const FusionConvBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else {
ConvBNReluBasic(param);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DWCONVBNRELU_OP
#pragma once
#include <vector>
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void DWConvBNReluBasic(const FusionDWConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
Tensor *output = param.Output();
output->mutable_data<float>();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::MatMulWithBn(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), true, &new_scale, &new_bias, g);
}
}
}
template <typename P>
void DWConvBNReluCompute(const FusionDWConvBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else {
DWConvBNReluBasic(param);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -14,12 +14,13 @@ limitations under the License. */
#pragma once
#include <vector>
#ifdef __ARM_NEON
#include <arm_neon.h>
#endif
#include "framework/ddim.h"
#include "framework/tensor.h"
#include "operators/math/activation.h"
namespace paddle_mobile {
namespace operators {
......@@ -35,8 +36,8 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation,
return output_size;
}
inline void expand_bias(Tensor &bias, int axis, const DDim &dDim) {
auto bias_ptr = bias.data<float>();
inline void expand_bias(Tensor &bias, int axis, const DDim &dDim) { // NOLINT
const auto bias_ptr = bias.data<float>();
const DDim bias_ddim = bias.dims();
PADDLE_MOBILE_ENFORCE(bias.dims().size() == 1,
"the bias tensor's dims size != 1")
......@@ -98,6 +99,63 @@ inline bool IsExpand(const std::vector<int64_t> &filter_dim,
return !(filter_1 && strides_1 && padding_0 && dilation_1);
}
template <ActivationType Act>
void ScaleAddChannelWise(const framework::Tensor *input,
const framework::Tensor *scale,
const framework::Tensor *bias,
framework::Tensor *output) {
const float *input_ptr = input->data<float>();
const float *scale_ptr = scale->data<float>();
const float *bias_ptr = bias->data<float>();
float *output_ptr = output->mutable_data<float>();
// maybe check shape
int batch_size = input->dims()[0];
int channels = input->dims()[1];
size_t spatial_size = input->dims()[2] * input->dims()[3];
for (int batch = 0; batch < batch_size; ++batch) {
for (int channel = 0; channel < channels; ++channel) {
size_t offset = (batch * channels + channel) * spatial_size;
const float *x = input_ptr + offset;
float *y = output_ptr + offset;
float alpha = scale_ptr[channel];
float beta = bias_ptr[channel];
int j = 0;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4_t __scale = vdupq_n_f32(alpha);
float32x4_t __bias = vdupq_n_f32(beta);
for (; j < spatial_size - 15; j += 16, x += 16, y += 16) {
float32x4_t in0 = vld1q_f32(x);
float32x4_t in1 = vld1q_f32(x + 4);
float32x4_t in2 = vld1q_f32(x + 8);
float32x4_t in3 = vld1q_f32(x + 12);
in0 = vmlaq_f32(__bias, __scale, in0);
in1 = vmlaq_f32(__bias, __scale, in1);
in2 = vmlaq_f32(__bias, __scale, in2);
in3 = vmlaq_f32(__bias, __scale, in3);
in0 = math::vActiveq_f32<Act>(in0);
in1 = math::vActiveq_f32<Act>(in1);
in2 = math::vActiveq_f32<Act>(in2);
in3 = math::vActiveq_f32<Act>(in3);
vst1q_f32(y, in0);
vst1q_f32(y + 4, in1);
vst1q_f32(y + 8, in2);
vst1q_f32(y + 12, in3);
}
for (; j < spatial_size - 3; j += 4, x += 4, y += 4) {
float32x4_t in0 = vld1q_f32(x);
in0 = vmlaq_f32(__bias, __scale, in0);
in0 = math::vActiveq_f32<Act>(in0);
vst1q_f32(y, in0);
}
#endif
for (; j < spatial_size; ++j, ++x, ++y) {
*y = math::Active<Act>(alpha * (*x) + beta);
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -3255,8 +3255,6 @@ void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
int mc, nc;
for (int j = 0; j < n; j += NC) {
......@@ -3288,7 +3286,6 @@ void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A,
......@@ -3328,8 +3325,6 @@ void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A,
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
int mc, nc;
for (int j = 0; j < n; j += NC) {
......@@ -3362,7 +3357,6 @@ void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A,
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
......@@ -3401,11 +3395,6 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
for (int l = 0; l < KC; ++l) {
zero[l] = 0;
}
int mc, nc;
for (int j = 0; j < n; j += NC) {
......@@ -3437,7 +3426,6 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
// 32位 float 矩阵乘法
......@@ -3459,8 +3447,6 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
int L = (max_threads > 2) ? 64 : 32;
int L1 = L / max_threads * 1024;
KC = k;
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(float));
......@@ -3566,7 +3552,6 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
......@@ -3581,8 +3566,6 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
int L1 = 64 / max_threads * 1024;
KC = k;
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(float));
......@@ -3694,7 +3677,6 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
......@@ -3709,8 +3691,6 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
int L1 = 8 * 1024;
KC = k;
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(float));
......@@ -3820,7 +3800,6 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
} // namespace math
......
......@@ -260,7 +260,6 @@ class Gemm {
float *packedA;
float *packedB;
float *packedC;
float *zero;
// 8 bits int
int8_t *packedA_int8;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "operators/math/gemm/cblas.h"
#include "operators/math/gemm/cpu_info.h"
#include "operators/math/gemm/executor.h"
#include "operators/math/gemm/strategy.h"
namespace paddle_mobile {
namespace operators {
namespace math {
void cblas_sgemm(const bool transA, const bool transB, const int M, const int N,
const int K, const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta, float *C,
const int ldc) {
if (N == 1) {
return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C);
}
CPUInfo *info = CPUInfo::Info();
GemmExecutor<SgemmStrategy> exec(info, transA, transB, M, N, K);
exec(alpha, A, lda, B, ldb, beta, C, ldc);
}
void cblas_sgemv(const bool trans, const int M, const int N, const float alpha,
const float *A, const int lda, const float *B,
const float beta, float *C) {
CPUInfo *info = CPUInfo::Info();
GemvExecutor<SgemvStrategy> exec(info, trans, M, N);
exec(alpha, A, lda, B, beta, C);
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace paddle_mobile {
namespace operators {
namespace math {
void cblas_sgemm(const bool transA, const bool transB, const int M, const int N,
const int K, const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta, float *C,
const int ldc);
void cblas_sgemv(const bool trans, const int M, const int N, const float alpha,
const float *A, const int lda, const float *B,
const float beta, float *C);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#define MOBILE_MAX_CPU_NUM 8
namespace paddle_mobile {
namespace operators {
namespace math {
struct CPUInfo {
private:
CPUInfo() {
// TODO(hjchen2)
num_cpus = 4;
for (int i = 0; i < num_cpus; ++i) {
cpu_frequency[i] = 2400; // 2400 MHz
max_cpu_frequency[i] = 2400; // 2400 MHz
}
// L1_cache = 32000; // 32K
L1_cache = 32 * 1024;
L2_cache = 2000000; // 2M
// L2_cache = 512000;
}
virtual ~CPUInfo() {}
public:
static CPUInfo* Info() {
static CPUInfo* ctx = new CPUInfo;
return ctx;
}
int num_cpus;
int cpu_frequency[MOBILE_MAX_CPU_NUM];
int max_cpu_frequency[MOBILE_MAX_CPU_NUM];
int L1_cache;
int L2_cache;
};
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#ifdef _OPENMP
#include <omp.h>
#endif
#include <sys/time.h>
#include <iostream>
#include "common/log.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm/cpu_info.h"
#include "operators/math/gemm/gemm_kernel.h"
namespace paddle_mobile {
namespace operators {
namespace math {
inline int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; }
class Executor {
public:
Executor() : num_threads_(1) {
#ifdef _OPENMP
num_threads_ = omp_get_max_threads();
#endif
}
virtual ~Executor() {}
protected:
int num_threads_;
};
template <typename Strategy>
class GemmExecutor : public Executor {
typedef typename Strategy::Itype Itype;
typedef typename Strategy::Otype Otype;
public:
GemmExecutor(const CPUInfo *info, const bool transA, const bool transB,
const int M, const int N, const int K)
: Executor(),
info_(info),
transA_(transA),
transB_(transB),
M_(M),
N_(N),
K_(K) {
unsigned int L1_size = info->L1_cache;
unsigned int L2_size = info->L2_cache;
// if (N_ > 10000) L1_size *= 2;
if (num_threads_ >= 2) L1_size /= 2;
rhs_tile_num_ = L1_size / (K * sizeof(Itype));
if (rhs_tile_num_ == 0) {
rhs_tile_num_ = Strategy::out_width();
} else {
int n_block = CeilDiv(N, rhs_tile_num_);
rhs_tile_num_ = CeilDiv(N, n_block);
rhs_tile_num_ = CeilDiv(rhs_tile_num_, Strategy::out_width());
rhs_tile_num_ *= Strategy::out_width();
}
// lhs_tile_num_ = CeilDiv(M, Strategy::out_height()) *
// Strategy::out_height();
lhs_tile_num_ = L2_size / (K * sizeof(Itype));
if (lhs_tile_num_ == 0) {
lhs_tile_num_ = Strategy::out_height();
} else {
int m_block = CeilDiv(M, lhs_tile_num_);
lhs_tile_num_ = CeilDiv(M, m_block);
lhs_tile_num_ = CeilDiv(lhs_tile_num_, Strategy::out_height());
lhs_tile_num_ *= Strategy::out_height();
}
}
void operator()(const float alpha, const Itype *A, const int lda,
const Itype *B, const int ldb, const float beta, Otype *C,
const int ldc) {
// struct timeval tv_begin, tv_end;
// gettimeofday(&tv_begin,NULL);
int mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height();
lhs_worksize_ = sizeof(Itype) * mblock * K_;
rhs_worksize_ = sizeof(Itype) * K_ * rhs_tile_num_ * num_threads_;
out_worksize_ = sizeof(Otype) * mblock * rhs_tile_num_ * num_threads_;
lhs_workspace_ =
static_cast<Itype *>(paddle_mobile::memory::Alloc(lhs_worksize_));
rhs_workspace_ =
static_cast<Itype *>(paddle_mobile::memory::Alloc(rhs_worksize_));
out_workspace_ =
static_cast<Otype *>(paddle_mobile::memory::Alloc(out_worksize_));
strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true);
// std::cout << "M: " << M_ << ", N: " << N_ << ", K: " << K_ <<
// std::endl; std::cout << "rhs_block: " << CeilDiv(N_, rhs_tile_num_) <<
// std::endl;
#pragma omp parallel for if (N_ > 128)
for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) {
int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_);
#ifdef _OPENMP
int thread_id = omp_get_thread_num();
#else
int thread_id = 0;
#endif
float *local_B = rhs_workspace_ + K_ * rhs_tile_num_ * thread_id;
float *local_C =
out_workspace_ + lhs_tile_num_ * rhs_tile_num_ * thread_id;
// load rhs into rhs_workspace
strategy_.pack_rhs(K_, rhs_range, B + rhs_block, ldb, local_B, false);
for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) {
int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_);
float *local_A = lhs_workspace_ + lhs_block * lda;
for (int lhs_tile = 0; lhs_tile < lhs_range;
lhs_tile += Strategy::out_height()) {
for (int rhs_tile = 0; rhs_tile < rhs_range;
rhs_tile += Strategy::out_width()) {
int offset = (lhs_block + lhs_tile) * rhs_tile_num_ + rhs_tile;
strategy_.kernel(local_A + lhs_tile * K_, local_B + rhs_tile * K_,
K_, local_C + offset, rhs_tile_num_);
}
}
}
strategy_.write(M_, rhs_range, local_C, rhs_tile_num_, C + rhs_block,
ldc);
}
paddle_mobile::memory::Free(lhs_workspace_);
paddle_mobile::memory::Free(rhs_workspace_);
paddle_mobile::memory::Free(out_workspace_);
// gettimeofday(&tv_end,NULL);
// float elapsed = (tv_end.tv_sec - tv_begin.tv_sec) * 1000.f +
// (tv_end.tv_usec - tv_begin.tv_usec) / 1000.f; std::cout << "elapsed: "
// << elapsed << "ms, speed: " << (M_ * N_ * K_ / 1000.f / 1000.f) /
// elapsed << " gflops" << std::endl;
}
virtual ~GemmExecutor() {}
private:
const CPUInfo *info_;
const unsigned int M_;
const unsigned int N_;
const unsigned int K_;
const bool transA_;
const bool transB_;
unsigned int lhs_tile_num_ = 0;
unsigned int rhs_tile_num_ = 0;
unsigned int out_tile_num_ = 0;
unsigned int lhs_worksize_ = 0;
unsigned int rhs_worksize_ = 0;
unsigned int out_worksize_ = 0;
Itype *lhs_workspace_ = nullptr;
Itype *rhs_workspace_ = nullptr;
Otype *out_workspace_ = nullptr;
Strategy strategy_;
};
template <typename Strategy>
class GemvExecutor : public Executor {
typedef typename Strategy::Itype Itype;
typedef typename Strategy::Otype Otype;
public:
GemvExecutor(const CPUInfo *info, const bool transA, const int M, const int N)
: Executor(), info_(info), M_(M), N_(N) {}
void operator()(const float alpha, const Itype *A, const int lda,
const Itype *B, const float beta, Otype *C) {
// strategy_.kernel();
}
virtual ~GemvExecutor() {}
private:
const CPUInfo *const info_;
const unsigned int M_;
const unsigned int N_;
Strategy strategy_;
};
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef __ARM_NEON__
#include <arm_neon.h>
namespace paddle_mobile {
namespace operators {
namespace math {
#ifdef __aarch64__
void sgemm_12x8(const float *lhs, const float *rhs, const int k, float *output,
const int ldc) {
// TODO(hjchen2)
}
#else
void sgemm_6x8(const float *lhs, const float *rhs, const int k, float *output,
const int ldc) {
int kc1 = k >> 3; // k / 8
int kc2 = k & 0x7; // k % 8
int step = sizeof(float) * ldc;
asm volatile(
"pld [%[lhs]] \n\t"
"pld [%[lhs], #64] \n\t"
"pld [%[rhs]] \n\t"
"pld [%[rhs], #64] \n\t"
"vmov.f32 q4, #0.0 \n\t"
"vmov.f32 q5, #0.0 \n\t"
"vmov.f32 q6, #0.0 \n\t"
"vmov.f32 q7, #0.0 \n\t"
"vmov.f32 q8, #0.0 \n\t"
"vmov.f32 q9, #0.0 \n\t"
"vmov.f32 q10, #0.0 \n\t"
"vmov.f32 q11, #0.0 \n\t"
"vmov.f32 q12, #0.0 \n\t"
"vmov.f32 q13, #0.0 \n\t"
"vmov.f32 q14, #0.0 \n\t"
"vmov.f32 q15, #0.0 \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt 2f \n\t"
"1: \n\t"
"pld [%[lhs], #128] \n\t"
"pld [%[rhs], #128] \n\t"
"vld1.32 {d0-d2}, [%[lhs]]! \n\t"
"vld1.32 {q2, q3}, [%[rhs]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"vld1.32 {d0-d2}, [%[lhs]]! \n\t"
"vld1.32 {q2, q3}, [%[rhs]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"pld [%[lhs], #128] \n\t"
"pld [%[rhs], #128] \n\t"
"vld1.32 {d0-d2}, [%[lhs]]! \n\t"
"vld1.32 {q2, q3}, [%[rhs]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"vld1.32 {d0-d2}, [%[lhs]]! \n\t"
"vld1.32 {q2, q3}, [%[rhs]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"pld [%[lhs], #128] \n\t"
"pld [%[rhs], #128] \n\t"
"vld1.32 {d0-d2}, [%[lhs]]! \n\t"
"vld1.32 {q2, q3}, [%[rhs]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"vld1.32 {d0-d2}, [%[lhs]]! \n\t"
"vld1.32 {q2, q3}, [%[rhs]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"pld [%[lhs], #128] \n\t"
"pld [%[rhs], #128] \n\t"
"vld1.32 {d0-d2}, [%[lhs]]! \n\t"
"vld1.32 {q2, q3}, [%[rhs]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"vld1.32 {d0-d2}, [%[lhs]]! \n\t"
"vld1.32 {q2, q3}, [%[rhs]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge 1b \n\t"
"2: \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"blt 4f \n\t"
"3: \n\t"
"vld1.32 {d0-d2}, [%[lhs]]! \n\t"
"vld1.32 {q2, q3}, [%[rhs]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"bge 3b \n\t"
"4: \n\t"
"mov r5, %[c] \n\t"
"mov r6, %[step] \n\t"
"vst1.32 {q4, q5}, [r5], r6 \n\t"
"vst1.32 {q6, q7}, [r5], r6 \n\t"
"vst1.32 {q8, q9}, [r5], r6 \n\t"
"vst1.32 {q10, q11}, [r5], r6 \n\t"
"vst1.32 {q12, q13}, [r5], r6 \n\t"
"vst1.32 {q14, q15}, [r5] \n\t"
:
: [lhs] "r"(lhs), [rhs] "r"(rhs), [c] "r"(output), [kc1] "r"(kc1),
[kc2] "r"(kc2), [step] "r"(step)
: "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
}
#endif // __aarch64__
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif // __ARM_NEON__
此差异已折叠。
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "operators/math/gemm/gemm_kernel.h"
#include "operators/math/gemm/pack_kernel.h"
namespace paddle_mobile {
namespace operators {
namespace math {
struct SgemmStrategy {
typedef float Itype;
typedef float Otype;
typedef void (*packLhsFunc)(const int, const int, const Itype *, const int,
Itype *, const bool);
typedef void (*packRhsFunc)(const int, const int, const Itype *, const int,
Itype *, const bool);
typedef void (*kernelFunc)(const Itype *, const Itype *, const int, Otype *,
const int);
typedef void (*WriteFunc)(const int, const int, const Otype *, const int,
Otype *, const int);
packLhsFunc pack_lhs;
packRhsFunc pack_rhs;
kernelFunc kernel;
WriteFunc write;
static int out_width() { return 8; }
static int out_height() {
#ifdef __aarch64__
return 12;
#else
return 6;
#endif
}
SgemmStrategy() {
#ifdef __aarch64__
pack_lhs = pack_lhs_12r;
pack_rhs = pack_rhs_8c;
kernel = sgemm_12x8;
#else
pack_lhs = pack_lhs_6r;
pack_rhs = pack_rhs_8c;
kernel = sgemm_6x8;
#endif
write = write_back;
}
};
struct I8o32gemmStrategy {
typedef int8_t Itype;
typedef int32_t Otype;
typedef void (*kern_type)(const Itype *, const Itype *, const int, Otype *,
const int);
kern_type kernel;
static int out_width() { return 8; }
static int out_height() {
#ifdef __aarch64__
return 12;
#else
return 6;
#endif
}
I8o32gemmStrategy() {}
};
struct SgemvStrategy {
typedef float Itype;
typedef float Otype;
typedef void (*kern_type)(const Itype *, const Itype *, const int, Otype *,
const int);
kern_type kernel;
static int out_width() { return 1; }
static int out_height() {
#ifdef __aarch64__
return 12;
#else
return 6;
#endif
}
};
struct I8o32gemvStrategy {
typedef int8_t Itype;
typedef int32_t Otype;
typedef void (*kern_type)(const Itype *, const Itype *, const int, Otype *,
const int);
kern_type kernel;
static int out_width() { return 1; }
static int out_height() {
#ifdef __aarch64__
return 12;
#else
return 6;
#endif
}
};
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "framework/data_type.h"
#include "framework/tensor.h"
#include "operators/math/gemm.h"
#include "operators/math/gemm/cblas.h"
namespace paddle_mobile {
namespace operators {
......@@ -55,6 +56,7 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
Gemm gemm;
if (trans_a) {
framework::Tensor matrix_trans;
......@@ -69,24 +71,34 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
a[index++] = tmp[i * n + j];
}
}
if (M > N || M == 1) {
#ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
#else
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
#endif
} else {
cblas_sgemm(false, false, M, N, K, alpha, a, K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N);
}
} else {
if (M > N || M == 1) {
#ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(),
N, relu, bias);
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(),
N, relu, bias);
#else
gemm.Sgemm(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N,
relu, bias);
gemm.Sgemm(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N,
relu, bias);
#endif
} else {
cblas_sgemm(false, false, M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(),
N);
}
}
}
......
......@@ -52,9 +52,7 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180};
const float *inptr = weight.data<float>();
int remain_start = out_channel & 0xFFFC;
#if 0
remain_start = 0;
#else
#pragma omp parallel for
for (int oc = 0; oc < out_channel - 3; oc += 4) {
float gw[96]; // gw[3][8][4]
......@@ -258,7 +256,6 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
"q13", "r0");
}
}
#endif
// remain output channel
#pragma omp parallel for
......@@ -350,311 +347,8 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
size_t image_size = height * width;
const float transform_matrix[8] = {5.25f, -5.f, -4.25f, -2.5f,
2.f, -1.25f, 0.5f, 0.25f};
int remain_c_start = channel & 0xFFFC;
#if 1
remain_c_start = 0;
#else
#pragma omp parallel for
for (int c = 0; c < channel - 3; c += 4) {
const float *in = inptr + c * image_size;
float d_bt[64 * 4]; // d * B_t
for (int h = 0; h < h_tiles; ++h) {
for (int w = 0; w < w_tiles; ++w) {
const float *in0 = in + (h * width + w) * 6;
const float *in1 = in0 + image_size;
const float *in2 = in1 + image_size;
const float *in3 = in2 + image_size;
int steps = width * sizeof(float);
float *d_bt_ptr = d_bt;
asm volatile(
"mov r0, #8 \n"
"vld1.32 {d0-d3}, [%[tm_ptr]] \n"
// row loop
"loop_r_%=: \n"
"vld1.32 {d4-d7}, [%[in0]], %[steps] \n"
"vld1.32 {d8-d11}, [%[in1]], %[steps] \n"
"vld1.32 {d12-d15}, [%[in2]], %[steps] \n"
"vld1.32 {d16-d19}, [%[in3]], %[steps] \n"
"vtrn.32 q2, q4 \n" // d0: q2
"vtrn.32 q3, q5 \n" // d1: q4
"vtrn.32 q6, q8 \n" // d2: q6
"vtrn.32 q7, q9 \n" // d3: q8
"vswp.32 d5, d12 \n" // d4: q3
"vswp.32 d9, d16 \n" // d5: q5
"vswp.32 d7, d14 \n" // d6: q7
"vswp.32 d11, d18 \n" // d7: q9
"vsub.f32 q10, q2, q7 \n"
"vsub.f32 q11, q3, q6 \n"
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20-d21}, [%[d_bt]]! \n"
"vadd.f32 q10, q6, q7 \n"
"vadd.f32 q11, q4, q5 \n"
"vmla.f32 q10, q3, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q8, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vmul.f32 q10, q6, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q4, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q7 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q5, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q3, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q8, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vmul.f32 q10, q6, d2[0] \n" // 2 * d2
"vmul.f32 q11, q4, d2[0] \n" // 2 * d1
"vmla.f32 q10, q3, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q8, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q7, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q5, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vsub.f32 q10, q9, q4 \n"
"vsub.f32 q11, q8, q5 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d20-d21}, [%[d_bt]]! \n"
"subs r0, #1 \n"
"bne loop_r_%= \n"
: [d_bt] "+r"(d_bt_ptr), [in0] "+r"(in0), [in1] "+r"(in1),
[in2] "+r"(in2), [in3] "+r"(in3)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
float *ptr0 = d_bt;
float *ptr1 = ptr0 + 32;
float *ptr2 = ptr1 + 32;
float *ptr3 = ptr2 + 32;
float *ptr4 = ptr3 + 32;
float *ptr5 = ptr4 + 32;
float *ptr6 = ptr5 + 32;
float *ptr7 = ptr6 + 32;
int tile_indics = h * w_tiles + w;
int tile_block = tile_indics >> 3;
int block_indics = tile_indics & 0x7;
// (tiles / 8, 64, channel, 8)
float *out0 =
outptr + (tile_block * 64 * channel + c) * 8 + block_indics;
steps = (channel - 3) * 8 * sizeof(float);
asm volatile(
"vld1.32 {d0-d3}, [%[tm_ptr]] \n"
"mov r0, 4 \n"
"mov r1, 32 \n"
"loop_col_%=: \n"
// col 0:
"vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0
"vld1.32 {d6-d7}, [%[ptr1]]! \n" // q3: d1
"vld1.32 {d8-d9}, [%[ptr2]]! \n" // q4: d2
"vld1.32 {d10-d11}, [%[ptr3]]! \n" // q5: d3
"vld1.32 {d12-d13}, [%[ptr4]]! \n" // q6: d4
"vld1.32 {d14-d15}, [%[ptr5]]! \n" // q7: d5
"vld1.32 {d16-d17}, [%[ptr6]]! \n" // q8: d6
"vld1.32 {d18-d19}, [%[ptr7]]! \n" // q9: d7
"vsub.f32 q10, q2, q8 \n" // d0 - d6
"vsub.f32 q11, q6, q4 \n" // d4 - d2
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20[0]}, [%[out0]], r1 \n"
"vst1.32 {d20[1]}, [%[out0]], r1 \n"
"vst1.32 {d21[0]}, [%[out0]], r1 \n"
"vst1.32 {d21[1]}, [%[out0]], %[steps] \n"
"vadd.f32 q10, q4, q8 \n"
"vadd.f32 q11, q3, q7 \n"
"vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d2[0] \n" // 2 * d2
"vmul.f32 q11, q3, d2[0] \n" // 2 * d1
"vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q10, q9, q3 \n"
"vsub.f32 q11, q5, q7 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d20[0]}, [%[out0]], r1 \n"
"vst1.32 {d20[1]}, [%[out0]], r1 \n"
"vst1.32 {d21[0]}, [%[out0]], r1 \n"
"vst1.32 {d21[1]}, [%[out0]], %[steps] \n"
// col 1:
"vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0
"vld1.32 {d6-d7}, [%[ptr1]]! \n" // q3: d1
"vld1.32 {d8-d9}, [%[ptr2]]! \n" // q4: d2
"vld1.32 {d10-d11}, [%[ptr3]]! \n" // q5: d3
"vld1.32 {d12-d13}, [%[ptr4]]! \n" // q6: d4
"vld1.32 {d14-d15}, [%[ptr5]]! \n" // q7: d5
"vld1.32 {d16-d17}, [%[ptr6]]! \n" // q8: d6
"vld1.32 {d18-d19}, [%[ptr7]]! \n" // q9: d7
"vsub.f32 q10, q2, q8 \n" // d0 - d6
"vsub.f32 q11, q6, q4 \n" // d4 - d2
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20[0]}, [%[out0]], r1 \n"
"vst1.32 {d20[1]}, [%[out0]], r1 \n"
"vst1.32 {d21[0]}, [%[out0]], r1 \n"
"vst1.32 {d21[1]}, [%[out0]], %[steps] \n"
"vadd.f32 q10, q4, q8 \n"
"vadd.f32 q11, q3, q7 \n"
"vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d2[0] \n" // 2 * d2
"vmul.f32 q11, q3, d2[0] \n" // 2 * d1
"vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q10, q9, q3 \n"
"vsub.f32 q11, q5, q7 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d20[0]}, [%[out0]], r1 \n"
"vst1.32 {d20[1]}, [%[out0]], r1 \n"
"vst1.32 {d21[0]}, [%[out0]], r1 \n"
"vst1.32 {d21[1]}, [%[out0]], %[steps] \n"
"subs r0, #1 \n"
"bne loop_col_%= \n"
: [out0] "+r"(out0), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3), [ptr4] "+r"(ptr4),
[ptr5] "+r"(ptr5), [ptr6] "+r"(ptr6), [ptr7] "+r"(ptr7)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0", "r1");
}
}
}
#endif
// remainer channels
#pragma omp parallel for
for (int c = remain_c_start; c < channel; ++c) {
for (int c = 0; c < channel; ++c) {
const float *in = inptr + c * image_size;
float d_bt[64]; // d * B_t
for (int h = 0; h < h_tiles; ++h) {
......
......@@ -1753,18 +1753,15 @@ class FusionConvAddParam : public ConvParam<Dtype> {
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope);
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
GType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
GType *Output() const { return output_; }
protected:
GType *bias_;
int axis_;
GType *output_;
};
template <typename Dtype>
......@@ -1797,18 +1794,16 @@ class FusionConvAddPReluParam : public ConvParam<Dtype> {
framework::DDim dims = alpha_->dims();
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope);
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
const GType *InputAlpha() const { return alpha_; }
const std::string &Mode() const { return mode_; }
GType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
GType *Output() const { return output_; }
protected:
GType *bias_;
int axis_;
GType *output_;
GType *alpha_;
std::string mode_;
};
......@@ -1830,7 +1825,6 @@ class FusionConvAddAddPReluParam : public ConvParam<Dtype> {
mode_ = OpParam::GetStringAttr("mode", attrs);
framework::DDim dims = alpha_->dims();
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
output_ = OpParam::OutFrom<GType>(outputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
keyOutput_ = OpParam::getkey("addOut", inputs, 0);
keyX1_ = OpParam::getkey("addX", inputs, 1);
......@@ -1840,6 +1834,7 @@ class FusionConvAddAddPReluParam : public ConvParam<Dtype> {
} else if (keyY1_ == keyOutput_) {
bias1_ = OpParam::InputXFrom1<GType>(inputs, scope);
}
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
const GType *InputAlpha() const { return alpha_; }
const std::string &Mode() const { return mode_; }
......@@ -1848,12 +1843,10 @@ class FusionConvAddAddPReluParam : public ConvParam<Dtype> {
GType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
GType *Output() const { return output_; }
protected:
GType *bias_;
int axis_;
GType *output_;
GType *alpha_;
std::string mode_;
GType *bias1_;
......@@ -1876,21 +1869,18 @@ class FusionConvAddBNReluParam : public ConvParam<Dtype> {
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
GType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
GType *Output() const { return output_; }
const GType *InputBias() const { return input_bias_; }
const GType *InputMean() const { return input_mean_; }
......@@ -1903,8 +1893,6 @@ class FusionConvAddBNReluParam : public ConvParam<Dtype> {
const float &Momentum() const { return momentum_; }
const bool &IsTest() const { return is_test_; }
void SetNewScale(GType *new_scale) { new_scale_ = new_scale; }
void SetNewBias(GType *new_bias) { new_bias_ = new_bias; }
......@@ -1916,14 +1904,12 @@ class FusionConvAddBNReluParam : public ConvParam<Dtype> {
protected:
GType *bias_;
int axis_;
GType *output_;
GType *input_bias_;
GType *input_mean_;
GType *input_scale_;
GType *input_variance_;
float epsilon_;
float momentum_;
bool is_test_;
GType *new_bias_;
GType *new_scale_;
};
......@@ -1942,7 +1928,6 @@ class FusionConvBNAddReluParam : public ConvParam<Dtype> {
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
......@@ -1957,14 +1942,12 @@ class FusionConvBNAddReluParam : public ConvParam<Dtype> {
} else if (keyY_ == keyBNY_) {
bias_ = OpParam::InputXFrom<GType>(inputs, scope);
}
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
GType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
GType *Output() const { return output_; }
const GType *InputBias() const { return input_bias_; }
const GType *InputMean() const { return input_mean_; }
......@@ -1977,8 +1960,6 @@ class FusionConvBNAddReluParam : public ConvParam<Dtype> {
const float &Momentum() const { return momentum_; }
const bool &IsTest() const { return is_test_; }
void SetNewScale(GType *new_scale) { new_scale_ = new_scale; }
void SetNewBias(GType *new_bias) { new_bias_ = new_bias; }
......@@ -1990,14 +1971,12 @@ class FusionConvBNAddReluParam : public ConvParam<Dtype> {
protected:
GType *bias_;
int axis_;
GType *output_;
GType *input_bias_;
GType *input_mean_;
GType *input_scale_;
GType *input_variance_;
float epsilon_;
float momentum_;
bool is_test_;
GType *new_bias_;
GType *new_scale_;
std::string keyBNY_;
......@@ -2017,16 +1996,14 @@ class FusionConvBNParam : public ConvParam<Dtype> {
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
output_y_ = OpParam::OutputYFrom<GType>(outputs, scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
this->output_ = OpParam::OutputYFrom<GType>(outputs, scope);
}
GType *Output() const { return output_y_; }
const GType *InputBias() const { return input_bias_; }
......@@ -2040,8 +2017,6 @@ class FusionConvBNParam : public ConvParam<Dtype> {
const float &Momentum() const { return momentum_; }
const bool &IsTest() const { return is_test_; }
void SetNewScale(GType *new_scale) { new_scale_ = new_scale; }
void SetNewBias(GType *new_bias) { new_bias_ = new_bias; }
......@@ -2051,14 +2026,12 @@ class FusionConvBNParam : public ConvParam<Dtype> {
const GType *NewBias() const { return new_bias_; }
protected:
GType *output_y_;
GType *input_bias_;
GType *input_mean_;
GType *input_scale_;
GType *input_variance_;
float epsilon_;
float momentum_;
bool is_test_;
GType *new_bias_;
GType *new_scale_;
};
......@@ -2077,21 +2050,18 @@ class FusionConvAddBNParam : public ConvParam<Dtype> {
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
output_y_ = OpParam::OutputYFrom<GType>(outputs, scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
this->output_ = OpParam::OutputYFrom<GType>(outputs, scope);
}
GType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
GType *Output() const { return output_y_; }
const GType *InputBias() const { return input_bias_; }
const GType *InputMean() const { return input_mean_; }
......@@ -2104,8 +2074,6 @@ class FusionConvAddBNParam : public ConvParam<Dtype> {
const float &Momentum() const { return momentum_; }
const bool &IsTest() const { return is_test_; }
void SetNewScale(GType *new_scale) { new_scale_ = new_scale; }
void SetNewBias(GType *new_bias) { new_bias_ = new_bias; }
......@@ -2117,14 +2085,12 @@ class FusionConvAddBNParam : public ConvParam<Dtype> {
protected:
GType *bias_;
int axis_;
GType *output_y_;
GType *input_bias_;
GType *input_mean_;
GType *input_scale_;
GType *input_variance_;
float epsilon_;
float momentum_;
bool is_test_;
GType *new_bias_;
GType *new_scale_;
};
......@@ -2141,16 +2107,14 @@ class FusionDWConvBNReluParam : public ConvParam<Dtype> {
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
output_ = OpParam::OutFrom<GType>(outputs, scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
GType *Output() const { return output_; }
const GType *InputBias() const { return input_bias_; }
......@@ -2164,8 +2128,6 @@ class FusionDWConvBNReluParam : public ConvParam<Dtype> {
const float &Momentum() const { return momentum_; }
const bool &IsTest() const { return is_test_; }
void SetNewScale(GType *new_scale) { new_scale_ = new_scale; }
void SetNewBias(GType *new_bias) { new_bias_ = new_bias; }
......@@ -2175,14 +2137,12 @@ class FusionDWConvBNReluParam : public ConvParam<Dtype> {
const GType *NewBias() const { return new_bias_; }
protected:
GType *output_;
GType *input_bias_;
GType *input_mean_;
GType *input_scale_;
GType *input_variance_;
float epsilon_;
float momentum_;
bool is_test_;
GType *new_bias_;
GType *new_scale_;
};
......@@ -2200,16 +2160,14 @@ class FusionConvBNReluParam : public ConvParam<Dtype> {
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
output_ = OpParam::OutFrom<GType>(outputs, scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
GType *Output() const { return output_; }
const GType *InputBias() const { return input_bias_; }
......@@ -2223,8 +2181,6 @@ class FusionConvBNReluParam : public ConvParam<Dtype> {
const float &Momentum() const { return momentum_; }
const bool &IsTest() const { return is_test_; }
void SetNewScale(GType *new_scale) { new_scale_ = new_scale; }
void SetNewBias(GType *new_bias) { new_bias_ = new_bias; }
......@@ -2234,14 +2190,12 @@ class FusionConvBNReluParam : public ConvParam<Dtype> {
const GType *NewBias() const { return new_bias_; }
protected:
GType *output_;
GType *input_bias_;
GType *input_mean_;
GType *input_scale_;
GType *input_variance_;
float epsilon_;
float momentum_;
bool is_test_;
GType *new_bias_;
GType *new_scale_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册