diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 6fa97687e74d1128c29b1858338e129c4544402f..fe95e6c7f3e913e0c0801b0371ffe5a179fb77ff 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -28,7 +28,7 @@ limitations under the License. */ #include "framework/scope.h" #include "framework/tensor.h" #include "memory/t_malloc.h" - +#include "pass/memory_optimize.h" #ifdef PADDLE_MOBILE_CL #include "framework/cl/cl_image.h" #endif @@ -62,6 +62,7 @@ Executor::Executor(const Program &program, use_optimize_ ? program_.optimizeProgram : program_.originProgram; PADDLE_MOBILE_ENFORCE(program_desc_ != nullptr, "program_desc_ should not be nullptr"); + pass::MemoryOptPass()(program_desc_.get(), program_.scope.get()); // resize feed and fetch list // should init feed and fetch variables before infer shape InitFeedFetchList(); @@ -210,6 +211,7 @@ void Executor::InitMemory() { var->template GetMutable(); continue; } + DLOG << "init persistable var: " << var_desc->Name(); char *origin_data = ReadFileToBuff(program_.model_path + "/" + var_desc->Name()); char *data = origin_data; @@ -322,7 +324,6 @@ bool Executor::varInputMemory( if (type == VARTYPE_TYPE_LOD_TENSOR) { auto data_type = var_desc->Tensor_desc().DataType(); framework::LoDTensor *tensor = var->template GetMutable(); - tensor->mutable_data(TypeId(data_type)); } else if (type == VARTYPE_TYPE_STEP_SCOPES) { std::vector *step_scopes = var->template GetMutable>(); @@ -458,6 +459,7 @@ PMStatus Executor::Predict() { clock_gettime(CLOCK_MONOTONIC, &ts); profile[op_index].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; #endif + DLOG << "run op: " << op_handler->Type(); if (lod_mode_) { op_handler->InferShape(); } diff --git a/src/framework/program/program_desc.cpp b/src/framework/program/program_desc.cpp index b66c7a0dcf97ef8517e1122d2834aa992736c6e7..23781fe77962608b515ebb7c0479f284ebfc4277 100644 --- a/src/framework/program/program_desc.cpp +++ b/src/framework/program/program_desc.cpp @@ -46,7 +46,7 @@ ProgramDesc::ProgramDesc(PaddleMobile__Framework__Proto__ProgramDesc *desc) { } } -void ProgramDesc::Description(std::string header) { +void ProgramDesc::Description(std::string header) const { #ifdef PADDLE_MOBILE_DEBUG if (header.size()) { LOG(kLOG_INFO) << header; diff --git a/src/framework/program/program_desc.h b/src/framework/program/program_desc.h index 5c75c915223d0768120b4153c38a3772ba74d8e9..f4551509ee2846e96e8e9b672a22b9de673658ab 100644 --- a/src/framework/program/program_desc.h +++ b/src/framework/program/program_desc.h @@ -30,6 +30,14 @@ class ProgramDesc { friend class ProgramOptimize; explicit ProgramDesc(PaddleMobile__Framework__Proto__ProgramDesc *desc); + ProgramDesc(const ProgramDesc &program_desc) { + for (auto &block : program_desc.blocks_) { + std::shared_ptr copy_block = + std::make_shared(*block); + blocks_.push_back(copy_block); + } + } + std::shared_ptr Block(size_t idx); BlockDesc *MutableBlock(size_t idx) { @@ -40,16 +48,11 @@ class ProgramDesc { } } - const std::vector> &Blocks() { return blocks_; } - ProgramDesc(const ProgramDesc &program_desc) { - for (auto &block : program_desc.blocks_) { - std::shared_ptr copy_block = - std::make_shared(*block); - blocks_.push_back(copy_block); - } + const std::vector> &Blocks() const { + return blocks_; } - void Description(std::string header = ""); + void Description(std::string header = "") const; private: std::vector> blocks_; diff --git a/src/framework/tensor.h b/src/framework/tensor.h index 24f09662ea5ecca2a96ccdac7e863034f6a3a311..4fb06c654983b9e1b8441b074d5a30220fb960a2 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -74,6 +74,15 @@ class Tensor : public TensorBase { return *this; } + /*! The internal of two tensors share the same memory block. */ + inline Tensor &ShareHolderWith(const Tensor &src) { + src.check_memory_size(); + if (holder_.get() != src.holder_.get()) { + holder_ = src.holder_; + } + return *this; + } + inline void *mutable_data(std::type_index type) { if (holder_ != nullptr) { holder_->set_type(type); @@ -81,7 +90,11 @@ class Tensor : public TensorBase { PADDLE_MOBILE_ENFORCE(numel() >= 0, "the Tensor's numel must >=0.") int64_t size = numel() * SizeOfType(type); if (holder_ == nullptr || holder_->size() < size + offset_) { - holder_.reset(new PlaceholderImpl(size, type)); + if (holder_ == nullptr) { + holder_.reset(new PlaceholderImpl(size, type)); + } else { + holder_->resize(size); + } offset_ = 0; } return reinterpret_cast( @@ -180,6 +193,7 @@ class Tensor : public TensorBase { : ptr_(static_cast(memory::Alloc(size)), memory::PODDeleter()), size_(size), + capatity_(size), type_(type) { PADDLE_MOBILE_ENFORCE(ptr_ != nullptr, "Insufficient memory to allocation"); @@ -193,11 +207,21 @@ class Tensor : public TensorBase { virtual void set_type(std::type_index type) { type_ = type; } + virtual void resize(size_t size) { + if (size > capatity_) { + capatity_ = size; + ptr_.reset(static_cast(memory::Alloc(capatity_))); + } + size_ = size; + } + std::unique_ptr> ptr_; /*! the size of memory block. */ size_t size_; + size_t capatity_; + /* the current type of memory */ std::type_index type_; }; diff --git a/src/framework/tensor_base.h b/src/framework/tensor_base.h index b41d7786c15222b3133d02b820cf4e089b19c1d3..e5ab7793c0eb392126359b050b9f0b8bc25c3515 100644 --- a/src/framework/tensor_base.h +++ b/src/framework/tensor_base.h @@ -117,6 +117,8 @@ class TensorBase { virtual std::type_index type() const = 0; virtual void set_type(std::type_index type) = 0; + + virtual void resize(size_t size) = 0; }; /** diff --git a/src/operators/kernel/arm/compare_kernel.cpp b/src/operators/kernel/arm/compare_kernel.cpp index c2ba6b583a1f80c3545f565464ff614ba4aaf52e..d83fae1748bc9b942425219ea169b63d5c84b867 100644 --- a/src/operators/kernel/arm/compare_kernel.cpp +++ b/src/operators/kernel/arm/compare_kernel.cpp @@ -79,7 +79,7 @@ struct CompareCompute { if (elementwise_num == 1) { int remain_start = 0; #if defined(__ARM_NEON__) || defined(__ARM_NEON) - remain_start = channels & 0xfff8; + remain_start = channels & 0xfffffff8; uint8x8_t __mask = vdup_n_u8(0x1); for (int i = 0; i < batch; ++i) { for (int j = 0; j < channels - 7; j += 8) { @@ -112,7 +112,7 @@ struct CompareCompute { int y_offset = j * elementwise_num; int remain_start = 0; #if defined(__ARM_NEON__) || defined(__ARM_NEON) - remain_start = elementwise_num & 0xfff8; + remain_start = elementwise_num & 0xfffffff8; uint8x8_t __mask = vdup_n_u8(0x1); for (int k = 0; k < elementwise_num - 7; k += 8) { float32x4_t __x0 = vld1q_f32(x + x_offset); diff --git a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp index 1a256eb733a11892c72ef4a12a84c78b914d87e6..d8d17dec2d3fefec174b756791792e734d37a9c7 100644 --- a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp @@ -18,7 +18,7 @@ limitations under the License. */ #include #include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/central-arm-func/conv_arm_func.h" -#include "operators/math/channel_wise.h" +#include "operators/math/element_wise.h" namespace paddle_mobile { namespace operators { diff --git a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp index 3ac1315ba9d0df36725ad6937594a3a8ddf82bf4..dabe1d389b9734c8613f18b908bf936c5ad6c353 100644 --- a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp @@ -17,7 +17,7 @@ limitations under the License. */ #include "operators/kernel/conv_add_kernel.h" #include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/central-arm-func/conv_arm_func.h" -#include "operators/math/channel_wise.h" +#include "operators/math/element_wise.h" namespace paddle_mobile { namespace operators { diff --git a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp index 104bb6d8b227455594ab34a37dabdb978553aac1..c06cd33dc5e8a40630842e36388a6d5b5ed1be45 100644 --- a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp @@ -17,7 +17,7 @@ limitations under the License. */ #include "operators/kernel/conv_add_relu_kernel.h" #include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/central-arm-func/conv_arm_func.h" -#include "operators/math/channel_wise.h" +#include "operators/math/element_wise.h" namespace paddle_mobile { namespace operators { diff --git a/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp index ceb1cf5144212f2e0e791b70d8a36ed3b7a62700..509bfd4df456116fb9fe27762978105dcc42f54b 100644 --- a/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp @@ -18,7 +18,7 @@ limitations under the License. */ #include #include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/central-arm-func/conv_arm_func.h" -#include "operators/math/channel_wise.h" +#include "operators/math/element_wise.h" namespace paddle_mobile { namespace operators { @@ -34,27 +34,15 @@ bool ConvBNAddReluKernel::Init( auto mean_ptr = mean->data(); auto variance_ptr = variance->data(); - auto scale_ptr = scale->data(); - auto bias_ptr = bias->data(); + auto scale_ptr = const_cast(scale->data()); + auto bias_ptr = const_cast(bias->data()); - const int C = mean->numel(); - float inv_std_ptr[C]; - for (int i = 0; i < C; i++) { - inv_std_ptr[i] = - 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + for (int c = 0; c < scale->numel(); ++c) { + float inv_scale = 1.f / (pow(variance_ptr[c] + epsilon, 0.5)); + bias_ptr[c] -= inv_scale * scale_ptr[c] * mean_ptr[c]; + scale_ptr[c] *= inv_scale; } - auto *new_scale = param->CreateNewScale(); - auto *new_bias = param->CreateNewBiase(); - auto new_scale_ptr = new_scale->mutable_data({C}); - auto new_bias_ptr = new_bias->mutable_data({C}); - for (int i = 0; i < C; i++) { - new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; - new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; - } - param->SetNewScale(new_scale); - param->SetNewBias(new_bias); - InitBaseConvKernel(param); return true; } @@ -84,9 +72,19 @@ void ConvBNAddReluKernel::Compute( PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", param.ExecMode()); } - math::ScaleAddChannelWise(param.Output(), param.NewScale(), - param.NewBias(), param.Output()); + + if (param.Bias()->dims() == param.Output()->dims()) { + math::ScaleAddChannelWise(param.Output(), param.InputScale(), + param.InputBias(), param.Bias(), + param.Output()); + } else { + math::ScaleAddChannelWise(param.Output(), param.InputScale(), + param.InputBias(), param.Output()); + math::AddElememtWise(param.Output(), param.Bias(), param.Axis(), + param.Output()); + } } + template class ConvBNAddReluKernel; } // namespace operators diff --git a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp index eafb9f763108b28d627f14f9a9d04e4378de4423..1bfbfccae49b185fdc221f7208f350b16719e353 100644 --- a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp @@ -18,7 +18,7 @@ limitations under the License. */ #include #include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/central-arm-func/conv_arm_func.h" -#include "operators/math/channel_wise.h" +#include "operators/math/element_wise.h" namespace paddle_mobile { namespace operators { diff --git a/src/operators/kernel/arm/convolution/conv_common.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp index 361d315a59aa8ace4e964a25514a8ebbf165717d..86c6e8d3373f743f485e4a69599e8c8323ba0083 100644 --- a/src/operators/kernel/arm/convolution/conv_common.cpp +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -52,7 +52,7 @@ void InitBaseConvKernel(ConvParam *param) { } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] && param->Strides()[0] == 1) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE5x5_FLOAT; - } else if (conv3x3 && !depth3x3 && + } else if (conv3x3 && param->Groups() == 1 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && param->Strides()[0] == 1 && param->Dilations()[0] == 1 @@ -66,7 +66,7 @@ void InitBaseConvKernel(ConvParam *param) { param->transformed_filter_ = new framework::LoDTensor; operators::math::winograd_transform_weight<8, 3>( *param->Filter(), param->transformed_filter_); - } else if (conv3x3 && !depth3x3 && + } else if (conv3x3 && param->Groups() == 1 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && param->Strides()[0] == 1 && param->Dilations()[0] == 1 @@ -76,7 +76,7 @@ void InitBaseConvKernel(ConvParam *param) { #endif ) { param->ExecMode() = ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT; - } else if (conv3x3 && !depth3x3 && + } else if (conv3x3 && param->Groups() == 1 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && param->Strides()[0] == 2 && param->Dilations()[0] == 1 diff --git a/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp index 063d51330eb5cc03c5596a8e209480ba0505009f..2035ad4739e7b6cef9f60df6de5d6b5f0f2a2125 100644 --- a/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp @@ -18,7 +18,7 @@ limitations under the License. */ #include #include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/central-arm-func/conv_arm_func.h" -#include "operators/math/channel_wise.h" +#include "operators/math/element_wise.h" namespace paddle_mobile { namespace operators { diff --git a/src/operators/kernel/arm/sequence_pool_kernel.cpp b/src/operators/kernel/arm/sequence_pool_kernel.cpp index 352158b973050c99555a82c0d0f02c318b7702ac..2be2accf58cc1fca527c44d9083538c4a72c828d 100644 --- a/src/operators/kernel/arm/sequence_pool_kernel.cpp +++ b/src/operators/kernel/arm/sequence_pool_kernel.cpp @@ -68,7 +68,7 @@ void SequencePoolImpl(const framework::LoDTensor &input, int remain_h = height - 1; int remain_w_start = 0; #ifdef __ARM_NEON__ - remain_w_start = width & 0xfffc; + remain_w_start = width & 0xfffffffc; #endif // __ARM_NEON__ for (int h = 0; h < remain_h; ++h) { #ifdef __ARM_NEON__ @@ -128,7 +128,7 @@ void SequencePoolImpl(const framework::LoDTensor &input, int remain_w_start = 0; #ifdef __ARM_NEON__ int loop_w = width >> 2; - remain_w_start = width & 0xfffc; + remain_w_start = width & 0xfffffffc; #endif // __ARM_NEON__ for (int h = 0; h < remain_h; ++h) { #ifdef __ARM_NEON__ diff --git a/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h b/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h index df78b96147b270b592eea68668550b0c55fde0bf..5a2b416b79ada2eea8ab1aad9135a223885c05be 100644 --- a/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h @@ -16,7 +16,7 @@ limitations under the License. */ #pragma once -#include "operators/math/elementwise_op_function.h" +#include "operators/math/element_wise.h" #include "operators/op_param.h" #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include @@ -29,98 +29,10 @@ template inline void ElementwiseAddCompute(const ElementwiseAddParam ¶m) { const framework::Tensor *input_x = param.InputX(); const framework::Tensor *input_y = param.InputY(); - framework::Tensor *Out = param.Out(); + framework::Tensor *output = param.Out(); int axis = param.Axis(); - const auto &x_dims = input_x->dims(); - const auto &y_dims = input_y->dims(); - /// axis = -1 represent the last dimensions. - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); - size_t batch = 1; - size_t channels = 1; - size_t elementwise_num = 1; - for (int i = 0; i < axis; ++i) { - batch *= x_dims[i]; - } - for (int i = 0; i < y_dims.size(); ++i) { - channels *= y_dims[i]; - } - for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) { - elementwise_num *= x_dims[i]; - } - const float *bias_data = input_y->data(); - const float *input_data = input_x->data(); - float *output_data = Out->mutable_data(); - - #pragma omp parallel for collapse(2) - for (int i = 0; i < batch; ++i) { - for (int j = 0; j < channels; ++j) { - size_t offset = (i * channels + j) * elementwise_num; - const float *input = input_data + offset; - const float bias = bias_data[j]; - float *output = output_data + offset; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - int loop = elementwise_num >> 0x4; - int remain = elementwise_num & 0xF; - float32x4_t rb = vdupq_n_f32(bias); - for (int k = 0; k < loop; ++k) { - float32x4_t r0 = vld1q_f32(input); - float32x4_t r1 = vld1q_f32(input + 4); - float32x4_t r2 = vld1q_f32(input + 8); - float32x4_t r3 = vld1q_f32(input + 12); - r0 = vaddq_f32(r0, rb); - r1 = vaddq_f32(r1, rb); - r2 = vaddq_f32(r2, rb); - r3 = vaddq_f32(r3, rb); - vst1q_f32(output, r0); - vst1q_f32(output + 4, r1); - vst1q_f32(output + 8, r2); - vst1q_f32(output + 12, r3); - input += 16; - output += 16; - } - if (remain >= 8) { - float32x4_t r0 = vld1q_f32(input); - float32x4_t r1 = vld1q_f32(input + 4); - r0 = vaddq_f32(r0, rb); - r1 = vaddq_f32(r1, rb); - vst1q_f32(output, r0); - vst1q_f32(output + 4, r1); - input += 8; - output += 8; - remain -= 8; - } - if (remain >= 4) { - float32x4_t r0 = vld1q_f32(input); - r0 = vaddq_f32(r0, rb); - vst1q_f32(output, r0); - input += 4; - output += 4; - remain -= 4; - } - if (remain > 0) { - float32x4_t r0 = vld1q_f32(input); - r0 = vaddq_f32(r0, rb); - switch (remain) { - case 1: - vst1q_lane_f32(output, r0, 0); - break; - case 2: - vst1_f32(output, vget_low_f32(r0)); - break; - case 3: - vst1_f32(output, vget_low_f32(r0)); - vst1q_lane_f32(output, r0, 2); - break; - } - } -#else - for (int k = 0; k < elementwise_num; ++k) { - output[k] = input[k] + bias; - } -#endif // __ARM_NEON__ - } - } + math::AddElememtWise(input_x, input_y, axis, output); } template class ElementwiseAddKernel; diff --git a/src/operators/math/channel_wise.h b/src/operators/math/channel_wise.h deleted file mode 100644 index e4c0cbe05bfabde42df7f33a71882aa8ec08c477..0000000000000000000000000000000000000000 --- a/src/operators/math/channel_wise.h +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "framework/tensor.h" -#include "operators/math/activation.h" -#ifdef __ARM_NEON -#include -#endif - -namespace paddle_mobile { -namespace operators { -namespace math { - -template -void AddChannelWise(const framework::Tensor *input, - const framework::Tensor *bias, framework::Tensor *output) { - const float *input_ptr = input->data(); - const float *bias_ptr = bias->data(); - float *output_ptr = output->mutable_data(); - // maybe check shape - int batch_size = input->dims()[0]; - int channels = input->dims()[1]; - int spatial_size = input->dims()[2] * input->dims()[3]; - - for (int batch = 0; batch < batch_size; ++batch) { - for (int channel = 0; channel < channels; ++channel) { - size_t offset = (batch * channels + channel) * spatial_size; - const float *x = input_ptr + offset; - float *y = output_ptr + offset; - float beta = bias_ptr[channel]; - int j = 0; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - float32x4_t __bias = vdupq_n_f32(beta); - for (; j < spatial_size - 15; j += 16, x += 16, y += 16) { - float32x4_t in0 = vld1q_f32(x); - float32x4_t in1 = vld1q_f32(x + 4); - float32x4_t in2 = vld1q_f32(x + 8); - float32x4_t in3 = vld1q_f32(x + 12); - in0 = vaddq_f32(__bias, in0); - in1 = vaddq_f32(__bias, in1); - in2 = vaddq_f32(__bias, in2); - in3 = vaddq_f32(__bias, in3); - in0 = math::vActiveq_f32(in0); - in1 = math::vActiveq_f32(in1); - in2 = math::vActiveq_f32(in2); - in3 = math::vActiveq_f32(in3); - vst1q_f32(y, in0); - vst1q_f32(y + 4, in1); - vst1q_f32(y + 8, in2); - vst1q_f32(y + 12, in3); - } - for (; j < spatial_size - 3; j += 4, x += 4, y += 4) { - float32x4_t in0 = vld1q_f32(x); - in0 = vaddq_f32(__bias, in0); - in0 = math::vActiveq_f32(in0); - vst1q_f32(y, in0); - } -#endif - for (; j < spatial_size; ++j, ++x, ++y) { - *y = math::Active((*x) + beta); - } - } - } -} - -template -void ScaleAddChannelWise(const framework::Tensor *input, - const framework::Tensor *scale, - const framework::Tensor *bias, - framework::Tensor *output) { - const float *input_ptr = input->data(); - const float *scale_ptr = scale->data(); - const float *bias_ptr = bias->data(); - float *output_ptr = output->mutable_data(); - // maybe check shape - int batch_size = input->dims()[0]; - int channels = input->dims()[1]; - int spatial_size = input->dims()[2] * input->dims()[3]; - - for (int batch = 0; batch < batch_size; ++batch) { - for (int channel = 0; channel < channels; ++channel) { - size_t offset = (batch * channels + channel) * spatial_size; - const float *x = input_ptr + offset; - float *y = output_ptr + offset; - float alpha = scale_ptr[channel]; - float beta = bias_ptr[channel]; - int j = 0; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - float32x4_t __scale = vdupq_n_f32(alpha); - float32x4_t __bias = vdupq_n_f32(beta); - for (; j < spatial_size - 15; j += 16, x += 16, y += 16) { - float32x4_t in0 = vld1q_f32(x); - float32x4_t in1 = vld1q_f32(x + 4); - float32x4_t in2 = vld1q_f32(x + 8); - float32x4_t in3 = vld1q_f32(x + 12); - in0 = vmlaq_f32(__bias, __scale, in0); - in1 = vmlaq_f32(__bias, __scale, in1); - in2 = vmlaq_f32(__bias, __scale, in2); - in3 = vmlaq_f32(__bias, __scale, in3); - in0 = math::vActiveq_f32(in0); - in1 = math::vActiveq_f32(in1); - in2 = math::vActiveq_f32(in2); - in3 = math::vActiveq_f32(in3); - vst1q_f32(y, in0); - vst1q_f32(y + 4, in1); - vst1q_f32(y + 8, in2); - vst1q_f32(y + 12, in3); - } - for (; j < spatial_size - 3; j += 4, x += 4, y += 4) { - float32x4_t in0 = vld1q_f32(x); - in0 = vmlaq_f32(__bias, __scale, in0); - in0 = math::vActiveq_f32(in0); - vst1q_f32(y, in0); - } -#endif - for (; j < spatial_size; ++j, ++x, ++y) { - *y = math::Active(alpha * (*x) + beta); - } - } - } -} - -} // namespace math -} // namespace operators -} // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv3x3.cpp b/src/operators/math/depthwise_conv3x3.cpp index 62fae35060c97e52142143fcc87b7571b13b1959..807067cc33190f0b84a7f87dd5bb2b1ede5377b2 100644 --- a/src/operators/math/depthwise_conv3x3.cpp +++ b/src/operators/math/depthwise_conv3x3.cpp @@ -448,7 +448,7 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, } } // remain height - int start_h = valid_h_start + (valid_h & 0xfffe); + int start_h = valid_h_start + (valid_h & 0xfffffffe); if (start_h < valid_h_end) { const float *input_ptr0 = input_ptr + (start_h - padding_h) * input_w; const float *input_ptr1 = input_ptr0 + input_w; @@ -906,7 +906,7 @@ void DepthwiseConv3x3S2(const framework::Tensor &input, } } // remain height - int start_h = valid_h_start + (valid_h & 0xfffe); + int start_h = valid_h_start + (valid_h & 0xfffffffe); if (start_h < valid_h_end) { const float *input_ptr0 = input_ptr + (2 * start_h - padding_h) * input_w; const float *input_ptr1 = input_ptr0 + input_w; diff --git a/src/operators/math/depthwise_conv3x3_int8.cpp b/src/operators/math/depthwise_conv3x3_int8.cpp index b8d7939badbfafb0f5c3ee2034320bf817eb5c32..e69df3e6bec76e74c64178e8b790642764a7c35c 100644 --- a/src/operators/math/depthwise_conv3x3_int8.cpp +++ b/src/operators/math/depthwise_conv3x3_int8.cpp @@ -580,7 +580,7 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, } } // remain height - int start_h = valid_h_start + (valid_h & 0xFFFC); + int start_h = valid_h_start + (valid_h & 0xFFFFFFFC); for (int h = start_h; h < valid_h_end - 1; h += 2) { const int8_t *input_ptr0 = input_ptr + (h - padding_h) * input_w; const int8_t *input_ptr1 = input_ptr0 + input_w; @@ -844,7 +844,7 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, } } - start_h = valid_h_start + (valid_h & 0xFFFE); + start_h = valid_h_start + (valid_h & 0xFFFFFFFE); if (start_h < valid_h_end) { const int8_t *input_ptr0 = input_ptr + (start_h - padding_h) * input_w; const int8_t *input_ptr1 = input_ptr0 + input_w; diff --git a/src/operators/math/depthwise_conv5x5.cpp b/src/operators/math/depthwise_conv5x5.cpp index 99c08c185cc8b28a4159226c2f0502794e0a0c37..a721cce71efcc44328274f41f812b24da7d2370e 100644 --- a/src/operators/math/depthwise_conv5x5.cpp +++ b/src/operators/math/depthwise_conv5x5.cpp @@ -721,7 +721,7 @@ void DepthwiseConv5x5S1(const framework::Tensor &input, } } // remain height - int start_h = valid_h_start + (valid_h & 0xfffe); + int start_h = valid_h_start + (valid_h & 0xfffffffe); if (start_h < valid_h_end) { const float *input_ptr0 = input_ptr + (start_h - padding_h) * input_w; const float *input_ptr1 = input_ptr0 + input_w; diff --git a/src/operators/math/depthwise_conv5x5_int8.cpp b/src/operators/math/depthwise_conv5x5_int8.cpp index a92d48272f3c3abbc9c86a652521db4564498d2e..1e9482beb4d0f46532becc5fa86fc6590e7790aa 100644 --- a/src/operators/math/depthwise_conv5x5_int8.cpp +++ b/src/operators/math/depthwise_conv5x5_int8.cpp @@ -686,7 +686,7 @@ void DepthwiseConv5x5S1(const framework::Tensor &input, } } // remain height - int start_h = valid_h_start + (valid_h & 0xfffe); + int start_h = valid_h_start + (valid_h & 0xfffffffe); if (start_h < valid_h_end) { const int8_t *input_ptr0 = input_ptr + (start_h - padding_h) * input_w; const int8_t *input_ptr1 = input_ptr0 + input_w; diff --git a/src/operators/math/element_wise.h b/src/operators/math/element_wise.h new file mode 100644 index 0000000000000000000000000000000000000000..6c75e53cb7c4d1a8f8c513e5435f6516bc9d720a --- /dev/null +++ b/src/operators/math/element_wise.h @@ -0,0 +1,359 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "framework/tensor.h" +#include "operators/math/activation.h" +#ifdef __ARM_NEON +#include +#endif + +namespace paddle_mobile { +namespace operators { +namespace math { + +template +void AddChannelWise(const framework::Tensor *input, + const framework::Tensor *bias, framework::Tensor *output) { + const float *input_ptr = input->data(); + const float *bias_ptr = bias->data(); + float *output_ptr = output->mutable_data(); + // maybe check shape + int batch_size = input->dims()[0]; + int channels = input->dims()[1]; + int spatial_size = input->dims()[2] * input->dims()[3]; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int channel = 0; channel < channels; ++channel) { + size_t offset = (batch * channels + channel) * spatial_size; + const float *x = input_ptr + offset; + float *y = output_ptr + offset; + float beta = bias_ptr[channel]; + int j = 0; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float32x4_t __bias = vdupq_n_f32(beta); + for (; j < spatial_size - 15; j += 16, x += 16, y += 16) { + float32x4_t in0 = vld1q_f32(x); + float32x4_t in1 = vld1q_f32(x + 4); + float32x4_t in2 = vld1q_f32(x + 8); + float32x4_t in3 = vld1q_f32(x + 12); + in0 = vaddq_f32(__bias, in0); + in1 = vaddq_f32(__bias, in1); + in2 = vaddq_f32(__bias, in2); + in3 = vaddq_f32(__bias, in3); + in0 = math::vActiveq_f32(in0); + in1 = math::vActiveq_f32(in1); + in2 = math::vActiveq_f32(in2); + in3 = math::vActiveq_f32(in3); + vst1q_f32(y, in0); + vst1q_f32(y + 4, in1); + vst1q_f32(y + 8, in2); + vst1q_f32(y + 12, in3); + } + for (; j < spatial_size - 3; j += 4, x += 4, y += 4) { + float32x4_t in0 = vld1q_f32(x); + in0 = vaddq_f32(__bias, in0); + in0 = math::vActiveq_f32(in0); + vst1q_f32(y, in0); + } +#endif + for (; j < spatial_size; ++j, ++x, ++y) { + *y = math::Active((*x) + beta); + } + } + } +} + +template +void ScaleAddChannelWise(const framework::Tensor *input, + const framework::Tensor *scale, + const framework::Tensor *bias, + framework::Tensor *output) { + const float *input_ptr = input->data(); + const float *scale_ptr = scale->data(); + const float *bias_ptr = bias->data(); + float *output_ptr = output->mutable_data(); + // maybe check shape + int batch_size = input->dims()[0]; + int channels = input->dims()[1]; + int spatial_size = input->dims()[2] * input->dims()[3]; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int channel = 0; channel < channels; ++channel) { + size_t offset = (batch * channels + channel) * spatial_size; + const float *x = input_ptr + offset; + float *y = output_ptr + offset; + float alpha = scale_ptr[channel]; + float beta = bias_ptr[channel]; + int j = 0; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float32x4_t __scale = vdupq_n_f32(alpha); + float32x4_t __bias = vdupq_n_f32(beta); + for (; j < spatial_size - 15; j += 16, x += 16, y += 16) { + float32x4_t in0 = vld1q_f32(x); + float32x4_t in1 = vld1q_f32(x + 4); + float32x4_t in2 = vld1q_f32(x + 8); + float32x4_t in3 = vld1q_f32(x + 12); + in0 = vmlaq_f32(__bias, __scale, in0); + in1 = vmlaq_f32(__bias, __scale, in1); + in2 = vmlaq_f32(__bias, __scale, in2); + in3 = vmlaq_f32(__bias, __scale, in3); + in0 = math::vActiveq_f32(in0); + in1 = math::vActiveq_f32(in1); + in2 = math::vActiveq_f32(in2); + in3 = math::vActiveq_f32(in3); + vst1q_f32(y, in0); + vst1q_f32(y + 4, in1); + vst1q_f32(y + 8, in2); + vst1q_f32(y + 12, in3); + } + for (; j < spatial_size - 3; j += 4, x += 4, y += 4) { + float32x4_t in0 = vld1q_f32(x); + in0 = vmlaq_f32(__bias, __scale, in0); + in0 = math::vActiveq_f32(in0); + vst1q_f32(y, in0); + } +#endif + for (; j < spatial_size; ++j, ++x, ++y) { + *y = math::Active(alpha * (*x) + beta); + } + } + } +} + +template +void ScaleAddChannelWise(const framework::Tensor *input, + const framework::Tensor *scale, + const framework::Tensor *bias, + const framework::Tensor *tensorwise_bias, + framework::Tensor *output) { + const float *input_ptr = input->data(); + const float *scale_ptr = scale->data(); + const float *bias_ptr = bias->data(); + const float *tensorwise_bias_ptr = tensorwise_bias->data(); + float *output_ptr = output->mutable_data(); + // maybe check shape + int batch_size = input->dims()[0]; + int channels = input->dims()[1]; + int spatial_size = input->dims()[2] * input->dims()[3]; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int channel = 0; channel < channels; ++channel) { + size_t offset = (batch * channels + channel) * spatial_size; + const float *x = input_ptr + offset; + const float *b = tensorwise_bias_ptr + offset; + float *y = output_ptr + offset; + float alpha = scale_ptr[channel]; + float beta = bias_ptr[channel]; + int j = 0; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float32x4_t __scale = vdupq_n_f32(alpha); + float32x4_t __bias = vdupq_n_f32(beta); + for (; j < spatial_size - 15; j += 16, x += 16, b += 16, y += 16) { + float32x4_t in0 = vld1q_f32(x); + float32x4_t in1 = vld1q_f32(x + 4); + float32x4_t in2 = vld1q_f32(x + 8); + float32x4_t in3 = vld1q_f32(x + 12); + float32x4_t b0 = vld1q_f32(b); + float32x4_t b1 = vld1q_f32(b + 4); + float32x4_t b2 = vld1q_f32(b + 8); + float32x4_t b3 = vld1q_f32(b + 12); + in0 = vmlaq_f32(__bias, __scale, in0); + in1 = vmlaq_f32(__bias, __scale, in1); + in2 = vmlaq_f32(__bias, __scale, in2); + in3 = vmlaq_f32(__bias, __scale, in3); + in0 = vaddq_f32(in0, b0); + in1 = vaddq_f32(in1, b1); + in2 = vaddq_f32(in2, b2); + in3 = vaddq_f32(in3, b3); + in0 = math::vActiveq_f32(in0); + in1 = math::vActiveq_f32(in1); + in2 = math::vActiveq_f32(in2); + in3 = math::vActiveq_f32(in3); + vst1q_f32(y, in0); + vst1q_f32(y + 4, in1); + vst1q_f32(y + 8, in2); + vst1q_f32(y + 12, in3); + } + for (; j < spatial_size - 3; j += 4, x += 4, b += 4, y += 4) { + float32x4_t in0 = vld1q_f32(x); + float32x4_t b0 = vld1q_f32(b); + in0 = vmlaq_f32(__bias, __scale, in0); + in0 = vaddq_f32(in0, b0); + in0 = math::vActiveq_f32(in0); + vst1q_f32(y, in0); + } +#endif + for (; j < spatial_size; ++j, ++x, ++b, ++y) { + *y = math::Active(alpha * (*x) + beta + (*b)); + } + } + } +} + +template +void AddElememtWise(const framework::Tensor *input, + const framework::Tensor *bias, const int axis, + framework::Tensor *output) { + const auto &x_dims = input->dims(); + const auto &y_dims = bias->dims(); + const float *input_data = input->data(); + const float *bias_data = bias->data(); + float *output_data = output->mutable_data(); + + if (x_dims == y_dims) { + int remain_start = 0; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + remain_start = input->numel() & 0xfffffffc; + + #pragma omp parallel for + for (int i = 0; i < input->numel() - 15; i += 16) { + float32x4_t r0 = vld1q_f32(input_data); + float32x4_t r1 = vld1q_f32(input_data + 4); + float32x4_t r2 = vld1q_f32(input_data + 8); + float32x4_t r3 = vld1q_f32(input_data + 12); + float32x4_t b0 = vld1q_f32(bias_data); + float32x4_t b1 = vld1q_f32(bias_data + 4); + float32x4_t b2 = vld1q_f32(bias_data + 8); + float32x4_t b3 = vld1q_f32(bias_data + 12); + r0 = vaddq_f32(r0, b0); + r1 = vaddq_f32(r1, b1); + r2 = vaddq_f32(r2, b2); + r3 = vaddq_f32(r3, b3); + r0 = math::vActiveq_f32(r0); + r1 = math::vActiveq_f32(r1); + r2 = math::vActiveq_f32(r2); + r3 = math::vActiveq_f32(r3); + vst1q_f32(output_data, r0); + vst1q_f32(output_data + 4, r1); + vst1q_f32(output_data + 8, r2); + vst1q_f32(output_data + 12, r3); + input_data += 16; + bias_data += 16; + output_data += 16; + } + for (int i = input->numel() & 0xfffffff0; i < input->numel() - 3; i += 4) { + float32x4_t r0 = vld1q_f32(input_data); + float32x4_t b0 = vld1q_f32(bias_data); + r0 = vaddq_f32(r0, b0); + r0 = math::vActiveq_f32(r0); + vst1q_f32(output_data, r0); + input_data += 4; + bias_data += 4; + output_data += 4; + } +#endif // __ARM_NEON__ + for (int i = remain_start; i < input->numel(); ++i) { + output_data[i] = math::Active(input_data[i] + bias_data[i]); + } + } else { + // axis = -1 represent the last dimensions. + int dim = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + size_t batch = 1; + size_t channels = 1; + size_t elementwise_num = 1; + for (int i = 0; i < dim; ++i) { + batch *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + channels *= y_dims[i]; + } + for (int i = y_dims.size() + dim; i < x_dims.size(); ++i) { + elementwise_num *= x_dims[i]; + } + + #pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + size_t offset = (i * channels + j) * elementwise_num; + const float *input = input_data + offset; + const float bias = bias_data[j]; + float *output = output_data + offset; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = elementwise_num >> 0x4; + int remain = elementwise_num & 0xF; + float32x4_t rb = vdupq_n_f32(bias); + for (int k = 0; k < loop; ++k) { + float32x4_t r0 = vld1q_f32(input); + float32x4_t r1 = vld1q_f32(input + 4); + float32x4_t r2 = vld1q_f32(input + 8); + float32x4_t r3 = vld1q_f32(input + 12); + r0 = vaddq_f32(r0, rb); + r1 = vaddq_f32(r1, rb); + r2 = vaddq_f32(r2, rb); + r3 = vaddq_f32(r3, rb); + r0 = math::vActiveq_f32(r0); + r1 = math::vActiveq_f32(r1); + r2 = math::vActiveq_f32(r2); + r3 = math::vActiveq_f32(r3); + vst1q_f32(output, r0); + vst1q_f32(output + 4, r1); + vst1q_f32(output + 8, r2); + vst1q_f32(output + 12, r3); + input += 16; + output += 16; + } + if (remain >= 8) { + float32x4_t r0 = vld1q_f32(input); + float32x4_t r1 = vld1q_f32(input + 4); + r0 = vaddq_f32(r0, rb); + r1 = vaddq_f32(r1, rb); + r0 = math::vActiveq_f32(r0); + r1 = math::vActiveq_f32(r1); + vst1q_f32(output, r0); + vst1q_f32(output + 4, r1); + input += 8; + output += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t r0 = vld1q_f32(input); + r0 = vaddq_f32(r0, rb); + r0 = math::vActiveq_f32(r0); + vst1q_f32(output, r0); + input += 4; + output += 4; + remain -= 4; + } + if (remain > 0) { + float32x4_t r0 = vld1q_f32(input); + r0 = vaddq_f32(r0, rb); + r0 = math::vActiveq_f32(r0); + switch (remain) { + case 1: + vst1q_lane_f32(output, r0, 0); + break; + case 2: + vst1_f32(output, vget_low_f32(r0)); + break; + case 3: + vst1_f32(output, vget_low_f32(r0)); + vst1q_lane_f32(output, r0, 2); + break; + } + } +#else + for (int k = 0; k < elementwise_num; ++k) { + output[k] = math::Active(input[k] + bias); + } +#endif // __ARM_NEON__ + } + } + } +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/gemm/gemm_kernel.h b/src/operators/math/gemm/gemm_kernel.h index eea54114786ff14f21318fba50c83303f08a8dab..2d2985a39c822b8bec7a090b04c9472cbd6b87f4 100644 --- a/src/operators/math/gemm/gemm_kernel.h +++ b/src/operators/math/gemm/gemm_kernel.h @@ -388,7 +388,7 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha, vst1q_f32(output, _sum0); } // remain m - for (int m = (M & 0xfffc); m < M; ++m) { + for (int m = (M & 0xfffffffc); m < M; ++m) { const float *in0 = A + m * lda; float *output = C + m; float32x4_t _sum0 = vdupq_n_f32(0.f); @@ -426,7 +426,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha, for (int m = 0; m < M - 3; m += 4) { vst1q_f32(C + m, vzero); } - for (int m = (M & 0xfffc); m < M; ++m) { + for (int m = (M & 0xfffffffc); m < M; ++m) { C[m] = 0.f; } } else { @@ -436,7 +436,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha, _vc = vmulq_f32(_vc, vbeta); vst1q_f32(C + m, _vc); } - for (int m = (M & 0xfffc); m < M; ++m) { + for (int m = (M & 0xfffffffc); m < M; ++m) { C[m] *= beta; } } @@ -491,7 +491,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha, } } // remain n - for (int n = (N & 0xfffc); n < N; ++n) { + for (int n = (N & 0xfffffffc); n < N; ++n) { const float *in0 = A + n * lda; float32x4_t _b = vld1q_dup_f32(B + n); float32x4_t _sum0; diff --git a/src/operators/math/gemm/pack_kernel.h b/src/operators/math/gemm/pack_kernel.h index b1f6a9d35ec5630a1bc0ae9fc997dcb05419f0aa..d3b135961056192583afa9ff59516094437720c9 100644 --- a/src/operators/math/gemm/pack_kernel.h +++ b/src/operators/math/gemm/pack_kernel.h @@ -325,7 +325,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output, : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); } for (; j < n - 7; j += 8) { - float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF); + float *out_ptr0 = output + (j & 0xFFFFFFF0) * k + 16 * i + (j & 0xF); int step = 64; asm volatile( "ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n" @@ -343,7 +343,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output, : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); } if (j < n) { - float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF); + float *out_ptr0 = output + (j & 0xFFFFFFF0) * k + 16 * i + (j & 0xF); int step = 64; asm volatile( "ld1 {v0.4s, v1.4s}, [%[b0]] \n" @@ -372,7 +372,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output, } if (j & 0xf) { - float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF); + float *out_ptr0 = output + (j & 0xFFFFFFF0) * k + 16 * i + (j & 0xF); vst1q_f32(out_ptr0, vzero); vst1q_f32(out_ptr0 + 4, vzero); out_ptr0 += 16; @@ -387,7 +387,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output, } } // remain k - for (int i = (k & 0xFFFC); i < k; ++i) { + for (int i = (k & 0xFFFFFFFC); i < k; ++i) { const float *b0 = B + i * ldb; int j = 0; asm volatile("prfm pldl1keep, [%[b0]] \n" @@ -404,7 +404,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output, : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); } for (; j < n - 7; j += 8) { - float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF); + float *out_ptr0 = output + (j & 0xFFFFFFF0) * k + 16 * i + (j & 0xF); int step = 64; asm volatile( "ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n" @@ -414,7 +414,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output, : "memory", "v0", "v1"); } if (j < n) { - float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF); + float *out_ptr0 = output + (j & 0xFFFFFFF0) * k + 16 * i + (j & 0xF); asm volatile( "ld1 {v0.4s, v1.4s}, [%[b0]] \n" "and v0.16b, v0.16b, %[vmask1].16b \n" @@ -426,7 +426,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output, j += 8; } if (j & 0xf) { - float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF); + float *out_ptr0 = output + (j & 0xFFFFFFF0) * k + 16 * i + (j & 0xF); vst1q_f32(out_ptr0, vzero); vst1q_f32(out_ptr0 + 4, vzero); } @@ -517,7 +517,7 @@ void pack_rhs_8c(int k, int n, const float *B, int ldb, float *output, } } // remain k - for (int i = (k & 0xFFFC); i < k; ++i) { + for (int i = (k & 0xFFFFFFFC); i < k; ++i) { const float *b0 = B + i * ldb; int j = 0; for (; j < n - 15; j += 16) { diff --git a/src/operators/math/pooling2x2.cpp b/src/operators/math/pooling2x2.cpp index 675a6392ed21dce4f9e324bc2dacd8609e2de999..1d8845ce69743b32f5901e0b6fa8c92b9cc05d0b 100644 --- a/src/operators/math/pooling2x2.cpp +++ b/src/operators/math/pooling2x2.cpp @@ -424,7 +424,7 @@ struct Pooling2x2 { } } // remain height - int start_h = valid_h_start + (valid_h & 0xFFFC); + int start_h = valid_h_start + (valid_h & 0xFFFFFFFC); for (int h = start_h; h < valid_h_end; ++h) { const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; const float *input_ptr1 = input_ptr0 + input_w; @@ -692,7 +692,7 @@ struct Pooling2x2 { } } // remain height - int start_h = valid_h_start + (valid_h & 0xfffe); + int start_h = valid_h_start + (valid_h & 0xfffffffe); for (int h = start_h; h < valid_h_end; ++h) { const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w; const float *input_ptr1 = input_ptr0 + input_w; diff --git a/src/operators/math/pooling3x3.cpp b/src/operators/math/pooling3x3.cpp index 35029c6425c07b4bed03d667a014bc3e7d960df6..e67404469334aec33d66fe1c0bc51aadbb0ffe93 100644 --- a/src/operators/math/pooling3x3.cpp +++ b/src/operators/math/pooling3x3.cpp @@ -560,7 +560,7 @@ struct Pooling3x3 { } } // remain height - int start_h = valid_h_start + (valid_h & 0xFFFC); + int start_h = valid_h_start + (valid_h & 0xFFFFFFFC); for (int h = start_h; h < valid_h_end; ++h) { const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; const float *input_ptr1 = input_ptr0 + input_w; diff --git a/src/operators/math/winograd/winograd_transform_f6k3.cpp b/src/operators/math/winograd/winograd_transform_f6k3.cpp index 234de599ad2cab72c471176330bc3c0aacd02d5f..4ba0ee4cb60b569cc7208c10fe1983e16dbfffbb 100644 --- a/src/operators/math/winograd/winograd_transform_f6k3.cpp +++ b/src/operators/math/winograd/winograd_transform_f6k3.cpp @@ -55,7 +55,7 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, #if __aarch64__ int remain_start = 0; #else - int remain_start = out_channel & 0xFFFC; + int remain_start = out_channel & 0xFFFFFFFC; #pragma omp parallel for for (int oc = 0; oc < out_channel - 3; oc += 4) { @@ -268,7 +268,7 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, float gw[3][8]; // gw[3][8] const float *inptr0 = inptr + oc * in_channel * 9; // // (oc / 4) * 64 * in_channel * 4 + oc % 4 - int offset = ((oc & 0xFFFC) << 6) * in_channel + (oc & 0x3); + int offset = ((oc & 0xFFFFFFFC) << 6) * in_channel + (oc & 0x3); int steps = (in_channel << 2); // in_channel * 4 float *outptr = trans_outptr + offset; for (int ic = 0; ic < in_channel; ++ic) { diff --git a/src/pass/memory_optimize.cpp b/src/pass/memory_optimize.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc754491fae9ba9a1604a4941a67015314bb2b13 --- /dev/null +++ b/src/pass/memory_optimize.cpp @@ -0,0 +1,141 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "pass/memory_optimize.h" +#include "framework/lod_tensor.h" + +namespace paddle_mobile { +namespace pass { + +void MemoryOptPass::AppendBlockVars(const framework::BlockDesc *block) { + // block_vars_.clear(); + for (const auto var : block->Vars()) { + block_vars_[var->Name()] = var.get(); + } +} + +bool MemoryOptPass::IsPersistable(const std::string name) { + const auto it = block_vars_.find(name); + if (it != block_vars_.end()) { + return it->second->Persistable(); + } + return false; +} + +VarNode *MemoryOptPass::CreateNode(const std::string name) { + auto it = created_nodes_.find(name); + if (it != created_nodes_.end()) { + ++(it->second->count); + return it->second; + } + VarNode *var = new VarNode; + var->name = name; + var->count = 1; + var->visited = false; + created_nodes_[name] = var; + return var; +} + +void MemoryOptPass::operator()(const framework::ProgramDesc *program, + framework::Scope *scope) { + const auto &blocks = program->Blocks(); + for (const auto &block : blocks) { + // access all variables in each block + AppendBlockVars(block.get()); + + reused_nodes_.clear(); + // collect all not persistable variables, and accumulate + // it's reference count + std::stack empty_var_nodes; + analysis_nodes_.swap(empty_var_nodes); + + for (const auto &op : block->Ops()) { + DLOG << "op_desc->Type(): " << op->Type(); + for (const auto &outputs : op->GetOutputs()) { + for (const auto &output : outputs.second) { + if (!IsPersistable(output)) { + DLOG << "output: " << output; + VarNode *node = CreateNode(output); + analysis_nodes_.push(node); + } + } + } + for (const auto &inputs : op->GetInputs()) { + for (const auto &input : inputs.second) { + if (!IsPersistable(input)) { + DLOG << "input: " << input; + VarNode *node = CreateNode(input); + analysis_nodes_.push(node); + } + } + } + for (const auto &outputs : op->GetOutputs()) { + for (const auto &output : outputs.second) { + if (!IsPersistable(output)) { + DLOG << "output: " << output; + VarNode *node = CreateNode(output); + analysis_nodes_.push(node); + } + } + } + } + + // apply optimize + while (!analysis_nodes_.empty()) { + auto *node = analysis_nodes_.top(); + analysis_nodes_.pop(); + // only not visited node can reuse memory between other nodes + // with 0 count which indicate they will not be used any more + if (!node->visited) { + bool reused = false; + // find out a possable reuse list + for (auto &list : reused_nodes_) { + if (list.back()->count == 0) { + list.push_back(node); + reused = true; + break; + } + } + // create new list if can't find a reused list + if (!reused) { + std::vector list; + list.push_back(node); + reused_nodes_.push_back(std::move(list)); + } + } + node->visited = true; + node->count -= 1; + } + + // shared data within all variables in the same reused list + for (const auto &list : reused_nodes_) { + DLOG << "\n"; + DLOG << "share memory within these variables"; + std::string name = list[0]->name; + auto *reused_var = scope->Var(name); + auto *reuse_tensor = + reused_var->template GetMutable(); + reuse_tensor->mutable_data(); + for (const auto &node : list) { + DLOG << node->name; + auto *var = scope->Var(node->name); + auto *tensor = var->template GetMutable(); + tensor->ShareHolderWith(*reuse_tensor); + } + } + } +} + +} // namespace pass +} // namespace paddle_mobile diff --git a/src/pass/memory_optimize.h b/src/pass/memory_optimize.h new file mode 100644 index 0000000000000000000000000000000000000000..116100af0bae137d74bbc9aaa24a8f8d61d9dfdf --- /dev/null +++ b/src/pass/memory_optimize.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include "framework/program/program.h" + +namespace paddle_mobile { +namespace pass { + +typedef struct { + std::string name; // variable name + int count; // reference count + bool visited; +} VarNode; + +class PassBase { + public: + PassBase() {} + virtual ~PassBase() {} +}; + +// MemoryOptPass will analyze the program, and reuse memory between +// variables as much as possible +class MemoryOptPass : public PassBase { + public: + MemoryOptPass() {} + virtual ~MemoryOptPass() { + for (auto &it : created_nodes_) { + delete it.second; + } + } + + void operator()(const framework::ProgramDesc *program, + framework::Scope *scope); + + void AppendBlockVars(const framework::BlockDesc *block); + + bool IsPersistable(const std::string name); + + VarNode *CreateNode(const std::string name); + + private: + std::stack analysis_nodes_; + std::vector> reused_nodes_; + std::unordered_map created_nodes_; + std::unordered_map block_vars_; +}; + +} // namespace pass +} // namespace paddle_mobile