提交 d541ab39 编写于 作者: Y yangfei

imp fusion_bn_add_relu op in resnet

上级 db214eb0
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNADDRELU_OP
#include "operators/fusion_conv_bn_add_relu_op.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionConvBNAddReluOp<Dtype, T>::InferShape() const {
auto in_dims = this->param_.Input()->dims();
auto filter_dims = this->param_.Filter()->dims();
const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
int groups = this->param_.Groups();
std::vector<int> dilations = this->param_.Dilations();
PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() &&
dilations.size() == paddings.size() &&
paddings.size() == strides.size()),
"ConvParam is not suitable");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(
math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i],
paddings[i], strides[i]));
}
framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_conv_bn_add_relu, ops::FusionConvBNAddReluOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fusion_conv_bn_add_relu, ops::FusionConvBNAddReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNADDRELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "op_param.h"
#include "operators/kernel/conv_bn_add_relu_kernel.h"
namespace paddle_mobile {
namespace operators {
using std::string;
using std::vector;
class FusionConvBNAddReluMatcher : public framework::FusionOpMatcher {
public:
FusionConvBNAddReluMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"},{"X","X"}}},
{G_OP_TYPE_BATCHNORM,
{{"Scale", "Scale"},
{"Mean", "Mean"},
{"Bias", "Bias"},
{"Variance", "Variance"},
{"Y","BNY"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_CONV_BN_ADD_RELU; }
std::vector<std::pair<int, std::string>> NeedCheck() {
DLOG << " conv bn add relu check add X ";
return {{2, "Y"}, {2, "X"}};
}
};
template <typename DeviceType, typename T>
class FusionConvBNAddReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionConvBNAddReluParam<DeviceType>,
operators::ConvBNAddReluKernel<DeviceType, T>> {
public:
FusionConvBNAddReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionConvBNAddReluParam<DeviceType>,
operators::ConvBNAddReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionConvBNAddReluParam<DeviceType>,
operators::ConvBNAddReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
#ifdef PADDLE_MOBILE_CPU
#ifndef FUSION_CONV_BN_ADD_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_bn_add_relu_registrar(
new FusionConvBNAddReluMatcher());
#define FUSION_CONV_BN_ADD_RELU_REGISTER
#endif
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#ifndef FUSION_CONV_BN_ADD_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_bn_add_relu_registrar(
new FusionConvBNAddReluMatcher());
#define FUSION_CONV_BN_ADD_RELU_REGISTER
#endif
#endif
#ifdef PADDLE_MOBILE_FPGA
#ifndef FUSION_CONV_BN_ADD_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_bn_add_relu_registrar(
new FusionConvBNAddReluMatcher());
#define FUSION_CONV_BN_ADD_RELU_REGISTER
#endif
#endif
} // namespace operators
} // namespace paddle_mobile
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(fusion_conv_bn_add_relu);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
USE_OP_FPGA(fusion_conv_bn_add_relu);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNADDRELU_OP
#include "operators/kernel/conv_bn_add_relu_kernel.h"
#include "operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvBNAddReluKernel<CPU, float>::Init(
FusionConvBNAddReluParam<CPU> *param) {
const Tensor *mean = param->InputMean();
const Tensor *variance = param->InputVariance();
const Tensor *scale = param->InputScale();
const Tensor *bias = param->InputBias();
const float epsilon = param->Epsilon();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
const int C = mean->numel();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor *new_scale = new Tensor();
Tensor *new_bias = new Tensor();
auto new_scale_ptr = new_scale->mutable_data<float>({C});
auto new_bias_ptr = new_bias->mutable_data<float>({C});
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
return true;
}
template <>
void ConvBNAddReluKernel<CPU, float>::Compute(
const FusionConvBNAddReluParam<CPU> &param) const {
ConvBNAddReluCompute<float>(param);
}
template class ConvBNAddReluKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNADDRELU_OP
#pragma once
#include <vector>
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
Tensor *output = param.Output();
Tensor *bias1 = param.Bias();
int groups = param.Groups();
DLOG<<"yangfei2";
DLOG<<bias1->dims();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor bias_batch = bias1->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
Tensor bias_data = bias_batch.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBnAdd<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(1), true, &new_scale, &new_bias, g,bias_data.data<float>());
}
}
}
template <typename P>
void ConvBNAddReluCompute(const FusionConvBNAddReluParam<CPU> &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else {
ConvBNAddReluBasic(param);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef FUSION_CONVBNADDRELU_OP
#include <vector>
#include "framework/ddim.h"
#include "framework/operator.h"
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::DDim;
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvBNAddReluKernel
: public OpKernelBase<DeviceType, FusionConvBNAddReluParam<DeviceType>> {
public:
void Compute(const FusionConvBNAddReluParam<DeviceType> &param) const;
bool Init(FusionConvBNAddReluParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册