diff --git a/src/common/types.cpp b/src/common/types.cpp index cea42171f0205e0d40b2703d5c90f0b9fc253e68..9bc594c7533b980626d8d07e89fc3ccf649a127f 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -24,6 +24,8 @@ const std::string G_OP_TYPE_CONCAT = "concat"; const std::string G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu"; const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu"; +const std::string G_OP_TYPE_FUSION_DWCONV_BN_RELU = "fusion_dwconv_bn_relu"; + const std::string G_OP_TYPE_FC = "fusion_fc"; const std::string G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add"; const std::string G_OP_TYPE_LRN = "lrn"; diff --git a/src/common/types.h b/src/common/types.h index ec428b9911f64d7ccc8c6f5dc4be7f970e855d3c..1daf9c9b7bccfc8bcb584e5a37f920539736a911 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -81,6 +81,7 @@ extern const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU; extern const std::string G_OP_TYPE_FC; extern const std::string G_OP_TYPE_FUSION_CONV_ADD; extern const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU; +extern const std::string G_OP_TYPE_FUSION_DWCONV_BN_RELU; extern const std::string G_OP_TYPE_LRN; extern const std::string G_OP_TYPE_MUL; diff --git a/src/operators/fusion_conv_add.h b/src/operators/fusion_conv_add.h index 170df9ce33e4ab90297664fbc81d723e7c246f83..fd63efa8f6172ad244c8e61619ef286dc3ffa1de 100644 --- a/src/operators/fusion_conv_add.h +++ b/src/operators/fusion_conv_add.h @@ -66,11 +66,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel< #ifdef PADDLE_MOBILE_CPU -#ifndef CONV_ADD_REGISTER -static framework::FusionOpRegistrar convadd_registrar( - new FusionConvAddMatcher()); -#define CONV_ADD_REGISTER -#endif +//#ifndef CONV_ADD_REGISTER +// static framework::FusionOpRegistrar convadd_registrar( +// new FusionConvAddMatcher()); +//#define CONV_ADD_REGISTER +//#endif #endif diff --git a/src/operators/fusion_dwconv_bn_relu_op.cpp b/src/operators/fusion_dwconv_bn_relu_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ba03a436c37cc8f1dcba94036fd6a3fbbd8fcaf3 --- /dev/null +++ b/src/operators/fusion_dwconv_bn_relu_op.cpp @@ -0,0 +1,60 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DWCONVBNRELU_OP + +#include "operators/fusion_dwconv_bn_relu_op.h" +#include "operators/math/conv_func.h" + +namespace paddle_mobile { +namespace operators { + +template +void FusionDWConvBNReluOp::InferShape() const { + auto in_dims = this->param_.Input()->dims(); + auto filter_dims = this->param_.Filter()->dims(); + const std::vector &strides = this->param_.Strides(); + std::vector paddings = this->param_.Paddings(); + int groups = this->param_.Groups(); + std::vector dilations = this->param_.Dilations(); + + PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() && + dilations.size() == paddings.size() && + paddings.size() == strides.size()), + "ConvParam is not suitable"); + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + output_shape.push_back( + math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], + paddings[i], strides[i])); + } + + framework::DDim ddim = framework::make_ddim(output_shape); + this->param_.Output()->Resize(ddim); +} +template class FusionDWConvBNReluOp; +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_dwconv_bn_relu, ops::FusionDWConvBNReluOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif diff --git a/src/operators/fusion_dwconv_bn_relu_op.h b/src/operators/fusion_dwconv_bn_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..bf95b51da43b8e9c0cec102876d48828b3749575 --- /dev/null +++ b/src/operators/fusion_dwconv_bn_relu_op.h @@ -0,0 +1,111 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DWCONVBNRELU_OP + +#pragma once + +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "op_param.h" +#include "operators/kernel/dwconv_bn_relu_kernel.h" + +namespace paddle_mobile { +namespace operators { +using std::string; +using std::vector; +class FusionDWConvBNReluMatcher : public framework::FusionOpMatcher { + public: + FusionDWConvBNReluMatcher() { + node_ = framework::Node(G_OP_TYPE_DEPTHWISE_CONV); + node_ > std::make_shared(G_OP_TYPE_BATCHNORM) > + std::make_shared(G_OP_TYPE_RELU); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + vector> origin_descs = + node->OpDescs(node_.Depth()); + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_BATCHNORM, + {{"Scale", "Scale"}, + {"Mean", "Mean"}, + {"Bias", "Bias"}, + {"Variance", "Variance"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_DWCONV_BN_RELU; } +}; + +template +class FusionDWConvBNReluOp : public framework::OperatorWithKernel< + DeviceType, FusionDWConvBNReluParam, + operators::DWConvBNReluKernel> { + public: + FusionDWConvBNReluOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionDWConvBNReluParam, + operators::DWConvBNReluKernel>(type, inputs, outputs, + attrs, scope) {} + + using framework::OperatorWithKernel< + DeviceType, FusionDWConvBNReluParam, + operators::DWConvBNReluKernel>::OperatorWithKernel; + void InferShape() const override; + + protected: +}; + +#ifdef PADDLE_MOBILE_CPU + +#ifndef FUSION_DWCONV_BN_RELU_REGISTER +static framework::FusionOpRegistrar fusion_dwconv_bn_relu_registrar( + new FusionDWConvBNReluMatcher()); +#define FUSION_DWCONV_BN_RELU_REGISTER +#endif + +#endif + +#ifdef PADDLE_MOBILE_MALI_GPU + +#ifndef FUSION_DWCONV_BN_RELU_REGISTER +static framework::FusionOpRegistrar fusion_dwconv_bn_relu_registrar( + new FusionDWConvBNReluMatcher()); +#define FUSION_DWCONV_BN_RELU_REGISTER +#endif + +#endif + +#ifdef PADDLE_MOBILE_FPGA +#endif + +} // namespace operators +} // namespace paddle_mobile + +#ifdef PADDLE_MOBILE_CPU +USE_OP_CPU(fusion_dwconv_bn_relu); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif diff --git a/src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp index 1fd1c66d4dc92a9918243b23e400ef5309422050..dbf3745eb15cf56bba32dc8cbae50d242ce2da76 100644 --- a/src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #ifdef FUSION_CONVADDBNRELU_OP #include "operators/kernel/conv_add_bn_relu_kernel.h" -#include "operators/kernel/central-arm-func/conv_add_bn_relu_func.h" +#include "operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h" namespace paddle_mobile { namespace operators { diff --git a/src/operators/kernel/arm/dwconv_bn_relu_kernel.cpp b/src/operators/kernel/arm/dwconv_bn_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0ec08fcecb9fefaa247e0acbb8a085e752b8dba3 --- /dev/null +++ b/src/operators/kernel/arm/dwconv_bn_relu_kernel.cpp @@ -0,0 +1,65 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DWCONVBNRELU_OP + +#include "operators/kernel/dwconv_bn_relu_kernel.h" +#include "operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool DWConvBNReluKernel::Init(FusionDWConvBNReluParam *param) { + const Tensor *mean = param->InputMean(); + const Tensor *variance = param->InputVariance(); + const Tensor *scale = param->InputScale(); + const Tensor *bias = param->InputBias(); + const float epsilon = param->Epsilon(); + + auto mean_ptr = mean->data(); + auto variance_ptr = variance->data(); + auto scale_ptr = scale->data(); + auto bias_ptr = bias->data(); + + const int C = mean->numel(); + float inv_std_ptr[C]; + for (int i = 0; i < C; i++) { + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + } + Tensor *new_scale = new Tensor(); + Tensor *new_bias = new Tensor(); + auto new_scale_ptr = new_scale->mutable_data({C}); + auto new_bias_ptr = new_bias->mutable_data({C}); + for (int i = 0; i < C; i++) { + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; + new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; + } + param->SetNewScale(new_scale); + param->SetNewBias(new_bias); + return true; +} + +template <> +void DWConvBNReluKernel::Compute( + const FusionDWConvBNReluParam ¶m) const { + DWConvBNReluCompute(param); +} +template class DWConvBNReluKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h similarity index 96% rename from src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h rename to src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h index fb49a33c67face81a2615516bffd6aa151868fe3..b74ea66fe28fbae0ffd6e6d3d4e503f5d739251b 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h @@ -15,6 +15,8 @@ limitations under the License. */ #ifdef FUSION_CONVADDBNRELU_OP #pragma once + +#include #include "operators/math/depthwise_conv_3x3.h" #include "operators/op_param.h" @@ -23,14 +25,9 @@ namespace operators { void ConvAddBNReluBasic(const FusionConvAddBNReluParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); - Tensor bias = *param.Bias(); Tensor new_bias = *param.NewBias(); Tensor new_scale = *param.NewScale(); - int axis = param.Axis(); Tensor *output = param.Output(); - math::expand_bias(bias, axis, output->dims()); - output->ShareDataWith(bias); - int groups = param.Groups(); std::vector strides = param.Strides(); std::vector paddings = param.Paddings(); @@ -121,7 +118,7 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam ¶m) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), - param.NewBias(), 1); + param.NewBias(), true); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && diff --git a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..737165884bfb89feeebfe7cf38c58edb44bc3e83 --- /dev/null +++ b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h @@ -0,0 +1,138 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DWCONVBNRELU_OP + +#pragma once +#include +#include "operators/math/depthwise_conv_3x3.h" +#include "operators/op_param.h" +namespace paddle_mobile { +namespace operators { +void DWConvBNReluBasic(const FusionDWConvBNReluParam ¶m) { + const Tensor *input = param.Input(); + Tensor filter = *param.Filter(); + Tensor new_bias = *param.NewBias(); + Tensor new_scale = *param.NewScale(); + + Tensor *output = param.Output(); + + int groups = param.Groups(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + + const int batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + + std::vector output_shape_vec(framework::vectorize(output->dims())); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + + bool is_expand = + math::IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col.mutable_data(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(in_slice, dilations, strides, paddings, &col); + } + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + std::cout << "***************" << std::endl; + math::matmulWithBn( + filter_slice, false, col_matrix, false, static_cast(1), + &out_slice, static_cast(0), false, &new_scale, &new_bias); + } + } +} +template +void DWConvBNReluCompute(const FusionDWConvBNReluParam ¶m) { + if (param.Groups() == param.Input()->dims()[1] && + param.Input()->dims()[1] == param.Output()->dims()[1] && + param.Filter()->dims()[2] == param.Filter()->dims()[3] && + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); + } else if (param.Groups() == param.Input()->dims()[1] && + param.Input()->dims()[1] == param.Output()->dims()[1] && + param.Filter()->dims()[2] == param.Filter()->dims()[3] && + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), + // param.Output(), param.NewScale(), + // param.NewBias(), 1); + math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); + } else { + DWConvBNReluBasic(param); + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/dwconv_bn_relu_kernel.h b/src/operators/kernel/dwconv_bn_relu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..91478ae5ecba37472e7e30f774f2c515b6952eee --- /dev/null +++ b/src/operators/kernel/dwconv_bn_relu_kernel.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef FUSION_DWCONVBNRELU_OP + +#include +#include "framework/ddim.h" +#include "framework/operator.h" +#include "operators/math/conv_func.h" +#include "operators/math/im2col.h" +#include "operators/math/math_function.h" +#include "operators/math/vol2col.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using framework::DDim; +using framework::OpKernelBase; + +template +class DWConvBNReluKernel + : public OpKernelBase { + public: + void Compute(const FusionDWConvBNReluParam ¶m) const; + bool Init(FusionDWConvBNReluParam *param); +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 6de4b77dbe486c3b7504212f0b0e6bd2c1c1cae2..994392d6786d42ad894ba8723cfc7ed82005f69a 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1059,6 +1059,86 @@ class FusionConvAddBNReluParam : public OpParam { Print &operator<<(Print &printer, const FusionConvAddParam &conv_param); #endif +#ifdef FUSION_DWCONVBNRELU_OP +class FusionDWConvBNReluParam : public OpParam { + public: + FusionDWConvBNReluParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + filter_ = FilterFrom(inputs, scope); + input_ = InputFrom(inputs, scope); + output_ = OutFrom(outputs, scope); + strides_ = GetAttr>("strides", attrs); + paddings_ = GetAttr>("paddings", attrs); + dilations_ = GetAttr>("dilations", attrs); + groups = GetAttr("groups", attrs); + input_bias_ = InputBiasFrom(inputs, scope); + input_mean_ = InputMeanFrom(inputs, scope); + input_scale_ = InputScaleFrom(inputs, scope); + input_variance_ = InputVarianceFrom(inputs, scope); + epsilon_ = GetAttr("epsilon", attrs); + momentum_ = GetAttr("momentum", attrs); + is_test_ = GetAttr("is_test", attrs); + } + + const Tensor *Input() const { return input_; } + + const Tensor *Filter() const { return filter_; } + + Tensor *Output() const { return output_; } + + const vector &Strides() const { return strides_; } + + const vector &Paddings() const { return paddings_; } + + const vector &Dilations() const { return dilations_; } + + const int &Groups() const { return groups; } + + const Tensor *InputBias() const { return input_bias_; } + + const Tensor *InputMean() const { return input_mean_; } + + const Tensor *InputScale() const { return input_scale_; } + + const Tensor *InputVariance() const { return input_variance_; } + + const float &Epsilon() const { return epsilon_; } + + const float &Momentum() const { return momentum_; } + + const bool &IsTest() const { return is_test_; } + + void SetNewScale(Tensor *new_scale) { new_scale_ = new_scale; } + + void SetNewBias(Tensor *new_bias) { new_bias_ = new_bias; } + + const Tensor *NewScale() const { return new_scale_; } + + const Tensor *NewBias() const { return new_bias_; } + + protected: + Tensor *input_; + Tensor *output_; + Tensor *filter_; + vector strides_; + vector paddings_; + vector dilations_; + int groups; + Tensor *input_bias_; + Tensor *input_mean_; + Tensor *input_scale_; + Tensor *input_variance_; + float epsilon_; + float momentum_; + bool is_test_; + Tensor *new_bias_; + Tensor *new_scale_; +}; + +Print &operator<<(Print &printer, const FusionConvAddParam &conv_param); +#endif + #ifdef IM2SEQUENCE_OP class Im2SequenceParam : public OpParam { public: diff --git a/tools/op.cmake b/tools/op.cmake index 71defeffcc919848e165ea836f4bfed2fcc7e0ff..456d36262e9abf997a7861838c870e698d64f3c1 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -64,6 +64,7 @@ else () set(TRANSPOSE_OP ON) set(FUSION_CONVADD_RELU_OP ON) set(FUSION_CONVADDBNRELU_OP ON) + set(FUSION_DWCONVBNRELU_OP ON) set(PRELU_OP ON) set(RESIZE_OP ON) set(SCALE_OP ON) @@ -155,6 +156,9 @@ endif() if (FUSION_CONVADDBNRELU_OP) add_definitions(-DFUSION_CONVADDBNRELU_OP) endif() +if (FUSION_DWCONVBNRELU_OP) + add_definitions(-DFUSION_DWCONVBNRELU_OP) +endif() if (PRELU_OP) add_definitions(-DPRELU_OP) endif()