提交 70cb9f7f 编写于 作者: xiebaiyuan's avatar xiebaiyuan

add HasAttr() for op_param.h and check "inplace" attr in reshapeOp

上级 6cb30b4d
......@@ -204,6 +204,10 @@ class OpParam {
return ((Attribute)map.at(key)).Get<T>();
}
static const bool HasAttr(const string &key, const AttributeMap &map) {
return map.count(key) > 0;
}
template <typename T>
static T *GetVarValue(const string &key, const VariableNameMap &var_map,
const Scope &scope) {
......@@ -833,7 +837,13 @@ class ReshapeParam : public OpParam {
input_shape_ = InputShapeFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
shape_ = GetAttr<vector<int>>("shape", attrs);
inplace_ = GetAttr<bool>("inplace", attrs);
if (HasAttr("inplace", attrs)) {
inplace_ = GetAttr<bool>("inplace", attrs);
} else {
inplace_ = false;
DLOG << "ReshapeParam lost inplace params. maybe fluid updated";
}
}
const RType *InputX() const { return input_x_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册