diff --git a/src/operators/fusion_conv_add.h b/src/operators/fusion_conv_add.h index f0a3ea17d9a86e2c8638c164cfa2bf21d4fb727d..24f1d3f63b3300db9b60a595466a0ced3b9e996b 100644 --- a/src/operators/fusion_conv_add.h +++ b/src/operators/fusion_conv_add.h @@ -68,11 +68,23 @@ class FusionConvAddOp : public framework::OperatorWithKernel< }; #ifdef PADDLE_MOBILE_CPU +#ifndef CONV_ADD_REGISTER static framework::FusionOpRegistrar convadd_registrar( new FusionConvAddMatcher()); +#define CONV_ADD_REGISTER #endif +#endif + #ifdef PADDLE_MOBILE_MALI_GPU + +#ifndef CONV_ADD_REGISTER +static framework::FusionOpRegistrar convadd_registrar( + new FusionConvAddMatcher()); +#define CONV_ADD_REGISTER +#endif + #endif + #ifdef PADDLE_MOBILE_FPGA #endif diff --git a/src/operators/fusion_conv_add_relu_op.h b/src/operators/fusion_conv_add_relu_op.h index b87f1c4110de6c525e4544d5a350b2beaf98af95..fd27005c8bef8f8cb91fbf5b6e5a852306c28a9b 100644 --- a/src/operators/fusion_conv_add_relu_op.h +++ b/src/operators/fusion_conv_add_relu_op.h @@ -64,8 +64,13 @@ class FusionConvAddReluOp : public framework::OperatorWithKernel< }; #ifdef PADDLE_MOBILE_CPU + +#ifndef CONV_ADD_RELU_REGISTER +#define CONV_ADD_RELU_REGISTER // static framework::FusionOpRegistrar fusion_conv_add_relu_registrar(new // FusionConvAddReluOpMatcher()); +#endif + #endif #ifdef PADDLE_MOBILE_MALI_GPU #endif diff --git a/src/operators/fusion_fc_op.h b/src/operators/fusion_fc_op.h index 2035704bb60eb96bfb22fc4f277d30817efcf646..0ca4d2b27ad46b77ddba55b6b377e741c97bdc9e 100644 --- a/src/operators/fusion_fc_op.h +++ b/src/operators/fusion_fc_op.h @@ -66,11 +66,19 @@ class FusionFcOp }; #ifdef PADDLE_MOBILE_CPU +#ifndef CONV_CPU_REGISTER +#define CONV_CPU_REGISTER static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher()); #endif +#endif + #ifdef PADDLE_MOBILE_MALI_GPU -// static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher()); +#ifndef CONV_CPU_REGISTER +#define CONV_CPU_REGISTER +static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher()); #endif +#endif + #ifdef PADDLE_MOBILE_FPGA #endif diff --git a/src/operators/kernel/arm/batchnorm_kernel.cpp b/src/operators/kernel/arm/batchnorm_kernel.cpp index af4639c4e7c13d486ad78f4877db19c6d3e15a31..68e0c7fa1e6996534ef87f771c9f1a3fb924224f 100644 --- a/src/operators/kernel/arm/batchnorm_kernel.cpp +++ b/src/operators/kernel/arm/batchnorm_kernel.cpp @@ -17,6 +17,7 @@ limitations under the License. */ #pragma once #include "operators/kernel/batchnorm_kernel.h" +#include "operators/kernel/central-arm-func/batchnorm_func.h" namespace paddle_mobile { namespace operators { @@ -28,215 +29,7 @@ bool BatchNormKernel::Init(const BatchNormParam ¶) const { template <> void BatchNormKernel::Compute(const BatchNormParam ¶m) const { - const Tensor *input_x = param.InputX(); - auto input_x_ptr = input_x->data(); - const auto &x_dims = input_x->dims(); - const int N = x_dims[0]; - const int C = x_dims[1]; - const int H = x_dims[2]; - const int W = x_dims[3]; - const int stride0 = C * H * W; - const int stride1 = H * W; - const int stride2 = W; - Tensor *out = param.OutputY(); - auto out_ptr = out->mutable_data(); - const float epsilon = param.Epsilon(); - const Tensor *mean = param.InputMean(); - const Tensor *variance = param.InputVariance(); - const Tensor *scale = param.InputScale(); - const Tensor *bias = param.InputBias(); - auto mean_ptr = mean->data(); - auto variance_ptr = variance->data(); - auto scale_ptr = scale->data(); - auto bias_ptr = bias->data(); - - // Tensor inv_std; - // auto inv_std_ptr = inv_std.mutable_data(make_ddim({C})); - - PADDLE_MOBILE_ENFORCE(C == variance->numel(), - "C must equal to variance.numel()"); - - int HXW = H * W; - if (HXW > 32) { - int NXC = N * C; - float *inv_std_ptr = new float[NXC * 4]; - float *volatile new_scale_ptr = new float[NXC * 4]; - float *volatile new_bias_ptr = new float[NXC * 4]; - - /// std = (var + epsilon).sqrt(); - /// inv_std = 1 / std; - for (int i = 0; i < C * 4; i += 4) { - int index = i / 4; - inv_std_ptr[i] = - 1 / static_cast(pow((variance_ptr[index] + epsilon), 0.5)); - inv_std_ptr[i + 1] = inv_std_ptr[i]; - inv_std_ptr[i + 2] = inv_std_ptr[i]; - inv_std_ptr[i + 3] = inv_std_ptr[i]; - - new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[index]; - new_scale_ptr[i + 1] = new_scale_ptr[i]; - new_scale_ptr[i + 2] = new_scale_ptr[i]; - new_scale_ptr[i + 3] = new_scale_ptr[i]; - - new_bias_ptr[i] = - bias_ptr[index] - mean_ptr[index] * inv_std_ptr[i] * scale_ptr[index]; - - new_bias_ptr[i + 1] = new_bias_ptr[i]; - new_bias_ptr[i + 2] = new_bias_ptr[i]; - new_bias_ptr[i + 3] = new_bias_ptr[i]; - } - - for (int j = C * 4; j < NXC * 4; ++j) { - new_scale_ptr[j] = new_scale_ptr[j - C * 4]; - new_bias_ptr[j] = new_bias_ptr[j - C * 4]; - } - - asm volatile( - "subs %[N], %[N], #1 \n\t" - "blt end_n_%= \n\t" - "loop_n_%=: \n\t" - - "subs %[C], %[C], #1 \n\t" - "blt end_c_%= \n\t" - "loop_c_%=: \n\t" - - "vld1.32 {q9}, [%[new_scale_ptr]]! \n\t" - "vld1.32 {q10}, [%[new_bias_ptr]]! \n\t" - - "mov r6, %[HXW] \n\t" - - "subs r6, r6, #32 \n\t" - "blt end_hw_%= \n\t" - "loop_hw_%=: \n\t" - - "vld1.32 {q1, q2}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q3, q4}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q5, q6}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q7, q8}, [%[input_x_ptr]]! \n\t" - - "vmul.f32 q1, q1, q9 \n\t" - "vmul.f32 q2, q2, q9 \n\t" - "vmul.f32 q3, q3, q9 \n\t" - "vmul.f32 q4, q4, q9 \n\t" - - "vmul.f32 q5, q5, q9 \n\t" - "vmul.f32 q6, q6, q9 \n\t" - "vmul.f32 q7, q7, q9 \n\t" - "vmul.f32 q8, q8, q9 \n\t" - - "vadd.f32 q1, q1, q10 \n\t" - "vadd.f32 q2, q2, q10 \n\t" - "vadd.f32 q3, q3, q10 \n\t" - "vadd.f32 q4, q4, q10 \n\t" - "vadd.f32 q5, q5, q10 \n\t" - "vadd.f32 q6, q6, q10 \n\t" - "vadd.f32 q7, q7, q10 \n\t" - "vadd.f32 q8, q8, q10 \n\t" - - "vst1.32 {q1, q2}, [%[out_ptr]]! \n\t" - "vst1.32 {q3, q4}, [%[out_ptr]]! \n\t" - "vst1.32 {q5, q6}, [%[out_ptr]]! \n\t" - "vst1.32 {q7, q8}, [%[out_ptr]]! \n\t" - - "subs r6, r6, #32 \n\t" - "bge loop_hw_%= \n\t" - "end_hw_%=: \n\t" - - "cmp r6, #0 \n\t" - "bge end_remainder_%= \n\t" - "mov r5, #4 \n\t" - "mul r6, r6, r5 \n\t" - "add %[input_x_ptr], %[input_x_ptr], r6 \n\t" - - "vld1.32 {q1, q2}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q3, q4}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q5, q6}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q7, q8}, [%[input_x_ptr]]! \n\t" - - "vmul.f32 q1, q1, q9 \n\t" - "vmul.f32 q2, q2, q9 \n\t" - "vmul.f32 q3, q3, q9 \n\t" - "vmul.f32 q4, q4, q9 \n\t" - "vmul.f32 q5, q5, q9 \n\t" - "vmul.f32 q6, q6, q9 \n\t" - "vmul.f32 q7, q7, q9 \n\t" - "vmul.f32 q8, q8, q9 \n\t" - "vadd.f32 q1, q1, q10 \n\t" - "vadd.f32 q2, q2, q10 \n\t" - "vadd.f32 q3, q3, q10 \n\t" - "vadd.f32 q4, q4, q10 \n\t" - "vadd.f32 q5, q5, q10 \n\t" - "vadd.f32 q6, q6, q10 \n\t" - "vadd.f32 q7, q7, q10 \n\t" - "vadd.f32 q8, q8, q10 \n\t" - - "add %[out_ptr], %[out_ptr], r6 \n\t" - "vst1.32 {q1, q2}, [%[out_ptr]]! \n\t" - "vst1.32 {q3, q4}, [%[out_ptr]]! \n\t" - "vst1.32 {q5, q6}, [%[out_ptr]]! \n\t" - "vst1.32 {q7, q8}, [%[out_ptr]]! \n\t" - - "end_remainder_%=: \n\t" - - "subs %[C], %[C], #1 \n\t" - "bge loop_c_%= \n\t" - "end_c_%=: \n\t" - - "subs %[N], %[N], #1 \n\t" - "bge loop_n_%= \n\t" - "end_n_%=: \n\t" - : - : [input_x_ptr] "r"(input_x_ptr), [out_ptr] "r"(out_ptr), - [new_scale_ptr] "r"(new_scale_ptr), [new_bias_ptr] "r"(new_bias_ptr), - [N] "r"(N), [C] "r"(C), [HXW] "r"(HXW) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "q10", "r5", "r6"); - - delete[] inv_std_ptr; - delete[] new_scale_ptr; - delete[] new_bias_ptr; - - } else { - float *inv_std_ptr = new float[C]; - for (int i = 0; i < C; i++) { - inv_std_ptr[i] = - 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); - } - - Tensor new_scale; - auto new_scale_ptr = new_scale.mutable_data(make_ddim({C})); - Tensor new_bias; - auto new_bias_ptr = new_bias.mutable_data(make_ddim({C})); - - /// ((x - est_mean) * (inv_var) * scale + bias equal to - /// (x * inv_var * scale) + (bias - est_mean * inv_var * scale) - for (int i = 0; i < C; i++) { - new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; - new_bias_ptr[i] = - bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; - { - for (int n = 0; n < N; n++) { - for (int h = 0; h < H; h++) { - int tmp_index = n * stride0 + i * stride1 + h * stride2; - for (int w = 0; w < W; w++) { - int index = tmp_index + w; - out_ptr[index] = - input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i]; - } - } - } - } - } - - delete[] inv_std_ptr; - // DLOG << "input[2,5,1,0](input[102]) ,channel 5 :"; - // DLOG << "input_x_ptr : " << input_x_ptr[102]; - // DLOG << "variance : " << variance_ptr[5]; - // DLOG << "inv_std_ptr : " << inv_std_ptr[5]; - // DLOG << "new_scale_ptr : " << new_scale_ptr[5]; - // DLOG << "new_bias_ptr : " << new_bias_ptr[5]; - // DLOG << "out_ptr : " << out_ptr[102]; - } + BatchnormCompute(param); } } // namespace operators diff --git a/src/operators/kernel/arm/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/conv_add_relu_kernel.cpp index 2df48222e0923e403f2ad44b3d5c4a89aceb4cc4..0ff86c7344fed8e4060b8d46b7ff457b031479d6 100644 --- a/src/operators/kernel/arm/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/conv_add_relu_kernel.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef FUSION_CONVADD_RELU_OP #include "operators/kernel/conv_add_relu_kernel.h" +#include "operators/kernel/central-arm-func/conv_add_relu_func.h" namespace paddle_mobile { namespace operators { @@ -28,92 +29,7 @@ bool ConvAddReluKernel::Init( template <> void ConvAddReluKernel::Compute( const FusionConvAddReluParam ¶m) const { - const Tensor *input = param.Input(); - Tensor filter = *param.Filter(); - Tensor bias = *param.Bias(); - int axis = param.Axis(); - Tensor *output = param.Output(); - math::expand_bias(bias, axis, output->dims()); - output->ShareDataWith(bias); - int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); - - const int batch_size = static_cast(input->dims()[0]); - - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - - std::vector output_shape_vec(framework::vectorize(output->dims())); - size_t data_dim = filter_shape_vec.size() - 2; - std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - - framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, data_dim + 1); - - bool is_expand = - math::IsExpand(filter_shape_vec, strides, paddings, dilations); - Tensor col; - Tensor col_matrix; - if (is_expand) { - col.mutable_data(col_shape); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } - - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = { - output->dims()[1], - output->numel() / (output->dims()[0] * output->dims()[1])}; - - // convolution operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; - - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; - - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (!is_expand) { - col.ShareDataWith(in_slice); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { - // im2col - im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); - } else if (data_dim == 3U) { - // vol2col - vol2col(in_slice, dilations, strides, paddings, &col); - } - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(1), true); - } - } + ConvAddReluCompute(param); } template class ConvAddReluKernel; diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp index ce7c8b2bb3d596bc365eecd31ae4181f37be5e38..45fae59a838d4d5fd44a94559cdf60b615f5b924 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/conv_kernel.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef CONV_OP #include "operators/kernel/conv_kernel.h" +#include "operators/kernel/central-arm-func/conv_func.h" namespace paddle_mobile { namespace operators { @@ -26,88 +27,7 @@ bool ConvKernel::Init(const ConvParam ¶) const { template <> void ConvKernel::Compute(const ConvParam ¶m) const { - const Tensor *input = param.Input(); - Tensor filter = *param.Filter(); - Tensor *output = param.Output(); - output->mutable_data(); - int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); - - const int batch_size = static_cast(input->dims()[0]); - - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - - std::vector output_shape_vec(framework::vectorize(output->dims())); - size_t data_dim = filter_shape_vec.size() - 2; - std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - - framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, data_dim + 1); - - bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); - Tensor col; - Tensor col_matrix; - if (is_expand) { - col.mutable_data(col_shape); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } - - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = { - output->dims()[1], - output->numel() / (output->dims()[0] * output->dims()[1])}; - - // convolution operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; - - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; - - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (!is_expand) { - col.ShareDataWith(in_slice); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { - // im2col - im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); - } else if (data_dim == 3U) { - // vol2col - vol2col(in_slice, dilations, strides, paddings, &col); - } - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(0)); - } - } + ConvCompute(param); } template class ConvKernel; diff --git a/src/operators/kernel/central-arm-func/batchnorm_func.h b/src/operators/kernel/central-arm-func/batchnorm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..7f02d768b790b5f496ab0eac369fa3a4100ee733 --- /dev/null +++ b/src/operators/kernel/central-arm-func/batchnorm_func.h @@ -0,0 +1,234 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BATCHNORM_OP + +#pragma once + +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +void BatchnormCompute(const BatchNormParam ¶m) { + const Tensor *input_x = param.InputX(); + auto input_x_ptr = input_x->data(); + const auto &x_dims = input_x->dims(); + const int N = x_dims[0]; + const int C = x_dims[1]; + const int H = x_dims[2]; + const int W = x_dims[3]; + const int stride0 = C * H * W; + const int stride1 = H * W; + const int stride2 = W; + Tensor *out = param.OutputY(); + auto out_ptr = out->mutable_data(); + const float epsilon = param.Epsilon(); + const Tensor *mean = param.InputMean(); + const Tensor *variance = param.InputVariance(); + const Tensor *scale = param.InputScale(); + const Tensor *bias = param.InputBias(); + auto mean_ptr = mean->data(); + auto variance_ptr = variance->data(); + auto scale_ptr = scale->data(); + auto bias_ptr = bias->data(); + + // Tensor inv_std; + // auto inv_std_ptr = inv_std.mutable_data(make_ddim({C})); + + PADDLE_MOBILE_ENFORCE(C == variance->numel(), + "C must equal to variance.numel()"); + + int HXW = H * W; + if (HXW > 32) { + int NXC = N * C; + float *inv_std_ptr = new float[NXC * 4]; + float *volatile new_scale_ptr = new float[NXC * 4]; + float *volatile new_bias_ptr = new float[NXC * 4]; + + /// std = (var + epsilon).sqrt(); + /// inv_std = 1 / std; + for (int i = 0; i < C * 4; i += 4) { + int index = i / 4; + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[index] + epsilon), 0.5)); + inv_std_ptr[i + 1] = inv_std_ptr[i]; + inv_std_ptr[i + 2] = inv_std_ptr[i]; + inv_std_ptr[i + 3] = inv_std_ptr[i]; + + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[index]; + new_scale_ptr[i + 1] = new_scale_ptr[i]; + new_scale_ptr[i + 2] = new_scale_ptr[i]; + new_scale_ptr[i + 3] = new_scale_ptr[i]; + + new_bias_ptr[i] = + bias_ptr[index] - mean_ptr[index] * inv_std_ptr[i] * scale_ptr[index]; + + new_bias_ptr[i + 1] = new_bias_ptr[i]; + new_bias_ptr[i + 2] = new_bias_ptr[i]; + new_bias_ptr[i + 3] = new_bias_ptr[i]; + } + + for (int j = C * 4; j < NXC * 4; ++j) { + new_scale_ptr[j] = new_scale_ptr[j - C * 4]; + new_bias_ptr[j] = new_bias_ptr[j - C * 4]; + } + + asm volatile( + "subs %[N], %[N], #1 \n\t" + "blt end_n_%= \n\t" + "loop_n_%=: \n\t" + + "subs %[C], %[C], #1 \n\t" + "blt end_c_%= \n\t" + "loop_c_%=: \n\t" + + "vld1.32 {q9}, [%[new_scale_ptr]]! \n\t" + "vld1.32 {q10}, [%[new_bias_ptr]]! \n\t" + + "mov r6, %[HXW] \n\t" + + "subs r6, r6, #32 \n\t" + "blt end_hw_%= \n\t" + "loop_hw_%=: \n\t" + + "vld1.32 {q1, q2}, [%[input_x_ptr]]! \n\t" + "vld1.32 {q3, q4}, [%[input_x_ptr]]! \n\t" + "vld1.32 {q5, q6}, [%[input_x_ptr]]! \n\t" + "vld1.32 {q7, q8}, [%[input_x_ptr]]! \n\t" + + "vmul.f32 q1, q1, q9 \n\t" + "vmul.f32 q2, q2, q9 \n\t" + "vmul.f32 q3, q3, q9 \n\t" + "vmul.f32 q4, q4, q9 \n\t" + + "vmul.f32 q5, q5, q9 \n\t" + "vmul.f32 q6, q6, q9 \n\t" + "vmul.f32 q7, q7, q9 \n\t" + "vmul.f32 q8, q8, q9 \n\t" + + "vadd.f32 q1, q1, q10 \n\t" + "vadd.f32 q2, q2, q10 \n\t" + "vadd.f32 q3, q3, q10 \n\t" + "vadd.f32 q4, q4, q10 \n\t" + "vadd.f32 q5, q5, q10 \n\t" + "vadd.f32 q6, q6, q10 \n\t" + "vadd.f32 q7, q7, q10 \n\t" + "vadd.f32 q8, q8, q10 \n\t" + + "vst1.32 {q1, q2}, [%[out_ptr]]! \n\t" + "vst1.32 {q3, q4}, [%[out_ptr]]! \n\t" + "vst1.32 {q5, q6}, [%[out_ptr]]! \n\t" + "vst1.32 {q7, q8}, [%[out_ptr]]! \n\t" + + "subs r6, r6, #32 \n\t" + "bge loop_hw_%= \n\t" + "end_hw_%=: \n\t" + + "cmp r6, #0 \n\t" + "bge end_remainder_%= \n\t" + "mov r5, #4 \n\t" + "mul r6, r6, r5 \n\t" + "add %[input_x_ptr], %[input_x_ptr], r6 \n\t" + + "vld1.32 {q1, q2}, [%[input_x_ptr]]! \n\t" + "vld1.32 {q3, q4}, [%[input_x_ptr]]! \n\t" + "vld1.32 {q5, q6}, [%[input_x_ptr]]! \n\t" + "vld1.32 {q7, q8}, [%[input_x_ptr]]! \n\t" + + "vmul.f32 q1, q1, q9 \n\t" + "vmul.f32 q2, q2, q9 \n\t" + "vmul.f32 q3, q3, q9 \n\t" + "vmul.f32 q4, q4, q9 \n\t" + "vmul.f32 q5, q5, q9 \n\t" + "vmul.f32 q6, q6, q9 \n\t" + "vmul.f32 q7, q7, q9 \n\t" + "vmul.f32 q8, q8, q9 \n\t" + "vadd.f32 q1, q1, q10 \n\t" + "vadd.f32 q2, q2, q10 \n\t" + "vadd.f32 q3, q3, q10 \n\t" + "vadd.f32 q4, q4, q10 \n\t" + "vadd.f32 q5, q5, q10 \n\t" + "vadd.f32 q6, q6, q10 \n\t" + "vadd.f32 q7, q7, q10 \n\t" + "vadd.f32 q8, q8, q10 \n\t" + + "add %[out_ptr], %[out_ptr], r6 \n\t" + "vst1.32 {q1, q2}, [%[out_ptr]]! \n\t" + "vst1.32 {q3, q4}, [%[out_ptr]]! \n\t" + "vst1.32 {q5, q6}, [%[out_ptr]]! \n\t" + "vst1.32 {q7, q8}, [%[out_ptr]]! \n\t" + + "end_remainder_%=: \n\t" + + "subs %[C], %[C], #1 \n\t" + "bge loop_c_%= \n\t" + "end_c_%=: \n\t" + + "subs %[N], %[N], #1 \n\t" + "bge loop_n_%= \n\t" + "end_n_%=: \n\t" + : + : [input_x_ptr] "r"(input_x_ptr), [out_ptr] "r"(out_ptr), + [new_scale_ptr] "r"(new_scale_ptr), [new_bias_ptr] "r"(new_bias_ptr), + [N] "r"(N), [C] "r"(C), [HXW] "r"(HXW) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "q10", "r5", "r6"); + + delete[] inv_std_ptr; + delete[] new_scale_ptr; + delete[] new_bias_ptr; + + } else { + float *inv_std_ptr = new float[C]; + for (int i = 0; i < C; i++) { + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + } + + Tensor new_scale; + auto new_scale_ptr = + new_scale.mutable_data(framework::make_ddim({C})); + Tensor new_bias; + auto new_bias_ptr = new_bias.mutable_data(framework::make_ddim({C})); + + /// ((x - est_mean) * (inv_var) * scale + bias equal to + /// (x * inv_var * scale) + (bias - est_mean * inv_var * scale) + for (int i = 0; i < C; i++) { + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; + new_bias_ptr[i] = + bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; + { + for (int n = 0; n < N; n++) { + for (int h = 0; h < H; h++) { + int tmp_index = n * stride0 + i * stride1 + h * stride2; + for (int w = 0; w < W; w++) { + int index = tmp_index + w; + out_ptr[index] = + input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i]; + } + } + } + } + } + + delete[] inv_std_ptr; + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/conv_add_relu_func.h b/src/operators/kernel/central-arm-func/conv_add_relu_func.h new file mode 100644 index 0000000000000000000000000000000000000000..416b4963aefdbd6f2c796378aec4b953a08e28cb --- /dev/null +++ b/src/operators/kernel/central-arm-func/conv_add_relu_func.h @@ -0,0 +1,116 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADD_RELU_OP + +#pragma once +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { + const Tensor *input = param.Input(); + Tensor filter = *param.Filter(); + Tensor bias = *param.Bias(); + int axis = param.Axis(); + Tensor *output = param.Output(); + math::expand_bias(bias, axis, output->dims()); + output->ShareDataWith(bias); + int groups = param.Groups(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + + const int batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + + std::vector output_shape_vec(framework::vectorize(output->dims())); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + + bool is_expand = + math::IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col.mutable_data(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(in_slice, dilations, strides, paddings, &col); + } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(1), true); + } + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/conv_func.h b/src/operators/kernel/central-arm-func/conv_func.h new file mode 100644 index 0000000000000000000000000000000000000000..30cfb24043b32effb723f029ef7e5e5cdd1f1e99 --- /dev/null +++ b/src/operators/kernel/central-arm-func/conv_func.h @@ -0,0 +1,112 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CONV_OP + +#pragma once +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +void ConvCompute(const ConvParam ¶m) { + const Tensor *input = param.Input(); + Tensor filter = *param.Filter(); + Tensor *output = param.Output(); + output->mutable_data(); + int groups = param.Groups(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + + const int batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + + std::vector output_shape_vec(framework::vectorize(output->dims())); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + + bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col.mutable_data(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(in_slice, dilations, strides, paddings, &col); + } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(0)); + } + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/tools/push2android.sh b/tools/android-debug-script/push2android.sh similarity index 59% rename from tools/push2android.sh rename to tools/android-debug-script/push2android.sh index d7d1ad9950d58f415804834b8ebc0740a3e796cb..fae1a856123bd16cf3f7a115f61b3e4473ff58a3 100644 --- a/tools/push2android.sh +++ b/tools/android-debug-script/push2android.sh @@ -1,10 +1,10 @@ #!/usr/bin/env sh push_fn () { -MODELS_PATH="../test/models/*" -MODELS_SRC="../test/models" -IMAGE_PATH="../test/images/*" -EXE_FILE="../test/build/*" +MODELS_PATH="../../test/models/*" +MODELS_SRC="../../test/models" +IMAGE_PATH="../../test/images/*" +EXE_FILE="../../test/build/*" EXE_DIR="data/local/tmp/bin" adb shell mkdir ${EXE_DIR} MODELS_DIR="data/local/tmp/models" @@ -14,9 +14,14 @@ do adb shell mkdir ${MODELS_DIR}"/"${file} done +if [[ -d "../../src/operators/kernel/mali/ACL_Android/build" ]]; then +ACL_BUILD_PATH="../../src/operators/kernel/mali/ACL_Android/build/*" +adb push ${ACL_BUILD_PATH} ${EXE_DIR} +fi + IMAGES_DIR="data/local/tmp/images" adb shell mkdir ${IMAGES_DIR} -LIB_PATH="../build/release/arm-v7a/build/*" +LIB_PATH="../../build/release/arm-v7a/build/*" adb push ${EXE_FILE} ${EXE_DIR} adb push ${LIB_PATH} ${EXE_DIR} if [[ $1 != "npm" ]]; then diff --git a/tools/scripts/run_on_android.sh b/tools/android-debug-script/run_on_android.sh similarity index 100% rename from tools/scripts/run_on_android.sh rename to tools/android-debug-script/run_on_android.sh diff --git a/tools/run.sh b/tools/run.sh deleted file mode 100644 index aaf0f52f0335d6e73060ed9b8e86a78ba357c552..0000000000000000000000000000000000000000 --- a/tools/run.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env sh -# auto build and run - -BUILDNET="mobilenetssd" -TESTUNIT="test-mobilenetssd" - -push_fn () { -sh build.sh android ${BUILDNET} -MODELS_PATH="../test/models/*" -MODELS_SRC="../test/models" -IMAGE_PATH="../test/images/*" -EXE_FILE="../test/build/*" -EXE_DIR="data/local/tmp/bin" -adb shell mkdir ${EXE_DIR} -MODELS_DIR="data/local/tmp/models" -adb shell mkdir ${MODELS_DIR} -for file in `ls ${MODELS_SRC}` -do - adb shell mkdir ${MODELS_DIR}"/"${file} -done - -IMAGES_DIR="data/local/tmp/images" -adb shell mkdir ${IMAGES_DIR} -LIB_PATH="../build/release/arm-v7a/build/*" -adb push ${EXE_FILE} ${EXE_DIR} -adb push ${LIB_PATH} ${EXE_DIR} -if [[ $1 != "npm" ]]; then -adb push ${IMAGE_PATH} ${IMAGES_DIR} -adb push ${MODELS_PATH} ${MODELS_DIR} -fi -adb shell "cd /data/local/tmp/bin; LD_LIBRARY_PATH=. ./${TESTUNIT}" -} - -if [[ $1 == "npm" ]]; then -push_fn $1 -else -push_fn -fi