提交 4a4157d5 编写于 作者: H hjchen2

Refine: fix depthwise conv bug and support stride=2

上级 c9630379
......@@ -72,7 +72,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(),
program_.scope);
// infer shape to reshape tensor before predict,
// but for lod tensor, it will need to reshape in runtime
// but for lod tensor, it will still need to reshape in runtime
if (!loddable_) {
op_base->InferShape();
}
......
......@@ -233,6 +233,3 @@ LOAD_OP1(quantize, CPU);
#ifdef DEQUANT_OP
LOAD_OP1(dequantize, CPU);
#endif
#ifdef PAD_OP
LOAD_OP1(pad, CPU);
#endif
......@@ -15,7 +15,6 @@ limitations under the License. */
#ifdef CONV_OP
#include "operators/kernel/conv_kernel.h"
#include <iostream>
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile {
......@@ -27,7 +26,8 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3) {
param->Filter()->dims()[2] == 3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
......@@ -70,30 +70,23 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_GEMM_INT8:
GemmConv<int8_t, int32_t>(param);
std::cout << "EXEC_GEMM_INT8" << std::endl;
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8:
DepthwiseConv3x3<int8_t, int32_t>(param);
std::cout << "EXEC_DEPTHWISE3x3_INT8" << std::endl;
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false);
std::cout << "EXEC_DEPTHWISE3x3S1P1_FLOAT" << std::endl;
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
std::cout << "EXEC_DEPTHWISE3x3_FLOAT=" << param.Strides()[0]
<< std::endl;
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
std::cout << "EXEC_WINOGRAD3X3_FLOAT" << std::endl;
break;
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param);
std::cout << "EXEC_GEMM_FLOAT" << std::endl;
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
......
......@@ -15,7 +15,6 @@ limitations under the License. */
#ifdef DEQUANT_OP
#include "operators/kernel/dequantize_kernel.h"
#include <iostream>
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
......
......@@ -15,7 +15,6 @@ limitations under the License. */
#ifdef ELEMENTWISEADD_OP
#include "operators/kernel/elementwise_add_kernel.h"
#include <iostream>
#include "operators/kernel/central-arm-func/elementwise_add_arm_func.h"
namespace paddle_mobile {
......
......@@ -126,54 +126,6 @@ static float find_abs_max(const Tensor *input) {
return max_abs;
}
#if 0
static void quantize_round_to_zero(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val, Tensor *output) {
const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4;
size_t remain = size & 0xF;
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) {
const float *local_x = x + (i << 4);
int8_t *local_y = y + (i << 4);
float32x4_t r0 = vld1q_f32(local_x);
float32x4_t r1 = vld1q_f32(local_x + 4);
float32x4_t r2 = vld1q_f32(local_x + 8);
float32x4_t r3 = vld1q_f32(local_x + 12);
r0 = vmulq_n_f32(r0, scale);
r1 = vmulq_n_f32(r1, scale);
r2 = vmulq_n_f32(r2, scale);
r3 = vmulq_n_f32(r3, scale);
int32x4_t q0 = vrnd_towards_zero(r0);
int32x4_t q1 = vrnd_towards_zero(r1);
int32x4_t q2 = vrnd_towards_zero(r2);
int32x4_t q3 = vrnd_towards_zero(r3);
int16x4_t d0 = vmovn_s32(q0);
int16x4_t d1 = vmovn_s32(q1);
int16x4_t d2 = vmovn_s32(q2);
int16x4_t d3 = vmovn_s32(q3);
int16x8_t q5 = vcombine_s16(d0, d1);
int16x8_t q6 = vcombine_s16(d2, d3);
int8x8_t d5 = vmovn_s16(q5);
int8x8_t d6 = vmovn_s16(q6);
vst1_s8(local_y, d5);
vst1_s8(local_y + 8, d6);
}
size = remain;
x += (loop << 4);
y += (loop << 4);
#endif
for (size_t i = 0; i < size; ++i) {
y[i] = static_cast<int8_t>(x[i] * scale);
}
}
#endif
#ifdef __aarch64__
static void quantize_round_to_even(const Tensor *input, const float scale,
Tensor *output) {
......@@ -320,7 +272,7 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale,
y[i] = round(x[i] * scale);
}
}
#else // __aarch64__
#else // __aarch64__
static void quantize_round_to_even(const Tensor *input, const float scale,
const std::vector<int> &paddings,
......@@ -330,7 +282,7 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val,
Tensor *output) {}
#if 1
static void quantize_round_to_zero(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val, Tensor *output) {
......@@ -347,6 +299,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
int start = paddings[0] * output_w + paddings[1];
for (int batch = 0; batch < input->dims()[0]; ++batch) {
#pragma omp parallel for
for (int c = 0; c < channels - 3; c += 4) {
const float *input0 = x + (batch * channels + c) * input_spatial_size;
const float *input1 = input0 + input_spatial_size;
......@@ -819,7 +772,6 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
}
}
}
#endif
#endif // __aarch64__
#endif // ARM_NEON
......
......@@ -98,7 +98,6 @@ inline void GemmConv(const ConvParam<CPU> &param) {
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
......@@ -176,25 +175,25 @@ inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
// if (paddings[0] || paddings[1]) {
// framework::DDim pad_shape = in_batch.dims();
// pad_shape[2] += 2 * paddings[0];
// pad_shape[3] += 2 * paddings[1];
// input_pad.mutable_data<float>(pad_shape);
// pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1],
// &input_pad);
// } else {
// input_pad = in_batch;
// }
// math::DepthwiseConv3x3s1<Itype, Otype>(input_pad, *filter,
// &out_batch);
if (paddings[0] || paddings[1]) {
framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += 2 * paddings[0];
pad_shape[3] += 2 * paddings[1];
input_pad.mutable_data<float>(pad_shape);
pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1],
&input_pad);
} else {
input_pad = in_batch;
}
if (strides[0] == 1) {
math::DepthwiseConv3x3s1<Itype, Otype>(in_batch, *filter, &out_batch);
math::DepthwiseConv3x3s1<Itype, Otype>(input_pad, *filter, &out_batch);
} else if (strides[0] == 2) {
math::DepthwiseConv3x3s2<Itype, Otype>(in_batch, *filter, &out_batch);
math::DepthwiseConv3x3s2<Itype, Otype>(input_pad, *filter, &out_batch);
} else {
// math::DepthwiseConv3x3<Itype, Otype>(in_batch, *filter,
// &out_batch);
// math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter,
// &out_batch);
PADDLE_MOBILE_THROW_EXCEPTION(
"Depthwise conv with generic strides has not been implemented.");
}
}
}
......
......@@ -65,6 +65,7 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
framework::Tensor *output, framework::Tensor bias,
bool if_bias);
// TODO(hjchen2) need to be implemented
// template<typename Itype, typename Otype>
// void DepthwiseConv3x3(const framework::Tensor *input,
// const framework::Tensor *filter,
......
......@@ -2564,26 +2564,5 @@ class DequantizeParam : public OpParam {
};
#endif
#ifdef PAD_OP
template <typename Dtype>
class PadParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
input_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
paddings_ = GetVarValue<std::vector<int>>("Paddings", inputs, scope);
public:
// op input
RType *input_;
// op output
RType *output_;
// paddings
std::vector<int> paddings_;
};
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -22,11 +22,7 @@ namespace operators {
template <typename DeviceType, typename T>
void QuantizeOp<DeviceType, T>::InferShape() const {
auto input_dims = this->param_.input_->dims();
// const auto &paddings = this->param_.paddings_;
std::vector<int> paddings = {0, 0};
input_dims[2] += 2 * paddings[0];
input_dims[3] += 2 * paddings[1];
const auto &input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
auto scale_dims = framework::make_ddim(std::vector<int>{1});
this->param_.online_scale_->Resize(scale_dims);
......
......@@ -69,7 +69,6 @@ build_for_android() {
-DANDROID_ABI="${ABI}" \
-DCMAKE_BUILD_TYPE="${MODE}" \
-DCMAKE_TOOLCHAIN_FILE="${TOOLCHAIN_FILE}" \
-DANDROID_TOOLCHAIN='clang' \
-DANDROID_PLATFORM="${ANDROID_PLATFORM_VERSION}" \
-DCMAKE_CXX_FLAGS="${CXX_FLAGS}" \
-DANDROID_STL=c++_static \
......@@ -83,7 +82,6 @@ build_for_android() {
-DANDROID_ABI="${ABI}" \
-DCMAKE_BUILD_TYPE="${MODE}" \
-DCMAKE_TOOLCHAIN_FILE="${TOOLCHAIN_FILE}" \
-DANDROID_TOOLCHAIN='clang' \
-DANDROID_PLATFORM="${ANDROID_PLATFORM_VERSION}" \
-DCMAKE_CXX_FLAGS="${CXX_FLAGS}" \
-DANDROID_STL=c++_static \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册