提交 74478de5 编写于 作者: E eclipsess

conflict

上级 25e97deb
......@@ -23,8 +23,9 @@ const std::string G_OP_TYPE_BOX_CODER = "box_coder";
const std::string G_OP_TYPE_CONCAT = "concat";
const std::string G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu";
const std::string G_OP_TYPE_FC = "fc";
const std::string G_OP_TYPE_CONV_ADD = "conv_add";
const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu";
const std::string G_OP_TYPE_FC = "fusion_fc";
const std::string G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add";
const std::string G_OP_TYPE_LRN = "lrn";
const std::string G_OP_TYPE_MUL = "mul";
const std::string G_OP_TYPE_MULTICLASS_NMS = "multiclass_nms";
......@@ -44,7 +45,7 @@ std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key = {
{G_OP_TYPE_CONV, {{"Input"}, {"Output"}}},
{G_OP_TYPE_CONV_ADD, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD, {{"Input"}, {"Out"}}},
{G_OP_TYPE_RELU, {{"X"}, {"Out"}}},
{G_OP_TYPE_SOFTMAX, {{"X"}, {"Out"}}},
{G_OP_TYPE_MUL, {{"X"}, {"Out"}}},
......@@ -59,6 +60,8 @@ std::unordered_map<
{G_OP_TYPE_TRANSPOSE, {{"X"}, {"Out"}}},
{G_OP_TYPE_BOX_CODER,
{{"PriorBox", "PriorBoxVar", "TargetBox"}, {"OutputBox"}}},
{G_OP_TYPE_FUSION_CONV_ADD_BN_RELU,
{{"Input"}, {"Out"}}},
{G_OP_TYPE_PRIOR_BOX, {{"Image", "Input"}, {"Boxes", "Variances"}}},
{G_OP_TYPE_MULTICLASS_NMS, {{"BBoxes", "Scores"}, {"Out"}}},
{G_OP_TYPE_FC, {{"X", "Y", "Z"}, {"Out"}}},
......
......@@ -79,7 +79,9 @@ extern const std::string G_OP_TYPE_CONCAT;
extern const std::string G_OP_TYPE_ELEMENTWISE_ADD;
extern const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU;
extern const std::string G_OP_TYPE_FC;
extern const std::string G_OP_TYPE_CONV_ADD;
extern const std::string G_OP_TYPE_FUSION_CONV_ADD;
extern const std::string G_OP_TYPE_FUSION_CONV_ADD_BN_RELU;
extern const std::string G_OP_TYPE_LRN;
extern const std::string G_OP_TYPE_MUL;
extern const std::string G_OP_TYPE_MULTICLASS_NMS;
......
......@@ -63,7 +63,7 @@ class OperatorBase {
std::vector<string> GetOutKeys() const;
virtual void RunImpl() const = 0;
virtual void Init() const = 0;
virtual void Init() = 0;
/*
* @b op 运算所需的输入, 如上一层的输出结果、卷积核
* */
......@@ -117,8 +117,8 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
virtual void InferShape() const = 0;
void Init() const {
PADDLE_MOBILE_ENFORCE(kernel_.Init(param_), " %s kernel init failed",
void Init() {
PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), " %s kernel init failed",
this->type_.c_str());
}
......@@ -146,7 +146,7 @@ class OpKernelBase {
}
#endif
virtual void Compute(const P &para) const = 0;
virtual bool Init(const P &para) const { return true; };
virtual bool Init(P *para) { return true; };
virtual ~OpKernelBase() = default;
private:
......
......@@ -32,7 +32,7 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
param_(inputs, outputs, attrs, *scope) {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() const {}
void Init() {}
void InferShape() const {
auto out_dims = param_.Out()->dims();
......
......@@ -33,7 +33,7 @@ class FetchOp : public framework::OperatorBase<DeviceType> {
param_(inputs, outputs, attrs, *scope) {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() const {}
void Init() {}
void InferShape() const {
auto x_dims = param_.InputX()->dims();
......
......@@ -50,8 +50,8 @@ template class FusionConvAddOp<CPU, float>;
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(conv_add);
REGISTER_OPERATOR_CPU(conv_add, ops::FusionConvAddOp);
USE_OP_CPU(fusion_conv_add);
REGISTER_OPERATOR_CPU(fusion_conv_add, ops::FusionConvAddOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
USE_OP_MALI_GPU(conv_add);
......
......@@ -42,7 +42,7 @@ class FusionConvAddMatcher : public framework::FusionOpMatcher {
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes);
}
std::string Type() { return G_OP_TYPE_CONV_ADD; }
std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD; }
};
template <typename DeviceType, typename T>
......@@ -68,11 +68,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel<
#ifdef PADDLE_MOBILE_CPU
#ifndef CONV_ADD_REGISTER
static framework::FusionOpRegistrar convadd_registrar(
new FusionConvAddMatcher());
#define CONV_ADD_REGISTER
#endif
//#ifndef CONV_ADD_REGISTER
//static framework::FusionOpRegistrar convadd_registrar(
// new FusionConvAddMatcher());
//#define CONV_ADD_REGISTER
//#endif
#endif
......
//
// Created by Yang,Sui on 2018/6/28.
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDBNRELU_OP
#include "operators/fusion_conv_add_bn_relu_op.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionConvAddBNReluOp<Dtype, T>::InferShape() const {
auto in_dims = this->param_.Input()->dims();
auto filter_dims = this->param_.Filter()->dims();
const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
int groups = this->param_.Groups();
std::vector<int> dilations = this->param_.Dilations();
PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() &&
dilations.size() == paddings.size() &&
paddings.size() == strides.size()),
"ConvParam is not suitable");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(
math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i],
paddings[i], strides[i]));
}
framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim);
}
template class FusionConvAddBNReluOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(fusion_conv_add_bn_relu);
REGISTER_OPERATOR_CPU(fusion_conv_add_bn_relu, ops::FusionConvAddBNReluOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
//
// Created by Yang,Sui on 2018/6/28.
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#ifndef PADDLE_MOBILE_FUSION_CONV_ADD_BN_RELU_OP_H
#define PADDLE_MOBILE_FUSION_CONV_ADD_BN_RELU_OP_H
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
#endif //PADDLE_MOBILE_FUSION_CONV_ADD_BN_RELU_OP_H
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define FUSION_CONVADDBNRELU_OP
#ifdef FUSION_CONVADDBNRELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "op_param.h"
#include "operators/kernel/conv_add_bn_relu_kernel.h"
namespace paddle_mobile {
namespace operators {
using std::string;
using std::vector;
class FusionConvAddBNReluMatcher : public framework::FusionOpMatcher {
public:
FusionConvAddBNReluMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
vector<std::shared_ptr<framework::OpDesc>> origin_descs =
node->OpDescs(node_.Depth());
node->Folder(node_.Depth(), Type(),
{ {G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}},
{G_OP_TYPE_BATCHNORM, {{"Scale", "Scale"},
{"Mean", "Mean"},
{"Bias", "Bias"},
{"Variance", "Variance"}}}}, removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_BN_RELU; }
};
template <typename DeviceType, typename T>
class FusionConvAddBNReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionConvAddBNReluParam,
operators::ConvAddBNReluKernel<DeviceType, T>> {
public:
FusionConvAddBNReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionConvAddBNReluParam,
operators::ConvAddBNReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionConvAddBNReluParam,
operators::ConvAddBNReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
#ifdef PADDLE_MOBILE_CPU
//#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
//static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
// new FusionConvAddBNReluMatcher());
//#define FUSION_CONV_ADD_BN_RELU_REGISTER
//#endif
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
new FusionConvAddBNReluMatcher());
#define FUSION_CONV_ADD_BN_RELU_REGISTER
#endif
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -55,8 +55,8 @@ template class FusionFcOp<CPU, float>;
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(fc);
REGISTER_OPERATOR_CPU(fc, ops::FusionFcOp);
USE_OP_CPU(fusion_fc);
REGISTER_OPERATOR_CPU(fusion_fc, ops::FusionFcOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
USE_OP_MALI_GPU(fc);
......
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool BatchNormKernel<CPU, float>::Init(const BatchNormParam &para) const {
bool BatchNormKernel<CPU, float>::Init(BatchNormParam *param) const {
return true;
}
......
......@@ -111,7 +111,7 @@ void DecodeCenterSize(const framework::Tensor& target_box,
}
template <>
bool BoxCoderKernel<CPU, float>::Init(const BoxCoderParam& para) const {
bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam* param) const {
return true;
}
......
......@@ -53,7 +53,7 @@ class ConcatFunctor {
};
template <>
bool ConcatKernel<CPU, float>::Init(const ConcatParam &para) const {
bool ConcatKernel<CPU, float>::Init(ConcatParam *param) const {
return true;
}
......
//
// Created by Yang,Sui on 2018/6/28.
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDBNRELU_OP
#include "operators/kernel/conv_add_bn_relu_kernel.h"
#include "operators/kernel/central-arm-func/conv_add_bn_relu_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddBNReluKernel<CPU, float>::Init(
FusionConvAddBNReluParam *param) const {
const Tensor *mean = (*param).InputMean();
const Tensor *variance = (*param).InputVariance();
const Tensor *scale = (*param).InputScale();
const Tensor *bias = (*param).InputBias();
const float epsilon = (*param).Epsilon();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
const int C = mean->numel();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor *new_scale = new Tensor();
Tensor *new_bias = new Tensor();
auto new_scale_ptr = new_scale->mutable_data<float>({C});
auto new_bias_ptr = new_bias->mutable_data<float>({C});
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
std::cout << "yes" << std::endl;
(*param).SetNewScale(new_scale);
(*param).SetNewBias(new_bias);
return true;
}
template <>
void ConvAddBNReluKernel<CPU, float>::Compute(
const FusionConvAddBNReluParam &param) const {
ConvAddBNReluCompute<float>(param);
}
template class ConvAddBNReluKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -21,8 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddReluKernel<CPU, float>::Init(
const FusionConvAddReluParam &para) const {
bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam *param) const {
return true;
}
......
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvKernel<CPU, float>::Init(const ConvParam &para) const {
bool ConvKernel<CPU, float>::Init(ConvParam *param) const {
return true;
}
......
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool DepthwiseConvKernel<CPU, float>::Init(const ConvParam &para) const {
bool DepthwiseConvKernel<CPU, float>::Init(ConvParam *param) const {
return true;
}
......
......@@ -27,8 +27,7 @@ struct AddFunctor {
};
template <>
bool ElementwiseAddKernel<CPU, float>::Init(
const ElementwiseAddParam &para) const {
bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam *param) const {
return true;
}
......
......@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool FusionFcKernel<CPU, float>::Init(const FusionFcParam &para) const {
bool FusionFcKernel<CPU, float>::Init(FusionFcParam *param) const {
return true;
}
......
......@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool LrnKernel<CPU, float>::Init(const LrnParam &para) const {
bool LrnKernel<CPU, float>::Init(LrnParam *param) const {
return true;
}
......
......@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool MulKernel<CPU, float>::Init(const MulParam &para) const {
bool MulKernel<CPU, float>::Init(MulParam *param) const {
return true;
}
......
......@@ -204,8 +204,7 @@ void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
}
template <>
bool MultiClassNMSKernel<CPU, float>::Init(
const MultiClassNMSParam& para) const {
bool MultiClassNMSKernel<CPU, float>::Init(MultiClassNMSParam* param) const {
return true;
}
......
......@@ -36,7 +36,7 @@ inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
}
template <>
bool PoolKernel<CPU, float>::Init(const PoolParam &para) const {
bool PoolKernel<CPU, float>::Init(PoolParam *param) const {
return true;
}
......
......@@ -27,7 +27,7 @@ struct ClipFunctor {
};
template <>
bool PriorBoxKernel<CPU, float>::Init(const PriorBoxParam &para) const {
bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam *param) const {
return true;
}
......
......@@ -26,7 +26,7 @@ struct ReluFunctor {
};
template <>
bool ReluKernel<CPU, float>::Init(const ReluParam &para) const {
bool ReluKernel<CPU, float>::Init(ReluParam *param) const {
return true;
}
......
......@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ReshapeKernel<CPU, float>::Init(const ReshapeParam &para) const {
bool ReshapeKernel<CPU, float>::Init(ReshapeParam *param) const {
return true;
}
......
......@@ -72,7 +72,7 @@ void sigmoid(const Tensor *X, Tensor *Y) {
}
template <>
bool SigmoidKernel<CPU, float>::Init(const SigmoidParam &para) const {
bool SigmoidKernel<CPU, float>::Init(SigmoidParam *param) const {
return true;
}
......
......@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool SoftmaxKernel<CPU, float>::Init(const SoftmaxParam &para) const {
bool SoftmaxKernel<CPU, float>::Init(SoftmaxParam *param) const {
return true;
}
......
......@@ -35,7 +35,7 @@ namespace operators {
// }
template <>
bool TransposeKernel<CPU, float>::Init(const TransposeParam& para) const {
bool TransposeKernel<CPU, float>::Init(TransposeParam* param) const {
return true;
}
......
......@@ -29,7 +29,7 @@ class BatchNormKernel
: public framework::OpKernelBase<DeviceType, BatchNormParam> {
public:
void Compute(const BatchNormParam &param) const;
bool Init(const BatchNormParam &para) const;
bool Init(BatchNormParam *param) const;
};
} // namespace operators
......
......@@ -30,7 +30,7 @@ class BoxCoderKernel
: public framework::OpKernelBase<DeviceType, BoxCoderParam> {
public:
void Compute(const BoxCoderParam& param) const;
bool Init(const BoxCoderParam& para) const;
bool Init(BoxCoderParam* param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -53,7 +53,7 @@ void BatchnormCompute(const BatchNormParam &param) {
"C must equal to variance.numel()");
int HXW = H * W;
if (HXW > 32) {
if (0&&HXW > 32) {
int NXC = N * C;
float *inv_std_ptr = new float[NXC * 4];
float *volatile new_scale_ptr = new float[NXC * 4];
......@@ -222,8 +222,15 @@ void BatchnormCompute(const BatchNormParam &param) {
}
}
}
}
}
// for(int i = 0; i < new_scale.numel(); i++){
// std::cout << "new_scale " << new_scale_ptr[i] <<std::endl;
// }
// for(int i = 0; i < new_bias.numel(); i++){
// std::cout << "new_bias " << new_bias_ptr[i] <<std::endl;
// }
delete[] inv_std_ptr;
}
......
//
// Created by Yang,Sui on 2018/6/28.
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#ifndef PADDLE_MOBILE_CONV_ADD_BN_RELU_FUNC_H
#define PADDLE_MOBILE_CONV_ADD_BN_RELU_FUNC_H
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
#endif //PADDLE_MOBILE_CONV_ADD_BN_RELU_FUNC_H
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDBNRELU_OP
#pragma once
#include "operators/math/depthwiseconv3x3s1p1.h"
#include "operators/kernel/conv_add_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
const Tensor *input = param.Input();
DLOG << "input: " << *input;
Tensor filter = *param.Filter();
DLOG << "filter: " << filter;
Tensor bias = *param.Bias();
DLOG << "bias: " << bias;
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
auto new_bias_ptr = new_bias.data<float>();
auto new_scale_ptr = new_scale.data<float>();
//
// for(int i = 0; i < new_scale.numel(); i++){
// std::cout << "new_scale " << new_scale_ptr[i] <<std::endl;
// }
// for(int i = 0; i < new_bias.numel(); i++){
// std::cout << "new_bias " << new_bias_ptr[i] <<std::endl;
// }
int axis = param.Axis();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
Tensor *output = param.Output();
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
if (filter_shape_vec[2] == 3 && strides[0] == 1 && groups > 1) {
math::DepthwiseConv3x3s1p1(input, filter, output, &bias, 1, &new_scale,
&new_bias, 1, 1);
} else {
const int batch_size = static_cast<int>(input->dims()[0]);
math::expand_bias(bias, axis, output->dims());
output->ShareDataWith(bias);
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), false);
}
}
auto output_ptr = output->data<float>();
for (int c = 0; c < output_matrix_shape[0]; c++){
// int start = c * output_matrix_shape[1];
for (int j = 0; j < output_matrix_shape[1]; j++){
// output_ptr[start + j] = output_ptr[start +j]*new_scale_ptr[c]+new_bias_ptr[c];
// output_ptr[start + j] = output_ptr[start+j]< 0 ? 0 : output_ptr[start +j];
}
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -27,7 +27,7 @@ template <typename DeviceType, typename T>
class ConcatKernel : public framework::OpKernelBase<DeviceType, ConcatParam> {
public:
void Compute(const ConcatParam &param) const;
bool Init(const ConcatParam &para) const;
bool Init(ConcatParam *param) const;
};
} // namespace operators
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#ifdef FUSION_CONVADD_BN_RELU_OP
#ifdef FUSION_CONVADDBNRELU_OP
#include <vector>
#include "framework/ddim.h"
......@@ -26,20 +26,20 @@ limitations under the License. */
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
namespace operators {
using framework::DDim;
using framework::OpKernelBase;
using framework::DDim;
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvAddBNReluKernel
: public OpKernelBase<DeviceType, FusionConvAddBNReluParam> {
public:
void Compute(const FusionConvAddBNReluParam &param) const;
bool Init(const FusionConvAddBNReluParam &para) const;
};
template <typename DeviceType, typename T>
class ConvAddBNReluKernel
: public OpKernelBase<DeviceType, FusionConvAddBNReluParam> {
public:
void Compute(const FusionConvAddBNReluParam &param) const;
bool Init(FusionConvAddBNReluParam *param) const;
};
} // namespace operators
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -40,7 +40,7 @@ template <typename DeviceType, typename T>
class ConvAddKernel : public OpKernelBase<DeviceType, FusionConvAddParam> {
public:
void Compute(const FusionConvAddParam &param) const;
bool Init(const FusionConvAddParam &para) const;
bool Init(FusionConvAddParam *param) const;
};
} // namespace operators
......
......@@ -36,7 +36,7 @@ class ConvAddReluKernel
: public OpKernelBase<DeviceType, FusionConvAddReluParam> {
public:
void Compute(const FusionConvAddReluParam &param) const;
bool Init(const FusionConvAddReluParam &para) const;
bool Init(FusionConvAddReluParam *param) const;
};
} // namespace operators
......
......@@ -32,7 +32,7 @@ template <typename DeviceType, typename T>
class ConvKernel : public OpKernelBase<DeviceType, ConvParam> {
public:
void Compute(const ConvParam &param) const;
bool Init(const ConvParam &para) const;
bool Init(ConvParam *param) const;
};
} // namespace operators
......
......@@ -31,7 +31,7 @@ template <typename DeviceType, typename T>
class DepthwiseConvKernel : public OpKernelBase<DeviceType, ConvParam> {
public:
void Compute(const ConvParam &param) const;
bool Init(const ConvParam &para) const;
bool Init(ConvParam *param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -30,7 +30,7 @@ class ElementwiseAddKernel
: public framework::OpKernelBase<DeviceType, ElementwiseAddParam> {
public:
void Compute(const ElementwiseAddParam &param) const;
bool Init(const ElementwiseAddParam &para) const;
bool Init(ElementwiseAddParam *param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvKernel<FPGA, float>::Init(const ConvParam &para) const {
bool ConvKernel<FPGA, float>::Init(ConvParam *param) const {
return true;
}
......
......@@ -28,7 +28,7 @@ class FusionFcKernel
: public framework::OpKernelBase<DeviceType, FusionFcParam> {
public:
void Compute(const FusionFcParam& param) const;
bool Init(const FusionFcParam& para) const;
bool Init(FusionFcParam* param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -170,7 +170,7 @@ template <typename DeviceType, typename T>
class LrnKernel : public framework::OpKernelBase<DeviceType, LrnParam> {
public:
void Compute(const LrnParam &param) const;
bool Init(const LrnParam &para) const;
bool Init(LrnParam *param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -128,7 +128,7 @@ class AclBatchNormOp : public acl::ACLOperator {
};
template <>
bool BatchNormKernel<GPU_MALI, float>::Init(const BatchNormParam& param) const {
bool BatchNormKernel<GPU_MALI, float>::Init(BatchNormParam *param) const {
AclBatchNormOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclBatchNormOp<GPU_MALI, float>*>(this->GetAclOp());
if (acl_op == nullptr) {
......
......@@ -195,7 +195,7 @@ class AclConvOp : public acl::ACLOperator {
};
template <>
bool ConvKernel<GPU_MALI, float>::Init(const ConvParam& param) const {
bool ConvKernel<GPU_MALI, float>::Init(ConvParam *param) const {
AclConvOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclConvOp<GPU_MALI, float>*>(this->GetAclOp());
if (acl_op == nullptr) {
......
......@@ -29,7 +29,7 @@ template <typename DeviceType, typename T>
class MulKernel : public framework::OpKernelBase<DeviceType, MulParam> {
public:
void Compute(const MulParam &param) const;
bool Init(const MulParam &para) const;
bool Init(MulParam *param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -28,7 +28,7 @@ class MultiClassNMSKernel
: public framework::OpKernelBase<DeviceType, MultiClassNMSParam> {
public:
void Compute(const MultiClassNMSParam& param) const;
bool Init(const MultiClassNMSParam& para) const;
bool Init(MultiClassNMSParam* param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -28,7 +28,7 @@ template <typename DeviceType, typename T>
class PoolKernel : public OpKernelBase<DeviceType, PoolParam> {
public:
void Compute(const PoolParam &param) const override;
bool Init(const PoolParam &para) const;
bool Init(PoolParam *param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -55,7 +55,7 @@ class PriorBoxKernel
: public framework::OpKernelBase<DeviceType, PriorBoxParam> {
public:
void Compute(const PriorBoxParam& param) const;
bool Init(const PriorBoxParam& para) const;
bool Init(PriorBoxParam* param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -27,7 +27,7 @@ template <typename DeviceType, typename T>
class ReluKernel : public framework::OpKernelBase<DeviceType, ReluParam> {
public:
void Compute(const ReluParam& param) const;
bool Init(const ReluParam& para) const;
bool Init(ReluParam* param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -71,7 +71,7 @@ template <typename DeviceType, typename T>
class ReshapeKernel : public framework::OpKernelBase<DeviceType, ReshapeParam> {
public:
void Compute(const ReshapeParam& param) const;
bool Init(const ReshapeParam& para) const;
bool Init(ReshapeParam* param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -26,7 +26,7 @@ template <typename DeviceType, typename T>
class SigmoidKernel : public OpKernelBase<DeviceType, SigmoidParam> {
public:
void Compute(const SigmoidParam& param) const override;
bool Init(const SigmoidParam& para) const;
bool Init(SigmoidParam* param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -29,7 +29,7 @@ template <typename DeviceType, typename T>
class SoftmaxKernel : public OpKernelBase<DeviceType, SoftmaxParam> {
public:
void Compute(const SoftmaxParam &param) const override;
bool Init(const SoftmaxParam &para) const;
bool Init(SoftmaxParam *param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -29,7 +29,7 @@ class TransposeKernel
: public framework::OpKernelBase<DeviceType, TransposeParam> {
public:
void Compute(const TransposeParam& param) const;
bool Init(const TransposeParam& para) const;
bool Init(TransposeParam* param) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "operators/math/depthwiseconv3x3s1p1.h"
#include <arm_neon.h>
#include <algorithm>
namespace paddle_mobile {
namespace operators {
......@@ -22,11 +23,14 @@ namespace math {
using framework::Tensor;
void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output,
Tensor bias, bool if_bias) {
Tensor *bias, bool if_bias, Tensor *new_scale,
Tensor *new_bias, bool if_bn, bool if_relu) {
const float *input_data = input->data<float>();
const float *filter_data = filter.data<float>();
float *output_data = output->data<float>();
const float *bias_data = bias.data<float>();
const float *bias_data = bias->data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
......@@ -36,6 +40,10 @@ void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output,
const int c = static_cast<int>(input->dims()[1]);
const int hxw = h * w;
float32x4_t vbias = vdupq_n_f32(0.0);
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; ++b) {
const float *filter_data_tmp = filter_data;
......@@ -43,7 +51,10 @@ void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output,
if (if_bias) {
vbias = vdupq_n_f32(bias_data[j]);
}
if (if_bn) {
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
}
int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
......@@ -55,34 +66,55 @@ void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output,
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[l] + w22 * input_data[l + 1] +
bias_data[j];
output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] +
w20 * input_data[2 * l - 2] +
w21 * input_data[2 * l - 1] + bias_data[j];
output_data[(l - 1) * l] =
w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] +
w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1] +
bias_data[j];
output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] +
w01 * input_data[(l - 2) * (l + 1) + 1] +
w10 * input_data[l * l - 2] +
w11 * input_data[l * l - 1] + bias_data[j];
output_data[0] =(w11 * input_data[0] + w12 * input_data[1] + w21 * input_data[l] +
w22 * input_data[l + 1] + bias_data[j]) *
newscale_data[j] +
newbias_data[j];
output_data[l - 1] = (w10 * input_data[l - 2] + w11 * input_data[l - 1] +
w20 * input_data[2 * l - 2] +
w21 * input_data[2 * l - 1] + bias_data[j]) *
newscale_data[j] +
newbias_data[j];
output_data[(l - 1) * l] =
(w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] +
w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1] +
bias_data[j]) *
newscale_data[j] +
newbias_data[j];
output_data[l * l - 1] = (w00 * input_data[(l - 2) * (l + 1)] +
w01 * input_data[(l - 2) * (l + 1) + 1] +
w10 * input_data[l * l - 2] +
w11 * input_data[l * l - 1] + bias_data[j]) *
newscale_data[j] +
newbias_data[j];
if(if_relu){
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[l-1] = output_data[l-1] < 0 ? 0 : output_data[l-1];
output_data[(l-1)*l] = output_data[(l-1)*l] < 0 ? 0 : output_data[(l-1)*l];
output_data[l * l - 1] = output_data[l * l - 1] < 0 ? 0 : output_data[l * l - 1];
}
for (int i = 1; i < l - 1; ++i) {
output_data[i * l] =
w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] +
w11 * input_data[i * l] + w12 * input_data[i * l + 1] +
w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1] +
bias_data[j];
output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] +
w01 * input_data[i * l + l - 1 - l] +
w10 * input_data[i * l + l - 1 - 1] +
w11 * input_data[i * l + l - 1] +
w20 * input_data[i * l + l - 1 + l - 1] +
w21 * input_data[i * l + l - 1 + l] +
bias_data[j];
(w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] +
w11 * input_data[i * l] + w12 * input_data[i * l + 1] +
w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1] +
bias_data[j]) *
newscale_data[j] +
newbias_data[j];
output_data[i * l + l - 1] =
(w00 * input_data[i * l + l - 1 - l - 1] +
w01 * input_data[i * l + l - 1 - l] +
w10 * input_data[i * l + l - 1 - 1] +
w11 * input_data[i * l + l - 1] +
w20 * input_data[i * l + l - 1 + l - 1] +
w21 * input_data[i * l + l - 1 + l] + bias_data[j]) *
newscale_data[j] +
newbias_data[j];
if(if_relu){
output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l];
output_data[i * l + l - 1] = output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1];
}
}
// top 1 row and bottom 1 row
......@@ -114,7 +146,10 @@ void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output,
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
......@@ -132,7 +167,10 @@ void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output,
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr + (l - 1) * l, out0);
// can optimize to each 8 stride.
......@@ -161,7 +199,10 @@ void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output,
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
......@@ -190,7 +231,10 @@ void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output,
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0);
......@@ -233,7 +277,10 @@ void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output,
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
output_ptr += 4;
......@@ -264,7 +311,10 @@ void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output,
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
......@@ -282,6 +332,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output,
filter_data_tmp += 9;
}
}
}
} // namespace math
} // namespace operators
......
......@@ -21,7 +21,8 @@ namespace math {
using framework::Tensor;
void DepthwiseConv3x3s1p1(const Tensor *input, Tensor filter, Tensor *output,
Tensor bias, bool if_bias);
Tensor *bias, bool if_bias, Tensor *new_scale,
Tensor *new_bias, bool if_bn, bool if_relu);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -823,6 +823,10 @@ class FusionConvAddParam : public OpParam {
const int &Groups() const { return groups; }
void Set(Tensor *t) {t_ = t;}
const Tensor *Get() const {return t_;}
protected:
Tensor *bias_;
int axis_;
......@@ -833,6 +837,7 @@ class FusionConvAddParam : public OpParam {
vector<int> paddings_;
vector<int> dilations_;
int groups;
Tensor *t_;
};
Print &operator<<(Print &printer, const FusionConvAddParam &conv_param);
......@@ -848,5 +853,91 @@ class FusionConvAddReluParam : public FusionConvAddParam {
};
#endif
#ifdef FUSION_CONVADDBNRELU_OP
class FusionConvAddBNReluParam : public OpParam {
public:
FusionConvAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
bias_ = InputYFrom<LoDTensor>(inputs, scope);
axis_ = GetAttr<int>("axis", attrs);
filter_ = FilterFrom<LoDTensor>(inputs, scope);
input_ = InputFrom<LoDTensor>(inputs, scope);
output_ = OutFrom<LoDTensor>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
input_bias_ = InputBiasFrom<framework::LoDTensor>(inputs, scope);
input_mean_ = InputMeanFrom<framework::LoDTensor>(inputs, scope);
input_scale_ = InputScaleFrom<framework::LoDTensor>(inputs, scope);
input_variance_ = InputVarianceFrom<framework::LoDTensor>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
is_test_ = GetAttr<bool>("is_test", attrs);
}
Tensor *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
const Tensor *Input() const { return input_; }
const Tensor *Filter() const { return filter_; }
Tensor *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const Tensor *InputBias() const { return input_bias_; }
const Tensor *InputMean() const { return input_mean_; }
const Tensor *InputScale() const { return input_scale_; }
const Tensor *InputVariance() const { return input_variance_; }
const float &Epsilon() const { return epsilon_; }
const float &Momentum() const { return momentum_; }
const bool &IsTest() const { return is_test_; }
void SetNewScale(Tensor *new_scale) { new_scale_ = new_scale; }
void SetNewBias(Tensor *new_bias) { new_bias_ = new_bias; }
const Tensor *NewScale() const { return new_scale_; }
const Tensor *NewBias() const { return new_bias_; }
protected:
Tensor *bias_;
int axis_;
Tensor *input_;
Tensor *output_;
Tensor *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
Tensor *input_bias_;
Tensor *input_mean_;
Tensor *input_scale_;
Tensor *input_variance_;
float epsilon_;
float momentum_;
bool is_test_;
Tensor *new_bias_;
Tensor *new_scale_;
};
Print &operator<<(Print &printer, const FusionConvAddParam &conv_param);
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -22,6 +22,7 @@ elseif (NET EQUAL "mobilenet")
set(BATCHNORM_OP ON)
set(POOL_OP ON)
set(RESHAPE_OP ON)
set(FUSION_CONVADDBNRELU_OP)
elseif (NET EQUAL "yolo")
set(BATCHNORM_OP ON)
set(CONV_OP ON)
......@@ -64,6 +65,8 @@ else ()
set(SOFTMAX_OP ON)
set(TRANSPOSE_OP ON)
set(FUSION_CONVADD_RELU_OP ON)
set(FUSION_CONVADDBNRELU_OP ON)
# option(BATCHNORM_OP "" ON)
# option(BOXCODER_OP "" ON)
# option(CONCAT_OP "" ON)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册