提交 ab1f29d6 编写于 作者: E eclipsess

update convaddbnrelu

上级 d8a336d2
...@@ -17,11 +17,10 @@ limitations under the License. */ ...@@ -17,11 +17,10 @@ limitations under the License. */
#pragma once #pragma once
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv_3x3.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
...@@ -30,21 +29,17 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) { ...@@ -30,21 +29,17 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
auto new_bias_ptr = new_bias.data<float>(); auto new_bias_ptr = new_bias.data<float>();
auto new_scale_ptr = new_scale.data<float>(); auto new_scale_ptr = new_scale.data<float>();
int axis = param.Axis(); int axis = param.Axis();
Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
output->ShareDataWith(bias);
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations(); std::vector<int> dilations = param.Dilations();
Tensor *output = param.Output();
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
if (filter_shape_vec[2] == 3 && strides[0] == 1 && groups > 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(input, filter, output, &bias, 1,
&new_scale, &new_bias, 1, 1);
} else {
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
math::expand_bias(bias, axis, output->dims()); std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
output->ShareDataWith(bias);
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims())); std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2; size_t data_dim = filter_shape_vec.size() - 2;
...@@ -107,16 +102,15 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) { ...@@ -107,16 +102,15 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
// vol2col // vol2col
vol2col(in_slice, dilations, strides, paddings, &col); vol2col(in_slice, dilations, strides, paddings, &col);
} }
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false, math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(1), false); static_cast<float>(1));
} }
} }
/// todo : use neon in special case instead of 2for(300ms)
auto output_ptr = output->data<float>(); auto output_ptr = output->data<float>();
for (int c = 0; c < output_matrix_shape[0]; c++) { for (int c = 0; c < output_matrix_shape[0]; c++) {
int start = c * output_matrix_shape[1]; int start = c * output_matrix_shape[1];
...@@ -127,8 +121,29 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) { ...@@ -127,8 +121,29 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
output_ptr[start + j] < 0 ? 0 : output_ptr[start + j]; output_ptr[start + j] < 0 ? 0 : output_ptr[start + j];
} }
} }
}
template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(
param.Input(), param.Filter(), param.Output(), &Bias, 1,
param.NewScale(), param.NewBias(), 1, 1);
} else if (0 && param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), &Bias, param.Output(), false);
} else {
ConvAddBNReluBasic(param);
} }
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -508,12 +508,13 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -508,12 +508,13 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
} }
} }
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, Tensor filter, void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias, Tensor *output, Tensor *bias, bool if_bias,
Tensor *new_scale, Tensor *new_bias, const Tensor *new_scale,
bool if_bn, bool if_relu) { const Tensor *new_bias, bool if_bn,
bool if_relu) {
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter.data<float>(); const float *filter_data = filter->data<float>();
float *output_data = output->data<float>(); float *output_data = output->data<float>();
const float *bias_data = bias->data<float>(); const float *bias_data = bias->data<float>();
const float *newscale_data = new_scale->data<float>(); const float *newscale_data = new_scale->data<float>();
......
...@@ -32,10 +32,11 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides, ...@@ -32,10 +32,11 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
Tensor *output, bool if_bias); Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias); Tensor *output, Tensor *bias, bool if_bias);
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, Tensor filter, void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias, Tensor *output, Tensor *bias, bool if_bias,
Tensor *new_scale, Tensor *new_bias, const Tensor *new_scale,
bool if_bn, bool if_relu); const Tensor *new_bias, bool if_bn,
bool if_relu);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册