提交 8ea3708c 编写于 作者: H hjchen2

Merge depthwise_conv2d and conv2d kernel, add some dequant fusion kernels

上级 e54bf8c5
...@@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum"; ...@@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_QUANTIZE = "quantize"; const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize"; const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN = "fusion_dequant_add_bn";
const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU = "fusion_dequant_bn_relu";
const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu"; const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu";
const char *G_OP_TYPE_TANH = "tanh"; const char *G_OP_TYPE_TANH = "tanh";
...@@ -136,6 +138,8 @@ std::unordered_map< ...@@ -136,6 +138,8 @@ std::unordered_map<
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}}, {G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_ADD_BN, {{"X", "Scale"}, {"Y"}}},
{G_OP_TYPE_FUSION_DEQUANT_BN_RELU, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_TANH, {{"X"}, {"Out"}}}, {G_OP_TYPE_TANH, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}},
......
...@@ -138,6 +138,8 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL; ...@@ -138,6 +138,8 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_QUANTIZE; extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE; extern const char *G_OP_TYPE_DEQUANTIZE;
extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN;
extern const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU;
extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU;
extern const char *G_OP_TYPE_TANH; extern const char *G_OP_TYPE_TANH;
......
...@@ -233,6 +233,14 @@ LOAD_OP1(quantize, CPU); ...@@ -233,6 +233,14 @@ LOAD_OP1(quantize, CPU);
#ifdef DEQUANT_OP #ifdef DEQUANT_OP
LOAD_OP1(dequantize, CPU); LOAD_OP1(dequantize, CPU);
#endif #endif
#ifdef FUSION_DEQUANT_ADD_BN_OP
LOAD_OP1(fusion_dequant_add_bn, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn);
#endif
#ifdef FUSION_DEQUANT_BN_RELU_OP
LOAD_OP1(fusion_dequant_bn_relu, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_bn_relu);
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP #ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
LOAD_OP1(fusion_dequant_add_bn_relu, CPU); LOAD_OP1(fusion_dequant_add_bn_relu, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu); LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu);
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include <string> #include <string>
#include "framework/operator.h" #include "framework/operator.h"
#include "operators/kernel/depthwise_conv_kernel.h" #include "operators/kernel/conv_kernel.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -26,19 +26,16 @@ namespace operators { ...@@ -26,19 +26,16 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class DepthwiseConvOp : public framework::OperatorWithKernel< class DepthwiseConvOp : public framework::OperatorWithKernel<
DeviceType, ConvParam<DeviceType>, DeviceType, ConvParam<DeviceType>,
operators::DepthwiseConvKernel<DeviceType, T>> { operators::ConvKernel<DeviceType, T>> {
public: public:
DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs, DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<DeviceType, ConvParam<DeviceType>,
DeviceType, ConvParam<DeviceType>, operators::ConvKernel<DeviceType, T>>(
operators::DepthwiseConvKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
void InferShape() const override; void InferShape() const override;
private:
}; };
} // namespace operators } // namespace operators
......
...@@ -12,27 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,27 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef DEPTHWISECONV_OP #ifdef FUSION_DEQUANT_ADD_BN_OP
#include "operators/kernel/depthwise_conv_kernel.h" #include "operators/fusion_dequant_add_bn_op.h"
#include "operators/kernel/central-arm-func/depthwise_conv_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <typename Dtype, typename T>
bool DepthwiseConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { void FusionDequantAddBNOp<Dtype, T>::InferShape() const {
return true; const auto& input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
} }
template <>
void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
DepthwiseConvCompute<float>(param);
}
template class DepthwiseConvKernel<CPU, float>;
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_dequant_add_bn, ops::FusionDequantAddBNMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dequant_add_bn, ops::FusionDequantAddBNOp);
#endif
#endif #endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionDequantAddBNMatcher : public framework::FusionOpMatcher {
public:
FusionDequantAddBNMatcher() {
node_ = framework::Node(G_OP_TYPE_DEQUANTIZE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}},
{G_OP_TYPE_BATCHNORM,
{{"Scale", "BNScale"},
{"Mean", "BNMean"},
{"Bias", "BNBias"},
{"Variance", "BNVariance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN; }
};
template <typename DeviceType, typename T>
class FusionDequantAddBNOp
: public framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNParam<DeviceType>,
operators::FusionDequantAddBNKernel<DeviceType, T>> {
public:
FusionDequantAddBNOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNParam<DeviceType>,
operators::FusionDequantAddBNKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "framework/operator.h" #include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h" #include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_relu_kernel.h" #include "operators/kernel/dequant_bn_relu_kernel.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -12,29 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,29 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef DEPTHWISECONV_OP #ifdef FUSION_DEQUANT_BN_RELU_OP
#pragma once #include "operators/fusion_dequant_bn_relu_op.h"
#include "framework/operator.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
using framework::OpKernelBase; template <typename Dtype, typename T>
void FusionDequantBNReluOp<Dtype, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
}
template <typename DeviceType, typename T>
class DepthwiseConvKernel
: public OpKernelBase<DeviceType, ConvParam<DeviceType>> {
public:
void Compute(const ConvParam<DeviceType> &param);
bool Init(ConvParam<DeviceType> *param);
};
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_dequant_bn_relu,
ops::FusionDequantBNReluMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dequant_bn_relu, ops::FusionDequantBNReluOp);
#endif
#endif #endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_BN_RELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionDequantBNReluMatcher : public framework::FusionOpMatcher {
public:
FusionDequantBNReluMatcher() {
node_ = framework::Node(G_OP_TYPE_DEQUANTIZE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_BATCHNORM,
{{"Scale", "BNScale"},
{"Mean", "BNMean"},
{"Bias", "BNBias"},
{"Variance", "BNVariance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_BN_RELU; }
};
template <typename DeviceType, typename T>
class FusionDequantBNReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionDequantBNReluParam<DeviceType>,
operators::FusionDequantBNReluKernel<DeviceType, T>> {
public:
FusionDequantBNReluOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDequantBNReluParam<DeviceType>,
operators::FusionDequantBNReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -22,33 +22,35 @@ namespace operators { ...@@ -22,33 +22,35 @@ namespace operators {
template <> template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3;
bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == typeid(int8_t)) { if (param->Filter()->type() == typeid(int8_t)) {
if (param->Groups() == param->Input()->dims()[1] && if (depth3x3 && param->Strides()[0] < 3 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) { param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8; param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
} else { } else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8; param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
} }
} else { } else {
if (param->Groups() == param->Input()->dims()[1] && if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] && param->Strides()[0] == 1 && param->Paddings()[0] == 1 &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Paddings()[0] == param->Paddings()[1]) {
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT;
} else if (param->Groups() == param->Input()->dims()[1] && } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] && param->Strides()[0] == 2 && param->Paddings()[0] == 0 &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Paddings()[0] == param->Paddings()[1]) {
param->Filter()->dims()[2] == 3) { param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT;
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT; } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
#ifndef __aarch64__ #ifndef __aarch64__
} else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] && } else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] && param->Dilations()[0] == param->Dilations()[1] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1 && param->Strides()[0] == 1 && param->Dilations()[0] == 1 &&
param->Dilations()[0] == 1 && param->Output()->dims()[1] >= 16 && param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 && param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 /* refered from ncnn */) { param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
...@@ -78,9 +80,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) { ...@@ -78,9 +80,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false); nullptr, false);
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Filter(), nullptr, param.Output(), false); param.Output(), nullptr, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false);
break; break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT: case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param); WinogradConv3x3<8, 3>(param);
......
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP #ifdef FUSION_DEQUANT_ADD_BN_OP
#include "operators/kernel/dequant_add_bn_relu_kernel.h" #include "operators/kernel/dequant_add_bn_kernel.h"
#include <cmath> #include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h> #include <arm_neon.h>
...@@ -24,8 +24,8 @@ namespace paddle_mobile { ...@@ -24,8 +24,8 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool FusionDequantAddBNReluKernel<CPU, float>::Init( bool FusionDequantAddBNKernel<CPU, float>::Init(
FusionDequantAddBNReluParam<CPU> *param) { FusionDequantAddBNParam<CPU> *param) {
// elementwise add params // elementwise add params
const Tensor *bias = param->bias_; const Tensor *bias = param->bias_;
// batch norm params // batch norm params
...@@ -49,8 +49,8 @@ bool FusionDequantAddBNReluKernel<CPU, float>::Init( ...@@ -49,8 +49,8 @@ bool FusionDequantAddBNReluKernel<CPU, float>::Init(
} }
template <> template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute( void FusionDequantAddBNKernel<CPU, float>::Compute(
const FusionDequantAddBNReluParam<CPU> &param) { const FusionDequantAddBNParam<CPU> &param) {
const int32_t *input = param.input_->data<int32_t>(); const int32_t *input = param.input_->data<int32_t>();
const float *bn_scale = param.bn_scale_->data<float>(); const float *bn_scale = param.bn_scale_->data<float>();
const float *bn_bias = param.bn_bias_->data<float>(); const float *bn_bias = param.bn_bias_->data<float>();
...@@ -78,7 +78,6 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute( ...@@ -78,7 +78,6 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
remain = spatial_size & 0xF; remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale); float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias); float32x4_t __bias = vdupq_n_f32(bias);
float32x4_t __zero = vdupq_n_f32(0.f);
for (int k = 0; k < loop; ++k, x += 16, y += 16) { for (int k = 0; k < loop; ++k, x += 16, y += 16) {
int32x4_t r0 = vld1q_s32(x); int32x4_t r0 = vld1q_s32(x);
...@@ -93,10 +92,6 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute( ...@@ -93,10 +92,6 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
f1 = vmlaq_f32(__bias, __scale, f1); f1 = vmlaq_f32(__bias, __scale, f1);
f2 = vmlaq_f32(__bias, __scale, f2); f2 = vmlaq_f32(__bias, __scale, f2);
f3 = vmlaq_f32(__bias, __scale, f3); f3 = vmlaq_f32(__bias, __scale, f3);
f0 = vmaxq_f32(__zero, f0);
f1 = vmaxq_f32(__zero, f1);
f2 = vmaxq_f32(__zero, f2);
f3 = vmaxq_f32(__zero, f3);
vst1q_f32(y, f0); vst1q_f32(y, f0);
vst1q_f32(y + 4, f1); vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2); vst1q_f32(y + 8, f2);
...@@ -104,7 +99,7 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute( ...@@ -104,7 +99,7 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
} }
#endif // __ARM_NEON__ #endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) { for (int k = 0; k < remain; ++k) {
y[k] = std::max(scale * x[k] + bias, 0.f); y[k] = scale * x[k] + bias;
} }
} }
} }
...@@ -113,4 +108,4 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute( ...@@ -113,4 +108,4 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP #endif // FUSION_DEQUANT_ADD_BN_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/kernel/dequant_bn_relu_kernel.h"
#include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
#if defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_RELU_OP)
void DequantBNReluCompute(const FusionDequantBNParam<CPU> *param) {
const int32_t *input = param->input_->data<int32_t>();
const float *bn_scale = param->bn_scale_->data<float>();
const float *bn_bias = param->bn_bias_->data<float>();
// dequantize params
const float activation_scale = param->activation_scale_->data<float>()[0];
const float weight_scale = param->weight_scale_;
const float dequant_scale = activation_scale / weight_scale;
float *output = param->output_->mutable_data<float>();
int batch_size = param->input_->dims()[0];
int channels = param->input_->dims()[1];
size_t spatial_size = param->input_->dims()[2] * param->input_->dims()[3];
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channels; ++c) {
float scale = bn_scale[c] * dequant_scale;
float bias = bn_bias[c];
size_t offset = (batch * channels + c) * spatial_size;
const int32_t *x = input + offset;
float *y = output + offset;
size_t remain = spatial_size;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = spatial_size >> 4;
remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias);
float32x4_t __zero = vdupq_n_f32(0.f);
for (int k = 0; k < loop; ++k, x += 16, y += 16) {
int32x4_t r0 = vld1q_s32(x);
int32x4_t r1 = vld1q_s32(x + 4);
int32x4_t r2 = vld1q_s32(x + 8);
int32x4_t r3 = vld1q_s32(x + 12);
float32x4_t f0 = vcvtq_f32_s32(r0);
float32x4_t f1 = vcvtq_f32_s32(r1);
float32x4_t f2 = vcvtq_f32_s32(r2);
float32x4_t f3 = vcvtq_f32_s32(r3);
f0 = vmlaq_f32(__bias, __scale, f0);
f1 = vmlaq_f32(__bias, __scale, f1);
f2 = vmlaq_f32(__bias, __scale, f2);
f3 = vmlaq_f32(__bias, __scale, f3);
f0 = vmaxq_f32(__zero, f0);
f1 = vmaxq_f32(__zero, f1);
f2 = vmaxq_f32(__zero, f2);
f3 = vmaxq_f32(__zero, f3);
vst1q_f32(y, f0);
vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2);
vst1q_f32(y + 12, f3);
}
#endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) {
y[k] = std::max(scale * x[k] + bias, 0.f);
}
}
}
}
#endif
#ifdef FUSION_DEQUANT_BN_RELU_OP
template <>
bool FusionDequantBNReluKernel<CPU, float>::Init(
FusionDequantBNReluParam<CPU> *param) {
// batch norm params
const Tensor *bn_mean = param->bn_mean_;
const Tensor *bn_variance = param->bn_variance_;
Tensor *bn_scale = param->bn_scale_;
Tensor *bn_bias = param->bn_bias_;
const float epsilon = param->epsilon_;
const float *mean_ptr = bn_mean->data<float>();
const float *var_ptr = bn_variance->data<float>();
float *bn_scale_ptr = bn_scale->mutable_data<float>();
float *bn_bias_ptr = bn_bias->mutable_data<float>();
for (int c = 0; c < bn_scale->numel(); ++c) {
float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon));
bn_scale_ptr[c] = inv_scale;
bn_bias_ptr[c] = bn_bias_ptr[c] - inv_scale * mean_ptr[c];
}
return true;
}
template <>
void FusionDequantBNReluKernel<CPU, float>::Compute(
const FusionDequantBNReluParam<CPU> &param) {
DequantBNReluCompute(&param);
}
#endif // FUSION_DEQUANT_BN_RELU_OP
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <>
bool FusionDequantAddBNReluKernel<CPU, float>::Init(
FusionDequantAddBNReluParam<CPU> *param) {
// elementwise add params
const Tensor *bias = param->bias_;
// batch norm params
const Tensor *bn_mean = param->bn_mean_;
const Tensor *bn_variance = param->bn_variance_;
Tensor *bn_scale = param->bn_scale_;
Tensor *bn_bias = param->bn_bias_;
const float epsilon = param->epsilon_;
const float *bias_ptr = bias->data<float>();
const float *mean_ptr = bn_mean->data<float>();
const float *var_ptr = bn_variance->data<float>();
float *bn_scale_ptr = bn_scale->mutable_data<float>();
float *bn_bias_ptr = bn_bias->mutable_data<float>();
for (int c = 0; c < bn_scale->numel(); ++c) {
float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon));
bn_scale_ptr[c] = inv_scale;
bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c];
}
return true;
}
template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute(
const FusionDequantAddBNReluParam<CPU> &param) {
DequantBNReluCompute(&param);
}
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP
} // namespace operators
} // namespace paddle_mobile
...@@ -132,10 +132,10 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) { ...@@ -132,10 +132,10 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
// param.Output(), false); // param.Output(), false);
if (param.Paddings()[0] == 0) { if (param.Paddings()[0] == 0) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
*param.Bias(), true); param.Bias(), true);
} else { } else {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), *param.Bias(), true); param.Output(), param.Bias(), true);
} }
} else { } else {
ConvAddBasic(param); ConvAddBasic(param);
......
...@@ -107,9 +107,15 @@ inline void GemmConv(const ConvParam<CPU> &param) { ...@@ -107,9 +107,15 @@ inline void GemmConv(const ConvParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Itype>(filter_slice, false, col_matrix, false, if (param.Input()->type() == typeid(int8_t)) {
math::matmul_int8(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(0)); static_cast<float>(0));
} else {
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
}
} }
} }
} }
......
...@@ -73,8 +73,8 @@ void MulCompute(const MulParam<CPU> &param) { ...@@ -73,8 +73,8 @@ void MulCompute(const MulParam<CPU> &param) {
} }
if (param.InputX()->type() == typeid(int8_t)) { if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>(); out->mutable_data<int32_t>();
math::matmul<int8_t>(x_matrix, false, y_matrix, false, math::matmul_int8(x_matrix, false, y_matrix, false, static_cast<float>(1),
static_cast<int8_t>(1), out, static_cast<int8_t>(0)); out, static_cast<float>(0));
} else { } else {
out->mutable_data<float>(); out->mutable_data<float>();
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP #ifdef FUSION_DEQUANT_ADD_BN_OP
#include "framework/operator.h" #include "framework/operator.h"
#include "operators/op_param.h" #include "operators/op_param.h"
...@@ -23,12 +23,12 @@ namespace paddle_mobile { ...@@ -23,12 +23,12 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionDequantAddBNReluKernel class FusionDequantAddBNKernel
: public framework::OpKernelBase<DeviceType, : public framework::OpKernelBase<DeviceType,
FusionDequantAddBNReluParam<DeviceType>> { FusionDequantAddBNParam<DeviceType>> {
public: public:
void Compute(const FusionDequantAddBNReluParam<DeviceType> &param); void Compute(const FusionDequantAddBNParam<DeviceType> &param);
bool Init(FusionDequantAddBNReluParam<DeviceType> *param); bool Init(FusionDequantAddBNParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -12,42 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,42 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef DEPTHWISECONV_OP
#pragma once #pragma once
#include <vector>
#include "operators/kernel/central-arm-func/conv_arm_func.h" #include "framework/operator.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> #ifdef FUSION_DEQUANT_BN_RELU_OP
void DepthwiseConvCompute(const ConvParam<CPU> &param) { template <typename DeviceType, typename T>
Tensor Bias; class FusionDequantBNReluKernel
Bias.mutable_data<float>({param.Groups()}); : public framework::OpKernelBase<DeviceType,
if (param.Groups() == param.Input()->dims()[1] && FusionDequantBNReluParam<DeviceType>> {
param.Filter()->dims()[2] == param.Filter()->dims()[3] && public:
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { void Compute(const FusionDequantBNReluParam<DeviceType> &param);
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), bool Init(FusionDequantBNReluParam<DeviceType> *param);
&Bias, false); };
} else if (param.Groups() == param.Input()->dims()[1] && #endif
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && #ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { template <typename DeviceType, typename T>
// math::DepthwiseConv3x3(param.Input(), param.Strides(), class FusionDequantAddBNReluKernel
// param.Paddings(), : public framework::OpKernelBase<DeviceType,
// param.Filter(), &Bias, param.Output(), false); FusionDequantAddBNReluParam<DeviceType>> {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(), public:
Bias, false); void Compute(const FusionDequantAddBNReluParam<DeviceType> &param);
bool Init(FusionDequantAddBNReluParam<DeviceType> *param);
} else { };
GemmConv<float, float>(param); #endif
}
}
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif
...@@ -1272,13 +1272,13 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, ...@@ -1272,13 +1272,13 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter, const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias, framework::Tensor *output, framework::Tensor *bias,
bool if_bias) { bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
float *output_data = output->data<float>(); float *output_data = output->data<float>();
const float *bias_data = bias.data<float>(); const float *bias_data = bias->data<float>();
const int in_h = static_cast<int>(input->dims()[2]); const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]); const int in_w = static_cast<int>(input->dims()[3]);
...@@ -1905,7 +1905,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, ...@@ -1905,7 +1905,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
void DepthwiseConv3x3s2p0(const framework::Tensor *input, void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter, const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias, framework::Tensor *output, framework::Tensor *bias,
bool if_bias) { bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
...@@ -1925,7 +1925,7 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input, ...@@ -1925,7 +1925,7 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
for (int c = 0; c < input_channel; c++) { for (int c = 0; c < input_channel; c++) {
const float *filter_data = filter->data<float>() + c * 9; const float *filter_data = filter->data<float>() + c * 9;
const float *input_data = input->data<float>() + c * inhxw; const float *input_data = input->data<float>() + c * inhxw;
const float *bias_data = bias.data<float>() + c; const float *bias_data = bias->data<float>() + c;
float *output_data = output->data<float>() + c * outhxw; float *output_data = output->data<float>() + c * outhxw;
float w00 = filter_data[0]; float w00 = filter_data[0];
float w01 = filter_data[1]; float w01 = filter_data[1];
......
...@@ -50,7 +50,7 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, ...@@ -50,7 +50,7 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter, const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias, framework::Tensor *output, framework::Tensor *bias,
bool if_bias); bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
...@@ -62,7 +62,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, ...@@ -62,7 +62,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
void DepthwiseConv3x3s2p0(const framework::Tensor *input, void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter, const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias, framework::Tensor *output, framework::Tensor *bias,
bool if_bias); bool if_bias);
// TODO(hjchen2) need to be implemented // TODO(hjchen2) need to be implemented
......
...@@ -23,10 +23,12 @@ limitations under the License. */ ...@@ -23,10 +23,12 @@ limitations under the License. */
#if __aarch64__ #if __aarch64__
#define MR_INT8 4 #define MR_INT8 4
#define NR_INT8 2
#define MR 6 #define MR 6
#define NR 16 #define NR 16
#else #else
#define MR_INT8 4 #define MR_INT8 4
#define NR_INT8 2
#define MR 6 #define MR 6
#define NR 8 #define NR 8
#endif #endif
...@@ -193,52 +195,58 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -193,52 +195,58 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int small block inner product // 8 bits int small block inner product
void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc); int32_t ldc);
void AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc); int32_t ldc);
// 8 bits int inner product // 8 bits int inner product
void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, void InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *a, const int8_t *b, int8_t beta, const int8_t *b, float beta, int32_t *c, int32_t *C,
int32_t *c, int32_t *C, int32_t ldc, bool relu, int32_t ldc, bool relu);
int8_t *bias); void InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int32_t *c, int8_t *C,
int32_t ldc, bool relu, int32_t *bias);
// 8 bits int pack function // 8 bits int pack function
void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer); int32_t lda, int8_t *buffer);
void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer); int32_t lda, int8_t *buffer);
void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer); int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer); int32_t lda, int8_t *buffer);
void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer); int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer);
void PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer);
// 8 bits int matrix product // 8 bits int matrix product
void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C, int32_t lda, const int8_t *B, int32_t ldb, float beta, int32_t *C,
int32_t ldc, bool relu, int8_t *bias); int32_t ldc, bool relu, int32_t *bias);
void Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t lda, const int8_t *B, int32_t ldb, float beta, int8_t *C,
int32_t *C, int32_t ldc, bool relu, int8_t *bias); int32_t ldc, bool relu, int32_t *bias);
void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta,
int32_t *C, int32_t ldc, bool relu, int32_t *bias);
// 8 bits int write back // 8 bits int write back
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B // C = A * B
void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc); void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc);
// C = A * B + C // C = A * B + bias, scale * relu(C)
void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, void WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc); int32_t ldc, int32_t *bias, float scale);
// C = A * B + bias // C = A * B + bias, scale * C
void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, void WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int8_t *bias); int32_t ldc, int32_t *bias, float scale);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias);
private: private:
int MC = 0; int MC = 0;
...@@ -254,7 +262,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -254,7 +262,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int // 8 bits int
int8_t *packedA_int8; int8_t *packedA_int8;
int8_t *packedB_int8; int8_t *packedB_int8;
int32_t *packedC_int8; int32_t *packedC_int32;
int8_t *zero_int8; int8_t *zero_int8;
}; };
......
此差异已折叠。
...@@ -28,10 +28,10 @@ namespace operators { ...@@ -28,10 +28,10 @@ namespace operators {
namespace math { namespace math {
// 8 bits int matrix product (m*k x k*n) // 8 bits int matrix product (m*k x k*n)
void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb, const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb,
int8_t beta, int32_t *C, int32_t ldc, bool relu, float beta, int32_t *C, int32_t ldc, bool relu,
int8_t *bias) { int32_t *bias) {
#ifdef _OPENMP #ifdef _OPENMP
int32_t max_threads = omp_get_max_threads(); int32_t max_threads = omp_get_max_threads();
#else #else
...@@ -39,10 +39,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, ...@@ -39,10 +39,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
#endif #endif
int32_t L1 = 64 / max_threads * 1024; int32_t L1 = 64 / max_threads * 1024;
KC = k; const int32_t k_complete = (k + 15) - ((k + 15) & 15);
KC = k_complete;
zero_int8 = zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * KC); memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * k);
if (m > n) { if (m > n) {
// 对 A 分块 // 对 A 分块
MC = L1 / (KC * sizeof(int8_t)); MC = L1 / (KC * sizeof(int8_t));
...@@ -54,14 +55,14 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, ...@@ -54,14 +55,14 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8;
} }
// 补齐 B // 补齐 B
NC = (n + NR - 1) / NR * NR; NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8;
packedB_int8 = static_cast<int8_t *>( packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO(wzzju)
#else #else
PackMatrixB_omp_8c(KC, n, n % NR, B, ldb, packedB_int8); PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#endif #endif
packedA_int8 = static_cast<int8_t *>( packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads)); paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads));
...@@ -69,11 +70,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, ...@@ -69,11 +70,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
// 对 B 分块 // 对 B 分块
NC = L1 / (KC * sizeof(int8_t)); NC = L1 / (KC * sizeof(int8_t));
if (NC == 0) { if (NC == 0) {
NC = NR; NC = NR_INT8;
} else { } else {
int32_t nblock_num = (n + NC - 1) / NC; int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num; NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR; NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8;
} }
// 补齐 A // 补齐 A
MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8; MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8;
...@@ -83,12 +84,12 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, ...@@ -83,12 +84,12 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO(wzzju)
#else #else
PackMatrixA_omp_4r(m, KC, m % MR_INT8, A, lda, packedA_int8); PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#endif #endif
packedB_int8 = static_cast<int8_t *>( packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads));
} }
packedC_int8 = static_cast<int32_t *>( packedC_int32 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads));
if (m > n) { if (m > n) {
...@@ -103,14 +104,19 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, ...@@ -103,14 +104,19 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
int32_t mc; int32_t mc;
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
int8_t *local_A = packedA_int8 + MC * KC * local_threads; int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int8 + MC * NC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO(wzzju)
#else #else
PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, local_A); PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif #endif
InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C, // InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta,
&C(i, 0), ldc, relu, bias + i); // local_C,
// &C(i, 0), ldc, relu, bias + i);
if (bias == nullptr) {
InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu);
}
} }
} else { } else {
#pragma omp parallel for #pragma omp parallel for
...@@ -123,20 +129,25 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, ...@@ -123,20 +129,25 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
int32_t nc; int32_t nc;
nc = s_min(n - j, NC); nc = s_min(n - j, NC);
int8_t *local_B = packedB_int8 + KC * NC * local_threads; int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int8 + MC * NC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO(wzzju)
#else #else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, local_B); PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#endif #endif
InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C, // InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta,
&C(0, j), ldc, relu, bias); // local_C,
// &C(0, j), ldc, relu, bias);
if (bias == nullptr) {
InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu);
}
} }
} }
paddle_mobile::memory::Free(packedA_int8); paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8); paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int8); paddle_mobile::memory::Free(packedC_int32);
paddle_mobile::memory::Free(zero_int8); paddle_mobile::memory::Free(zero_int8);
} }
...@@ -144,7 +155,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, ...@@ -144,7 +155,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) { const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail; const int32_t j_length = n - n_tail;
#pragma omp parallel for #pragma omp parallel for
for (int32_t j = 0; j < j_length; j += NR) { for (int32_t j = 0; j < j_length; j += 8) {
int8_t *local_buffer = buffer + j * k; int8_t *local_buffer = buffer + j * k;
for (int32_t i = 0; i < k; ++i) { for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j); const int8_t *b0 = &B(i, j);
...@@ -179,7 +190,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, ...@@ -179,7 +190,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
for (int32_t j = j_length; j < n; ++j) { for (int32_t j = j_length; j < n; ++j) {
*local_buffer++ = *b0++; *local_buffer++ = *b0++;
} }
for (int32_t j = n; j < j_length + NR; ++j) { for (int32_t j = n; j < j_length + 8; ++j) {
*local_buffer++ = 0; *local_buffer++ = 0;
} }
} }
...@@ -188,9 +199,9 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, ...@@ -188,9 +199,9 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) { const int8_t *A, int32_t lda, int8_t *buffer) {
const int i_length = m - m_tail; const int32_t i_length = m - m_tail;
#pragma omp parallel for #pragma omp parallel for
for (int32_t i = 0; i < i_length; i += MR_INT8) { for (int32_t i = 0; i < i_length; i += 4) {
const int8_t *a0 = A + i * lda; const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda; const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda; const int8_t *a2 = A + (i + 2) * lda;
...@@ -221,7 +232,7 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, ...@@ -221,7 +232,7 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
default: default:
break; break;
} }
for (int j = 0; j < k; ++j) { for (int32_t j = 0; j < k; ++j) {
*local_buffer++ = *a0++; *local_buffer++ = *a0++;
*local_buffer++ = *a1++; *local_buffer++ = *a1++;
*local_buffer++ = *a2++; *local_buffer++ = *a2++;
...@@ -230,6 +241,232 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, ...@@ -230,6 +241,232 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
} }
} }
// 8 bits int PackMatrixA_4r
void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t i = 0; i < i_length; i += 4) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
const int8_t *a3 = A + (i + 3) * lda;
int8_t *local_buffer = buffer + i * KC;
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (m_tail != 0) {
const int8_t *a0 = &A(i_length, 0);
const int8_t *a1 = a0 + lda;
const int8_t *a2 = a0 + 2 * lda;
const int8_t *a3 = a0 + 3 * lda;
int8_t *local_buffer = buffer + i_length * KC;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
break;
default:
break;
}
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t j = 0; j < j_length; j += 2) {
int8_t *local_buffer = buffer + j * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j);
const int8_t *b1 = &B((i << 4), j + 1);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b1;
b1 += ldb;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j);
const int8_t *b1 = &B((k_count << 4), j + 1);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (n_tail != 0) {
int8_t *local_buffer = buffer + j_length * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j_length);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = 0;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j_length);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -28,7 +28,12 @@ template <typename T> ...@@ -28,7 +28,12 @@ template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a, void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha, const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false, framework::Tensor *matrix_out, T beta, bool relu = false,
T *bias = nullptr); float *bias = nullptr);
void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu = false,
int32_t *bias = nullptr);
template <typename T> template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
...@@ -20,11 +20,10 @@ limitations under the License. */ ...@@ -20,11 +20,10 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
template <> void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha,
const framework::Tensor &matrix_b, bool trans_b, framework::Tensor *matrix_out, float beta, bool relu,
int8_t alpha, framework::Tensor *matrix_out, int8_t beta, int32_t *bias) {
bool relu, int8_t *bias) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
...@@ -52,21 +51,45 @@ void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -52,21 +51,45 @@ void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a,
} }
#ifdef _OPENMP #ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta, if (bias != nullptr) {
matrix_out->data<int32_t>(), N, relu, bias); // TODO(wzzju): gemm.Sgemm_omp_with_bias, now use single thread instead.
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias);
} else {
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
#else #else
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta, if (bias != nullptr) {
matrix_out->data<int32_t>(), N, relu, bias); gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias);
} else {
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
#endif #endif
} else { } else {
#ifdef _OPENMP #ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K, if (bias != nullptr) {
matrix_b.data<int8_t>(), N, beta, // TODO(wzzju): gemm.Sgemm_omp_with_bias, now use single thread instead.
matrix_out->data<int32_t>(), N, relu, bias); gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int8_t>(),
N, relu, bias);
} else {
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
#else #else
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K, if (bias != nullptr) {
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(), N, gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
relu, bias); matrix_b.data<int8_t>(), N, beta, matrix_out->data<int8_t>(),
N, relu, bias);
} else {
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(),
N, relu, bias);
}
#endif #endif
} }
} }
......
...@@ -419,6 +419,8 @@ class ConvParam : public OpParam { ...@@ -419,6 +419,8 @@ class ConvParam : public OpParam {
EXEC_INVALID = 0, EXEC_INVALID = 0,
EXEC_GEMM_FLOAT, EXEC_GEMM_FLOAT,
EXEC_DEPTHWISE3x3S1P1_FLOAT, EXEC_DEPTHWISE3x3S1P1_FLOAT,
EXEC_DEPTHWISE3x3S2P0_FLOAT,
EXEC_DEPTHWISE3x3S2P1_FLOAT,
EXEC_DEPTHWISE3x3_FLOAT, EXEC_DEPTHWISE3x3_FLOAT,
EXEC_WINOGRAD3X3_FLOAT, EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT, EXEC_WINOGRAD5X5_FLOAT,
...@@ -2573,7 +2575,9 @@ class DequantizeParam : public OpParam { ...@@ -2573,7 +2575,9 @@ class DequantizeParam : public OpParam {
DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope); input_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope); if (outputs.count("Out")) {
output_ = OutFrom<GType>(outputs, scope);
}
activation_scale_ = OpParam::GetVarValue<GType>("Scale", inputs, scope); activation_scale_ = OpParam::GetVarValue<GType>("Scale", inputs, scope);
// dequantization is performed as x = x / static_scale / online_scale // dequantization is performed as x = x / static_scale / online_scale
if (HasAttr("weight_scale", attrs)) { if (HasAttr("weight_scale", attrs)) {
...@@ -2593,20 +2597,19 @@ class DequantizeParam : public OpParam { ...@@ -2593,20 +2597,19 @@ class DequantizeParam : public OpParam {
}; };
#endif #endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP #if defined(FUSION_DEQUANT_ADD_BN_OP) || \
defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || \
defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_BN_OP)
template <typename Dtype> template <typename Dtype>
class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> { class FusionDequantBNParam : public DequantizeParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
FusionDequantAddBNReluParam(const VariableNameMap &inputs, FusionDequantBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, const Scope &scope)
: DequantizeParam<Dtype>(inputs, outputs, attrs, scope) { : DequantizeParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
// batch norm params // batch norm params
bn_mean_ = OpParam::GetVarValue<GType>("BNMean", inputs, scope); bn_mean_ = OpParam::GetVarValue<GType>("BNMean", inputs, scope);
bn_variance_ = OpParam::GetVarValue<GType>("BNVariance", inputs, scope); bn_variance_ = OpParam::GetVarValue<GType>("BNVariance", inputs, scope);
...@@ -2614,21 +2617,83 @@ class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> { ...@@ -2614,21 +2617,83 @@ class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> {
bn_bias_ = OpParam::GetVarValue<GType>("BNBias", inputs, scope); bn_bias_ = OpParam::GetVarValue<GType>("BNBias", inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
// output // output
output_ = OpParam::OutFrom<GType>(outputs, scope); if (outputs.count("Y")) {
this->output_ = OpParam::OutputYFrom<GType>(outputs, scope);
}
} }
public: public:
// elementwise add
int axis_;
RType *bias_;
// batch norm // batch norm
RType *bn_mean_; RType *bn_mean_;
RType *bn_variance_; RType *bn_variance_;
RType *bn_scale_; RType *bn_scale_;
RType *bn_bias_; RType *bn_bias_;
float epsilon_; float epsilon_;
// output };
RType *output_; #endif
#if defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_OP)
template <typename Dtype>
class FusionDequantAddBNParam : public FusionDequantBNParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionDequantBNParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
// output
if (outputs.count("Y")) {
this->output_ = OpParam::OutputYFrom<GType>(outputs, scope);
}
}
public:
// elementwise add
int axis_;
RType *bias_;
};
#endif
#ifdef FUSION_DEQUANT_BN_RELU_OP
template <typename Dtype>
class FusionDequantBNReluParam : public FusionDequantBNParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionDequantBNParam<Dtype>(inputs, outputs, attrs, scope) {
// output
if (outputs.count("Out")) {
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
}
};
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <typename Dtype>
class FusionDequantAddBNReluParam : public FusionDequantAddBNParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionDequantAddBNParam<Dtype>(inputs, outputs, attrs, scope) {
// output
if (outputs.count("Out")) {
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
}
}; };
#endif #endif
......
...@@ -28,7 +28,7 @@ limitations under the License. */ ...@@ -28,7 +28,7 @@ limitations under the License. */
int main() { int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile; paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(8); paddle_mobile.SetThreadNum(4);
Tensor aa, bb, cc; Tensor aa, bb, cc;
auto aaptr = aa.mutable_data<float>({m, k}); auto aaptr = aa.mutable_data<float>({m, k});
auto bbptr = bb.mutable_data<float>({k, n}); auto bbptr = bb.mutable_data<float>({k, n});
...@@ -44,10 +44,12 @@ int main() { ...@@ -44,10 +44,12 @@ int main() {
ccptr[i] = 2; ccptr[i] = 2;
} }
Tensor aa_int8, bb_int8, cc_int8; Tensor aa_int8, bb_int8, cc_int32, cc_int8;
auto aaptr_int8 = aa_int8.mutable_data<int8_t>({m, k}); auto aaptr_int8 = aa_int8.mutable_data<int8_t>({m, k});
auto bbptr_int8 = bb_int8.mutable_data<int8_t>({k, n}); auto bbptr_int8 = bb_int8.mutable_data<int8_t>({k, n});
auto ccptr_int8 = cc_int8.mutable_data<int32_t>({m, n}); auto ccptr_int32 = cc_int32.mutable_data<int32_t>({m, n});
auto ccptr_int8 = cc_int8.mutable_data<int8_t>({m, n});
int32_t* bias_data = new int32_t[m];
for (int i = 0; i < m * k; ++i) { for (int i = 0; i < m * k; ++i) {
aaptr_int8[i] = static_cast<int8_t>(2); aaptr_int8[i] = static_cast<int8_t>(2);
...@@ -56,7 +58,11 @@ int main() { ...@@ -56,7 +58,11 @@ int main() {
bbptr_int8[i] = static_cast<int8_t>(2); bbptr_int8[i] = static_cast<int8_t>(2);
} }
for (int i = 0; i < m * n; ++i) { for (int i = 0; i < m * n; ++i) {
ccptr_int8[i] = static_cast<int32_t>(2); ccptr_int32[i] = static_cast<int32_t>(2);
}
for (int i = 0; i < m; ++i) {
bias_data[i] = 2;
} }
// float // float
...@@ -76,22 +82,41 @@ int main() { ...@@ -76,22 +82,41 @@ int main() {
auto time2 = time(); auto time2 = time();
std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n"; std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n";
// int8_t // int8_t without bias
// warm-up 10 times // warm-up 10 times
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t>( paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8, aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<int8_t>(0), false, nullptr); static_cast<float>(0), false, nullptr);
} }
auto time3 = time(); auto time3 = time();
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t>( paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8, aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<int8_t>(0), false, nullptr); static_cast<float>(0), false, nullptr);
} }
auto time4 = time(); auto time4 = time();
std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n"; std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n";
// int8_t with bias&relu
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int8,
static_cast<float>(0), true, &bias_data[0]);
}
auto time5 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int8,
static_cast<float>(0), true, &bias_data[0]);
}
auto time6 = time();
std::cout << "int8_t gemm_with_bias_relu cost :"
<< time_diff(time5, time6) / 10 << "ms\n";
delete[] bias_data;
return 0; return 0;
} }
...@@ -249,7 +249,9 @@ if(NOT FOUND_MATCH) ...@@ -249,7 +249,9 @@ if(NOT FOUND_MATCH)
set(SUM_OP ON) set(SUM_OP ON)
set(QUANT_OP ON) set(QUANT_OP ON)
set(DEQUANT_OP ON) set(DEQUANT_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU ON) set(FUSION_DEQUANT_ADD_BN_OP ON)
set(FUSION_DEQUANT_BN_RELU_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU_OP ON)
endif() endif()
# option(BATCHNORM_OP "" ON) # option(BATCHNORM_OP "" ON)
...@@ -451,10 +453,17 @@ endif() ...@@ -451,10 +453,17 @@ endif()
if (DEQUANT_OP) if (DEQUANT_OP)
add_definitions(-DDEQUANT_OP) add_definitions(-DDEQUANT_OP)
endif() endif()
if (FUSION_DEQUANT_ADD_BN_RELU) if (FUSION_DEQUANT_ADD_BN_OP)
add_definitions(-DFUSION_DEQUANT_ADD_BN_OP)
endif()
if (FUSION_DEQUANT_BN_RELU_OP)
add_definitions(-DFUSION_DEQUANT_BN_RELU_OP)
endif()
if (FUSION_DEQUANT_ADD_BN_RELU_OP)
add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP) add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP)
endif() endif()
if (TANH_OP) if (TANH_OP)
add_definitions(-DTANH_OP) add_definitions(-DTANH_OP)
endif() endif()
...@@ -467,3 +476,4 @@ endif() ...@@ -467,3 +476,4 @@ endif()
if (FUSION_DECONVADDRELU_OP) if (FUSION_DECONVADDRELU_OP)
add_definitions(-DFUSION_DECONVADDRELU_OP) add_definitions(-DFUSION_DECONVADDRELU_OP)
endif() endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册