未验证 提交 5d2eb678 编写于 作者: W wanghuancoder 提交者: GitHub

optimize attr default value (#33357)

* optimize attr default value, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* fix bug in AttrReader, test=develop

* fix bug, test=develop

* fix double_grad, test=develop

* refine, test=develop

* refine, test=develop

* fix checker null, test=develop

* for test, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop
上级 2133b45a
...@@ -208,15 +208,27 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); ...@@ -208,15 +208,27 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
class AttrReader { class AttrReader {
public: public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {} explicit AttrReader(const AttributeMap& attrs)
: attrs_(attrs), default_attrs_(nullptr) {}
AttrReader(const AttributeMap& attrs, const AttributeMap& default_attrs)
: attrs_(attrs), default_attrs_(&default_attrs) {}
template <typename T> template <typename T>
inline const T& Get(const std::string& name) const { inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE_NE(attrs_.count(name), 0, auto it = attrs_.find(name);
bool found = it != attrs_.end();
if (!found) {
if (default_attrs_ != nullptr) {
it = default_attrs_->find(name);
found = it != default_attrs_->end();
}
}
PADDLE_ENFORCE_EQ(found, true,
platform::errors::NotFound( platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name)); "Attribute (%s) should be in AttributeMap.", name));
Attribute& attr = const_cast<Attribute&>(attrs_.at(name)); Attribute& attr = const_cast<Attribute&>(it->second);
ExtractAttribute<T> extract_attr(name); ExtractAttribute<T> extract_attr(name);
T* attr_value = extract_attr(attr); T* attr_value = extract_attr(attr);
return *attr_value; return *attr_value;
...@@ -224,6 +236,7 @@ class AttrReader { ...@@ -224,6 +236,7 @@ class AttrReader {
private: private:
const AttributeMap& attrs_; const AttributeMap& attrs_;
const AttributeMap* default_attrs_;
}; };
// check whether a value(attribute) fit a certain limit // check whether a value(attribute) fit a certain limit
...@@ -234,8 +247,8 @@ class GreaterThanChecker { ...@@ -234,8 +247,8 @@ class GreaterThanChecker {
void operator()(const T& value) const { void operator()(const T& value) const {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
value, lower_bound_, value, lower_bound_,
platform::errors::OutOfRange( platform::errors::OutOfRange("Check for attribute value greater than "
"Check for attribute value greater than a certain value failed.")); "a certain value failed."));
} }
private: private:
...@@ -332,9 +345,9 @@ class TypedAttrChecker { ...@@ -332,9 +345,9 @@ class TypedAttrChecker {
TypedAttrChecker& SetDefault(const T& default_value) { TypedAttrChecker& SetDefault(const T& default_value) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), true, default_value_setter_.empty(), true,
platform::errors::AlreadyExists( platform::errors::AlreadyExists("Attribute (%s) has a default value "
"Attribute (%s) has a default value and cannot be set repeatedly.", "and cannot be set repeatedly.",
attr_name_)); attr_name_));
default_value_setter_.push_back(DefaultValueSetter<T>(default_value)); default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
return *this; return *this;
} }
...@@ -345,8 +358,8 @@ class TypedAttrChecker { ...@@ -345,8 +358,8 @@ class TypedAttrChecker {
return *this; return *this;
} }
void operator()(AttributeMap* attr_map, void operator()(AttributeMap* attr_map, bool get_default_value_only = false,
bool get_default_value_only = false) const { bool only_check_exist_value = false) const {
if (get_default_value_only) { if (get_default_value_only) {
if (!default_value_setter_.empty()) { if (!default_value_setter_.empty()) {
attr_map->emplace(attr_name_, default_value_setter_[0]()); attr_map->emplace(attr_name_, default_value_setter_[0]());
...@@ -354,21 +367,32 @@ class TypedAttrChecker { ...@@ -354,21 +367,32 @@ class TypedAttrChecker {
return; return;
} }
auto it = attr_map->find(attr_name_); if (only_check_exist_value) {
if (it == attr_map->end()) { auto it = attr_map->find(attr_name_);
// user do not set this attr if (it != attr_map->end()) {
PADDLE_ENFORCE_EQ( ExtractAttribute<T> extract_attr(attr_name_);
default_value_setter_.empty(), false, T* attr_value = extract_attr(it->second);
platform::errors::InvalidArgument( for (const auto& checker : value_checkers_) {
"Attribute (%s) is not set correctly.", attr_name_)); checker(*attr_value);
// default_value_setter_ has no more than one element }
attr_map->emplace(attr_name_, default_value_setter_[0]()); }
} } else {
it = attr_map->find(attr_name_); auto it = attr_map->find(attr_name_);
ExtractAttribute<T> extract_attr(attr_name_); if (it == attr_map->end()) {
T* attr_value = extract_attr(it->second); // user do not set this attr
for (const auto& checker : value_checkers_) { PADDLE_ENFORCE_EQ(
checker(*attr_value); default_value_setter_.empty(), false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
auto tmp = attr_map->emplace(attr_name_, default_value_setter_[0]());
it = tmp.first;
}
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
} }
} }
...@@ -380,7 +404,7 @@ class TypedAttrChecker { ...@@ -380,7 +404,7 @@ class TypedAttrChecker {
// check whether op's all attributes fit their own limits // check whether op's all attributes fit their own limits
class OpAttrChecker { class OpAttrChecker {
typedef std::function<void(AttributeMap*, bool)> AttrChecker; typedef std::function<void(AttributeMap*, bool, bool)> AttrChecker;
public: public:
template <typename T> template <typename T>
...@@ -390,18 +414,19 @@ class OpAttrChecker { ...@@ -390,18 +414,19 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>()); return *(checker.target<TypedAttrChecker<T>>());
} }
void Check(AttributeMap* attr_map, bool explicit_only = false) const { void Check(AttributeMap* attr_map, bool explicit_only = false,
bool only_check_exist_value = false) const {
auto checker_num = attr_checkers_.size(); auto checker_num = attr_checkers_.size();
if (explicit_only) checker_num = explicit_checker_num_; if (explicit_only) checker_num = explicit_checker_num_;
for (size_t i = 0; i < checker_num; ++i) { for (size_t i = 0; i < checker_num; ++i) {
attr_checkers_[i](attr_map, false); attr_checkers_[i](attr_map, false, only_check_exist_value);
} }
} }
AttributeMap GetAttrsDefaultValuesMap() const { AttributeMap GetDefaultAttrsMap() const {
AttributeMap default_values_map; AttributeMap default_values_map;
for (const auto& checker : attr_checkers_) { for (const auto& checker : attr_checkers_) {
checker(&default_values_map, true); checker(&default_values_map, true, false);
} }
return default_values_map; return default_values_map;
} }
...@@ -410,15 +435,26 @@ class OpAttrChecker { ...@@ -410,15 +435,26 @@ class OpAttrChecker {
explicit_checker_num_ = attr_checkers_.size(); explicit_checker_num_ = attr_checkers_.size();
} }
void InitDefaultAttributeMap() {
for (const auto& checker : attr_checkers_) {
checker(&default_attrs_, true, false);
}
}
const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; }
private: private:
std::vector<AttrChecker> attr_checkers_; std::vector<AttrChecker> attr_checkers_;
AttributeMap default_attrs_;
// in order to improve the efficiency of dynamic graph mode, // in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type. // we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized // for explicit attribute, we mean the attribute added in the customized
// op makers, usually it's defined in the overloaded Make method. // op makers, usually it's defined in the overloaded Make method.
// for implicit attribute, we mean the attribute added outside of the Make // for implicit attribute, we mean the attribute added outside of the Make
// method like "op_role", "op_role_var", and they are useless in dynamic graph // method like "op_role", "op_role_var", and they are useless in dynamic
// graph
// mode // mode
size_t explicit_checker_num_; size_t explicit_checker_num_;
}; };
......
...@@ -781,10 +781,12 @@ void RegisterOperatorWithMetaInfo( ...@@ -781,10 +781,12 @@ void RegisterOperatorWithMetaInfo(
const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out, const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const std::map<std::string, std::string>& inplace_map) { const std::map<std::string, std::string>& inplace_map) {
CustomGradOpMaker<paddle::imperative::OpBase> maker( CustomGradOpMaker<paddle::imperative::OpBase> maker(
type, var_base_map_in, var_base_map_out, attrs, inplace_map, type, var_base_map_in, var_base_map_out, attrs, inplace_map,
grad_op_name, grad_op_inputs, grad_op_outputs); grad_op_name, grad_op_inputs, grad_op_outputs);
maker.SetDygraphDefaultAttrsMap(default_attrs);
return maker(); return maker();
}; };
......
...@@ -249,8 +249,10 @@ struct OpInfoFiller<T, kGradOpBaseMaker> { ...@@ -249,8 +249,10 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out, const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const std::map<std::string, std::string>& inplace_map) { const std::map<std::string, std::string>& inplace_map) {
T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map); T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map);
maker.SetDygraphDefaultAttrsMap(default_attrs);
return maker(); return maker();
}; };
} }
......
...@@ -219,6 +219,19 @@ class SingleGradOpMaker<imperative::OpBase> ...@@ -219,6 +219,19 @@ class SingleGradOpMaker<imperative::OpBase>
public: public:
using GradOpBaseMakerBase::GradOpBaseMakerBase; using GradOpBaseMakerBase::GradOpBaseMakerBase;
virtual const framework::Attribute& GetAttr(const std::string& name) const {
auto it = Attrs().find(name);
if (it == Attrs().end()) {
it = this->DefaultAttrsMap().find(name);
PADDLE_ENFORCE_EQ(it != this->DefaultAttrsMap().end(), true,
platform::errors::NotFound(
"Cannot find attribute [%s] in operator [%s]", name,
this->ForwardOpType()));
}
return it->second;
}
std::shared_ptr<imperative::GradOpNode> operator()() const final { std::shared_ptr<imperative::GradOpNode> operator()() const final {
auto node = this->NewGradNode(); auto node = this->NewGradNode();
auto& inplace_map = this->GetInplaceMap(); auto& inplace_map = this->GetInplaceMap();
...@@ -228,6 +241,7 @@ class SingleGradOpMaker<imperative::OpBase> ...@@ -228,6 +241,7 @@ class SingleGradOpMaker<imperative::OpBase>
{ {
imperative::TracedGradOp traced_grad_op(node); imperative::TracedGradOp traced_grad_op(node);
try { try {
traced_grad_op.SetDefaultAttrsMap(this->DefaultAttrsMap());
this->Apply(&traced_grad_op); this->Apply(&traced_grad_op);
} catch (platform::EnforceNotMet& exception) { } catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(traced_grad_op.Type(), &exception); framework::AppendErrorOpHint(traced_grad_op.Type(), &exception);
......
...@@ -61,7 +61,7 @@ AttrCompat& AttrCompat::IsLeftDefault() { ...@@ -61,7 +61,7 @@ AttrCompat& AttrCompat::IsLeftDefault() {
return *this; return *this;
} }
const OpInfo& op_info = OpInfoMap::Instance().Get(op_name); const OpInfo& op_info = OpInfoMap::Instance().Get(op_name);
const AttributeMap attrs = op_info.Checker()->GetAttrsDefaultValuesMap(); const AttributeMap attrs = op_info.Checker()->GetDefaultAttrsMap();
if (attrs.find(attr_name_) == attrs.end()) { if (attrs.find(attr_name_) == attrs.end()) {
LOG(WARNING) << "Op (" << op_name << ") has no default attr:" << attr_name_; LOG(WARNING) << "Op (" << op_name << ") has no default attr:" << attr_name_;
conditions_.emplace_back([](const Attribute& attr) { return false; }); conditions_.emplace_back([](const Attribute& attr) { return false; });
......
...@@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
op_checker_ = attr_checker; op_checker_ = attr_checker;
Make(); Make();
op_checker_->RecordExplicitCheckerNum(); op_checker_->RecordExplicitCheckerNum();
op_checker_->InitDefaultAttributeMap();
AddAttr<int>(OpRoleAttrName(), "The role of this operator") AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum( .InEnum(
......
...@@ -71,6 +71,7 @@ using DygraphGradOpMakerFN = ...@@ -71,6 +71,7 @@ using DygraphGradOpMakerFN =
const imperative::NameVarBaseMap& /*var_base_map_in*/, const imperative::NameVarBaseMap& /*var_base_map_in*/,
const imperative::NameVarBaseMap& /*var_base_map_out*/, const imperative::NameVarBaseMap& /*var_base_map_out*/,
const framework::AttributeMap& /*attributes*/, const framework::AttributeMap& /*attributes*/,
const framework::AttributeMap& /*default attributes*/,
const std::map<std::string, std::string>& /*inplace_map*/)>; const std::map<std::string, std::string>& /*inplace_map*/)>;
using InferVarTypeFN = using InferVarTypeFN =
......
...@@ -474,10 +474,11 @@ void BasicEngine::Execute() { ...@@ -474,10 +474,11 @@ void BasicEngine::Execute() {
try { try {
if (tmp_ins_ptr == nullptr) { if (tmp_ins_ptr == nullptr) {
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(), OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
cur_op.place()); cur_op.DefaultAttrsMap(), cur_op.place());
} else { } else {
OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs, OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs,
cur_op.Attrs(), cur_op.place()); cur_op.Attrs(), cur_op.DefaultAttrsMap(),
cur_op.place());
} }
} catch (platform::EnforceNotMet& exception) { } catch (platform::EnforceNotMet& exception) {
Clear(); Clear();
......
...@@ -113,9 +113,18 @@ class GradOpBaseMakerBase { ...@@ -113,9 +113,18 @@ class GradOpBaseMakerBase {
return vec_temp; return vec_temp;
} }
// Only for dygraph
void SetDygraphDefaultAttrsMap(const framework::AttributeMap& default_attrs) {
default_attrs_ = &default_attrs;
}
const framework::AttributeMap& DefaultAttrsMap() const {
return *default_attrs_;
}
const framework::AttributeMap& Attrs() const { return attrs_; } const framework::AttributeMap& Attrs() const { return attrs_; }
const framework::Attribute& GetAttr(const std::string& name) const { virtual const framework::Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
it != attrs_.end(), true, it != attrs_.end(), true,
...@@ -199,6 +208,7 @@ class GradOpBaseMakerBase { ...@@ -199,6 +208,7 @@ class GradOpBaseMakerBase {
const NameVarBaseMap& var_base_map_in_; const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_; const NameVarBaseMap& var_base_map_out_;
const framework::AttributeMap& attrs_; const framework::AttributeMap& attrs_;
const framework::AttributeMap* default_attrs_;
const std::map<std::string, std::string>& inplace_map_; const std::map<std::string, std::string>& inplace_map_;
}; };
...@@ -285,6 +295,10 @@ class TracedGradOp { ...@@ -285,6 +295,10 @@ class TracedGradOp {
return op_->SetAttrMap(attrs); return op_->SetAttrMap(attrs);
} }
void SetDefaultAttrsMap(const framework::AttributeMap& attrs) {
return op_->SetDefaultAttrsMap(attrs);
}
void SetAttr(const std::string& name, const framework::Attribute& v) { void SetAttr(const std::string& name, const framework::Attribute& v) {
op_->SetAttr(name, v); op_->SetAttr(name, v);
} }
......
...@@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const NameVarMap<VarType>& var_base_map_in, const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out, const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs)
: ExecutionContext(op, scope, device_context, ctx), : ExecutionContext(op, scope, device_context, ctx),
var_base_map_in_(var_base_map_in), var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out), var_base_map_out_(var_base_map_out),
attrs_(attrs) {} attrs_(attrs),
default_attrs_(default_attrs) {}
std::string InputName(const std::string& name) const override { std::string InputName(const std::string& name) const override {
auto it = var_base_map_in_.find(name); auto it = var_base_map_in_.find(name);
...@@ -92,7 +94,7 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -92,7 +94,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
} }
bool HasAttr(const std::string& name) const override { bool HasAttr(const std::string& name) const override {
return attrs_.count(name) != 0; return attrs_.count(name) != 0 || default_attrs_.count(name) != 0;
} }
const framework::AttributeMap& Attrs() const override { return attrs_; } const framework::AttributeMap& Attrs() const override { return attrs_; }
...@@ -100,9 +102,14 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -100,9 +102,14 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::Attribute& GetAttr(const std::string& name) const override { const framework::Attribute& GetAttr(const std::string& name) const override {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE_NE( if (it == attrs_.end()) {
it, attrs_.end(), it = default_attrs_.find(name);
platform::errors::NotFound("can not find [%s] in attrs", name)); if (it == default_attrs_.end()) {
PADDLE_THROW(platform::errors::NotFound(
"Can not find [%s] in attributes of op %s.", name,
this->GetOp().Type()));
}
}
return it->second; return it->second;
} }
...@@ -192,6 +199,7 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -192,6 +199,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const NameVarMap<VarType>& var_base_map_in_; const NameVarMap<VarType>& var_base_map_in_;
const NameVarMap<VarType>& var_base_map_out_; const NameVarMap<VarType>& var_base_map_out_;
const framework::AttributeMap& attrs_; const framework::AttributeMap& attrs_;
const framework::AttributeMap& default_attrs_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -35,10 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -35,10 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
DygraphInferShapeContext(const NameVarMap<VarType>* in, DygraphInferShapeContext(const NameVarMap<VarType>* in,
const NameVarMap<VarType>* out, const NameVarMap<VarType>* out,
const framework::AttributeMap* attr, const framework::AttributeMap* attr,
const framework::AttributeMap* default_attr,
const std::string op_type) const std::string op_type)
: var_base_map_in_(in), : var_base_map_in_(in),
var_base_map_out_(out), var_base_map_out_(out),
attrs_(attr), attrs_(attr),
default_attrs_(default_attr),
op_type_(op_type) {} op_type_(op_type) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
...@@ -101,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -101,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
} }
framework::AttrReader Attrs() const override { framework::AttrReader Attrs() const override {
return framework::AttrReader(*attrs_); return framework::AttrReader(*attrs_, *default_attrs_);
} }
std::vector<std::string> Inputs(const std::string& name) const override { std::vector<std::string> Inputs(const std::string& name) const override {
...@@ -395,6 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -395,6 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const NameVarMap<VarType>* var_base_map_in_; const NameVarMap<VarType>* var_base_map_in_;
const NameVarMap<VarType>* var_base_map_out_; const NameVarMap<VarType>* var_base_map_out_;
const framework::AttributeMap* attrs_; const framework::AttributeMap* attrs_;
const framework::AttributeMap* default_attrs_;
const std::string op_type_; const std::string op_type_;
}; };
......
...@@ -32,20 +32,28 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -32,20 +32,28 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
public: public:
RuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs, RuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
const NameVarMap<VarType>& outputs, const NameVarMap<VarType>& outputs,
const framework::AttributeMap& attrs_map) const framework::AttributeMap& attrs_map,
const framework::AttributeMap& default_attrs_map)
: InferVarTypeContext(nullptr, nullptr), : InferVarTypeContext(nullptr, nullptr),
inputs_(inputs), inputs_(inputs),
outputs_(outputs), outputs_(outputs),
attrs_(attrs_map) {} attrs_(attrs_map),
default_attrs_(default_attrs_map) {}
virtual ~RuntimeInferVarTypeContext() {} virtual ~RuntimeInferVarTypeContext() {}
framework::Attribute GetAttr(const std::string& name) const override { framework::Attribute GetAttr(const std::string& name) const override {
auto iter = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE_EQ(
iter != attrs_.end(), true, if (it == attrs_.end()) {
platform::errors::NotFound("Cannot find attribute %s", name)); it = default_attrs_.find(name);
return iter->second; if (it == default_attrs_.end()) {
PADDLE_THROW(platform::errors::NotFound(
"Can not find [%s] in attributes.", name));
}
}
return it->second;
} }
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
...@@ -233,6 +241,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -233,6 +241,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
const NameVarMap<VarType>& inputs_; const NameVarMap<VarType>& inputs_;
const NameVarMap<VarType>& outputs_; const NameVarMap<VarType>& outputs_;
const framework::AttributeMap& attrs_; const framework::AttributeMap& attrs_;
const framework::AttributeMap& default_attrs_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -329,6 +329,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -329,6 +329,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) { const platform::Place& place) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op); auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
...@@ -336,7 +337,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -336,7 +337,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
"Only support operator with kernel in Dygraph mode.")); "Only support operator with kernel in Dygraph mode."));
auto& info = op.Info(); auto& info = op.Info();
if (info.infer_var_type_) { if (info.infer_var_type_) {
RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(ins, outs, attrs); RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(ins, outs, attrs,
default_attrs);
info.infer_var_type_(&infer_var_type_ctx); info.infer_var_type_(&infer_var_type_ctx);
} }
...@@ -369,13 +371,14 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -369,13 +371,14 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
* after the execution of op, but the original input is directly * after the execution of op, but the original input is directly
* overwritten in the previous dynamic graph implemention. * overwritten in the previous dynamic graph implemention.
*/ */
auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs); auto prepared_op =
PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs);
auto tmp_ins_ptr = auto tmp_ins_ptr =
PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type()); PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type());
if (tmp_ins_ptr == nullptr) { if (tmp_ins_ptr == nullptr) {
prepared_op.Run(ins, outs, attrs); prepared_op.Run(ins, outs, attrs, default_attrs);
} else { } else {
prepared_op.Run(*tmp_ins_ptr, outs, attrs); prepared_op.Run(*tmp_ins_ptr, outs, attrs, default_attrs);
} }
VLOG(4) << LayerDebugString(op.Type(), ins, outs); VLOG(4) << LayerDebugString(op.Type(), ins, outs);
...@@ -395,16 +398,18 @@ void OpBase::Run(const framework::OperatorBase& op, ...@@ -395,16 +398,18 @@ void OpBase::Run(const framework::OperatorBase& op,
const NameVarMap<VarBase>& ins, const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) { const platform::Place& place) {
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, place); OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place);
} }
void OpBase::Run(const framework::OperatorBase& op, void OpBase::Run(const framework::OperatorBase& op,
const NameVarMap<VariableWrapper>& ins, const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs, const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) { const platform::Place& place) {
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, place); OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place);
} }
void ClearNoNeedBufferInputs(OpBase* op) { void ClearNoNeedBufferInputs(OpBase* op) {
...@@ -446,15 +451,15 @@ void ClearNoNeedBufferInputs(OpBase* op) { ...@@ -446,15 +451,15 @@ void ClearNoNeedBufferInputs(OpBase* op) {
std::shared_ptr<GradOpNode> CreateGradOpNode( std::shared_ptr<GradOpNode> CreateGradOpNode(
const framework::OperatorBase& op, const NameVarBaseMap& ins, const framework::OperatorBase& op, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs, const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place, const framework::AttributeMap& default_attrs, const platform::Place& place,
const std::map<std::string, std::string>& inplace_map) { const std::map<std::string, std::string>& inplace_map) {
const auto& info = op.Info(); const auto& info = op.Info();
if (!info.dygraph_grad_op_maker_) { if (!info.dygraph_grad_op_maker_) {
return nullptr; return nullptr;
} }
auto grad_node = auto grad_node = info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs,
info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs, inplace_map); default_attrs, inplace_map);
if (grad_node && !grad_node->empty()) { if (grad_node && !grad_node->empty()) {
for (auto& grad_op : *grad_node) { for (auto& grad_op : *grad_node) {
grad_op.SetId(OpBase::GenerateUniqueId()); grad_op.SetId(OpBase::GenerateUniqueId());
......
...@@ -108,7 +108,7 @@ class VarBase { ...@@ -108,7 +108,7 @@ class VarBase {
void ClearGradVarBase() { grad_var_ = nullptr; } void ClearGradVarBase() { grad_var_ = nullptr; }
void SetGradVarBase(VarBase& grad_var) { void SetGradVarBase(const VarBase& grad_var) {
MutableGradVarBase()->CopyFrom(grad_var, true); MutableGradVarBase()->CopyFrom(grad_var, true);
} }
...@@ -283,7 +283,7 @@ class Layer { ...@@ -283,7 +283,7 @@ class Layer {
std::shared_ptr<GradOpNode> CreateGradOpNode( std::shared_ptr<GradOpNode> CreateGradOpNode(
const framework::OperatorBase& op, const NameVarBaseMap& ins, const framework::OperatorBase& op, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs, const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place, const framework::AttributeMap& default_attrs, const platform::Place& place,
const std::map<std::string, std::string>& inplace_map); const std::map<std::string, std::string>& inplace_map);
void ClearNoNeedBufferInputs(OpBase* op); void ClearNoNeedBufferInputs(OpBase* op);
......
...@@ -50,6 +50,10 @@ class OpBase { ...@@ -50,6 +50,10 @@ class OpBase {
const framework::AttributeMap& Attrs() const { return attrs_; } const framework::AttributeMap& Attrs() const { return attrs_; }
const framework::AttributeMap& DefaultAttrsMap() const {
return *default_attrs_;
}
const framework::OpInfo& Info() const { const framework::OpInfo& Info() const {
PADDLE_ENFORCE_NOT_NULL(op_, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(op_, platform::errors::PreconditionNotMet(
"OpBase::Info() should be called after " "OpBase::Info() should be called after "
...@@ -99,6 +103,10 @@ class OpBase { ...@@ -99,6 +103,10 @@ class OpBase {
void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; } void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; }
void SetDefaultAttrsMap(const framework::AttributeMap& default_attrs) {
default_attrs_ = &default_attrs;
}
void SetAttr(const std::string& name, const framework::Attribute& v) { void SetAttr(const std::string& name, const framework::Attribute& v) {
attrs_[name] = v; attrs_[name] = v;
} }
...@@ -110,14 +118,23 @@ class OpBase { ...@@ -110,14 +118,23 @@ class OpBase {
const framework::AttributeMap& Attrs() { return attrs_; } const framework::AttributeMap& Attrs() { return attrs_; }
bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; } const framework::AttributeMap& DefaultAttrsMap() { return *default_attrs_; }
bool HasAttr(const std::string& name) const {
return attrs_.count(name) > 0 || default_attrs_->count(name) > 0;
}
const framework::Attribute& GetAttr(const std::string& name) const { const framework::Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE_NE( if (it != attrs_.end()) {
it, attrs_.end(), return it->second;
platform::errors::NotFound("can not find attribute [%s]", name)); } else {
return it->second; auto it_default = default_attrs_->find(name);
PADDLE_ENFORCE_NE(
it_default, default_attrs_->end(),
platform::errors::NotFound("can not find attribute [%s]", name));
return it_default->second;
}
} }
template <typename T> template <typename T>
...@@ -156,12 +173,14 @@ class OpBase { ...@@ -156,12 +173,14 @@ class OpBase {
const NameVarMap<VarBase>& ins, const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place); const platform::Place& place);
static void Run(const framework::OperatorBase& op, static void Run(const framework::OperatorBase& op,
const NameVarMap<VariableWrapper>& ins, const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs, const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place); const platform::Place& place);
private: private:
...@@ -174,6 +193,7 @@ class OpBase { ...@@ -174,6 +193,7 @@ class OpBase {
NameVarMap<VariableWrapper> ins_; NameVarMap<VariableWrapper> ins_;
NameVarMap<VariableWrapper> outs_; NameVarMap<VariableWrapper> outs_;
framework::AttributeMap attrs_; framework::AttributeMap attrs_;
const framework::AttributeMap* default_attrs_;
std::unique_ptr<framework::OperatorBase> op_; std::unique_ptr<framework::OperatorBase> op_;
platform::Place place_; platform::Place place_;
size_t id_{-1UL}; size_t id_{-1UL};
......
...@@ -884,11 +884,13 @@ void PartialGradTask::RunEachOp(OpBase *op) { ...@@ -884,11 +884,13 @@ void PartialGradTask::RunEachOp(OpBase *op) {
} }
// Run op // Run op
OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), op->place()); OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(),
op->DefaultAttrsMap(), op->place());
if (create_graph_) { if (create_graph_) {
auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, auto double_grad_node =
op->Attrs(), op->place(), {}); CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(),
op->DefaultAttrsMap(), op->place(), {});
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
double_grad_node, double_grad_node,
platform::errors::NotFound("The Op %s doesn't have any grad op. If you " platform::errors::NotFound("The Op %s doesn't have any grad op. If you "
......
...@@ -91,7 +91,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -91,7 +91,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs) { const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -108,9 +109,9 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -108,9 +109,9 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif #endif
// 1. get expected kernel key // 1. get expected kernel key
auto expected_kernel_key = auto expected_kernel_key = op.GetExpectedKernelType(
op.GetExpectedKernelType(DygraphExecutionContext<VarType>( DygraphExecutionContext<VarType>(op, framework::Scope(), *dev_ctx, ctx,
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs)); ins, outs, attrs, default_attrs));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
// 2. check if op[type] has kernel registered. // 2. check if op[type] has kernel registered.
...@@ -148,16 +149,19 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins, ...@@ -148,16 +149,19 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs) { const framework::AttributeMap& attrs,
return PrepareImpl<VarBase>(ins, outs, op, place, attrs); const framework::AttributeMap& default_attrs) {
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs);
} }
PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins, PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs, const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs) { const framework::AttributeMap& attrs,
return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs); const framework::AttributeMap& default_attrs) {
return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs,
default_attrs);
} }
template <typename VarType> template <typename VarType>
...@@ -166,17 +170,18 @@ static void PreparedOpRunImpl( ...@@ -166,17 +170,18 @@ static void PreparedOpRunImpl(
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins, platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs) { const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
// TODO(zjl): remove scope in dygraph // TODO(zjl): remove scope in dygraph
framework::Scope scope; framework::Scope scope;
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs, DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
op.Type()); &default_attrs, op.Type());
static_cast<const framework::OperatorWithKernel&>(op).InferShape( static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx); &infer_shape_ctx);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs, func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs)); attrs, default_attrs));
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
framework::details::CheckOpHasNanOrInfInDygraph<VarType>( framework::details::CheckOpHasNanOrInfInDygraph<VarType>(
...@@ -202,16 +207,18 @@ static void PreparedOpRunImpl( ...@@ -202,16 +207,18 @@ static void PreparedOpRunImpl(
void PreparedOp::Run(const NameVarMap<VarBase>& ins, void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs) { const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
outs, attrs); outs, attrs, default_attrs);
} }
void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs, const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs) { const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_, PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
ins, outs, attrs); ins, outs, attrs, default_attrs);
} }
} // namespace imperative } // namespace imperative
......
...@@ -151,20 +151,24 @@ class PreparedOp { ...@@ -151,20 +151,24 @@ class PreparedOp {
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs); const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs);
static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins, static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs, const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs); const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs);
void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out, void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out,
const framework::AttributeMap& attrs); const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs);
void Run(const NameVarMap<VariableWrapper>& ins, void Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs, const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs); const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs);
const framework::OpKernelType& kernel_type() const { return kernel_type_; } const framework::OpKernelType& kernel_type() const { return kernel_type_; }
......
...@@ -43,10 +43,12 @@ template <typename VarType> ...@@ -43,10 +43,12 @@ template <typename VarType>
class TestRuntimeInferVarTypeContext class TestRuntimeInferVarTypeContext
: public RuntimeInferVarTypeContext<VarType> { : public RuntimeInferVarTypeContext<VarType> {
public: public:
TestRuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs, TestRuntimeInferVarTypeContext(
const NameVarMap<VarType>& outputs, const NameVarMap<VarType>& inputs, const NameVarMap<VarType>& outputs,
const framework::AttributeMap& attrs_map) const framework::AttributeMap& attrs_map,
: RuntimeInferVarTypeContext<VarType>(inputs, outputs, attrs_map) {} const framework::AttributeMap& default_attrs_map)
: RuntimeInferVarTypeContext<VarType>(inputs, outputs, attrs_map,
default_attrs_map) {}
bool HasVar(const std::string& name) const { bool HasVar(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::HasVar(name); return RuntimeInferVarTypeContext<VarType>::HasVar(name);
...@@ -125,7 +127,7 @@ TEST(test_layer, test_runtime_context) { ...@@ -125,7 +127,7 @@ TEST(test_layer, test_runtime_context) {
auto* ctx = auto* ctx =
new imperative::TestRuntimeInferVarTypeContext<imperative::VarBase>( new imperative::TestRuntimeInferVarTypeContext<imperative::VarBase>(
ins, outs, attrs); ins, outs, attrs, {});
ASSERT_TRUE(ctx->HasInput("X")); ASSERT_TRUE(ctx->HasInput("X"));
ASSERT_TRUE(ctx->HasOutput("Out")); ASSERT_TRUE(ctx->HasOutput("Out"));
...@@ -358,7 +360,7 @@ TEST(test_layer, test_dygraph_execution_context) { ...@@ -358,7 +360,7 @@ TEST(test_layer, test_dygraph_execution_context) {
framework::Scope scope; framework::Scope scope;
DygraphExecutionContext<imperative::VarBase> dy_exe_context( DygraphExecutionContext<imperative::VarBase> dy_exe_context(
*(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map); *(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map, {});
ASSERT_EQ(dy_exe_context.InputSize("X"), 1u); ASSERT_EQ(dy_exe_context.InputSize("X"), 1u);
ASSERT_EQ(dy_exe_context.InputName("X"), "vin"); ASSERT_EQ(dy_exe_context.InputName("X"), "vin");
...@@ -386,7 +388,7 @@ TEST(test_layer, test_dygraph_infershape_context) { ...@@ -386,7 +388,7 @@ TEST(test_layer, test_dygraph_infershape_context) {
concat_att_map["axis"] = 1; concat_att_map["axis"] = 1;
DygraphInferShapeContext<imperative::VarBase> infer_shape_ctx( DygraphInferShapeContext<imperative::VarBase> infer_shape_ctx(
&ins, &outs, &concat_att_map, "dummy"); &ins, &outs, &concat_att_map, {}, "dummy");
bool have_x = infer_shape_ctx.HasOutputs("Out"); bool have_x = infer_shape_ctx.HasOutputs("Out");
ASSERT_EQ(have_x, true); ASSERT_EQ(have_x, true);
......
...@@ -93,7 +93,7 @@ TEST(test_prepare_op, test_prepare_op) { ...@@ -93,7 +93,7 @@ TEST(test_prepare_op, test_prepare_op) {
ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare( ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare(
ins, outs, ins, outs,
dynamic_cast<framework::OperatorWithKernel&>(*op), dynamic_cast<framework::OperatorWithKernel&>(*op),
place, split_attr_map)); place, split_attr_map, {}));
} }
const framework::Tensor* GetTensorFromVar(const framework::Variable& var); const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
...@@ -144,7 +144,7 @@ TEST(test_prepare_op, test_prepare_data) { ...@@ -144,7 +144,7 @@ TEST(test_prepare_op, test_prepare_data) {
// test if it can be transformed to GPU place // test if it can be transformed to GPU place
auto prepared_op = PreparedOp::Prepare( auto prepared_op = PreparedOp::Prepare(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), gpu_place, ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), gpu_place,
attr_map); attr_map, {});
PrepareData<imperative::VarBase>( PrepareData<imperative::VarBase>(
dynamic_cast<framework::OperatorWithKernel&>(*op), ins, dynamic_cast<framework::OperatorWithKernel&>(*op), ins,
prepared_op.kernel_type()); prepared_op.kernel_type());
...@@ -193,7 +193,7 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) { ...@@ -193,7 +193,7 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) {
// test if it never transferred on GPU place // test if it never transferred on GPU place
auto prepared_op = PreparedOp::Prepare( auto prepared_op = PreparedOp::Prepare(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), cpu_place, ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), cpu_place,
attr_map); attr_map, {});
PrepareData<imperative::VarBase>( PrepareData<imperative::VarBase>(
dynamic_cast<framework::OperatorWithKernel&>(*op), ins, dynamic_cast<framework::OperatorWithKernel&>(*op), ins,
prepared_op.kernel_type()); prepared_op.kernel_type());
......
...@@ -154,9 +154,14 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -154,9 +154,14 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const auto& op_info = op->Info(); const auto& op_info = op->Info();
auto* attr_checker = op_info.Checker(); auto* attr_checker = op_info.Checker();
if (attr_checker) { if (attr_checker) {
attr_checker->Check(&attrs, true); attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
} }
static paddle::framework::AttributeMap empty_attrs_map = {};
const paddle::framework::AttributeMap& default_attrs =
attr_checker == nullptr ? empty_attrs_map
: attr_checker->GetDefaultAttrMap();
NameVarBaseMap new_ins = ins; NameVarBaseMap new_ins = ins;
if (enable_autocast_) { if (enable_autocast_) {
VLOG(5) << "Auto mixed precision run operator: " << type; VLOG(5) << "Auto mixed precision run operator: " << type;
...@@ -181,7 +186,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -181,7 +186,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
#endif #endif
} }
OpBase::Run(*op, new_ins, outs, attrs, place); OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place);
} catch (platform::EnforceNotMet& exception) { } catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(type, &exception); framework::AppendErrorOpHint(type, &exception);
throw std::move(exception); throw std::move(exception);
...@@ -204,7 +209,8 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -204,7 +209,8 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
} }
if (ComputeRequiredGrad(new_ins, outs, trace_backward)) { if (ComputeRequiredGrad(new_ins, outs, trace_backward)) {
CreateGradOpNode(*op, new_ins, outs, attrs, place, inplace_map); CreateGradOpNode(*op, new_ins, outs, attrs, default_attrs, place,
inplace_map);
} else { } else {
VLOG(3) << "No Grad to track for Op: " << type; VLOG(3) << "No Grad to track for Op: " << type;
} }
......
...@@ -48,7 +48,7 @@ class DygraphInferShapeTest { ...@@ -48,7 +48,7 @@ class DygraphInferShapeTest {
void SetOpType(const std::string& op_type) { op_type_ = op_type; } void SetOpType(const std::string& op_type) { op_type_ = op_type; }
void Run(std::function<void(framework::InferShapeContext* ctx)> infer_shape) { void Run(std::function<void(framework::InferShapeContext* ctx)> infer_shape) {
imperative::DygraphInferShapeContext<imperative::VarBase> ctx( imperative::DygraphInferShapeContext<imperative::VarBase> ctx(
&ins_, &outs_, &attrs_, op_type_); &ins_, &outs_, &attrs_, {}, op_type_);
infer_shape(&ctx); infer_shape(&ctx);
for (const auto& pair : expected_dims_) { for (const auto& pair : expected_dims_) {
auto out = outs_[pair.first][0]; auto out = outs_[pair.first][0];
......
...@@ -1308,7 +1308,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1308,7 +1308,7 @@ All parameter, weight, gradient are variables in Paddle.
if (info != nullptr) { if (info != nullptr) {
if (info->HasOpProtoAndChecker()) { if (info->HasOpProtoAndChecker()) {
auto op_checker = info->Checker(); auto op_checker = info->Checker();
res = op_checker->GetAttrsDefaultValuesMap(); res = op_checker->GetDefaultAttrsMap();
} }
} }
return res; return res;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册