提交 93a1705a 编写于 作者: L liuruilong

put tensor when kernel init in scope

上级 abab2bf9
...@@ -103,7 +103,7 @@ class OperatorWithKernel : public OperatorBase<Dtype> { ...@@ -103,7 +103,7 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope) std::shared_ptr<Scope> scope)
: OperatorBase<Dtype>(type, inputs, outputs, attrs, scope), : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope),
param_(inputs, outputs, attrs, *scope) { param_(inputs, outputs, attrs, scope.get()) {
#ifdef PADDLE_MOBILE_CL #ifdef PADDLE_MOBILE_CL
kernel_.InitCLHelper(scope->GetCLScpoe()); kernel_.InitCLHelper(scope->GetCLScpoe());
#endif #endif
......
...@@ -35,7 +35,7 @@ class FillConstantOp : public framework::OperatorBase<DeviceType> { ...@@ -35,7 +35,7 @@ class FillConstantOp : public framework::OperatorBase<DeviceType> {
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorBase<DeviceType>(type, inputs, outputs, attrs, : framework::OperatorBase<DeviceType>(type, inputs, outputs, attrs,
scope), scope),
param_(inputs, outputs, attrs, *scope) {} param_(inputs, outputs, attrs, scope.get()) {}
void RunImpl() { void RunImpl() {
auto data_type = auto data_type =
static_cast<_PaddleMobile__Framework__Proto__VarType__Type>( static_cast<_PaddleMobile__Framework__Proto__VarType__Type>(
......
...@@ -41,8 +41,11 @@ bool ConvAddBNReluKernel<CPU, float>::Init( ...@@ -41,8 +41,11 @@ bool ConvAddBNReluKernel<CPU, float>::Init(
inv_std_ptr[i] = inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5)); 1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
} }
Tensor *new_scale = new Tensor(); // Tensor *new_scale = new Tensor();
Tensor *new_bias = new Tensor(); // Tensor *new_bias = new Tensor();
Tensor *new_scale = param->CreateNewScale<Tensor>();
Tensor *new_bias = param->CreateNewBiase<Tensor>();
auto new_scale_ptr = new_scale->mutable_data<float>({C}); auto new_scale_ptr = new_scale->mutable_data<float>({C});
auto new_bias_ptr = new_bias->mutable_data<float>({C}); auto new_bias_ptr = new_bias->mutable_data<float>({C});
for (int i = 0; i < C; i++) { for (int i = 0; i < C; i++) {
......
...@@ -42,8 +42,8 @@ bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam<CPU> *param) { ...@@ -42,8 +42,8 @@ bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam<CPU> *param) {
inv_std_ptr[i] = inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5)); 1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
} }
Tensor *new_scale = new Tensor(); Tensor *new_scale = param->CreateNewScale<Tensor>();
Tensor *new_bias = new Tensor(); Tensor *new_bias = param->CreateNewBiase<Tensor>();
auto new_scale_ptr = new_scale->mutable_data<float>({C}); auto new_scale_ptr = new_scale->mutable_data<float>({C});
auto new_bias_ptr = new_bias->mutable_data<float>({C}); auto new_bias_ptr = new_bias->mutable_data<float>({C});
for (int i = 0; i < C; i++) { for (int i = 0; i < C; i++) {
......
...@@ -27,12 +27,14 @@ class AnchorGeneratorParam : public OpParam { ...@@ -27,12 +27,14 @@ class AnchorGeneratorParam : public OpParam {
public: public:
AnchorGeneratorParam(const VariableNameMap &inputs, AnchorGeneratorParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_ = OpParam::GetVarValue<framework::LoDTensor>("Input", inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_ =
OpParam::GetVarValue<framework::LoDTensor>("Input", inputs, *scope);
output_anchors_ = output_anchors_ =
OpParam::GetVarValue<framework::LoDTensor>("Anchors", outputs, scope); OpParam::GetVarValue<framework::LoDTensor>("Anchors", outputs, *scope);
output_variances_ = output_variances_ = OpParam::GetVarValue<framework::LoDTensor>(
OpParam::GetVarValue<framework::LoDTensor>("Variances", outputs, scope); "Variances", outputs, *scope);
anchor_sizes_ = OpParam::GetAttr<std::vector<float>>("anchor_sizes", attrs); anchor_sizes_ = OpParam::GetAttr<std::vector<float>>("anchor_sizes", attrs);
aspect_ratios_ = aspect_ratios_ =
...@@ -64,22 +66,23 @@ template <typename Dtype> ...@@ -64,22 +66,23 @@ template <typename Dtype>
class ProposalParam : public OpParam { class ProposalParam : public OpParam {
public: public:
ProposalParam(const VariableNameMap &inputs, const VariableNameMap &outputs, ProposalParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
scores_ = scores_ =
OpParam::GetVarValue<framework::LoDTensor>("Scores", inputs, scope); OpParam::GetVarValue<framework::LoDTensor>("Scores", inputs, *scope);
bbox_deltas_ = bbox_deltas_ = OpParam::GetVarValue<framework::LoDTensor>("BboxDeltas",
OpParam::GetVarValue<framework::LoDTensor>("BboxDeltas", inputs, scope); inputs, *scope);
im_info_ = im_info_ =
OpParam::GetVarValue<framework::LoDTensor>("ImInfo", inputs, scope); OpParam::GetVarValue<framework::LoDTensor>("ImInfo", inputs, *scope);
anchors_ = anchors_ =
OpParam::GetVarValue<framework::LoDTensor>("Anchors", inputs, scope); OpParam::GetVarValue<framework::LoDTensor>("Anchors", inputs, *scope);
variances_ = variances_ =
OpParam::GetVarValue<framework::LoDTensor>("Variances", inputs, scope); OpParam::GetVarValue<framework::LoDTensor>("Variances", inputs, *scope);
rpn_rois_ = rpn_rois_ =
OpParam::GetVarValue<framework::LoDTensor>("RpnRois", outputs, scope); OpParam::GetVarValue<framework::LoDTensor>("RpnRois", outputs, *scope);
rpn_probs_ = OpParam::GetVarValue<framework::LoDTensor>("RpnRoiProbs", rpn_probs_ = OpParam::GetVarValue<framework::LoDTensor>("RpnRoiProbs",
outputs, scope); outputs, *scope);
pre_nms_topn_ = OpParam::GetAttr<int>("pre_nms_topN", attrs); pre_nms_topn_ = OpParam::GetAttr<int>("pre_nms_topN", attrs);
post_nms_topn_ = OpParam::GetAttr<int>("post_nms_topN", attrs); post_nms_topn_ = OpParam::GetAttr<int>("post_nms_topN", attrs);
...@@ -117,11 +120,13 @@ template <typename Dtype> ...@@ -117,11 +120,13 @@ template <typename Dtype>
class PSRoiPoolParam : public OpParam { class PSRoiPoolParam : public OpParam {
public: public:
PSRoiPoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs, PSRoiPoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_x_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, *scope);
input_rois_ = input_rois_ =
OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, scope); OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, *scope);
output_ = OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, scope); output_ =
OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope);
output_channels_ = OpParam::GetAttr<int>("output_channels", attrs); output_channels_ = OpParam::GetAttr<int>("output_channels", attrs);
pooled_height_ = OpParam::GetAttr<int>("pooled_height", attrs); pooled_height_ = OpParam::GetAttr<int>("pooled_height", attrs);
...@@ -152,11 +157,13 @@ class RoiPerspectiveParam : public OpParam { ...@@ -152,11 +157,13 @@ class RoiPerspectiveParam : public OpParam {
public: public:
RoiPerspectiveParam(const VariableNameMap &inputs, RoiPerspectiveParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
input_x_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_x_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, *scope);
input_rois_ = input_rois_ =
OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, scope); OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, *scope);
output_ = OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, scope); output_ =
OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope);
spatial_scale_ = OpParam::GetAttr<float>("spatial_scale", attrs); spatial_scale_ = OpParam::GetAttr<float>("spatial_scale", attrs);
transformed_height_ = OpParam::GetAttr<int>("transformed_height", attrs); transformed_height_ = OpParam::GetAttr<int>("transformed_height", attrs);
......
...@@ -25,10 +25,13 @@ template <typename Dtype> ...@@ -25,10 +25,13 @@ template <typename Dtype>
class WhileParam : public OpParam { class WhileParam : public OpParam {
public: public:
WhileParam(const VariableNameMap &inputs, const VariableNameMap &outputs, WhileParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: inputs_(inputs), outputs_(outputs), scope_(scope) { : inputs_(inputs),
outputs_(outputs),
scope_(*scope),
OpParam(inputs, outputs, attrs, scope) {
cond_ = cond_ =
OpParam::GetVarValue<framework::LoDTensor>("Condition", inputs, scope); OpParam::GetVarValue<framework::LoDTensor>("Condition", inputs, *scope);
sub_block_ = OpParam::GetAttr<int>("sub_block", attrs); sub_block_ = OpParam::GetAttr<int>("sub_block", attrs);
} }
......
...@@ -69,6 +69,30 @@ struct DtypeTensorTrait<GPU_CL> { ...@@ -69,6 +69,30 @@ struct DtypeTensorTrait<GPU_CL> {
#endif #endif
class OpParam { class OpParam {
public:
OpParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope) {
scope_pointer_ = scope;
inputs_ = inputs;
}
template <typename T>
T *CreateNewScale() {
std::string scale_key = Getkey("Scale", inputs_, 0);
auto var = scope_pointer_->Var(scale_key + "_new");
return var->GetMutable<T>();
}
template <typename T>
T *CreateNewBiase() {
std::string biase_key = Getkey("Bias", inputs_, 0);
auto var = scope_pointer_->Var(biase_key + "_new");
return var->GetMutable<T>();
}
VariableNameMap inputs_;
Scope *scope_pointer_ = nullptr;
protected: protected:
template <typename T> template <typename T>
static T *InputH0From(const VariableNameMap &inputs, const Scope &scope) { static T *InputH0From(const VariableNameMap &inputs, const Scope &scope) {
...@@ -359,8 +383,10 @@ class OpParam { ...@@ -359,8 +383,10 @@ class OpParam {
} }
} }
static std::string getkey(const string &key, const VariableNameMap &var_map, static std::string Getkey(const string &key, const VariableNameMap &var_map,
int index) { int index) {
PADDLE_MOBILE_ENFORCE(var_map.count(key) > index,
"%s is not contained in var_map", key.c_str())
auto var_vec = var_map.at(key); auto var_vec = var_map.at(key);
return var_vec[index]; return var_vec[index];
} }
...@@ -414,11 +440,12 @@ class ConvParam : public OpParam { ...@@ -414,11 +440,12 @@ class ConvParam : public OpParam {
public: public:
ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs, ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
filter_ = OpParam::FilterFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_ = OpParam::InputFrom<GType>(inputs, scope); filter_ = OpParam::FilterFrom<GType>(inputs, *scope);
input_ = OpParam::InputFrom<GType>(inputs, *scope);
if (outputs.count("Output")) { if (outputs.count("Output")) {
output_ = OpParam::OutputFrom<GType>(outputs, scope); output_ = OpParam::OutputFrom<GType>(outputs, *scope);
} }
strides_ = OpParam::GetAttr<vector<int>>("strides", attrs); strides_ = OpParam::GetAttr<vector<int>>("strides", attrs);
paddings_ = OpParam::GetAttr<vector<int>>("paddings", attrs); paddings_ = OpParam::GetAttr<vector<int>>("paddings", attrs);
...@@ -500,17 +527,18 @@ template <typename Dtype> ...@@ -500,17 +527,18 @@ template <typename Dtype>
Print &operator<<(Print &printer, const ConvParam<Dtype> &conv_param); Print &operator<<(Print &printer, const ConvParam<Dtype> &conv_param);
template <typename Dtype> template <typename Dtype>
class ElementwiseAddParam : OpParam { class ElementwiseAddParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
ElementwiseAddParam(const VariableNameMap &inputs, ElementwiseAddParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_y_ = InputYFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); input_y_ = InputYFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
axis_ = GetAttr<int>("axis", attrs); axis_ = GetAttr<int>("axis", attrs);
} }
...@@ -540,17 +568,18 @@ class ElementwiseAddParam : OpParam { ...@@ -540,17 +568,18 @@ class ElementwiseAddParam : OpParam {
#ifdef ELEMENTWISEMUL_OP #ifdef ELEMENTWISEMUL_OP
template <typename Dtype> template <typename Dtype>
class ElementwiseMulParam : OpParam { class ElementwiseMulParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
ElementwiseMulParam(const VariableNameMap &inputs, ElementwiseMulParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_y_ = InputYFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); input_y_ = InputYFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
axis_ = GetAttr<int>("axis", attrs); axis_ = GetAttr<int>("axis", attrs);
} }
...@@ -577,17 +606,18 @@ using ElementwiseAddReluParam = ElementwiseAddParam<Dtype>; ...@@ -577,17 +606,18 @@ using ElementwiseAddReluParam = ElementwiseAddParam<Dtype>;
#ifdef ELEMENTWISESUB_OP #ifdef ELEMENTWISESUB_OP
template <typename Dtype> template <typename Dtype>
class ElementwiseSubParam : OpParam { class ElementwiseSubParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
ElementwiseSubParam(const VariableNameMap &inputs, ElementwiseSubParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_y_ = InputYFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); input_y_ = InputYFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
axis_ = GetAttr<int>("axis", attrs); axis_ = GetAttr<int>("axis", attrs);
} }
...@@ -609,16 +639,17 @@ class ElementwiseSubParam : OpParam { ...@@ -609,16 +639,17 @@ class ElementwiseSubParam : OpParam {
#ifdef MUL_OP #ifdef MUL_OP
template <typename Dtype> template <typename Dtype>
class MulParam : OpParam { class MulParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
MulParam(const VariableNameMap &inputs, const VariableNameMap &outputs, MulParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_y_ = InputYFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); input_y_ = InputYFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
x_num_col_dims_ = GetAttr<int>("x_num_col_dims", attrs); x_num_col_dims_ = GetAttr<int>("x_num_col_dims", attrs);
y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs); y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
} }
...@@ -650,9 +681,10 @@ class ConcatParam : public OpParam { ...@@ -650,9 +681,10 @@ class ConcatParam : public OpParam {
public: public:
ConcatParam(const VariableNameMap &inputs, const VariableNameMap &outputs, ConcatParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
inputs_ = InputMultiFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); inputs_ = InputMultiFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
axis_ = GetAttr<int>("axis", attrs); axis_ = GetAttr<int>("axis", attrs);
} }
...@@ -686,11 +718,12 @@ class SumParam : public OpParam { ...@@ -686,11 +718,12 @@ class SumParam : public OpParam {
public: public:
SumParam(const VariableNameMap &inputs, const VariableNameMap &outputs, SumParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
inputs_vars_ = InputMultiVarsFrom(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_var_ = OutVarFrom(outputs, scope); inputs_vars_ = InputMultiVarsFrom(inputs, *scope);
inputs_ = InputMultiFrom<GType>(inputs, scope); out_var_ = OutVarFrom(outputs, *scope);
out_ = OutFrom<GType>(outputs, scope); inputs_ = InputMultiFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
} }
vector<Variable *> InputsVars() const { return inputs_vars_; } vector<Variable *> InputsVars() const { return inputs_vars_; }
...@@ -717,10 +750,11 @@ class LrnParam : public OpParam { ...@@ -717,10 +750,11 @@ class LrnParam : public OpParam {
public: public:
LrnParam(const VariableNameMap &inputs, const VariableNameMap &outputs, LrnParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
mid_out_ = MidOutFrom<GType>(outputs, scope); out_ = OutFrom<GType>(outputs, *scope);
mid_out_ = MidOutFrom<GType>(outputs, *scope);
n_ = GetAttr<int>("n", attrs); n_ = GetAttr<int>("n", attrs);
alpha_ = GetAttr<float>("alpha", attrs); alpha_ = GetAttr<float>("alpha", attrs);
beta_ = GetAttr<float>("beta", attrs); beta_ = GetAttr<float>("beta", attrs);
...@@ -758,16 +792,17 @@ class LrnParam : public OpParam { ...@@ -758,16 +792,17 @@ class LrnParam : public OpParam {
#ifdef NORM_OP #ifdef NORM_OP
template <typename Dtype> template <typename Dtype>
class NormParam : OpParam { class NormParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
NormParam(const VariableNameMap &inputs, const VariableNameMap &outputs, NormParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
output_norm_ = OutputNormFrom<GType>(outputs, scope); out_ = OutFrom<GType>(outputs, *scope);
output_norm_ = OutputNormFrom<GType>(outputs, *scope);
epsilon_ = GetAttr<float>("epsilon", attrs); epsilon_ = GetAttr<float>("epsilon", attrs);
axis_ = GetAttr<int>("axis", attrs); axis_ = GetAttr<int>("axis", attrs);
} }
...@@ -793,19 +828,20 @@ class NormParam : OpParam { ...@@ -793,19 +828,20 @@ class NormParam : OpParam {
#ifdef BATCHNORM_OP #ifdef BATCHNORM_OP
template <typename Dtype> template <typename Dtype>
class BatchNormParam : OpParam { class BatchNormParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
BatchNormParam(const VariableNameMap &inputs, const VariableNameMap &outputs, BatchNormParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
output_y_ = OutputYFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
input_bias_ = InputBiasFrom<GType>(inputs, scope); output_y_ = OutputYFrom<GType>(outputs, *scope);
input_mean_ = InputMeanFrom<GType>(inputs, scope); input_bias_ = InputBiasFrom<GType>(inputs, *scope);
input_scale_ = InputScaleFrom<GType>(inputs, scope); input_mean_ = InputMeanFrom<GType>(inputs, *scope);
input_variance_ = InputVarianceFrom<GType>(inputs, scope); input_scale_ = InputScaleFrom<GType>(inputs, *scope);
input_variance_ = InputVarianceFrom<GType>(inputs, *scope);
epsilon_ = GetAttr<float>("epsilon", attrs); epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs); momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs); // is_test_ = GetAttr<bool>("is_test", attrs);
...@@ -863,10 +899,11 @@ class PoolParam : public OpParam { ...@@ -863,10 +899,11 @@ class PoolParam : public OpParam {
public: public:
PoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs, PoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_ = InputXFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, scope); output_ = OutFrom<GType>(outputs, *scope);
pooling_type_ = GetStringAttr("pooling_type", attrs); pooling_type_ = GetStringAttr("pooling_type", attrs);
ksize_ = GetAttr<vector<int>>("ksize", attrs); ksize_ = GetAttr<vector<int>>("ksize", attrs);
strides_ = GetAttr<vector<int>>("strides", attrs); strides_ = GetAttr<vector<int>>("strides", attrs);
...@@ -920,11 +957,12 @@ class PriorBoxParam : public OpParam { ...@@ -920,11 +957,12 @@ class PriorBoxParam : public OpParam {
public: public:
PriorBoxParam(const VariableNameMap &inputs, const VariableNameMap &outputs, PriorBoxParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_ = InputFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_image_ = InputImageFrom<GType>(inputs, scope); input_ = InputFrom<GType>(inputs, *scope);
output_boxes_ = OutputBoxesFrom<GType>(outputs, scope); input_image_ = InputImageFrom<GType>(inputs, *scope);
output_variances_ = OutputVariancesFrom<GType>(outputs, scope); output_boxes_ = OutputBoxesFrom<GType>(outputs, *scope);
output_variances_ = OutputVariancesFrom<GType>(outputs, *scope);
min_sizes_ = GetAttr<vector<float>>("min_sizes", attrs); min_sizes_ = GetAttr<vector<float>>("min_sizes", attrs);
max_sizes_ = GetAttr<vector<float>>("max_sizes", attrs); max_sizes_ = GetAttr<vector<float>>("max_sizes", attrs);
aspect_ratios_ = GetAttr<vector<float>>("aspect_ratios", attrs); aspect_ratios_ = GetAttr<vector<float>>("aspect_ratios", attrs);
...@@ -998,11 +1036,12 @@ class BoxCoderParam : public OpParam { ...@@ -998,11 +1036,12 @@ class BoxCoderParam : public OpParam {
public: public:
BoxCoderParam(const VariableNameMap &inputs, const VariableNameMap &outputs, BoxCoderParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_priorbox_ = InputPriorBoxFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_priorboxvar_ = InputPriorBoxVarFrom<GType>(inputs, scope); input_priorbox_ = InputPriorBoxFrom<GType>(inputs, *scope);
input_targetbox_ = InputTargetBoxFrom<GType>(inputs, scope); input_priorboxvar_ = InputPriorBoxVarFrom<GType>(inputs, *scope);
output_box_ = OutputBoxFrom<GType>(outputs, scope); input_targetbox_ = InputTargetBoxFrom<GType>(inputs, *scope);
output_box_ = OutputBoxFrom<GType>(outputs, *scope);
code_type_ = GetStringAttr("code_type", attrs); code_type_ = GetStringAttr("code_type", attrs);
} }
const RType *InputPriorBox() const { return input_priorbox_; } const RType *InputPriorBox() const { return input_priorbox_; }
...@@ -1032,9 +1071,10 @@ class SoftmaxParam : public OpParam { ...@@ -1032,9 +1071,10 @@ class SoftmaxParam : public OpParam {
public: public:
SoftmaxParam(const VariableNameMap &inputs, const VariableNameMap &outputs, SoftmaxParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
} }
const GType *InputX() const { return input_x_; } const GType *InputX() const { return input_x_; }
GType *Out() const { return out_; } GType *Out() const { return out_; }
...@@ -1068,9 +1108,10 @@ class SigmoidParam : public OpParam { ...@@ -1068,9 +1108,10 @@ class SigmoidParam : public OpParam {
public: public:
SigmoidParam(const VariableNameMap &inputs, const VariableNameMap &outputs, SigmoidParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
} }
const RType *InputX() const { return input_x_; } const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; } RType *Out() const { return out_; }
...@@ -1099,10 +1140,11 @@ class MultiClassNMSParam : public OpParam { ...@@ -1099,10 +1140,11 @@ class MultiClassNMSParam : public OpParam {
public: public:
MultiClassNMSParam(const VariableNameMap &inputs, MultiClassNMSParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
input_bboxes_ = InputBBoxesFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_scores_ = InputScoresFrom<GType>(inputs, scope); input_bboxes_ = InputBBoxesFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); input_scores_ = InputScoresFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
background_label_ = GetAttr<int>("background_label", attrs); background_label_ = GetAttr<int>("background_label", attrs);
nms_top_k_ = GetAttr<int>("nms_top_k", attrs); nms_top_k_ = GetAttr<int>("nms_top_k", attrs);
keep_top_k_ = GetAttr<int>("keep_top_k", attrs); keep_top_k_ = GetAttr<int>("keep_top_k", attrs);
...@@ -1151,9 +1193,10 @@ class PolygonBoxTransformParam : public OpParam { ...@@ -1151,9 +1193,10 @@ class PolygonBoxTransformParam : public OpParam {
public: public:
PolygonBoxTransformParam(const VariableNameMap &inputs, PolygonBoxTransformParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_ = InputFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
output_ = OutputFrom<GType>(outputs, scope); input_ = InputFrom<GType>(inputs, *scope);
output_ = OutputFrom<GType>(outputs, *scope);
} }
const RType *Input() const { return input_; } const RType *Input() const { return input_; }
RType *Output() const { return output_; } RType *Output() const { return output_; }
...@@ -1171,16 +1214,17 @@ class FeedParam : public OpParam { ...@@ -1171,16 +1214,17 @@ class FeedParam : public OpParam {
public: public:
FeedParam(const VariableNameMap &inputs, const VariableNameMap &outputs, FeedParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
static int feed_num = 0; static int feed_num = 0;
auto new_name = std::string("feed") + std::to_string(feed_num++); auto new_name = std::string("feed") + std::to_string(feed_num++);
const_cast<VariableNameMap &>(inputs).at("X") = {string(new_name)}; const_cast<VariableNameMap &>(inputs).at("X") = {string(new_name)};
#endif #endif
input_x_ = InputXFrom<LoDTensor>(inputs, scope); input_x_ = InputXFrom<LoDTensor>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); out_ = OutFrom<GType>(outputs, *scope);
auto var = scope.FindVar("batch_size"); auto var = scope->FindVar("batch_size");
batch_size = var->GetValue<int>(); batch_size = var->GetValue<int>();
} }
const LoDTensor *InputX() const { return input_x_; } const LoDTensor *InputX() const { return input_x_; }
...@@ -1200,14 +1244,15 @@ class FetchParam : public OpParam { ...@@ -1200,14 +1244,15 @@ class FetchParam : public OpParam {
public: public:
FetchParam(const VariableNameMap &inputs, const VariableNameMap &outputs, FetchParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
static int fetch_num = 0; static int fetch_num = 0;
auto new_name = std::string("fetch") + std::to_string(fetch_num++); auto new_name = std::string("fetch") + std::to_string(fetch_num++);
const_cast<VariableNameMap &>(outputs).at("Out") = {string(new_name)}; const_cast<VariableNameMap &>(outputs).at("Out") = {string(new_name)};
#endif #endif
input_x_ = InputXFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom(outputs, scope); out_ = OutFrom(outputs, *scope);
} }
const RType *InputX() const { return input_x_; } const RType *InputX() const { return input_x_; }
...@@ -1237,9 +1282,10 @@ class FillConstantParam : public OpParam { ...@@ -1237,9 +1282,10 @@ class FillConstantParam : public OpParam {
public: public:
FillConstantParam(const VariableNameMap &inputs, FillConstantParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
out_var_ = OutVarFrom(outputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); out_var_ = OutVarFrom(outputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
dtype_ = GetAttr<int>("dtype", attrs); dtype_ = GetAttr<int>("dtype", attrs);
shape_ = GetAttr<vector<int>>("shape", attrs); shape_ = GetAttr<vector<int>>("shape", attrs);
value_ = GetAttr<float>("value", attrs); value_ = GetAttr<float>("value", attrs);
...@@ -1272,9 +1318,10 @@ class TransposeParam : public OpParam { ...@@ -1272,9 +1318,10 @@ class TransposeParam : public OpParam {
public: public:
TransposeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, TransposeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
axis_ = GetAttr<vector<int>>("axis", attrs); axis_ = GetAttr<vector<int>>("axis", attrs);
} }
...@@ -1299,10 +1346,11 @@ class Transpose2Param : public OpParam { ...@@ -1299,10 +1346,11 @@ class Transpose2Param : public OpParam {
public: public:
Transpose2Param(const VariableNameMap &inputs, const VariableNameMap &outputs, Transpose2Param(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
output_xshape_ = OutputXShapeFrom<GType>(outputs, scope); out_ = OutFrom<GType>(outputs, *scope);
output_xshape_ = OutputXShapeFrom<GType>(outputs, *scope);
axis_ = GetAttr<vector<int>>("axis", attrs); axis_ = GetAttr<vector<int>>("axis", attrs);
} }
...@@ -1330,10 +1378,11 @@ class LookupParam : public OpParam { ...@@ -1330,10 +1378,11 @@ class LookupParam : public OpParam {
public: public:
LookupParam(const VariableNameMap &inputs, const VariableNameMap &outputs, LookupParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_w_ = InputWFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_ids_ = InputIdsFrom<GType>(inputs, scope); input_w_ = InputWFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); input_ids_ = InputIdsFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
padding_idx_ = GetAttr<int64_t>("padding_idx", attrs); padding_idx_ = GetAttr<int64_t>("padding_idx", attrs);
} }
...@@ -1360,12 +1409,13 @@ class CrfParam : public OpParam { ...@@ -1360,12 +1409,13 @@ class CrfParam : public OpParam {
// {G_OP_TYPE_CRF, {{"Emission", "Transition", "Label"}, {"ViterbiPath"}}}, // {G_OP_TYPE_CRF, {{"Emission", "Transition", "Label"}, {"ViterbiPath"}}},
CrfParam(const VariableNameMap &inputs, const VariableNameMap &outputs, CrfParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
// todo crf params // todo crf params
input_emission_ = InputEmissionFrom<GType>(inputs, scope); input_emission_ = InputEmissionFrom<GType>(inputs, *scope);
input_transition_ = InputTransitionFrom<GType>(inputs, scope); input_transition_ = InputTransitionFrom<GType>(inputs, *scope);
input_label_ = InputLabelFrom<GType>(inputs, scope); input_label_ = InputLabelFrom<GType>(inputs, *scope);
output_viterbipath_ = OutputViterbiPathFrom<GType>(outputs, scope); output_viterbipath_ = OutputViterbiPathFrom<GType>(outputs, *scope);
// padding_idx_ = GetAttr<int64_t>("padding_idx", attrs); // padding_idx_ = GetAttr<int64_t>("padding_idx", attrs);
} }
const GType *InputEmission() const { return input_emission_; } const GType *InputEmission() const { return input_emission_; }
...@@ -1396,10 +1446,11 @@ class ReshapeParam : public OpParam { ...@@ -1396,10 +1446,11 @@ class ReshapeParam : public OpParam {
public: public:
ReshapeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, ReshapeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_shape_ = InputShapeFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); input_shape_ = InputShapeFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
shape_ = GetAttr<vector<int>>("shape", attrs); shape_ = GetAttr<vector<int>>("shape", attrs);
if (HasAttr("inplace", attrs)) { if (HasAttr("inplace", attrs)) {
...@@ -1437,11 +1488,12 @@ class Reshape2Param : public OpParam { ...@@ -1437,11 +1488,12 @@ class Reshape2Param : public OpParam {
public: public:
Reshape2Param(const VariableNameMap &inputs, const VariableNameMap &outputs, Reshape2Param(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_shape_ = InputShapeFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); input_shape_ = InputShapeFrom<GType>(inputs, *scope);
output_xshape_ = OutputXShapeFrom<GType>(outputs, scope); out_ = OutFrom<GType>(outputs, *scope);
output_xshape_ = OutputXShapeFrom<GType>(outputs, *scope);
shape_ = GetAttr<vector<int>>("shape", attrs); shape_ = GetAttr<vector<int>>("shape", attrs);
if (HasAttr("inplace", attrs)) { if (HasAttr("inplace", attrs)) {
inplace_ = GetAttr<bool>("inplace", attrs); inplace_ = GetAttr<bool>("inplace", attrs);
...@@ -1480,10 +1532,11 @@ class ScaleParam : public OpParam { ...@@ -1480,10 +1532,11 @@ class ScaleParam : public OpParam {
public: public:
ScaleParam(const VariableNameMap &inputs, const VariableNameMap &outputs, ScaleParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_bias_ = InputBiasFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); input_bias_ = InputBiasFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
inplace_ = GetAttr<bool>("inplace", attrs); inplace_ = GetAttr<bool>("inplace", attrs);
has_bias_ = GetAttr<bool>("has_bias", attrs); has_bias_ = GetAttr<bool>("has_bias", attrs);
scales_ = GetAttr<vector<float>>("scales", attrs); scales_ = GetAttr<vector<float>>("scales", attrs);
...@@ -1523,9 +1576,10 @@ class SliceParam : public OpParam { ...@@ -1523,9 +1576,10 @@ class SliceParam : public OpParam {
public: public:
SliceParam(const VariableNameMap &inputs, const VariableNameMap &outputs, SliceParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_ = InputFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
output_ = OutFrom<GType>(outputs, scope); input_ = InputFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, *scope);
axes_ = GetAttr<std::vector<int>>("axes", attrs); axes_ = GetAttr<std::vector<int>>("axes", attrs);
starts_ = GetAttr<std::vector<int>>("starts", attrs); starts_ = GetAttr<std::vector<int>>("starts", attrs);
...@@ -1549,10 +1603,11 @@ class ResizeParam : public OpParam { ...@@ -1549,10 +1603,11 @@ class ResizeParam : public OpParam {
public: public:
ResizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, ResizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_shape_ = InputShapeFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); input_shape_ = InputShapeFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
is_pyramid_test_ = GetAttr<bool>("is_pyramid_test", attrs); is_pyramid_test_ = GetAttr<bool>("is_pyramid_test", attrs);
height_ = GetAttr<int>("height", attrs); height_ = GetAttr<int>("height", attrs);
width_ = GetAttr<int>("width", attrs); width_ = GetAttr<int>("width", attrs);
...@@ -1599,9 +1654,10 @@ class ReluParamBase : public OpParam { ...@@ -1599,9 +1654,10 @@ class ReluParamBase : public OpParam {
public: public:
ReluParamBase(const VariableNameMap &inputs, const VariableNameMap &outputs, ReluParamBase(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
} }
const RType *InputX() const { return input_x_; } const RType *InputX() const { return input_x_; }
...@@ -1641,9 +1697,10 @@ class TanhParam : public OpParam { ...@@ -1641,9 +1697,10 @@ class TanhParam : public OpParam {
public: public:
TanhParam(const VariableNameMap &inputs, const VariableNameMap &outputs, TanhParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
} }
const RType *InputX() const { return input_x_; } const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; } RType *Out() const { return out_; }
...@@ -1676,12 +1733,13 @@ class PReluParam : public OpParam { ...@@ -1676,12 +1733,13 @@ class PReluParam : public OpParam {
public: public:
PReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs, PReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
DLOG << "PReluParam inputs before"; DLOG << "PReluParam inputs before";
input_x_ = InputXFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
alpha_ = InputAlphaFrom<GType>(inputs, scope); alpha_ = InputAlphaFrom<GType>(inputs, *scope);
framework::DDim dims = alpha_->dims(); framework::DDim dims = alpha_->dims();
out_ = OutFrom<GType>(outputs, scope); out_ = OutFrom<GType>(outputs, *scope);
mode_ = GetStringAttr("mode", attrs); mode_ = GetStringAttr("mode", attrs);
DLOG << "PReluParam mode after" << mode_; DLOG << "PReluParam mode after" << mode_;
} }
...@@ -1705,11 +1763,12 @@ class FusionFcParam : public OpParam { ...@@ -1705,11 +1763,12 @@ class FusionFcParam : public OpParam {
public: public:
FusionFcParam(const VariableNameMap &inputs, const VariableNameMap &outputs, FusionFcParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_y_ = InputYFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
input_z_ = InputZFrom<GType>(inputs, scope); input_y_ = InputYFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); input_z_ = InputZFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
x_num_col_dims_ = GetAttr<int>("x_num_col_dims", attrs); x_num_col_dims_ = GetAttr<int>("x_num_col_dims", attrs);
y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs); y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
axis_ = GetAttr<int>("axis", attrs); axis_ = GetAttr<int>("axis", attrs);
...@@ -1760,11 +1819,11 @@ class FusionConvAddParam : public ConvParam<Dtype> { ...@@ -1760,11 +1819,11 @@ class FusionConvAddParam : public ConvParam<Dtype> {
public: public:
FusionConvAddParam(const VariableNameMap &inputs, FusionConvAddParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) Scope *scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) { : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope); bias_ = OpParam::InputYFrom<GType>(inputs, *scope);
axis_ = OpParam::GetAttr<int>("axis", attrs); axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope); output_ = OpParam::OutFrom<GType>(outputs, *scope);
} }
RType *Bias() const { return bias_; } RType *Bias() const { return bias_; }
...@@ -1787,7 +1846,7 @@ class FusionConvAddReluParam : public FusionConvAddParam<DeviceType> { ...@@ -1787,7 +1846,7 @@ class FusionConvAddReluParam : public FusionConvAddParam<DeviceType> {
public: public:
FusionConvAddReluParam(const VariableNameMap &inputs, FusionConvAddReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: FusionConvAddParam<DeviceType>(inputs, outputs, attrs, scope) {} : FusionConvAddParam<DeviceType>(inputs, outputs, attrs, scope) {}
}; };
#endif #endif
...@@ -1801,14 +1860,14 @@ class FusionConvAddPReluParam : public ConvParam<Dtype> { ...@@ -1801,14 +1860,14 @@ class FusionConvAddPReluParam : public ConvParam<Dtype> {
public: public:
FusionConvAddPReluParam(const VariableNameMap &inputs, FusionConvAddPReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) { : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
alpha_ = OpParam::InputAlphaFrom<GType>(inputs, scope); alpha_ = OpParam::InputAlphaFrom<GType>(inputs, *scope);
mode_ = OpParam::GetStringAttr("mode", attrs); mode_ = OpParam::GetStringAttr("mode", attrs);
framework::DDim dims = alpha_->dims(); framework::DDim dims = alpha_->dims();
bias_ = OpParam::InputYFrom<GType>(inputs, scope); bias_ = OpParam::InputYFrom<GType>(inputs, *scope);
axis_ = OpParam::GetAttr<int>("axis", attrs); axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope); output_ = OpParam::OutFrom<GType>(outputs, *scope);
} }
const RType *InputAlpha() const { return alpha_; } const RType *InputAlpha() const { return alpha_; }
const std::string &Mode() const { return mode_; } const std::string &Mode() const { return mode_; }
...@@ -1834,22 +1893,22 @@ class FusionConvAddAddPReluParam : public ConvParam<Dtype> { ...@@ -1834,22 +1893,22 @@ class FusionConvAddAddPReluParam : public ConvParam<Dtype> {
public: public:
FusionConvAddAddPReluParam(const VariableNameMap &inputs, FusionConvAddAddPReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) { : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias1_ = OpParam::InputYFrom1<GType>(inputs, scope); bias1_ = OpParam::InputYFrom1<GType>(inputs, *scope);
alpha_ = OpParam::InputAlphaFrom<GType>(inputs, scope); alpha_ = OpParam::InputAlphaFrom<GType>(inputs, *scope);
mode_ = OpParam::GetStringAttr("mode", attrs); mode_ = OpParam::GetStringAttr("mode", attrs);
framework::DDim dims = alpha_->dims(); framework::DDim dims = alpha_->dims();
bias_ = OpParam::InputYFrom<GType>(inputs, scope); bias_ = OpParam::InputYFrom<GType>(inputs, *scope);
output_ = OpParam::OutFrom<GType>(outputs, scope); output_ = OpParam::OutFrom<GType>(outputs, *scope);
axis_ = OpParam::GetAttr<int>("axis", attrs); axis_ = OpParam::GetAttr<int>("axis", attrs);
keyOutput_ = OpParam::getkey("addOut", inputs, 0); keyOutput_ = OpParam::Getkey("addOut", inputs, 0);
keyX1_ = OpParam::getkey("addX", inputs, 1); keyX1_ = OpParam::Getkey("addX", inputs, 1);
keyY1_ = OpParam::getkey("Y", inputs, 1); keyY1_ = OpParam::Getkey("Y", inputs, 1);
if (keyX1_ == keyOutput_) { if (keyX1_ == keyOutput_) {
bias1_ = OpParam::InputYFrom1<GType>(inputs, scope); bias1_ = OpParam::InputYFrom1<GType>(inputs, *scope);
} else if (keyY1_ == keyOutput_) { } else if (keyY1_ == keyOutput_) {
bias1_ = OpParam::InputXFrom1<GType>(inputs, scope); bias1_ = OpParam::InputXFrom1<GType>(inputs, *scope);
} }
} }
const RType *InputAlpha() const { return alpha_; } const RType *InputAlpha() const { return alpha_; }
...@@ -1883,15 +1942,15 @@ class FusionConvAddBNReluParam : public ConvParam<Dtype> { ...@@ -1883,15 +1942,15 @@ class FusionConvAddBNReluParam : public ConvParam<Dtype> {
public: public:
FusionConvAddBNReluParam(const VariableNameMap &inputs, FusionConvAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) { : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope); bias_ = OpParam::InputYFrom<GType>(inputs, *scope);
axis_ = OpParam::GetAttr<int>("axis", attrs); axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope); output_ = OpParam::OutFrom<GType>(outputs, *scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, *scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, *scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, *scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, *scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
...@@ -1949,24 +2008,24 @@ class FusionConvBNAddReluParam : public ConvParam<Dtype> { ...@@ -1949,24 +2008,24 @@ class FusionConvBNAddReluParam : public ConvParam<Dtype> {
public: public:
FusionConvBNAddReluParam(const VariableNameMap &inputs, FusionConvBNAddReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) { : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope); bias_ = OpParam::InputYFrom<GType>(inputs, *scope);
axis_ = OpParam::GetAttr<int>("axis", attrs); axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope); output_ = OpParam::OutFrom<GType>(outputs, *scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, *scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, *scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, *scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, *scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
keyBNY_ = OpParam::getkey("BNY", inputs, 0); keyBNY_ = OpParam::Getkey("BNY", inputs, 0);
keyX_ = OpParam::getkey("X", inputs, 0); keyX_ = OpParam::Getkey("X", inputs, 0);
keyY_ = OpParam::getkey("Y", inputs, 0); keyY_ = OpParam::Getkey("Y", inputs, 0);
if (keyX_ == keyBNY_) { if (keyX_ == keyBNY_) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope); bias_ = OpParam::InputYFrom<GType>(inputs, *scope);
} else if (keyY_ == keyBNY_) { } else if (keyY_ == keyBNY_) {
bias_ = OpParam::InputXFrom<GType>(inputs, scope); bias_ = OpParam::InputXFrom<GType>(inputs, *scope);
} }
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
} }
...@@ -2026,13 +2085,13 @@ class FusionConvBNParam : public ConvParam<Dtype> { ...@@ -2026,13 +2085,13 @@ class FusionConvBNParam : public ConvParam<Dtype> {
public: public:
FusionConvBNParam(const VariableNameMap &inputs, FusionConvBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) Scope *scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) { : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
output_y_ = OpParam::OutputYFrom<GType>(outputs, scope); output_y_ = OpParam::OutputYFrom<GType>(outputs, *scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, *scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, *scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, *scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, *scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
...@@ -2084,15 +2143,15 @@ class FusionConvAddBNParam : public ConvParam<Dtype> { ...@@ -2084,15 +2143,15 @@ class FusionConvAddBNParam : public ConvParam<Dtype> {
public: public:
FusionConvAddBNParam(const VariableNameMap &inputs, FusionConvAddBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) { : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope); bias_ = OpParam::InputYFrom<GType>(inputs, *scope);
axis_ = OpParam::GetAttr<int>("axis", attrs); axis_ = OpParam::GetAttr<int>("axis", attrs);
output_y_ = OpParam::OutputYFrom<GType>(outputs, scope); output_y_ = OpParam::OutputYFrom<GType>(outputs, *scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, *scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, *scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, *scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, *scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
...@@ -2150,13 +2209,13 @@ class FusionDWConvBNReluParam : public ConvParam<Dtype> { ...@@ -2150,13 +2209,13 @@ class FusionDWConvBNReluParam : public ConvParam<Dtype> {
public: public:
FusionDWConvBNReluParam(const VariableNameMap &inputs, FusionDWConvBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) { : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
output_ = OpParam::OutFrom<GType>(outputs, scope); output_ = OpParam::OutFrom<GType>(outputs, *scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, *scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, *scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, *scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, *scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
...@@ -2209,13 +2268,13 @@ class FusionConvBNReluParam : public ConvParam<Dtype> { ...@@ -2209,13 +2268,13 @@ class FusionConvBNReluParam : public ConvParam<Dtype> {
public: public:
FusionConvBNReluParam(const VariableNameMap &inputs, FusionConvBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) { : ConvParam<Dtype>(inputs, outputs, attrs, scope) {
output_ = OpParam::OutFrom<GType>(outputs, scope); output_ = OpParam::OutFrom<GType>(outputs, *scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, *scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, *scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, *scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, *scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
...@@ -2267,9 +2326,10 @@ class Im2SequenceParam : public OpParam { ...@@ -2267,9 +2326,10 @@ class Im2SequenceParam : public OpParam {
public: public:
Im2SequenceParam(const VariableNameMap &inputs, Im2SequenceParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
kernels_ = GetAttr<vector<int>>("kernels", attrs); kernels_ = GetAttr<vector<int>>("kernels", attrs);
strides_ = GetAttr<vector<int>>("strides", attrs); strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs); paddings_ = GetAttr<vector<int>>("paddings", attrs);
...@@ -2302,9 +2362,10 @@ class DropoutParam : public OpParam { ...@@ -2302,9 +2362,10 @@ class DropoutParam : public OpParam {
public: public:
DropoutParam(const VariableNameMap &inputs, const VariableNameMap &outputs, DropoutParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
dropout_prob_ = GetAttr<float>("dropout_prob", attrs); dropout_prob_ = GetAttr<float>("dropout_prob", attrs);
} }
...@@ -2330,12 +2391,13 @@ class ConvTransposeParam : public OpParam { ...@@ -2330,12 +2391,13 @@ class ConvTransposeParam : public OpParam {
public: public:
ConvTransposeParam(const VariableNameMap &inputs, ConvTransposeParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
filter_ = FilterFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_ = InputFrom<GType>(inputs, scope); filter_ = FilterFrom<GType>(inputs, *scope);
input_ = InputFrom<GType>(inputs, *scope);
// output_ = OutputFrom<GType>(outputs, scope); // output_ = OutputFrom<GType>(outputs, scope);
if (outputs.count("Output")) { if (outputs.count("Output")) {
output_ = OpParam::OutputFrom<GType>(outputs, scope); output_ = OpParam::OutputFrom<GType>(outputs, *scope);
} }
strides_ = GetAttr<vector<int>>("strides", attrs); strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs); paddings_ = GetAttr<vector<int>>("paddings", attrs);
...@@ -2393,11 +2455,11 @@ class FusionDeconvAddParam : public ConvTransposeParam<Dtype> { ...@@ -2393,11 +2455,11 @@ class FusionDeconvAddParam : public ConvTransposeParam<Dtype> {
public: public:
FusionDeconvAddParam(const VariableNameMap &inputs, FusionDeconvAddParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: ConvTransposeParam<Dtype>(inputs, outputs, attrs, scope) { : ConvTransposeParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope); bias_ = OpParam::InputYFrom<GType>(inputs, *scope);
axis_ = OpParam::GetAttr<int>("axis", attrs); axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope); output_ = OpParam::OutFrom<GType>(outputs, *scope);
} }
RType *Bias() const { return bias_; } RType *Bias() const { return bias_; }
...@@ -2425,13 +2487,13 @@ class FusionDeconvAddBNParam : public ConvTransposeParam<Dtype> { ...@@ -2425,13 +2487,13 @@ class FusionDeconvAddBNParam : public ConvTransposeParam<Dtype> {
public: public:
FusionDeconvAddBNParam(const VariableNameMap &inputs, FusionDeconvAddBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: ConvTransposeParam<Dtype>(inputs, outputs, attrs, scope) { : ConvTransposeParam<Dtype>(inputs, outputs, attrs, scope) {
output_ = OpParam::OutFrom<GType>(outputs, scope); output_ = OpParam::OutFrom<GType>(outputs, *scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, *scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, *scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, *scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, *scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
...@@ -2482,13 +2544,13 @@ class FusionDeconvAddBNReluParam : public ConvTransposeParam<Dtype> { ...@@ -2482,13 +2544,13 @@ class FusionDeconvAddBNReluParam : public ConvTransposeParam<Dtype> {
public: public:
FusionDeconvAddBNReluParam(const VariableNameMap &inputs, FusionDeconvAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: ConvTransposeParam<Dtype>(inputs, outputs, attrs, scope) { : ConvTransposeParam<Dtype>(inputs, outputs, attrs, scope) {
output_ = OpParam::OutFrom<GType>(outputs, scope); output_ = OpParam::OutFrom<GType>(outputs, *scope);
input_bias_ = OpParam::InputBiasFrom<GType>(inputs, scope); input_bias_ = OpParam::InputBiasFrom<GType>(inputs, *scope);
input_mean_ = OpParam::InputMeanFrom<GType>(inputs, scope); input_mean_ = OpParam::InputMeanFrom<GType>(inputs, *scope);
input_scale_ = OpParam::InputScaleFrom<GType>(inputs, scope); input_scale_ = OpParam::InputScaleFrom<GType>(inputs, *scope);
input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, scope); input_variance_ = OpParam::InputVarianceFrom<GType>(inputs, *scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
momentum_ = OpParam::GetAttr<float>("momentum", attrs); momentum_ = OpParam::GetAttr<float>("momentum", attrs);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs); // is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
...@@ -2550,17 +2612,18 @@ class GruParam : public OpParam { ...@@ -2550,17 +2612,18 @@ class GruParam : public OpParam {
* @param scope * @param scope
* */ * */
GruParam(const VariableNameMap &inputs, const VariableNameMap &outputs, GruParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_input_ = InputFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_h0_ = InputH0From<GType>(inputs, scope); input_input_ = InputFrom<GType>(inputs, *scope);
input_bias_ = InputBiasFrom<GType>(inputs, scope); input_h0_ = InputH0From<GType>(inputs, *scope);
input_weight_ = InputWeightFrom<GType>(inputs, scope); input_bias_ = InputBiasFrom<GType>(inputs, *scope);
input_weight_ = InputWeightFrom<GType>(inputs, *scope);
output_batch_gate_ = OutputBatchGateFrom<GType>(outputs, scope);
output_batch_gate_ = OutputBatchGateFrom<GType>(outputs, *scope);
output_batch_reset_hidden_prev_ = output_batch_reset_hidden_prev_ =
OutputBatchResetHiddenPrevFrom<GType>(outputs, scope); OutputBatchResetHiddenPrevFrom<GType>(outputs, *scope);
output_batch_hidden_ = OutputBatchHiddenFrom<GType>(outputs, scope); output_batch_hidden_ = OutputBatchHiddenFrom<GType>(outputs, *scope);
output_hidden_ = OutputHiddenFrom<GType>(outputs, scope); output_hidden_ = OutputHiddenFrom<GType>(outputs, *scope);
activation_ = GetStringAttr("activation", attrs); activation_ = GetStringAttr("activation", attrs);
gate_activation_ = GetStringAttr("gate_activation", attrs); gate_activation_ = GetStringAttr("gate_activation", attrs);
is_reverse_ = GetAttr<bool>("is_reverse", attrs); is_reverse_ = GetAttr<bool>("is_reverse", attrs);
...@@ -2603,16 +2666,17 @@ class GruUnitParam : public OpParam { ...@@ -2603,16 +2666,17 @@ class GruUnitParam : public OpParam {
public: public:
GruUnitParam(const VariableNameMap &inputs, const VariableNameMap &outputs, GruUnitParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_input_ = InputFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_hidden_prev_ = InputHiddenPrevFrom<GType>(inputs, scope); input_input_ = InputFrom<GType>(inputs, *scope);
input_bias_ = InputBiasFrom<GType>(inputs, scope); input_hidden_prev_ = InputHiddenPrevFrom<GType>(inputs, *scope);
input_weight_ = InputWeightFrom<GType>(inputs, scope); input_bias_ = InputBiasFrom<GType>(inputs, *scope);
input_weight_ = InputWeightFrom<GType>(inputs, *scope);
output_gate_ = OutputGateFrom<GType>(outputs, scope);
output_gate_ = OutputGateFrom<GType>(outputs, *scope);
output_reset_hidden_prev_ = output_reset_hidden_prev_ =
OutputResetHiddenPrevFrom<GType>(outputs, scope); OutputResetHiddenPrevFrom<GType>(outputs, *scope);
output_hidden_ = OutputHiddenFrom<GType>(outputs, scope); output_hidden_ = OutputHiddenFrom<GType>(outputs, *scope);
activation_ = GetAttr<int>("activation", attrs); activation_ = GetAttr<int>("activation", attrs);
gate_activation_ = GetAttr<int>("gate_activation", attrs); gate_activation_ = GetAttr<int>("gate_activation", attrs);
} }
...@@ -2649,9 +2713,10 @@ class FlattenParam : public OpParam { ...@@ -2649,9 +2713,10 @@ class FlattenParam : public OpParam {
public: public:
FlattenParam(const VariableNameMap &inputs, const VariableNameMap &outputs, FlattenParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
axis = GetAttr<int>("axis", attrs); axis = GetAttr<int>("axis", attrs);
} }
const RType *InputX() const { return input_x_; } const RType *InputX() const { return input_x_; }
...@@ -2673,9 +2738,10 @@ class SplitParam : public OpParam { ...@@ -2673,9 +2738,10 @@ class SplitParam : public OpParam {
public: public:
SplitParam(const VariableNameMap &inputs, const VariableNameMap &outputs, SplitParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
outs_ = OutMultiFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
outs_ = OutMultiFrom<GType>(outputs, *scope);
axis = GetAttr<int>("axis", attrs); axis = GetAttr<int>("axis", attrs);
num = GetAttr<int>("num", attrs); num = GetAttr<int>("num", attrs);
sections = GetAttr<std::vector<int>>("sections", attrs); sections = GetAttr<std::vector<int>>("sections", attrs);
...@@ -2719,10 +2785,11 @@ class BilinearInterpParam : public OpParam { ...@@ -2719,10 +2785,11 @@ class BilinearInterpParam : public OpParam {
public: public:
BilinearInterpParam(const VariableNameMap &inputs, BilinearInterpParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_outsize_ = InputOutSizeFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, scope); input_outsize_ = InputOutSizeFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
out_h_ = GetAttr<int>("out_h", attrs); out_h_ = GetAttr<int>("out_h", attrs);
out_w_ = GetAttr<int>("out_w", attrs); out_w_ = GetAttr<int>("out_w", attrs);
} }
...@@ -2749,9 +2816,10 @@ class ShapeParam : public OpParam { ...@@ -2749,9 +2816,10 @@ class ShapeParam : public OpParam {
public: public:
ShapeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, ShapeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_ = InputFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_ = InputFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
} }
const RType *Input() const { return input_; } const RType *Input() const { return input_; }
RType *Out() const { return out_; } RType *Out() const { return out_; }
...@@ -2770,10 +2838,11 @@ class TopKParam : public OpParam { ...@@ -2770,10 +2838,11 @@ class TopKParam : public OpParam {
public: public:
TopKParam(const VariableNameMap &inputs, const VariableNameMap &outputs, TopKParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_ = OpParam::GetVarValue<GType>("X", inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
output_ = OpParam::GetVarValue<GType>("Out", outputs, scope); input_ = OpParam::GetVarValue<GType>("X", inputs, *scope);
indices_ = OpParam::GetVarValue<GType>("Indices", outputs, scope); output_ = OpParam::GetVarValue<GType>("Out", outputs, *scope);
indices_ = OpParam::GetVarValue<GType>("Indices", outputs, *scope);
k_ = OpParam::GetAttr<int>("k", attrs); k_ = OpParam::GetAttr<int>("k", attrs);
} }
...@@ -2793,9 +2862,10 @@ class CastParam : public OpParam { ...@@ -2793,9 +2862,10 @@ class CastParam : public OpParam {
public: public:
CastParam(const VariableNameMap &inputs, const VariableNameMap &outputs, CastParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_ = OpParam::GetVarValue<GType>("X", inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
output_ = OpParam::GetVarValue<GType>("Out", outputs, scope); input_ = OpParam::GetVarValue<GType>("X", inputs, *scope);
output_ = OpParam::GetVarValue<GType>("Out", outputs, *scope);
input_type_ = OpParam::GetAttr<int>("in_dtype", attrs); input_type_ = OpParam::GetAttr<int>("in_dtype", attrs);
output_type_ = OpParam::GetAttr<int>("out_dtype", attrs); output_type_ = OpParam::GetAttr<int>("out_dtype", attrs);
} }
...@@ -2816,16 +2886,17 @@ class QuantizeParam : public OpParam { ...@@ -2816,16 +2886,17 @@ class QuantizeParam : public OpParam {
public: public:
QuantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, QuantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
output_ = OutFrom<GType>(outputs, scope); input_ = InputXFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, *scope);
// online // online
// scale = max(abs(x)) // scale = max(abs(x))
online_scale_ = OpParam::GetVarValue<GType>("OutScale", outputs, scope); online_scale_ = OpParam::GetVarValue<GType>("OutScale", outputs, *scope);
// offline // offline
if (inputs.count("InScale")) { if (inputs.count("InScale")) {
offline_ = true; offline_ = true;
offline_scale_ = OpParam::GetVarValue<GType>("InScale", inputs, scope); offline_scale_ = OpParam::GetVarValue<GType>("InScale", inputs, *scope);
} }
// x = round(scale * x) // x = round(scale * x)
if (OpParam::HasAttr("round_type", attrs)) { if (OpParam::HasAttr("round_type", attrs)) {
...@@ -2857,10 +2928,11 @@ class DequantizeParam : public OpParam { ...@@ -2857,10 +2928,11 @@ class DequantizeParam : public OpParam {
public: public:
DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
output_ = OutFrom<GType>(outputs, scope); input_ = InputXFrom<GType>(inputs, *scope);
activation_scale_ = OpParam::GetVarValue<GType>("Scale", inputs, scope); output_ = OutFrom<GType>(outputs, *scope);
activation_scale_ = OpParam::GetVarValue<GType>("Scale", inputs, *scope);
// dequantization is performed as x = x / static_scale / online_scale // dequantization is performed as x = x / static_scale / online_scale
if (OpParam::HasAttr("weight_scale", attrs)) { if (OpParam::HasAttr("weight_scale", attrs)) {
weight_scale_ = OpParam::GetAttr<float>("weight_scale", attrs); weight_scale_ = OpParam::GetAttr<float>("weight_scale", attrs);
...@@ -2892,13 +2964,13 @@ class FusionDequantBNParam : public DequantizeParam<Dtype> { ...@@ -2892,13 +2964,13 @@ class FusionDequantBNParam : public DequantizeParam<Dtype> {
public: public:
FusionDequantBNParam(const VariableNameMap &inputs, FusionDequantBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: DequantizeParam<Dtype>(inputs, outputs, attrs, scope) { : DequantizeParam<Dtype>(inputs, outputs, attrs, scope) {
// batch norm params // batch norm params
bn_mean_ = OpParam::GetVarValue<GType>("BNMean", inputs, scope); bn_mean_ = OpParam::GetVarValue<GType>("BNMean", inputs, *scope);
bn_variance_ = OpParam::GetVarValue<GType>("BNVariance", inputs, scope); bn_variance_ = OpParam::GetVarValue<GType>("BNVariance", inputs, *scope);
bn_scale_ = OpParam::GetVarValue<GType>("BNScale", inputs, scope); bn_scale_ = OpParam::GetVarValue<GType>("BNScale", inputs, *scope);
bn_bias_ = OpParam::GetVarValue<GType>("BNBias", inputs, scope); bn_bias_ = OpParam::GetVarValue<GType>("BNBias", inputs, *scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
} }
...@@ -2924,11 +2996,11 @@ class FusionDequantAddBNParam : public FusionDequantBNParam<Dtype> { ...@@ -2924,11 +2996,11 @@ class FusionDequantAddBNParam : public FusionDequantBNParam<Dtype> {
public: public:
FusionDequantAddBNParam(const VariableNameMap &inputs, FusionDequantAddBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: FusionDequantBNParam<Dtype>(inputs, outputs, attrs, scope) { : FusionDequantBNParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params // element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs); axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope); bias_ = OpParam::InputYFrom<GType>(inputs, *scope);
} }
public: public:
...@@ -2947,14 +3019,14 @@ class FusionDequantAddBNQuantParam : public FusionDequantAddBNParam<Dtype> { ...@@ -2947,14 +3019,14 @@ class FusionDequantAddBNQuantParam : public FusionDequantAddBNParam<Dtype> {
public: public:
FusionDequantAddBNQuantParam(const VariableNameMap &inputs, FusionDequantAddBNQuantParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, Scope *scope)
: FusionDequantAddBNParam<Dtype>(inputs, outputs, attrs, scope) { : FusionDequantAddBNParam<Dtype>(inputs, outputs, attrs, scope) {
// scale output // scale output
online_scale_ = OpParam::GetVarValue<GType>("OutScale", outputs, scope); online_scale_ = OpParam::GetVarValue<GType>("OutScale", outputs, *scope);
// offline // offline
if (inputs.count("InScale")) { if (inputs.count("InScale")) {
offline_ = true; offline_ = true;
offline_scale_ = OpParam::GetVarValue<GType>("InScale", inputs, scope); offline_scale_ = OpParam::GetVarValue<GType>("InScale", inputs, *scope);
} }
// x = round(scale * x) // x = round(scale * x)
if (OpParam::HasAttr("round_type", attrs)) { if (OpParam::HasAttr("round_type", attrs)) {
...@@ -2983,10 +3055,11 @@ class SequenceExpandParam : public OpParam { ...@@ -2983,10 +3055,11 @@ class SequenceExpandParam : public OpParam {
public: public:
SequenceExpandParam(const VariableNameMap &inputs, SequenceExpandParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_y_ = InputYFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, scope); input_y_ = InputYFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, *scope);
ref_level_ = -1; ref_level_ = -1;
if (OpParam::HasAttr("ref_level", attrs)) { if (OpParam::HasAttr("ref_level", attrs)) {
ref_level_ = OpParam::GetAttr<int>("ref_level", attrs); ref_level_ = OpParam::GetAttr<int>("ref_level", attrs);
...@@ -3010,9 +3083,10 @@ class SequencePoolParam : public OpParam { ...@@ -3010,9 +3083,10 @@ class SequencePoolParam : public OpParam {
public: public:
SequencePoolParam(const VariableNameMap &inputs, SequencePoolParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
input_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
output_ = OutFrom<GType>(outputs, scope); input_ = InputXFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, *scope);
pool_type_ = "MAX"; pool_type_ = "MAX";
if (OpParam::HasAttr("pooltype", attrs)) { if (OpParam::HasAttr("pooltype", attrs)) {
pool_type_ = OpParam::GetStringAttr("pooltype", attrs); pool_type_ = OpParam::GetStringAttr("pooltype", attrs);
...@@ -3034,12 +3108,13 @@ class LodResetParam : public OpParam { ...@@ -3034,12 +3108,13 @@ class LodResetParam : public OpParam {
public: public:
LodResetParam(const VariableNameMap &inputs, const VariableNameMap &outputs, LodResetParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
output_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, *scope);
input_y_ = nullptr; input_y_ = nullptr;
if (inputs.count("Y")) { if (inputs.count("Y")) {
input_y_ = InputYFrom<GType>(inputs, scope); input_y_ = InputYFrom<GType>(inputs, *scope);
} else { } else {
target_lod_ = OpParam::GetAttr<vector<int>>("target_lod", attrs); target_lod_ = OpParam::GetAttr<vector<int>>("target_lod", attrs);
} }
...@@ -3061,10 +3136,11 @@ class CompareParam : public OpParam { ...@@ -3061,10 +3136,11 @@ class CompareParam : public OpParam {
public: public:
CompareParam(const VariableNameMap &inputs, const VariableNameMap &outputs, CompareParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_y_ = InputYFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, scope); input_y_ = InputYFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, *scope);
axis_ = OpParam::GetAttr<int>("axis", attrs); axis_ = OpParam::GetAttr<int>("axis", attrs);
} }
...@@ -3085,10 +3161,11 @@ class LogicalBinaryParam : public OpParam { ...@@ -3085,10 +3161,11 @@ class LogicalBinaryParam : public OpParam {
public: public:
LogicalBinaryParam(const VariableNameMap &inputs, LogicalBinaryParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
input_y_ = InputYFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, scope); input_y_ = InputYFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, *scope);
} }
const GType *InputX() const { return input_x_; } const GType *InputX() const { return input_x_; }
...@@ -3111,9 +3188,10 @@ class LogicalUnaryParam : public OpParam { ...@@ -3111,9 +3188,10 @@ class LogicalUnaryParam : public OpParam {
public: public:
LogicalUnaryParam(const VariableNameMap &inputs, LogicalUnaryParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
output_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, *scope);
} }
const GType *InputX() const { return input_x_; } const GType *InputX() const { return input_x_; }
...@@ -3131,7 +3209,7 @@ class LogicalUnaryParam : public OpParam { ...@@ -3131,7 +3209,7 @@ class LogicalUnaryParam : public OpParam {
// public: // public:
// WhileParam(const VariableNameMap &inputs, // WhileParam(const VariableNameMap &inputs,
// const VariableNameMap &outputs, const AttributeMap &attrs, // const VariableNameMap &outputs, const AttributeMap &attrs,
// const Scope &scope) { // const Scope &scope) : OpParam(inputs, outputs, attrs, scope) {
// cond_ = OpParam::GetVarValue<framework::LoDTensor>("Condition", inputs, // cond_ = OpParam::GetVarValue<framework::LoDTensor>("Condition", inputs,
// scope); block_desc_ = OpParam::GetAttr<framework::BlockDesc // scope); block_desc_ = OpParam::GetAttr<framework::BlockDesc
// *>("sub_block", attrs); // *>("sub_block", attrs);
...@@ -3149,11 +3227,12 @@ class WriteToArrayParam : public OpParam { ...@@ -3149,11 +3227,12 @@ class WriteToArrayParam : public OpParam {
public: public:
WriteToArrayParam(const VariableNameMap &inputs, WriteToArrayParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
input_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
index_ = OpParam::GetVarValue<framework::LoDTensor>("I", inputs, scope); input_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, *scope);
index_ = OpParam::GetVarValue<framework::LoDTensor>("I", inputs, *scope);
output_ = output_ =
OpParam::GetVarValue<framework::LoDTensorArray>("Out", outputs, scope); OpParam::GetVarValue<framework::LoDTensorArray>("Out", outputs, *scope);
} }
public: public:
...@@ -3169,11 +3248,13 @@ class ReadFromArrayParam : public OpParam { ...@@ -3169,11 +3248,13 @@ class ReadFromArrayParam : public OpParam {
public: public:
ReadFromArrayParam(const VariableNameMap &inputs, ReadFromArrayParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) { Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_ = input_ =
OpParam::GetVarValue<framework::LoDTensorArray>("X", inputs, scope); OpParam::GetVarValue<framework::LoDTensorArray>("X", inputs, *scope);
index_ = OpParam::GetVarValue<framework::LoDTensor>("I", inputs, scope); index_ = OpParam::GetVarValue<framework::LoDTensor>("I", inputs, *scope);
output_ = OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, scope); output_ =
OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, *scope);
} }
public: public:
...@@ -3191,9 +3272,10 @@ class IsEmptyParam : public OpParam { ...@@ -3191,9 +3272,10 @@ class IsEmptyParam : public OpParam {
public: public:
IsEmptyParam(const VariableNameMap &inputs, const VariableNameMap &outputs, IsEmptyParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
output_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, *scope);
} }
const GType *InputX() const { return input_x_; } const GType *InputX() const { return input_x_; }
...@@ -3213,9 +3295,10 @@ class IncrementParam : public OpParam { ...@@ -3213,9 +3295,10 @@ class IncrementParam : public OpParam {
public: public:
IncrementParam(const VariableNameMap &inputs, const VariableNameMap &outputs, IncrementParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
output_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
output_ = OutFrom<GType>(outputs, *scope);
step_ = OpParam::GetAttr<int>("step", attrs); step_ = OpParam::GetAttr<int>("step", attrs);
} }
...@@ -3237,9 +3320,10 @@ class Pad2dParam : public OpParam { ...@@ -3237,9 +3320,10 @@ class Pad2dParam : public OpParam {
public: public:
Pad2dParam(const VariableNameMap &inputs, const VariableNameMap &outputs, Pad2dParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, Scope *scope)
input_x_ = InputXFrom<GType>(inputs, scope); : OpParam(inputs, outputs, attrs, scope) {
out_ = OutFrom<GType>(outputs, scope); input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
} }
const RType *InputX() const { return input_x_; } const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; } RType *Out() const { return out_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册