未验证 提交 165b02f1 编写于 作者: X xiebaiyuan 提交者: GitHub

[mobile][opencl] make model attr support issue more readable ,fastfail before...

[mobile][opencl] make model attr support issue more readable ,fastfail before cxxlib err ,test=mobile (#2697)
上级 2856c2b7
...@@ -344,10 +344,14 @@ class OpParam { ...@@ -344,10 +344,14 @@ class OpParam {
template <typename T> template <typename T>
static const T GetAttr(const string &key, const AttributeMap &map) { static const T GetAttr(const string &key, const AttributeMap &map) {
PADDLE_MOBILE_ENFORCE(HasAttr(key, map), "%s is not contained in attr map",
key.c_str())
return ((Attribute)map.at(key)).Get<T>(); return ((Attribute)map.at(key)).Get<T>();
} }
static const std::string GetStringAttr(const string &key, static const std::string GetStringAttr(const string &key,
const AttributeMap &map) { const AttributeMap &map) {
PADDLE_MOBILE_ENFORCE(HasAttr(key, map), "%s is not contained in attr map",
key.c_str())
return ((Attribute)map.at(key)).GetString(); return ((Attribute)map.at(key)).GetString();
} }
...@@ -355,6 +359,10 @@ class OpParam { ...@@ -355,6 +359,10 @@ class OpParam {
return map.count(key) > 0; return map.count(key) > 0;
} }
static const bool HasVar(const string &key, const VariableNameMap &var_map) {
return var_map.count(key) > 0;
}
template <typename T> template <typename T>
static T *GetVarValue(const string &key, const VariableNameMap &var_map, static T *GetVarValue(const string &key, const VariableNameMap &var_map,
const Scope &scope) { const Scope &scope) {
...@@ -3100,16 +3108,37 @@ class NearestInterpolationParam : public OpParam { ...@@ -3100,16 +3108,37 @@ class NearestInterpolationParam : public OpParam {
const AttributeMap &attrs, Scope *scope) const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) { : OpParam(inputs, outputs, attrs, scope) {
input_x_ = InputXFrom<GType>(inputs, *scope); input_x_ = InputXFrom<GType>(inputs, *scope);
input_outsize_ = InputOutSizeFrom<GType>(inputs, *scope); const bool has_out_size = HasVar("OutSize", inputs);
if (has_out_size) {
input_outsize_ = InputOutSizeFrom<GType>(inputs, *scope);
}
out_ = OutFrom<GType>(outputs, *scope); out_ = OutFrom<GType>(outputs, *scope);
out_h_ = GetAttr<int>("out_h", attrs);
out_w_ = GetAttr<int>("out_w", attrs); if (HasAttr("out_h", attrs)) {
out_h_ = GetAttr<int>("out_h", attrs);
} else if (HasAttr("out_h ", attrs)) {
// some models hurts .... attr with space ..
out_h_ = GetAttr<int>("out_h ", attrs);
}
if (HasAttr("out_w", attrs)) {
out_w_ = GetAttr<int>("out_w", attrs);
} else if (HasAttr("out_w ", attrs)) {
// some models hurts .... attr with space ..
out_w_ = GetAttr<int>("out_w ", attrs);
}
LOG(kLOG_DEBUG1) << "out_h_: " << out_h_;
LOG(kLOG_DEBUG1) << "out_w_: " << out_w_;
if (HasAttr("scale", attrs)) { if (HasAttr("scale", attrs)) {
has_scale_ = true; has_scale_ = true;
scale_ = GetAttr<float>("scale", attrs); scale_ = GetAttr<float>("scale", attrs);
} }
DLOG << "has_scale_: " << has_scale_; LOG(kLOG_DEBUG1) << "has_scale_: " << has_scale_;
DLOG << "scale_: " << scale_; LOG(kLOG_DEBUG1) << "scale_: " << scale_;
} }
const GType *InputX() const { return input_x_; } const GType *InputX() const { return input_x_; }
const GType *InputOutPutSize() const { return input_outsize_; } const GType *InputOutPutSize() const { return input_outsize_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册