提交 b7e92db8 编写于 作者: H hjchen2

Optimize: fuse quantize and pad op

上级 b680fc96
...@@ -233,3 +233,6 @@ LOAD_OP1(quantize, CPU); ...@@ -233,3 +233,6 @@ LOAD_OP1(quantize, CPU);
#ifdef DEQUANT_OP #ifdef DEQUANT_OP
LOAD_OP1(dequantize, CPU); LOAD_OP1(dequantize, CPU);
#endif #endif
#ifdef PAD_OP
LOAD_OP1(pad, CPU);
#endif
...@@ -22,7 +22,7 @@ namespace operators { ...@@ -22,7 +22,7 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
void DequantizeOp<DeviceType, T>::InferShape() const { void DequantizeOp<DeviceType, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims(); const auto& input_dims = this->param_.input_->dims();
this->param_.out_->Resize(input_dims); this->param_.output_->Resize(input_dims);
} }
} // namespace operators } // namespace operators
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef CONV_OP #ifdef CONV_OP
#include "operators/kernel/conv_kernel.h" #include "operators/kernel/conv_kernel.h"
#include <iostream>
#include "operators/kernel/central-arm-func/conv_arm_func.h" #include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -22,8 +23,15 @@ namespace operators { ...@@ -22,8 +23,15 @@ namespace operators {
template <> template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
if (param->Input()->type() == typeid(int8_t)) { if (param->Filter()->type() == typeid(int8_t)) {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8; if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
}
} else { } else {
if (param->Groups() == param->Input()->dims()[1] && if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] && param->Input()->dims()[1] == param->Output()->dims()[1] &&
...@@ -35,6 +43,7 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { ...@@ -35,6 +43,7 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3) { param->Filter()->dims()[2] == 3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT;
#ifndef __aarch64__
} else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] && } else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Strides()[0] == param->Strides()[1] && param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] && param->Dilations()[0] == param->Dilations()[1] &&
...@@ -48,6 +57,7 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { ...@@ -48,6 +57,7 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
operators::math::winograd_transform_weight<8, 3>(*param->Filter(), operators::math::winograd_transform_weight<8, 3>(*param->Filter(),
transformed_weight); transformed_weight);
param->Filter() = transformed_weight; param->Filter() = transformed_weight;
#endif
} else { } else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
} }
...@@ -60,25 +70,36 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) { ...@@ -60,25 +70,36 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_GEMM_INT8: case ConvParam<CPU>::EXEC_GEMM_INT8:
GemmConv<int8_t, int32_t>(param); GemmConv<int8_t, int32_t>(param);
std::cout << "EXEC_GEMM_INT8" << std::endl;
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8:
DepthwiseConv3x3<int8_t, int32_t>(param);
std::cout << "EXEC_DEPTHWISE3x3_INT8" << std::endl;
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false); nullptr, false);
std::cout << "EXEC_DEPTHWISE3x3S1P1_FLOAT" << std::endl;
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false); param.Filter(), nullptr, param.Output(), false);
std::cout << "EXEC_DEPTHWISE3x3_FLOAT=" << param.Strides()[0]
<< std::endl;
break; break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT: case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param); WinogradConv3x3<8, 3>(param);
std::cout << "EXEC_WINOGRAD3X3_FLOAT" << std::endl;
break; break;
case ConvParam<CPU>::EXEC_GEMM_FLOAT: case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param); GemmConv<float, float>(param);
std::cout << "EXEC_GEMM_FLOAT" << std::endl;
break; break;
default: default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode()); param.ExecMode());
} }
std::cout << "exec here..." << std::endl;
} }
template class ConvKernel<CPU, float>; template class ConvKernel<CPU, float>;
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef DEQUANT_OP #ifdef DEQUANT_OP
#include "operators/kernel/dequantize_kernel.h" #include "operators/kernel/dequantize_kernel.h"
#include <iostream>
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h> #include <arm_neon.h>
...@@ -31,7 +32,7 @@ bool DequantizeKernel<CPU, float>::Init(DequantizeParam<CPU> *param) { ...@@ -31,7 +32,7 @@ bool DequantizeKernel<CPU, float>::Init(DequantizeParam<CPU> *param) {
template <> template <>
void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) { void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) {
const Tensor *input = param.input_; const Tensor *input = param.input_;
Tensor *output = param.out_; Tensor *output = param.output_;
float activation_scale = param.activation_scale_->data<float>()[0]; float activation_scale = param.activation_scale_->data<float>()[0];
float weight_scale = param.weight_scale_; float weight_scale = param.weight_scale_;
const int32_t *x = input->data<const int32_t>(); const int32_t *x = input->data<const int32_t>();
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef ELEMENTWISEADD_OP #ifdef ELEMENTWISEADD_OP
#include "operators/kernel/elementwise_add_kernel.h" #include "operators/kernel/elementwise_add_kernel.h"
#include <iostream>
#include "operators/kernel/central-arm-func/elementwise_add_arm_func.h" #include "operators/kernel/central-arm-func/elementwise_add_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "operators/math/conv_func.h" #include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/conv_func.h" #include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/pad.h" #include "operators/math/pad.h"
...@@ -39,10 +39,7 @@ inline void GemmConv(const ConvParam<CPU> &param) { ...@@ -39,10 +39,7 @@ inline void GemmConv(const ConvParam<CPU> &param) {
const std::vector<int> paddings = param.Paddings(); const std::vector<int> paddings = param.Paddings();
const std::vector<int> dilations = param.Dilations(); const std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims())); std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2; size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
...@@ -83,6 +80,7 @@ inline void GemmConv(const ConvParam<CPU> &param) { ...@@ -83,6 +80,7 @@ inline void GemmConv(const ConvParam<CPU> &param) {
math::Vol2ColFunctor<CPU, Itype> vol2col; math::Vol2ColFunctor<CPU, Itype> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, Itype> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, CPU, Itype> im2col;
const int batch_size = static_cast<int>(input->dims()[0]);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
...@@ -126,7 +124,6 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) { ...@@ -126,7 +124,6 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
int batch_size = input->dims()[0]; int batch_size = input->dims()[0];
int groups = param.Groups(); int groups = param.Groups();
const std::vector<int> &paddings = param.Paddings(); const std::vector<int> &paddings = param.Paddings();
math::PadFunctor<CPU, float> pad;
auto winograd_pad = [&](int width, int pad) { auto winograd_pad = [&](int width, int pad) {
int output_tile = tile - kernel + 1; int output_tile = tile - kernel + 1;
...@@ -136,6 +133,7 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) { ...@@ -136,6 +133,7 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
return pad_width + tile - width; return pad_width + tile - width;
}; };
math::PadFunctor<CPU, float> pad;
Tensor input_pad; Tensor input_pad;
framework::Tensor transformed_input; framework::Tensor transformed_input;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
...@@ -155,15 +153,49 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) { ...@@ -155,15 +153,49 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
} else { } else {
input_pad = in_batch; input_pad = in_batch;
} }
#if __aarch64__
// TODO(hjchen2)
#else
// tile input and transform // tile input and transform
math::winograd_transform_input<tile, kernel>(input_pad, &transformed_input); math::winograd_transform_input<tile, kernel>(input_pad, &transformed_input);
// caculate output // caculate output
math::winograd_transform_output<tile, kernel>(transformed_input, *filter, math::winograd_transform_output<tile, kernel>(transformed_input, *filter,
output); output);
#endif }
}
template <typename Itype, typename Otype>
inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
Tensor *output = param.Output();
output->mutable_data<Otype>();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = static_cast<int>(input->dims()[0]);
Tensor input_pad;
math::PadFunctor<CPU, Itype> pad;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
// if (paddings[0] || paddings[1]) {
// framework::DDim pad_shape = in_batch.dims();
// pad_shape[2] += 2 * paddings[0];
// pad_shape[3] += 2 * paddings[1];
// input_pad.mutable_data<float>(pad_shape);
// pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1],
// &input_pad);
// } else {
// input_pad = in_batch;
// }
// math::DepthwiseConv3x3s1<Itype, Otype>(input_pad, *filter,
// &out_batch);
if (strides[0] == 1) {
math::DepthwiseConv3x3s1<Itype, Otype>(in_batch, *filter, &out_batch);
} else if (strides[0] == 2) {
math::DepthwiseConv3x3s2<Itype, Otype>(in_batch, *filter, &out_batch);
} else {
// math::DepthwiseConv3x3<Itype, Otype>(in_batch, *filter,
// &out_batch);
}
} }
} }
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
......
...@@ -16,13 +16,15 @@ limitations under the License. */ ...@@ -16,13 +16,15 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
void ConvBNReluBasic(const FusionConvBNReluParam<CPU> &param) { void ConvBNReluBasic(const FusionConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
......
...@@ -15,10 +15,9 @@ limitations under the License. */ ...@@ -15,10 +15,9 @@ limitations under the License. */
#ifdef DEPTHWISECONV_OP #ifdef DEPTHWISECONV_OP
#pragma once #pragma once
#include <operators/math/depthwise_conv_3x3.h>
#include <vector> #include <vector>
#include "operators/kernel/central-arm-func/conv_arm_func.h" #include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -16,13 +16,15 @@ limitations under the License. */ ...@@ -16,13 +16,15 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
void DWConvBNReluBasic(const FusionDWConvBNReluParam<CPU> &param) { void DWConvBNReluBasic(const FusionDWConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
......
...@@ -24,7 +24,7 @@ limitations under the License. */ ...@@ -24,7 +24,7 @@ limitations under the License. */
#include "framework/ddim.h" #include "framework/ddim.h"
#include "framework/operator.h" #include "framework/operator.h"
#include "operators/math/conv_func.h" #include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
......
...@@ -11,18 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,18 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include <vector>
#if __ARM_NEON #if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include <vector>
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
vector<int> paddings, const Tensor *filter, Tensor *bias, void DepthwiseConv3x3(const framework::Tensor *input,
Tensor *output, bool if_bias) { const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias) {
const int batch_size = input->dims()[0]; const int batch_size = input->dims()[0];
const int input_height = input->dims()[2]; const int input_height = input->dims()[2];
...@@ -67,12 +71,12 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides, ...@@ -67,12 +71,12 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
for (int pw = 0; pw < output_width; pw++) { for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height; hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width; wstart = pw * stride_width - padding_width;
hend = min(hstart + _kernel_size, input_height + padding_height); hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = min(wstart + _kernel_size, input_width + padding_width); wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = max(hstart, 0); hstart = std::max(hstart, 0);
wstart = max(wstart, 0); wstart = std::max(wstart, 0);
hend = min(hend, input_height); hend = std::min(hend, input_height);
wend = min(wend, input_width); wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart; pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart; pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart; pos3 = input_data + (hstart + 2) * input_width + wstart;
...@@ -244,8 +248,10 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides, ...@@ -244,8 +248,10 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
} }
} }
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s1p1(const framework::Tensor *input,
Tensor *output, Tensor *bias, bool if_bias) { const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
...@@ -517,9 +523,12 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -517,9 +523,12 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
Tensor *output, const Tensor *new_scale, const framework::Tensor *filter,
const Tensor *new_bias, bool if_relu) { framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
...@@ -1059,9 +1068,12 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -1059,9 +1068,12 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
} }
/// w!=h not fix /// w!=h not fix
void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
Tensor *output, const Tensor *new_scale, const framework::Tensor *filter,
const Tensor *new_bias, bool if_relu) { framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON #if __ARM_NEON
const int batch_size = input->dims()[0]; const int batch_size = input->dims()[0];
...@@ -1107,12 +1119,12 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, ...@@ -1107,12 +1119,12 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
for (int pw = 0; pw < output_width; pw++) { for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height; hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width; wstart = pw * stride_width - padding_width;
hend = min(hstart + _kernel_size, input_height + padding_height); hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = min(wstart + _kernel_size, input_width + padding_width); wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = max(hstart, 0); hstart = std::max(hstart, 0);
wstart = max(wstart, 0); wstart = std::max(wstart, 0);
hend = min(hend, input_height); hend = std::min(hend, input_height);
wend = min(wend, input_width); wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart; pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart; pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart; pos3 = input_data + (hstart + 2) * input_width + wstart;
...@@ -1258,8 +1270,10 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, ...@@ -1258,8 +1270,10 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
Tensor *output, Tensor bias, bool if_bias) { const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
...@@ -1463,9 +1477,12 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1463,9 +1477,12 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
Tensor *output, const Tensor *new_scale, const framework::Tensor *filter,
const Tensor *new_bias, bool if_relu) { framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON #if __ARM_NEON
// #ifdef _OPENMP // #ifdef _OPENMP
// const float *newscale_data = new_scale->data<float>(); // const float *newscale_data = new_scale->data<float>();
...@@ -1886,8 +1903,10 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1886,8 +1903,10 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s2p0(const framework::Tensor *input,
Tensor *output, Tensor bias, bool if_bias) { const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
namespace math {
void DepthwiseConv3x3(const framework::Tensor *input,
const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias);
void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias);
// template<typename Itype, typename Otype>
// void DepthwiseConv3x3(const framework::Tensor *input,
// const framework::Tensor *filter,
// const std::vector<int> &strides,
// framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv3x3s1(const framework::Tensor &input,
const framework::Tensor &filter,
framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv3x3s2(const framework::Tensor &input,
const framework::Tensor &filter,
framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "framework/tensor.h"
namespace paddle_mobile {
namespace operators {
namespace math {
void DepthwiseConv3x3_int8(const framework::Tensor *input,
const framework::Tensor *filter,
const std::vector<int> &strides,
framework::Tensor *output);
void DepthwiseConv3x3s1_int8(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output);
void DepthwiseConv3x3s2_int8(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
namespace math {
using framework::Tensor;
using std::max;
using std::min;
using std::vector;
void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
vector<int> paddings, const Tensor *filter, Tensor *bias,
Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias);
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
...@@ -26,79 +26,6 @@ limitations under the License. */ ...@@ -26,79 +26,6 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
/*int MC = 0;
int KC = 0;
int NC = 0;
float *packedA;
float *packedB;
float *packedC;
float *zero;
typedef void (*FnPack)(int, int, int, const float *, int, float *);
typedef void (*FnAddDot)(int, const float *, const float *, float *, int);
FnPack procPackA;
FnPack procPackB;
FnAddDot procAddDot;*/
/*
// 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
int i, j;
const float *Aij;
for (i = 0; i < m - m_tail; i += MR) {
for (j = 0; j < k; ++j) {
Aij = &A(i, j);
*buffer++ = *Aij;
*buffer++ = *(Aij + 1);
*buffer++ = *(Aij + 2);
*buffer++ = *(Aij + 3);
}
}
if (m_tail != 0) {
for (j = 0; j < k; ++j) {
Aij = &A(m - m_tail, j);
for (i = 0; i < m_tail; ++i) {
*buffer++ = *(Aij + i);
}
for (i = m_tail; i < MR; ++i) {
*buffer++ = 0;
}
}
}
}
// 将B矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
int i, j;
const float *Bj, *Bj1, *Bj2, *Bj3;
for (j = 0; j < n - n_tail; j += NR) {
Bj = &B(0, j);
Bj1 = &B(0, j + 1);
Bj2 = &B(0, j + 2);
Bj3 = &B(0, j + 3);
for (i = 0; i < k; ++i) {
*buffer++ = *Bj++;
*buffer++ = *Bj1++;
*buffer++ = *Bj2++;
*buffer++ = *Bj3++;
}
}
if (n_tail != 0) {
for (i = 0; i < k; ++i) {
for (int j = n - n_tail; j < n; ++j) {
*buffer++ = B(i, j);
}
for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0;
}
}
}
}
*/
// 将A矩阵分块复制到连续内存(RowMajor) // 将A矩阵分块复制到连续内存(RowMajor)
void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
......
...@@ -423,6 +423,7 @@ class ConvParam : public OpParam { ...@@ -423,6 +423,7 @@ class ConvParam : public OpParam {
EXEC_WINOGRAD3X3_FLOAT, EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT, EXEC_WINOGRAD5X5_FLOAT,
EXEC_GEMM_INT8, EXEC_GEMM_INT8,
EXEC_DEPTHWISE3x3_INT8,
}; };
ExecMode &ExecMode() const { return exec_mode_; } ExecMode &ExecMode() const { return exec_mode_; }
...@@ -2498,7 +2499,7 @@ class QuantizeParam : public OpParam { ...@@ -2498,7 +2499,7 @@ class QuantizeParam : public OpParam {
QuantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, QuantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope); input_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope); output_ = OutFrom<GType>(outputs, scope);
// online // online
// scale = max(abs(x)) // scale = max(abs(x))
online_scale_ = GetVarValue<GType>("OutScale", outputs, scope); online_scale_ = GetVarValue<GType>("OutScale", outputs, scope);
...@@ -2517,8 +2518,7 @@ class QuantizeParam : public OpParam { ...@@ -2517,8 +2518,7 @@ class QuantizeParam : public OpParam {
// op input // op input
RType *input_; RType *input_;
// op output // op output
RType *out_; RType *output_;
//
RType *online_scale_; RType *online_scale_;
// if static scale or not // if static scale or not
bool is_static_ = false; bool is_static_ = false;
...@@ -2526,7 +2526,11 @@ class QuantizeParam : public OpParam { ...@@ -2526,7 +2526,11 @@ class QuantizeParam : public OpParam {
float static_scale_ = 1.0f; float static_scale_ = 1.0f;
// round method type // round method type
// nearest_zero and nearest_even is valid currently // nearest_zero and nearest_even is valid currently
RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; // RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO;
RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO;
// optional paddings
std::vector<int> paddings_;
int8_t padding_val_;
}; };
#endif #endif
...@@ -2540,7 +2544,7 @@ class DequantizeParam : public OpParam { ...@@ -2540,7 +2544,7 @@ class DequantizeParam : public OpParam {
DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope); input_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope); output_ = OutFrom<GType>(outputs, scope);
activation_scale_ = GetVarValue<GType>("Scale", inputs, scope); activation_scale_ = GetVarValue<GType>("Scale", inputs, scope);
// dequantization is performed as x = x / static_scale / online_scale // dequantization is performed as x = x / static_scale / online_scale
if (HasAttr("weight_scale", attrs)) { if (HasAttr("weight_scale", attrs)) {
...@@ -2554,11 +2558,32 @@ class DequantizeParam : public OpParam { ...@@ -2554,11 +2558,32 @@ class DequantizeParam : public OpParam {
// op input // op input
RType *input_; RType *input_;
// op output // op output
RType *out_; RType *output_;
RType *activation_scale_; RType *activation_scale_;
float weight_scale_; float weight_scale_;
}; };
#endif #endif
#ifdef PAD_OP
template <typename Dtype>
class PadParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
input_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
paddings_ = GetVarValue<std::vector<int>>("Paddings", inputs, scope);
public:
// op input
RType *input_;
// op output
RType *output_;
// paddings
std::vector<int> paddings_;
};
#endif
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,8 +22,12 @@ namespace operators { ...@@ -22,8 +22,12 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
void QuantizeOp<DeviceType, T>::InferShape() const { void QuantizeOp<DeviceType, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims(); auto input_dims = this->param_.input_->dims();
this->param_.out_->Resize(input_dims); // const auto &paddings = this->param_.paddings_;
std::vector<int> paddings = {0, 0};
input_dims[2] += 2 * paddings[0];
input_dims[3] += 2 * paddings[1];
this->param_.output_->Resize(input_dims);
auto scale_dims = framework::make_ddim(std::vector<int>{1}); auto scale_dims = framework::make_ddim(std::vector<int>{1});
this->param_.online_scale_->Resize(scale_dims); this->param_.online_scale_->Resize(scale_dims);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册