提交 158f4210 编写于 作者: E eclipsess

add dwconv+bn+relu

上级 c9deac4a
...@@ -24,6 +24,8 @@ const std::string G_OP_TYPE_CONCAT = "concat"; ...@@ -24,6 +24,8 @@ const std::string G_OP_TYPE_CONCAT = "concat";
const std::string G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; const std::string G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu"; const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu";
const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu"; const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu";
const std::string G_OP_TYPE_FUSION_DWCONV_BN_RELU = "fusion_dwconv_bn_relu";
const std::string G_OP_TYPE_FC = "fusion_fc"; const std::string G_OP_TYPE_FC = "fusion_fc";
const std::string G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add"; const std::string G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add";
const std::string G_OP_TYPE_LRN = "lrn"; const std::string G_OP_TYPE_LRN = "lrn";
......
...@@ -81,6 +81,7 @@ extern const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU; ...@@ -81,6 +81,7 @@ extern const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU;
extern const std::string G_OP_TYPE_FC; extern const std::string G_OP_TYPE_FC;
extern const std::string G_OP_TYPE_FUSION_CONV_ADD; extern const std::string G_OP_TYPE_FUSION_CONV_ADD;
extern const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU; extern const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU;
extern const std::string G_OP_TYPE_FUSION_DWCONV_BN_RELU;
extern const std::string G_OP_TYPE_LRN; extern const std::string G_OP_TYPE_LRN;
extern const std::string G_OP_TYPE_MUL; extern const std::string G_OP_TYPE_MUL;
......
...@@ -52,10 +52,10 @@ void ProgramDesc::Description(std::string header) { ...@@ -52,10 +52,10 @@ void ProgramDesc::Description(std::string header) {
LOG(kLOG_DEBUG3) << "argument - " << n; LOG(kLOG_DEBUG3) << "argument - " << n;
} }
} }
for (auto &attr : op->GetAttrMap()) { // for (auto &attr : op->GetAttrMap()) {
LOG(kLOG_DEBUG2) << "attr name:: " << attr.first; // LOG(kLOG_DEBUG2) << "attr name:: " << attr.first;
LOG(kLOG_DEBUG3) << "argument - " << attr.second; // LOG(kLOG_DEBUG3) << "argument - " << attr.second;
} // }
} }
for (const auto &var_desc : block->Vars()) { for (const auto &var_desc : block->Vars()) {
......
...@@ -68,11 +68,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel< ...@@ -68,11 +68,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel<
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
#ifndef CONV_ADD_REGISTER //#ifndef CONV_ADD_REGISTER
static framework::FusionOpRegistrar convadd_registrar( // static framework::FusionOpRegistrar convadd_registrar(
new FusionConvAddMatcher()); // new FusionConvAddMatcher());
#define CONV_ADD_REGISTER //#define CONV_ADD_REGISTER
#endif //#endif
#endif #endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DWCONVBNRELU_OP
#include "operators/fusion_dwconv_bn_relu_op.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionDWConvBNReluOp<Dtype, T>::InferShape() const {
auto in_dims = this->param_.Input()->dims();
auto filter_dims = this->param_.Filter()->dims();
const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
int groups = this->param_.Groups();
std::vector<int> dilations = this->param_.Dilations();
PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() &&
dilations.size() == paddings.size() &&
paddings.size() == strides.size()),
"ConvParam is not suitable");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(
math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i],
paddings[i], strides[i]));
}
framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim);
}
template class FusionDWConvBNReluOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dwconv_bn_relu, ops::FusionDWConvBNReluOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DWCONVBNRELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "op_param.h"
#include "operators/kernel/dwconv_bn_relu_kernel.h"
namespace paddle_mobile {
namespace operators {
using std::string;
using std::vector;
class FusionDWConvBNReluMatcher : public framework::FusionOpMatcher {
public:
FusionDWConvBNReluMatcher() {
node_ = framework::Node(G_OP_TYPE_DEPTHWISE_CONV);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
vector<std::shared_ptr<framework::OpDesc>> origin_descs =
node->OpDescs(node_.Depth());
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_BATCHNORM,
{{"Scale", "Scale"},
{"Mean", "Mean"},
{"Bias", "Bias"},
{"Variance", "Variance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DWCONV_BN_RELU; }
};
template <typename DeviceType, typename T>
class FusionDWConvBNReluOp : public framework::OperatorWithKernel<
DeviceType, FusionDWConvBNReluParam,
operators::DWConvBNReluKernel<DeviceType, T>> {
public:
FusionDWConvBNReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDWConvBNReluParam,
operators::DWConvBNReluKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionDWConvBNReluParam,
operators::DWConvBNReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
#ifdef PADDLE_MOBILE_CPU
#ifndef FUSION_DWCONV_BN_RELU_REGISTER
static framework::FusionOpRegistrar fusion_dwconv_bn_relu_registrar(
new FusionDWConvBNReluMatcher());
#define FUSION_DWCONV_BN_RELU_REGISTER
#endif
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#ifndef FUSION_DWCONV_BN_RELU_REGISTER
static framework::FusionOpRegistrar fusion_dwconv_bn_relu_registrar(
new FusionDWConvBNReluMatcher());
#define FUSION_DWCONV_BN_RELU_REGISTER
#endif
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
} // namespace operators
} // namespace paddle_mobile
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(fusion_dwconv_bn_relu);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#ifdef FUSION_CONVADDBNRELU_OP #ifdef FUSION_CONVADDBNRELU_OP
#include "operators/kernel/conv_add_bn_relu_kernel.h" #include "operators/kernel/conv_add_bn_relu_kernel.h"
#include "operators/kernel/central-arm-func/conv_add_bn_relu_func.h" #include "operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DWCONVBNRELU_OP
#include "operators/kernel/dwconv_bn_relu_kernel.h"
#include "operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool DWConvBNReluKernel<CPU, float>::Init(FusionDWConvBNReluParam *param) {
const Tensor *mean = param->InputMean();
const Tensor *variance = param->InputVariance();
const Tensor *scale = param->InputScale();
const Tensor *bias = param->InputBias();
const float epsilon = param->Epsilon();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
const int C = mean->numel();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor *new_scale = new Tensor();
Tensor *new_bias = new Tensor();
auto new_scale_ptr = new_scale->mutable_data<float>({C});
auto new_bias_ptr = new_bias->mutable_data<float>({C});
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
return true;
}
template <>
void DWConvBNReluKernel<CPU, float>::Compute(
const FusionDWConvBNReluParam &param) const {
DWConvBNReluCompute<float>(param);
}
template class DWConvBNReluKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#ifdef FUSION_CONVADDBNRELU_OP #ifdef FUSION_CONVADDBNRELU_OP
#pragma once #pragma once
#include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv_3x3.h"
#include "operators/op_param.h" #include "operators/op_param.h"
...@@ -23,14 +25,9 @@ namespace operators { ...@@ -23,14 +25,9 @@ namespace operators {
void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) { void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
Tensor new_bias = *param.NewBias(); Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale(); Tensor new_scale = *param.NewScale();
int axis = param.Axis();
Tensor *output = param.Output(); Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
output->ShareDataWith(bias);
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
...@@ -121,7 +118,7 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) { ...@@ -121,7 +118,7 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(), param.Output(), param.NewScale(),
param.NewBias(), 1); param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DWCONVBNRELU_OP
#pragma once
#include <vector>
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void DWConvBNReluBasic(const FusionDWConvBNReluParam &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
Tensor *output = param.Output();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
std::cout << "***************" << std::endl;
math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), false, &new_scale, &new_bias);
}
}
}
template <typename P>
void DWConvBNReluCompute(const FusionDWConvBNReluParam &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else {
DWConvBNReluBasic(param);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef FUSION_DWCONVBNRELU_OP
#include <vector>
#include "framework/ddim.h"
#include "framework/operator.h"
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::DDim;
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class DWConvBNReluKernel
: public OpKernelBase<DeviceType, FusionDWConvBNReluParam> {
public:
void Compute(const FusionDWConvBNReluParam &param) const;
bool Init(FusionDWConvBNReluParam *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -1059,6 +1059,86 @@ class FusionConvAddBNReluParam : public OpParam { ...@@ -1059,6 +1059,86 @@ class FusionConvAddBNReluParam : public OpParam {
Print &operator<<(Print &printer, const FusionConvAddParam &conv_param); Print &operator<<(Print &printer, const FusionConvAddParam &conv_param);
#endif #endif
#ifdef FUSION_DWCONVBNRELU_OP
class FusionDWConvBNReluParam : public OpParam {
public:
FusionDWConvBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
filter_ = FilterFrom<LoDTensor>(inputs, scope);
input_ = InputFrom<LoDTensor>(inputs, scope);
output_ = OutFrom<LoDTensor>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
input_bias_ = InputBiasFrom<LoDTensor>(inputs, scope);
input_mean_ = InputMeanFrom<LoDTensor>(inputs, scope);
input_scale_ = InputScaleFrom<LoDTensor>(inputs, scope);
input_variance_ = InputVarianceFrom<LoDTensor>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
is_test_ = GetAttr<bool>("is_test", attrs);
}
const Tensor *Input() const { return input_; }
const Tensor *Filter() const { return filter_; }
Tensor *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const Tensor *InputBias() const { return input_bias_; }
const Tensor *InputMean() const { return input_mean_; }
const Tensor *InputScale() const { return input_scale_; }
const Tensor *InputVariance() const { return input_variance_; }
const float &Epsilon() const { return epsilon_; }
const float &Momentum() const { return momentum_; }
const bool &IsTest() const { return is_test_; }
void SetNewScale(Tensor *new_scale) { new_scale_ = new_scale; }
void SetNewBias(Tensor *new_bias) { new_bias_ = new_bias; }
const Tensor *NewScale() const { return new_scale_; }
const Tensor *NewBias() const { return new_bias_; }
protected:
Tensor *input_;
Tensor *output_;
Tensor *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
Tensor *input_bias_;
Tensor *input_mean_;
Tensor *input_scale_;
Tensor *input_variance_;
float epsilon_;
float momentum_;
bool is_test_;
Tensor *new_bias_;
Tensor *new_scale_;
};
Print &operator<<(Print &printer, const FusionConvAddParam &conv_param);
#endif
#ifdef IM2SEQUENCE_OP #ifdef IM2SEQUENCE_OP
class Im2SequenceParam : public OpParam { class Im2SequenceParam : public OpParam {
public: public:
......
...@@ -64,6 +64,7 @@ else () ...@@ -64,6 +64,7 @@ else ()
set(TRANSPOSE_OP ON) set(TRANSPOSE_OP ON)
set(FUSION_CONVADD_RELU_OP ON) set(FUSION_CONVADD_RELU_OP ON)
set(FUSION_CONVADDBNRELU_OP ON) set(FUSION_CONVADDBNRELU_OP ON)
set(FUSION_DWCONVBNRELU_OP ON)
set(PRELU_OP ON) set(PRELU_OP ON)
set(RESIZE_OP ON) set(RESIZE_OP ON)
set(SCALE_OP ON) set(SCALE_OP ON)
...@@ -155,6 +156,9 @@ endif() ...@@ -155,6 +156,9 @@ endif()
if (FUSION_CONVADDBNRELU_OP) if (FUSION_CONVADDBNRELU_OP)
add_definitions(-DFUSION_CONVADDBNRELU_OP) add_definitions(-DFUSION_CONVADDBNRELU_OP)
endif() endif()
if (FUSION_DWCONVBNRELU_OP)
add_definitions(-DFUSION_DWCONVBNRELU_OP)
endif()
if (PRELU_OP) if (PRELU_OP)
add_definitions(-DPRELU_OP) add_definitions(-DPRELU_OP)
endif() endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册