未验证 提交 1890199f 编写于 作者: S smilejames 提交者: GitHub

Merge pull request #601 from codeWorm2015/develop

fix #600 add marco to asse, add fusion conv batch relu op
......@@ -25,7 +25,7 @@ const std::string G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu";
const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu";
const std::string G_OP_TYPE_FUSION_DWCONV_BN_RELU = "fusion_dwconv_bn_relu";
const std::string G_OP_TYPE_FUSION_CONV_BN_RELU = "fusion_conv_bn_relu";
const std::string G_OP_TYPE_FC = "fusion_fc";
const std::string G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add";
const std::string G_OP_TYPE_LRN = "lrn";
......@@ -49,6 +49,8 @@ std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key = {
{G_OP_TYPE_CONV, {{"Input"}, {"Output"}}},
{G_OP_TYPE_FUSION_DWCONV_BN_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_BN_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD, {{"Input"}, {"Out"}}},
{G_OP_TYPE_RELU, {{"X"}, {"Out"}}},
{G_OP_TYPE_SOFTMAX, {{"X"}, {"Out"}}},
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace paddle_mobile {
......@@ -82,6 +83,7 @@ extern const std::string G_OP_TYPE_FC;
extern const std::string G_OP_TYPE_FUSION_CONV_ADD;
extern const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU;
extern const std::string G_OP_TYPE_FUSION_DWCONV_BN_RELU;
extern const std::string G_OP_TYPE_FUSION_CONV_BN_RELU;
extern const std::string G_OP_TYPE_LRN;
extern const std::string G_OP_TYPE_MUL;
......
......@@ -28,6 +28,16 @@ vector<string> OperatorBase<Dtype>::GetOutKeys() const {
return it->second.second;
}
template <typename Dtype>
vector<string> OperatorBase<Dtype>::GetInputKeys() const {
auto it = op_input_output_key.find(type_);
if (it == op_input_output_key.end()) {
DLOG << type_ << " has no outputs";
return {};
}
return it->second.first;
}
template <typename Dtype>
OperatorBase<Dtype>::OperatorBase(const std::string &type,
const VariableNameMap &inputs,
......@@ -49,6 +59,11 @@ template <typename Dtype>
void OperatorBase<Dtype>::Run() const {
RunImpl();
#ifdef PADDLE_MOBILE_DEBUG
vector<string> input_keys = GetInputKeys();
for (const auto key : input_keys) {
Tensor *input = GetVarValue<framework::LoDTensor>(key, inputs_, *scope_);
DLOG << type_ << " input- " << key << "=" << *input;
}
vector<string> output_keys = GetOutKeys();
for (const auto key : output_keys) {
Tensor *out_ = GetVarValue<framework::LoDTensor>(key, outputs_, *scope_);
......
......@@ -61,6 +61,7 @@ class OperatorBase {
virtual ~OperatorBase() {}
void Run() const;
std::vector<string> GetOutKeys() const;
std::vector<string> GetInputKeys() const;
virtual void RunImpl() const = 0;
virtual void Init() = 0;
......@@ -118,6 +119,10 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
virtual void InferShape() const = 0;
void Init() {
// for (auto i : this->inputs_) {
// DLOG << i.first;
// DLOG << i.second;
// }
PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), " %s kernel init failed",
this->type_.c_str());
}
......@@ -146,7 +151,7 @@ class OpKernelBase {
}
#endif
virtual void Compute(const P &para) const = 0;
virtual bool Init(P *para) { return true; };
virtual bool Init(P *para) { return true; }
virtual ~OpKernelBase() = default;
private:
......
......@@ -66,11 +66,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel<
#ifdef PADDLE_MOBILE_CPU
//#ifndef CONV_ADD_REGISTER
// static framework::FusionOpRegistrar convadd_registrar(
// new FusionConvAddMatcher());
//#define CONV_ADD_REGISTER
//#endif
#ifndef CONV_ADD_REGISTER
static framework::FusionOpRegistrar convadd_registrar(
new FusionConvAddMatcher());
#define CONV_ADD_REGISTER
#endif
#endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNRELU_OP
#include "operators/fusion_conv_bn_relu_op.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionConvBNReluOp<Dtype, T>::InferShape() const {
auto in_dims = this->param_.Input()->dims();
auto filter_dims = this->param_.Filter()->dims();
const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
int groups = this->param_.Groups();
std::vector<int> dilations = this->param_.Dilations();
PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() &&
dilations.size() == paddings.size() &&
paddings.size() == strides.size()),
"ConvParam is not suitable");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(
math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i],
paddings[i], strides[i]));
}
framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_conv_bn_relu, ops::FusionConvBNReluOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNRELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/conv_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
using std::vector;
class FusionConvBNReluMatcher : public framework::FusionOpMatcher {
public:
FusionConvBNReluMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_BATCHNORM,
{{"Scale", "Scale"},
{"Mean", "Mean"},
{"Bias", "Bias"},
{"Variance", "Variance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_CONV_BN_RELU; }
};
template <typename DeviceType, typename T>
class FusionConvBNReluOp : public framework::OperatorWithKernel<
DeviceType, FusionConvBNReluParam,
operators::ConvBNReluKernel<DeviceType, T>> {
public:
FusionConvBNReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionConvBNReluParam,
operators::ConvBNReluKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionConvBNReluParam,
operators::ConvBNReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
#ifdef PADDLE_MOBILE_CPU
#ifndef FUSION_CONV_BN_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_bn_relu_registrar(
new FusionConvBNReluMatcher());
#define FUSION_CONV_BN_RELU_REGISTER
#endif
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
} // namespace operators
} // namespace paddle_mobile
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(fusion_conv_bn_relu);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNRELU_OP
#include "operators/kernel/conv_bn_relu_kernel.h"
#include "operators/kernel/central-arm-func/conv_bn_relu_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam *param) {
const Tensor *mean = param->InputMean();
const Tensor *variance = param->InputVariance();
const Tensor *scale = param->InputScale();
const Tensor *bias = param->InputBias();
const float epsilon = param->Epsilon();
// DLOG << "variance: " << *variance;
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
const int C = mean->numel();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor *new_scale = new Tensor();
Tensor *new_bias = new Tensor();
auto new_scale_ptr = new_scale->mutable_data<float>({C});
auto new_bias_ptr = new_bias->mutable_data<float>({C});
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
return true;
}
template <>
void ConvBNReluKernel<CPU, float>::Compute(
const FusionConvBNReluParam &param) const {
ConvBNReluCompute<float>(param);
}
template class ConvBNReluKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -54,7 +54,40 @@ void BatchnormCompute(const BatchNormParam &param) {
int HXW = H * W;
#ifdef ARMV7
#if __ARM_NEON
#if __aarch64__
float *inv_std_ptr = new float[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor new_scale;
auto new_scale_ptr = new_scale.mutable_data<float>(framework::make_ddim({C}));
Tensor new_bias;
auto new_bias_ptr = new_bias.mutable_data<float>(framework::make_ddim({C}));
/// ((x - est_mean) * (inv_var) * scale + bias equal to
/// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
{
for (int n = 0; n < N; n++) {
for (int h = 0; h < H; h++) {
int tmp_index = n * stride0 + i * stride1 + h * stride2;
for (int w = 0; w < W; w++) {
int index = tmp_index + w;
out_ptr[index] =
input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i];
}
}
}
}
}
delete[] inv_std_ptr;
#else
if (HXW > 32) {
int NXC = N * C;
float *inv_std_ptr = new float[NXC * 4];
......@@ -229,6 +262,7 @@ void BatchnormCompute(const BatchNormParam &param) {
delete[] inv_std_ptr;
}
#endif
#else
float *inv_std_ptr = new float[C];
for (int i = 0; i < C; i++) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNRELU_OP
#pragma once
#include <vector>
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void ConvBNReluBasic(const FusionConvBNReluParam &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
Tensor *output = param.Output();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), true, &new_scale, &new_bias, g);
}
}
}
template <typename P>
void ConvBNReluCompute(const FusionConvBNReluParam &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else {
ConvBNReluBasic(param);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -76,7 +76,7 @@ void PoolCompute(const PoolParam &param) {
}
} else if (ksize[0] == 2 && ksize[0] == ksize[1]) {
#ifndef IOS
#if __ARM_NEON
if (pooling_type == "max") {
math::Pool2x2Max(strides, paddings, in_x, out);
} else if (pooling_type == "avg") {
......@@ -84,7 +84,8 @@ void PoolCompute(const PoolParam &param) {
}
#else
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
#endif
#endif // __ARM_NEON
} else {
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
}
......
......@@ -68,6 +68,7 @@ void sigmoid(const Tensor *X, Tensor *Y) {
input_outer_ptr++;
}
}
#else
#endif
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef FUSION_CONVBNRELU_OP
#include <vector>
#include "framework/ddim.h"
#include "framework/operator.h"
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::DDim;
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvBNReluKernel
: public OpKernelBase<DeviceType, FusionConvBNReluParam> {
public:
void Compute(const FusionConvBNReluParam &param) const;
bool Init(FusionConvBNReluParam *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "operators/math/im2col.h"
#include <vector>
#ifdef __ARM_NEON
#include "arm_neon.h"
#include <arm_neon.h>
#endif
#include "common/types.h"
namespace paddle_mobile {
......@@ -69,7 +69,7 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
int channels_col = im_channels * filter_height * filter_width;
const T *im_data = im.data<T>();
T *col_data = col->data<T>();
#ifdef __ARM_NEON
#if __ARM_NEON
const int osize = col_height;
const int isize = im_height;
bool pad1 = padding[0] > 0;
......
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POOL_OP
#include "pool_2x2.h"
#include "operators/math/pool_2x2.h"
#include <algorithm>
#include <vector>
namespace paddle_mobile {
namespace operators {
......@@ -21,10 +23,10 @@ namespace math {
void Pool2x2Max(vector<int> strides, vector<int> paddings, const Tensor *input,
Tensor *output) {
#ifdef __ARM_NEON
#ifdef ARMV7
#if __ARM_NEON
#if __aarch64__
#else
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
......@@ -93,15 +95,16 @@ void Pool2x2Max(vector<int> strides, vector<int> paddings, const Tensor *input,
output_data += output_batch_stride;
}
#endif
#else
#endif
}
void Pool2x2Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
Tensor *output) {
#ifdef __ARM_NEON
#if __ARM_NEON
#ifdef ARMV7
#if __aarch64__
#else
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
......@@ -171,12 +174,9 @@ void Pool2x2Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
input_data += input_batch_stride;
output_data += output_batch_stride;
}
#else
// TODO(): to imp other asm
#endif
#else
#endif
}
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <omp.h>
#endif
#include "framework/tensor.h"
#include "pool_3x3.h"
#include "operators/math/pool_3x3.h"
#if __ARM_NEON
#include <arm_neon.h>
#endif // __ARM_NEON
......@@ -518,6 +518,8 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
input_data += input_batch_stride;
out_data += output_batch_stride;
}
#else
#endif
}
......@@ -582,7 +584,18 @@ void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
}
output_seg[ph * output_width + pw] = max_value;
} else {
#if defined(ARMV7)
#if __aarch64__
const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos1 + input_width);
const float32x4_t data3 = vld1q_f32(pos1 + 2 * input_width);
const float32x4_t max_data =
vmaxq_f32(vmaxq_f32(data1, data2), data3);
float32x2_t res =
vpmax_f32(vget_high_f32(vsetq_lane_f32(-INT_MAX, max_data, 3)),
vget_low_f32(max_data));
res = vpmax_f32(res, res);
output_seg[ph * output_width + pw] = vget_lane_f32(res, 0);
#else
asm volatile(
"vld1.32 {q1}, [%[pos1]] \n\t"
"vld1.32 {q2}, [%[pos2]] \n\t"
......@@ -598,17 +611,6 @@ void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
[pos2] "r"(pos2), [pos3] "r"(pos3),
[output_ptr] "r"(output_ptr), [negative_max] "r"(negative_max)
: "memory", "q1", "q2", "q3", "q4");
#else
const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos1 + input_width);
const float32x4_t data3 = vld1q_f32(pos1 + 2 * input_width);
const float32x4_t max_data =
vmaxq_f32(vmaxq_f32(data1, data2), data3);
float32x2_t res =
vpmax_f32(vget_high_f32(vsetq_lane_f32(-INT_MAX, max_data, 3)),
vget_low_f32(max_data));
res = vpmax_f32(res, res);
output_seg[ph * output_width + pw] = vget_lane_f32(res, 0);
#endif
}
}
......@@ -676,8 +678,8 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
}
output_seg[ph * output_width + pw] = sum / 9.0;
} else {
#if defined(ARMV7)
#if __aarch64__
#else
asm volatile(
"vld1.32 {q1}, [%[pos1]] \n\t"
"vld1.32 {q2}, [%[pos2]] \n\t"
......@@ -696,7 +698,7 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
[output_ptr] "r"(output_ptr), [zero] "r"(zero),
[nine_ptr] "r"(nine_ptr)
: "memory", "r6", "q1", "q2", "q3", "q4");
#else
#endif
const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos2);
const float32x4_t data3 = vld1q_f32(pos3);
......@@ -707,7 +709,6 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
vget_low_f32(sum_data));
res = vpadd_f32(res, res);
output_seg[ph * output_width + pw] = vget_lane_f32(res, 0) / 9.0;
#endif
}
}
}
......@@ -715,6 +716,7 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
input_data += input_batch_stride;
output_data += output_batch_stride;
}
#else
#endif
}
} // namespace math
......
......@@ -135,6 +135,7 @@ class SoftmaxFuntor<CPU, T> {
}
}
}
#else
#endif // ARM_NEON
public:
......
......@@ -1078,7 +1078,7 @@ class FusionDWConvBNReluParam : public OpParam {
input_variance_ = InputVarianceFrom<LoDTensor>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
is_test_ = GetAttr<bool>("is_test", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
}
const Tensor *Input() const { return input_; }
......@@ -1139,6 +1139,85 @@ class FusionDWConvBNReluParam : public OpParam {
Print &operator<<(Print &printer, const FusionConvAddParam &conv_param);
#endif
#ifdef FUSION_CONVBNRELU_OP
class FusionConvBNReluParam : public OpParam {
public:
FusionConvBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
filter_ = FilterFrom<LoDTensor>(inputs, scope);
input_ = InputFrom<LoDTensor>(inputs, scope);
output_ = OutFrom<LoDTensor>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
input_bias_ = InputBiasFrom<LoDTensor>(inputs, scope);
input_mean_ = InputMeanFrom<LoDTensor>(inputs, scope);
input_scale_ = InputScaleFrom<LoDTensor>(inputs, scope);
input_variance_ = InputVarianceFrom<LoDTensor>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
}
const Tensor *Input() const { return input_; }
const Tensor *Filter() const { return filter_; }
Tensor *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const Tensor *InputBias() const { return input_bias_; }
const Tensor *InputMean() const { return input_mean_; }
const Tensor *InputScale() const { return input_scale_; }
const Tensor *InputVariance() const { return input_variance_; }
const float &Epsilon() const { return epsilon_; }
const float &Momentum() const { return momentum_; }
const bool &IsTest() const { return is_test_; }
void SetNewScale(Tensor *new_scale) { new_scale_ = new_scale; }
void SetNewBias(Tensor *new_bias) { new_bias_ = new_bias; }
const Tensor *NewScale() const { return new_scale_; }
const Tensor *NewBias() const { return new_bias_; }
protected:
Tensor *input_;
Tensor *output_;
Tensor *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
Tensor *input_bias_;
Tensor *input_mean_;
Tensor *input_scale_;
Tensor *input_variance_;
float epsilon_;
float momentum_;
bool is_test_;
Tensor *new_bias_;
Tensor *new_scale_;
};
#endif
#ifdef IM2SEQUENCE_OP
class Im2SequenceParam : public OpParam {
public:
......
......@@ -19,7 +19,9 @@ int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader;
// ../../../test/models/googlenet
// ../../../test/models/mobilenet
auto program = loader.Load(g_googlenet, true);
// auto program = loader.Load(g_googlenet, true);
auto program = loader.Load(g_mobilenet_ssd, true);
// auto program = loader.Load(g_googlenet_combine + "/model",
// g_googlenet_combine +
// "/params", true);
......
......@@ -23,7 +23,7 @@ int main() {
auto time1 = time();
if (paddle_mobile.Load(g_googlenet, optimize)) {
auto time2 = time();
DLOG << "load cost :" << time_diff(time1, time1) << "ms";
DLOG << "load cost: " << time_diff(time1, time1) << "ms";
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
......
......@@ -12,16 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
auto time1 = time();
if (paddle_mobile.Load(g_mobilenet_ssd, true)) {
// auto isok = paddle_mobile.Load(g_mobilenet_ssd_gesture + "/model",
// g_mobilenet_ssd_gesture + "/params",
// true);
auto isok = paddle_mobile.Load(g_mobilenet_ssd, false);
if (isok) {
auto time2 = time();
DLOG << "load cost :" << time_diff(time1, time1) << "ms";
std::cout << "load cost :" << time_diff(time1, time2) << "ms" << std::endl;
std::vector<int64_t> dims{1, 3, 300, 300};
Tensor input_tensor;
......@@ -33,7 +37,8 @@ int main() {
auto time3 = time();
paddle_mobile.Predict(input, dims);
auto time4 = time();
DLOG << "predict cost :" << time_diff(time3, time4) << "ms";
std::cout << "predict cost :" << time_diff(time3, time4) << "ms"
<< std::endl;
}
return 0;
}
......@@ -16,6 +16,8 @@ limitations under the License. */
#include <fstream>
#include <random>
#include <string>
#include <vector>
#include "common/common.h"
#include "common/log.h"
......@@ -23,6 +25,8 @@ limitations under the License. */
#include "framework/tensor.h"
static const std::string g_mobilenet_ssd = "../models/mobilenet+ssd";
static const std::string g_mobilenet_ssd_gesture =
"../models/mobilenet+ssd_gesture";
static const std::string g_squeezenet = "../models/squeezenet";
static const std::string g_googlenet = "../models/googlenet";
static const std::string g_mobilenet = "../models/mobilenet";
......@@ -62,9 +66,9 @@ void GetInput(const std::string &input_name, std::vector<T> *input,
size *= dim;
}
T *input_ptr = (T *)malloc(sizeof(T) * size);
T *input_ptr = reinterpret_cast<T *>(malloc(sizeof(T) * size));
std::ifstream in(input_name, std::ios::in | std::ios::binary);
in.read((char *)(input_ptr), size * sizeof(T));
in.read(reinterpret_cast<char *>(input_ptr), size * sizeof(T));
in.close();
for (int i = 0; i < size; ++i) {
input->push_back(input_ptr[i]);
......@@ -79,6 +83,6 @@ void GetInput(const std::string &input_name,
T *input_ptr = input->mutable_data<T>(dims);
std::ifstream in(input_name, std::ios::in | std::ios::binary);
in.read((char *)(input_ptr), input->numel() * sizeof(T));
in.read(reinterpret_cast<char *>(input_ptr), input->numel() * sizeof(T));
in.close();
}
......@@ -65,6 +65,7 @@ else ()
set(FUSION_CONVADD_RELU_OP ON)
set(FUSION_CONVADDBNRELU_OP ON)
set(FUSION_DWCONVBNRELU_OP ON)
set(FUSION_CONVBNRELU_OP ON)
set(PRELU_OP ON)
set(RESIZE_OP ON)
set(SCALE_OP ON)
......@@ -159,6 +160,11 @@ endif()
if (FUSION_DWCONVBNRELU_OP)
add_definitions(-DFUSION_DWCONVBNRELU_OP)
endif()
if (FUSION_CONVBNRELU_OP)
add_definitions(-DFUSION_CONVBNRELU_OP)
endif()
if (PRELU_OP)
add_definitions(-DPRELU_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册