提交 8137d199 编写于 作者: qnqinan's avatar qnqinan

Merge remote-tracking branch 'upstream/develop' into develop

......@@ -5,6 +5,7 @@ option(DEBUGING "enable debug mode" ON)
option(USE_EXCEPTION "use std exception" ON)
option(SYMBOL_HIDDEN "symbol hidden" OFF) # on when use jni or ios io
option(LOG_PROFILE "log profile" OFF)
# select the platform to build
option(CPU "armv7 with neon" ON)
option(GPU_MALI "mali gpu" OFF)
......@@ -15,7 +16,6 @@ if(FPGA)
option(FPGAV2 "fpga v2" OFF)
endif()
project(paddle-mobile)
file(GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c src/*.mm)
......@@ -247,6 +247,3 @@ elseif(FPGA)
add_subdirectory(test)
endif()
......@@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu";
const char *G_OP_TYPE_TANH = "tanh";
const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu";
const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add";
......@@ -134,6 +136,7 @@ std::unordered_map<
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_TANH, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}},
......
......@@ -138,6 +138,7 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE;
extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU;
extern const char *G_OP_TYPE_TANH;
extern const char *G_OP_TYPE_FUSION_DECONV_RELU;
......
......@@ -30,7 +30,6 @@ limitations under the License. */
#ifdef PADDLE_EXECUTOR_MULTITHREAD
#include <queue>
#include <utility>
#include "common/threadpool.h"
#endif
......@@ -73,7 +72,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(),
program_.scope);
// infer shape to reshape tensor before predict,
// but for lod tensor, it will need to reshape in runtime
// but for lod tensor, it will still need to reshape in runtime
if (!loddable_) {
op_base->InferShape();
}
......
......@@ -233,3 +233,7 @@ LOAD_OP1(quantize, CPU);
#ifdef DEQUANT_OP
LOAD_OP1(dequantize, CPU);
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
LOAD_OP1(fusion_dequant_add_bn_relu, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu);
#endif
......@@ -127,11 +127,6 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
virtual void InferShape() const = 0;
void Init() {
// for (auto i : this->inputs_) {
// DLOG << i.first;
// DLOG << i.second;
// }
PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), " %s kernel init failed",
this->type_.c_str());
}
......
......@@ -54,22 +54,6 @@ class Tensor : public TensorBase {
this->offset_ = inTensor.offset_;
}
#ifdef PADDLE_MOBILE_DEBUG
template <typename T>
inline void dump(std::string filename) const {
const T *dataptr = data<T>();
std::ofstream out(filename.c_str());
for (int i = 0; i < numel(); ++i) {
out << dataptr[i] << " ";
}
out << "形状:";
for (int j = 0; j < dims_.size(); ++j) {
out << dims_[j] << " ";
}
out.close();
}
#endif
/*! Resize the dimensions of the memory block. */
inline Tensor &Resize(const DDim &dims) {
dims_ = dims;
......
......@@ -22,7 +22,7 @@ namespace operators {
template <typename DeviceType, typename T>
void DequantizeOp<DeviceType, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims();
this->param_.out_->Resize(input_dims);
this->param_.output_->Resize(input_dims);
}
} // namespace operators
......
......@@ -12,33 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/conv_op.h"
int main() {
paddle_mobile::framework::Loader<paddle_mobile::GPU_MALI> loader;
// ../models/image_classification_resnet.inference.model
auto program = loader.Load(g_googlenet);
PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr,
"program file read fail");
Executor4Test<paddle_mobile::GPU_MALI, paddle_mobile::operators::ConvOp<
paddle_mobile::GPU_MALI, float>>
executor(program, "conv2d");
paddle_mobile::framework::Tensor input;
GetInput<float>(g_test_image_1x3x224x224, &input, {1, 3, 224, 224});
// // use SetupTensor if not has local input image .
// SetupTensor<float>(&input, {1, 3, 224, 224}, static_cast<float>(0),
// static_cast<float>(1));
auto out_ddim = paddle_mobile::framework::make_ddim({1, 64, 112, 112});
auto output = executor.Predict(input, "data", "conv2d_0.tmp_0", out_ddim);
auto output_ptr = output->data<float>();
for (int j = 0; j < 20; ++j) {
DLOG << " value of output: " << output_ptr[j];
}
return 0;
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "operators/fusion_dequant_add_bn_relu_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionDequantAddBNReluOp<Dtype, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_dequant_add_bn_relu,
ops::FusionDequantAddBNReluMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dequant_add_bn_relu,
ops::FusionDequantAddBNReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionDequantAddBNReluMatcher : public framework::FusionOpMatcher {
public:
FusionDequantAddBNReluMatcher() {
node_ = framework::Node(G_OP_TYPE_DEQUANTIZE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}},
{G_OP_TYPE_BATCHNORM,
{{"Scale", "BNScale"},
{"Mean", "BNMean"},
{"Bias", "BNBias"},
{"Variance", "BNVariance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; }
};
template <typename DeviceType, typename T>
class FusionDequantAddBNReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNReluParam<DeviceType>,
operators::FusionDequantAddBNReluKernel<DeviceType, T>> {
public:
FusionDequantAddBNReluOp(const std::string &type,
const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNReluParam<DeviceType>,
operators::FusionDequantAddBNReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -22,12 +22,76 @@ namespace operators {
template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
if (param->Filter()->type() == typeid(int8_t)) {
if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
}
} else {
if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT;
} else if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT;
#ifndef __aarch64__
} else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1 &&
param->Dilations()[0] == 1 && param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight
framework::Tensor *transformed_weight = new framework::Tensor;
operators::math::winograd_transform_weight<8, 3>(*param->Filter(),
transformed_weight);
param->Filter() = transformed_weight;
#endif
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
}
}
return true;
}
template <>
void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
ConvCompute<float>(param);
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_GEMM_INT8:
GemmConv<int8_t, int32_t>(param);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8:
DepthwiseConv3x3<int8_t, int32_t>(param);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
break;
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
}
template class ConvKernel<CPU, float>;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
template <>
bool FusionDequantAddBNReluKernel<CPU, float>::Init(
FusionDequantAddBNReluParam<CPU> *param) {
// elementwise add params
const Tensor *bias = param->bias_;
// batch norm params
const Tensor *bn_mean = param->bn_mean_;
const Tensor *bn_variance = param->bn_variance_;
Tensor *bn_scale = param->bn_scale_;
Tensor *bn_bias = param->bn_bias_;
const float epsilon = param->epsilon_;
const float *bias_ptr = bias->data<float>();
const float *mean_ptr = bn_mean->data<float>();
const float *var_ptr = bn_variance->data<float>();
float *bn_scale_ptr = bn_scale->mutable_data<float>();
float *bn_bias_ptr = bn_bias->mutable_data<float>();
for (int c = 0; c < bn_scale->numel(); ++c) {
float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon));
bn_scale_ptr[c] = inv_scale;
bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c];
}
return true;
}
template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute(
const FusionDequantAddBNReluParam<CPU> &param) {
const int32_t *input = param.input_->data<int32_t>();
const float *bn_scale = param.bn_scale_->data<float>();
const float *bn_bias = param.bn_bias_->data<float>();
// dequantize params
const float activation_scale = param.activation_scale_->data<float>()[0];
const float weight_scale = param.weight_scale_;
const float dequant_scale = activation_scale / weight_scale;
float *output = param.output_->mutable_data<float>();
int batch_size = param.input_->dims()[0];
int channels = param.input_->dims()[1];
size_t spatial_size = param.input_->dims()[2] * param.input_->dims()[3];
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channels; ++c) {
float scale = bn_scale[c] * dequant_scale;
float bias = bn_bias[c];
size_t offset = (batch * channels + c) * spatial_size;
const int32_t *x = input + offset;
float *y = output + offset;
size_t remain = spatial_size;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = spatial_size >> 4;
remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias);
float32x4_t __zero = vdupq_n_f32(0.f);
for (int k = 0; k < loop; ++k, x += 16, y += 16) {
int32x4_t r0 = vld1q_s32(x);
int32x4_t r1 = vld1q_s32(x + 4);
int32x4_t r2 = vld1q_s32(x + 8);
int32x4_t r3 = vld1q_s32(x + 12);
float32x4_t f0 = vcvtq_f32_s32(r0);
float32x4_t f1 = vcvtq_f32_s32(r1);
float32x4_t f2 = vcvtq_f32_s32(r2);
float32x4_t f3 = vcvtq_f32_s32(r3);
f0 = vmlaq_f32(__bias, __scale, f0);
f1 = vmlaq_f32(__bias, __scale, f1);
f2 = vmlaq_f32(__bias, __scale, f2);
f3 = vmlaq_f32(__bias, __scale, f3);
f0 = vmaxq_f32(__zero, f0);
f1 = vmaxq_f32(__zero, f1);
f2 = vmaxq_f32(__zero, f2);
f3 = vmaxq_f32(__zero, f3);
vst1q_f32(y, f0);
vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2);
vst1q_f32(y + 12, f3);
}
#endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) {
y[k] = std::max(scale * x[k] + bias, 0.f);
}
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP
......@@ -31,7 +31,7 @@ bool DequantizeKernel<CPU, float>::Init(DequantizeParam<CPU> *param) {
template <>
void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) {
const Tensor *input = param.input_;
Tensor *output = param.out_;
Tensor *output = param.output_;
float activation_scale = param.activation_scale_->data<float>()[0];
float weight_scale = param.weight_scale_;
const int32_t *x = input->data<const int32_t>();
......@@ -43,11 +43,15 @@ void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) {
size_t loop = size >> 4;
size_t remain = size & 0xF;
float32x4_t s = vdupq_n_f32(scale);
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) {
int32x4_t r0 = vld1q_s32(x);
int32x4_t r1 = vld1q_s32(x + 4);
int32x4_t r2 = vld1q_s32(x + 8);
int32x4_t r3 = vld1q_s32(x + 12);
const int32_t *local_x = x + (i << 4);
float *local_y = y + (i << 4);
int32x4_t r0 = vld1q_s32(local_x);
int32x4_t r1 = vld1q_s32(local_x + 4);
int32x4_t r2 = vld1q_s32(local_x + 8);
int32x4_t r3 = vld1q_s32(local_x + 12);
float32x4_t f0 = vcvtq_f32_s32(r0);
float32x4_t f1 = vcvtq_f32_s32(r1);
float32x4_t f2 = vcvtq_f32_s32(r2);
......@@ -56,14 +60,14 @@ void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) {
f1 = vmulq_f32(f1, s);
f2 = vmulq_f32(f2, s);
f3 = vmulq_f32(f3, s);
vst1q_f32(y, f0);
vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2);
vst1q_f32(y + 12, f3);
x += 16;
y += 16;
vst1q_f32(local_y, f0);
vst1q_f32(local_y + 4, f1);
vst1q_f32(local_y + 8, f2);
vst1q_f32(local_y + 12, f3);
}
size = remain;
x += (loop << 4);
y += (loop << 4);
#endif
for (size_t i = 0; i < size; ++i) {
y[i] = x[i] * scale;
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
......
......@@ -17,18 +17,19 @@ limitations under the License. */
#pragma once
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/pad.h"
#include "operators/math/vol2col.h"
#include "operators/math/winograd/winograd_transform.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename Itype, typename Otype>
inline void ConvBasic(const ConvParam<CPU> &param) {
inline void GemmConv(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor *output = param.Output();
......@@ -38,10 +39,7 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
const std::vector<int> paddings = param.Paddings();
const std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
......@@ -82,6 +80,7 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
math::Vol2ColFunctor<CPU, Itype> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, Itype> im2col;
const int batch_size = static_cast<int>(input->dims()[0]);
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
......@@ -99,7 +98,6 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
......@@ -116,25 +114,86 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
}
}
template <typename P>
void ConvCompute(const ConvParam<CPU> &param) {
if (param.Input()->type() == typeid(int8_t)) {
ConvBasic<int8_t, int32_t>(param);
} else {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
template <int tile, int kernel>
inline void WinogradConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
Tensor *output = param.Output();
output->mutable_data<float>();
int batch_size = input->dims()[0];
int groups = param.Groups();
const std::vector<int> &paddings = param.Paddings();
auto winograd_pad = [&](int width, int pad) {
int output_tile = tile - kernel + 1;
// int tiles = (width + pad - kernel) / output_tile + 1;
// return (tiles - 1) * output_tile + tile - width;
int pad_width = (width + 2 * pad - kernel) / output_tile * output_tile;
return pad_width + tile - width;
};
math::PadFunctor<CPU, float> pad;
Tensor input_pad;
framework::Tensor transformed_input;
for (int i = 0; i < batch_size; ++i) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
// int pad_bottom = winograd_pad(in_batch.dims()[2], paddings[0]);
// int pad_right = winograd_pad(in_batch.dims()[3], paddings[1]);
int pad_bottom = paddings[0];
int pad_right = paddings[1];
if (paddings[0] || paddings[1] || pad_bottom || pad_right) {
framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += paddings[0] + pad_bottom;
pad_shape[3] += paddings[1] + pad_right;
input_pad.mutable_data<float>(pad_shape);
pad(in_batch, paddings[0], pad_bottom, paddings[1], pad_right,
&input_pad);
} else {
input_pad = in_batch;
}
// tile input and transform
math::winograd_transform_input<tile, kernel>(input_pad, &transformed_input);
// caculate output
math::winograd_transform_output<tile, kernel>(transformed_input, *filter,
output);
}
}
template <typename Itype, typename Otype>
inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
Tensor *output = param.Output();
output->mutable_data<Otype>();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = static_cast<int>(input->dims()[0]);
Tensor input_pad;
math::PadFunctor<CPU, Itype> pad;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
if (paddings[0] || paddings[1]) {
framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += 2 * paddings[0];
pad_shape[3] += 2 * paddings[1];
input_pad.mutable_data<float>(pad_shape);
pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1],
&input_pad);
} else {
input_pad = in_batch;
}
if (strides[0] == 1) {
math::DepthwiseConv3x3s1<Itype, Otype>(input_pad, *filter, &out_batch);
} else if (strides[0] == 2) {
math::DepthwiseConv3x3s2<Itype, Otype>(input_pad, *filter, &out_batch);
} else {
ConvBasic<float, float>(param);
// math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter,
// &out_batch);
PADDLE_MOBILE_THROW_EXCEPTION(
"Depthwise conv with generic strides has not been implemented.");
}
}
}
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
......
......@@ -16,13 +16,15 @@ limitations under the License. */
#pragma once
#include <vector>
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void ConvBNReluBasic(const FusionConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
......
......@@ -15,10 +15,9 @@ limitations under the License. */
#ifdef DEPTHWISECONV_OP
#pragma once
#include <operators/math/depthwise_conv_3x3.h>
#include <vector>
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/op_param.h"
namespace paddle_mobile {
......@@ -44,7 +43,7 @@ void DepthwiseConvCompute(const ConvParam<CPU> &param) {
Bias, false);
} else {
ConvBasic<float, float>(param);
GemmConv<float, float>(param);
}
}
......
......@@ -16,13 +16,15 @@ limitations under the License. */
#pragma once
#include <vector>
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void DWConvBNReluBasic(const FusionDWConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
......
......@@ -24,7 +24,7 @@ limitations under the License. */
#include "framework/ddim.h"
#include "framework/operator.h"
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class FusionDequantAddBNReluKernel
: public framework::OpKernelBase<DeviceType,
FusionDequantAddBNReluParam<DeviceType>> {
public:
void Compute(const FusionDequantAddBNReluParam<DeviceType> &param);
bool Init(FusionDequantAddBNReluParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -11,18 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include <vector>
#if __ARM_NEON
#include <arm_neon.h>
#endif
#include <vector>
namespace paddle_mobile {
namespace operators {
namespace math {
void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
vector<int> paddings, const Tensor *filter, Tensor *bias,
Tensor *output, bool if_bias) {
void DepthwiseConv3x3(const framework::Tensor *input,
const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias) {
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
......@@ -67,12 +71,12 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width;
hend = min(hstart + _kernel_size, input_height + padding_height);
wend = min(wstart + _kernel_size, input_width + padding_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, input_height);
wend = min(wend, input_width);
hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, input_height);
wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart;
......@@ -244,12 +248,14 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
}
}
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias) {
void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias) {
#if __ARM_NEON
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
float *output_data = output->mutable_data<float>();
const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
......@@ -517,9 +523,12 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
#endif
}
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu) {
void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
......@@ -1059,9 +1068,12 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
}
/// w!=h not fix
void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu) {
void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON
const int batch_size = input->dims()[0];
......@@ -1107,12 +1119,12 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width;
hend = min(hstart + _kernel_size, input_height + padding_height);
wend = min(wstart + _kernel_size, input_width + padding_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, input_height);
wend = min(wend, input_width);
hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, input_height);
wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart;
......@@ -1258,8 +1270,10 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
#endif
}
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias) {
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias) {
#if __ARM_NEON
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
......@@ -1463,9 +1477,12 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
#endif
}
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu) {
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON
// #ifdef _OPENMP
// const float *newscale_data = new_scale->data<float>();
......@@ -1886,8 +1903,10 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
#endif
}
void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias) {
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias) {
#if __ARM_NEON
const int batch_size = static_cast<int>(input->dims()[0]);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
namespace math {
void DepthwiseConv3x3(const framework::Tensor *input,
const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias);
void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias);
// TODO(hjchen2) need to be implemented
// template<typename Itype, typename Otype>
// void DepthwiseConv3x3(const framework::Tensor *input,
// const framework::Tensor *filter,
// const std::vector<int> &strides,
// framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv3x3s1(const framework::Tensor &input,
const framework::Tensor &filter,
framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv3x3s2(const framework::Tensor &input,
const framework::Tensor &filter,
framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
此差异已折叠。
......@@ -26,79 +26,6 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
/*int MC = 0;
int KC = 0;
int NC = 0;
float *packedA;
float *packedB;
float *packedC;
float *zero;
typedef void (*FnPack)(int, int, int, const float *, int, float *);
typedef void (*FnAddDot)(int, const float *, const float *, float *, int);
FnPack procPackA;
FnPack procPackB;
FnAddDot procAddDot;*/
/*
// 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
int i, j;
const float *Aij;
for (i = 0; i < m - m_tail; i += MR) {
for (j = 0; j < k; ++j) {
Aij = &A(i, j);
*buffer++ = *Aij;
*buffer++ = *(Aij + 1);
*buffer++ = *(Aij + 2);
*buffer++ = *(Aij + 3);
}
}
if (m_tail != 0) {
for (j = 0; j < k; ++j) {
Aij = &A(m - m_tail, j);
for (i = 0; i < m_tail; ++i) {
*buffer++ = *(Aij + i);
}
for (i = m_tail; i < MR; ++i) {
*buffer++ = 0;
}
}
}
}
// 将B矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
int i, j;
const float *Bj, *Bj1, *Bj2, *Bj3;
for (j = 0; j < n - n_tail; j += NR) {
Bj = &B(0, j);
Bj1 = &B(0, j + 1);
Bj2 = &B(0, j + 2);
Bj3 = &B(0, j + 3);
for (i = 0; i < k; ++i) {
*buffer++ = *Bj++;
*buffer++ = *Bj1++;
*buffer++ = *Bj2++;
*buffer++ = *Bj3++;
}
}
if (n_tail != 0) {
for (i = 0; i < k; ++i) {
for (int j = n - n_tail; j < n; ++j) {
*buffer++ = B(i, j);
}
for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0;
}
}
}
}
*/
// 将A矩阵分块复制到连续内存(RowMajor)
void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
......
......@@ -22,6 +22,70 @@ namespace paddle_mobile {
namespace operators {
namespace math {
void ExtractToImg(const float *im_data, float *col_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw) {
int h = padding_h - kh;
int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0;
int start_height = kh + col_start_height * stride_h - padding_h;
int start_width = kw + col_start_width * stride_w - padding_w;
int end_height = (col_height - col_start_height) * stride_h + start_height;
end_height = end_height > im_height ? im_height : end_height;
int end_width = (col_width - col_start_width) * stride_w + start_width;
end_width = end_width > im_width ? im_width : end_width;
int extract = (end_width - start_width + stride_w - 1) / stride_w;
im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) {
if (stride_w == 1) {
memcpy(col_data, im_data, extract * sizeof(float));
} else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4x2_t img = vld2q_f32(im_data + s * 2);
vst1q_f32(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 2];
}
} else if (stride_w == 3) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4x3_t img = vld3q_f32(im_data + s * 3);
vst1q_f32(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 3];
}
} else if (stride_w == 4) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4x4_t img = vld4q_f32(im_data + s * 4);
vst1q_f32(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 4];
}
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1, 2, 3 and 4.");
}
im_data += im_width * stride_h;
col_data += col_width;
}
}
/*
* im = [input_channels, input_height, input_width]
* col =
......@@ -363,7 +427,27 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0
memset(col_data, 0, col->numel() * sizeof(float));
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) {
const float *local_im_data = im_data + ic * im_spatial_size;
float *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(local_im_data, local_col_data, im_height, im_width,
col_height, col_width, padding[0], padding[1], stride[0],
stride[1], kh, kw);
local_col_data += col_spatial_size;
}
}
}
} else {
#endif
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
......@@ -382,25 +466,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
}
}
}
}
#else
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<float>(0)
: im_data[im_idx];
}
}
#if __ARM_NEON
}
#endif
}
......@@ -489,21 +555,26 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()(
int channels_col = im_channels * filter_height * filter_width;
const int8_t *im_data = im.data<int8_t>();
int8_t *col_data = col->data<int8_t>();
int8_t *col_data = col->mutable_data<int8_t>();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0
memset(col_data, 0, col->numel() * sizeof(int8_t));
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) {
const int8_t *local_im_data = im_data + ic * im_spatial_size;
int8_t *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(im_data, col_data, im_height, im_width, col_height,
col_width, padding[0], padding[1], stride[0], stride[1],
kh, kw);
col_data += col_height * col_width;
ExtractToImg(local_im_data, local_col_data, im_height, im_width,
col_height, col_width, padding[0], padding[1], stride[0],
stride[1], kh, kw);
local_col_data += col_spatial_size;
}
}
im_data += im_height * im_width;
}
} else {
#endif
......
......@@ -21,10 +21,12 @@ namespace math {
template <typename T>
class PadFunctor<CPU, T> {
public:
void operator()(const framework::Tensor &input, const int pad_h,
const int pad_w, framework::Tensor *output) {
void operator()(const framework::Tensor &input, const int pad_top,
const int pad_bottom, const int pad_left, const int pad_right,
framework::Tensor *output) {
const T *in_data = input.data<T>();
T *out_data = output->mutable_data<T>();
// should check output shape is valid for such pad parameters
const framework::DDim &input_shape = input.dims();
const framework::DDim &output_shape = output->dims();
// fill output with 0
......@@ -32,13 +34,13 @@ class PadFunctor<CPU, T> {
// should make sure the shape of output is match with input
for (int i = 0; i < input_shape[0]; ++i) {
for (int c = 0; c < input_shape[1]; ++c) {
out_data += pad_h * output_shape[3];
out_data += pad_top * output_shape[3];
for (int h = 0; h < input_shape[2]; ++h) {
memcpy(out_data + pad_w, in_data, sizeof(T) * input_shape[3]);
memcpy(out_data + pad_left, in_data, sizeof(T) * input_shape[3]);
out_data += output_shape[3];
in_data += input_shape[3];
}
out_data += pad_h * output_shape[3];
out_data += pad_bottom * output_shape[3];
}
}
}
......
......@@ -22,8 +22,9 @@ namespace math {
template <typename DeviceType, typename T>
class PadFunctor {
public:
void operator()(const framework::Tensor &input, const int pad_h,
const int pad_w, framework::Tensor *output);
void operator()(const framework::Tensor &input, const int pad_top,
const int pad_bottom, const int pad_left, const int pad_right,
framework::Tensor *output);
};
} // namespace math
......
......@@ -12,40 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONV_OP
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
namespace math {
using framework::Tensor;
using std::max;
using std::min;
using std::vector;
void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
vector<int> paddings, const Tensor *filter, Tensor *bias,
Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias);
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias);
template <int tile, int kernel>
void winograd_transform_weight(const framework::Tensor &weight,
framework::Tensor *output);
template <int tile, int kernel>
void winograd_transform_input(const framework::Tensor &input,
framework::Tensor *output);
template <int tile, int kernel>
void winograd_transform_output(const framework::Tensor &input,
const framework::Tensor &weight,
framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
此差异已折叠。
......@@ -405,9 +405,9 @@ class ConvParam : public OpParam {
const RType *Input() const { return input_; }
RType *Filter() const { return filter_; }
RType *&Filter() const { return filter_; }
RType *Output() const { return output_; }
RType *&Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
......@@ -415,6 +415,19 @@ class ConvParam : public OpParam {
const vector<int> &Dilations() const { return dilations_; }
enum ExecMode {
EXEC_INVALID = 0,
EXEC_GEMM_FLOAT,
EXEC_DEPTHWISE3x3S1P1_FLOAT,
EXEC_DEPTHWISE3x3_FLOAT,
EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT,
EXEC_GEMM_INT8,
EXEC_DEPTHWISE3x3_INT8,
};
ExecMode &ExecMode() const { return exec_mode_; }
const int &Groups() const { return groups; }
#ifdef PADDLE_MOBILE_CL
......@@ -426,11 +439,12 @@ class ConvParam : public OpParam {
private:
RType *input_;
RType *output_;
RType *filter_;
mutable RType *output_;
mutable RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
mutable enum ExecMode exec_mode_;
int groups;
#ifdef PADDLE_MOBILE_CL
......@@ -2509,10 +2523,10 @@ class QuantizeParam : public OpParam {
QuantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
output_ = OutFrom<GType>(outputs, scope);
// online
// scale = max(abs(x))
online_scale_ = GetVarValue<GType>("OutScale", outputs, scope);
online_scale_ = OpParam::GetVarValue<GType>("OutScale", outputs, scope);
// offline
if (HasAttr("static_scale", attrs)) {
is_static_ = true;
......@@ -2522,14 +2536,18 @@ class QuantizeParam : public OpParam {
if (HasAttr("round_type", attrs)) {
round_type_ = GetAttr<RoundType>("round_type", attrs);
}
// get paddings
paddings_ = std::vector<int>({0, 0});
if (HasAttr("paddings", attrs)) {
paddings_ = GetAttr<vector<int>>("paddings", attrs);
}
}
public:
// op input
RType *input_;
// op output
RType *out_;
//
RType *output_;
RType *online_scale_;
// if static scale or not
bool is_static_ = false;
......@@ -2537,7 +2555,11 @@ class QuantizeParam : public OpParam {
float static_scale_ = 1.0f;
// round method type
// nearest_zero and nearest_even is valid currently
RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO;
// RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO;
RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO;
// optional paddings
std::vector<int> paddings_;
int8_t padding_val_;
};
#endif
......@@ -2551,8 +2573,8 @@ class DequantizeParam : public OpParam {
DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
activation_scale_ = GetVarValue<GType>("Scale", inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
activation_scale_ = OpParam::GetVarValue<GType>("Scale", inputs, scope);
// dequantization is performed as x = x / static_scale / online_scale
if (HasAttr("weight_scale", attrs)) {
weight_scale_ = GetAttr<float>("weight_scale", attrs);
......@@ -2565,11 +2587,50 @@ class DequantizeParam : public OpParam {
// op input
RType *input_;
// op output
RType *out_;
RType *output_;
RType *activation_scale_;
float weight_scale_;
};
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <typename Dtype>
class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: DequantizeParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
// batch norm params
bn_mean_ = OpParam::GetVarValue<GType>("BNMean", inputs, scope);
bn_variance_ = OpParam::GetVarValue<GType>("BNVariance", inputs, scope);
bn_scale_ = OpParam::GetVarValue<GType>("BNScale", inputs, scope);
bn_bias_ = OpParam::GetVarValue<GType>("BNBias", inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
// output
output_ = OpParam::OutFrom<GType>(outputs, scope);
}
public:
// elementwise add
int axis_;
RType *bias_;
// batch norm
RType *bn_mean_;
RType *bn_variance_;
RType *bn_scale_;
RType *bn_bias_;
float epsilon_;
// output
RType *output_;
};
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -22,8 +22,11 @@ namespace operators {
template <typename DeviceType, typename T>
void QuantizeOp<DeviceType, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims();
this->param_.out_->Resize(input_dims);
auto input_dims = this->param_.input_->dims();
const std::vector<int> &paddings = this->param_.paddings_;
input_dims[2] += 2 * paddings[0];
input_dims[3] += 2 * paddings[1];
this->param_.output_->Resize(input_dims);
auto scale_dims = framework::make_ddim(std::vector<int>{1});
this->param_.online_scale_->Resize(scale_dims);
}
......
......@@ -155,7 +155,7 @@ if (NOT FOUND_MATCH)
target_link_libraries(test-googlenet-quali paddle-mobile)
# gen test
ADD_EXECUTABLE(test-conv-op operators/test_cov_op.cpp test_helper.h test_include.h executor_for_test.h)
ADD_EXECUTABLE(test-conv-op operators/test_conv_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-conv-op paddle-mobile)
# gen test
......@@ -242,10 +242,6 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-dequantize-op operators/test_dequantize_op.cpp test_helper.h test_include.h)
target_link_libraries(test-dequantize-op paddle-mobile)
# test int8 conv op
ADD_EXECUTABLE(test-int8-conv-op operators/test_int8_conv_op.cpp test_helper.h test_include.h)
target_link_libraries(test-int8-conv-op paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-log common/test_log.cpp)
target_link_libraries(test-log paddle-mobile)
......@@ -368,6 +364,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-multi-process net/test_multi_inference_predict.cpp test_helper.h test_include.h)
target_link_libraries(test-multi-process paddle-mobile)
# gen test benchmark
ADD_EXECUTABLE(test-benchmark net/test_benchmark.cpp)
target_link_libraries(test-benchmark paddle-mobile)
# gen test
ADD_EXECUTABLE(test-eng net/test_eng.cpp test_helper.h test_include.h)
target_link_libraries(test-eng paddle-mobile)
......
......@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <string>
#include "../test_helper.h"
#include "../test_include.h"
static size_t ReadBuffer(const char *file_name, uint8_t **out) {
FILE *fp;
fp = fopen(file_name, "rb");
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout << "Usage: " << std::endl
<< "./test_benchmark fluid_model feed_shape thread_num [use_fuse]"
<< std::endl;
std::cout << "use_fuse: optional, bool, default is 1\n";
return 1;
}
bool optimize = true;
char* fluid_model = argv[1];
char* feed_shape = argv[2];
int thread_num = atoi(argv[3]);
if (argc == 5) {
optimize = atoi(argv[4]);
}
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(thread_num);
auto time1 = time();
if (paddle_mobile.Load(fluid_model, optimize)) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time2) << "ms\n";
paddle_mobile::framework::Tensor input;
std::shared_ptr<paddle_mobile::framework::Tensor> output;
std::vector<int64_t> dims{1, 3, 224, 224};
if (feed_shape) {
sscanf(feed_shape, "%d,%d,%d,%d", &dims[0], &dims[1], &dims[2], &dims[3]);
}
std::cout << "feed shape: [" << dims[0] << ", " << dims[1] << ", "
<< dims[2] << ", " << dims[3] << "]\n";
paddle_mobile::framework::DDim in_shape =
paddle_mobile::framework::make_ddim(dims);
SetupTensor<float>(&input, in_shape, 0.f, 255.f);
// warmup
for (int i = 0; i < 10; ++i) {
output = paddle_mobile.Predict(input);
}
auto time3 = time();
for (int i = 0; i < 10; ++i) {
output = paddle_mobile.Predict(input);
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms\n";
}
return 0;
}
......@@ -20,12 +20,11 @@ int main() {
#ifdef PADDLE_MOBILE_FPGA
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
#endif
#ifdef PADDLE_MOBILE_CPU
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
#endif
paddle_mobile.SetThreadNum(4);
paddle_mobile.SetThreadNum(1);
bool optimize = true;
auto time1 = time();
if (paddle_mobile.Load(g_googlenet, optimize)) {
......@@ -36,7 +35,7 @@ int main() {
std::vector<float> output;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
// 预热十次
// warmup
for (int i = 0; i < 10; ++i) {
output = paddle_mobile.Predict(input, dims);
}
......@@ -46,8 +45,7 @@ int main() {
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms"
<< std::endl;
std::cout << "predict cost: " << time_diff(time3, time4) / 10 << "ms\n";
}
return 0;
}
......@@ -18,7 +18,7 @@ limitations under the License. */
namespace paddle_mobile {
// Reference convolution for checking results:
// Reference convolution from Caffe for checking results.
// accumulate through explicit loops over input, output, and filters.
template <typename Itype, typename Otype>
void conv2d(const framework::Tensor *input, const framework::Tensor *filter,
......@@ -129,7 +129,7 @@ void conv2d(const framework::Tensor *input, const framework::Tensor *filter,
}
template <typename Itype, typename Otype, int Kernel, int Pad, int Stride>
int TestConvOp() {
int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) {
int kernel_h = Kernel;
int kernel_w = Kernel;
int pad_h = Pad;
......@@ -140,10 +140,10 @@ int TestConvOp() {
int dilation_w = 1;
int batch_size = 1;
int input_c = 3;
int input_h = 100;
int input_w = 100;
int output_c = 10;
int input_c = in_channels;
int input_h = in_height;
int input_w = in_width;
int output_c = out_channels;
framework::DDim input_shape =
framework::make_ddim({batch_size, input_c, input_h, input_w});
framework::DDim filter_shape =
......@@ -158,7 +158,7 @@ int TestConvOp() {
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<Itype>(input, input_shape, -20, 20);
SetupTensor<Itype>(input, input_shape, -20.0, 20.0);
auto filter_var = scope.get()->Var("filter");
auto filter = filter_var->template GetMutable<framework::LoDTensor>();
......@@ -174,8 +174,9 @@ int TestConvOp() {
auto *op = new operators::ConvOp<CPU, float>("conv2d", inputs, outputs, attrs,
scope);
// struct timespec ts_begin, ts_end;
op->InferShape();
op->Init();
// struct timespec ts_begin, ts_end;
// warmup
// op->Run();
// clock_gettime(CLOCK_MONOTONIC, &ts_begin);
......@@ -202,9 +203,16 @@ int TestConvOp() {
const Otype *output_data = output->data<Otype>();
Otype *output_cmp_data = output_cmp.data<Otype>();
for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
float gap = output_data[i] - output_cmp_data[i];
PADDLE_MOBILE_ENFORCE(std::abs(gap / (output_data[i] + 1e-5)) < 1e-3,
"output[%d] = %d, output_cmp[%d] = %d", i,
output_data[i], i, output_cmp_data[i]);
// if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
// LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
// << ", output_cmp_data[" << i << "] = " <<
// output_cmp_data[i];
// return 1;
// }
}
delete op;
return 0;
......@@ -212,68 +220,88 @@ int TestConvOp() {
} // namespace paddle_mobile
int main() {
int main(int argc, char *argv[]) {
if (argc < 5) {
LOG(paddle_mobile::kLOG_INFO)
<< "Usage:\n"
<< " ./test-int8-conv-op in_channels in_height in_width out_channels\n"
<< " params:\n"
<< " -in_channels: int, input image's channels\n"
<< " -in_height: int, input image's height\n"
<< " -in_width: int, input image's width\n"
<< " -out_channels: int, conv output channels\n";
return 1;
}
int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]);
int out_channels = atoi(argv[4]);
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<float, float, 3, 1, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 2>();
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 2>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 1, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 2>();
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 2>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 3, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 2>();
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 2>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 1>();
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 1>();
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 3, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 1>();
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 5, stride = 3
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 5, 3>();
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 5, 3>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 3, stride = 4
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 4>();
LOG(paddle_mobile::kLOG_INFO) << "\n";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 4>(in_channels, in_height,
in_width, out_channels);
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 0, 1>();
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 0, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1";
paddle_mobile::TestConvOp<float, float, 3, 0, 1>();
LOG(paddle_mobile::kLOG_INFO) << "\n";
paddle_mobile::TestConvOp<float, float, 3, 0, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 1, 1>();
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 1, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<float, float, 3, 1, 1>();
LOG(paddle_mobile::kLOG_INFO) << "\n";
paddle_mobile::TestConvOp<float, float, 3, 1, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 0, 1>();
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 0, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 0, 1>();
LOG(paddle_mobile::kLOG_INFO) << "\n";
paddle_mobile::TestConvOp<float, float, 5, 0, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 2, 1>();
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 2, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 2, 1>();
paddle_mobile::TestConvOp<float, float, 5, 2, 1>(in_channels, in_height,
in_width, out_channels);
}
此差异已折叠。
......@@ -212,4 +212,4 @@ else
else
build_error "$1"
fi
fi
\ No newline at end of file
fi
......@@ -249,6 +249,7 @@ if(NOT FOUND_MATCH)
set(SUM_OP ON)
set(QUANT_OP ON)
set(DEQUANT_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -450,6 +451,9 @@ endif()
if (DEQUANT_OP)
add_definitions(-DDEQUANT_OP)
endif()
if (FUSION_DEQUANT_ADD_BN_RELU)
add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP)
endif()
if (TANH_OP)
add_definitions(-DTANH_OP)
......@@ -462,4 +466,4 @@ if (FUSION_DECONVADD_OP)
endif()
if (FUSION_DECONVADDRELU_OP)
add_definitions(-DFUSION_DECONVADDRELU_OP)
endif()
\ No newline at end of file
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册