提交 6540b7cc 编写于 作者: Y yangfei

imp fusion_conv_add_prelu and fusion_conv_add_add_prelu op

上级 25830a06
......@@ -23,6 +23,8 @@ const char *G_OP_TYPE_BOX_CODER = "box_coder";
const char *G_OP_TYPE_CONCAT = "concat";
const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu";
const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu";
const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU = "fusion_conv_add_add_prelu";
const char *G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu";
const char *G_OP_TYPE_FUSION_DWCONV_BN_RELU = "fusion_dwconv_bn_relu";
const char *G_OP_TYPE_FUSION_CONV_BN_RELU = "fusion_conv_bn_relu";
......@@ -83,6 +85,8 @@ std::unordered_map<
{G_OP_TYPE_RESHAPE, {{"X"}, {"Out"}}},
{G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}},
{G_OP_TYPE_FUSION_CONV_ADD_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_PRELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_IM2SEQUENCE, {{"X"}, {"Out"}}},
{G_OP_TYPE_DROPOUT, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD_BN, {{"Input"}, {"Y"}}},
......
......@@ -85,6 +85,8 @@ extern const char *G_OP_TYPE_BOX_CODER;
extern const char *G_OP_TYPE_CONCAT;
extern const char *G_OP_TYPE_ELEMENTWISE_ADD;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_RELU;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU;
extern const char *G_OP_TYPE_FC;
extern const char *G_OP_TYPE_FUSION_CONV_ADD;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_BN_RELU;
......
......@@ -100,8 +100,16 @@ void Node::Folder(
if (change->find(this->type_) != change->end()) {
auto change_pairs = (*change)[this->type_];
for (const auto &change_pair : change_pairs) {
op_desc->GetInputs()[change_pair.second] =
this->op_desc_->GetInputs()[change_pair.first];
std::map<std::string, int> f;
if (op_desc->GetInputs().find(change_pair.second) !=
op_desc->GetInputs().end()) {
for (auto value : this->op_desc_->GetInputs()[change_pair.first]) {
op_desc->GetInputs()[change_pair.second].push_back(value);
}
} else {
op_desc->GetInputs()[change_pair.second] =
this->op_desc_->GetInputs()[change_pair.first];
}
}
}
......
......@@ -95,11 +95,13 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FusionOptimize(
std::vector<std::shared_ptr<Node>> removed_nodes;
matcher->FolderNodes(match_node.get(), &removed_nodes);
for (int j = 0; j < removed_nodes.size(); ++j) {
auto removed_node = removed_nodes[j];
for (int k = removed_nodes.size() - 1; k >= 0; --k) {
auto removed_node = removed_nodes[k];
auto removed_ite =
std::find(nodes.begin(), nodes.end(), removed_node);
nodes.erase(removed_ite);
if (removed_ite != nodes.end()) {
nodes.erase(removed_ite);
}
}
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDADDPRELU_OP
#include "fusion_conv_add_add_prelu_op.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionConvAddAddPReluOp<Dtype, T>::InferShape() const {
auto in_dims = this->param_.Input()->dims();
auto filter_dims = this->param_.Filter()->dims();
const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
int groups = this->param_.Groups();
std::vector<int> dilations = this->param_.Dilations();
PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() &&
dilations.size() == paddings.size() &&
paddings.size() == strides.size()),
"ConvParam is not suitable");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(
math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i],
paddings[i], strides[i]));
}
framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_conv_add_add_prelu, ops::FusionConvAddAddPReluOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fusion_conv_add_add_prelu, ops::FusionConvAddAddPReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDADDPRELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/conv_add_add_prelu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionConvAddAddPReluOpMatcher : public framework::FusionOpMatcher {
public:
FusionConvAddAddPReluOpMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_PRELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}},
{G_OP_TYPE_PRELU, {{"Alpha", "Alpha"}}}
},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU; }
};
template <typename DeviceType, typename T>
class FusionConvAddAddPReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionConvAddAddPReluParam,
operators::ConvAddAddPReluKernel<DeviceType, T>> {
public:
FusionConvAddAddPReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionConvAddAddPReluParam,
operators::ConvAddAddPReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionConvAddAddPReluParam,
operators::ConvAddAddPReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
#ifdef PADDLE_MOBILE_CPU
#ifndef CONV_ADD_ADD_PRELU_REGISTER
#define CONV_ADD_ADD_PRELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_add_prelu_registrar(
new FusionConvAddAddPReluOpMatcher());
#endif
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#ifndef CONV_ADD_ADD_PRELU_REGISTER
#define CONV_ADD_ADD_PRELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_add_prelu_registrar(
new FusionConvAddAddPReluOpMatcher());
#endif
#endif
} // namespace operators
} // namespace paddle_mobile
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(fusion_conv_add_add_prelu);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
USE_OP_FPGA(fusion_conv_add_add_prelu);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDPRELU_OP
#include "fusion_conv_add_prelu_op.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionConvAddPReluOp<Dtype, T>::InferShape() const {
auto in_dims = this->param_.Input()->dims();
auto filter_dims = this->param_.Filter()->dims();
const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
int groups = this->param_.Groups();
std::vector<int> dilations = this->param_.Dilations();
PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() &&
dilations.size() == paddings.size() &&
paddings.size() == strides.size()),
"ConvParam is not suitable");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(
math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i],
paddings[i], strides[i]));
}
framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_conv_add_prelu, ops::FusionConvAddPReluOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fusion_conv_add_prelu, ops::FusionConvAddPReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDPRELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/conv_add_prelu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionConvAddPReluOpMatcher : public framework::FusionOpMatcher {
public:
FusionConvAddPReluOpMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_PRELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}},
{G_OP_TYPE_PRELU, {{"Alpha", "Alpha"}}}
},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_PRELU; }
};
template <typename DeviceType, typename T>
class FusionConvAddPReluOp : public framework::OperatorWithKernel<
DeviceType, FusionConvAddPReluParam,
operators::ConvAddPReluKernel<DeviceType, T>> {
public:
FusionConvAddPReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionConvAddPReluParam,
operators::ConvAddPReluKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionConvAddPReluParam,
operators::ConvAddPReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
#ifdef PADDLE_MOBILE_CPU
#ifndef CONV_ADD_PRELU_REGISTER
#define CONV_ADD_PRELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_prelu_registrar(
new FusionConvAddPReluOpMatcher());
#endif
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#ifndef CONV_ADD_PRELU_REGISTER
#define CONV_ADD_PRELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_prelu_registrar(
new FusionConvAddPReluOpMatcher());
#endif
#endif
} // namespace operators
} // namespace paddle_mobile
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(fusion_conv_add_prelu);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
USE_OP_FPGA(fusion_conv_add_prelu);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDADDPRELU_OP
#include "operators/kernel/conv_add_add_prelu_kernel.h"
#include "operators/kernel/central-arm-func/conv_add_add_prelu_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddAddPReluKernel<CPU, float>::Init(
FusionConvAddAddPReluParam *param) {
return true;
}
template <>
void ConvAddAddPReluKernel<CPU, float>::Compute(
const FusionConvAddAddPReluParam &param) const {
ConvAddAddPReluCompute<float>(param);
}
template class ConvAddAddPReluKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDPRELU_OP
#include "operators/kernel/conv_add_prelu_kernel.h"
#include "operators/kernel/central-arm-func/conv_add_prelu_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddPReluKernel<CPU, float>::Init(FusionConvAddPReluParam *param) {
return true;
}
template <>
void ConvAddPReluKernel<CPU, float>::Compute(
const FusionConvAddPReluParam &param) const {
ConvAddPReluCompute<float>(param);
}
template class ConvAddPReluKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDADDPRELU_OP
#pragma once
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
void ConvAddAddPReluCompute(const FusionConvAddAddPReluParam &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
Tensor bias1 = *param.Bias1();
int axis = param.Axis();
Tensor *output = param.Output();
float *biase_data = bias.data<float>();
float *biase_data1 = bias1.data<float>();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
Tensor aa = *param.InputAlpha();
float *p = aa.data<float>();
DLOG << "bias1";
DLOG << bias1;
std::string mode = param.Mode();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor bias1_batch = bias1.Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
Tensor bias1_slice = bias1_batch.Slice(g * out_step, (g + 1) * out_step);
biase_data1 = bias1_slice.data<float>();
// int n = bias1_slice.dims()[0];
// int m = bias1_slice.dims()[1];
// for(int i=0;i<n*m;i++){
// if(biase_data1[i]!=0)
// DLOG<<biase_data1[i]<<",yangfei";
// }
// math::matmul<float>(filter_slice, false, col_matrix,
// false,
// static_cast<float>(1),
// &out_slice,
// static_cast<float>(1), true,
// biase_data);
math::matmulWithPRelu(filter_slice, false, col_matrix, false, &out_slice,
p, mode, biase_data, biase_data1);
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDPRELU_OP
#pragma once
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
void ConvAddPReluCompute(const FusionConvAddPReluParam &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
// DLOG<<"yangfei";
// DLOG<<bias.dims();
int axis = param.Axis();
Tensor *output = param.Output();
float *biase_data = bias.data<float>();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
Tensor aa = *param.InputAlpha();
float *p = aa.data<float>();
std::string mode = param.Mode();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
// math::matmul<float>(filter_slice, false, col_matrix,
// false,
// static_cast<float>(1),
// &out_slice,
// static_cast<float>(1), true,
// biase_data);
math::matmulWithPRelu(filter_slice, false, col_matrix, false, &out_slice,
p, mode, biase_data, nullptr);
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef FUSION_CONVADDADDPRELU_OP
#include <vector>
#include "framework/ddim.h"
#include "framework/operator.h"
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::DDim;
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvAddAddPReluKernel
: public OpKernelBase<DeviceType, FusionConvAddAddPReluParam> {
public:
void Compute(const FusionConvAddAddPReluParam &param) const;
bool Init(FusionConvAddAddPReluParam *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef FUSION_CONVADDPRELU_OP
#include <vector>
#include "framework/ddim.h"
#include "framework/operator.h"
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::DDim;
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvAddPReluKernel
: public OpKernelBase<DeviceType, FusionConvAddPReluParam> {
public:
void Compute(const FusionConvAddPReluParam &param) const;
bool Init(FusionConvAddPReluParam *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -707,6 +707,26 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
}
}
// 分块矩阵乘法
void InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b,
float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1) {
#pragma omp parallel for
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
#if __aarch64__
// AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#else
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif
}
}
WriteWithAddPRelu(mc, nc, c, C, ldc, p, mode, bias, bias1);
}
#if __ARM_NEON
#if __aarch64__
......@@ -987,6 +1007,81 @@ void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
}
}
}
// C = A * B + C,prelu(C)
void WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t cv1;
float32x4_t biasv;
float32x4_t biasv1;
float32x4_t zero = vdupq_n_f32(0.0);
float32x4_t pv;
float *ptr = p;
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
if (bias1 == nullptr) {
biasv1 = zero;
} else {
biasv1 = vld1q_dup_f32(bias1 + i);
}
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vaddq_f32(cv, biasv1);
cv = vmaxq_f32(cv, zero);
cv1 = vminq_f32(cv, zero);
if (mode == "channel") {
cv1 = vmulq_n_f32(cv1, ptr[i]);
} else if (mode == "element") {
pv = vld1q_f32(ptr);
cv1 = vmulq_f32(cv1, pv);
ptr = ptr + 4;
} else {
cv1 = vmulq_n_f32(cv1, ptr[0]);
}
cv = vaddq_f32(cv, cv1);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vaddq_f32(cv, biasv1);
cv = vmaxq_f32(cv, zero);
cv1 = vminq_f32(cv, zero);
if (mode == "channel") {
cv1 = vmulq_n_f32(cv1, ptr[i]);
} else if (mode == "element") {
pv = vld1q_f32(ptr);
cv1 = vmulq_f32(cv1, pv);
ptr = ptr + 4;
} else {
cv1 = vmulq_n_f32(cv1, ptr[0]);
}
cv = vaddq_f32(cv, cv1);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
......@@ -1971,6 +2066,162 @@ void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
}
}
// C = A * B + C,prelu(C)
void WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t cv1;
float32x4_t cv2;
float32x4_t biasv;
float32x4_t biasv1;
float32x4_t zero = vdupq_n_f32(0.0);
float32x4_t pv;
float *ptr = p;
float *tmp;
if (bias1 == nullptr) {
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv1 = vmaxq_f32(cv, zero);
cv2 = vminq_f32(cv, zero);
if (mode == "channel") {
cv2 = vmulq_n_f32(cv2, ptr[i]);
} else if (mode == "element") {
pv = vld1q_f32(ptr);
cv2 = vmulq_f32(cv2, pv);
ptr = ptr + 4;
} else {
cv1 = vmulq_n_f32(cv2, ptr[0]);
}
cv = vaddq_f32(cv1, cv2);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv1 = vmaxq_f32(cv, zero);
cv2 = vminq_f32(cv, zero);
if (mode == "channel") {
cv2 = vmulq_n_f32(cv2, ptr[i]);
} else if (mode == "element") {
pv = vld1q_f32(ptr);
cv2 = vmulq_f32(cv2, pv);
ptr = ptr + 4;
} else {
cv2 = vmulq_n_f32(cv2, ptr[0]);
}
cv = vaddq_f32(cv1, cv2);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
if (mode == "element") {
ptr++;
}
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
if (mode == "element") {
ptr++;
}
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
if (mode == "element") {
ptr++;
}
C_ptr++;
}
}
}
} else {
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
tmp = bias1 + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
biasv1 = vld1q_f32(tmp);
tmp += 4;
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vaddq_f32(cv, biasv1);
cv1 = vmaxq_f32(cv, zero);
cv2 = vminq_f32(cv, zero);
if (mode == "channel") {
cv2 = vmulq_n_f32(cv2, ptr[i]);
} else if (mode == "element") {
pv = vld1q_f32(ptr);
cv2 = vmulq_f32(cv2, pv);
ptr = ptr + 4;
} else {
cv2 = vmulq_n_f32(cv2, ptr[0]);
}
cv = vaddq_f32(cv1, cv2);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
biasv1 = vld1q_f32(tmp);
tmp += 4;
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vaddq_f32(cv, biasv1);
cv1 = vmaxq_f32(cv, zero);
cv2 = vminq_f32(cv, zero);
if (mode == "channel") {
cv2 = vmulq_n_f32(cv2, ptr[i]);
} else if (mode == "element") {
pv = vld1q_f32(ptr);
cv2 = vmulq_f32(cv2, pv);
} else {
cv2 = vmulq_n_f32(cv2, ptr[0]);
}
cv = vaddq_f32(cv1, cv2);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
tmp++;
if (mode == "element") {
ptr++;
}
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
tmp++;
if (mode == "element") {
ptr++;
}
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
tmp++;
if (mode == "element") {
ptr++;
}
}
}
}
}
}
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
float *bias) {
......@@ -2512,6 +2763,8 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {}
void WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1) {}
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias) {}
......@@ -2648,6 +2901,74 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile::memory::Free(zero);
}
void SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
const float *B, int ldb, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024;
int L2 = 0.5 * 1024 * 1024;
KC = k;
MC = L1 / (KC * sizeof(float));
NC = L2 / (KC * sizeof(float));
// make sure MC is multiple of MR, and NC is multiple of NR
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
for (int l = 0; l < KC; ++l) {
zero[l] = 0;
}
int mc, nc;
for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#endif
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
#if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
if (bias1 == nullptr) {
InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc,
p + i, mode, bias + i, nullptr);
} else {
InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc,
p + i, mode, bias + i, bias1 + i * ldc + j);
}
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
// 32位 float 矩阵乘法
void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "common/log.h"
// 矩阵取值运算宏,假设矩阵按行存储
#define A(i, j) A[(i)*lda + (j)]
......@@ -79,6 +80,9 @@ void InnerKernelWithBias(int mc, int nc, float alpha, const float *a,
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
void InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b,
float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
/*
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
......@@ -108,6 +112,9 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C,prelu(C)
void WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
// C = A * B + bias ,relu(C)
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias);
......@@ -146,6 +153,10 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
void SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
const float *B, int ldb, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
// 32位 float 矩阵乘法(openmp 多线程版本)
void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
......
......@@ -87,6 +87,37 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
new_bias->data<float>() + group);
#endif
}
void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
framework::Tensor *matrix_out, float *p, std::string mode,
float *bias, float *bias1) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
// PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 &&
// dim_out.size() ==
// 2,
// "The input and output of matmul be matrix");
//
// PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
// platform::is_cpu_place(matrix_b.place())
// &&
// platform::is_cpu_place(matrix_out->place()),
// "Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
#ifdef _OPENMP
Sgemm_omp(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu, bias);
#else
SgemmWithPRelu(M, N, K, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
matrix_out->data<float>(), N, p, mode, bias, bias1);
#endif
}
} // namespace math
} // namespace operators
......
......@@ -33,6 +33,11 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
framework::Tensor *matrix_out, T beta, bool relu,
framework::Tensor *new_scale, framework::Tensor *new_bias,
int group);
void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
framework::Tensor *matrix_out, float *p, std::string mode,
float *bias, float *bias1);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -92,6 +92,10 @@ class OpParam {
static T *InputYFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Y", inputs, scope);
}
template <typename T>
static T *InputYFrom1(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue1<T>("Y", inputs, scope);
}
template <typename T>
static T *InputZFrom(const VariableNameMap &inputs, const Scope &scope) {
......@@ -217,6 +221,19 @@ class OpParam {
return nullptr;
}
}
template <typename T>
static T *GetVarValue1(const string &key, const VariableNameMap &var_map,
const Scope &scope) {
PADDLE_MOBILE_ENFORCE(var_map.count(key) > 0,
"%s is not contained in var_map", key.c_str())
auto var_vec = var_map.at(key);
if (!var_vec.empty()) {
auto var = scope.FindVar(var_vec[1]);
return var->GetMutable<T>();
} else {
return nullptr;
}
}
template <typename T>
static vector<T *> GetMultiVarValue(const string &key,
......@@ -1174,6 +1191,48 @@ class FusionConvAddReluParam : public FusionConvAddParam<DeviceType> {
};
#endif
#ifdef FUSION_CONVADDPRELU_OP
class FusionConvAddPReluParam : public FusionConvAddParam {
public:
FusionConvAddPReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionConvAddParam(inputs, outputs, attrs, scope) {
alpha_ = InputAlphaFrom<LoDTensor>(inputs, scope);
mode_ = GetAttr<std::string>("mode", attrs);
framework::DDim dims = alpha_->dims();
}
const Tensor *InputAlpha() const { return alpha_; }
const std::string &Mode() const { return mode_; }
private:
Tensor *alpha_;
std::string mode_;
};
#endif
#ifdef FUSION_CONVADDADDPRELU_OP
class FusionConvAddAddPReluParam : public FusionConvAddParam {
public:
FusionConvAddAddPReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionConvAddParam(inputs, outputs, attrs, scope) {
bias1_ = InputYFrom1<LoDTensor>(inputs, scope);
alpha_ = InputAlphaFrom<LoDTensor>(inputs, scope);
mode_ = GetAttr<std::string>("mode", attrs);
framework::DDim dims = alpha_->dims();
}
const Tensor *InputAlpha() const { return alpha_; }
const std::string &Mode() const { return mode_; }
const Tensor *Bias1() const { return bias1_; }
private:
Tensor *alpha_;
std::string mode_;
Tensor *bias1_;
};
#endif
#ifdef FUSION_CONVADDBNRELU_OP
template <typename Dtype>
class FusionConvAddBNReluParam : public OpParam {
......
......@@ -124,6 +124,7 @@ if(NOT FOUND_MATCH)
set(DEPTHWISECONV_OP ON)
set(ELEMENTWISEADD_OP ON)
set(FUSION_CONVADD_OP ON)
set(FUSION_CONVADDPRELU_OP OFF)
set(FUSION_CONVADDRELU_OP ON)
set(FUSION_FC_OP ON)
set(LRN_OP ON)
......@@ -137,6 +138,7 @@ if(NOT FOUND_MATCH)
set(SOFTMAX_OP ON)
set(TRANSPOSE_OP ON)
set(FUSION_CONVADDBNRELU_OP ON)
set(FUSION_CONVADDADDPRELU_OP ON)
set(FUSION_DWCONVBNRELU_OP ON)
set(FUSION_CONVBNRELU_OP ON)
set(PRELU_OP ON)
......@@ -192,6 +194,12 @@ endif()
if (FUSION_CONVADDRELU_OP)
add_definitions(-DFUSION_CONVADDRELU_OP)
endif()
if (FUSION_CONVADDPRELU_OP)
add_definitions(-DFUSION_CONVADDPRELU_OP)
endif()
if (FUSION_CONVADDADDPRELU_OP)
add_definitions(-DFUSION_CONVADDADDPRELU_OP)
endif()
if (FUSION_FC_OP)
add_definitions(-DFUSION_FC_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册