提交 88960828 编写于 作者: S smilejames 提交者: GitHub

Merge pull request #899 from yangfei963158659/develop

imp fusion_conv_bn_add_relu op in resnet
......@@ -26,6 +26,7 @@ const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu";
const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu";
const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU = "fusion_conv_add_add_prelu";
const char *G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu";
const char *G_OP_TYPE_FUSION_CONV_BN_ADD_RELU = "fusion_conv_bn_add_relu";
const char *G_OP_TYPE_FUSION_DWCONV_BN_RELU = "fusion_dwconv_bn_relu";
const char *G_OP_TYPE_FUSION_CONV_BN_RELU = "fusion_conv_bn_relu";
const char *G_OP_TYPE_FC = "fusion_fc";
......@@ -79,6 +80,7 @@ std::unordered_map<
{G_OP_TYPE_BOX_CODER,
{{"PriorBox", "PriorBoxVar", "TargetBox"}, {"OutputBox"}}},
{G_OP_TYPE_FUSION_CONV_ADD_BN_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_BN_ADD_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_PRIOR_BOX, {{"Image", "Input"}, {"Boxes", "Variances"}}},
{G_OP_TYPE_MULTICLASS_NMS, {{"BBoxes", "Scores"}, {"Out"}}},
{G_OP_TYPE_FC, {{"X", "Y", "Z"}, {"Out"}}},
......
......@@ -90,6 +90,7 @@ extern const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU;
extern const char *G_OP_TYPE_FC;
extern const char *G_OP_TYPE_FUSION_CONV_ADD;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_BN_RELU;
extern const char *G_OP_TYPE_FUSION_CONV_BN_ADD_RELU;
extern const char *G_OP_TYPE_FUSION_DWCONV_BN_RELU;
extern const char *G_OP_TYPE_FUSION_CONV_BN_RELU;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNADDRELU_OP
#include "operators/fusion_conv_bn_add_relu_op.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionConvBNAddReluOp<Dtype, T>::InferShape() const {
auto in_dims = this->param_.Input()->dims();
auto filter_dims = this->param_.Filter()->dims();
const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
int groups = this->param_.Groups();
std::vector<int> dilations = this->param_.Dilations();
PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() &&
dilations.size() == paddings.size() &&
paddings.size() == strides.size()),
"ConvParam is not suitable");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(
math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i],
paddings[i], strides[i]));
}
framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_conv_bn_add_relu, ops::FusionConvBNAddReluOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fusion_conv_bn_add_relu, ops::FusionConvBNAddReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNADDRELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "op_param.h"
#include "operators/kernel/conv_bn_add_relu_kernel.h"
namespace paddle_mobile {
namespace operators {
using std::string;
using std::vector;
class FusionConvBNAddReluMatcher : public framework::FusionOpMatcher {
public:
FusionConvBNAddReluMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}, {"X", "X"}}},
{G_OP_TYPE_BATCHNORM,
{{"Scale", "Scale"},
{"Mean", "Mean"},
{"Bias", "Bias"},
{"Variance", "Variance"},
{"Y", "BNY"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_CONV_BN_ADD_RELU; }
std::vector<std::pair<int, std::string>> NeedCheck() {
DLOG << " conv bn add relu check add X ";
return {{2, "Y"}, {2, "X"}};
}
};
template <typename DeviceType, typename T>
class FusionConvBNAddReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionConvBNAddReluParam<DeviceType>,
operators::ConvBNAddReluKernel<DeviceType, T>> {
public:
FusionConvBNAddReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionConvBNAddReluParam<DeviceType>,
operators::ConvBNAddReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionConvBNAddReluParam<DeviceType>,
operators::ConvBNAddReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
#ifdef PADDLE_MOBILE_CPU
#ifndef FUSION_CONV_BN_ADD_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_bn_add_relu_registrar(
new FusionConvBNAddReluMatcher());
#define FUSION_CONV_BN_ADD_RELU_REGISTER
#endif
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#ifndef FUSION_CONV_BN_ADD_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_bn_add_relu_registrar(
new FusionConvBNAddReluMatcher());
#define FUSION_CONV_BN_ADD_RELU_REGISTER
#endif
#endif
#ifdef PADDLE_MOBILE_FPGA
#ifndef FUSION_CONV_BN_ADD_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_bn_add_relu_registrar(
new FusionConvBNAddReluMatcher());
#define FUSION_CONV_BN_ADD_RELU_REGISTER
#endif
#endif
} // namespace operators
} // namespace paddle_mobile
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(fusion_conv_bn_add_relu);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
USE_OP_FPGA(fusion_conv_bn_add_relu);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNADDRELU_OP
#include "operators/kernel/conv_bn_add_relu_kernel.h"
#include "operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvBNAddReluKernel<CPU, float>::Init(
FusionConvBNAddReluParam<CPU> *param) {
const Tensor *mean = param->InputMean();
const Tensor *variance = param->InputVariance();
const Tensor *scale = param->InputScale();
const Tensor *bias = param->InputBias();
const float epsilon = param->Epsilon();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
const int C = mean->numel();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor *new_scale = new Tensor();
Tensor *new_bias = new Tensor();
auto new_scale_ptr = new_scale->mutable_data<float>({C});
auto new_bias_ptr = new_bias->mutable_data<float>({C});
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
return true;
}
template <>
void ConvBNAddReluKernel<CPU, float>::Compute(
const FusionConvBNAddReluParam<CPU> &param) const {
ConvBNAddReluCompute<float>(param);
}
template class ConvBNAddReluKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNADDRELU_OP
#pragma once
#include <vector>
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
Tensor *output = param.Output();
Tensor *bias1 = param.Bias();
int groups = param.Groups();
DLOG << "yangfei2";
DLOG << bias1->dims();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor bias_batch = bias1->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
Tensor bias_data = bias_batch.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBn<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), true, &new_scale,
&new_bias, g, bias_data.data<float>());
}
}
}
template <typename P>
void ConvBNAddReluCompute(const FusionConvBNAddReluParam<CPU> &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else {
ConvBNAddReluBasic(param);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef FUSION_CONVBNADDRELU_OP
#include <vector>
#include "framework/ddim.h"
#include "framework/operator.h"
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::DDim;
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvBNAddReluKernel
: public OpKernelBase<DeviceType, FusionConvBNAddReluParam<DeviceType>> {
public:
void Compute(const FusionConvBNAddReluParam<DeviceType> &param) const;
bool Init(FusionConvBNAddReluParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -716,6 +716,27 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
}
}
// 分块矩阵乘法
void InnerKernelWithBnAdd(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *new_scale, float *new_bias,
float *bias) {
#pragma omp parallel for
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
#if __aarch64__
// AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#else
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif
}
}
WriteWithBnAddRelu(mc, nc, c, C, ldc, new_scale, new_bias, bias);
}
void InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b,
float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1) {
......@@ -1183,6 +1204,59 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
}
}
// C = A * B, batchnorm(C),C = C + bias; relu(C)
void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias, float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr, *bias_ptr;
float32x4_t cv;
float32x4_t nbias;
float32x2_t scale;
float32x4_t biasv;
float32x4_t zero = vdupq_n_f32(0.0);
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
bias_ptr = bias + i * ldc;
nbias = vld1q_dup_f32(new_bias);
scale = vld1_dup_f32(new_scale);
new_bias++;
new_scale++;
float scale0 = vget_lane_f32(scale, 0);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
biasv = vld1q_f32(bias_ptr);
cv = vmlaq_n_f32(nbias, cv, scale0);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
bias_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
biasv = vld1q_f32(bias_ptr);
cv = vmlaq_n_f32(nbias, cv, scale0);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
}
}
}
}
#else
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
......@@ -2426,6 +2500,59 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
"q8", "q10", "q11", "q12", "q13", "q14");
}
// C = A * B, batchnorm(C),C = C + bias; relu(C)
void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias, float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr, *bias_ptr;
float32x4_t cv;
float32x4_t nbias;
float32x2_t scale;
float32x4_t biasv;
float32x4_t zero = vdupq_n_f32(0.0);
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
bias_ptr = bias + i * ldc;
nbias = vld1q_dup_f32(new_bias);
scale = vld1_dup_f32(new_scale);
new_bias++;
new_scale++;
float scale0 = vget_lane_f32(scale, 0);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
biasv = vld1q_f32(bias_ptr);
cv = vmlaq_n_f32(nbias, cv, scale0);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
bias_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
biasv = vld1q_f32(bias_ptr);
cv = vmlaq_n_f32(nbias, cv, scale0);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
}
}
}
}
/*
// C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc) {
......@@ -2835,7 +2962,7 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
bool relu, float *new_scale, float *new_bias, float *bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024;
......@@ -2882,8 +3009,14 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
if (bias == nullptr) {
InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, new_scale + i, new_bias + i);
} else {
InnerKernelWithBnAdd(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, new_scale + i, new_bias + i,
bias + i * ldc + j);
}
}
}
......@@ -3071,7 +3204,8 @@ void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
bool relu, float *new_scale, float *new_bias,
float *bias) {
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
......@@ -3148,8 +3282,14 @@ void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda,
float *local_A = packedA + MC * KC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackA(mc, KC, mc % MR, &A(i, 0), lda, local_A);
InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, &C(i, 0),
ldc, relu, new_scale + i, new_bias + i);
if (bias == nullptr) {
InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C,
&C(i, 0), ldc, relu, new_scale + i, new_bias + i);
} else {
InnerKernelWithBnAdd(mc, n, alpha, local_A, packedB, beta, local_C,
&C(i, 0), ldc, relu, new_scale + i, new_bias + i,
bias + i * ldc);
}
}
} else {
#pragma omp parallel for
......@@ -3165,8 +3305,14 @@ void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda,
float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackB(KC, nc, nc % NR, &B(0, j), ldb, local_B);
InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, &C(0, j),
ldc, relu, new_scale, new_bias);
if (bias == nullptr) {
InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C,
&C(0, j), ldc, relu, new_scale, new_bias);
} else {
InnerKernelWithBnAdd(m, nc, alpha, packedA, local_B, beta, local_C,
&C(0, j), ldc, relu, new_scale, new_bias,
bias + j);
}
}
}
......
......@@ -81,6 +81,10 @@ void InnerKernelWithBias(int mc, int nc, float alpha, const float *a,
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
void InnerKernelWithBnAdd(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *new_scale, float *new_bias,
float *bias);
void InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b,
float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
......@@ -125,7 +129,8 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
// C = A * B, batchnorm(C), relu(C)
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias);
void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias, float *bias1);
/*
// 向量矩阵乘法结果回写
// C = A * B
......@@ -152,8 +157,7 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
bool relu, float *new_scale, float *new_bias, float *bias);
void SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
const float *B, int ldb, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
......@@ -166,7 +170,7 @@ void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
// 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本)
void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
bool relu, float *new_scale, float *new_bias, float *bias);
void SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
const float *B, int ldb, float *C, int ldc, float *p,
......
......@@ -56,7 +56,7 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out, float beta,
bool relu, framework::Tensor *new_scale,
framework::Tensor *new_bias, int group) {
framework::Tensor *new_bias, int group, float *bias) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -79,12 +79,12 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
SgemmWithBn_omp(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N,
relu, new_scale->data<float>() + group,
new_bias->data<float>() + group);
new_bias->data<float>() + group, bias);
#else
SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu,
new_scale->data<float>() + group,
new_bias->data<float>() + group);
new_scale->data<float>() + group, new_bias->data<float>() + group,
bias);
#endif
}
void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
......
......@@ -32,7 +32,7 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu,
framework::Tensor *new_scale, framework::Tensor *new_bias,
int group);
int group, float *bias = nullptr);
void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
......
......@@ -1472,6 +1472,119 @@ class FusionConvAddBNReluParam : public OpParam {
};
#endif
#ifdef FUSION_CONVBNADDRELU_OP
template <typename Dtype>
class FusionConvBNAddReluParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionConvBNAddReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
bias_ = InputYFrom<GType>(inputs, scope);
axis_ = GetAttr<int>("axis", attrs);
filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope);
input_mean_ = InputMeanFrom<GType>(inputs, scope);
input_scale_ = InputScaleFrom<GType>(inputs, scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
keyBNY_ = getkey("BNY", inputs, 0);
keyX_ = getkey("X", inputs, 0);
keyY_ = getkey("Y", inputs, 0);
if (keyX_ == keyBNY_) {
bias_ = InputYFrom<GType>(inputs, scope);
} else if (keyY_ == keyBNY_) {
bias_ = InputXFrom<GType>(inputs, scope);
}
// is_test_ = GetAttr<bool>("is_test", attrs);
}
RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
#ifdef PADDLE_MOBILE_FPGA
RType *Filter() const { return filter_; }
#else
const RType *Filter() const { return filter_; }
#endif
RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; }
const RType *InputScale() const { return input_scale_; }
const RType *InputVariance() const { return input_variance_; }
const float &Epsilon() const { return epsilon_; }
const float &Momentum() const { return momentum_; }
const bool &IsTest() const { return is_test_; }
void SetNewScale(RType *new_scale) { new_scale_ = new_scale; }
void SetNewBias(RType *new_bias) { new_bias_ = new_bias; }
const RType *NewScale() const { return new_scale_; }
const RType *NewBias() const { return new_bias_; }
protected:
RType *bias_;
int axis_;
RType *input_;
RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_;
RType *input_mean_;
RType *input_scale_;
RType *input_variance_;
float epsilon_;
float momentum_;
bool is_test_;
RType *new_bias_;
RType *new_scale_;
std::string keyBNY_;
std::string keyX_;
std::string keyY_;
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::ConvArgs fpga_conv_args;
public:
const fpga::ConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::ConvArgs &args) { fpga_conv_args = args; }
#endif
};
#endif
#ifdef FUSION_CONVBN_OP
template <typename Dtype>
class FusionConvBNParam : public OpParam {
......
......@@ -83,8 +83,8 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
}
}
paddle_mobile::operators::math::SgemmWithBn(m, n, k, 0.9, a, lda, b, ldb, 0.3,
c, ldc, relu, scale, bias);
paddle_mobile::operators::math::SgemmWithBn(
m, n, k, 0.9, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias, nullptr);
int eq = 0;
int neq = 0;
for (int i = 0; i < m * n; ++i) {
......
......@@ -87,6 +87,7 @@ if ("resnet" IN_LIST NET)
set(ELEMENTWISEADD_OP ON)
set(POOL_OP ON)
set(BATCHNORM_OP ON)
set(FUSION_CONVBNADDRELU_OP ON)
set(MUL_OP ON)
set(RESHAPE_OP ON)
set(SOFTMAX_OP ON)
......@@ -141,6 +142,7 @@ if(NOT FOUND_MATCH)
set(FUSION_CONVADDADDPRELU_OP ON)
set(FUSION_DWCONVBNRELU_OP ON)
set(FUSION_CONVBNRELU_OP ON)
set(FUSION_CONVBNADDRELU_OP ON)
set(PRELU_OP ON)
set(RESIZE_OP ON)
set(SCALE_OP ON)
......@@ -244,6 +246,10 @@ if (FUSION_CONVBNRELU_OP)
add_definitions(-DFUSION_CONVBNRELU_OP)
endif()
if (FUSION_CONVBNADDRELU_OP)
add_definitions(-DFUSION_CONVBNADDRELU_OP)
endif()
if (PRELU_OP)
add_definitions(-DPRELU_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册