提交 5b08ad31 编写于 作者: Z zhaojiaying01

FusionConvXXXParam inheritance ConvParam

上级 1d896aa0
......@@ -343,20 +343,22 @@ class OpParam {
#ifdef CONV_OP
template <typename Dtype>
class ConvParam : OpParam {
class ConvParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_ = OutputFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
filter_ = OpParam::FilterFrom<GType>(inputs, scope);
input_ = OpParam::InputFrom<GType>(inputs, scope);
if (outputs.count("Output")) {
output_ = OpParam::OutputFrom<GType>(outputs, scope);
}
strides_ = OpParam::GetAttr<vector<int>>("strides", attrs);
paddings_ = OpParam::GetAttr<vector<int>>("paddings", attrs);
dilations_ = OpParam::GetAttr<vector<int>>("dilations", attrs);
groups = OpParam::GetAttr<int>("groups", attrs);
}
const RType *Input() const { return input_; }
......@@ -1294,52 +1296,29 @@ using FusionFcReluParam = FusionFcParam<DeviceType>;
#endif
template <typename Dtype>
class FusionConvAddParam : public OpParam {
class FusionConvAddParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionConvAddParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
bias_ = InputYFrom<GType>(inputs, scope);
axis_ = GetAttr<int>("axis", attrs);
filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope);
}
RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
protected:
RType *bias_;
int axis_;
RType *input_;
RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
#ifdef PADDLE_MOBILE_FPGA
private:
......@@ -1366,58 +1345,33 @@ class FusionConvAddReluParam : public FusionConvAddParam<DeviceType> {
#endif
#ifdef FUSION_CONVADDPRELU_OP
template <typename DeviceType>
class FusionConvAddPReluParam : public OpParam {
typedef typename DtypeTensorTrait<DeviceType>::gtype GType;
typedef typename DtypeTensorTrait<DeviceType>::rtype RType;
template <typename Dtype>
class FusionConvAddPReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionConvAddPReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
alpha_ = InputAlphaFrom<GType>(inputs, scope);
mode_ = GetAttr<std::string>("mode", attrs);
const AttributeMap &attrs, const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
alpha_ = OpParam::InputAlphaFrom<GType>(inputs, scope);
mode_ = OpParam::GetAttr<std::string>("mode", attrs);
framework::DDim dims = alpha_->dims();
bias_ = InputYFrom<GType>(inputs, scope);
axis_ = GetAttr<int>("axis", attrs);
filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope);
}
const RType *InputAlpha() const { return alpha_; }
const std::string &Mode() const { return mode_; }
RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
protected:
RType *bias_;
int axis_;
RType *input_;
RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *alpha_;
std::string mode_;
#ifdef PADDLE_MOBILE_FPGA
......@@ -1433,35 +1387,30 @@ class FusionConvAddPReluParam : public OpParam {
#endif
#ifdef FUSION_CONVADDADDPRELU_OP
template <typename DeviceType>
class FusionConvAddAddPReluParam : public OpParam {
typedef typename DtypeTensorTrait<DeviceType>::gtype GType;
typedef typename DtypeTensorTrait<DeviceType>::rtype RType;
template <typename Dtype>
class FusionConvAddAddPReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionConvAddAddPReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
bias1_ = InputYFrom1<GType>(inputs, scope);
alpha_ = InputAlphaFrom<GType>(inputs, scope);
mode_ = GetAttr<std::string>("mode", attrs);
const AttributeMap &attrs, const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias1_ = OpParam::InputYFrom1<GType>(inputs, scope);
alpha_ = OpParam::InputAlphaFrom<GType>(inputs, scope);
mode_ = OpParam::GetAttr<std::string>("mode", attrs);
framework::DDim dims = alpha_->dims();
bias_ = InputYFrom<GType>(inputs, scope);
axis_ = GetAttr<int>("axis", attrs);
filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
keyOutput_ = getkey("addOut", inputs, 0);
keyX1_ = getkey("addX", inputs, 1);
keyY1_ = getkey("Y", inputs, 1);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
output_ = OpParam::OutFrom<GType>(outputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
keyOutput_ = OpParam::getkey("addOut", inputs, 0);
keyX1_ = OpParam::getkey("addX", inputs, 1);
keyY1_ = OpParam::getkey("Y", inputs, 1);
if (keyX1_ == keyOutput_) {
bias1_ = InputYFrom1<GType>(inputs, scope);
bias1_ = OpParam::InputYFrom1<GType>(inputs, scope);
} else if (keyY1_ == keyOutput_) {
bias1_ = InputXFrom1<GType>(inputs, scope);
bias1_ = OpParam::InputXFrom1<GType>(inputs, scope);
}
}
const RType *InputAlpha() const { return alpha_; }
......@@ -1471,31 +1420,12 @@ class FusionConvAddAddPReluParam : public OpParam {
RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
protected:
RType *bias_;
int axis_;
RType *input_;
RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *alpha_;
std::string mode_;
RType *bias1_;
......@@ -1516,49 +1446,32 @@ class FusionConvAddAddPReluParam : public OpParam {
#ifdef FUSION_CONVADDBNRELU_OP
template <typename Dtype>
class FusionConvAddBNReluParam : public OpParam {
class FusionConvAddBNReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionConvAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
bias_ = InputYFrom<GType>(inputs, scope);
axis_ = GetAttr<int>("axis", attrs);
filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope);
input_mean_ = InputMeanFrom<GType>(inputs, scope);
input_scale_ = InputScaleFrom<GType>(inputs, scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
const AttributeMap &attrs, const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
}
RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; }
......@@ -1584,13 +1497,7 @@ class FusionConvAddBNReluParam : public OpParam {
protected:
RType *bias_;
int axis_;
RType *input_;
RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_;
RType *input_mean_;
RType *input_scale_;
......@@ -1614,57 +1521,40 @@ class FusionConvAddBNReluParam : public OpParam {
#ifdef FUSION_CONVBNADDRELU_OP
template <typename Dtype>
class FusionConvBNAddReluParam : public OpParam {
class FusionConvBNAddReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionConvBNAddReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
bias_ = InputYFrom<GType>(inputs, scope);
axis_ = GetAttr<int>("axis", attrs);
filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope);
input_mean_ = InputMeanFrom<GType>(inputs, scope);
input_scale_ = InputScaleFrom<GType>(inputs, scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
keyBNY_ = getkey("BNY", inputs, 0);
keyX_ = getkey("X", inputs, 0);
keyY_ = getkey("Y", inputs, 0);
const AttributeMap &attrs, const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs);
keyBNY_ = OpParam::getkey("BNY", inputs, 0);
keyX_ = OpParam::getkey("X", inputs, 0);
keyY_ = OpParam::getkey("Y", inputs, 0);
if (keyX_ == keyBNY_) {
bias_ = InputYFrom<GType>(inputs, scope);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
} else if (keyY_ == keyBNY_) {
bias_ = InputXFrom<GType>(inputs, scope);
bias_ = OpParam::InputXFrom<GType>(inputs, scope);
}
// is_test_ = GetAttr<bool>("is_test", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
}
RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; }
......@@ -1690,13 +1580,7 @@ class FusionConvBNAddReluParam : public OpParam {
protected:
RType *bias_;
int axis_;
RType *input_;
RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_;
RType *input_mean_;
RType *input_scale_;
......@@ -1723,44 +1607,26 @@ class FusionConvBNAddReluParam : public OpParam {
#ifdef FUSION_CONVBN_OP
template <typename Dtype>
class FusionConvBNParam : public OpParam {
class FusionConvBNParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionConvBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_y_ = OutputYFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope);
input_mean_ = InputMeanFrom<GType>(inputs, scope);
input_scale_ = InputScaleFrom<GType>(inputs, scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
output_y_ = OpParam::OutputYFrom<GType>(outputs, scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
}
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_y_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; }
......@@ -1784,13 +1650,7 @@ class FusionConvBNParam : public OpParam {
const RType *NewBias() const { return new_bias_; }
protected:
RType *input_;
RType *output_y_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_;
RType *input_mean_;
RType *input_scale_;
......@@ -1814,49 +1674,32 @@ class FusionConvBNParam : public OpParam {
#ifdef FUSION_CONVADDBN_OP
template <typename Dtype>
class FusionConvAddBNParam : public OpParam {
class FusionConvAddBNParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionConvAddBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
bias_ = InputYFrom<GType>(inputs, scope);
axis_ = GetAttr<int>("axis", attrs);
filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_y_ = OutputYFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope);
input_mean_ = InputMeanFrom<GType>(inputs, scope);
input_scale_ = InputScaleFrom<GType>(inputs, scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
const AttributeMap &attrs, const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
output_y_ = OpParam::OutputYFrom<GType>(outputs, scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
}
RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_y_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; }
......@@ -1882,13 +1725,7 @@ class FusionConvAddBNParam : public OpParam {
protected:
RType *bias_;
int axis_;
RType *input_;
RType *output_y_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_;
RType *input_mean_;
RType *input_scale_;
......@@ -1912,44 +1749,26 @@ class FusionConvAddBNParam : public OpParam {
#ifdef FUSION_DWCONVBNRELU_OP
template <typename Dtype>
class FusionDWConvBNReluParam : public OpParam {
class FusionDWConvBNReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDWConvBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope);
input_mean_ = InputMeanFrom<GType>(inputs, scope);
input_scale_ = InputScaleFrom<GType>(inputs, scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
const AttributeMap &attrs, const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
output_ = OpParam::OutFrom<GType>(outputs, scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
}
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; }
......@@ -1973,13 +1792,7 @@ class FusionDWConvBNReluParam : public OpParam {
const RType *NewBias() const { return new_bias_; }
protected:
RType *input_;
RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_;
RType *input_mean_;
RType *input_scale_;
......@@ -1995,45 +1808,26 @@ class FusionDWConvBNReluParam : public OpParam {
#ifdef FUSION_CONVBNRELU_OP
template <typename Dtype>
class FusionConvBNReluParam : public OpParam {
class FusionConvBNReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionConvBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs);
input_bias_ = InputBiasFrom<GType>(inputs, scope);
input_mean_ = InputMeanFrom<GType>(inputs, scope);
input_scale_ = InputScaleFrom<GType>(inputs, scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
const AttributeMap &attrs, const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
output_ = OpParam::OutFrom<GType>(outputs, scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
}
const RType *Input() const { return input_; }
const RType *Filter() const { return filter_; }
RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
const vector<int> &Paddings() const { return paddings_; }
const vector<int> &Dilations() const { return dilations_; }
const int &Groups() const { return groups; }
const RType *InputBias() const { return input_bias_; }
const RType *InputMean() const { return input_mean_; }
......@@ -2057,13 +1851,7 @@ class FusionConvBNReluParam : public OpParam {
const RType *NewBias() const { return new_bias_; }
protected:
RType *input_;
RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
int groups;
RType *input_bias_;
RType *input_mean_;
RType *input_scale_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册