未验证 提交 5d2eb678 编写于 作者: W wanghuancoder 提交者: GitHub

optimize attr default value (#33357)

* optimize attr default value, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* fix bug in AttrReader, test=develop

* fix bug, test=develop

* fix double_grad, test=develop

* refine, test=develop

* refine, test=develop

* fix checker null, test=develop

* for test, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop
上级 2133b45a
......@@ -208,15 +208,27 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
class AttrReader {
public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {}
explicit AttrReader(const AttributeMap& attrs)
: attrs_(attrs), default_attrs_(nullptr) {}
AttrReader(const AttributeMap& attrs, const AttributeMap& default_attrs)
: attrs_(attrs), default_attrs_(&default_attrs) {}
template <typename T>
inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE_NE(attrs_.count(name), 0,
auto it = attrs_.find(name);
bool found = it != attrs_.end();
if (!found) {
if (default_attrs_ != nullptr) {
it = default_attrs_->find(name);
found = it != default_attrs_->end();
}
}
PADDLE_ENFORCE_EQ(found, true,
platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));
Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
Attribute& attr = const_cast<Attribute&>(it->second);
ExtractAttribute<T> extract_attr(name);
T* attr_value = extract_attr(attr);
return *attr_value;
......@@ -224,6 +236,7 @@ class AttrReader {
private:
const AttributeMap& attrs_;
const AttributeMap* default_attrs_;
};
// check whether a value(attribute) fit a certain limit
......@@ -234,8 +247,8 @@ class GreaterThanChecker {
void operator()(const T& value) const {
PADDLE_ENFORCE_GT(
value, lower_bound_,
platform::errors::OutOfRange(
"Check for attribute value greater than a certain value failed."));
platform::errors::OutOfRange("Check for attribute value greater than "
"a certain value failed."));
}
private:
......@@ -332,8 +345,8 @@ class TypedAttrChecker {
TypedAttrChecker& SetDefault(const T& default_value) {
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), true,
platform::errors::AlreadyExists(
"Attribute (%s) has a default value and cannot be set repeatedly.",
platform::errors::AlreadyExists("Attribute (%s) has a default value "
"and cannot be set repeatedly.",
attr_name_));
default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
return *this;
......@@ -345,8 +358,8 @@ class TypedAttrChecker {
return *this;
}
void operator()(AttributeMap* attr_map,
bool get_default_value_only = false) const {
void operator()(AttributeMap* attr_map, bool get_default_value_only = false,
bool only_check_exist_value = false) const {
if (get_default_value_only) {
if (!default_value_setter_.empty()) {
attr_map->emplace(attr_name_, default_value_setter_[0]());
......@@ -354,6 +367,16 @@ class TypedAttrChecker {
return;
}
if (only_check_exist_value) {
auto it = attr_map->find(attr_name_);
if (it != attr_map->end()) {
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
} else {
auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) {
// user do not set this attr
......@@ -362,15 +385,16 @@ class TypedAttrChecker {
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
attr_map->emplace(attr_name_, default_value_setter_[0]());
auto tmp = attr_map->emplace(attr_name_, default_value_setter_[0]());
it = tmp.first;
}
it = attr_map->find(attr_name_);
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
}
private:
std::string attr_name_;
......@@ -380,7 +404,7 @@ class TypedAttrChecker {
// check whether op's all attributes fit their own limits
class OpAttrChecker {
typedef std::function<void(AttributeMap*, bool)> AttrChecker;
typedef std::function<void(AttributeMap*, bool, bool)> AttrChecker;
public:
template <typename T>
......@@ -390,18 +414,19 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>());
}
void Check(AttributeMap* attr_map, bool explicit_only = false) const {
void Check(AttributeMap* attr_map, bool explicit_only = false,
bool only_check_exist_value = false) const {
auto checker_num = attr_checkers_.size();
if (explicit_only) checker_num = explicit_checker_num_;
for (size_t i = 0; i < checker_num; ++i) {
attr_checkers_[i](attr_map, false);
attr_checkers_[i](attr_map, false, only_check_exist_value);
}
}
AttributeMap GetAttrsDefaultValuesMap() const {
AttributeMap GetDefaultAttrsMap() const {
AttributeMap default_values_map;
for (const auto& checker : attr_checkers_) {
checker(&default_values_map, true);
checker(&default_values_map, true, false);
}
return default_values_map;
}
......@@ -410,15 +435,26 @@ class OpAttrChecker {
explicit_checker_num_ = attr_checkers_.size();
}
void InitDefaultAttributeMap() {
for (const auto& checker : attr_checkers_) {
checker(&default_attrs_, true, false);
}
}
const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; }
private:
std::vector<AttrChecker> attr_checkers_;
AttributeMap default_attrs_;
// in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized
// op makers, usually it's defined in the overloaded Make method.
// for implicit attribute, we mean the attribute added outside of the Make
// method like "op_role", "op_role_var", and they are useless in dynamic graph
// method like "op_role", "op_role_var", and they are useless in dynamic
// graph
// mode
size_t explicit_checker_num_;
};
......
......@@ -781,10 +781,12 @@ void RegisterOperatorWithMetaInfo(
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const std::map<std::string, std::string>& inplace_map) {
CustomGradOpMaker<paddle::imperative::OpBase> maker(
type, var_base_map_in, var_base_map_out, attrs, inplace_map,
grad_op_name, grad_op_inputs, grad_op_outputs);
maker.SetDygraphDefaultAttrsMap(default_attrs);
return maker();
};
......
......@@ -249,8 +249,10 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const std::map<std::string, std::string>& inplace_map) {
T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map);
maker.SetDygraphDefaultAttrsMap(default_attrs);
return maker();
};
}
......
......@@ -219,6 +219,19 @@ class SingleGradOpMaker<imperative::OpBase>
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;
virtual const framework::Attribute& GetAttr(const std::string& name) const {
auto it = Attrs().find(name);
if (it == Attrs().end()) {
it = this->DefaultAttrsMap().find(name);
PADDLE_ENFORCE_EQ(it != this->DefaultAttrsMap().end(), true,
platform::errors::NotFound(
"Cannot find attribute [%s] in operator [%s]", name,
this->ForwardOpType()));
}
return it->second;
}
std::shared_ptr<imperative::GradOpNode> operator()() const final {
auto node = this->NewGradNode();
auto& inplace_map = this->GetInplaceMap();
......@@ -228,6 +241,7 @@ class SingleGradOpMaker<imperative::OpBase>
{
imperative::TracedGradOp traced_grad_op(node);
try {
traced_grad_op.SetDefaultAttrsMap(this->DefaultAttrsMap());
this->Apply(&traced_grad_op);
} catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(traced_grad_op.Type(), &exception);
......
......@@ -61,7 +61,7 @@ AttrCompat& AttrCompat::IsLeftDefault() {
return *this;
}
const OpInfo& op_info = OpInfoMap::Instance().Get(op_name);
const AttributeMap attrs = op_info.Checker()->GetAttrsDefaultValuesMap();
const AttributeMap attrs = op_info.Checker()->GetDefaultAttrsMap();
if (attrs.find(attr_name_) == attrs.end()) {
LOG(WARNING) << "Op (" << op_name << ") has no default attr:" << attr_name_;
conditions_.emplace_back([](const Attribute& attr) { return false; });
......
......@@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
op_checker_ = attr_checker;
Make();
op_checker_->RecordExplicitCheckerNum();
op_checker_->InitDefaultAttributeMap();
AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum(
......
......@@ -71,6 +71,7 @@ using DygraphGradOpMakerFN =
const imperative::NameVarBaseMap& /*var_base_map_in*/,
const imperative::NameVarBaseMap& /*var_base_map_out*/,
const framework::AttributeMap& /*attributes*/,
const framework::AttributeMap& /*default attributes*/,
const std::map<std::string, std::string>& /*inplace_map*/)>;
using InferVarTypeFN =
......
......@@ -474,10 +474,11 @@ void BasicEngine::Execute() {
try {
if (tmp_ins_ptr == nullptr) {
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
cur_op.place());
cur_op.DefaultAttrsMap(), cur_op.place());
} else {
OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs,
cur_op.Attrs(), cur_op.place());
cur_op.Attrs(), cur_op.DefaultAttrsMap(),
cur_op.place());
}
} catch (platform::EnforceNotMet& exception) {
Clear();
......
......@@ -113,9 +113,18 @@ class GradOpBaseMakerBase {
return vec_temp;
}
// Only for dygraph
void SetDygraphDefaultAttrsMap(const framework::AttributeMap& default_attrs) {
default_attrs_ = &default_attrs;
}
const framework::AttributeMap& DefaultAttrsMap() const {
return *default_attrs_;
}
const framework::AttributeMap& Attrs() const { return attrs_; }
const framework::Attribute& GetAttr(const std::string& name) const {
virtual const framework::Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE_EQ(
it != attrs_.end(), true,
......@@ -199,6 +208,7 @@ class GradOpBaseMakerBase {
const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap* default_attrs_;
const std::map<std::string, std::string>& inplace_map_;
};
......@@ -285,6 +295,10 @@ class TracedGradOp {
return op_->SetAttrMap(attrs);
}
void SetDefaultAttrsMap(const framework::AttributeMap& attrs) {
return op_->SetDefaultAttrsMap(attrs);
}
void SetAttr(const std::string& name, const framework::Attribute& v) {
op_->SetAttr(name, v);
}
......
......@@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::RuntimeContext& ctx,
const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs)
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs)
: ExecutionContext(op, scope, device_context, ctx),
var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out),
attrs_(attrs) {}
attrs_(attrs),
default_attrs_(default_attrs) {}
std::string InputName(const std::string& name) const override {
auto it = var_base_map_in_.find(name);
......@@ -92,7 +94,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
}
bool HasAttr(const std::string& name) const override {
return attrs_.count(name) != 0;
return attrs_.count(name) != 0 || default_attrs_.count(name) != 0;
}
const framework::AttributeMap& Attrs() const override { return attrs_; }
......@@ -100,9 +102,14 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::Attribute& GetAttr(const std::string& name) const override {
auto it = attrs_.find(name);
PADDLE_ENFORCE_NE(
it, attrs_.end(),
platform::errors::NotFound("can not find [%s] in attrs", name));
if (it == attrs_.end()) {
it = default_attrs_.find(name);
if (it == default_attrs_.end()) {
PADDLE_THROW(platform::errors::NotFound(
"Can not find [%s] in attributes of op %s.", name,
this->GetOp().Type()));
}
}
return it->second;
}
......@@ -192,6 +199,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const NameVarMap<VarType>& var_base_map_in_;
const NameVarMap<VarType>& var_base_map_out_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap& default_attrs_;
};
} // namespace imperative
......
......@@ -35,10 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
DygraphInferShapeContext(const NameVarMap<VarType>* in,
const NameVarMap<VarType>* out,
const framework::AttributeMap* attr,
const framework::AttributeMap* default_attr,
const std::string op_type)
: var_base_map_in_(in),
var_base_map_out_(out),
attrs_(attr),
default_attrs_(default_attr),
op_type_(op_type) {}
bool HasInput(const std::string& name) const override {
......@@ -101,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
}
framework::AttrReader Attrs() const override {
return framework::AttrReader(*attrs_);
return framework::AttrReader(*attrs_, *default_attrs_);
}
std::vector<std::string> Inputs(const std::string& name) const override {
......@@ -395,6 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const NameVarMap<VarType>* var_base_map_in_;
const NameVarMap<VarType>* var_base_map_out_;
const framework::AttributeMap* attrs_;
const framework::AttributeMap* default_attrs_;
const std::string op_type_;
};
......
......@@ -32,20 +32,28 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
public:
RuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
const NameVarMap<VarType>& outputs,
const framework::AttributeMap& attrs_map)
const framework::AttributeMap& attrs_map,
const framework::AttributeMap& default_attrs_map)
: InferVarTypeContext(nullptr, nullptr),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs_map) {}
attrs_(attrs_map),
default_attrs_(default_attrs_map) {}
virtual ~RuntimeInferVarTypeContext() {}
framework::Attribute GetAttr(const std::string& name) const override {
auto iter = attrs_.find(name);
PADDLE_ENFORCE_EQ(
iter != attrs_.end(), true,
platform::errors::NotFound("Cannot find attribute %s", name));
return iter->second;
auto it = attrs_.find(name);
if (it == attrs_.end()) {
it = default_attrs_.find(name);
if (it == default_attrs_.end()) {
PADDLE_THROW(platform::errors::NotFound(
"Can not find [%s] in attributes.", name));
}
}
return it->second;
}
bool HasInput(const std::string& name) const override {
......@@ -233,6 +241,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
const NameVarMap<VarType>& inputs_;
const NameVarMap<VarType>& outputs_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap& default_attrs_;
};
} // namespace imperative
......
......@@ -329,6 +329,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL(
......@@ -336,7 +337,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
"Only support operator with kernel in Dygraph mode."));
auto& info = op.Info();
if (info.infer_var_type_) {
RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(ins, outs, attrs);
RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(ins, outs, attrs,
default_attrs);
info.infer_var_type_(&infer_var_type_ctx);
}
......@@ -369,13 +371,14 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
* after the execution of op, but the original input is directly
* overwritten in the previous dynamic graph implemention.
*/
auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs);
auto prepared_op =
PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs);
auto tmp_ins_ptr =
PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type());
if (tmp_ins_ptr == nullptr) {
prepared_op.Run(ins, outs, attrs);
prepared_op.Run(ins, outs, attrs, default_attrs);
} else {
prepared_op.Run(*tmp_ins_ptr, outs, attrs);
prepared_op.Run(*tmp_ins_ptr, outs, attrs, default_attrs);
}
VLOG(4) << LayerDebugString(op.Type(), ins, outs);
......@@ -395,16 +398,18 @@ void OpBase::Run(const framework::OperatorBase& op,
const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, place);
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place);
}
void OpBase::Run(const framework::OperatorBase& op,
const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, place);
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place);
}
void ClearNoNeedBufferInputs(OpBase* op) {
......@@ -446,15 +451,15 @@ void ClearNoNeedBufferInputs(OpBase* op) {
std::shared_ptr<GradOpNode> CreateGradOpNode(
const framework::OperatorBase& op, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place,
const framework::AttributeMap& default_attrs, const platform::Place& place,
const std::map<std::string, std::string>& inplace_map) {
const auto& info = op.Info();
if (!info.dygraph_grad_op_maker_) {
return nullptr;
}
auto grad_node =
info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs, inplace_map);
auto grad_node = info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs,
default_attrs, inplace_map);
if (grad_node && !grad_node->empty()) {
for (auto& grad_op : *grad_node) {
grad_op.SetId(OpBase::GenerateUniqueId());
......
......@@ -108,7 +108,7 @@ class VarBase {
void ClearGradVarBase() { grad_var_ = nullptr; }
void SetGradVarBase(VarBase& grad_var) {
void SetGradVarBase(const VarBase& grad_var) {
MutableGradVarBase()->CopyFrom(grad_var, true);
}
......@@ -283,7 +283,7 @@ class Layer {
std::shared_ptr<GradOpNode> CreateGradOpNode(
const framework::OperatorBase& op, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place,
const framework::AttributeMap& default_attrs, const platform::Place& place,
const std::map<std::string, std::string>& inplace_map);
void ClearNoNeedBufferInputs(OpBase* op);
......
......@@ -50,6 +50,10 @@ class OpBase {
const framework::AttributeMap& Attrs() const { return attrs_; }
const framework::AttributeMap& DefaultAttrsMap() const {
return *default_attrs_;
}
const framework::OpInfo& Info() const {
PADDLE_ENFORCE_NOT_NULL(op_, platform::errors::PreconditionNotMet(
"OpBase::Info() should be called after "
......@@ -99,6 +103,10 @@ class OpBase {
void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; }
void SetDefaultAttrsMap(const framework::AttributeMap& default_attrs) {
default_attrs_ = &default_attrs;
}
void SetAttr(const std::string& name, const framework::Attribute& v) {
attrs_[name] = v;
}
......@@ -110,14 +118,23 @@ class OpBase {
const framework::AttributeMap& Attrs() { return attrs_; }
bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; }
const framework::AttributeMap& DefaultAttrsMap() { return *default_attrs_; }
bool HasAttr(const std::string& name) const {
return attrs_.count(name) > 0 || default_attrs_->count(name) > 0;
}
const framework::Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name);
if (it != attrs_.end()) {
return it->second;
} else {
auto it_default = default_attrs_->find(name);
PADDLE_ENFORCE_NE(
it, attrs_.end(),
it_default, default_attrs_->end(),
platform::errors::NotFound("can not find attribute [%s]", name));
return it->second;
return it_default->second;
}
}
template <typename T>
......@@ -156,12 +173,14 @@ class OpBase {
const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place);
static void Run(const framework::OperatorBase& op,
const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place);
private:
......@@ -174,6 +193,7 @@ class OpBase {
NameVarMap<VariableWrapper> ins_;
NameVarMap<VariableWrapper> outs_;
framework::AttributeMap attrs_;
const framework::AttributeMap* default_attrs_;
std::unique_ptr<framework::OperatorBase> op_;
platform::Place place_;
size_t id_{-1UL};
......
......@@ -884,11 +884,13 @@ void PartialGradTask::RunEachOp(OpBase *op) {
}
// Run op
OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), op->place());
OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(),
op->DefaultAttrsMap(), op->place());
if (create_graph_) {
auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs,
op->Attrs(), op->place(), {});
auto double_grad_node =
CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(),
op->DefaultAttrsMap(), op->place(), {});
PADDLE_ENFORCE_NOT_NULL(
double_grad_node,
platform::errors::NotFound("The Op %s doesn't have any grad op. If you "
......
......@@ -91,7 +91,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs) {
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
......@@ -108,9 +109,9 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif
// 1. get expected kernel key
auto expected_kernel_key =
op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs));
auto expected_kernel_key = op.GetExpectedKernelType(
DygraphExecutionContext<VarType>(op, framework::Scope(), *dev_ctx, ctx,
ins, outs, attrs, default_attrs));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
// 2. check if op[type] has kernel registered.
......@@ -148,16 +149,19 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs) {
return PrepareImpl<VarBase>(ins, outs, op, place, attrs);
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs) {
return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs);
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs,
default_attrs);
}
template <typename VarType>
......@@ -166,17 +170,18 @@ static void PreparedOpRunImpl(
const framework::OpKernelType& kernel_type,
const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs) {
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
// TODO(zjl): remove scope in dygraph
framework::Scope scope;
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
op.Type());
&default_attrs, op.Type());
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs));
attrs, default_attrs));
if (FLAGS_check_nan_inf) {
framework::details::CheckOpHasNanOrInfInDygraph<VarType>(
......@@ -202,16 +207,18 @@ static void PreparedOpRunImpl(
void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs) {
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
outs, attrs);
outs, attrs, default_attrs);
}
void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs) {
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
ins, outs, attrs);
ins, outs, attrs, default_attrs);
}
} // namespace imperative
......
......@@ -151,20 +151,24 @@ class PreparedOp {
const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs);
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs);
static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs);
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs);
void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out,
const framework::AttributeMap& attrs);
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs);
void Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs);
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs);
const framework::OpKernelType& kernel_type() const { return kernel_type_; }
......
......@@ -43,10 +43,12 @@ template <typename VarType>
class TestRuntimeInferVarTypeContext
: public RuntimeInferVarTypeContext<VarType> {
public:
TestRuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
const NameVarMap<VarType>& outputs,
const framework::AttributeMap& attrs_map)
: RuntimeInferVarTypeContext<VarType>(inputs, outputs, attrs_map) {}
TestRuntimeInferVarTypeContext(
const NameVarMap<VarType>& inputs, const NameVarMap<VarType>& outputs,
const framework::AttributeMap& attrs_map,
const framework::AttributeMap& default_attrs_map)
: RuntimeInferVarTypeContext<VarType>(inputs, outputs, attrs_map,
default_attrs_map) {}
bool HasVar(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::HasVar(name);
......@@ -125,7 +127,7 @@ TEST(test_layer, test_runtime_context) {
auto* ctx =
new imperative::TestRuntimeInferVarTypeContext<imperative::VarBase>(
ins, outs, attrs);
ins, outs, attrs, {});
ASSERT_TRUE(ctx->HasInput("X"));
ASSERT_TRUE(ctx->HasOutput("Out"));
......@@ -358,7 +360,7 @@ TEST(test_layer, test_dygraph_execution_context) {
framework::Scope scope;
DygraphExecutionContext<imperative::VarBase> dy_exe_context(
*(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map);
*(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map, {});
ASSERT_EQ(dy_exe_context.InputSize("X"), 1u);
ASSERT_EQ(dy_exe_context.InputName("X"), "vin");
......@@ -386,7 +388,7 @@ TEST(test_layer, test_dygraph_infershape_context) {
concat_att_map["axis"] = 1;
DygraphInferShapeContext<imperative::VarBase> infer_shape_ctx(
&ins, &outs, &concat_att_map, "dummy");
&ins, &outs, &concat_att_map, {}, "dummy");
bool have_x = infer_shape_ctx.HasOutputs("Out");
ASSERT_EQ(have_x, true);
......
......@@ -93,7 +93,7 @@ TEST(test_prepare_op, test_prepare_op) {
ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare(
ins, outs,
dynamic_cast<framework::OperatorWithKernel&>(*op),
place, split_attr_map));
place, split_attr_map, {}));
}
const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
......@@ -144,7 +144,7 @@ TEST(test_prepare_op, test_prepare_data) {
// test if it can be transformed to GPU place
auto prepared_op = PreparedOp::Prepare(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), gpu_place,
attr_map);
attr_map, {});
PrepareData<imperative::VarBase>(
dynamic_cast<framework::OperatorWithKernel&>(*op), ins,
prepared_op.kernel_type());
......@@ -193,7 +193,7 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) {
// test if it never transferred on GPU place
auto prepared_op = PreparedOp::Prepare(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), cpu_place,
attr_map);
attr_map, {});
PrepareData<imperative::VarBase>(
dynamic_cast<framework::OperatorWithKernel&>(*op), ins,
prepared_op.kernel_type());
......
......@@ -154,9 +154,14 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const auto& op_info = op->Info();
auto* attr_checker = op_info.Checker();
if (attr_checker) {
attr_checker->Check(&attrs, true);
attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
}
static paddle::framework::AttributeMap empty_attrs_map = {};
const paddle::framework::AttributeMap& default_attrs =
attr_checker == nullptr ? empty_attrs_map
: attr_checker->GetDefaultAttrMap();
NameVarBaseMap new_ins = ins;
if (enable_autocast_) {
VLOG(5) << "Auto mixed precision run operator: " << type;
......@@ -181,7 +186,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
#endif
}
OpBase::Run(*op, new_ins, outs, attrs, place);
OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place);
} catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(type, &exception);
throw std::move(exception);
......@@ -204,7 +209,8 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
}
if (ComputeRequiredGrad(new_ins, outs, trace_backward)) {
CreateGradOpNode(*op, new_ins, outs, attrs, place, inplace_map);
CreateGradOpNode(*op, new_ins, outs, attrs, default_attrs, place,
inplace_map);
} else {
VLOG(3) << "No Grad to track for Op: " << type;
}
......
......@@ -48,7 +48,7 @@ class DygraphInferShapeTest {
void SetOpType(const std::string& op_type) { op_type_ = op_type; }
void Run(std::function<void(framework::InferShapeContext* ctx)> infer_shape) {
imperative::DygraphInferShapeContext<imperative::VarBase> ctx(
&ins_, &outs_, &attrs_, op_type_);
&ins_, &outs_, &attrs_, {}, op_type_);
infer_shape(&ctx);
for (const auto& pair : expected_dims_) {
auto out = outs_[pair.first][0];
......
......@@ -1308,7 +1308,7 @@ All parameter, weight, gradient are variables in Paddle.
if (info != nullptr) {
if (info->HasOpProtoAndChecker()) {
auto op_checker = info->Checker();
res = op_checker->GetAttrsDefaultValuesMap();
res = op_checker->GetDefaultAttrsMap();
}
}
return res;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册