提交 5b08ad31 编写于 作者: Z zhaojiaying01

FusionConvXXXParam inheritance ConvParam

上级 1d896aa0
...@@ -343,20 +343,22 @@ class OpParam { ...@@ -343,20 +343,22 @@ class OpParam {
#ifdef CONV_OP #ifdef CONV_OP
template <typename Dtype> template <typename Dtype>
class ConvParam : OpParam { class ConvParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs, ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope) {
filter_ = FilterFrom<GType>(inputs, scope); filter_ = OpParam::FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope); input_ = OpParam::InputFrom<GType>(inputs, scope);
output_ = OutputFrom<GType>(outputs, scope); if (outputs.count("Output")) {
strides_ = GetAttr<vector<int>>("strides", attrs); output_ = OpParam::OutputFrom<GType>(outputs, scope);
paddings_ = GetAttr<vector<int>>("paddings", attrs); }
dilations_ = GetAttr<vector<int>>("dilations", attrs); strides_ = OpParam::GetAttr<vector<int>>("strides", attrs);
groups = GetAttr<int>("groups", attrs); paddings_ = OpParam::GetAttr<vector<int>>("paddings", attrs);
dilations_ = OpParam::GetAttr<vector<int>>("dilations", attrs);
groups = OpParam::GetAttr<int>("groups", attrs);
} }
const RType *Input() const { return input_; } const RType *Input() const { return input_; }
...@@ -1294,52 +1296,29 @@ using FusionFcReluParam = FusionFcParam<DeviceType>; ...@@ -1294,52 +1296,29 @@ using FusionFcReluParam = FusionFcParam<DeviceType>;
#endif #endif
template <typename Dtype> template <typename Dtype>
class FusionConvAddParam : public OpParam { class FusionConvAddParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
FusionConvAddParam(const VariableNameMap &inputs, FusionConvAddParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { const Scope &scope)
bias_ = InputYFrom<GType>(inputs, scope); : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
axis_ = GetAttr<int>("axis", attrs); bias_ = OpParam::InputYFrom<GType>(inputs, scope);
filter_ = FilterFrom<GType>(inputs, scope); axis_ = OpParam::GetAttr<int>("axis", attrs);
input_ = InputFrom<GType>(inputs, scope); output_ = OpParam::OutFrom<GType>(outputs, scope);
output_ = OutFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
} }
RType *Bias() const { return bias_; } RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; } const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; } RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
protected: protected:
RType *bias_; RType *bias_;
int axis_; int axis_;
RType *input_;
RType *output_; RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
private: private:
...@@ -1366,58 +1345,33 @@ class FusionConvAddReluParam : public FusionConvAddParam<DeviceType> { ...@@ -1366,58 +1345,33 @@ class FusionConvAddReluParam : public FusionConvAddParam<DeviceType> {
#endif #endif
#ifdef FUSION_CONVADDPRELU_OP #ifdef FUSION_CONVADDPRELU_OP
template <typename DeviceType> template <typename Dtype>
class FusionConvAddPReluParam : public OpParam { class FusionConvAddPReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<DeviceType>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<DeviceType>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
FusionConvAddPReluParam(const VariableNameMap &inputs, FusionConvAddPReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope)
alpha_ = InputAlphaFrom<GType>(inputs, scope); : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
mode_ = GetAttr<std::string>("mode", attrs); alpha_ = OpParam::InputAlphaFrom<GType>(inputs, scope);
mode_ = OpParam::GetAttr<std::string>("mode", attrs);
framework::DDim dims = alpha_->dims(); framework::DDim dims = alpha_->dims();
bias_ = InputYFrom<GType>(inputs, scope); bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = GetAttr<int>("axis", attrs); axis_ = OpParam::GetAttr<int>("axis", attrs);
filter_ = FilterFrom<GType>(inputs, scope); output_ = OpParam::OutFrom<GType>(outputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
} }
const RType *InputAlpha() const { return alpha_; } const RType *InputAlpha() const { return alpha_; }
const std::string &Mode() const { return mode_; } const std::string &Mode() const { return mode_; }
RType *Bias() const { return bias_; } RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; } const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; } RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
protected: protected:
RType *bias_; RType *bias_;
int axis_; int axis_;
RType *input_;
RType *output_; RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *alpha_; RType *alpha_;
std::string mode_; std::string mode_;
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
...@@ -1433,35 +1387,30 @@ class FusionConvAddPReluParam : public OpParam { ...@@ -1433,35 +1387,30 @@ class FusionConvAddPReluParam : public OpParam {
#endif #endif
#ifdef FUSION_CONVADDADDPRELU_OP #ifdef FUSION_CONVADDADDPRELU_OP
template <typename DeviceType> template <typename Dtype>
class FusionConvAddAddPReluParam : public OpParam { class FusionConvAddAddPReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<DeviceType>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<DeviceType>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
FusionConvAddAddPReluParam(const VariableNameMap &inputs, FusionConvAddAddPReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope)
bias1_ = InputYFrom1<GType>(inputs, scope); : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
alpha_ = InputAlphaFrom<GType>(inputs, scope); bias1_ = OpParam::InputYFrom1<GType>(inputs, scope);
mode_ = GetAttr<std::string>("mode", attrs); alpha_ = OpParam::InputAlphaFrom<GType>(inputs, scope);
mode_ = OpParam::GetAttr<std::string>("mode", attrs);
framework::DDim dims = alpha_->dims(); framework::DDim dims = alpha_->dims();
bias_ = InputYFrom<GType>(inputs, scope); bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = GetAttr<int>("axis", attrs); output_ = OpParam::OutFrom<GType>(outputs, scope);
filter_ = FilterFrom<GType>(inputs, scope); axis_ = OpParam::GetAttr<int>("axis", attrs);
input_ = InputFrom<GType>(inputs, scope); keyOutput_ = OpParam::getkey("addOut", inputs, 0);
output_ = OutFrom<GType>(outputs, scope); keyX1_ = OpParam::getkey("addX", inputs, 1);
strides_ = GetAttr<vector<int>>("strides", attrs); keyY1_ = OpParam::getkey("Y", inputs, 1);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
keyOutput_ = getkey("addOut", inputs, 0);
keyX1_ = getkey("addX", inputs, 1);
keyY1_ = getkey("Y", inputs, 1);
if (keyX1_ == keyOutput_) { if (keyX1_ == keyOutput_) {
bias1_ = InputYFrom1<GType>(inputs, scope); bias1_ = OpParam::InputYFrom1<GType>(inputs, scope);
} else if (keyY1_ == keyOutput_) { } else if (keyY1_ == keyOutput_) {
bias1_ = InputXFrom1<GType>(inputs, scope); bias1_ = OpParam::InputXFrom1<GType>(inputs, scope);
} }
} }
const RType *InputAlpha() const { return alpha_; } const RType *InputAlpha() const { return alpha_; }
...@@ -1471,31 +1420,12 @@ class FusionConvAddAddPReluParam : public OpParam { ...@@ -1471,31 +1420,12 @@ class FusionConvAddAddPReluParam : public OpParam {
RType *Bias() const { return bias_; } RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; } const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; } RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
protected: protected:
RType *bias_; RType *bias_;
int axis_; int axis_;
RType *input_;
RType *output_; RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *alpha_; RType *alpha_;
std::string mode_; std::string mode_;
RType *bias1_; RType *bias1_;
...@@ -1516,49 +1446,32 @@ class FusionConvAddAddPReluParam : public OpParam { ...@@ -1516,49 +1446,32 @@ class FusionConvAddAddPReluParam : public OpParam {
#ifdef FUSION_CONVADDBNRELU_OP #ifdef FUSION_CONVADDBNRELU_OP
template <typename Dtype> template <typename Dtype>
class FusionConvAddBNReluParam : public OpParam { class FusionConvAddBNReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
FusionConvAddBNReluParam(const VariableNameMap &inputs, FusionConvAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope)
bias_ = InputYFrom<GType>(inputs, scope); : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
axis_ = GetAttr<int>("axis", attrs); bias_ = OpParam::InputYFrom<GType>(inputs, scope);
filter_ = FilterFrom<GType>(inputs, scope); axis_ = OpParam::GetAttr<int>("axis", attrs);
input_ = InputFrom<GType>(inputs, scope); output_ = OpParam::OutFrom<GType>(outputs, scope);
output_ = OutFrom<GType>(outputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
paddings_ = GetAttr<vector<int>>("paddings", attrs); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
dilations_ = GetAttr<vector<int>>("dilations", attrs); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
groups = GetAttr<int>("groups", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
input_mean_ = InputMeanFrom<GType>(inputs, scope); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
input_scale_ = InputScaleFrom<GType>(inputs, scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
} }
RType *Bias() const { return bias_; } RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; } const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; } RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; } const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; } const RType *InputMean() const { return input_mean_; }
...@@ -1584,13 +1497,7 @@ class FusionConvAddBNReluParam : public OpParam { ...@@ -1584,13 +1497,7 @@ class FusionConvAddBNReluParam : public OpParam {
protected: protected:
RType *bias_; RType *bias_;
int axis_; int axis_;
RType *input_;
RType *output_; RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_; RType *input_bias_;
RType *input_mean_; RType *input_mean_;
RType *input_scale_; RType *input_scale_;
...@@ -1614,57 +1521,40 @@ class FusionConvAddBNReluParam : public OpParam { ...@@ -1614,57 +1521,40 @@ class FusionConvAddBNReluParam : public OpParam {
#ifdef FUSION_CONVBNADDRELU_OP #ifdef FUSION_CONVBNADDRELU_OP
template <typename Dtype> template <typename Dtype>
class FusionConvBNAddReluParam : public OpParam { class FusionConvBNAddReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
FusionConvBNAddReluParam(const VariableNameMap &inputs, FusionConvBNAddReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope)
bias_ = InputYFrom<GType>(inputs, scope); : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
axis_ = GetAttr<int>("axis", attrs); bias_ = OpParam::InputYFrom<GType>(inputs, scope);
filter_ = FilterFrom<GType>(inputs, scope); axis_ = OpParam::GetAttr<int>("axis", attrs);
input_ = InputFrom<GType>(inputs, scope); output_ = OpParam::OutFrom<GType>(outputs, scope);
output_ = OutFrom<GType>(outputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
paddings_ = GetAttr<vector<int>>("paddings", attrs); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
dilations_ = GetAttr<vector<int>>("dilations", attrs); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
groups = GetAttr<int>("groups", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
input_mean_ = InputMeanFrom<GType>(inputs, scope); keyBNY_ = OpParam::getkey("BNY", inputs, 0);
input_scale_ = InputScaleFrom<GType>(inputs, scope); keyX_ = OpParam::getkey("X", inputs, 0);
input_variance_ = InputVarianceFrom<GType>(inputs, scope); keyY_ = OpParam::getkey("Y", inputs, 0);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
keyBNY_ = getkey("BNY", inputs, 0);
keyX_ = getkey("X", inputs, 0);
keyY_ = getkey("Y", inputs, 0);
if (keyX_ == keyBNY_) { if (keyX_ == keyBNY_) {
bias_ = InputYFrom<GType>(inputs, scope); bias_ = OpParam::InputYFrom<GType>(inputs, scope);
} else if (keyY_ == keyBNY_) { } else if (keyY_ == keyBNY_) {
bias_ = InputXFrom<GType>(inputs, scope); bias_ = OpParam::InputXFrom<GType>(inputs, scope);
} }
// is_test_ = GetAttr<bool>("is_test", attrs); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
} }
RType *Bias() const { return bias_; } RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; } const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; } RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; } const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; } const RType *InputMean() const { return input_mean_; }
...@@ -1690,13 +1580,7 @@ class FusionConvBNAddReluParam : public OpParam { ...@@ -1690,13 +1580,7 @@ class FusionConvBNAddReluParam : public OpParam {
protected: protected:
RType *bias_; RType *bias_;
int axis_; int axis_;
RType *input_;
RType *output_; RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_; RType *input_bias_;
RType *input_mean_; RType *input_mean_;
RType *input_scale_; RType *input_scale_;
...@@ -1723,44 +1607,26 @@ class FusionConvBNAddReluParam : public OpParam { ...@@ -1723,44 +1607,26 @@ class FusionConvBNAddReluParam : public OpParam {
#ifdef FUSION_CONVBN_OP #ifdef FUSION_CONVBN_OP
template <typename Dtype> template <typename Dtype>
class FusionConvBNParam : public OpParam { class FusionConvBNParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
FusionConvBNParam(const VariableNameMap &inputs, FusionConvBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { const Scope &scope)
filter_ = FilterFrom<GType>(inputs, scope); : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
input_ = InputFrom<GType>(inputs, scope); output_y_ = OpParam::OutputYFrom<GType>(outputs, scope);
output_y_ = OutputYFrom<GType>(outputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
paddings_ = GetAttr<vector<int>>("paddings", attrs); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
dilations_ = GetAttr<vector<int>>("dilations", attrs); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
groups = GetAttr<int>("groups", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
input_mean_ = InputMeanFrom<GType>(inputs, scope); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
input_scale_ = InputScaleFrom<GType>(inputs, scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
} }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_y_; } RType *Output() const { return output_y_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; } const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; } const RType *InputMean() const { return input_mean_; }
...@@ -1784,13 +1650,7 @@ class FusionConvBNParam : public OpParam { ...@@ -1784,13 +1650,7 @@ class FusionConvBNParam : public OpParam {
const RType *NewBias() const { return new_bias_; } const RType *NewBias() const { return new_bias_; }
protected: protected:
RType *input_;
RType *output_y_; RType *output_y_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_; RType *input_bias_;
RType *input_mean_; RType *input_mean_;
RType *input_scale_; RType *input_scale_;
...@@ -1814,49 +1674,32 @@ class FusionConvBNParam : public OpParam { ...@@ -1814,49 +1674,32 @@ class FusionConvBNParam : public OpParam {
#ifdef FUSION_CONVADDBN_OP #ifdef FUSION_CONVADDBN_OP
template <typename Dtype> template <typename Dtype>
class FusionConvAddBNParam : public OpParam { class FusionConvAddBNParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
FusionConvAddBNParam(const VariableNameMap &inputs, FusionConvAddBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope)
bias_ = InputYFrom<GType>(inputs, scope); : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
axis_ = GetAttr<int>("axis", attrs); bias_ = OpParam::InputYFrom<GType>(inputs, scope);
filter_ = FilterFrom<GType>(inputs, scope); axis_ = OpParam::GetAttr<int>("axis", attrs);
input_ = InputFrom<GType>(inputs, scope); output_y_ = OpParam::OutputYFrom<GType>(outputs, scope);
output_y_ = OutputYFrom<GType>(outputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
paddings_ = GetAttr<vector<int>>("paddings", attrs); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
dilations_ = GetAttr<vector<int>>("dilations", attrs); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
groups = GetAttr<int>("groups", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
input_mean_ = InputMeanFrom<GType>(inputs, scope); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
input_scale_ = InputScaleFrom<GType>(inputs, scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
} }
RType *Bias() const { return bias_; } RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; } const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_y_; } RType *Output() const { return output_y_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; } const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; } const RType *InputMean() const { return input_mean_; }
...@@ -1882,13 +1725,7 @@ class FusionConvAddBNParam : public OpParam { ...@@ -1882,13 +1725,7 @@ class FusionConvAddBNParam : public OpParam {
protected: protected:
RType *bias_; RType *bias_;
int axis_; int axis_;
RType *input_;
RType *output_y_; RType *output_y_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_; RType *input_bias_;
RType *input_mean_; RType *input_mean_;
RType *input_scale_; RType *input_scale_;
...@@ -1912,44 +1749,26 @@ class FusionConvAddBNParam : public OpParam { ...@@ -1912,44 +1749,26 @@ class FusionConvAddBNParam : public OpParam {
#ifdef FUSION_DWCONVBNRELU_OP #ifdef FUSION_DWCONVBNRELU_OP
template <typename Dtype> template <typename Dtype>
class FusionDWConvBNReluParam : public OpParam { class FusionDWConvBNReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
FusionDWConvBNReluParam(const VariableNameMap &inputs, FusionDWConvBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope)
filter_ = FilterFrom<GType>(inputs, scope); : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
input_ = InputFrom<GType>(inputs, scope); output_ = OpParam::OutFrom<GType>(outputs, scope);
output_ = OutFrom<GType>(outputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
paddings_ = GetAttr<vector<int>>("paddings", attrs); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
dilations_ = GetAttr<vector<int>>("dilations", attrs); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
groups = GetAttr<int>("groups", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
input_mean_ = InputMeanFrom<GType>(inputs, scope); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
input_scale_ = InputScaleFrom<GType>(inputs, scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
} }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; } RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; } const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; } const RType *InputMean() const { return input_mean_; }
...@@ -1973,13 +1792,7 @@ class FusionDWConvBNReluParam : public OpParam { ...@@ -1973,13 +1792,7 @@ class FusionDWConvBNReluParam : public OpParam {
const RType *NewBias() const { return new_bias_; } const RType *NewBias() const { return new_bias_; }
protected: protected:
RType *input_;
RType *output_; RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_; RType *input_bias_;
RType *input_mean_; RType *input_mean_;
RType *input_scale_; RType *input_scale_;
...@@ -1995,45 +1808,26 @@ class FusionDWConvBNReluParam : public OpParam { ...@@ -1995,45 +1808,26 @@ class FusionDWConvBNReluParam : public OpParam {
#ifdef FUSION_CONVBNRELU_OP #ifdef FUSION_CONVBNRELU_OP
template <typename Dtype> template <typename Dtype>
class FusionConvBNReluParam : public OpParam { class FusionConvBNReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
FusionConvBNReluParam(const VariableNameMap &inputs, FusionConvBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope)
filter_ = FilterFrom<GType>(inputs, scope); : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
input_ = InputFrom<GType>(inputs, scope); output_ = OpParam::OutFrom<GType>(outputs, scope);
output_ = OutFrom<GType>(outputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
paddings_ = GetAttr<vector<int>>("paddings", attrs); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
dilations_ = GetAttr<vector<int>>("dilations", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
groups = GetAttr<int>("groups", attrs); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
input_mean_ = InputMeanFrom<GType>(inputs, scope);
input_scale_ = InputScaleFrom<GType>(inputs, scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
} }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; } RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; } const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; } const RType *InputMean() const { return input_mean_; }
...@@ -2057,13 +1851,7 @@ class FusionConvBNReluParam : public OpParam { ...@@ -2057,13 +1851,7 @@ class FusionConvBNReluParam : public OpParam {
const RType *NewBias() const { return new_bias_; } const RType *NewBias() const { return new_bias_; }
protected: protected:
RType *input_;
RType *output_; RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_; RType *input_bias_;
RType *input_mean_; RType *input_mean_;
RType *input_scale_; RType *input_scale_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册