未验证 提交 54825257 编写于 作者: Z zhangyang0701 提交者: GitHub

Merge pull request #1509 from hjchen2/backup

Add memory optimize pass, which will make all variables reuse memory as much as possible
......@@ -28,7 +28,7 @@ limitations under the License. */
#include "framework/scope.h"
#include "framework/tensor.h"
#include "memory/t_malloc.h"
#include "pass/memory_optimize.h"
#ifdef PADDLE_MOBILE_CL
#include "framework/cl/cl_image.h"
#endif
......@@ -62,6 +62,7 @@ Executor<Device, T>::Executor(const Program<Device> &program,
use_optimize_ ? program_.optimizeProgram : program_.originProgram;
PADDLE_MOBILE_ENFORCE(program_desc_ != nullptr,
"program_desc_ should not be nullptr");
pass::MemoryOptPass()(program_desc_.get(), program_.scope.get());
// resize feed and fetch list
// should init feed and fetch variables before infer shape
InitFeedFetchList();
......@@ -210,6 +211,7 @@ void Executor<Device, T>::InitMemory() {
var->template GetMutable<framework::LoDTensorArray>();
continue;
}
DLOG << "init persistable var: " << var_desc->Name();
char *origin_data =
ReadFileToBuff(program_.model_path + "/" + var_desc->Name());
char *data = origin_data;
......@@ -322,7 +324,6 @@ bool Executor<Device, T>::varInputMemory(
if (type == VARTYPE_TYPE_LOD_TENSOR) {
auto data_type = var_desc->Tensor_desc().DataType();
framework::LoDTensor *tensor = var->template GetMutable<LoDTensor>();
tensor->mutable_data(TypeId(data_type));
} else if (type == VARTYPE_TYPE_STEP_SCOPES) {
std::vector<framework::Scope *> *step_scopes =
var->template GetMutable<std::vector<framework::Scope *>>();
......@@ -458,6 +459,7 @@ PMStatus Executor<Device, T>::Predict() {
clock_gettime(CLOCK_MONOTONIC, &ts);
profile[op_index].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
DLOG << "run op: " << op_handler->Type();
if (lod_mode_) {
op_handler->InferShape();
}
......
......@@ -46,7 +46,7 @@ ProgramDesc::ProgramDesc(PaddleMobile__Framework__Proto__ProgramDesc *desc) {
}
}
void ProgramDesc::Description(std::string header) {
void ProgramDesc::Description(std::string header) const {
#ifdef PADDLE_MOBILE_DEBUG
if (header.size()) {
LOG(kLOG_INFO) << header;
......
......@@ -30,6 +30,14 @@ class ProgramDesc {
friend class ProgramOptimize;
explicit ProgramDesc(PaddleMobile__Framework__Proto__ProgramDesc *desc);
ProgramDesc(const ProgramDesc &program_desc) {
for (auto &block : program_desc.blocks_) {
std::shared_ptr<BlockDesc> copy_block =
std::make_shared<BlockDesc>(*block);
blocks_.push_back(copy_block);
}
}
std::shared_ptr<BlockDesc> Block(size_t idx);
BlockDesc *MutableBlock(size_t idx) {
......@@ -40,16 +48,11 @@ class ProgramDesc {
}
}
const std::vector<std::shared_ptr<BlockDesc>> &Blocks() { return blocks_; }
ProgramDesc(const ProgramDesc &program_desc) {
for (auto &block : program_desc.blocks_) {
std::shared_ptr<BlockDesc> copy_block =
std::make_shared<BlockDesc>(*block);
blocks_.push_back(copy_block);
}
const std::vector<std::shared_ptr<BlockDesc>> &Blocks() const {
return blocks_;
}
void Description(std::string header = "");
void Description(std::string header = "") const;
private:
std::vector<std::shared_ptr<BlockDesc>> blocks_;
......
......@@ -74,6 +74,15 @@ class Tensor : public TensorBase {
return *this;
}
/*! The internal of two tensors share the same memory block. */
inline Tensor &ShareHolderWith(const Tensor &src) {
src.check_memory_size();
if (holder_.get() != src.holder_.get()) {
holder_ = src.holder_;
}
return *this;
}
inline void *mutable_data(std::type_index type) {
if (holder_ != nullptr) {
holder_->set_type(type);
......@@ -81,7 +90,11 @@ class Tensor : public TensorBase {
PADDLE_MOBILE_ENFORCE(numel() >= 0, "the Tensor's numel must >=0.")
int64_t size = numel() * SizeOfType(type);
if (holder_ == nullptr || holder_->size() < size + offset_) {
holder_.reset(new PlaceholderImpl(size, type));
if (holder_ == nullptr) {
holder_.reset(new PlaceholderImpl(size, type));
} else {
holder_->resize(size);
}
offset_ = 0;
}
return reinterpret_cast<void *>(
......@@ -180,6 +193,7 @@ class Tensor : public TensorBase {
: ptr_(static_cast<uint8_t *>(memory::Alloc(size)),
memory::PODDeleter<uint8_t>()),
size_(size),
capatity_(size),
type_(type) {
PADDLE_MOBILE_ENFORCE(ptr_ != nullptr,
"Insufficient memory to allocation");
......@@ -193,11 +207,21 @@ class Tensor : public TensorBase {
virtual void set_type(std::type_index type) { type_ = type; }
virtual void resize(size_t size) {
if (size > capatity_) {
capatity_ = size;
ptr_.reset(static_cast<uint8_t *>(memory::Alloc(capatity_)));
}
size_ = size;
}
std::unique_ptr<uint8_t, memory::PODDeleter<uint8_t>> ptr_;
/*! the size of memory block. */
size_t size_;
size_t capatity_;
/* the current type of memory */
std::type_index type_;
};
......
......@@ -117,6 +117,8 @@ class TensorBase {
virtual std::type_index type() const = 0;
virtual void set_type(std::type_index type) = 0;
virtual void resize(size_t size) = 0;
};
/**
......
......@@ -79,7 +79,7 @@ struct CompareCompute<float, Comp> {
if (elementwise_num == 1) {
int remain_start = 0;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
remain_start = channels & 0xfff8;
remain_start = channels & 0xfffffff8;
uint8x8_t __mask = vdup_n_u8(0x1);
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels - 7; j += 8) {
......@@ -112,7 +112,7 @@ struct CompareCompute<float, Comp> {
int y_offset = j * elementwise_num;
int remain_start = 0;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
remain_start = elementwise_num & 0xfff8;
remain_start = elementwise_num & 0xfffffff8;
uint8x8_t __mask = vdup_n_u8(0x1);
for (int k = 0; k < elementwise_num - 7; k += 8) {
float32x4_t __x0 = vld1q_f32(x + x_offset);
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <cmath>
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/channel_wise.h"
#include "operators/math/element_wise.h"
namespace paddle_mobile {
namespace operators {
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "operators/kernel/conv_add_kernel.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/channel_wise.h"
#include "operators/math/element_wise.h"
namespace paddle_mobile {
namespace operators {
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "operators/kernel/conv_add_relu_kernel.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/channel_wise.h"
#include "operators/math/element_wise.h"
namespace paddle_mobile {
namespace operators {
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <cmath>
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/channel_wise.h"
#include "operators/math/element_wise.h"
namespace paddle_mobile {
namespace operators {
......@@ -34,27 +34,15 @@ bool ConvBNAddReluKernel<CPU, float>::Init(
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
auto scale_ptr = const_cast<float *>(scale->data<float>());
auto bias_ptr = const_cast<float *>(bias->data<float>());
const int C = mean->numel();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
for (int c = 0; c < scale->numel(); ++c) {
float inv_scale = 1.f / (pow(variance_ptr[c] + epsilon, 0.5));
bias_ptr[c] -= inv_scale * scale_ptr[c] * mean_ptr[c];
scale_ptr[c] *= inv_scale;
}
auto *new_scale = param->CreateNewScale<framework::LoDTensor>();
auto *new_bias = param->CreateNewBiase<framework::LoDTensor>();
auto new_scale_ptr = new_scale->mutable_data<float>({C});
auto new_bias_ptr = new_bias->mutable_data<float>({C});
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
InitBaseConvKernel(param);
return true;
}
......@@ -84,9 +72,19 @@ void ConvBNAddReluKernel<CPU, float>::Compute(
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output());
if (param.Bias()->dims() == param.Output()->dims()) {
math::ScaleAddChannelWise<RELU>(param.Output(), param.InputScale(),
param.InputBias(), param.Bias(),
param.Output());
} else {
math::ScaleAddChannelWise<IDENTITY>(param.Output(), param.InputScale(),
param.InputBias(), param.Output());
math::AddElememtWise<RELU>(param.Output(), param.Bias(), param.Axis(),
param.Output());
}
}
template class ConvBNAddReluKernel<CPU, float>;
} // namespace operators
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <cmath>
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/channel_wise.h"
#include "operators/math/element_wise.h"
namespace paddle_mobile {
namespace operators {
......
......@@ -52,7 +52,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT;
} else if (conv3x3 && !depth3x3 &&
} else if (conv3x3 && param->Groups() == 1 &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1
......@@ -66,7 +66,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
param->transformed_filter_ = new framework::LoDTensor;
operators::math::winograd_transform_weight<8, 3>(
*param->Filter(), param->transformed_filter_);
} else if (conv3x3 && !depth3x3 &&
} else if (conv3x3 && param->Groups() == 1 &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1
......@@ -76,7 +76,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
#endif
) {
param->ExecMode() = ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT;
} else if (conv3x3 && !depth3x3 &&
} else if (conv3x3 && param->Groups() == 1 &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 2 && param->Dilations()[0] == 1
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <cmath>
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/channel_wise.h"
#include "operators/math/element_wise.h"
namespace paddle_mobile {
namespace operators {
......
......@@ -68,7 +68,7 @@ void SequencePoolImpl(const framework::LoDTensor &input,
int remain_h = height - 1;
int remain_w_start = 0;
#ifdef __ARM_NEON__
remain_w_start = width & 0xfffc;
remain_w_start = width & 0xfffffffc;
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
#ifdef __ARM_NEON__
......@@ -128,7 +128,7 @@ void SequencePoolImpl<SUM, float>(const framework::LoDTensor &input,
int remain_w_start = 0;
#ifdef __ARM_NEON__
int loop_w = width >> 2;
remain_w_start = width & 0xfffc;
remain_w_start = width & 0xfffffffc;
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
#ifdef __ARM_NEON__
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#pragma once
#include "operators/math/elementwise_op_function.h"
#include "operators/math/element_wise.h"
#include "operators/op_param.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
......@@ -29,98 +29,10 @@ template <typename T>
inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
const framework::Tensor *input_x = param.InputX();
const framework::Tensor *input_y = param.InputY();
framework::Tensor *Out = param.Out();
framework::Tensor *output = param.Out();
int axis = param.Axis();
const auto &x_dims = input_x->dims();
const auto &y_dims = input_y->dims();
/// axis = -1 represent the last dimensions.
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
size_t batch = 1;
size_t channels = 1;
size_t elementwise_num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
elementwise_num *= x_dims[i];
}
const float *bias_data = input_y->data<float>();
const float *input_data = input_x->data<float>();
float *output_data = Out->mutable_data<float>();
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
size_t offset = (i * channels + j) * elementwise_num;
const float *input = input_data + offset;
const float bias = bias_data[j];
float *output = output_data + offset;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = elementwise_num >> 0x4;
int remain = elementwise_num & 0xF;
float32x4_t rb = vdupq_n_f32(bias);
for (int k = 0; k < loop; ++k) {
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
float32x4_t r2 = vld1q_f32(input + 8);
float32x4_t r3 = vld1q_f32(input + 12);
r0 = vaddq_f32(r0, rb);
r1 = vaddq_f32(r1, rb);
r2 = vaddq_f32(r2, rb);
r3 = vaddq_f32(r3, rb);
vst1q_f32(output, r0);
vst1q_f32(output + 4, r1);
vst1q_f32(output + 8, r2);
vst1q_f32(output + 12, r3);
input += 16;
output += 16;
}
if (remain >= 8) {
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
r0 = vaddq_f32(r0, rb);
r1 = vaddq_f32(r1, rb);
vst1q_f32(output, r0);
vst1q_f32(output + 4, r1);
input += 8;
output += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t r0 = vld1q_f32(input);
r0 = vaddq_f32(r0, rb);
vst1q_f32(output, r0);
input += 4;
output += 4;
remain -= 4;
}
if (remain > 0) {
float32x4_t r0 = vld1q_f32(input);
r0 = vaddq_f32(r0, rb);
switch (remain) {
case 1:
vst1q_lane_f32(output, r0, 0);
break;
case 2:
vst1_f32(output, vget_low_f32(r0));
break;
case 3:
vst1_f32(output, vget_low_f32(r0));
vst1q_lane_f32(output, r0, 2);
break;
}
}
#else
for (int k = 0; k < elementwise_num; ++k) {
output[k] = input[k] + bias;
}
#endif // __ARM_NEON__
}
}
math::AddElememtWise<IDENTITY>(input_x, input_y, axis, output);
}
template class ElementwiseAddKernel<CPU, float>;
......
......@@ -448,7 +448,7 @@ void DepthwiseConv3x3S1<float, float>(const framework::Tensor &input,
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
int start_h = valid_h_start + (valid_h & 0xfffffffe);
if (start_h < valid_h_end) {
const float *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
......@@ -906,7 +906,7 @@ void DepthwiseConv3x3S2<float, float>(const framework::Tensor &input,
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
int start_h = valid_h_start + (valid_h & 0xfffffffe);
if (start_h < valid_h_end) {
const float *input_ptr0 = input_ptr + (2 * start_h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
......
......@@ -580,7 +580,7 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xFFFC);
int start_h = valid_h_start + (valid_h & 0xFFFFFFFC);
for (int h = start_h; h < valid_h_end - 1; h += 2) {
const int8_t *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const int8_t *input_ptr1 = input_ptr0 + input_w;
......@@ -844,7 +844,7 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
}
}
start_h = valid_h_start + (valid_h & 0xFFFE);
start_h = valid_h_start + (valid_h & 0xFFFFFFFE);
if (start_h < valid_h_end) {
const int8_t *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
const int8_t *input_ptr1 = input_ptr0 + input_w;
......
......@@ -721,7 +721,7 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
int start_h = valid_h_start + (valid_h & 0xfffffffe);
if (start_h < valid_h_end) {
const float *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
......
......@@ -686,7 +686,7 @@ void DepthwiseConv5x5S1<int8_t, int32_t>(const framework::Tensor &input,
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
int start_h = valid_h_start + (valid_h & 0xfffffffe);
if (start_h < valid_h_end) {
const int8_t *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
const int8_t *input_ptr1 = input_ptr0 + input_w;
......
......@@ -133,6 +133,227 @@ void ScaleAddChannelWise(const framework::Tensor *input,
}
}
template <ActivationType Act>
void ScaleAddChannelWise(const framework::Tensor *input,
const framework::Tensor *scale,
const framework::Tensor *bias,
const framework::Tensor *tensorwise_bias,
framework::Tensor *output) {
const float *input_ptr = input->data<float>();
const float *scale_ptr = scale->data<float>();
const float *bias_ptr = bias->data<float>();
const float *tensorwise_bias_ptr = tensorwise_bias->data<float>();
float *output_ptr = output->mutable_data<float>();
// maybe check shape
int batch_size = input->dims()[0];
int channels = input->dims()[1];
int spatial_size = input->dims()[2] * input->dims()[3];
for (int batch = 0; batch < batch_size; ++batch) {
for (int channel = 0; channel < channels; ++channel) {
size_t offset = (batch * channels + channel) * spatial_size;
const float *x = input_ptr + offset;
const float *b = tensorwise_bias_ptr + offset;
float *y = output_ptr + offset;
float alpha = scale_ptr[channel];
float beta = bias_ptr[channel];
int j = 0;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4_t __scale = vdupq_n_f32(alpha);
float32x4_t __bias = vdupq_n_f32(beta);
for (; j < spatial_size - 15; j += 16, x += 16, b += 16, y += 16) {
float32x4_t in0 = vld1q_f32(x);
float32x4_t in1 = vld1q_f32(x + 4);
float32x4_t in2 = vld1q_f32(x + 8);
float32x4_t in3 = vld1q_f32(x + 12);
float32x4_t b0 = vld1q_f32(b);
float32x4_t b1 = vld1q_f32(b + 4);
float32x4_t b2 = vld1q_f32(b + 8);
float32x4_t b3 = vld1q_f32(b + 12);
in0 = vmlaq_f32(__bias, __scale, in0);
in1 = vmlaq_f32(__bias, __scale, in1);
in2 = vmlaq_f32(__bias, __scale, in2);
in3 = vmlaq_f32(__bias, __scale, in3);
in0 = vaddq_f32(in0, b0);
in1 = vaddq_f32(in1, b1);
in2 = vaddq_f32(in2, b2);
in3 = vaddq_f32(in3, b3);
in0 = math::vActiveq_f32<Act>(in0);
in1 = math::vActiveq_f32<Act>(in1);
in2 = math::vActiveq_f32<Act>(in2);
in3 = math::vActiveq_f32<Act>(in3);
vst1q_f32(y, in0);
vst1q_f32(y + 4, in1);
vst1q_f32(y + 8, in2);
vst1q_f32(y + 12, in3);
}
for (; j < spatial_size - 3; j += 4, x += 4, b += 4, y += 4) {
float32x4_t in0 = vld1q_f32(x);
float32x4_t b0 = vld1q_f32(b);
in0 = vmlaq_f32(__bias, __scale, in0);
in0 = vaddq_f32(in0, b0);
in0 = math::vActiveq_f32<Act>(in0);
vst1q_f32(y, in0);
}
#endif
for (; j < spatial_size; ++j, ++x, ++b, ++y) {
*y = math::Active<Act>(alpha * (*x) + beta + (*b));
}
}
}
}
template <ActivationType Act>
void AddElememtWise(const framework::Tensor *input,
const framework::Tensor *bias, const int axis,
framework::Tensor *output) {
const auto &x_dims = input->dims();
const auto &y_dims = bias->dims();
const float *input_data = input->data<float>();
const float *bias_data = bias->data<float>();
float *output_data = output->mutable_data<float>();
if (x_dims == y_dims) {
int remain_start = 0;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
remain_start = input->numel() & 0xfffffffc;
#pragma omp parallel for
for (int i = 0; i < input->numel() - 15; i += 16) {
float32x4_t r0 = vld1q_f32(input_data);
float32x4_t r1 = vld1q_f32(input_data + 4);
float32x4_t r2 = vld1q_f32(input_data + 8);
float32x4_t r3 = vld1q_f32(input_data + 12);
float32x4_t b0 = vld1q_f32(bias_data);
float32x4_t b1 = vld1q_f32(bias_data + 4);
float32x4_t b2 = vld1q_f32(bias_data + 8);
float32x4_t b3 = vld1q_f32(bias_data + 12);
r0 = vaddq_f32(r0, b0);
r1 = vaddq_f32(r1, b1);
r2 = vaddq_f32(r2, b2);
r3 = vaddq_f32(r3, b3);
r0 = math::vActiveq_f32<Act>(r0);
r1 = math::vActiveq_f32<Act>(r1);
r2 = math::vActiveq_f32<Act>(r2);
r3 = math::vActiveq_f32<Act>(r3);
vst1q_f32(output_data, r0);
vst1q_f32(output_data + 4, r1);
vst1q_f32(output_data + 8, r2);
vst1q_f32(output_data + 12, r3);
input_data += 16;
bias_data += 16;
output_data += 16;
}
for (int i = input->numel() & 0xfffffff0; i < input->numel() - 3; i += 4) {
float32x4_t r0 = vld1q_f32(input_data);
float32x4_t b0 = vld1q_f32(bias_data);
r0 = vaddq_f32(r0, b0);
r0 = math::vActiveq_f32<Act>(r0);
vst1q_f32(output_data, r0);
input_data += 4;
bias_data += 4;
output_data += 4;
}
#endif // __ARM_NEON__
for (int i = remain_start; i < input->numel(); ++i) {
output_data[i] = math::Active<Act>(input_data[i] + bias_data[i]);
}
} else {
// axis = -1 represent the last dimensions.
int dim = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
size_t batch = 1;
size_t channels = 1;
size_t elementwise_num = 1;
for (int i = 0; i < dim; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + dim; i < x_dims.size(); ++i) {
elementwise_num *= x_dims[i];
}
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
size_t offset = (i * channels + j) * elementwise_num;
const float *input = input_data + offset;
const float bias = bias_data[j];
float *output = output_data + offset;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = elementwise_num >> 0x4;
int remain = elementwise_num & 0xF;
float32x4_t rb = vdupq_n_f32(bias);
for (int k = 0; k < loop; ++k) {
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
float32x4_t r2 = vld1q_f32(input + 8);
float32x4_t r3 = vld1q_f32(input + 12);
r0 = vaddq_f32(r0, rb);
r1 = vaddq_f32(r1, rb);
r2 = vaddq_f32(r2, rb);
r3 = vaddq_f32(r3, rb);
r0 = math::vActiveq_f32<Act>(r0);
r1 = math::vActiveq_f32<Act>(r1);
r2 = math::vActiveq_f32<Act>(r2);
r3 = math::vActiveq_f32<Act>(r3);
vst1q_f32(output, r0);
vst1q_f32(output + 4, r1);
vst1q_f32(output + 8, r2);
vst1q_f32(output + 12, r3);
input += 16;
output += 16;
}
if (remain >= 8) {
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
r0 = vaddq_f32(r0, rb);
r1 = vaddq_f32(r1, rb);
r0 = math::vActiveq_f32<Act>(r0);
r1 = math::vActiveq_f32<Act>(r1);
vst1q_f32(output, r0);
vst1q_f32(output + 4, r1);
input += 8;
output += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t r0 = vld1q_f32(input);
r0 = vaddq_f32(r0, rb);
r0 = math::vActiveq_f32<Act>(r0);
vst1q_f32(output, r0);
input += 4;
output += 4;
remain -= 4;
}
if (remain > 0) {
float32x4_t r0 = vld1q_f32(input);
r0 = vaddq_f32(r0, rb);
r0 = math::vActiveq_f32<Act>(r0);
switch (remain) {
case 1:
vst1q_lane_f32(output, r0, 0);
break;
case 2:
vst1_f32(output, vget_low_f32(r0));
break;
case 3:
vst1_f32(output, vget_low_f32(r0));
vst1q_lane_f32(output, r0, 2);
break;
}
}
#else
for (int k = 0; k < elementwise_num; ++k) {
output[k] = math::Active<Act>(input[k] + bias);
}
#endif // __ARM_NEON__
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -388,7 +388,7 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha,
vst1q_f32(output, _sum0);
}
// remain m
for (int m = (M & 0xfffc); m < M; ++m) {
for (int m = (M & 0xfffffffc); m < M; ++m) {
const float *in0 = A + m * lda;
float *output = C + m;
float32x4_t _sum0 = vdupq_n_f32(0.f);
......@@ -426,7 +426,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha,
for (int m = 0; m < M - 3; m += 4) {
vst1q_f32(C + m, vzero);
}
for (int m = (M & 0xfffc); m < M; ++m) {
for (int m = (M & 0xfffffffc); m < M; ++m) {
C[m] = 0.f;
}
} else {
......@@ -436,7 +436,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha,
_vc = vmulq_f32(_vc, vbeta);
vst1q_f32(C + m, _vc);
}
for (int m = (M & 0xfffc); m < M; ++m) {
for (int m = (M & 0xfffffffc); m < M; ++m) {
C[m] *= beta;
}
}
......@@ -491,7 +491,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha,
}
}
// remain n
for (int n = (N & 0xfffc); n < N; ++n) {
for (int n = (N & 0xfffffffc); n < N; ++n) {
const float *in0 = A + n * lda;
float32x4_t _b = vld1q_dup_f32(B + n);
float32x4_t _sum0;
......
......@@ -325,7 +325,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output,
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
}
for (; j < n - 7; j += 8) {
float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF);
float *out_ptr0 = output + (j & 0xFFFFFFF0) * k + 16 * i + (j & 0xF);
int step = 64;
asm volatile(
"ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n"
......@@ -343,7 +343,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output,
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
}
if (j < n) {
float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF);
float *out_ptr0 = output + (j & 0xFFFFFFF0) * k + 16 * i + (j & 0xF);
int step = 64;
asm volatile(
"ld1 {v0.4s, v1.4s}, [%[b0]] \n"
......@@ -372,7 +372,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output,
}
if (j & 0xf) {
float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF);
float *out_ptr0 = output + (j & 0xFFFFFFF0) * k + 16 * i + (j & 0xF);
vst1q_f32(out_ptr0, vzero);
vst1q_f32(out_ptr0 + 4, vzero);
out_ptr0 += 16;
......@@ -387,7 +387,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output,
}
}
// remain k
for (int i = (k & 0xFFFC); i < k; ++i) {
for (int i = (k & 0xFFFFFFFC); i < k; ++i) {
const float *b0 = B + i * ldb;
int j = 0;
asm volatile("prfm pldl1keep, [%[b0]] \n"
......@@ -404,7 +404,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output,
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
}
for (; j < n - 7; j += 8) {
float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF);
float *out_ptr0 = output + (j & 0xFFFFFFF0) * k + 16 * i + (j & 0xF);
int step = 64;
asm volatile(
"ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n"
......@@ -414,7 +414,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output,
: "memory", "v0", "v1");
}
if (j < n) {
float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF);
float *out_ptr0 = output + (j & 0xFFFFFFF0) * k + 16 * i + (j & 0xF);
asm volatile(
"ld1 {v0.4s, v1.4s}, [%[b0]] \n"
"and v0.16b, v0.16b, %[vmask1].16b \n"
......@@ -426,7 +426,7 @@ void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output,
j += 8;
}
if (j & 0xf) {
float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF);
float *out_ptr0 = output + (j & 0xFFFFFFF0) * k + 16 * i + (j & 0xF);
vst1q_f32(out_ptr0, vzero);
vst1q_f32(out_ptr0 + 4, vzero);
}
......@@ -517,7 +517,7 @@ void pack_rhs_8c(int k, int n, const float *B, int ldb, float *output,
}
}
// remain k
for (int i = (k & 0xFFFC); i < k; ++i) {
for (int i = (k & 0xFFFFFFFC); i < k; ++i) {
const float *b0 = B + i * ldb;
int j = 0;
for (; j < n - 15; j += 16) {
......
......@@ -424,7 +424,7 @@ struct Pooling2x2<P, 1> {
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xFFFC);
int start_h = valid_h_start + (valid_h & 0xFFFFFFFC);
for (int h = start_h; h < valid_h_end; ++h) {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
......@@ -692,7 +692,7 @@ struct Pooling2x2<P, 2> {
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
int start_h = valid_h_start + (valid_h & 0xfffffffe);
for (int h = start_h; h < valid_h_end; ++h) {
const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
......
......@@ -560,7 +560,7 @@ struct Pooling3x3<P, 1> {
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xFFFC);
int start_h = valid_h_start + (valid_h & 0xFFFFFFFC);
for (int h = start_h; h < valid_h_end; ++h) {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
......
......@@ -55,7 +55,7 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
#if __aarch64__
int remain_start = 0;
#else
int remain_start = out_channel & 0xFFFC;
int remain_start = out_channel & 0xFFFFFFFC;
#pragma omp parallel for
for (int oc = 0; oc < out_channel - 3; oc += 4) {
......@@ -268,7 +268,7 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
float gw[3][8]; // gw[3][8]
const float *inptr0 = inptr + oc * in_channel * 9; //
// (oc / 4) * 64 * in_channel * 4 + oc % 4
int offset = ((oc & 0xFFFC) << 6) * in_channel + (oc & 0x3);
int offset = ((oc & 0xFFFFFFFC) << 6) * in_channel + (oc & 0x3);
int steps = (in_channel << 2); // in_channel * 4
float *outptr = trans_outptr + offset;
for (int ic = 0; ic < in_channel; ++ic) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "pass/memory_optimize.h"
#include "framework/lod_tensor.h"
namespace paddle_mobile {
namespace pass {
void MemoryOptPass::AppendBlockVars(const framework::BlockDesc *block) {
// block_vars_.clear();
for (const auto var : block->Vars()) {
block_vars_[var->Name()] = var.get();
}
}
bool MemoryOptPass::IsPersistable(const std::string name) {
const auto it = block_vars_.find(name);
if (it != block_vars_.end()) {
return it->second->Persistable();
}
return false;
}
VarNode *MemoryOptPass::CreateNode(const std::string name) {
auto it = created_nodes_.find(name);
if (it != created_nodes_.end()) {
++(it->second->count);
return it->second;
}
VarNode *var = new VarNode;
var->name = name;
var->count = 1;
var->visited = false;
created_nodes_[name] = var;
return var;
}
void MemoryOptPass::operator()(const framework::ProgramDesc *program,
framework::Scope *scope) {
const auto &blocks = program->Blocks();
for (const auto &block : blocks) {
// access all variables in each block
AppendBlockVars(block.get());
reused_nodes_.clear();
// collect all not persistable variables, and accumulate
// it's reference count
std::stack<VarNode *> empty_var_nodes;
analysis_nodes_.swap(empty_var_nodes);
for (const auto &op : block->Ops()) {
DLOG << "op_desc->Type(): " << op->Type();
for (const auto &outputs : op->GetOutputs()) {
for (const auto &output : outputs.second) {
if (!IsPersistable(output)) {
DLOG << "output: " << output;
VarNode *node = CreateNode(output);
analysis_nodes_.push(node);
}
}
}
for (const auto &inputs : op->GetInputs()) {
for (const auto &input : inputs.second) {
if (!IsPersistable(input)) {
DLOG << "input: " << input;
VarNode *node = CreateNode(input);
analysis_nodes_.push(node);
}
}
}
for (const auto &outputs : op->GetOutputs()) {
for (const auto &output : outputs.second) {
if (!IsPersistable(output)) {
DLOG << "output: " << output;
VarNode *node = CreateNode(output);
analysis_nodes_.push(node);
}
}
}
}
// apply optimize
while (!analysis_nodes_.empty()) {
auto *node = analysis_nodes_.top();
analysis_nodes_.pop();
// only not visited node can reuse memory between other nodes
// with 0 count which indicate they will not be used any more
if (!node->visited) {
bool reused = false;
// find out a possable reuse list
for (auto &list : reused_nodes_) {
if (list.back()->count == 0) {
list.push_back(node);
reused = true;
break;
}
}
// create new list if can't find a reused list
if (!reused) {
std::vector<VarNode *> list;
list.push_back(node);
reused_nodes_.push_back(std::move(list));
}
}
node->visited = true;
node->count -= 1;
}
// shared data within all variables in the same reused list
for (const auto &list : reused_nodes_) {
DLOG << "\n";
DLOG << "share memory within these variables";
std::string name = list[0]->name;
auto *reused_var = scope->Var(name);
auto *reuse_tensor =
reused_var->template GetMutable<framework::LoDTensor>();
reuse_tensor->mutable_data<float>();
for (const auto &node : list) {
DLOG << node->name;
auto *var = scope->Var(node->name);
auto *tensor = var->template GetMutable<framework::LoDTensor>();
tensor->ShareHolderWith(*reuse_tensor);
}
}
}
}
} // namespace pass
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stack>
#include <string>
#include <unordered_map>
#include <vector>
#include "framework/program/program.h"
namespace paddle_mobile {
namespace pass {
typedef struct {
std::string name; // variable name
int count; // reference count
bool visited;
} VarNode;
class PassBase {
public:
PassBase() {}
virtual ~PassBase() {}
};
// MemoryOptPass will analyze the program, and reuse memory between
// variables as much as possible
class MemoryOptPass : public PassBase {
public:
MemoryOptPass() {}
virtual ~MemoryOptPass() {
for (auto &it : created_nodes_) {
delete it.second;
}
}
void operator()(const framework::ProgramDesc *program,
framework::Scope *scope);
void AppendBlockVars(const framework::BlockDesc *block);
bool IsPersistable(const std::string name);
VarNode *CreateNode(const std::string name);
private:
std::stack<VarNode *> analysis_nodes_;
std::vector<std::vector<VarNode *>> reused_nodes_;
std::unordered_map<std::string, VarNode *> created_nodes_;
std::unordered_map<std::string, framework::VarDesc *> block_vars_;
};
} // namespace pass
} // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册