diff --git a/mobile/src/operators/op_param.h b/mobile/src/operators/op_param.h index e58159fbb74e7a91a88c3e76f8aa713b679d94b8..85dabe3bcd009c8c00a59ccf74b7651d907b6dc2 100644 --- a/mobile/src/operators/op_param.h +++ b/mobile/src/operators/op_param.h @@ -344,10 +344,14 @@ class OpParam { template static const T GetAttr(const string &key, const AttributeMap &map) { + PADDLE_MOBILE_ENFORCE(HasAttr(key, map), "%s is not contained in attr map", + key.c_str()) return ((Attribute)map.at(key)).Get(); } static const std::string GetStringAttr(const string &key, const AttributeMap &map) { + PADDLE_MOBILE_ENFORCE(HasAttr(key, map), "%s is not contained in attr map", + key.c_str()) return ((Attribute)map.at(key)).GetString(); } @@ -355,6 +359,10 @@ class OpParam { return map.count(key) > 0; } + static const bool HasVar(const string &key, const VariableNameMap &var_map) { + return var_map.count(key) > 0; + } + template static T *GetVarValue(const string &key, const VariableNameMap &var_map, const Scope &scope) { @@ -3100,16 +3108,37 @@ class NearestInterpolationParam : public OpParam { const AttributeMap &attrs, Scope *scope) : OpParam(inputs, outputs, attrs, scope) { input_x_ = InputXFrom(inputs, *scope); - input_outsize_ = InputOutSizeFrom(inputs, *scope); + const bool has_out_size = HasVar("OutSize", inputs); + + if (has_out_size) { + input_outsize_ = InputOutSizeFrom(inputs, *scope); + } + out_ = OutFrom(outputs, *scope); - out_h_ = GetAttr("out_h", attrs); - out_w_ = GetAttr("out_w", attrs); + + if (HasAttr("out_h", attrs)) { + out_h_ = GetAttr("out_h", attrs); + } else if (HasAttr("out_h ", attrs)) { + // some models hurts .... attr with space .. + out_h_ = GetAttr("out_h ", attrs); + } + + if (HasAttr("out_w", attrs)) { + out_w_ = GetAttr("out_w", attrs); + } else if (HasAttr("out_w ", attrs)) { + // some models hurts .... attr with space .. + out_w_ = GetAttr("out_w ", attrs); + } + + LOG(kLOG_DEBUG1) << "out_h_: " << out_h_; + LOG(kLOG_DEBUG1) << "out_w_: " << out_w_; + if (HasAttr("scale", attrs)) { has_scale_ = true; scale_ = GetAttr("scale", attrs); } - DLOG << "has_scale_: " << has_scale_; - DLOG << "scale_: " << scale_; + LOG(kLOG_DEBUG1) << "has_scale_: " << has_scale_; + LOG(kLOG_DEBUG1) << "scale_: " << scale_; } const GType *InputX() const { return input_x_; } const GType *InputOutPutSize() const { return input_outsize_; }