提交 8ea3708c 编写于 作者: H hjchen2

Merge depthwise_conv2d and conv2d kernel, add some dequant fusion kernels

上级 e54bf8c5
......@@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN = "fusion_dequant_add_bn";
const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU = "fusion_dequant_bn_relu";
const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu";
const char *G_OP_TYPE_TANH = "tanh";
......@@ -136,6 +138,8 @@ std::unordered_map<
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_ADD_BN, {{"X", "Scale"}, {"Y"}}},
{G_OP_TYPE_FUSION_DEQUANT_BN_RELU, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_TANH, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}},
......
......@@ -138,6 +138,8 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE;
extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN;
extern const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU;
extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU;
extern const char *G_OP_TYPE_TANH;
......
......@@ -233,6 +233,14 @@ LOAD_OP1(quantize, CPU);
#ifdef DEQUANT_OP
LOAD_OP1(dequantize, CPU);
#endif
#ifdef FUSION_DEQUANT_ADD_BN_OP
LOAD_OP1(fusion_dequant_add_bn, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn);
#endif
#ifdef FUSION_DEQUANT_BN_RELU_OP
LOAD_OP1(fusion_dequant_bn_relu, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_bn_relu);
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
LOAD_OP1(fusion_dequant_add_bn_relu, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu);
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <string>
#include "framework/operator.h"
#include "operators/kernel/depthwise_conv_kernel.h"
#include "operators/kernel/conv_kernel.h"
namespace paddle_mobile {
namespace operators {
......@@ -26,19 +26,16 @@ namespace operators {
template <typename DeviceType, typename T>
class DepthwiseConvOp : public framework::OperatorWithKernel<
DeviceType, ConvParam<DeviceType>,
operators::DepthwiseConvKernel<DeviceType, T>> {
operators::ConvKernel<DeviceType, T>> {
public:
DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, ConvParam<DeviceType>,
operators::DepthwiseConvKernel<DeviceType, T>>(
: framework::OperatorWithKernel<DeviceType, ConvParam<DeviceType>,
operators::ConvKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
private:
};
} // namespace operators
......
......@@ -12,27 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DEPTHWISECONV_OP
#ifdef FUSION_DEQUANT_ADD_BN_OP
#include "operators/kernel/depthwise_conv_kernel.h"
#include "operators/kernel/central-arm-func/depthwise_conv_arm_func.h"
#include "operators/fusion_dequant_add_bn_op.h"
namespace paddle_mobile {
namespace operators {
template <>
bool DepthwiseConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
return true;
template <typename Dtype, typename T>
void FusionDequantAddBNOp<Dtype, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
}
template <>
void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
DepthwiseConvCompute<float>(param);
}
template class DepthwiseConvKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_dequant_add_bn, ops::FusionDequantAddBNMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dequant_add_bn, ops::FusionDequantAddBNOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionDequantAddBNMatcher : public framework::FusionOpMatcher {
public:
FusionDequantAddBNMatcher() {
node_ = framework::Node(G_OP_TYPE_DEQUANTIZE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}},
{G_OP_TYPE_BATCHNORM,
{{"Scale", "BNScale"},
{"Mean", "BNMean"},
{"Bias", "BNBias"},
{"Variance", "BNVariance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN; }
};
template <typename DeviceType, typename T>
class FusionDequantAddBNOp
: public framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNParam<DeviceType>,
operators::FusionDequantAddBNKernel<DeviceType, T>> {
public:
FusionDequantAddBNOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNParam<DeviceType>,
operators::FusionDequantAddBNKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -20,7 +20,7 @@ limitations under the License. */
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include "operators/kernel/dequant_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
......
......@@ -12,29 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DEPTHWISECONV_OP
#ifdef FUSION_DEQUANT_BN_RELU_OP
#pragma once
#include "framework/operator.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
#include "operators/fusion_dequant_bn_relu_op.h"
namespace paddle_mobile {
namespace operators {
using framework::OpKernelBase;
template <typename Dtype, typename T>
void FusionDequantBNReluOp<Dtype, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
}
template <typename DeviceType, typename T>
class DepthwiseConvKernel
: public OpKernelBase<DeviceType, ConvParam<DeviceType>> {
public:
void Compute(const ConvParam<DeviceType> &param);
bool Init(ConvParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_dequant_bn_relu,
ops::FusionDequantBNReluMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dequant_bn_relu, ops::FusionDequantBNReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_BN_RELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionDequantBNReluMatcher : public framework::FusionOpMatcher {
public:
FusionDequantBNReluMatcher() {
node_ = framework::Node(G_OP_TYPE_DEQUANTIZE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_BATCHNORM,
{{"Scale", "BNScale"},
{"Mean", "BNMean"},
{"Bias", "BNBias"},
{"Variance", "BNVariance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_BN_RELU; }
};
template <typename DeviceType, typename T>
class FusionDequantBNReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionDequantBNReluParam<DeviceType>,
operators::FusionDequantBNReluKernel<DeviceType, T>> {
public:
FusionDequantBNReluOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDequantBNReluParam<DeviceType>,
operators::FusionDequantBNReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -22,33 +22,35 @@ namespace operators {
template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3;
bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == typeid(int8_t)) {
if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] < 3 &&
if (depth3x3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
}
} else {
if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1) {
if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT;
} else if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT;
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 0 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT;
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
#ifndef __aarch64__
} else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Strides()[0] == param->Strides()[1] &&
} else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1 &&
param->Dilations()[0] == 1 && param->Output()->dims()[1] >= 16 &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1 &&
param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
......@@ -78,9 +80,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), nullptr, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false);
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
......
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#ifdef FUSION_DEQUANT_ADD_BN_OP
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include "operators/kernel/dequant_add_bn_kernel.h"
#include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
......@@ -24,8 +24,8 @@ namespace paddle_mobile {
namespace operators {
template <>
bool FusionDequantAddBNReluKernel<CPU, float>::Init(
FusionDequantAddBNReluParam<CPU> *param) {
bool FusionDequantAddBNKernel<CPU, float>::Init(
FusionDequantAddBNParam<CPU> *param) {
// elementwise add params
const Tensor *bias = param->bias_;
// batch norm params
......@@ -49,8 +49,8 @@ bool FusionDequantAddBNReluKernel<CPU, float>::Init(
}
template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute(
const FusionDequantAddBNReluParam<CPU> &param) {
void FusionDequantAddBNKernel<CPU, float>::Compute(
const FusionDequantAddBNParam<CPU> &param) {
const int32_t *input = param.input_->data<int32_t>();
const float *bn_scale = param.bn_scale_->data<float>();
const float *bn_bias = param.bn_bias_->data<float>();
......@@ -78,7 +78,6 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias);
float32x4_t __zero = vdupq_n_f32(0.f);
for (int k = 0; k < loop; ++k, x += 16, y += 16) {
int32x4_t r0 = vld1q_s32(x);
......@@ -93,10 +92,6 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
f1 = vmlaq_f32(__bias, __scale, f1);
f2 = vmlaq_f32(__bias, __scale, f2);
f3 = vmlaq_f32(__bias, __scale, f3);
f0 = vmaxq_f32(__zero, f0);
f1 = vmaxq_f32(__zero, f1);
f2 = vmaxq_f32(__zero, f2);
f3 = vmaxq_f32(__zero, f3);
vst1q_f32(y, f0);
vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2);
......@@ -104,7 +99,7 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
}
#endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) {
y[k] = std::max(scale * x[k] + bias, 0.f);
y[k] = scale * x[k] + bias;
}
}
}
......@@ -113,4 +108,4 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
} // namespace operators
} // namespace paddle_mobile
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP
#endif // FUSION_DEQUANT_ADD_BN_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/kernel/dequant_bn_relu_kernel.h"
#include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
#if defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_RELU_OP)
void DequantBNReluCompute(const FusionDequantBNParam<CPU> *param) {
const int32_t *input = param->input_->data<int32_t>();
const float *bn_scale = param->bn_scale_->data<float>();
const float *bn_bias = param->bn_bias_->data<float>();
// dequantize params
const float activation_scale = param->activation_scale_->data<float>()[0];
const float weight_scale = param->weight_scale_;
const float dequant_scale = activation_scale / weight_scale;
float *output = param->output_->mutable_data<float>();
int batch_size = param->input_->dims()[0];
int channels = param->input_->dims()[1];
size_t spatial_size = param->input_->dims()[2] * param->input_->dims()[3];
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channels; ++c) {
float scale = bn_scale[c] * dequant_scale;
float bias = bn_bias[c];
size_t offset = (batch * channels + c) * spatial_size;
const int32_t *x = input + offset;
float *y = output + offset;
size_t remain = spatial_size;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = spatial_size >> 4;
remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias);
float32x4_t __zero = vdupq_n_f32(0.f);
for (int k = 0; k < loop; ++k, x += 16, y += 16) {
int32x4_t r0 = vld1q_s32(x);
int32x4_t r1 = vld1q_s32(x + 4);
int32x4_t r2 = vld1q_s32(x + 8);
int32x4_t r3 = vld1q_s32(x + 12);
float32x4_t f0 = vcvtq_f32_s32(r0);
float32x4_t f1 = vcvtq_f32_s32(r1);
float32x4_t f2 = vcvtq_f32_s32(r2);
float32x4_t f3 = vcvtq_f32_s32(r3);
f0 = vmlaq_f32(__bias, __scale, f0);
f1 = vmlaq_f32(__bias, __scale, f1);
f2 = vmlaq_f32(__bias, __scale, f2);
f3 = vmlaq_f32(__bias, __scale, f3);
f0 = vmaxq_f32(__zero, f0);
f1 = vmaxq_f32(__zero, f1);
f2 = vmaxq_f32(__zero, f2);
f3 = vmaxq_f32(__zero, f3);
vst1q_f32(y, f0);
vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2);
vst1q_f32(y + 12, f3);
}
#endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) {
y[k] = std::max(scale * x[k] + bias, 0.f);
}
}
}
}
#endif
#ifdef FUSION_DEQUANT_BN_RELU_OP
template <>
bool FusionDequantBNReluKernel<CPU, float>::Init(
FusionDequantBNReluParam<CPU> *param) {
// batch norm params
const Tensor *bn_mean = param->bn_mean_;
const Tensor *bn_variance = param->bn_variance_;
Tensor *bn_scale = param->bn_scale_;
Tensor *bn_bias = param->bn_bias_;
const float epsilon = param->epsilon_;
const float *mean_ptr = bn_mean->data<float>();
const float *var_ptr = bn_variance->data<float>();
float *bn_scale_ptr = bn_scale->mutable_data<float>();
float *bn_bias_ptr = bn_bias->mutable_data<float>();
for (int c = 0; c < bn_scale->numel(); ++c) {
float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon));
bn_scale_ptr[c] = inv_scale;
bn_bias_ptr[c] = bn_bias_ptr[c] - inv_scale * mean_ptr[c];
}
return true;
}
template <>
void FusionDequantBNReluKernel<CPU, float>::Compute(
const FusionDequantBNReluParam<CPU> &param) {
DequantBNReluCompute(&param);
}
#endif // FUSION_DEQUANT_BN_RELU_OP
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <>
bool FusionDequantAddBNReluKernel<CPU, float>::Init(
FusionDequantAddBNReluParam<CPU> *param) {
// elementwise add params
const Tensor *bias = param->bias_;
// batch norm params
const Tensor *bn_mean = param->bn_mean_;
const Tensor *bn_variance = param->bn_variance_;
Tensor *bn_scale = param->bn_scale_;
Tensor *bn_bias = param->bn_bias_;
const float epsilon = param->epsilon_;
const float *bias_ptr = bias->data<float>();
const float *mean_ptr = bn_mean->data<float>();
const float *var_ptr = bn_variance->data<float>();
float *bn_scale_ptr = bn_scale->mutable_data<float>();
float *bn_bias_ptr = bn_bias->mutable_data<float>();
for (int c = 0; c < bn_scale->numel(); ++c) {
float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon));
bn_scale_ptr[c] = inv_scale;
bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c];
}
return true;
}
template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute(
const FusionDequantAddBNReluParam<CPU> &param) {
DequantBNReluCompute(&param);
}
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP
} // namespace operators
} // namespace paddle_mobile
......@@ -132,10 +132,10 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
// param.Output(), false);
if (param.Paddings()[0] == 0) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
*param.Bias(), true);
param.Bias(), true);
} else {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), *param.Bias(), true);
param.Output(), param.Bias(), true);
}
} else {
ConvAddBasic(param);
......
......@@ -107,9 +107,15 @@ inline void GemmConv(const ConvParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Itype>(filter_slice, false, col_matrix, false,
if (param.Input()->type() == typeid(int8_t)) {
math::matmul_int8(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
} else {
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
}
}
}
}
......
......@@ -73,8 +73,8 @@ void MulCompute(const MulParam<CPU> &param) {
}
if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>();
math::matmul<int8_t>(x_matrix, false, y_matrix, false,
static_cast<int8_t>(1), out, static_cast<int8_t>(0));
math::matmul_int8(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0));
} else {
out->mutable_data<float>();
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#ifdef FUSION_DEQUANT_ADD_BN_OP
#include "framework/operator.h"
#include "operators/op_param.h"
......@@ -23,12 +23,12 @@ namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class FusionDequantAddBNReluKernel
class FusionDequantAddBNKernel
: public framework::OpKernelBase<DeviceType,
FusionDequantAddBNReluParam<DeviceType>> {
FusionDequantAddBNParam<DeviceType>> {
public:
void Compute(const FusionDequantAddBNReluParam<DeviceType> &param);
bool Init(FusionDequantAddBNReluParam<DeviceType> *param);
void Compute(const FusionDequantAddBNParam<DeviceType> &param);
bool Init(FusionDequantAddBNParam<DeviceType> *param);
};
} // namespace operators
......
......@@ -12,42 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DEPTHWISECONV_OP
#pragma once
#include <vector>
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/depthwise_conv3x3.h"
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
void DepthwiseConvCompute(const ConvParam<CPU> &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
&Bias, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), &Bias, param.Output(), false);
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(),
Bias, false);
} else {
GemmConv<float, float>(param);
}
}
#ifdef FUSION_DEQUANT_BN_RELU_OP
template <typename DeviceType, typename T>
class FusionDequantBNReluKernel
: public framework::OpKernelBase<DeviceType,
FusionDequantBNReluParam<DeviceType>> {
public:
void Compute(const FusionDequantBNReluParam<DeviceType> &param);
bool Init(FusionDequantBNReluParam<DeviceType> *param);
};
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <typename DeviceType, typename T>
class FusionDequantAddBNReluKernel
: public framework::OpKernelBase<DeviceType,
FusionDequantAddBNReluParam<DeviceType>> {
public:
void Compute(const FusionDequantAddBNReluParam<DeviceType> &param);
bool Init(FusionDequantAddBNReluParam<DeviceType> *param);
};
#endif
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -1272,13 +1272,13 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias) {
#if __ARM_NEON
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *bias_data = bias.data<float>();
const float *bias_data = bias->data<float>();
const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]);
......@@ -1905,7 +1905,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias) {
#if __ARM_NEON
......@@ -1925,7 +1925,7 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
for (int c = 0; c < input_channel; c++) {
const float *filter_data = filter->data<float>() + c * 9;
const float *input_data = input->data<float>() + c * inhxw;
const float *bias_data = bias.data<float>() + c;
const float *bias_data = bias->data<float>() + c;
float *output_data = output->data<float>() + c * outhxw;
float w00 = filter_data[0];
float w01 = filter_data[1];
......
......@@ -50,7 +50,7 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
......@@ -62,7 +62,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias);
// TODO(hjchen2) need to be implemented
......
......@@ -23,10 +23,12 @@ limitations under the License. */
#if __aarch64__
#define MR_INT8 4
#define NR_INT8 2
#define MR 6
#define NR 16
#else
#define MR_INT8 4
#define NR_INT8 2
#define MR 6
#define NR 8
#endif
......@@ -193,52 +195,58 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int small block inner product
void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
// 8 bits int inner product
void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha,
const int8_t *a, const int8_t *b, int8_t beta,
int32_t *c, int32_t *C, int32_t ldc, bool relu,
int8_t *bias);
void InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int32_t *c, int32_t *C,
int32_t ldc, bool relu);
void InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int32_t *c, int8_t *C,
int32_t ldc, bool relu, int32_t *bias);
// 8 bits int pack function
void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer);
void PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer);
// 8 bits int matrix product
void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C,
int32_t ldc, bool relu, int8_t *bias);
void Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta,
int32_t *C, int32_t ldc, bool relu, int8_t *bias);
void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta, int32_t *C,
int32_t ldc, bool relu, int32_t *bias);
void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta, int8_t *C,
int32_t ldc, bool relu, int32_t *bias);
void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta,
int32_t *C, int32_t ldc, bool relu, int32_t *bias);
// 8 bits int write back
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B
void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc);
// C = A * B + C
void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B + bias
void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias);
// C = A * B + bias, scale * relu(C)
void WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale);
// C = A * B + bias, scale * C
void WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale);
private:
int MC = 0;
......@@ -254,7 +262,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int
int8_t *packedA_int8;
int8_t *packedB_int8;
int32_t *packedC_int8;
int32_t *packedC_int32;
int8_t *zero_int8;
};
......
......@@ -18,6 +18,8 @@ limitations under the License. */
#include "operators/math/gemm.h"
#if __ARM_NEON
#include <arm_neon.h>
#include <iostream>
#endif
#ifdef _OPENMP
#include <omp.h>
......@@ -62,7 +64,7 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
"pld [%[b_ptr], #128] \n\t"
"vld1.s8 {d0-d3}, [%[a_ptr]]! \n\t" // load A 8 cols
"vld1.s8 {d8-d11}, [%[b_ptr]]! \n\t" // load B first 4 rows
"vmovl.s8 q2, d0 \n\t" // process B first 4
"vmovl.s8 q2, d0 \n\t" // process B first
// rows
"vmovl.s8 q3, d8 \n\t"
"vmlal.s16 q8, d6, d4[0]\n\t"
......@@ -241,6 +243,132 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
#endif // __ARM_NEON
}
void Gemm::AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
#define PADDLE_LABEL_LOOP "1"
#define PADDLE_LABEL_AFTER_LOOP "2"
asm volatile(
"lsl %[ldc], %[ldc], #2 \n\t" // sizeof(int32) == 4
"vldr d0, [%[b], #0] \n\t"
"vmov.s32 q8, #0 \n\t"
"vldr d4, [%[a], #0] \n\t"
"vmov.s32 q9, q8 \n\t"
"vldr d2, [%[b], #16] \n\t"
"vmov.s32 q10, q8 \n\t"
"vldr d6, [%[a], #16] \n\t"
"vmov.s32 q11, q8 \n\t"
"vldr d1, [%[b], #8]\n\t"
"vmov.s32 q12, q8 \n\t"
"vldr d5, [%[a], #8]\n"
"vmov.s32 q13, q8 \n\t"
"vldr d3, [%[b], #24]\n\t"
"vmov.s32 q14, q8 \n\t"
"vldr d7, [%[a], #24]\n"
"vmov.s32 q15, q8 \n\t"
PADDLE_LABEL_LOOP
": \n\t"
"vmull.s8 q4, d0, d4 \n\t" // first half
"add %[b], %[b], #32 \n\t"
"vmull.s8 q5, d2, d4 \n\t"
"vldr d4, [%[a], #32] \n\t"
"vmull.s8 q6, d0, d6 \n\t"
"vmull.s8 q7, d2, d6 \n\t"
"vldr d6, [%[a], #48] \n\t"
"vmlal.s8 q4, d1, d5 \n\t" // second half
"vmlal.s8 q5, d3, d5 \n\t"
"vldr d5, [%[a], #40] \n\t"
"vmlal.s8 q6, d1, d7 \n\t"
"vmlal.s8 q7, d3, d7 \n\t"
"vldr d7, [%[a], #56] \n\t"
"vpadal.s16 q8, q4 \n\t" // pairwise-add
"add %[a], %[a], #64 \n\t"
"vpadal.s16 q9, q5 \n\t"
"subs %[k], %[k], #16 \n\t"
"vpadal.s16 q10, q6 \n\t"
"vpadal.s16 q11, q7 \n\t"
"beq " PADDLE_LABEL_AFTER_LOOP
"f \n\t"
"vmull.s8 q4, d0, d4 \n\t" // first half
"vmull.s8 q5, d2, d4 \n\t"
"vldr d4, [%[a], #0] \n\t"
"vmull.s8 q6, d0, d6 \n\t"
"vldr d0, [%[b], #0] \n\t"
"vmull.s8 q7, d2, d6 \n\t"
"vldr d2, [%[b], #16] \n\t"
"vmlal.s8 q4, d1, d5 \n\t" // second half
"vldr d6, [%[a], #16] \n\t"
"vmlal.s8 q5, d3, d5 \n\t"
"vldr d5, [%[a], #8] \n\t"
"vmlal.s8 q6, d1, d7 \n\t"
"vldr d1, [%[b], #8] \n\t"
"vmlal.s8 q7, d3, d7 \n\t"
"vldr d3, [%[b], #24] \n\t"
"vpadal.s16 q12, q4 \n\t" // pairwise-add
"vldr d7, [%[a], #24] \n\t"
"vpadal.s16 q13, q5 \n\t"
"vpadal.s16 q14, q6 \n\t"
"vpadal.s16 q15, q7 \n\t"
"b " PADDLE_LABEL_LOOP "b \n\t"
PADDLE_LABEL_AFTER_LOOP
": \n\t"
"vmull.s8 q4, d0, d4 \n\t" // first half
"vmull.s8 q5, d2, d4 \n\t"
"vmull.s8 q6, d0, d6 \n\t"
"vmull.s8 q7, d2, d6 \n\t"
"vmlal.s8 q4, d1, d5 \n\t" // second half
"vmlal.s8 q5, d3, d5 \n\t"
"vmlal.s8 q6, d1, d7 \n\t"
"vmlal.s8 q7, d3, d7 \n\t"
"vpadal.s16 q12, q4 \n\t" // pairwise-add
"vpadal.s16 q13, q5 \n\t"
"vpadal.s16 q14, q6 \n\t"
"vpadal.s16 q15, q7 \n\t"
"vpadd.s32 d0, d16, d17 \n\t" // reduce to int32
"vpadd.s32 d1, d18, d19 \n\t"
"vpadd.s32 d2, d20, d21 \n\t"
"vpadd.s32 d3, d22, d23 \n\t"
"vpadd.s32 d4, d24, d25 \n\t"
"vpadd.s32 d5, d26, d27 \n\t"
"vpadd.s32 d6, d28, d29 \n\t"
"vpadd.s32 d7, d30, d31 \n\t"
"vpadd.s32 d8, d0, d1 \n\t" // reduce to int32 again
"vpadd.s32 d9, d2, d3 \n\t"
"vpadd.s32 d10, d4, d5 \n\t"
"vpadd.s32 d11, d6, d7 \n\t"
"vst1.32 {d8}, [%[c]], %[ldc] \n\t"
"vst1.32 {d9}, [%[c]], %[ldc] \n\t"
"vst1.32 {d10}, [%[c]], %[ldc] \n\t"
"vst1.32 {d11}, [%[c]] \n\t"
: [k] "+r"(k), [a] "+r"(a), [b] "+r"(b), [c] "+r"(c)
: [ldc] "r"(ldc)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q11", "q12", "q13", "q14", "q15");
#undef PADDLE_LABEL_AFTER_LOOP
#undef PADDLE_LABEL_LOOP
#endif // __aarch64__
#endif // __ARM_NEON
}
// 8 bits int small block inner product
void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) {
......@@ -539,51 +667,213 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
}
// 8 bits int inner product
void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha,
const int8_t *a, const int8_t *b, int8_t beta,
int32_t *c, int32_t *C, int32_t ldc, bool relu,
int8_t *bias) {
void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int32_t *c, int32_t *C,
int32_t ldc, bool relu) {
#pragma omp parallel for
for (int32_t j = 0; j < nc; j += NR) {
for (int32_t j = 0; j < nc; j += NR_INT8) {
for (int32_t i = 0; i < mc; i += MR_INT8) {
#if __aarch64__
// TODO(wzzju)
#else
// AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot4x2(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif // __aarch64__
}
}
if (alpha != 1) {
WriteWithAlphaBeta(mc, nc, c, C, ldc);
return;
}
if (beta == 0) {
if (!relu) {
WriteBasic(mc, nc, c, C, ldc);
return;
}
if (beta == 1 && !relu) {
if (bias == nullptr) {
WriteWithAdd(mc, nc, c, C, ldc);
} else {
WriteWithAddV1(mc, nc, c, C, ldc, bias);
}
void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha,
const int8_t *a, const int8_t *b, float beta,
int32_t *c, int8_t *C, int32_t ldc, bool relu,
int32_t *bias) {
#pragma omp parallel for
for (int32_t j = 0; j < nc; j += NR_INT8) {
for (int32_t i = 0; i < mc; i += MR_INT8) {
#if __aarch64__
// TODO(wzzju)
#else
// AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot4x2(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif // __aarch64__
}
return;
}
if (beta == 1 && relu) {
if (bias == nullptr) {
WriteWithAddRelu(mc, nc, c, C, ldc);
if (relu) {
WriteWithAddReluScale(mc, nc, c, C, ldc, bias, alpha);
return;
} else {
WriteWithAddReluV1(mc, nc, c, C, ldc, bias);
WriteWithAddScale(mc, nc, c, C, ldc, bias, alpha);
}
}
// 8 bits int PackMatrixA_4r
void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
for (int32_t i = 0; i < i_length; i += 4) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
const int8_t *a3 = A + (i + 3) * lda;
int8_t *local_buffer = buffer + i * KC;
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (m_tail != 0) {
const int8_t *a0 = &A(i_length, 0);
const int8_t *a1 = a0 + lda;
const int8_t *a2 = a0 + 2 * lda;
const int8_t *a3 = a0 + 3 * lda;
int8_t *local_buffer = buffer + i_length * KC;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
break;
default:
break;
}
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
return;
}
}
// 8 bits int PackMatrixA_4r
void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer) {
const int8_t *a0, *a1, *a2, *a3;
for (int32_t i = 0; i < m - m_tail; i += MR_INT8) {
for (int32_t i = 0; i < m - m_tail; i += 4) {
a0 = A + i * lda;
a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda;
......@@ -625,7 +915,7 @@ void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
for (int32_t i = 0; i < i_length; i += MR_INT8) {
for (int32_t i = 0; i < i_length; i += 6) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
......@@ -676,11 +966,79 @@ void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
for (int32_t j = 0; j < j_length; j += 2) {
int8_t *local_buffer = buffer + j * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j);
const int8_t *b1 = &B((i << 4), j + 1);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b1;
b1 += ldb;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j);
const int8_t *b1 = &B((k_count << 4), j + 1);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (n_tail != 0) {
int8_t *local_buffer = buffer + j_length * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j_length);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = 0;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j_length);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
for (int32_t j = 0; j < j_length; j += NR) {
for (int32_t j = 0; j < j_length; j += 8) {
int8_t *local_buffer = buffer + j * k;
for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j);
......@@ -715,7 +1073,7 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
for (int32_t j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int32_t j = n; j < j_length + NR; ++j) {
for (int32_t j = n; j < j_length + 8; ++j) {
*local_buffer++ = 0;
}
}
......@@ -723,19 +1081,20 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
}
// 8 bits int matrix product (m*k x k*n)
void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta,
int32_t *C, int32_t ldc, bool relu, int8_t *bias) {
void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta,
int32_t *C, int32_t ldc, bool relu, int32_t *bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int32_t L1 = 32 * 1024;
int32_t L2 = 512 * 1024;
KC = k;
const int32_t k_complete = (k + 15) - ((k + 15) & 15);
KC = k_complete;
MC = L1 / (KC * sizeof(int8_t));
NC = L2 / (KC * sizeof(int8_t));
// make sure MC is multiple of MR_INT8, and NC is multiple of NR
// make sure MC is multiple of MR_INT8, and NC is multiple of NR_INT8
if (MC == 0) {
MC = MR_INT8;
} else {
......@@ -745,52 +1104,106 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
}
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
if (NC == 0) {
NC = NR;
NC = NR_INT8;
} else {
int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8;
}
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
packedC_int8 = static_cast<int32_t *>(
packedC_int32 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC));
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC));
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * KC);
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * k);
int32_t mc, nc;
for (int32_t j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8);
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8);
for (int32_t i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
// PackMatrixA_6r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8);
PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8);
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8);
if (bias == nullptr) {
InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int8, &C(i, j), ldc, relu, nullptr);
InnerKernel(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int32, &C(i, j), ldc, relu);
}
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int32);
paddle_mobile::memory::Free(zero_int8);
}
// 8 bits int matrix product (m*k x k*n)
void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta,
int8_t *C, int32_t ldc, bool relu, int32_t *bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int32_t L1 = 32 * 1024;
int32_t L2 = 512 * 1024;
const int32_t k_complete = (k + 15) - ((k + 15) & 15);
KC = k_complete;
MC = L1 / (KC * sizeof(int8_t));
NC = L2 / (KC * sizeof(int8_t));
// make sure MC is multiple of MR_INT8, and NC is multiple of NR_INT8
if (MC == 0) {
MC = MR_INT8;
} else {
int32_t mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8;
}
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
if (NC == 0) {
NC = NR_INT8;
} else {
int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8;
}
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
packedC_int32 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC));
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * k);
int32_t mc, nc;
for (int32_t j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8);
for (int32_t i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8);
if (bias != nullptr) {
InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int8, &C(i, j), ldc, relu, bias + i);
packedC_int32, &C(i, j), ldc, relu, bias + i);
}
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int8);
paddle_mobile::memory::Free(packedC_int32);
paddle_mobile::memory::Free(zero_int8);
}
// 8 bits int write back
// C = alpha * A * B + beta * C
void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) {}
// C = A * B, 8位 int32_t
// C = A * B
void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) {
#if __ARM_NEON
......@@ -802,7 +1215,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t step = sizeof(int32_t) * ldc;
int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 4));
int32_t volatile m = mc;
int32_t volatile n = nc1;
int32_t *volatile c_ptr, *volatile C_ptr;
int32_t *C0, *c0;
c_ptr = c;
......@@ -836,7 +1249,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
"end_mc_%=: \n\t"
:
: [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1),
: [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(n),
[step] "r"(step), [step1] "r"(step1)
: "memory", "r5", "r6", "q0", "q1", "q2", "q3");
}
......@@ -854,20 +1267,254 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
#endif // __ARM_NEON
}
// C = A * B + C
void Gemm::WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) {}
// C = A * B + bias, scale * C
void Gemm::WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
int32_t zero = 0;
int8_t narrow = -128;
int32_t nc1 = nc >> 3;
int32_t _nc1 = nc & 7;
int32_t step = sizeof(int8_t) * ldc;
int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 3));
int32_t volatile m = mc;
int32_t volatile n = nc1;
int32_t *volatile c_ptr, *volatile bias_ptr;
int8_t *volatile C_ptr;
c_ptr = c;
C_ptr = C;
bias_ptr = bias;
if (nc1 > 0) {
asm volatile(
"subs %[mc], %[mc], #1 \n\t"
"blt end_mc_%= \n\t"
"vdup.32 q15, %[scale] \n\t"
"vdup.32 q14, %[zero] \n\t"
"vdup.8 d24, %[narrow] \n\t"
"loop_mc_%=: \n\t"
"vld1.32 {d26[0]}, [%[bias_ptr]]!\n\t"
"vdup.32 q13, d26[0] \n\t"
"mov r6, %[C_ptr] \n\t"
"mov r5, %[nc1] \n\t"
"subs r5, r5, #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c_ptr]]! \n\t"
"vqadd.s32 q0, q0, q13 \n\t"
"vqadd.s32 q1, q1, q13 \n\t"
"vcvt.f32.s32 q2, q0 \n\t"
"vcvt.f32.s32 q3, q1 \n\t"
"vmul.f32 q2, q2, q15 \n\t"
"vmul.f32 q3, q3, q15 \n\t"
"vcvt.s32.f32 q4, q2 \n\t"
"vcvt.s32.f32 q5, q3 \n\t"
"vqmovn.s32 d12, q4 \n\t"
"vqmovn.s32 d13, q5 \n\t"
"vqmovn.s16 d14, q6 \n\t"
"vceq.s8 d15, d14, d24 \n\t"
"vsub.s8 d14, d14, d15 \n\t"
"vst1.8 {d14}, [r6]! \n\t"
"subs r5, r5, #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"add %[C_ptr], %[C_ptr], %[step] \n\t"
"add %[c_ptr], %[c_ptr], %[step1] \n\t"
"subs %[mc], %[mc], #1 \n\t"
"bge loop_mc_%= \n\t"
"end_mc_%=: \n\t"
:
: [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(n),
[step] "r"(step), [step1] "r"(step1), [bias_ptr] "r"(bias_ptr),
[scale] "r"(scale), [zero] "r"(zero), [narrow] "r"(narrow)
: "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q12", "q13", "q14", "q15");
}
int32_t nc_left;
int32_t *c0;
int8_t *C0;
int32_t bias_v;
if (_nc1 != 0) {
for (int32_t i = 0; i < mc; i++) {
C0 = C_ptr + nc1 * 8 + i * ldc;
c0 = c_ptr + nc1 * 8 + i * NC;
bias_v = *(bias_ptr + i);
nc_left = _nc1;
asm volatile(
"vdup.32 q15, %[scale] \n\t"
"vdup.32 q14, %[zero] \n\t"
"vdup.8 d24, %[narrow] \n\t"
"vdup.32 q13, %[bias_v] \n\t"
"cmp %[_nc1], #4 \n\t"
"blt less_four_%= \n\t"
"vld1.32 {q0}, [%[c0]]! \n\t"
"vqadd.s32 q0, q0, q13 \n\t"
"vcvt.f32.s32 q1, q0 \n\t"
"vmul.f32 q1, q1, q15 \n\t"
"vcvt.s32.f32 q2, q1 \n\t"
"vqmovn.s32 d6, q2 \n\t"
"vqmovn.s16 d8, q3 \n\t"
"vceq.s8 d9, d8, d24 \n\t"
"vsub.s8 d8, d8, d9 \n\t"
"vst1.8 {d8[0]}, [%[C0]]! \n\t"
"vst1.8 {d8[1]}, [%[C0]]! \n\t"
"vst1.8 {d8[2]}, [%[C0]]! \n\t"
"vst1.8 {d8[3]}, [%[C0]]! \n\t"
"subs %[_nc1], %[_nc1], #4 \n\t"
"beq process_over_%= \n\t"
"less_four_%=: \n\t"
"vld1.32 {q0}, [%[c0]]! \n\t"
"vqadd.s32 q0, q0, q13 \n\t"
"vcvt.f32.s32 q1, q0 \n\t"
"vmul.f32 q1, q1, q15 \n\t"
"vcvt.s32.f32 q2, q1 \n\t"
"vqmovn.s32 d6, q2 \n\t"
"vqmovn.s16 d8, q3 \n\t"
"vceq.s8 d9, d8, d24 \n\t"
"vsub.s8 d8, d8, d9 \n\t"
"loop_save_%=: \n\t"
"vst1.8 {d8[0]}, [%[C0]]! \n\t"
"vext.8 d8, d8, d8, #1 \n\t"
"subs %[_nc1], %[_nc1], #1 \n\t"
"bgt loop_save_%= \n\t"
"process_over_%=: \n\t"
:
: [_nc1] "r"(nc_left), [C0] "r"(C0), [c0] "r"(c0),
[bias_v] "r"(bias_v), [scale] "r"(scale), [zero] "r"(zero),
[narrow] "r"(narrow)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q12", "q13", "q14",
"q15");
}
}
#endif // __aarch64__
#endif // __ARM_NEON
}
// C = A * B + bias, scale * relu(C)
void Gemm::WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
int32_t zero = 0;
int32_t nc1 = nc >> 3;
int32_t _nc1 = nc & 7;
int32_t step = sizeof(int8_t) * ldc;
int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 3));
int32_t volatile m = mc;
int32_t volatile n = nc1;
int32_t *volatile c_ptr, *volatile bias_ptr;
int8_t *volatile C_ptr;
c_ptr = c;
C_ptr = C;
bias_ptr = bias;
if (nc1 > 0) {
asm volatile(
"subs %[mc], %[mc], #1 \n\t"
"blt end_mc_%= \n\t"
"vdup.32 q15, %[scale] \n\t"
"vdup.32 q14, %[zero] \n\t"
"loop_mc_%=: \n\t"
"vld1.32 {d26[0]}, [%[bias_ptr]]!\n\t"
"vdup.32 q13, d26[0] \n\t"
"mov r6, %[C_ptr] \n\t"
"mov r5, %[nc1] \n\t"
"subs r5, r5, #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c_ptr]]! \n\t"
"vqadd.s32 q0, q0, q13 \n\t"
"vqadd.s32 q1, q1, q13 \n\t"
"vmax.s32 q0, q0, q14 \n\t"
"vmax.s32 q1, q1, q14 \n\t"
"vcvt.f32.s32 q2, q0 \n\t"
"vcvt.f32.s32 q3, q1 \n\t"
"vmul.f32 q2, q2, q15 \n\t"
"vmul.f32 q3, q3, q15 \n\t"
"vcvt.s32.f32 q4, q2 \n\t"
"vcvt.s32.f32 q5, q3 \n\t"
"vqmovn.s32 d12, q4 \n\t"
"vqmovn.s32 d13, q5 \n\t"
"vqmovn.s16 d14, q6 \n\t"
"vst1.8 {d14}, [r6]! \n\t"
"subs r5, r5, #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"add %[C_ptr], %[C_ptr], %[step] \n\t"
"add %[c_ptr], %[c_ptr], %[step1] \n\t"
"subs %[mc], %[mc], #1 \n\t"
"bge loop_mc_%= \n\t"
"end_mc_%=: \n\t"
// C = A * B + bias
void Gemm::WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias) {}
// C = A * B + C, relu(C)
void Gemm::WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) {}
:
: [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(n),
[step] "r"(step), [step1] "r"(step1), [bias_ptr] "r"(bias_ptr),
[scale] "r"(scale), [zero] "r"(zero)
: "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q13", "q14", "q15");
}
// C = A * B + bias, relu(C)
void Gemm::WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias) {}
int32_t nc_left;
int32_t *c0;
int8_t *C0;
int32_t bias_v;
if (_nc1 != 0) {
for (int32_t i = 0; i < mc; i++) {
C0 = C_ptr + nc1 * 8 + i * ldc;
c0 = c_ptr + nc1 * 8 + i * NC;
bias_v = *(bias_ptr + i);
nc_left = _nc1;
asm volatile(
"vdup.32 q15, %[scale] \n\t"
"vdup.32 q14, %[zero] \n\t"
"vdup.32 q13, %[bias_v] \n\t"
"cmp %[_nc1], #4 \n\t"
"blt less_four_%= \n\t"
"vld1.32 {q0}, [%[c0]]! \n\t"
"vqadd.s32 q0, q0, q13 \n\t"
"vmax.s32 q0, q0, q14 \n\t"
"vcvt.f32.s32 q1, q0 \n\t"
"vmul.f32 q1, q1, q15 \n\t"
"vcvt.s32.f32 q2, q1 \n\t"
"vqmovn.s32 d6, q2 \n\t"
"vqmovn.s16 d8, q3 \n\t"
"vst1.8 {d8[0]}, [%[C0]]! \n\t"
"vst1.8 {d8[1]}, [%[C0]]! \n\t"
"vst1.8 {d8[2]}, [%[C0]]! \n\t"
"vst1.8 {d8[3]}, [%[C0]]! \n\t"
"subs %[_nc1], %[_nc1], #4 \n\t"
"beq process_over_%= \n\t"
"less_four_%=: \n\t"
"vld1.32 {q0}, [%[c0]]! \n\t"
"vqadd.s32 q0, q0, q13 \n\t"
"vmax.s32 q0, q0, q14 \n\t"
"vcvt.f32.s32 q1, q0 \n\t"
"vmul.f32 q1, q1, q15 \n\t"
"vcvt.s32.f32 q2, q1 \n\t"
"vqmovn.s32 d6, q2 \n\t"
"vqmovn.s16 d8, q3 \n\t"
"loop_save_%=: \n\t"
"vst1.8 {d8[0]}, [%[C0]]! \n\t"
"vext.8 d8, d8, d8, #1 \n\t"
"subs %[_nc1], %[_nc1], #1 \n\t"
"bgt loop_save_%= \n\t"
"process_over_%=: \n\t"
:
: [_nc1] "r"(nc_left), [C0] "r"(C0), [c0] "r"(c0),
[bias_v] "r"(bias_v), [scale] "r"(scale), [zero] "r"(zero)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q13", "q14", "q15");
}
}
#endif // __aarch64__
#endif // __ARM_NEON
}
} // namespace math
} // namespace operators
......
......@@ -28,10 +28,10 @@ namespace operators {
namespace math {
// 8 bits int matrix product (m*k x k*n)
void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb,
int8_t beta, int32_t *C, int32_t ldc, bool relu,
int8_t *bias) {
float beta, int32_t *C, int32_t ldc, bool relu,
int32_t *bias) {
#ifdef _OPENMP
int32_t max_threads = omp_get_max_threads();
#else
......@@ -39,10 +39,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
#endif
int32_t L1 = 64 / max_threads * 1024;
KC = k;
const int32_t k_complete = (k + 15) - ((k + 15) & 15);
KC = k_complete;
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * KC);
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * k);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(int8_t));
......@@ -54,14 +55,14 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8;
}
// 补齐 B
NC = (n + NR - 1) / NR * NR;
NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8;
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixB_omp_8c(KC, n, n % NR, B, ldb, packedB_int8);
PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#endif
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads));
......@@ -69,11 +70,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
// 对 B 分块
NC = L1 / (KC * sizeof(int8_t));
if (NC == 0) {
NC = NR;
NC = NR_INT8;
} else {
int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8;
}
// 补齐 A
MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8;
......@@ -83,12 +84,12 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixA_omp_4r(m, KC, m % MR_INT8, A, lda, packedA_int8);
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#endif
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads));
}
packedC_int8 = static_cast<int32_t *>(
packedC_int32 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads));
if (m > n) {
......@@ -103,14 +104,19 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
int32_t mc;
mc = s_min(m - i, MC);
int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int8 + MC * NC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, local_A);
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif
InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu, bias + i);
// InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta,
// local_C,
// &C(i, 0), ldc, relu, bias + i);
if (bias == nullptr) {
InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu);
}
}
} else {
#pragma omp parallel for
......@@ -123,20 +129,25 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
int32_t nc;
nc = s_min(n - j, NC);
int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int8 + MC * NC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, local_B);
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#endif
InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu, bias);
// InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta,
// local_C,
// &C(0, j), ldc, relu, bias);
if (bias == nullptr) {
InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu);
}
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int8);
paddle_mobile::memory::Free(packedC_int32);
paddle_mobile::memory::Free(zero_int8);
}
......@@ -144,7 +155,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
#pragma omp parallel for
for (int32_t j = 0; j < j_length; j += NR) {
for (int32_t j = 0; j < j_length; j += 8) {
int8_t *local_buffer = buffer + j * k;
for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j);
......@@ -179,7 +190,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
for (int32_t j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int32_t j = n; j < j_length + NR; ++j) {
for (int32_t j = n; j < j_length + 8; ++j) {
*local_buffer++ = 0;
}
}
......@@ -188,9 +199,9 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) {
const int i_length = m - m_tail;
const int32_t i_length = m - m_tail;
#pragma omp parallel for
for (int32_t i = 0; i < i_length; i += MR_INT8) {
for (int32_t i = 0; i < i_length; i += 4) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
......@@ -221,7 +232,7 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
default:
break;
}
for (int j = 0; j < k; ++j) {
for (int32_t j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
......@@ -230,6 +241,232 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
}
}
// 8 bits int PackMatrixA_4r
void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t i = 0; i < i_length; i += 4) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
const int8_t *a3 = A + (i + 3) * lda;
int8_t *local_buffer = buffer + i * KC;
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (m_tail != 0) {
const int8_t *a0 = &A(i_length, 0);
const int8_t *a1 = a0 + lda;
const int8_t *a2 = a0 + 2 * lda;
const int8_t *a3 = a0 + 3 * lda;
int8_t *local_buffer = buffer + i_length * KC;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
break;
default:
break;
}
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t j = 0; j < j_length; j += 2) {
int8_t *local_buffer = buffer + j * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j);
const int8_t *b1 = &B((i << 4), j + 1);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b1;
b1 += ldb;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j);
const int8_t *b1 = &B((k_count << 4), j + 1);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (n_tail != 0) {
int8_t *local_buffer = buffer + j_length * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j_length);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = 0;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j_length);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -28,7 +28,12 @@ template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false,
T *bias = nullptr);
float *bias = nullptr);
void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu = false,
int32_t *bias = nullptr);
template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
......@@ -20,11 +20,10 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
template <>
void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
int8_t alpha, framework::Tensor *matrix_out, int8_t beta,
bool relu, int8_t *bias) {
void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu,
int32_t *bias) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -52,21 +51,45 @@ void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a,
}
#ifdef _OPENMP
if (bias != nullptr) {
// TODO(wzzju): gemm.Sgemm_omp_with_bias, now use single thread instead.
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias);
} else {
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
#else
if (bias != nullptr) {
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias);
} else {
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
#endif
} else {
#ifdef _OPENMP
if (bias != nullptr) {
// TODO(wzzju): gemm.Sgemm_omp_with_bias, now use single thread instead.
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int8_t>(),
N, relu, bias);
} else {
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
#else
if (bias != nullptr) {
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(), N,
relu, bias);
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int8_t>(),
N, relu, bias);
} else {
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(),
N, relu, bias);
}
#endif
}
}
......
......@@ -419,6 +419,8 @@ class ConvParam : public OpParam {
EXEC_INVALID = 0,
EXEC_GEMM_FLOAT,
EXEC_DEPTHWISE3x3S1P1_FLOAT,
EXEC_DEPTHWISE3x3S2P0_FLOAT,
EXEC_DEPTHWISE3x3S2P1_FLOAT,
EXEC_DEPTHWISE3x3_FLOAT,
EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT,
......@@ -2573,7 +2575,9 @@ class DequantizeParam : public OpParam {
DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope);
if (outputs.count("Out")) {
output_ = OutFrom<GType>(outputs, scope);
}
activation_scale_ = OpParam::GetVarValue<GType>("Scale", inputs, scope);
// dequantization is performed as x = x / static_scale / online_scale
if (HasAttr("weight_scale", attrs)) {
......@@ -2593,20 +2597,19 @@ class DequantizeParam : public OpParam {
};
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#if defined(FUSION_DEQUANT_ADD_BN_OP) || \
defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || \
defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_BN_OP)
template <typename Dtype>
class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> {
class FusionDequantBNParam : public DequantizeParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNReluParam(const VariableNameMap &inputs,
FusionDequantBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: DequantizeParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
// batch norm params
bn_mean_ = OpParam::GetVarValue<GType>("BNMean", inputs, scope);
bn_variance_ = OpParam::GetVarValue<GType>("BNVariance", inputs, scope);
......@@ -2614,21 +2617,83 @@ class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> {
bn_bias_ = OpParam::GetVarValue<GType>("BNBias", inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
// output
output_ = OpParam::OutFrom<GType>(outputs, scope);
if (outputs.count("Y")) {
this->output_ = OpParam::OutputYFrom<GType>(outputs, scope);
}
}
public:
// elementwise add
int axis_;
RType *bias_;
// batch norm
RType *bn_mean_;
RType *bn_variance_;
RType *bn_scale_;
RType *bn_bias_;
float epsilon_;
};
#endif
#if defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_OP)
template <typename Dtype>
class FusionDequantAddBNParam : public FusionDequantBNParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionDequantBNParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
// output
RType *output_;
if (outputs.count("Y")) {
this->output_ = OpParam::OutputYFrom<GType>(outputs, scope);
}
}
public:
// elementwise add
int axis_;
RType *bias_;
};
#endif
#ifdef FUSION_DEQUANT_BN_RELU_OP
template <typename Dtype>
class FusionDequantBNReluParam : public FusionDequantBNParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionDequantBNParam<Dtype>(inputs, outputs, attrs, scope) {
// output
if (outputs.count("Out")) {
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
}
};
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <typename Dtype>
class FusionDequantAddBNReluParam : public FusionDequantAddBNParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionDequantAddBNParam<Dtype>(inputs, outputs, attrs, scope) {
// output
if (outputs.count("Out")) {
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
}
};
#endif
......
......@@ -28,7 +28,7 @@ limitations under the License. */
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(8);
paddle_mobile.SetThreadNum(4);
Tensor aa, bb, cc;
auto aaptr = aa.mutable_data<float>({m, k});
auto bbptr = bb.mutable_data<float>({k, n});
......@@ -44,10 +44,12 @@ int main() {
ccptr[i] = 2;
}
Tensor aa_int8, bb_int8, cc_int8;
Tensor aa_int8, bb_int8, cc_int32, cc_int8;
auto aaptr_int8 = aa_int8.mutable_data<int8_t>({m, k});
auto bbptr_int8 = bb_int8.mutable_data<int8_t>({k, n});
auto ccptr_int8 = cc_int8.mutable_data<int32_t>({m, n});
auto ccptr_int32 = cc_int32.mutable_data<int32_t>({m, n});
auto ccptr_int8 = cc_int8.mutable_data<int8_t>({m, n});
int32_t* bias_data = new int32_t[m];
for (int i = 0; i < m * k; ++i) {
aaptr_int8[i] = static_cast<int8_t>(2);
......@@ -56,7 +58,11 @@ int main() {
bbptr_int8[i] = static_cast<int8_t>(2);
}
for (int i = 0; i < m * n; ++i) {
ccptr_int8[i] = static_cast<int32_t>(2);
ccptr_int32[i] = static_cast<int32_t>(2);
}
for (int i = 0; i < m; ++i) {
bias_data[i] = 2;
}
// float
......@@ -76,22 +82,41 @@ int main() {
auto time2 = time();
std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n";
// int8_t
// int8_t without bias
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t>(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8,
static_cast<int8_t>(0), false, nullptr);
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, nullptr);
}
auto time3 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t>(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8,
static_cast<int8_t>(0), false, nullptr);
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, nullptr);
}
auto time4 = time();
std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n";
// int8_t with bias&relu
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int8,
static_cast<float>(0), true, &bias_data[0]);
}
auto time5 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int8,
static_cast<float>(0), true, &bias_data[0]);
}
auto time6 = time();
std::cout << "int8_t gemm_with_bias_relu cost :"
<< time_diff(time5, time6) / 10 << "ms\n";
delete[] bias_data;
return 0;
}
......@@ -249,7 +249,9 @@ if(NOT FOUND_MATCH)
set(SUM_OP ON)
set(QUANT_OP ON)
set(DEQUANT_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU ON)
set(FUSION_DEQUANT_ADD_BN_OP ON)
set(FUSION_DEQUANT_BN_RELU_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU_OP ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -451,10 +453,17 @@ endif()
if (DEQUANT_OP)
add_definitions(-DDEQUANT_OP)
endif()
if (FUSION_DEQUANT_ADD_BN_RELU)
if (FUSION_DEQUANT_ADD_BN_OP)
add_definitions(-DFUSION_DEQUANT_ADD_BN_OP)
endif()
if (FUSION_DEQUANT_BN_RELU_OP)
add_definitions(-DFUSION_DEQUANT_BN_RELU_OP)
endif()
if (FUSION_DEQUANT_ADD_BN_RELU_OP)
add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP)
endif()
if (TANH_OP)
add_definitions(-DTANH_OP)
endif()
......@@ -467,3 +476,4 @@ endif()
if (FUSION_DECONVADDRELU_OP)
add_definitions(-DFUSION_DECONVADDRELU_OP)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册