提交 de37013f 编写于 作者: H hjchen2

Support padding in 8bit depthwise conv, so remove padding from dequantize kernel

上级 7b5a6c39
...@@ -55,10 +55,10 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { ...@@ -55,10 +55,10 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Input()->dims()[2] <= 140 /* refered from ncnn */) { param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight // transform weight
framework::Tensor *transformed_weight = new framework::Tensor; framework::Tensor transformed_weight;
operators::math::winograd_transform_weight<8, 3>(*param->Filter(), operators::math::winograd_transform_weight<8, 3>(*param->Filter(),
transformed_weight); &transformed_weight);
param->Filter() = transformed_weight; framework::TensorCopy(transformed_weight, param->Filter());
#endif #endif
} else { } else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
......
...@@ -20,6 +20,9 @@ limitations under the License. */ ...@@ -20,6 +20,9 @@ limitations under the License. */
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h> #include <arm_neon.h>
namespace paddle_mobile {
namespace operators {
#ifndef __aarch64__ #ifndef __aarch64__
inline float32_t vmaxvq_f32(float32x4_t r) { inline float32_t vmaxvq_f32(float32x4_t r) {
float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r)); float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r));
...@@ -27,9 +30,13 @@ inline float32_t vmaxvq_f32(float32x4_t r) { ...@@ -27,9 +30,13 @@ inline float32_t vmaxvq_f32(float32x4_t r) {
} }
#endif #endif
inline int32x4_t vrnd_towards_zero(float32x4_t r) { return vcvtq_s32_f32(r); } template <RoundType R = ROUND_NEAREST_TOWARDS_ZERO>
inline int32x4_t vround_f32(float32x4_t r) {
return vcvtq_s32_f32(r);
}
inline int32x4_t vrnd_away_zero(float32x4_t r) { template <>
inline int32x4_t vround_f32<ROUND_NEAREST_AWAY_ZERO>(float32x4_t r) {
float32x4_t plus = vdupq_n_f32(0.5); float32x4_t plus = vdupq_n_f32(0.5);
float32x4_t minus = vdupq_n_f32(-0.5); float32x4_t minus = vdupq_n_f32(-0.5);
float32x4_t zero = vdupq_n_f32(0); float32x4_t zero = vdupq_n_f32(0);
...@@ -40,31 +47,13 @@ inline int32x4_t vrnd_away_zero(float32x4_t r) { ...@@ -40,31 +47,13 @@ inline int32x4_t vrnd_away_zero(float32x4_t r) {
return ret; return ret;
} }
inline int32x4_t vrnd_to_even(float32x4_t r) { template <>
#if 0 inline int32x4_t vround_f32<ROUND_NEAREST_TO_EVEN>(float32x4_t r) {
int32x4_t ret;
float value[4];
vst1q_f32(value, r);
for (int i = 0; i < 4; ++i) {
float v = round(value[i]);
int32_t q = (int32_t)v;
if (abs(abs(v - value[i]) - 0.5) > 0) {
ret[i] = q;
} else {
if (abs(q) % 2 == 0) {
ret[i] = q;
} else {
ret[i] = q + ((q > 0) ? -1 : 1);
}
}
}
return ret;
#else
float32x4_t point5 = vdupq_n_f32(0.5); float32x4_t point5 = vdupq_n_f32(0.5);
int32x4_t one = vdupq_n_s32(1); int32x4_t one = vdupq_n_s32(1);
int32x4_t zero = vdupq_n_s32(0); int32x4_t zero = vdupq_n_s32(0);
int32x4_t rnd = vrnd_away_zero(r); int32x4_t rnd = vround_f32<ROUND_NEAREST_AWAY_ZERO>(r);
float32x4_t frnd = vcvtq_f32_s32(rnd); float32x4_t frnd = vcvtq_f32_s32(rnd);
frnd = vsubq_f32(frnd, r); frnd = vsubq_f32(frnd, r);
frnd = vabsq_f32(frnd); frnd = vabsq_f32(frnd);
...@@ -82,117 +71,39 @@ inline int32x4_t vrnd_to_even(float32x4_t r) { ...@@ -82,117 +71,39 @@ inline int32x4_t vrnd_to_even(float32x4_t r) {
smask = vsubq_s32(smask, one); smask = vsubq_s32(smask, one);
rnd = vaddq_s32(rnd, smask); rnd = vaddq_s32(rnd, smask);
return rnd; return rnd;
#endif
} }
namespace paddle_mobile {
namespace operators {
static float find_abs_max(const Tensor *input) {
float max_abs = 0.f;
const float *x = input->data<const float>();
size_t size = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4;
size_t remain = size & 0xF;
for (size_t i = 0; i < loop; ++i) {
float32x4_t max;
float32x4_t r0 = vld1q_f32(x);
float32x4_t r1 = vld1q_f32(x + 4);
float32x4_t r2 = vld1q_f32(x + 8);
float32x4_t r3 = vld1q_f32(x + 12);
r0 = vabsq_f32(r0);
r1 = vabsq_f32(r1);
r2 = vabsq_f32(r2);
r3 = vabsq_f32(r3);
max[0] = vmaxvq_f32(r0);
max[1] = vmaxvq_f32(r1);
max[2] = vmaxvq_f32(r2);
max[3] = vmaxvq_f32(r3);
max[0] = vmaxvq_f32(max);
if (max[0] > max_abs) {
max_abs = max[0];
}
x += 16;
}
size = remain;
#endif #endif
for (size_t i = 0; i < size; ++i) {
float value = std::abs(x[i]); template <RoundType R = ROUND_NEAREST_TOWARDS_ZERO>
if (value > max_abs) { inline int8_t Round(const float &x) {
max_abs = value; return static_cast<int8_t>(x);
}
}
return max_abs;
} }
#ifdef __aarch64__ template <>
static void quantize_round_to_even(const Tensor *input, const float scale, inline int8_t Round<ROUND_NEAREST_AWAY_ZERO>(const float &x) {
const std::vector<int> &paddings, return std::round(x);
const int8_t padding_val, Tensor *output) { }
const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4;
size_t remain = size & 0xF;
#pragma omp parallel for template <>
for (size_t i = 0; i < loop; ++i) { inline int8_t Round<ROUND_NEAREST_TO_EVEN>(const float &x) {
const float *local_x = x + (i << 4); float v = std::round(x);
int8_t *local_y = y + (i << 4); int32_t q = static_cast<int32_t>(v);
float32x4_t r0 = vld1q_f32(local_x); if (abs(abs(q - v) - 0.5) <= 0) {
float32x4_t r1 = vld1q_f32(local_x + 4); if (abs(q) % 2 != 0) {
float32x4_t r2 = vld1q_f32(local_x + 8); q = q + ((q > 0) ? -1 : 1);
float32x4_t r3 = vld1q_f32(local_x + 12);
r0 = vmulq_n_f32(r0, scale);
r1 = vmulq_n_f32(r1, scale);
r2 = vmulq_n_f32(r2, scale);
r3 = vmulq_n_f32(r3, scale);
int32x4_t q0 = vrnd_to_even(r0);
int32x4_t q1 = vrnd_to_even(r1);
int32x4_t q2 = vrnd_to_even(r2);
int32x4_t q3 = vrnd_to_even(r3);
int16x4_t d0 = vmovn_s32(q0);
int16x4_t d1 = vmovn_s32(q1);
int16x4_t d2 = vmovn_s32(q2);
int16x4_t d3 = vmovn_s32(q3);
int16x8_t q5 = vcombine_s16(d0, d1);
int16x8_t q6 = vcombine_s16(d2, d3);
int8x8_t d5 = vmovn_s16(q5);
int8x8_t d6 = vmovn_s16(q6);
vst1_s8(local_y, d5);
vst1_s8(local_y + 8, d6);
}
size = remain;
x += (loop << 4);
y += (loop << 4);
#endif
for (size_t i = 0; i < size; ++i) {
float value = x[i] * scale;
float v = round(value);
int32_t q = (int32_t)v;
if (abs(abs(q - value) - 0.5) > 0) {
y[i] = q;
} else {
if (abs(q) % 2 == 0) {
y[i] = q;
} else {
y[i] = q + ((q > 0) ? -1 : 1);
}
} }
} }
return static_cast<int8_t>(q);
} }
static void quantize_round_to_zero(const Tensor *input, const float scale, template <RoundType R>
const std::vector<int> &paddings, static void Quantize(const Tensor *input, const float scale, Tensor *output) {
const int8_t padding_val, Tensor *output) {
const float *x = input->data<const float>(); const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>(); int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel(); size_t remain = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4; size_t loop = remain >> 4;
size_t remain = size & 0xF; remain = remain & 0xF;
#pragma omp parallel for #pragma omp parallel for
for (size_t i = 0; i < loop; ++i) { for (size_t i = 0; i < loop; ++i) {
...@@ -206,10 +117,10 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, ...@@ -206,10 +117,10 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
r1 = vmulq_n_f32(r1, scale); r1 = vmulq_n_f32(r1, scale);
r2 = vmulq_n_f32(r2, scale); r2 = vmulq_n_f32(r2, scale);
r3 = vmulq_n_f32(r3, scale); r3 = vmulq_n_f32(r3, scale);
int32x4_t q0 = vrnd_towards_zero(r0); int32x4_t q0 = vround_f32<R>(r0);
int32x4_t q1 = vrnd_towards_zero(r1); int32x4_t q1 = vround_f32<R>(r1);
int32x4_t q2 = vrnd_towards_zero(r2); int32x4_t q2 = vround_f32<R>(r2);
int32x4_t q3 = vrnd_towards_zero(r3); int32x4_t q3 = vround_f32<R>(r3);
int16x4_t d0 = vmovn_s32(q0); int16x4_t d0 = vmovn_s32(q0);
int16x4_t d1 = vmovn_s32(q1); int16x4_t d1 = vmovn_s32(q1);
int16x4_t d2 = vmovn_s32(q2); int16x4_t d2 = vmovn_s32(q2);
...@@ -221,563 +132,44 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, ...@@ -221,563 +132,44 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
vst1_s8(local_y, d5); vst1_s8(local_y, d5);
vst1_s8(local_y + 8, d6); vst1_s8(local_y + 8, d6);
} }
size = remain;
x += (loop << 4); x += (loop << 4);
y += (loop << 4); y += (loop << 4);
#endif #endif
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < remain; ++i) {
y[i] = static_cast<int8_t>(x[i] * scale); y[i] = Round<R>(x[i] * scale);
} }
} }
static void quantize_round_to_nearest(const Tensor *input, const float scale, float find_abs_max(const Tensor *input) {
const std::vector<int> &paddings, float max_abs = 0.f;
const int8_t padding_val,
Tensor *output) {
const float *x = input->data<const float>(); const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>(); size_t remain = input->numel();
size_t size = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4; size_t loop = remain >> 4;
size_t remain = size & 0xF; remain = remain & 0xF;
float32x4_t __max = {0.f, 0.f, 0.f, 0.f};
#pragma omp parallel for for (size_t i = 0; i < loop; ++i, x += 16) {
for (size_t i = 0; i < loop; ++i) { float32x4_t r0 = vld1q_f32(x);
const float *local_x = x + (i << 4); float32x4_t r1 = vld1q_f32(x + 4);
int8_t *local_y = y + (i << 4); float32x4_t r2 = vld1q_f32(x + 8);
float32x4_t r0 = vld1q_f32(local_x); float32x4_t r3 = vld1q_f32(x + 12);
float32x4_t r1 = vld1q_f32(local_x + 4); r0 = vabsq_f32(r0);
float32x4_t r2 = vld1q_f32(local_x + 8); r1 = vabsq_f32(r1);
float32x4_t r3 = vld1q_f32(local_x + 12); r2 = vabsq_f32(r2);
r0 = vmulq_n_f32(r0, scale); r3 = vabsq_f32(r3);
r1 = vmulq_n_f32(r1, scale); r0 = vmaxq_f32(r0, r1);
r2 = vmulq_n_f32(r2, scale); r1 = vmaxq_f32(r2, r3);
r3 = vmulq_n_f32(r3, scale); r0 = vmaxq_f32(r0, r1);
int32x4_t q0 = vrnd_away_zero(r0); __max = vmaxq_f32(r0, __max);
int32x4_t q1 = vrnd_away_zero(r1);
int32x4_t q2 = vrnd_away_zero(r2);
int32x4_t q3 = vrnd_away_zero(r3);
int16x4_t d0 = vmovn_s32(q0);
int16x4_t d1 = vmovn_s32(q1);
int16x4_t d2 = vmovn_s32(q2);
int16x4_t d3 = vmovn_s32(q3);
int16x8_t q5 = vcombine_s16(d0, d1);
int16x8_t q6 = vcombine_s16(d2, d3);
int8x8_t d5 = vmovn_s16(q5);
int8x8_t d6 = vmovn_s16(q6);
vst1_s8(local_y, d5);
vst1_s8(local_y + 8, d6);
} }
size = remain; max_abs = vmaxvq_f32(__max);
x += (loop << 4);
y += (loop << 4);
#endif #endif
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < remain; ++i) {
y[i] = round(x[i] * scale); max_abs = std::max(max_abs, std::abs(x[i]));
}
}
#else // __aarch64__
static void quantize_round_to_even(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val, Tensor *output) {}
static void quantize_round_to_nearest(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val,
Tensor *output) {}
static void quantize_round_to_zero(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val, Tensor *output) {
int channels = input->dims()[1];
int input_h = input->dims()[2];
int input_w = input->dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int input_spatial_size = input_h * input_w;
int output_spatial_size = output_h * output_w;
const float *x = input->data<float>();
int8_t *y = output->mutable_data<int8_t>();
// valid area start
int start = paddings[0] * output_w + paddings[1];
for (int batch = 0; batch < input->dims()[0]; ++batch) {
#pragma omp parallel for
for (int c = 0; c < channels - 3; c += 4) {
const float *input0 = x + (batch * channels + c) * input_spatial_size;
const float *input1 = input0 + input_spatial_size;
const float *input2 = input1 + input_spatial_size;
const float *input3 = input2 + input_spatial_size;
size_t offset = (batch * channels + c) * output_spatial_size;
for (int h = 0; h < 2; ++h) {
int8_t *y0 =
y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]);
int8_t *y1 = y0 + output_spatial_size;
int8_t *y2 = y1 + output_spatial_size;
int8_t *y3 = y2 + output_spatial_size;
int loop = start >> 4;
int remain = start & 0xF;
asm volatile(
"vdup.s8 q0, %[val] \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"store_16w_%=: \n"
"vst1.32 {q0}, [%[y0]]! \n"
"vst1.32 {q0}, [%[y1]]! \n"
"vst1.32 {q0}, [%[y2]]! \n"
"vst1.32 {q0}, [%[y3]]! \n"
"subs %[loop], #1 \n"
"bne store_16w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d0}, [%[y0]]! \n"
"vst1.32 {d0}, [%[y1]]! \n"
"vst1.32 {d0}, [%[y2]]! \n"
"vst1.32 {d0}, [%[y3]]! \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"vst1.32 {d0[0]}, [%[y1]]! \n"
"vst1.32 {d0[0]}, [%[y2]]! \n"
"vst1.32 {d0[0]}, [%[y3]]! \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"vst1.16 {d0[0]}, [%[y1]]! \n"
"vst1.16 {d0[0]}, [%[y2]]! \n"
"vst1.16 {d0[0]}, [%[y3]]! \n"
"sub %[remain], #2 \n"
"store_1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"vst1.8 {d0[0]}, [%[y1]]! \n"
"vst1.8 {d0[0]}, [%[y2]]! \n"
"vst1.8 {d0[0]}, [%[y3]]! \n"
"end_%=: \n"
: [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3),
[loop] "+r"(loop), [remain] "+r"(remain)
: [val] "r"(padding_val)
: "cc", "memory", "q0");
}
// quantize valid area
int8_t *y0 = y + offset + start;
int8_t *y1 = y0 + output_spatial_size;
int8_t *y2 = y1 + output_spatial_size;
int8_t *y3 = y2 + output_spatial_size;
for (int h = 0; h < input_h; ++h) {
const float *x0 = input0 + h * input_w;
const float *x1 = input1 + h * input_w;
const float *x2 = input2 + h * input_w;
const float *x3 = input3 + h * input_w;
int loop = input_w >> 4;
int remain = input_w & 0xF;
int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2
int pad_remain = (paddings[1] << 1) & 0x3;
int remain_steps = remain;
asm volatile(
"vdup.f32 q0, %[scale] \n"
"cmp %[loop], #0 \n"
"ble quantize_remain_%= \n"
"loop_quantize_%=: \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vld1.32 {q3, q4}, [%[x1]]! \n"
"vld1.32 {q5, q6}, [%[x2]]! \n"
"vld1.32 {q7, q8}, [%[x3]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d18, q1 \n"
"vmovn.s16 d20, q2 \n"
"vmovn.s16 d22, q3 \n"
"vmovn.s16 d24, q4 \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vld1.32 {q3, q4}, [%[x1]]! \n"
"vld1.32 {q5, q6}, [%[x2]]! \n"
"vld1.32 {q7, q8}, [%[x3]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d19, q1 \n"
"vmovn.s16 d21, q2 \n"
"vmovn.s16 d23, q3 \n"
"vmovn.s16 d25, q4 \n"
"vst1.32 {q9}, [%[y0]]! \n"
"vst1.32 {q10}, [%[y1]]! \n"
"vst1.32 {q11}, [%[y2]]! \n"
"vst1.32 {q12}, [%[y3]]! \n"
"subs %[loop], #1 \n"
"bne loop_quantize_%= \n"
"quantize_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vld1.32 {q3, q4}, [%[x1]]! \n"
"vld1.32 {q5, q6}, [%[x2]]! \n"
"vld1.32 {q7, q8}, [%[x3]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d18, q1 \n"
"vmovn.s16 d20, q2 \n"
"vmovn.s16 d22, q3 \n"
"vmovn.s16 d24, q4 \n"
"vld1.32 {q1, q2}, [%[x0]] \n"
"vld1.32 {q3, q4}, [%[x1]] \n"
"vld1.32 {q5, q6}, [%[x2]] \n"
"vld1.32 {q7, q8}, [%[x3]] \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d19, q1 \n"
"vmovn.s16 d21, q2 \n"
"vmovn.s16 d23, q3 \n"
"vmovn.s16 d25, q4 \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d18}, [%[y0]]! \n"
"vst1.32 {d20}, [%[y1]]! \n"
"vst1.32 {d22}, [%[y2]]! \n"
"vst1.32 {d24}, [%[y3]]! \n"
"vmov.32 d18, d19 \n"
"vmov.32 d20, d21 \n"
"vmov.32 d22, d23 \n"
"vmov.32 d24, d25 \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d18[0]}, [%[y0]]! \n"
"vst1.32 {d20[0]}, [%[y1]]! \n"
"vst1.32 {d22[0]}, [%[y2]]! \n"
"vst1.32 {d24[0]}, [%[y3]]! \n"
"vext.32 d18, d18, d18, #1 \n"
"vext.32 d20, d20, d20, #1 \n"
"vext.32 d22, d22, d22, #1 \n"
"vext.32 d24, d24, d24, #1 \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1w_%= \n"
"vst1.16 {d18[0]}, [%[y0]]! \n"
"vst1.16 {d20[0]}, [%[y1]]! \n"
"vst1.16 {d22[0]}, [%[y2]]! \n"
"vst1.16 {d24[0]}, [%[y3]]! \n"
"vext.16 d18, d18, d18, #1 \n"
"vext.16 d20, d20, d20, #1 \n"
"vext.16 d22, d22, d22, #1 \n"
"vext.16 d24, d24, d24, #1 \n"
"sub %[remain], #2 \n"
"store_1w_%=:"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d18[0]}, [%[y0]]! \n"
"vst1.8 {d20[0]}, [%[y1]]! \n"
"vst1.8 {d22[0]}, [%[y2]]! \n"
"vst1.8 {d24[0]}, [%[y3]]! \n"
"end_%=: \n"
: [x0] "+r"(x0), [x1] "+r"(x1), [x2] "+r"(x2), [x3] "+r"(x3),
[y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3),
[loop] "+r"(loop), [remain] "+r"(remain)
: [scale] "r"(scale)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12");
asm volatile(
"vdup.s8 d0, %[val] \n"
"cmp %[pad_loop], #0 \n"
"ble store_pad_2w_%= \n"
"loop_pad_4w_%=: \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"vst1.32 {d0[0]}, [%[y1]]! \n"
"vst1.32 {d0[0]}, [%[y2]]! \n"
"vst1.32 {d0[0]}, [%[y3]]! \n"
"subs %[pad_loop], #1 \n"
"bne loop_pad_4w_%= \n"
"store_pad_2w_%=: \n"
"cmp %[pad_remain], #2 \n"
"blt store_pad_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"vst1.16 {d0[0]}, [%[y1]]! \n"
"vst1.16 {d0[0]}, [%[y2]]! \n"
"vst1.16 {d0[0]}, [%[y3]]! \n"
"sub %[pad_remain], #2 \n"
"store_pad_1w_%=: \n"
"cmp %[pad_remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"vst1.8 {d0[0]}, [%[y1]]! \n"
"vst1.8 {d0[0]}, [%[y2]]! \n"
"vst1.8 {d0[0]}, [%[y3]]! \n"
"end_%=: \n"
: [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3),
[pad_loop] "+r"(pad_loop), [pad_remain] "+r"(pad_remain)
: [val] "r"(padding_val)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12");
}
}
for (int c = (channels & 0xFFFC); c < channels; ++c) {
const float *input0 = x + (batch * channels + c) * input_spatial_size;
size_t offset = (batch * channels + c) * output_spatial_size;
for (int h = 0; h < 2; ++h) {
int8_t *y0 =
y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]);
int loop = start >> 4;
int remain = start & 0xF;
asm volatile(
"vdup.s8 q0, %[val] \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"store_16w_%=: \n"
"vst1.32 {q0}, [%[y0]]! \n"
"subs %[loop], #1 \n"
"bne store_16w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d0}, [%[y0]]! \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"sub %[remain], #2 \n"
"store_1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"end_%=: \n"
: [y0] "+r"(y0), [loop] "+r"(loop), [remain] "+r"(remain)
: [val] "r"(padding_val)
: "cc", "memory", "q0");
}
// quantize valid area
int8_t *y0 = y + offset + start;
for (int h = 0; h < input_h; ++h) {
const float *x0 = input0 + h * input_w;
int loop = input_w >> 4;
int remain = input_w & 0xF;
int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2
int pad_remain = (paddings[1] << 1) & 0x3;
asm volatile(
"vdup.f32 q0, %[scale] \n"
"cmp %[loop], #0 \n"
"ble quantize_remain_%= \n"
"loop_quantize_%=: \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d18, q1 \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d19, q1 \n"
"vst1.32 {q9}, [%[y0]]! \n"
"subs %[loop], #1 \n"
"bne loop_quantize_%= \n"
"quantize_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble start_pad_%= \n"
"vldm %[x0], {d2-d9} \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d18, q1 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vcvt.s32.f32 q1, q3 \n"
"vcvt.s32.f32 q2, q4 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d19, q1 \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d18}, [%[y0]]! \n"
"vmov.32 d18, d19 \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d18[0]}, [%[y0]]! \n"
"vext.32 d18, d18, d18, #1 \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1w_%= \n"
"vst1.16 {d18[0]}, [%[y0]]! \n"
"vext.16 d18, d18, d18, #1 \n"
"sub %[remain], #2 \n"
"store_1w_%=:"
"cmp %[remain], #1 \n"
"blt start_pad_%= \n"
"vst1.8 {d18[0]}, [%[y0]]! \n"
"start_pad_%=: \n"
"vdup.s8 d0, %[val] \n"
"cmp %[pad_loop], #0 \n"
"ble pad_remain_%= \n"
"loop_pad_4w_%=: \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"subs %[pad_loop], #1 \n"
"bne loop_pad_4w_%= \n"
"pad_remain_%=: \n"
"cmp %[pad_remain], #2 \n"
"blt store_pad_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"sub %[pad_remain], #2 \n"
"store_pad_1w_%=: \n"
"cmp %[pad_remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"end_%=: \n"
: [x0] "+r"(x0), [y0] "+r"(y0), [loop] "+r"(loop),
[remain] "+r"(remain), [pad_loop] "+r"(pad_loop),
[pad_remain] "+r"(pad_remain)
: [scale] "r"(scale), [val] "r"(padding_val)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q9");
}
}
} }
return max_abs;
} }
#endif // __aarch64__
#endif // ARM_NEON
template <> template <>
bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) { bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) {
...@@ -799,19 +191,15 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) { ...@@ -799,19 +191,15 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
// only support int8 currently // only support int8 currently
float scale = 127 / max_abs; float scale = 127 / max_abs;
param.online_scale_->mutable_data<float>()[0] = max_abs; param.online_scale_->mutable_data<float>()[0] = max_abs;
const auto &paddings = param.paddings_;
// std::vector<int> paddings = {0, 0};
// const auto padding_val = param.padding_val_;
int8_t padding_val = 0;
switch (param.round_type_) { switch (param.round_type_) {
case ROUND_NEAREST_TO_EVEN: case ROUND_NEAREST_TO_EVEN:
quantize_round_to_even(input, scale, paddings, padding_val, output); Quantize<ROUND_NEAREST_TO_EVEN>(input, scale, output);
break; break;
case ROUND_NEAREST_TOWARDS_ZERO: case ROUND_NEAREST_TOWARDS_ZERO:
quantize_round_to_zero(input, scale, paddings, padding_val, output); Quantize<ROUND_NEAREST_TOWARDS_ZERO>(input, scale, output);
break; break;
case ROUND_NEAREST_AWAY_ZERO: case ROUND_NEAREST_AWAY_ZERO:
quantize_round_to_nearest(input, scale, paddings, padding_val, output); Quantize<ROUND_NEAREST_AWAY_ZERO>(input, scale, output);
break; break;
default: default:
LOG(kLOG_ERROR) << "round type is not supported."; LOG(kLOG_ERROR) << "round type is not supported.";
......
...@@ -170,31 +170,21 @@ template <typename Itype, typename Otype> ...@@ -170,31 +170,21 @@ template <typename Itype, typename Otype>
inline void DepthwiseConv3x3(const ConvParam<CPU> &param) { inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
const Tensor *filter = param.Filter(); const Tensor *filter = param.Filter();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = input->dims()[0];
Tensor *output = param.Output(); Tensor *output = param.Output();
output->mutable_data<Otype>(); output->mutable_data<Otype>();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = static_cast<int>(input->dims()[0]);
Tensor input_pad;
math::PadFunctor<CPU, Itype> pad;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1); Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1); Tensor out_batch = output->Slice(i, i + 1);
if (paddings[0] || paddings[1]) {
framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += 2 * paddings[0];
pad_shape[3] += 2 * paddings[1];
input_pad.mutable_data<float>(pad_shape);
pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1],
&input_pad);
} else {
input_pad = in_batch;
}
if (strides[0] == 1) { if (strides[0] == 1) {
math::DepthwiseConv3x3s1<Itype, Otype>(input_pad, *filter, &out_batch); math::DepthwiseConv3x3s1<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
} else if (strides[0] == 2) { } else if (strides[0] == 2) {
math::DepthwiseConv3x3s2<Itype, Otype>(input_pad, *filter, &out_batch); math::DepthwiseConv3x3s2<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
} else { } else {
// math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter, // math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter,
// &out_batch); // &out_batch);
......
...@@ -1278,7 +1278,10 @@ void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, ...@@ -1278,7 +1278,10 @@ void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
float *output_data = output->data<float>(); float *output_data = output->data<float>();
const float *bias_data = bias->data<float>(); const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
}
const int in_h = static_cast<int>(input->dims()[2]); const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]); const int in_w = static_cast<int>(input->dims()[3]);
......
...@@ -70,16 +70,19 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input, ...@@ -70,16 +70,19 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
// void DepthwiseConv3x3(const framework::Tensor *input, // void DepthwiseConv3x3(const framework::Tensor *input,
// const framework::Tensor *filter, // const framework::Tensor *filter,
// const std::vector<int> &strides, // const std::vector<int> &strides,
// const std::vector<int> &paddings,
// framework::Tensor *output); // framework::Tensor *output);
template <typename Itype, typename Otype> template <typename Itype, typename Otype>
void DepthwiseConv3x3s1(const framework::Tensor &input, void DepthwiseConv3x3s1(const framework::Tensor &input,
const framework::Tensor &filter, const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output); framework::Tensor *output);
template <typename Itype, typename Otype> template <typename Itype, typename Otype>
void DepthwiseConv3x3s2(const framework::Tensor &input, void DepthwiseConv3x3s2(const framework::Tensor &input,
const framework::Tensor &filter, const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output); framework::Tensor *output);
} // namespace math } // namespace math
......
...@@ -29,6 +29,7 @@ namespace math { ...@@ -29,6 +29,7 @@ namespace math {
template <> template <>
void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input, void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter, const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) { framework::Tensor *output) {
const int8_t *input_data = input.data<int8_t>(); const int8_t *input_data = input.data<int8_t>();
const int8_t *filter_data = filter.data<int8_t>(); const int8_t *filter_data = filter.data<int8_t>();
...@@ -751,6 +752,7 @@ void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -751,6 +752,7 @@ void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input,
template <> template <>
void DepthwiseConv3x3s2<int8_t, int32_t>(const framework::Tensor &input, void DepthwiseConv3x3s2<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter, const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) { framework::Tensor *output) {
const int8_t *input_data = input.data<int8_t>(); const int8_t *input_data = input.data<int8_t>();
const int8_t *filter_data = filter.data<int8_t>(); const int8_t *filter_data = filter.data<int8_t>();
......
...@@ -405,9 +405,9 @@ class ConvParam : public OpParam { ...@@ -405,9 +405,9 @@ class ConvParam : public OpParam {
const RType *Input() const { return input_; } const RType *Input() const { return input_; }
RType *&Filter() const { return filter_; } RType *Filter() const { return filter_; }
RType *&Output() const { return output_; } RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; } const vector<int> &Strides() const { return strides_; }
...@@ -441,8 +441,8 @@ class ConvParam : public OpParam { ...@@ -441,8 +441,8 @@ class ConvParam : public OpParam {
private: private:
RType *input_; RType *input_;
mutable RType *output_; RType *output_;
mutable RType *filter_; RType *filter_;
vector<int> strides_; vector<int> strides_;
vector<int> paddings_; vector<int> paddings_;
vector<int> dilations_; vector<int> dilations_;
......
...@@ -44,25 +44,19 @@ struct Round<round::RoundTowardsZero> { ...@@ -44,25 +44,19 @@ struct Round<round::RoundTowardsZero> {
template <> template <>
struct Round<round::RoundToEven> { struct Round<round::RoundToEven> {
int8_t operator()(float x) { int8_t operator()(float x) {
int8_t ret = 0;
float v = std::round(x); float v = std::round(x);
int32_t q = (int32_t)v; int32_t q = static_cast<int32_t>(v);
if (abs(abs(q - x) - 0.5) > 0) { if (abs(abs(q - v) - 0.5) <= 0) {
ret = q; if (abs(q) % 2 != 0) {
} else { q = q + ((q > 0) ? -1 : 1);
if (abs(q) % 2 == 0) {
ret = q;
} else {
ret = q + ((q > 0) ? -1 : 1);
} }
} }
return ret; return static_cast<int8_t>(q);
} }
}; };
template <round::RoundType T> template <round::RoundType T>
static void quantize(const Tensor *input, const float scale, const int pad, static void quantize(const Tensor *input, const float scale, Tensor *output) {
const int8_t pad_val, Tensor *output) {
int batch_size = input->dims()[0]; int batch_size = input->dims()[0];
int channels = input->dims()[1]; int channels = input->dims()[1];
int input_h = input->dims()[2]; int input_h = input->dims()[2];
...@@ -77,29 +71,9 @@ static void quantize(const Tensor *input, const float scale, const int pad, ...@@ -77,29 +71,9 @@ static void quantize(const Tensor *input, const float scale, const int pad,
for (int nc = 0; nc < batch_size * channels; ++nc) { for (int nc = 0; nc < batch_size * channels; ++nc) {
const float *xh = x + nc * input_spatial; const float *xh = x + nc * input_spatial;
int8_t *yh = y + nc * output_spatial; int8_t *yh = y + nc * output_spatial;
// pad top
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
}
}
for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) { for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) {
// pad left
for (int w = 0; w < pad; ++w) {
yh[w] = pad_val;
}
for (int w = 0; w < input_w; ++w) { for (int w = 0; w < input_w; ++w) {
yh[w + pad] = Round<T>()(xh[w] * scale); yh[w] = Round<T>()(xh[w] * scale);
}
// pad right
for (int w = 0; w < pad; ++w) {
yh[pad + input_w + w] = pad_val;
}
}
// pad bottom
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
} }
} }
} }
...@@ -120,19 +94,14 @@ static float find_abs_max(const Tensor *input) { ...@@ -120,19 +94,14 @@ static float find_abs_max(const Tensor *input) {
int TestQuqntizeOp(int argc, char *argv[]) { int TestQuqntizeOp(int argc, char *argv[]) {
if (argc < 5) { if (argc < 5) {
std::cout std::cout << "Usage: ./test-quantize-op batch_size channel height width"
<< "Usage: ./test-quantize-op batch_size channel height width [pad]" << std::endl;
<< std::endl;
return 1; return 1;
} }
int pad = 0;
int batch_size = atoi(argv[1]); int batch_size = atoi(argv[1]);
int channel = atoi(argv[2]); int channel = atoi(argv[2]);
int height = atoi(argv[3]); int height = atoi(argv[3]);
int width = atoi(argv[4]); int width = atoi(argv[4]);
if (argc == 6) {
pad = atoi(argv[5]);
}
std::cout << "batch_size: " << batch_size << ", channel: " << channel std::cout << "batch_size: " << batch_size << ", channel: " << channel
<< ", height: " << height << ", width: " << width << std::endl; << ", height: " << height << ", width: " << width << std::endl;
framework::DDim dim = framework::DDim dim =
...@@ -153,7 +122,6 @@ int TestQuqntizeOp(int argc, char *argv[]) { ...@@ -153,7 +122,6 @@ int TestQuqntizeOp(int argc, char *argv[]) {
auto output_scale_var = scope.get()->Var("output_scale"); auto output_scale_var = scope.get()->Var("output_scale");
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad, pad}));
auto *op = new operators::QuantizeOp<CPU, float>("quantize", inputs, outputs, auto *op = new operators::QuantizeOp<CPU, float>("quantize", inputs, outputs,
attrs, scope); attrs, scope);
op->InferShape(); op->InferShape();
...@@ -172,9 +140,9 @@ int TestQuqntizeOp(int argc, char *argv[]) { ...@@ -172,9 +140,9 @@ int TestQuqntizeOp(int argc, char *argv[]) {
framework::Tensor output_cmp; framework::Tensor output_cmp;
output_cmp.Resize(output->dims()); output_cmp.Resize(output->dims());
float scale = 127 / output_scale_cmp; float scale = 127 / output_scale_cmp;
// quantize<round::RoundToEven>(input, scale, pad, 0, &output_cmp); // quantize<round::RoundToEven>(input, scale, &output_cmp);
// quantize<round::RoundAwayZero>(input, scale, pad, 0, &output_cmp); // quantize<round::RoundAwayZero>(input, scale, &output_cmp);
quantize<round::RoundTowardsZero>(input, scale, pad, 0, &output_cmp); quantize<round::RoundTowardsZero>(input, scale, &output_cmp);
int8_t *output_cmp_data = output_cmp.data<int8_t>(); int8_t *output_cmp_data = output_cmp.data<int8_t>();
for (int i = 0; i < output->numel(); ++i) { for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册