提交 7be0ed00 编写于 作者: Z zhangyang

Merge remote-tracking branch 'upstream/develop' into develop

......@@ -24,6 +24,7 @@ const char *G_OP_TYPE_CONCAT = "concat";
const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
const char *G_OP_TYPE_FILL_CONSTANT = "fill_constant";
const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu";
const char *G_OP_TYPE_FUSION_CONV_ADD_RELU_INT8 = "fusion_conv_add_relu_int8";
const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu";
const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU = "fusion_conv_add_add_prelu";
const char *G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu";
......@@ -31,6 +32,7 @@ const char *G_OP_TYPE_FUSION_CONV_BN_ADD_RELU = "fusion_conv_bn_add_relu";
const char *G_OP_TYPE_FUSION_DWCONV_BN_RELU = "fusion_dwconv_bn_relu";
const char *G_OP_TYPE_FUSION_CONV_BN_RELU = "fusion_conv_bn_relu";
const char *G_OP_TYPE_FC = "fusion_fc";
const char *G_OP_TYPE_FC_INT8 = "fusion_fc_int8";
const char *G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add";
const char *G_OP_TYPE_LRN = "lrn";
const char *G_OP_TYPE_MUL = "mul";
......@@ -110,11 +112,13 @@ std::unordered_map<
{G_OP_TYPE_MULTICLASS_NMS, {{"BBoxes", "Scores"}, {"Out"}}},
{G_OP_TYPE_POLYGON_BOX_TRANSFORM, {{"Input"}, {"Output"}}},
{G_OP_TYPE_FC, {{"X", "Y", "Z"}, {"Out"}}},
{G_OP_TYPE_FC_INT8, {{"X", "Y", "Z", "Scale"}, {"Out"}}},
{G_OP_TYPE_RESHAPE, {{"X"}, {"Out"}}},
{G_OP_TYPE_RESHAPE2, {{"X"}, {"Out", "XShape"}}},
{G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}},
{G_OP_TYPE_FILL_CONSTANT, {{}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_RELU_INT8, {{"Input", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_PRELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_IM2SEQUENCE, {{"X"}, {"Out"}}},
......
......@@ -99,9 +99,11 @@ extern const char *G_OP_TYPE_BOX_CODER;
extern const char *G_OP_TYPE_CONCAT;
extern const char *G_OP_TYPE_ELEMENTWISE_ADD;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_RELU;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_RELU_INT8;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU;
extern const char *G_OP_TYPE_FC;
extern const char *G_OP_TYPE_FC_INT8;
extern const char *G_OP_TYPE_FUSION_CONV_ADD;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_BN_RELU;
extern const char *G_OP_TYPE_FUSION_CONV_BN_ADD_RELU;
......
......@@ -98,6 +98,24 @@ class OpRegistry {
}
};
#define REGISTER_OPERATOR_INT8(op_type, op_class, device_name, device_type) \
template class op_class<device_type, int8_t>; \
template <typename Dtype, typename T> \
class _OpClass_##op_type##_##device_name : public op_class<Dtype, T> { \
public: \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_##device_name, op_class); \
}; \
static paddle_mobile::framework::OperatorRegistrar< \
device_type, _OpClass_##op_type##_##device_name<device_type, int8_t>> \
__op_registrar_##op_type##_##device_name(#op_type); \
int TouchOpRegistrar_##op_type##_##device_name() { \
__op_registrar_##op_type##_##device_name.Touch(); \
return 0; \
}
#define REGISTER_OPERATOR_CPU_INT8(op_type, op_class) \
REGISTER_OPERATOR_INT8(op_type, op_class, cpu, paddle_mobile::CPU);
#define REGISTER_OPERATOR(op_type, op_class, device_name, device_type) \
template class op_class<device_type, float>; \
template <typename Dtype, typename T> \
......
......@@ -153,7 +153,8 @@ double PaddleMobile<CPU, Precision::FP32>::GetPredictTime() {
paddle_mobile::operators::math::Gemm gemm;
auto time1 = paddle_mobile::time();
gemm.Sgemm(m, n, k, static_cast<float>(1), a, lda, b, ldb,
static_cast<float>(0), c, ldc, false, nullptr);
static_cast<float>(0), c, ldc, false,
static_cast<float *>(nullptr));
auto time2 = paddle_mobile::time();
double cost = paddle_mobile::time_diff(time1, time2);
paddle_mobile::memory::Free(a);
......
......@@ -30,6 +30,9 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(dropout, ops::DropoutOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(dropout, ops::DropoutOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(dropout, ops::DropoutOp);
#endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDRELU_INT8_OP
#include "operators/fusion_conv_add_relu_int8_op.h"
#include <vector>
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionConvAddReluInt8Op<Dtype, T>::InferShape() const {
auto in_dims = this->param_.Input()->dims();
auto filter_dims = this->param_.Filter()->dims();
const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
int groups = this->param_.Groups();
std::vector<int> dilations = this->param_.Dilations();
PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() &&
dilations.size() == paddings.size() &&
paddings.size() == strides.size()),
"ConvParam is not suitable");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(
math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i],
paddings[i], strides[i]));
}
framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU_INT8(fusion_conv_add_relu_int8,
ops::FusionConvAddReluInt8Op);
#endif
#endif // FUSION_CONVADDRELU_INT8_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDRELU_INT8_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/conv_add_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class FusionConvAddReluInt8Op
: public framework::OperatorWithKernel<DeviceType,
FusionConvAddReluParam<DeviceType>,
ConvAddReluKernel<DeviceType, T>> {
public:
FusionConvAddReluInt8Op(const std::string &type,
const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType,
FusionConvAddReluParam<DeviceType>,
ConvAddReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif // FUSION_CONVADDRELU_INT8_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_FC_INT8_OP
#include "operators/fusion_fc_int8_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionFcInt8Op<Dtype, T>::InferShape() const {
auto x_dims = this->param_.InputX()->dims();
auto y_dims = this->param_.InputY()->dims();
int x_num_col_dims = this->param_.XNumColDims();
int y_num_col_dims = this->param_.YNumColDims();
assert(x_dims.size() > x_num_col_dims);
assert(y_dims.size() > y_num_col_dims);
/// (1,2,3,4) , x_num_col_dims = 2 -> (2,12)
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
assert(x_mat_dims[1] == y_mat_dims[0]);
std::vector<int64_t> output_dims;
output_dims.reserve(
static_cast<size_t>(x_num_col_dims + y_dims.size() - y_num_col_dims));
for (int i = 0; i < x_num_col_dims; ++i) {
output_dims.push_back(x_dims[i]);
}
for (int i = y_num_col_dims; i < y_dims.size(); ++i) {
output_dims.push_back(y_dims[i]);
}
framework::DDim ddim = framework::make_ddim(output_dims);
this->param_.Out()->Resize(ddim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU_INT8(fusion_fc_int8, ops::FusionFcInt8Op);
#endif
#endif // FUSION_FC_INT8_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_FC_INT8_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/fusion_fc_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class FusionFcInt8Op
: public framework::OperatorWithKernel<DeviceType,
FusionFcParam<DeviceType>,
FusionFcKernel<DeviceType, T>> {
public:
FusionFcInt8Op(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, FusionFcParam<DeviceType>,
FusionFcKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif // FUSION_FC_INT8_OP
......@@ -28,10 +28,24 @@ bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam<CPU> *param) {
template <>
void ConvAddReluKernel<CPU, float>::Compute(
const FusionConvAddReluParam<CPU> &param) {
ConvAddReluCompute<float>(param);
ConvAddReluCompute<float, float>(param);
}
template class ConvAddReluKernel<CPU, float>;
#ifdef FUSION_CONVADDRELU_INT8_OP
template <>
bool ConvAddReluKernel<CPU, int8_t>::Init(FusionConvAddReluParam<CPU> *param) {
return true;
}
template <>
void ConvAddReluKernel<CPU, int8_t>::Compute(
const FusionConvAddReluParam<CPU> &param) {
ConvAddReluCompute<int8_t, int32_t>(param);
}
template class ConvAddReluKernel<CPU, int8_t>;
#endif
} // namespace operators
} // namespace paddle_mobile
......
......@@ -27,10 +27,27 @@ bool FusionFcKernel<CPU, float>::Init(FusionFcParam<CPU> *param) {
template <>
void FusionFcKernel<CPU, float>::Compute(const FusionFcParam<CPU> &param) {
FusionFcCompute<float>(param);
FusionFcCompute<float, float>(param);
param.Out()->set_lod(param.InputX()->lod());
}
template class FusionFcKernel<CPU, float>;
#ifdef FUSION_FC_INT8_OP
template <>
bool FusionFcKernel<CPU, int8_t>::Init(FusionFcParam<CPU> *param) {
return true;
}
template <>
void FusionFcKernel<CPU, int8_t>::Compute(const FusionFcParam<CPU> &param) {
FusionFcCompute<int8_t, int32_t>(param);
param.Out()->set_lod(param.InputX()->lod());
}
template class FusionFcKernel<CPU, int8_t>;
#endif
} // namespace operators
} // namespace paddle_mobile
......
......@@ -25,21 +25,30 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <typename P>
template <typename P, typename S>
void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
int axis = param.Axis();
int32_t axis = param.Axis();
S *bias_data = bias.data<S>();
Tensor *output = param.Output();
float *biase_data = bias.data<float>();
output->mutable_data<P>();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
float alpha = 1.0f;
float beta = 1.0f;
const int batch_size = static_cast<int>(input->dims()[0]);
#ifdef FUSION_CONVADDRELU_INT8_OP
alpha = param.InputScale()->data<float>()[0];
beta = 0.0f;
#endif
int32_t groups = param.Groups();
std::vector<int32_t> strides = param.Strides();
std::vector<int32_t> paddings = param.Paddings();
std::vector<int32_t> dilations = param.Dilations();
const int32_t batch_size = static_cast<int32_t>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
......@@ -61,13 +70,13 @@ void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col.mutable_data<P>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
input->dims(), 1, static_cast<int32_t>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
......@@ -77,17 +86,17 @@ void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
int32_t in_step = static_cast<int32_t>(input->dims()[1]) / groups;
int32_t out_step = static_cast<int32_t>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
math::Vol2ColFunctor<CPU, P> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, P> im2col;
for (int i = 0; i < batch_size; i++) {
for (int32_t i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
for (int32_t g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
......@@ -97,8 +106,8 @@ void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
std::vector<int32_t>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
......@@ -108,9 +117,9 @@ void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), true, biase_data);
math::matmul(filter_slice, false, col_matrix, false, alpha, &out_slice,
beta, true, bias_data);
}
}
}
......
......@@ -106,10 +106,9 @@ inline void GemmConv(const ConvParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Itype>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
math::matmul(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(0),
false, static_cast<Otype *>(nullptr));
}
}
}
......
......@@ -15,23 +15,29 @@ limitations under the License. */
#ifdef FUSION_FC_OP
#pragma once
#include <type_traits>
#include "operators/math/math_function.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
template <typename P, typename S>
void FusionFcCompute(const FusionFcParam<CPU> &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
const Tensor *input_z = param.InputZ();
auto *input_z_data = input_z->data<float>();
Tensor *input_z = param.InputZ();
S *input_z_data = input_z->data<S>();
int axis = param.Axis();
Tensor *out = param.Out();
// int m = out->dims()[0];
// int n = out->dims()[1];
auto *out_data = out->mutable_data<float>();
auto *out_data = out->mutable_data<P>();
float alpha = 1.0f;
float beta = 1.0f;
const Tensor x_matrix =
input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
......@@ -51,21 +57,28 @@ void FusionFcCompute(const FusionFcParam<CPU> &param) {
axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis);
PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ");
int64_t classes = input_z->numel();
for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes);
}
if (std::is_same<P, int8_t>::value) {
#ifdef FUSION_FC_INT8_OP
alpha = param.InputScale()->data<float>()[0];
beta = 0.0f;
math::matmul(x_matrix, false, y_matrix, false, alpha, out, beta, false,
input_z_data, true);
#endif
} else {
// bias_data的维度和out的第二个维度一致
int64_t classes = input_z->numel();
for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data,
sizeof(float) * classes);
}
// for (int i = 0; i < out->numel(); i++) {
// DLOG << out_data[i];
// }
// bias_data的维度和out的维度一致
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1), false);
math::matmul<float>(x_matrix, false, y_matrix, false, alpha, out, beta,
false);
}
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
// if (out_dim.size() != 2) {
// out->Resize(out_dim);
// }
// if (out_dim.size() != 2) {
// out->Resize(out_dim);
// }
}
} // namespace operators
......
......@@ -73,8 +73,9 @@ void MulCompute(const MulParam<CPU> &param) {
}
if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>();
math::matmul<int8_t>(x_matrix, false, y_matrix, false,
static_cast<int8_t>(1), out, static_cast<int8_t>(0));
math::matmul<float, int32_t>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out,
static_cast<float>(0));
} else {
out->mutable_data<float>();
......
......@@ -23,20 +23,22 @@ namespace paddle_mobile {
namespace operators {
using framework::Tensor;
inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
std::vector<int> strides, std::vector<int> paddings,
const Tensor *in_x, Tensor *out) {
template <typename T, typename S>
void PoolBasic(std::string pooling_type, std::vector<int> ksize,
std::vector<int> strides, std::vector<int> paddings,
const Tensor *in_x, Tensor *out) {
if (pooling_type == "max") {
math::PoolFunctor<CPU, math::MaxPool<float>, float> pool2d_forward;
math::MaxPool<float> pool_process;
math::PoolFunctor<CPU, math::MaxPool<T>, T> pool2d_forward;
math::MaxPool<T> pool_process;
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
} else if (pooling_type == "avg") {
math::PoolFunctor<CPU, math::AvgPool<float>, float> pool2d_forward;
math::AvgPool<float> pool_process;
math::PoolFunctor<CPU, math::AvgPool<T, S>, T> pool2d_forward;
math::AvgPool<T, S> pool_process;
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
}
}
template <typename P>
void PoolCompute(const PoolParam<CPU> &param) {
const Tensor *in_x = param.Input();
......@@ -52,50 +54,67 @@ void PoolCompute(const PoolParam<CPU> &param) {
LOG(paddle_mobile::LogLevel::kLOG_ERROR)
<< "Pool op only supports 2D and 3D input.";
}
if (param.isGlobalPooling()) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
}
if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Maxs1p1(in_x, out);
if (in_x->type() == typeid(int8_t)) {
if (pooling_type == "max" && ksize[0] == 3 && ksize[0] == ksize[1]) {
if (strides[0] == strides[1] && strides[0] == 1) {
math::Pool3x3Maxs1_int8(in_x, out, paddings[0], paddings[1]);
} else if (strides[0] == strides[1] && strides[0] == 2) {
math::Pool3x3Maxs2_int8(in_x, out, paddings[0], paddings[1]);
} else {
math::Pool3x3Max(strides, paddings, in_x, out);
}
} else if (pooling_type == "avg") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Avgs1p1(in_x, out);
} else {
math::Pool3x3Avg(strides, paddings, in_x, out);
math::Pool3x3Max_int8(strides, paddings, in_x, out);
}
} else {
PoolBasic<int8_t, int32_t>(pooling_type, ksize, strides, paddings, in_x,
out);
}
} else {
if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Maxs1p1(in_x, out);
} else {
math::Pool3x3Max(strides, paddings, in_x, out);
}
} else if (pooling_type == "avg") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Avgs1p1(in_x, out);
} else {
math::Pool3x3Avg(strides, paddings, in_x, out);
}
}
} else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1] && paddings[0] == paddings[1] &&
paddings[1] == 0) {
} else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1] && paddings[0] == paddings[1] &&
paddings[1] == 0) {
#if __ARM_NEON
#if __aarch64__
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x,
out);
#else
/// todo: fix bug in Pool2x2
if (pooling_type == "max") {
math::Pool2x2Maxs2p0(strides, paddings, in_x, out);
} else if (pooling_type == "avg") {
math::Pool2x2Avgs2p0(strides, paddings, in_x, out);
}
/// todo: fix bug in Pool2x2
if (pooling_type == "max") {
math::Pool2x2Maxs2p0(strides, paddings, in_x, out);
} else if (pooling_type == "avg") {
math::Pool2x2Avgs2p0(strides, paddings, in_x, out);
}
#endif
#else
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x,
out);
#endif // __ARM_NEON
} else {
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
} else {
PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x,
out);
}
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void dropout(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int out_W,
__private const float dropoutPro) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
int2 output_pos;
output_pos.x = out_c * out_W + out_w;
output_pos.y = out_nh;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 input;
half4 output;
input = read_imageh(input_image, sampler,output_pos);
half4 dropout = (half4)(1 - dropoutPro);
output = dropout * input;
write_imageh(output_image, output_pos, output);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DROPOUT_OP
#include "operators/kernel/dropout_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool DropoutKernel<GPU_CL, float>::Init(DropoutParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("dropout", "dropout_kernel.cl");
return true;
}
template <>
void DropoutKernel<GPU_CL, float>::Compute(const DropoutParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*(param.Out()));
auto *input_image = param.InputX()->GetCLImage();
auto *output_image = param.Out()->GetCLImage();
const float dropoutProb = param.DropoutProb();
const auto &inputDim = param.InputX()->dims();
int input_dims[4] = {1, 1, 1, 1};
// 1 1000 1 1
for (int i = 0; i < inputDim.size(); i++) {
input_dims[4 - inputDim.size() + i] = inputDim[i];
}
int out_W = input_dims[1];
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &out_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(float), &dropoutProb);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef MUL_OP
#include "operators/kernel/mul_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool MulKernel<GPU_CL, float>::Init(MulParam<GPU_CL> *param) {
return true;
}
template <>
void MulKernel<GPU_CL, float>::Compute(const MulParam<GPU_CL> &param) {}
template class MulKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -15,6 +15,10 @@ limitations under the License. */
#pragma once
#include <string>
#include "common/log.h"
#include "memory/t_malloc.h"
#ifdef _OPENMP
#include <omp.h>
#endif
// 矩阵取值运算宏,假设矩阵按行存储
#define A(i, j) A[(i)*lda + (j)]
......@@ -23,10 +27,12 @@ limitations under the License. */
#if __aarch64__
#define MR_INT8 4
#define NR_INT8 2
#define MR 6
#define NR 16
#else
#define MR_INT8 4
#define NR_INT8 2
#define MR 6
#define NR 8
#endif
......@@ -170,6 +176,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias, float *bias);
void SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
const float *B, int ldb, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
......@@ -193,52 +200,72 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int small block inner product
void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
// 8 bits int inner product
void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha,
const int8_t *a, const int8_t *b, int8_t beta,
int32_t *c, int32_t *C, int32_t ldc, bool relu,
int8_t *bias);
template <typename Otype>
void InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int32_t *c, Otype *C,
int32_t ldc, bool relu);
template <typename Otype>
void InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int32_t *c, Otype *C,
int32_t ldc, bool relu, int32_t *bias,
bool addOnRow = false);
// 8 bits int pack function
void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer);
void PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer);
// 8 bits int matrix product
void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C,
int32_t ldc, bool relu, int8_t *bias);
void Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta,
int32_t *C, int32_t ldc, bool relu, int8_t *bias);
template <typename Itype, typename Btype, typename Otype>
void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const Itype *A,
int32_t lda, const Itype *B, int32_t ldb, float beta, Otype *C,
int32_t ldc, bool relu, Btype *bias, bool addOnRow = false);
template <typename Otype>
void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta,
Otype *C, int32_t ldc, bool relu, int32_t *bias,
bool addOnRow = false);
template <typename Itype, typename Btype, typename Otype>
void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const Itype *A,
int32_t lda, const Itype *B, int32_t ldb, float beta, Otype *C,
int32_t ldc, bool relu, Btype *bias, bool addOnRow = false);
template <typename Otype>
void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta, Otype *C,
int32_t ldc, bool relu, int32_t *bias, bool addOnRow = false);
// 8 bits int write back
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B
void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc);
// C = A * B + C
void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B + bias
void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias);
// C = A * B + bias, scale * relu(C)
void WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale);
// C = A * B + bias, scale * C, bias is added on column
void WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale);
// C = A * B + bias, scale * C, bias is added on row
void WriteWithAddScaleT(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale);
private:
int MC = 0;
......@@ -254,10 +281,218 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int
int8_t *packedA_int8;
int8_t *packedB_int8;
int32_t *packedC_int8;
int32_t *packedC_int32;
int8_t *zero_int8;
};
// 8 bits int matrix product (m*k x k*n)
template <typename Otype>
void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta,
Otype *C, int32_t ldc, bool relu, int32_t *bias,
bool addOnRow) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int32_t L1 = 32 * 1024;
int32_t L2 = 512 * 1024;
const int32_t k_complete = (k + 15) - ((k + 15) & 15);
KC = k_complete;
MC = L1 / (KC * sizeof(int8_t));
NC = L2 / (KC * sizeof(int8_t));
// make sure MC is multiple of MR_INT8, and NC is multiple of NR_INT8
if (MC == 0) {
MC = MR_INT8;
} else {
int32_t mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8;
}
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
if (NC == 0) {
NC = NR_INT8;
} else {
int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8;
}
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
packedC_int32 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC));
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * k);
int32_t mc, nc;
for (int32_t j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8);
for (int32_t i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8);
if (bias == nullptr) {
InnerKernel(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int32, &C(i, j), ldc, relu);
} else {
if (addOnRow) {
InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int32, &C(i, j), ldc, relu, bias + j,
addOnRow);
} else {
InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int32, &C(i, j), ldc, relu, bias + i,
addOnRow);
}
}
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int32);
paddle_mobile::memory::Free(zero_int8);
}
// 8 bits int matrix product (m*k x k*n), omp version
template <typename Otype>
void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb,
float beta, Otype *C, int32_t ldc, bool relu,
int32_t *bias, bool addOnRow) {
#ifdef _OPENMP
int32_t max_threads = omp_get_max_threads();
#else
int32_t max_threads = 1;
#endif
int32_t L1 = 64 / max_threads * 1024;
const int32_t k_complete = (k + 15) - ((k + 15) & 15);
KC = k_complete;
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * k);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(int8_t));
if (MC == 0) {
MC = MR_INT8;
} else {
int32_t mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8;
}
// 补齐 B
NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8;
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__
// TODO()
#else
PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#endif
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(int8_t));
if (NC == 0) {
NC = NR_INT8;
} else {
int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8;
}
// 补齐 A
MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8;
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
#if __aarch64__
// TODO()
#else
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#endif
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads));
}
packedC_int32 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads));
if (m > n) {
#pragma omp parallel for
for (int32_t i = 0; i < m; i += MC) {
#ifdef _OPENMP
int32_t local_threads = omp_get_thread_num();
#else
int32_t local_threads = 0;
#endif
int32_t mc;
mc = s_min(m - i, MC);
int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO()
#else
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif
if (bias == nullptr) {
InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu);
} else {
if (addOnRow) {
InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta,
local_C, &C(i, 0), ldc, relu, bias, addOnRow);
} else {
InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta,
local_C, &C(i, 0), ldc, relu, bias + i, addOnRow);
}
}
}
} else {
#pragma omp parallel for
for (int32_t j = 0; j < n; j += NC) {
#ifdef _OPENMP
int32_t local_threads = omp_get_thread_num();
#else
int32_t local_threads = 0;
#endif
int32_t nc;
nc = s_min(n - j, NC);
int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO()
#else
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#endif
if (bias == nullptr) {
InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu);
} else {
if (addOnRow) {
InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta,
local_C, &C(0, j), ldc, relu, bias + j, addOnRow);
} else {
InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta,
local_C, &C(0, j), ldc, relu, bias, addOnRow);
}
}
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int32);
paddle_mobile::memory::Free(zero_int8);
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
此差异已折叠。
......@@ -27,130 +27,17 @@ namespace paddle_mobile {
namespace operators {
namespace math {
// 8 bits int matrix product (m*k x k*n)
void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb,
int8_t beta, int32_t *C, int32_t ldc, bool relu,
int8_t *bias) {
#ifdef _OPENMP
int32_t max_threads = omp_get_max_threads();
#else
int32_t max_threads = 1;
#endif
int32_t L1 = 64 / max_threads * 1024;
KC = k;
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * KC);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(int8_t));
if (MC == 0) {
MC = MR_INT8;
} else {
int32_t mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8;
}
// 补齐 B
NC = (n + NR - 1) / NR * NR;
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixB_omp_8c(KC, n, n % NR, B, ldb, packedB_int8);
#endif
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(int8_t));
if (NC == 0) {
NC = NR;
} else {
int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// 补齐 A
MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8;
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixA_omp_4r(m, KC, m % MR_INT8, A, lda, packedA_int8);
#endif
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads));
}
packedC_int8 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads));
if (m > n) {
#pragma omp parallel for
for (int32_t i = 0; i < m; i += MC) {
#ifdef _OPENMP
int32_t local_threads = omp_get_thread_num();
#else
int32_t local_threads = 0;
#endif
int32_t mc;
mc = s_min(m - i, MC);
int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int8 + MC * NC * local_threads;
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif
InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu, bias + i);
}
} else {
#pragma omp parallel for
for (int32_t j = 0; j < n; j += NC) {
#ifdef _OPENMP
int32_t local_threads = omp_get_thread_num();
#else
int32_t local_threads = 0;
#endif
int32_t nc;
nc = s_min(n - j, NC);
int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int8 + MC * NC * local_threads;
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, local_B);
#endif
InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu, bias);
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int8);
paddle_mobile::memory::Free(zero_int8);
}
void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
#pragma omp parallel for
for (int32_t j = 0; j < j_length; j += NR) {
for (int32_t j = 0; j < j_length; j += 8) {
int8_t *local_buffer = buffer + j * k;
for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j);
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
asm volatile(
// "pld [%[b0]] \n\t"
......@@ -179,7 +66,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
for (int32_t j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int32_t j = n; j < j_length + NR; ++j) {
for (int32_t j = n; j < j_length + 8; ++j) {
*local_buffer++ = 0;
}
}
......@@ -188,9 +75,9 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) {
const int i_length = m - m_tail;
const int32_t i_length = m - m_tail;
#pragma omp parallel for
for (int32_t i = 0; i < i_length; i += MR_INT8) {
for (int32_t i = 0; i < i_length; i += 4) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
......@@ -221,7 +108,7 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
default:
break;
}
for (int j = 0; j < k; ++j) {
for (int32_t j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
......@@ -230,6 +117,232 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
}
}
// 8 bits int PackMatrixA_4r
void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t i = 0; i < i_length; i += 4) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
const int8_t *a3 = A + (i + 3) * lda;
int8_t *local_buffer = buffer + i * KC;
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (m_tail != 0) {
const int8_t *a0 = &A(i_length, 0);
const int8_t *a1 = a0 + lda;
const int8_t *a2 = a0 + 2 * lda;
const int8_t *a3 = a0 + 3 * lda;
int8_t *local_buffer = buffer + i_length * KC;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
break;
default:
break;
}
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t j = 0; j < j_length; j += 2) {
int8_t *local_buffer = buffer + j * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j);
const int8_t *b1 = &B((i << 4), j + 1);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b1;
b1 += ldb;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j);
const int8_t *b1 = &B((k_count << 4), j + 1);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (n_tail != 0) {
int8_t *local_buffer = buffer + j_length * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j_length);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = 0;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j_length);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -34,12 +34,12 @@ struct GRUUnitFunctor<CPU, T> {
gemm.Sgemm_omp(batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight,
frame_size * 2, 1, value.gate_value, frame_size * 3, false,
nullptr);
static_cast<float *>(nullptr));
#else
gemm.Sgemm(batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight,
frame_size * 2, 1, value.gate_value, frame_size * 3, false,
nullptr);
static_cast<float *>(nullptr));
#endif
}
......@@ -51,12 +51,12 @@ struct GRUUnitFunctor<CPU, T> {
gemm.Sgemm_omp(batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight,
frame_size, 1, value.gate_value + frame_size * 2,
frame_size * 3, false, nullptr);
frame_size * 3, false, static_cast<float *>(nullptr));
#else
gemm.Sgemm(batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight,
frame_size, 1, value.gate_value + frame_size * 2,
frame_size * 3, false, nullptr);
frame_size * 3, false, static_cast<float *>(nullptr));
#endif
}
......
......@@ -28,7 +28,13 @@ template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false,
T *bias = nullptr);
float *bias = nullptr);
template <typename T, typename S>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false,
S *bias = nullptr, bool addOnRow = false);
template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
......@@ -20,11 +20,12 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
template <>
void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
int8_t alpha, framework::Tensor *matrix_out, int8_t beta,
bool relu, int8_t *bias) {
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu, int32_t *bias,
bool addOnRow) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -52,21 +53,43 @@ void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a,
}
#ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
if (bias != nullptr) {
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias, addOnRow);
} else {
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias, addOnRow);
}
#else
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
if (bias != nullptr) {
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias, addOnRow);
} else {
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias, addOnRow);
}
#endif
} else {
#ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
if (bias != nullptr) {
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias, addOnRow);
} else {
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias, addOnRow);
}
#else
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(), N,
relu, bias);
if (bias != nullptr) {
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int8_t>(),
N, relu, bias, addOnRow);
} else {
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(),
N, relu, bias, addOnRow);
}
#endif
}
}
......
......@@ -38,6 +38,7 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]);
output->mutable_data<float>();
const int hxw = input_height * input_width;
......@@ -472,7 +473,7 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
const int inputdata_channel_stride = h_in * w_in;
const int input_batch_stride = output_channels * inputdata_channel_stride;
const int output_batch_stride = output_channels * outputdata_channel_stride;
float *out_data = output->data<float>();
float *out_data = output->mutable_data<float>();
const float *input_data = input->data<float>();
for (int k = 0; k < batch_size; ++k) {
#pragma omp parallel for
......
......@@ -28,15 +28,21 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
using framework::Tensor;
using std::vector;
void Pool3x3Avgs1p1(const Tensor *input, Tensor *output);
void Pool3x3Maxs1p1(const Tensor *input, Tensor *output);
void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
Tensor *output);
void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *in_x,
Tensor *out);
void Pool3x3Avgs1p1(const framework::Tensor *input, framework::Tensor *output);
void Pool3x3Maxs1p1(const framework::Tensor *input, framework::Tensor *output);
void Pool3x3Max(std::vector<int> strides, std::vector<int> paddings,
const framework::Tensor *input, framework::Tensor *output);
void Pool3x3Avg(std::vector<int> strides, std::vector<int> paddings,
const framework::Tensor *in_x, framework::Tensor *out);
void Pool3x3Maxs1_int8(const framework::Tensor *input,
framework::Tensor *output, int32_t pad_h, int32_t pad_w);
void Pool3x3Maxs2_int8(const framework::Tensor *input,
framework::Tensor *output, int32_t pad_h, int32_t pad_w);
void Pool3x3Max_int8(const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *input, framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......
此差异已折叠。
......@@ -70,15 +70,15 @@ class PoolFunctor<CPU, PoolProcess, T> {
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
T ele = pool_process.initial();
auto ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.compute(input_data[h * input_width + w], &ele);
}
}
int pool_size = (hend - hstart) * (wend - wstart);
pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[ph * output_width + pw] = ele;
pool_process.finalize(static_cast<float>(pool_size), &ele);
output_data[ph * output_width + pw] = static_cast<T>(ele);
}
}
input_data += input_stride;
......@@ -88,8 +88,10 @@ class PoolFunctor<CPU, PoolProcess, T> {
}
};
template class PoolFunctor<CPU, math::AvgPool<float>, float>;
template class PoolFunctor<CPU, math::AvgPool<float, float>, float>;
template class PoolFunctor<CPU, math::MaxPool<float>, float>;
template class PoolFunctor<CPU, math::AvgPool<int8_t, int32_t>, int8_t>;
template class PoolFunctor<CPU, math::MaxPool<int8_t>, int8_t>;
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......
......@@ -16,6 +16,8 @@ limitations under the License. */
#pragma once
#include <climits>
#include <cmath>
#include "common/log.h"
#include "framework/tensor.h"
#include "pool_2x2.h"
......@@ -37,24 +39,42 @@ namespace math {
* in pool pooling, and finally takes the average.
* MaxPoolGrad and AvgPoolGrad are gradient operations respectively.
*/
template <class T>
template <typename T>
class MaxPool {
public:
inline T initial() { return static_cast<T>(-FLT_MAX); }
inline T initial() {
if (typeid(T) == typeid(int8_t)) {
return static_cast<T>(-SCHAR_MAX);
}
return static_cast<T>(-FLT_MAX);
}
inline void compute(const T &x, T *y) { *y = *y > x ? *y : x; }
inline void finalize(const T &pool_field, T *y) {}
};
template <class T>
template <typename Itype, typename Otype>
class AvgPool {
public:
inline T initial() { return static_cast<T>(0); }
inline void compute(const T &x, T *y) { *y += x; }
inline void finalize(const T &pool_field, T *y) { *y /= pool_field; }
inline Otype initial() { return static_cast<Otype>(0); }
inline void compute(const Itype &x, Otype *y) { *y += x; }
inline void finalize(const float &pool_field, Otype *y) {
if (typeid(Itype) == typeid(int8_t)) {
float tmp = *y / pool_field;
if (tmp > SCHAR_MAX) {
*y = SCHAR_MAX;
} else if (tmp < -SCHAR_MAX) {
*y = -SCHAR_MAX;
} else {
*y = static_cast<Otype>(std::round(tmp));
}
} else {
*y /= pool_field;
}
}
};
template <typename DeviceType, typename PoolProcess, typename T>
......
......@@ -58,6 +58,9 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(mul, ops::MulOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(mul, ops::MulOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(mul, ops::MulOp);
#endif
......
......@@ -439,7 +439,7 @@ class ConvParam : public OpParam {
#endif
private:
protected:
RType *input_;
RType *output_;
RType *filter_;
......@@ -1632,6 +1632,10 @@ class FusionFcParam : public OpParam {
x_num_col_dims_ = GetAttr<int>("x_num_col_dims", attrs);
y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
axis_ = GetAttr<int>("axis", attrs);
#ifdef FUSION_FC_INT8_OP
scale_ = InputScaleFrom<GType>(inputs, scope);
#endif
}
GType *InputX() const { return input_x_; }
......@@ -1655,8 +1659,16 @@ class FusionFcParam : public OpParam {
int x_num_col_dims_;
int y_num_col_dims_;
int axis_;
#ifdef PADDLE_MOBILE_FPGA
#ifdef FUSION_FC_INT8_OP
public:
const RType *InputScale() const { return scale_; }
private:
RType *scale_;
#endif
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::SplitConvArgs fpga_conv_args;
......@@ -1707,7 +1719,19 @@ class FusionConvAddReluParam : public FusionConvAddParam<DeviceType> {
FusionConvAddReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionConvAddParam<DeviceType>(inputs, outputs, attrs, scope) {}
: FusionConvAddParam<DeviceType>(inputs, outputs, attrs, scope) {
#ifdef FUSION_CONVADDRELU_INT8_OP
scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
#endif
}
#ifdef FUSION_CONVADDRELU_INT8_OP
typedef typename DtypeTensorTrait<DeviceType>::gtype GType;
typedef typename DtypeTensorTrait<DeviceType>::rtype RType;
const RType *InputScale() const { return scale_; }
private:
RType *scale_;
#endif
};
#endif
......
......@@ -269,8 +269,8 @@ if (NOT FOUND_MATCH)
#gen test
ADD_EXECUTABLE(test-pool operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-pool paddle-mobile)
ADD_EXECUTABLE(test-pool-op operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-pool-op paddle-mobile)
#gen test
ADD_EXECUTABLE(test-softmax operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h)
......@@ -324,6 +324,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-conv-add-relu-op operators/test_conv_add_relu_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-conv-add-relu-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-conv-add-relu-int8-op operators/test_fusion_conv_add_relu_int8_op.cpp test_helper.h test_include.h)
target_link_libraries(test-conv-add-relu-int8-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-conv-add-bn-relu-op operators/test_fusion_conv_add_bn_relu_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-conv-add-bn-relu-op paddle-mobile)
......
......@@ -25,7 +25,7 @@ limitations under the License. */
#define c(i, j) c[(i)*ldc + (j)]
#define c1(i, j) c1[(i)*ldc + (j)]
void print_matirx(int m, int n, int ldc, float *c) {
void print_matrix(int m, int n, int ldc, float *c) {
for (int i = 0; i < m; ++i) {
std::cout << c(i, 0);
for (int j = 1; j < n; ++j) {
......@@ -98,18 +98,20 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
if (pr > 0) {
std::cout << "A:" << std::endl;
print_matirx(m, k, lda, a);
print_matrix(m, k, lda, a);
std::cout << "B:" << std::endl;
print_matirx(k, n, ldb, b);
print_matrix(k, n, ldb, b);
std::cout << "C:" << std::endl;
print_matirx(m, n, ldc, c);
print_matrix(m, n, ldc, c);
std::cout << "C1:" << std::endl;
print_matirx(m, n, ldc, c1);
print_matrix(m, n, ldc, c1);
}
std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu
<< " eq=" << eq << " neq=" << neq << std::endl;
PADDLE_MOBILE_ENFORCE(neq == 0, "The execution of do_sgemm is failed!");
paddle_mobile::memory::Free(a);
paddle_mobile::memory::Free(b);
paddle_mobile::memory::Free(c);
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <limits>
#include <random>
#include <type_traits>
#include "../test_helper.h"
#include "common/log.h"
#include "memory/t_malloc.h"
......@@ -32,26 +34,65 @@ limitations under the License. */
using std::default_random_engine;
using std::uniform_int_distribution;
void print_matirx(int m, int n, int ldc, int32_t *c) {
template <typename T>
void print_matrix(int m, int n, int ldc, T *c) {
for (int i = 0; i < m; ++i) {
std::cout << c(i, 0);
if (std::is_same<T, int8_t>::value) {
std::cout.setf(std::ios::left);
std::cout.width(4);
std::cout << static_cast<int32_t>(c(i, 0));
} else {
std::cout.setf(std::ios::left);
std::cout.width(6);
std::cout << c(i, 0);
}
for (int j = 1; j < n; ++j) {
std::cout << " | " << c(i, j);
if (std::is_same<T, int8_t>::value) {
std::cout << " | ";
std::cout.setf(std::ios::left);
std::cout.width(4);
std::cout << static_cast<int32_t>(c(i, j));
} else {
std::cout << " | ";
std::cout.setf(std::ios::left);
std::cout.width(6);
std::cout << c(i, j);
}
}
std::cout << std::endl;
std::cout << "\n";
}
std::cout << std::endl;
}
void print_matirx(int m, int n, int ldc, int8_t *c) {
for (int i = 0; i < m; ++i) {
std::cout << static_cast<int32_t>(c(i, 0));
for (int j = 1; j < n; ++j) {
std::cout << " | " << static_cast<int32_t>(c(i, j));
}
std::cout << std::endl;
}
std::cout << std::endl;
int32_t qadd_int32(int32_t l, int32_t r) {
int64_t res = static_cast<int64_t>(l) + static_cast<int64_t>(r);
if (res > std::numeric_limits<int32_t>::max())
return std::numeric_limits<int32_t>::max();
else if (res < std::numeric_limits<int32_t>::min())
return std::numeric_limits<int32_t>::min();
else
return static_cast<int32_t>(res);
}
// round to zero
float round2zero(float v) {
float res;
if (v > 0)
res = std::floor(v);
else if (v < 0)
res = std::ceil(v);
return res;
}
int8_t qscale_int32(int32_t v, float scale) {
float res = static_cast<float>(v) * scale;
res = round2zero(res);
if (res > 127)
return static_cast<int8_t>(127);
else if (res < -127)
return static_cast<int8_t>(-127);
else
return static_cast<int8_t>(res);
}
int do_sgemm(int m, int n, int k, bool relu, int pr) {
......@@ -106,30 +147,152 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) {
if (pr > 0) {
std::cout << "A:" << std::endl;
print_matirx(m, k, lda, a);
print_matrix(m, k, lda, a);
std::cout << "B:" << std::endl;
print_matrix(k, n, ldb, b);
std::cout << "C:" << std::endl;
print_matrix(m, n, ldc, c);
std::cout << "C1:" << std::endl;
print_matrix(m, n, ldc, c1);
}
std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu
<< " eq=" << eq << " neq=" << neq << std::endl;
PADDLE_MOBILE_ENFORCE(neq == 0, "The execution of do_sgemm is failed!");
paddle_mobile::memory::Free(a);
paddle_mobile::memory::Free(b);
paddle_mobile::memory::Free(c);
paddle_mobile::memory::Free(c1);
return 0;
}
int do_sgemm_with_bias(int m, int n, int k, bool relu, int pr,
bool addOnRow = false) {
int lda = k;
int ldb = n;
int ldc = n;
float scale = 0.00628f;
default_random_engine e;
uniform_int_distribution<int8_t> pixel(-127, 127);
int8_t *a = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * m * k));
int8_t *b = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * k * n));
int8_t *c = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * m * n));
int8_t *c1 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * m * n));
int32_t *bias = nullptr;
if (addOnRow) {
bias = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * n));
} else {
bias = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * m));
}
for (int i = 0; i < m * k; ++i) {
a[i] = pixel(e);
}
for (int i = 0; i < k * n; ++i) {
b[i] = pixel(e);
}
if (addOnRow) {
for (int i = 0; i < n; ++i) {
bias[i] = static_cast<int32_t>(pixel(e));
}
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
int32_t bias_v = bias[j];
int32_t r = 0;
for (int p = 0; p < k; p++) {
r += static_cast<int32_t>(a(i, p)) * static_cast<int32_t>(b(p, j));
}
r = qadd_int32(r, bias_v);
if (relu) r = std::max(0, r);
c1(i, j) = qscale_int32(r, scale);
}
}
} else {
for (int i = 0; i < m; ++i) {
bias[i] = static_cast<int32_t>(pixel(e));
}
for (int i = 0; i < m; ++i) {
int32_t bias_v = bias[i];
for (int j = 0; j < n; ++j) {
int32_t r = 0;
for (int p = 0; p < k; p++) {
r += static_cast<int32_t>(a(i, p)) * static_cast<int32_t>(b(p, j));
}
r = qadd_int32(r, bias_v);
if (relu) r = std::max(0, r);
c1(i, j) = qscale_int32(r, scale);
}
}
}
paddle_mobile::operators::math::Gemm gemm;
#ifdef _OPENMP
gemm.Sgemm_omp(m, n, k, scale, a, lda, b, ldb, static_cast<float>(0), c, ldc,
relu, bias, addOnRow);
#else
gemm.Sgemm(m, n, k, scale, a, lda, b, ldb, static_cast<float>(0), c, ldc,
relu, bias, addOnRow);
#endif
int eq = 0;
int neq = 0;
for (int i = 0; i < m * n; ++i) {
if (c[i] == c1[i]) {
++eq;
} else {
++neq;
}
}
if (pr > 0) {
std::cout << "A:" << std::endl;
print_matrix(m, k, lda, a);
std::cout << "B:" << std::endl;
print_matirx(k, n, ldb, b);
print_matrix(k, n, ldb, b);
std::cout << "Bias:" << std::endl;
if (addOnRow) {
print_matrix(1, n, n, bias);
} else {
print_matrix(m, 1, 1, bias);
}
std::cout << "C:" << std::endl;
print_matirx(m, n, ldc, c);
print_matrix(m, n, ldc, c);
std::cout << "C1:" << std::endl;
print_matirx(m, n, ldc, c1);
print_matrix(m, n, ldc, c1);
}
std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu
<< " eq=" << eq << " neq=" << neq << std::endl;
PADDLE_MOBILE_ENFORCE(neq == 0,
"The execution of do_sgemm_with_bias is failed!");
paddle_mobile::memory::Free(a);
paddle_mobile::memory::Free(b);
paddle_mobile::memory::Free(c);
paddle_mobile::memory::Free(c1);
paddle_mobile::memory::Free(bias);
return 0;
}
int main() {
#ifdef _OPENMP
omp_set_num_threads(8);
omp_set_num_threads(4);
#endif
std::cout << "\n\n******************************************************\n\n"
<< std::endl;
std::cout << "Test gemm without bias:" << std::endl;
do_sgemm(9, 9, 9, false, 1);
do_sgemm(10, 6, 12, false, 0);
do_sgemm(512, 256, 384, false, 0);
......@@ -140,5 +303,44 @@ int main() {
do_sgemm(333, 797, 939, false, 0);
do_sgemm(1024, 1024, 1024, false, 0);
std::cout << "\n\n******************************************************\n\n"
<< std::endl;
std::cout << "Test gemm with bias(bias is added on column):" << std::endl;
do_sgemm_with_bias(9, 9, 9, false, 1);
do_sgemm_with_bias(10, 6, 12, false, 0);
do_sgemm_with_bias(512, 256, 384, false, 0);
do_sgemm_with_bias(1366, 768, 256, false, 0);
do_sgemm_with_bias(1255, 755, 333, false, 0);
do_sgemm_with_bias(599, 1133, 393, false, 0);
do_sgemm_with_bias(777, 555, 999, false, 0);
do_sgemm_with_bias(333, 797, 939, false, 0);
do_sgemm_with_bias(1024, 1024, 1024, false, 0);
std::cout << "\n\n******************************************************\n\n"
<< std::endl;
std::cout << "Test gemm with bias(bias is added on row):" << std::endl;
do_sgemm_with_bias(9, 9, 9, false, 1, true);
do_sgemm_with_bias(10, 6, 12, false, 0, true);
do_sgemm_with_bias(512, 256, 384, false, 0, true);
do_sgemm_with_bias(1366, 768, 256, false, 0, true);
do_sgemm_with_bias(1255, 755, 333, false, 0, true);
do_sgemm_with_bias(599, 1133, 393, false, 0, true);
do_sgemm_with_bias(777, 555, 999, false, 0, true);
do_sgemm_with_bias(333, 797, 939, false, 0, true);
do_sgemm_with_bias(1024, 1024, 1024, false, 0, true);
std::cout << "\n\n******************************************************\n\n"
<< std::endl;
std::cout << "Test gemm with relu and bias:" << std::endl;
do_sgemm_with_bias(9, 9, 9, true, 1);
do_sgemm_with_bias(10, 6, 12, true, 0);
do_sgemm_with_bias(512, 256, 384, true, 0);
do_sgemm_with_bias(1366, 768, 256, true, 0);
do_sgemm_with_bias(1255, 755, 333, true, 0);
do_sgemm_with_bias(599, 1133, 393, true, 0);
do_sgemm_with_bias(777, 555, 999, true, 0);
do_sgemm_with_bias(333, 797, 939, true, 0);
do_sgemm_with_bias(1024, 1024, 1024, true, 0);
return 0;
}
......@@ -28,7 +28,7 @@ limitations under the License. */
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(8);
paddle_mobile.SetThreadNum(4);
Tensor aa, bb, cc;
auto aaptr = aa.mutable_data<float>({m, k});
auto bbptr = bb.mutable_data<float>({k, n});
......@@ -44,10 +44,13 @@ int main() {
ccptr[i] = 2;
}
Tensor aa_int8, bb_int8, cc_int8;
Tensor aa_int8, bb_int8, cc_int32, cc_int8;
auto aaptr_int8 = aa_int8.mutable_data<int8_t>({m, k});
auto bbptr_int8 = bb_int8.mutable_data<int8_t>({k, n});
auto ccptr_int8 = cc_int8.mutable_data<int32_t>({m, n});
auto ccptr_int32 = cc_int32.mutable_data<int32_t>({m, n});
auto ccptr_int8 = cc_int8.mutable_data<int8_t>({m, n});
int32_t* bias_data_col = new int32_t[m];
int32_t* bias_data_row = new int32_t[n];
for (int i = 0; i < m * k; ++i) {
aaptr_int8[i] = static_cast<int8_t>(2);
......@@ -56,7 +59,15 @@ int main() {
bbptr_int8[i] = static_cast<int8_t>(2);
}
for (int i = 0; i < m * n; ++i) {
ccptr_int8[i] = static_cast<int32_t>(2);
ccptr_int32[i] = static_cast<int32_t>(2);
}
for (int i = 0; i < m; ++i) {
bias_data_col[i] = 2;
}
for (int i = 0; i < n; ++i) {
bias_data_row[i] = 2;
}
// float
......@@ -67,31 +78,87 @@ int main() {
false, nullptr);
}
auto time1 = time();
auto time_start0 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float>(
aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
false, nullptr);
}
auto time2 = time();
std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n";
auto time_end0 = time();
std::cout << "float gemm cost :" << time_diff(time_start0, time_end0) / 10
<< "ms\n";
// int8_t
// int8_t without bias
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t>(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8,
static_cast<int8_t>(0), false, nullptr);
paddle_mobile::operators::math::matmul<float, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0));
}
auto time_start1 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0));
}
auto time_end1 = time();
std::cout << "int8_t gemm cost :" << time_diff(time_start1, time_end1) / 10
<< "ms\n";
auto time3 = time();
// int8_t with bias, column element wise add
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_col, false);
}
auto time_start2 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t>(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8,
static_cast<int8_t>(0), false, nullptr);
paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_col, false);
}
auto time4 = time();
std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n";
auto time_end2 = time();
std::cout << "int8_t gemm_with_bias(column add) cost :"
<< time_diff(time_start2, time_end2) / 10 << "ms\n";
// int8_t with bias, row element wise add
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_row, true);
}
auto time_start3 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_row, true);
}
auto time_end3 = time();
std::cout << "int8_t gemm_with_bias(row add) cost :"
<< time_diff(time_start3, time_end3) / 10 << "ms\n";
// int8_t with bias&relu
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), true, bias_data_col, false);
}
auto time_start4 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), true, bias_data_col, false);
}
auto time_end4 = time();
std::cout << "int8_t gemm_with_bias_relu cost :"
<< time_diff(time_start4, time_end4) / 10 << "ms\n";
delete[] bias_data_row;
delete[] bias_data_col;
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#ifdef FUSION_CONVADDRELU_INT8_OP
#include <limits>
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/fusion_conv_add_relu_int8_op.h"
namespace paddle_mobile {
int32_t qadd_int32(int32_t l, int32_t r) {
int64_t res = static_cast<int64_t>(l) + static_cast<int64_t>(r);
if (res > std::numeric_limits<int32_t>::max())
return std::numeric_limits<int32_t>::max();
else if (res < std::numeric_limits<int32_t>::min())
return std::numeric_limits<int32_t>::min();
else
return static_cast<int32_t>(res);
}
// round to zero
float round2zero(float v) {
float res;
if (v > 0)
res = std::floor(v);
else if (v < 0)
res = std::ceil(v);
return res;
}
int8_t qscale_int32(int32_t v, float scale) {
float res = static_cast<float>(v) * scale;
res = round2zero(res);
if (res > 127)
return static_cast<int8_t>(127);
else if (res < -127)
return static_cast<int8_t>(-127);
else
return static_cast<int8_t>(res);
}
// Reference convolution from Caffe for checking results.
// accumulate through explicit loops over input, output, and filters.
template <typename T>
void conv2d(const framework::Tensor *input, const framework::Tensor *filter,
const framework::Tensor *bias, const framework::AttributeMap &attrs,
framework::Tensor *output, float scale) {
framework::AttrReader attr_reader(attrs);
std::vector<int> paddings = attr_reader.Get<std::vector<int>>("paddings");
std::vector<int> strides = attr_reader.Get<std::vector<int>>("strides");
std::vector<int> dilations = attr_reader.Get<std::vector<int>>("dilations");
int groups = attr_reader.Get<int>("groups");
int kernel_h = filter->dims()[2];
int kernel_w = filter->dims()[3];
int pad_h = paddings[0];
int pad_w = paddings[1];
int stride_h = strides[0];
int stride_w = strides[1];
int dilation_h = dilations[0];
int dilation_w = dilations[1];
auto in_shape = input->dims();
auto out_shape = output->dims();
const bool has_depth = 0;
int kernel_d, pad_d, stride_d, dilation_d;
if (has_depth) {
kernel_d = kernel_h;
stride_d = stride_h;
pad_d = pad_h;
dilation_d = dilation_h;
} else {
kernel_d = stride_d = dilation_d = 1;
pad_d = 0;
}
// Groups
int o_g = out_shape[1] / groups;
int k_g = in_shape[1] / groups;
int o_head, k_head;
// Convolution
vector<int> weight_offset(4 + has_depth);
vector<int> in_offset(4 + has_depth);
vector<int> out_offset(4 + has_depth);
auto offset = [](const framework::Tensor *input, const vector<int> &indics) {
framework::DDim shape = input->dims();
size_t count = 0;
for (int i = 0; i < indics.size(); ++i) {
count *= shape[i];
count += indics[i];
}
return count;
};
const T *in_data = input->data<T>();
const T *w_data = filter->data<T>();
framework::Tensor output_32;
int32_t *out_data_32 = output_32.mutable_data<int32_t>(out_shape);
memset(out_data_32, 0, output_32.numel() * sizeof(int32_t));
for (int n = 0; n < out_shape[0]; n++) {
for (int g = 0; g < groups; g++) {
o_head = o_g * g;
k_head = k_g * g;
for (int o = 0; o < o_g; o++) {
for (int k = 0; k < k_g; k++) {
for (int z = 0; z < (has_depth ? out_shape[2] : 1); z++) {
for (int y = 0; y < out_shape[2 + has_depth]; y++) {
for (int x = 0; x < out_shape[3 + has_depth]; x++) {
for (int r = 0; r < kernel_d; r++) {
for (int p = 0; p < kernel_h; p++) {
for (int q = 0; q < kernel_w; q++) {
int in_z = z * stride_d - pad_d + r * dilation_d;
int in_y = y * stride_h - pad_h + p * dilation_h;
int in_x = x * stride_w - pad_w + q * dilation_w;
if (in_z >= 0 && in_z < (has_depth ? in_shape[2] : 1) &&
in_y >= 0 && in_y < in_shape[2 + has_depth] &&
in_x >= 0 && in_x < in_shape[3 + has_depth]) {
weight_offset[0] = o + o_head;
weight_offset[1] = k;
if (has_depth) {
weight_offset[2] = r;
}
weight_offset[2 + has_depth] = p;
weight_offset[3 + has_depth] = q;
in_offset[0] = n;
in_offset[1] = k + k_head;
if (has_depth) {
in_offset[2] = in_z;
}
in_offset[2 + has_depth] = in_y;
in_offset[3 + has_depth] = in_x;
out_offset[0] = n;
out_offset[1] = o + o_head;
if (has_depth) {
out_offset[2] = z;
}
out_offset[2 + has_depth] = y;
out_offset[3 + has_depth] = x;
out_data_32[offset(output, out_offset)] +=
in_data[offset(input, in_offset)] *
w_data[offset(filter, weight_offset)];
}
}
}
}
}
}
}
}
}
}
}
T *out_data = output->mutable_data<T>();
int32_t n = out_shape[0];
int32_t c = out_shape[1];
int32_t h = out_shape[2];
int32_t w = out_shape[3];
const int32_t *bias_data = bias->data<int32_t>();
for (int i = 0; i < n; ++i) {
for (int j = 0; j < c; ++j) {
int32_t bias_v = bias_data[j];
for (int k = 0; k < h; ++k) {
for (int l = 0; l < w; ++l) {
int32_t tmp = out_data_32[i * c * h * w + j * h * w + k * w + l];
tmp = qadd_int32(tmp, bias_v);
tmp = std::max(0, tmp);
out_data[i * c * h * w + j * h * w + k * w + l] =
qscale_int32(tmp, scale);
}
}
}
}
}
template <typename T, int Kernel, int Pad, int Stride>
int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) {
int kernel_h = Kernel;
int kernel_w = Kernel;
int pad_h = Pad;
int pad_w = Pad;
int stride_h = Stride;
int stride_w = Stride;
int dilation_h = 1;
int dilation_w = 1;
int batch_size = 1;
int input_c = in_channels;
int input_h = in_height;
int input_w = in_width;
int output_c = out_channels;
framework::DDim input_shape =
framework::make_ddim({batch_size, input_c, input_h, input_w});
framework::DDim filter_shape =
framework::make_ddim({output_c, input_c, kernel_h, kernel_w});
int kernel_extent_h = dilation_h * (kernel_h - 1) + 1;
int kernel_extent_w = dilation_w * (kernel_w - 1) + 1;
int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1;
int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1;
framework::DDim output_shape = framework::make_ddim(
std::vector<int>({batch_size, output_c, output_h, output_w}));
framework::DDim bias_shape = framework::make_ddim({output_c});
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["Input"] = std::vector<std::string>({"input"});
inputs["Filter"] = std::vector<std::string>({"filter"});
inputs["Scale"] = std::vector<std::string>({"scale"});
inputs["Y"] = std::vector<std::string>({"bias"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<T>(input, input_shape, -127, 127);
auto filter_var = scope.get()->Var("filter");
auto filter = filter_var->template GetMutable<framework::LoDTensor>();
SetupTensor<T>(filter, filter_shape, -127, 127);
auto scale_var = scope.get()->Var("scale");
auto scale = scale_var->template GetMutable<framework::LoDTensor>();
scale->Resize(framework::make_ddim({1}));
float scale_v = 0.000828f;
scale->mutable_data<float>()[0] = scale_v;
auto bias_var = scope.get()->Var("bias");
auto bias = bias_var->template GetMutable<framework::LoDTensor>();
SetupTensor<int32_t>(bias, bias_shape, -127, 127);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["strides"].Set<vector<int>>(std::vector<int>({stride_h, stride_w}));
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad_h, pad_w}));
attrs["dilations"].Set<vector<int>>(
std::vector<int>({dilation_h, dilation_w}));
attrs["groups"].Set<int>(1);
attrs["axis"].Set<int>(0);
auto *op = new operators::FusionConvAddReluInt8Op<CPU, T>(
"fusion_conv_add_relu_int8", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
framework::Tensor output_cmp;
output_cmp.mutable_data<T>(output_shape);
conv2d<T>(input, filter, bias, attrs, &output_cmp, scale_v);
// compare results
int eq = 0;
int neq = 0;
auto output = output_var->template Get<framework::LoDTensor>();
const T *output_data = output->data<T>();
T *output_cmp_data = output_cmp.data<T>();
for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(
output_data[i] == output_cmp_data[i],
"The execution of test_fusion_conv_add_relu_int8_op is failed!");
if (output_data[i] == output_cmp_data[i]) {
++eq;
} else {
++neq;
}
}
std::cout << "eq = " << eq << ", neq = " << neq << std::endl;
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
if (argc < 5) {
LOG(paddle_mobile::kLOG_INFO)
<< "Usage:\n"
<< " ./test-conv-add-relu-int8-op in_channels in_height in_width "
"out_channels\n"
<< " params:\n"
<< " -in_channels: int, input image's channels\n"
<< " -in_height: int, input image's height\n"
<< " -in_width: int, input image's width\n"
<< " -out_channels: int, conv output channels\n";
return 1;
}
int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]);
int out_channels = atoi(argv[4]);
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8_t, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, 3, 1, 1>(in_channels, in_height, in_width,
out_channels);
// kernel = 7, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2";
paddle_mobile::TestConvOp<int8_t, 7, 0, 2>(in_channels, in_height, in_width,
out_channels);
// kernel = 7, pad = 1, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2";
paddle_mobile::TestConvOp<int8_t, 7, 1, 2>(in_channels, in_height, in_width,
out_channels);
// kernel = 7, pad = 3, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2";
paddle_mobile::TestConvOp<int8_t, 7, 3, 2>(in_channels, in_height, in_width,
out_channels);
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, 7, 0, 1>(in_channels, in_height, in_width,
out_channels);
// kernel = 7, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, 7, 1, 1>(in_channels, in_height, in_width,
out_channels);
// kernel = 7, pad = 3, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1";
paddle_mobile::TestConvOp<int8_t, 7, 3, 1>(in_channels, in_height, in_width,
out_channels);
// kernel = 7, pad = 5, stride = 3
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3";
paddle_mobile::TestConvOp<int8_t, 7, 5, 3>(in_channels, in_height, in_width,
out_channels);
// kernel = 7, pad = 3, stride = 4
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4";
paddle_mobile::TestConvOp<int8_t, 7, 3, 4>(in_channels, in_height, in_width,
out_channels);
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, 3, 0, 1>(in_channels, in_height, in_width,
out_channels);
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, 3, 1, 1>(in_channels, in_height, in_width,
out_channels);
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, 5, 0, 1>(in_channels, in_height, in_width,
out_channels);
// kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<int8_t, 5, 2, 1>(in_channels, in_height, in_width,
out_channels);
}
#else
int main() {
std::cout << "FUSION_CONVADDRELU_INT8_OP is not defined!" << std::endl;
return 0;
}
#endif
......@@ -12,147 +12,163 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <framework/program/program-optimize/program_optimize.h>
#include <iostream>
#include <type_traits>
#include "../test_helper.h"
#include "../test_include.h"
#include "framework/operator.h"
#include "operators/fusion_fc_int8_op.h"
#include "operators/fusion_fc_op.h"
#define a(i, j) a[(i)*lda + (j)]
#define b(i, j) b[(i)*ldb + (j)]
#define c(i, j) c[(i)*ldc + (j)]
namespace paddle_mobile {
namespace framework {
using framework::AttributeMap;
using framework::DDim;
using framework::Scope;
using framework::make_ddim;
int32_t qadd_int32(int32_t l, int32_t r) {
int64_t res = static_cast<int64_t>(l) + static_cast<int64_t>(r);
if (res > std::numeric_limits<int32_t>::max())
return std::numeric_limits<int32_t>::max();
else if (res < std::numeric_limits<int32_t>::min())
return std::numeric_limits<int32_t>::min();
else
return static_cast<int32_t>(res);
}
template <typename Dtype>
class TestFcOp {
public:
explicit TestFcOp(const Program<Dtype> p) : program_(p) {
use_optimize_ = true;
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
// round to zero
float round2zero(float v) {
float res;
if (v > 0)
res = std::floor(v);
else if (v < 0)
res = std::ceil(v);
return res;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j];
if (op->Type() == "fc" && op->Input("X")[0] == "pool2d_13.tmp_0") {
DLOG << " fc attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
DLOG << " Input X is : " << op->Input("X")[0];
DLOG << " Input Y is : " << op->Input("Y")[0];
DLOG << " Input Y is : " << op->Input("Z")[0];
DLOG << " Output Out is : " << op->Output("Out")[0];
std::shared_ptr<operators::FusionFcOp<Dtype, float>> testOp =
std::make_shared<operators::FusionFcOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(testOp);
int8_t qscale_int32(int32_t v, float scale) {
float res = static_cast<float>(v) * scale;
res = round2zero(res);
if (res > 127)
return static_cast<int8_t>(127);
else if (res < -127)
return static_cast<int8_t>(-127);
else
return static_cast<int8_t>(res);
}
template <typename T, typename S>
int TestFcOP() {
int32_t m = 377;
int32_t n = 1363;
int32_t k = 577;
int32_t lda = k;
int32_t ldb = n;
int32_t ldc = n;
DDim inputA_shape = make_ddim({m, k});
DDim inputB_shape = make_ddim({k, n});
DDim bias_shape = make_ddim({n});
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<Scope>();
inputs["X"] = std::vector<std::string>({"inputA"});
inputs["Y"] = std::vector<std::string>({"inputB"});
inputs["Z"] = std::vector<std::string>({"bias"});
inputs["Scale"] = std::vector<std::string>({"scale"});
outputs["Out"] = std::vector<std::string>({"output"});
auto inputA_var = scope.get()->Var("inputA");
auto inputA = inputA_var->template GetMutable<framework::LoDTensor>();
SetupTensor<T>(inputA, inputA_shape, -127, 127);
auto inputB_var = scope.get()->Var("inputB");
auto inputB = inputB_var->template GetMutable<framework::LoDTensor>();
SetupTensor<T>(inputB, inputB_shape, -127, 127);
auto bias_var = scope.get()->Var("bias");
auto bias = bias_var->template GetMutable<framework::LoDTensor>();
SetupTensor<S>(bias, bias_shape, -127, 127);
auto scale_var = scope.get()->Var("scale");
auto scale = scale_var->template GetMutable<framework::LoDTensor>();
scale->Resize(framework::make_ddim({1}));
float scale_v = 0.000828f;
scale->mutable_data<float>()[0] = scale_v;
auto output_var = scope.get()->Var("output");
AttributeMap attrs;
attrs["x_num_col_dims"].Set<int>(1);
attrs["y_num_col_dims"].Set<int>(1);
attrs["axis"].Set<int>(1);
operators::OperatorBase<CPU> *op = nullptr;
#ifdef FUSION_FC_INT8_OP
if (std::is_same<T, int8_t>::value) {
op = new operators::FusionFcInt8Op<CPU, T>("fusion_fc_int8", inputs,
outputs, attrs, scope);
} else {
op = new operators::FusionFcOp<CPU, T>("fusion_fc", inputs, outputs, attrs,
scope);
}
#else
op = new operators::FusionFcOp<CPU, T>("fusion_fc", inputs, outputs, attrs,
scope);
#endif
op->InferShape();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
const T *output_data = output->data<T>();
// compare
T *c = static_cast<T *>(memory::Alloc(sizeof(T) * m * n));
T *a = inputA->data<T>();
T *b = inputB->data<T>();
S *bias_data = bias->data<S>();
for (int32_t i = 0; i < m; ++i) {
for (int32_t j = 0; j < n; ++j) {
S bias_v = bias_data[j];
if (std::is_same<T, int8_t>::value) {
int32_t r = 0;
for (int32_t p = 0; p < k; p++) {
r += static_cast<int32_t>(a(i, p)) * static_cast<int32_t>(b(p, j));
}
r = qadd_int32(r, bias_v);
c(i, j) = qscale_int32(r, scale_v);
} else {
T r = 0;
for (int32_t p = 0; p < k; p++) {
r += a(i, p) * b(p, j);
}
r += bias_v;
c(i, j) = r;
}
}
}
std::shared_ptr<Tensor> predict(const Tensor &t1, const Tensor &t2,
const Tensor &t3) {
// feed
auto scope = program_.scope;
Variable *x_feed_value = scope->Var("pool2d_13.tmp_0");
auto tensor_x = x_feed_value->GetMutable<LoDTensor>();
tensor_x->ShareDataWith(t1);
Variable *y_feed_value = scope->Var("loss3_classifier-loc_weights");
auto tensor_y = y_feed_value->GetMutable<LoDTensor>();
tensor_y->ShareDataWith(t2);
Variable *z_feed_value = scope->Var("loss3_classifier-loc_biases");
auto tensor_z = z_feed_value->GetMutable<LoDTensor>();
tensor_z->ShareDataWith(t3);
Variable *con_output = scope->Var("loss3_classifier-loc.tmp_1");
auto *output_tensor = con_output->GetMutable<LoDTensor>();
output_tensor->mutable_data<float>({3, 10});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
std::shared_ptr<LoDTensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict(t1, t2, t3, 0);
return out_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
void predict(const Tensor &t1, const Tensor &t2, const Tensor &t3,
int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
int32_t eq = 0;
int32_t neq = 0;
for (int32_t i = 0; i < m * n; ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == c[i],
"The execution of test_fusion_fc_op is failed!");
if (output_data[i] == c[i]) {
++eq;
} else {
++neq;
}
}
};
template class TestFcOp<CPU>;
} // namespace framework
std::cout << "mnk=" << m << " " << n << " " << k << " eq=" << eq
<< " neq=" << neq << std::endl;
delete op;
return 0;
}
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run Fc Test";
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
// "../../../test/models/googlenet"
auto program = loader.Load(g_googlenet);
paddle_mobile::framework::ProgramOptimize optimize;
// program.originProgram->Description("origin");
auto optimize_program = optimize.FusionOptimize(program.originProgram);
program.optimizeProgram = optimize_program;
if (optimize_program != nullptr) {
optimize_program->Description("optimize");
} else {
LOG(paddle_mobile::kLOG_ERROR) << "optimize_program is null";
}
/// input x (1,3,224,224)
paddle_mobile::framework::LoDTensor inputx;
SetupTensor<float>(&inputx, {3, 64, 1, 1}, static_cast<float>(1),
static_cast<float>(1));
auto *inputx_ptr = inputx.data<float>();
/// input y (224,)
paddle_mobile::framework::LoDTensor inputy;
SetupTensor<float>(&inputy, {64, 10}, static_cast<float>(1.5),
static_cast<float>(1.5));
auto *inputy_ptr = inputy.data<float>();
paddle_mobile::framework::LoDTensor inputz;
SetupTensor<float>(&inputz, {10}, static_cast<float>(0),
static_cast<float>(1));
auto *inputz_ptr = inputz.data<float>();
paddle_mobile::framework::TestFcOp<paddle_mobile::CPU> testFcOp(program);
auto output = testFcOp.predict(inputx, inputy, inputz);
auto *output_ptr = output->data<float>();
for (int j = 0; j < output->numel(); ++j) {
DLOG << "value of output: " << output_ptr[j];
}
DLOG << "1 (3,64) * 2 (64,10) = 96(3,10)";
DLOG << "output : 96(3,10) + bias(10)";
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(4);
#ifdef FUSION_FC_INT8_OP
paddle_mobile::TestFcOP<int8_t, int32_t>();
#endif
paddle_mobile::TestFcOP<float, float>();
return 0;
}
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/mul_op.h"
......@@ -79,14 +80,14 @@ int TestMulOP() {
PADDLE_MOBILE_ENFORCE(
output_data[i] == c[i], "output[%d] = %d, output_cmp[%d] = %d", i,
static_cast<int32_t>(output_data[i]), i, static_cast<int32_t>(c[i]));
if (static_cast<int>(output_data[i] == c[i])) {
if (output_data[i] == c[i]) {
++eq;
} else {
++neq;
}
}
DLOG << "mnk=" << m << " " << n << " " << k << " eq=" << eq
<< " neq=" << neq;
std::cout << "mnk=" << m << " " << n << " " << k << " eq=" << eq
<< " neq=" << neq << std::endl;
delete op;
return 0;
}
......@@ -94,7 +95,7 @@ int TestMulOP() {
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(8);
paddle_mobile.SetThreadNum(4);
paddle_mobile::TestMulOP<int8_t, int32_t>();
paddle_mobile::TestMulOP<float, float>();
return 0;
......
......@@ -12,30 +12,281 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_include.h"
#include "operators/kernel/central-arm-func/pool_arm_func.h"
#include "operators/pool_op.h"
int main() {
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_googlenet));
if (program.originProgram == nullptr) {
DLOG << "program read file";
namespace paddle_mobile {
static int PoolOutputSize(int input_size, int filter_size, int padding,
int stride, bool ceil_mode) {
int output_size;
if (!ceil_mode) {
output_size = (input_size - filter_size + 2 * padding) / stride + 1;
} else {
output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
}
return output_size;
}
template <typename T>
static void PoolAvgPad0(std::vector<int> ksize, std::vector<int> strides,
const framework::Tensor *input,
framework::Tensor *out) {
const int32_t batch_size = input->dims()[0];
const int32_t input_c = input->dims()[1];
const int32_t input_h = input->dims()[2];
const int32_t input_w = input->dims()[3];
const int32_t out_c = out->dims()[1];
const int32_t out_h = out->dims()[2];
const int32_t out_w = out->dims()[3];
const int32_t kernel_h = ksize[0];
const int32_t kernel_w = ksize[1];
const int32_t stride_h = strides[0];
const int32_t stride_w = strides[1];
const int32_t inputdata_channel_stride = input_h * input_w;
const int32_t input_batch_stride = input_c * inputdata_channel_stride;
const int32_t outputdata_channel_stride = out_h * out_w;
const int32_t output_batch_stride = out_c * outputdata_channel_stride;
T *out_data = out->mutable_data<T>();
const T *input_data = input->data<T>();
const T **rows = new const T *[kernel_h];
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < out_c; ++j) {
const T *img_in = input_data + j * inputdata_channel_stride;
T *img_out = out_data + j * outputdata_channel_stride;
for (int k = 0; k < out_h; ++k) {
for (int m = 0; m < kernel_h; ++m) {
rows[m] = img_in + (stride_h * k + m) * input_w;
}
int32_t left = out_w;
while (left > 0) {
float tmp = 0;
for (int m = 0; m < kernel_h; ++m) {
for (int l = 0; l < kernel_w; ++l) {
tmp += rows[m][l];
}
}
if (typeid(T) == typeid(int8_t)) {
tmp = tmp / (kernel_h * kernel_w);
if (tmp < -127) {
*img_out = -127;
} else if (tmp > 127) {
*img_out = 127;
} else {
*img_out = static_cast<T>(std::round(tmp));
}
} else {
*img_out = static_cast<T>(tmp / (kernel_h * kernel_w));
}
for (int m = 0; m < kernel_h; ++m) {
rows[m] += stride_w;
}
img_out++;
left--;
}
}
}
input_data += input_batch_stride;
out_data += output_batch_stride;
}
delete[] rows;
}
template <typename T, int CeilMode, int PoolType, int Kernel, int Pad,
int Stride>
int TestPoolOp(int in_channels, int in_height, int in_width) {
int kernel_h = Kernel;
int kernel_w = Kernel;
int pad_h = Pad;
int pad_w = Pad;
int stride_h = Stride;
int stride_w = Stride;
bool ceil_mode = CeilMode != 0;
std::string pooling_type = (PoolType == 0 ? "max" : "avg");
int batch_size = 1;
int input_c = in_channels;
int input_h = in_height;
int input_w = in_width;
framework::DDim input_shape =
framework::make_ddim({batch_size, input_c, input_h, input_w});
std::vector<int64_t> output_shape_v({batch_size, input_c});
output_shape_v.push_back(
PoolOutputSize(input_h, kernel_h, pad_h, stride_h, ceil_mode));
output_shape_v.push_back(
PoolOutputSize(input_w, kernel_w, pad_w, stride_w, ceil_mode));
framework::DDim output_shape = framework::make_ddim(output_shape_v);
Executor4Test<paddle_mobile::CPU,
paddle_mobile::operators::PoolOp<paddle_mobile::CPU, float>>
executor(program, "pool2d");
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
paddle_mobile::framework::Tensor input;
SetupTensor<float>(&input, {1, 64, 112, 112}, static_cast<float>(0),
static_cast<float>(1));
auto out_ddim = paddle_mobile::framework::make_ddim({1, 64, 56, 56});
auto output =
executor.Predict(input, "conv2d_0.tmp_1", "pool2d_0.tmp_0", out_ddim);
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<T>(input, input_shape, -127, 127);
float *output_ptr = output->data<float>();
for (int j = 0; j < output->numel(); ++j) {
DLOG << " value of output: " << output_ptr[j];
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["pooling_type"].SetString(pooling_type);
attrs["ksize"].Set<vector<int>>(std::vector<int>({kernel_h, kernel_w}));
attrs["strides"].Set<vector<int>>(std::vector<int>({stride_h, stride_w}));
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad_h, pad_w}));
attrs["ceil_mode"].Set<bool>(false);
attrs["global_pooling"].Set<bool>(false);
auto *op = new operators::PoolOp<CPU, float>("pool2d", inputs, outputs, attrs,
scope);
op->InferShape();
op->Init();
op->Run();
framework::Tensor output_cmp;
output_cmp.mutable_data<T>(output_shape);
if (pooling_type == "avg" && pad_h == 0 && pad_h == pad_w) {
PoolAvgPad0<T>(std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, input, &output_cmp);
} else {
if (typeid(T) == typeid(int8_t)) {
operators::PoolBasic<int8_t, int32_t>(
pooling_type, std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, std::vector<int>{pad_h, pad_w},
input, &output_cmp);
} else {
operators::PoolBasic<float, float>(
pooling_type, std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, std::vector<int>{pad_h, pad_w},
input, &output_cmp);
}
}
// compare results
int eq = 0;
int neq = 0;
auto output = output_var->template Get<framework::LoDTensor>();
const T *output_data = output->data<T>();
T *output_cmp_data = output_cmp.data<T>();
for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
"The execution of test_pool_op is failed!");
if (output_data[i] == output_cmp_data[i]) {
++eq;
} else {
++neq;
}
}
std::cout << "eq = " << eq << ", neq = " << neq << std::endl;
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
if (argc < 4) {
LOG(paddle_mobile::kLOG_INFO)
<< "Usage:\n"
<< " ./test-pool-op in_channels in_height in_width \n"
<< " params:\n"
<< " -in_channels: int, input image's channels\n"
<< " -in_height: int, input image's height\n"
<< " -in_width: int, input image's width\n";
return 1;
}
int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]);
#if __ARM_NEON
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=1";
paddle_mobile::TestPoolOp<float, 0, 0, 3, 1, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<float, 0, 0, 3, 0, 2>(in_channels, in_height,
in_width);
#endif
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 0, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 1, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 2, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 0, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 1, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 1, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 2, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 3, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=3, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 3, 3>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 1>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 2>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 3>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 3>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 1>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 4
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=4";
paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 4>(in_channels, in_height,
in_width);
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height,
in_width);
}
......@@ -213,6 +213,8 @@ if(NOT FOUND_MATCH)
set(FUSION_CONVADD_OP ON)
set(FUSION_CONVADDPRELU_OP ON)
set(FUSION_CONVADDRELU_OP ON)
set(FUSION_CONVADDRELU_INT8_OP ON)
set(FUSION_FC_INT8_OP ON)
set(FUSION_FC_OP ON)
set(LRN_OP ON)
set(MUL_OP ON)
......@@ -309,6 +311,9 @@ endif()
if (FUSION_CONVADDRELU_OP)
add_definitions(-DFUSION_CONVADDRELU_OP)
endif()
if (FUSION_CONVADDRELU_INT8_OP)
add_definitions(-DFUSION_CONVADDRELU_INT8_OP)
endif()
if (FUSION_CONVADDPRELU_OP)
add_definitions(-DFUSION_CONVADDPRELU_OP)
endif()
......@@ -318,6 +323,9 @@ endif()
if (FUSION_FC_OP)
add_definitions(-DFUSION_FC_OP)
endif()
if(FUSION_FC_INT8_OP)
add_definitions(-DFUSION_FC_INT8_OP)
endif()
if (LRN_OP)
add_definitions(-DLRN_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册