提交 022f1291 编写于 作者: H hjchen2

Optimize elementwise/relu/im2col, support 1x1 and 7x7 conv using int8, fix some code style

上级 3d475c9a
...@@ -80,12 +80,13 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size, ...@@ -80,12 +80,13 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
} }
template <typename Dtype> template <typename Dtype>
void LoadMemInternal(void **data, framework::LoDTensor *tensor) { static void LoadMemInternal(void **data, framework::LoDTensor *tensor,
bool quant_uint8 = false) {
char **data_buf = reinterpret_cast<char **>(data); char **data_buf = reinterpret_cast<char **>(data);
int64_t size = tensor->numel(); int64_t size = tensor->numel();
Dtype *tensor_data = tensor->mutable_data<Dtype>(); Dtype *tensor_data = tensor->mutable_data<Dtype>();
if (0) { if (quant_uint8) {
// TODO(hjchen2) should be moved into operator init function // should be moved into operator init function
float min_value; float min_value;
float max_value; float max_value;
memcpy(&min_value, data_buf, sizeof(float)); memcpy(&min_value, data_buf, sizeof(float));
...@@ -141,7 +142,8 @@ void Executor<Dtype, P>::LoadMemory( ...@@ -141,7 +142,8 @@ void Executor<Dtype, P>::LoadMemory(
// parse tensor from stream // parse tensor from stream
switch (tensor_desc.DataType()) { switch (tensor_desc.DataType()) {
case framework::VARTYPE_TYPE_FP32: case framework::VARTYPE_TYPE_FP32:
LoadMemInternal<float>(reinterpret_cast<void **>(data_buf), tensor); LoadMemInternal<float>(reinterpret_cast<void **>(data_buf), tensor,
program_.quantification);
break; break;
case framework::VARTYPE_TYPE_INT8: case framework::VARTYPE_TYPE_INT8:
LoadMemInternal<int8_t>(reinterpret_cast<void **>(data_buf), tensor); LoadMemInternal<int8_t>(reinterpret_cast<void **>(data_buf), tensor);
......
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef DEQUANT_OP
#include "operators/dequantize_op.h" #include "operators/dequantize_op.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -30,3 +32,5 @@ namespace ops = paddle_mobile::operators; ...@@ -30,3 +32,5 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(dequantize, ops::DequantizeOp); REGISTER_OPERATOR_CPU(dequantize, ops::DequantizeOp);
#endif #endif
#endif
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef DEQUANT_OP
#pragma once #pragma once
#include <string> #include <string>
...@@ -41,3 +43,5 @@ class DequantizeOp ...@@ -41,3 +43,5 @@ class DequantizeOp
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_MOBILE_CPU #ifdef DEQUANT_OP
#include "operators/kernel/dequantize_kernel.h" #include "operators/kernel/dequantize_kernel.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_MOBILE_CPU #ifdef QUANT_OP
#include "operators/kernel/quantize_kernel.h" #include "operators/kernel/quantize_kernel.h"
#include <cmath> #include <cmath>
...@@ -225,7 +225,7 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale, ...@@ -225,7 +225,7 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale,
const float *x = input->data<const float>(); const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>(); int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel(); size_t size = input->numel();
#ifdef defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4; size_t loop = size >> 4;
size_t remain = size & 0xF; size_t remain = size & 0xF;
for (size_t i = 0; i < loop; ++i) { for (size_t i = 0; i < loop; ++i) {
......
...@@ -15,8 +15,12 @@ limitations under the License. */ ...@@ -15,8 +15,12 @@ limitations under the License. */
#ifdef ELEMENTWISEADD_OP #ifdef ELEMENTWISEADD_OP
#pragma once #pragma once
#include "operators/math/elementwise_op_function.h" #include "operators/math/elementwise_op_function.h"
#include "operators/op_param.h" #include "operators/op_param.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -33,8 +37,50 @@ void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) { ...@@ -33,8 +37,50 @@ void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
Tensor *Out = param.Out(); Tensor *Out = param.Out();
Out->mutable_data<float>(); Out->mutable_data<float>();
int axis = param.Axis(); int axis = param.Axis();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t batch = 1;
size_t elementwise_num = 1;
for (int i = 0; i < axis; ++i) {
batch *= input_x->dims()[i];
}
for (int i = axis + 1; i < input_x->dims().size(); ++i) {
elementwise_num *= input_x->dims()[i];
}
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < input_x->dims()[axis]; ++j) {
size_t offset = (i * input_x->dims()[axis] + j) * elementwise_num;
const float *input = input_x->data<float>() + offset;
const float *bias = input_y->data<float>() + j;
float *output = Out->mutable_data<float>() + offset;
int loop = elementwise_num >> 0x4;
int remain = elementwise_num & 0xF;
for (int k = 0; k < loop; ++k) {
float32x4_t rb = vdupq_n_f32(*bias);
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
float32x4_t r2 = vld1q_f32(input + 8);
float32x4_t r3 = vld1q_f32(input + 12);
r0 = vaddq_f32(r0, rb);
r1 = vaddq_f32(r1, rb);
r2 = vaddq_f32(r2, rb);
r3 = vaddq_f32(r3, rb);
vst1q_f32(output, r0);
vst1q_f32(output + 4, r1);
vst1q_f32(output + 8, r2);
vst1q_f32(output + 12, r3);
input += 16;
output += 16;
}
for (int k = 0; k < remain; ++k) {
output[k] = input[k] + *bias;
}
}
}
#else
ElementwiseComputeEx<AddFunctor<float>, float>(input_x, input_y, axis, ElementwiseComputeEx<AddFunctor<float>, float>(input_x, input_y, axis,
AddFunctor<float>(), Out); AddFunctor<float>(), Out);
#endif
} }
template class ElementwiseAddKernel<CPU, float>; template class ElementwiseAddKernel<CPU, float>;
......
...@@ -17,6 +17,9 @@ limitations under the License. */ ...@@ -17,6 +17,9 @@ limitations under the License. */
#include <operators/math/transform.h> #include <operators/math/transform.h>
#include "operators/op_param.h" #include "operators/op_param.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -37,71 +40,100 @@ void ReluCompute(const ReluParam<CPU> &param) { ...@@ -37,71 +40,100 @@ void ReluCompute(const ReluParam<CPU> &param) {
auto *out_ptr = out->mutable_data<float>(); auto *out_ptr = out->mutable_data<float>();
int numel = input_x->numel(); int numel = input_x->numel();
// if (numel > 64) { #if defined(__ARM_NEON__) || defined(__ARM_NEON)
// asm volatile( #if __aarch64__
// "pld [%[input_x_ptr], #0] \n\t" if (numel > 0) {
// "vmov.f32 q8, #0.0 \n\t" int loop = numel >> 0x4;
// "subs %[num], %[num], #32 \n\t" int remain = numel & 0xF;
// "blt end_num_%= \n\t" float32x4_t zero = vdupq_n_f32(0.f);
// "loop_num_%=: \n\t" for (int i = 0; i < loop; ++i) {
// "pld [%[input_x_ptr], #1024] \n\t" float32x4_t r0 = vld1q_f32(input_x_ptr);
// float32x4_t r1 = vld1q_f32(input_x_ptr + 4);
// "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t" float32x4_t r2 = vld1q_f32(input_x_ptr + 8);
// "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t" float32x4_t r3 = vld1q_f32(input_x_ptr + 12);
// "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t" r0 = vmaxq_f32(r0, zero);
// "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t" r1 = vmaxq_f32(r1, zero);
// r2 = vmaxq_f32(r2, zero);
// "vmax.f32 q0, q0, q8 \n\t" r3 = vmaxq_f32(r3, zero);
// "vmax.f32 q1, q1, q8 \n\t" vst1q_f32(out_ptr, r0);
// "vmax.f32 q2, q2, q8 \n\t" vst1q_f32(out_ptr + 4, r1);
// "vmax.f32 q3, q3, q8 \n\t" vst1q_f32(out_ptr + 8, r2);
// "vmax.f32 q4, q4, q8 \n\t" vst1q_f32(out_ptr + 12, r3);
// "vmax.f32 q5, q5, q8 \n\t" input_x_ptr += 16;
// "vmax.f32 q6, q6, q8 \n\t" out_ptr += 16;
// "vmax.f32 q7, q7, q8 \n\t" }
// for (int i = 0; i < remain; ++i) {
// "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t" out_ptr[i] = (input_x_ptr[i] > 0) * input_x_ptr[i];
// "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t" }
// "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" #else
// "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t" if (numel > 64) {
// asm volatile(
// "subs %[num], %[num], #32 \n\t" "pld [%[input_x_ptr], #0] \n\t"
// "bge loop_num_%= \n\t" "vmov.f32 q8, #0.0 \n\t"
// "end_num_%=: \n\t" "subs %[num], %[num], #32 \n\t"
// "cmp %[num], #0 \n\t" "blt end_num_%= \n\t"
// "bge end_%= \n\t" "loop_num_%=: \n\t"
// "mov r6, #4 \n\t" "pld [%[input_x_ptr], #1024] \n\t"
// "mul r5, %[num], r6 \n\t"
// "add %[input_x_ptr], %[input_x_ptr], r5 \n\t" "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t" "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t" "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t" "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t"
// "vmax.f32 q0, q0, q8 \n\t" "vmax.f32 q0, q0, q8 \n\t"
// "vmax.f32 q1, q1, q8 \n\t" "vmax.f32 q1, q1, q8 \n\t"
// "vmax.f32 q2, q2, q8 \n\t" "vmax.f32 q2, q2, q8 \n\t"
// "vmax.f32 q3, q3, q8 \n\t" "vmax.f32 q3, q3, q8 \n\t"
// "vmax.f32 q4, q4, q8 \n\t" "vmax.f32 q4, q4, q8 \n\t"
// "vmax.f32 q5, q5, q8 \n\t" "vmax.f32 q5, q5, q8 \n\t"
// "vmax.f32 q6, q6, q8 \n\t" "vmax.f32 q6, q6, q8 \n\t"
// "vmax.f32 q7, q7, q8 \n\t" "vmax.f32 q7, q7, q8 \n\t"
// "add %[out_ptr], %[out_ptr], r5 \n\t"
// "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t" "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t"
// "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t" "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t"
// "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t"
// "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t" "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t"
// "end_%=: \n\t"
// : "subs %[num], %[num], #32 \n\t"
// : "bge loop_num_%= \n\t"
// [out_ptr] "r"(out_ptr), [input_x_ptr] "r"(input_x_ptr), [num] "end_num_%=: \n\t"
// "r"(numel) : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "cmp %[num], #0 \n\t"
// "q7", "q8", "r5", "bge end_%= \n\t"
// "r6"); "mov r6, #4 \n\t"
// } else { "mul r5, %[num], r6 \n\t"
ReluFunctor<float> func_; "add %[input_x_ptr], %[input_x_ptr], r5 \n\t"
math::Transform trans; "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t"
trans(input_x_ptr, input_x_ptr + numel, out_ptr, func_); "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t"
// } "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t"
"vmax.f32 q0, q0, q8 \n\t"
"vmax.f32 q1, q1, q8 \n\t"
"vmax.f32 q2, q2, q8 \n\t"
"vmax.f32 q3, q3, q8 \n\t"
"vmax.f32 q4, q4, q8 \n\t"
"vmax.f32 q5, q5, q8 \n\t"
"vmax.f32 q6, q6, q8 \n\t"
"vmax.f32 q7, q7, q8 \n\t"
"add %[out_ptr], %[out_ptr], r5 \n\t"
"vst1.32 {q0, q1}, [%[out_ptr]]! \n\t"
"vst1.32 {q2, q3}, [%[out_ptr]]! \n\t"
"vst1.32 {q4, q5}, [%[out_ptr]]! \n\t"
"vst1.32 {q6, q7}, [%[out_ptr]]! \n\t"
"end_%=: \n\t"
:
:
[out_ptr] "r"(out_ptr), [input_x_ptr] "r"(input_x_ptr), [num] "r"(numel)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "r5",
"r6");
#endif
} else {
#endif
ReluFunctor<float> func_;
math::Transform trans;
trans(input_x_ptr, input_x_ptr + numel, out_ptr, func_);
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
}
#endif
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef DEQUANT_OP
#pragma once #pragma once
#include "framework/operator.h" #include "framework/operator.h"
...@@ -30,3 +32,5 @@ class DequantizeKernel ...@@ -30,3 +32,5 @@ class DequantizeKernel
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef QUANT_OP
#pragma once #pragma once
#include "framework/operator.h" #include "framework/operator.h"
...@@ -30,3 +32,5 @@ class QuantizeKernel ...@@ -30,3 +32,5 @@ class QuantizeKernel
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif
...@@ -36,7 +36,9 @@ void conv3x3s1_int8(const framework::Tensor& input, ...@@ -36,7 +36,9 @@ void conv3x3s1_int8(const framework::Tensor& input,
int image_size = input_h * input_w; int image_size = input_h * input_w;
int out_image_size = output_h * output_w; int out_image_size = output_h * output_w;
memset(out_data, 0, output_c * out_image_size * sizeof(int32_t)); memset(out_data, 0, output_c * out_image_size * sizeof(int32_t));
#if __aarch64__
// TODO(hjchen2)
#else
int oc = 0; int oc = 0;
#pragma omp parallel for #pragma omp parallel for
for (; oc < output_c - 1; oc += 2) { for (; oc < output_c - 1; oc += 2) {
...@@ -747,7 +749,7 @@ void conv3x3s1_int8(const framework::Tensor& input, ...@@ -747,7 +749,7 @@ void conv3x3s1_int8(const framework::Tensor& input,
} }
} }
} }
#endif
#else #else
// TODO(hjchen2) // TODO(hjchen2)
#endif #endif
......
...@@ -36,7 +36,9 @@ void conv5x5s1_int8(const framework::Tensor& input, ...@@ -36,7 +36,9 @@ void conv5x5s1_int8(const framework::Tensor& input,
int image_size = input_h * input_w; int image_size = input_h * input_w;
int out_image_size = output_h * output_w; int out_image_size = output_h * output_w;
memset(out_data, 0, output_c * out_image_size * sizeof(int32_t)); memset(out_data, 0, output_c * out_image_size * sizeof(int32_t));
#if __aarch64__
// TODO(hjchen2)
#else
#pragma omp parallel for #pragma omp parallel for
for (int oc = 0; oc < output_c; ++oc) { for (int oc = 0; oc < output_c; ++oc) {
for (int ic = 0; ic < input_c; ++ic) { for (int ic = 0; ic < input_c; ++ic) {
...@@ -537,6 +539,7 @@ void conv5x5s1_int8(const framework::Tensor& input, ...@@ -537,6 +539,7 @@ void conv5x5s1_int8(const framework::Tensor& input,
} }
} }
} }
#endif
#else #else
// TODO(hjchen2) // TODO(hjchen2)
#endif #endif
......
...@@ -642,6 +642,7 @@ void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, ...@@ -642,6 +642,7 @@ void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
// C = A * B, 8位 int32_t // C = A * B, 8位 int32_t
void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) { int32_t ldc) {
#if __ARM_NEON
int32_t nc1 = nc >> 4; int32_t nc1 = nc >> 4;
int32_t _nc1 = nc & 15; int32_t _nc1 = nc & 15;
int32_t step = sizeof(int32_t) * ldc; int32_t step = sizeof(int32_t) * ldc;
...@@ -695,6 +696,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, ...@@ -695,6 +696,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
} }
} }
} }
#endif // __ARM_NEON
} }
// C = A * B + C // C = A * B + C
......
...@@ -397,7 +397,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()( ...@@ -397,7 +397,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width) im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<T>(0) ? static_cast<float>(0)
: im_data[im_idx]; : im_data[im_idx];
} }
} }
...@@ -405,10 +405,68 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()( ...@@ -405,10 +405,68 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
#endif #endif
} }
// TODO(hjchen2) void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
void ExtractToRows1() {} const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
void ExtractToRows2() {} const int stride_w, const int kh, const int kw) {
int h = padding_h - kh;
int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0;
int start_height = kh + col_start_height * stride_h - padding_h;
int start_width = kw + col_start_width * stride_w - padding_w;
int end_height = (col_height - col_start_height) * stride_h + start_height;
end_height = end_height > im_height ? im_height : end_height;
int end_width = (col_width - col_start_width) * stride_w + start_width;
end_width = end_width > im_width ? im_width : end_width;
int extract = (end_width - start_width + stride_w - 1) / stride_w;
im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) {
if (stride_w == 1) {
memcpy(col_data, im_data, extract * sizeof(int8_t));
} else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
int8x16x2_t img = vld2q_s8(im_data + s * 2);
vst1q_s8(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 2];
}
} else if (stride_w == 3) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
int8x16x3_t img = vld3q_s8(im_data + s * 3);
vst1q_s8(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 3];
}
} else if (stride_w == 4) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
int8x16x4_t img = vld4q_s8(im_data + s * 4);
vst1q_s8(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 4];
}
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1, 2, 3 and 4.");
}
im_data += im_width * stride_h;
col_data += col_width;
}
}
/* /*
* im = [input_channels, input_height, input_width] * im = [input_channels, input_height, input_width]
...@@ -432,64 +490,42 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()( ...@@ -432,64 +490,42 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()(
int channels_col = im_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
const int8_t *im_data = im.data<int8_t>(); const int8_t *im_data = im.data<int8_t>();
int8_t *col_data = col->data<int8_t>(); int8_t *col_data = col->data<int8_t>();
// #if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
#if 0 if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
if (stride[0] == stride[1] && stride[0] == 1 && dilation[0] == 1 &&
padding[0] == padding[1] && dilation[0] == dilation[1]) {
// pad 0 // pad 0
memset(col_data, 0, col->numel() * sizeof(int8_t)); memset(col_data, 0, col->numel() * sizeof(int8_t));
for (int ic = 0; ic < im_channels; ++ic) { for (int ic = 0; ic < im_channels; ++ic) {
for (int oh = 0; oh < padding[0]; ++oh) { for (int kh = 0; kh < filter_height; ++kh) {
for (int k = 0; k < filter_height * filter_width; ++k) { for (int kw = 0; kw < filter_width; ++kw) {
ExtractToRows1(); ExtractToImg(im_data, col_data, im_height, im_width, col_height,
ExtractToRows1(); col_width, padding[0], padding[1], stride[0], stride[1],
} kh, kw);
} col_data += col_height * col_width;
for (int oh = padding[0]; oh < col_height - padding[0]; ++oh) {
for (int k = 0; k < filter_height * filter_width; ++k) {
ExtractToRows1();
}
}
}
} else if (stride[0] == stride[1] && stride[0] == 2 && dilation[0] == 1 &&
padding[0] == padding[1] && dilation[0] == dilation[1]) {
// pad 0
memset(col_data, 0, col->numel() * sizeof(int8_t));
for (int ic = 0; ic < im_channels; ++ic) {
for (int oh = 0; oh < padding[0]; ++oh) {
for (int k = 0; k < filter_height * filter_width; ++k) {
ExtractToRows2();
ExtractToRows2();
}
}
for (int oh = padding[0]; oh < col_height - padding[0]; ++oh) {
for (int k = 0; k < filter_height * filter_width; ++k) {
ExtractToRows2();
} }
} }
im_data += im_height * im_width;
} }
} else { } else {
#endif #endif
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height); int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) { for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) { for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w; int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width) im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<int8_t>(0) ? static_cast<int8_t>(0)
: im_data[im_idx]; : im_data[im_idx];
}
} }
} }
} #if defined(__ARM_NEON__) || defined(__ARM_NEON)
// #if defined(__ARM_NEON__) || defined(__ARM_NEON)
#if 0
} }
#endif #endif
} }
......
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef QUANT_OP
#include "operators/quantize_op.h" #include "operators/quantize_op.h"
#include <vector> #include <vector>
...@@ -33,3 +35,5 @@ namespace ops = paddle_mobile::operators; ...@@ -33,3 +35,5 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(quantize, ops::QuantizeOp); REGISTER_OPERATOR_CPU(quantize, ops::QuantizeOp);
#endif #endif
#endif
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef QUANT_OP
#pragma once #pragma once
#include <string> #include <string>
...@@ -40,3 +42,5 @@ class QuantizeOp : public framework::OperatorWithKernel< ...@@ -40,3 +42,5 @@ class QuantizeOp : public framework::OperatorWithKernel<
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif
...@@ -140,10 +140,10 @@ int TestConvOp() { ...@@ -140,10 +140,10 @@ int TestConvOp() {
int dilation_w = 1; int dilation_w = 1;
int batch_size = 1; int batch_size = 1;
int input_c = 63; int input_c = 3;
int input_h = 51; int input_h = 100;
int input_w = 51; int input_w = 100;
int output_c = 125; int output_c = 10;
framework::DDim input_shape = framework::DDim input_shape =
framework::make_ddim({batch_size, input_c, input_h, input_w}); framework::make_ddim({batch_size, input_c, input_h, input_w});
framework::DDim filter_shape = framework::DDim filter_shape =
...@@ -174,40 +174,38 @@ int TestConvOp() { ...@@ -174,40 +174,38 @@ int TestConvOp() {
auto *op = new operators::ConvOp<CPU, float>("conv2d", inputs, outputs, attrs, auto *op = new operators::ConvOp<CPU, float>("conv2d", inputs, outputs, attrs,
scope); scope);
struct timespec ts_begin, ts_end; // struct timespec ts_begin, ts_end;
op->InferShape(); op->InferShape();
// warmup // warmup
// op->Run();
// clock_gettime(CLOCK_MONOTONIC, &ts_begin);
// for (int i = 0; i < 10; ++i) {
op->Run(); op->Run();
clock_gettime(CLOCK_MONOTONIC, &ts_begin); // }
for (int i = 0; i < 10; ++i) { // clock_gettime(CLOCK_MONOTONIC, &ts_end);
op->Run(); // uint64_t elapsed = (ts_end.tv_sec - ts_begin.tv_sec) * 1e3 +
} // (ts_end.tv_nsec - ts_begin.tv_nsec) / 1e6;
clock_gettime(CLOCK_MONOTONIC, &ts_end); // LOG(kLOG_INFO) << "elapsed: " << elapsed / 10.0 << " ms";
uint64_t elapsed = (ts_end.tv_sec - ts_begin.tv_sec) * 1e3 +
(ts_end.tv_nsec - ts_begin.tv_nsec) / 1e6;
LOG(kLOG_INFO) << "elapsed: " << elapsed / 10.0 << " ms";
/* int kernel_extent_h = dilation_h * (kernel_h - 1) + 1;
int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; int kernel_extent_w = dilation_w * (kernel_w - 1) + 1;
int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1;
int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1; int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1;
int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1; auto output_shape = framework::make_ddim(
auto output_shape = framework::make_ddim( std::vector<int>({batch_size, output_c, output_h, output_w}));
std::vector<int>({batch_size, output_c, output_h, output_w})); framework::Tensor output_cmp;
framework::Tensor output_cmp; output_cmp.mutable_data<Otype>(output_shape);
output_cmp.mutable_data<Otype>(output_shape); conv2d<Itype, Otype>(input, filter, attrs, &output_cmp);
conv2d<Itype, Otype>(input, filter, attrs, &output_cmp);
// compare results // compare results
auto output = output_var->template Get<framework::LoDTensor>(); auto output = output_var->template Get<framework::LoDTensor>();
const Otype *output_data = output->data<Otype>(); const Otype *output_data = output->data<Otype>();
Otype *output_cmp_data = output_cmp.data<Otype>(); Otype *output_cmp_data = output_cmp.data<Otype>();
for (int i = 0; i < output->numel(); ++i) { for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
"output[%d] = %d, output_cmp[%d] = %d", i, "output[%d] = %d, output_cmp[%d] = %d", i,
output_data[i], i, output_cmp_data[i]); output_data[i], i, output_cmp_data[i]);
} }
*/
delete op; delete op;
return 0; return 0;
} }
...@@ -219,10 +217,35 @@ int main() { ...@@ -219,10 +217,35 @@ int main() {
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 2>(); paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 2>();
// kernel = 7, pad = 1, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 2>();
// kernel = 7, pad = 3, stride = 2 // kernel = 7, pad = 3, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 2>(); paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 2>();
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 1>();
// kernel = 7, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 1>();
// kernel = 7, pad = 3, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 1>();
// kernel = 7, pad = 5, stride = 3
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 5, 3>();
// kernel = 7, pad = 3, stride = 4
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 4>();
LOG(paddle_mobile::kLOG_INFO) << "\n";
// kernel = 3, pad = 0, stride = 1 // kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 0, 1>(); paddle_mobile::TestConvOp<int8_t, int32_t, 3, 0, 1>();
......
...@@ -222,6 +222,8 @@ if(NOT FOUND_MATCH) ...@@ -222,6 +222,8 @@ if(NOT FOUND_MATCH)
set(SHAPE_OP ON) set(SHAPE_OP ON)
set(ELEMENTWISEMUL_OP ON) set(ELEMENTWISEMUL_OP ON)
set(SUM_OP ON) set(SUM_OP ON)
set(QUANT_OP ON)
set(DEQUANT_OP ON)
endif() endif()
# option(BATCHNORM_OP "" ON) # option(BATCHNORM_OP "" ON)
...@@ -401,3 +403,10 @@ if (SUM_OP) ...@@ -401,3 +403,10 @@ if (SUM_OP)
add_definitions(-DSUM_OP) add_definitions(-DSUM_OP)
endif() endif()
if (QUANT_OP)
add_definitions(-DQUANT_OP)
endif()
if (DEQUANT_OP)
add_definitions(-DDEQUANT_OP)
endif()
...@@ -5,7 +5,7 @@ TOTAL_ERRORS=0 ...@@ -5,7 +5,7 @@ TOTAL_ERRORS=0
# The trick to remove deleted files: https://stackoverflow.com/a/2413151 # The trick to remove deleted files: https://stackoverflow.com/a/2413151
for file in $(git diff --cached --name-status | awk '$1 != "D" {print $2}' | \ for file in $(git diff --cached --name-status | awk '$1 != "D" {print $2}' | \
grep -v ".pb.cpp" | grep -v ".pb.h" | grep -v ".pb-c.h" | grep -v ".pb-c.c" | \ grep -v ".pb.cpp" | grep -v ".pb.h" | grep -v ".pb-c.h" | grep -v ".pb-c.c" | \
grep -v "protobuf-c.h" | grep -v "protobuf-c.c" | grep -v "variant.h"); do grep -v "protobuf-c.h" | grep -v "protobuf-c.c"); do
cpplint $file; cpplint $file;
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?); TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
done done
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册