diff --git a/.gitignore b/.gitignore index 875de7af1beeace354bbc5cd1e0fffd05960ea81..4b53cfd8591686efc8c92a169620e75620ddbac1 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ paddle/fluid/API_PR.spec paddle/fluid/eager/api/generated/* paddle/fluid/op_use_default_grad_maker_DEV.spec paddle/fluid/op_use_default_grad_maker_PR.spec -paddle/fluid/operators/ops_extra_info.h +paddle/fluid/operators/ops_extra_info.cc paddle/phi/api/backward/backward_api.h paddle/phi/api/backward/sparse_bw_api.h paddle/phi/api/include/api.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index a3eb067426f5636bbdd6da5a9012d7eebb6ae33f..7a9c631941e047b089b5fc6faebdd149944c5aec 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -348,7 +348,7 @@ cc_test( cc_library( op_proto_maker SRCS op_proto_maker.cc - DEPS framework_proto attribute glog) + DEPS framework_proto attribute ops_extra_info glog) cc_test( op_proto_maker_test SRCS op_proto_maker_test.cc @@ -483,6 +483,7 @@ cc_library( proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute + ops_extra_info shape_inference op_info operator @@ -498,7 +499,7 @@ endif() cc_library( op_registry SRCS op_registry.cc - DEPS op_proto_maker op_info operator glog proto_desc) + DEPS op_proto_maker op_info operator ops_extra_info glog proto_desc) cc_library( op_call_stack diff --git a/paddle/fluid/framework/attribute_checker.h b/paddle/fluid/framework/attribute_checker.h index f3650dc085d6ec84dc597f469629f4d29779e6ed..fbafe9c73a9cc636a72f2c66685122a0ef53e5c8 100644 --- a/paddle/fluid/framework/attribute_checker.h +++ b/paddle/fluid/framework/attribute_checker.h @@ -325,10 +325,13 @@ class OpAttrChecker { explicit_checker_num_ = attr_checkers_.size(); } - void InitDefaultAttributeMap() { + void InitDefaultAttributeMap(const AttributeMap* extra_attr_map) { for (const auto& checker : attr_checkers_) { checker(&default_attrs_, true, false); } + if (extra_attr_map) { + default_attrs_.insert(extra_attr_map->begin(), extra_attr_map->end()); + } } const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; } diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index beedf6198aa1d31517f2c740af7ed88089fff1e6..132d73f5af6f45c3b5a32ae87598dc8afa75b7ed 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -207,9 +207,9 @@ struct OpInfoFiller { "OpAttrChecker of %s has been registered.", op_type)); info->proto_ = new proto::OpProto; info->checker_ = new OpAttrChecker(); + info->proto_->set_type(op_type); T maker; maker(info->proto_, info->checker_); - info->proto_->set_type(op_type); PADDLE_ENFORCE_EQ( info->proto_->IsInitialized(), true, diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index 7f6fc4690b8774263a11c2fa34514224736051d8..25960383904b6a8199b1603eaabb7b48706a9c9f 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -161,6 +161,10 @@ class GradOpDescMakerBase { return fwd_op_.GetAttrMap(); } + const std::unordered_map& RuntimeAttrs() const { + return fwd_op_.GetRuntimeAttrMap(); + } + const Attribute& GetAttr(const std::string& name) const { auto& map = fwd_op_.GetAttrMap(); auto it = map.find(name); @@ -209,6 +213,7 @@ class SingleGradOpMaker : public GradOpDescMakerBase { retv.emplace_back(new OpDesc()); try { this->Apply(retv.front().get()); + retv.front()->SetRuntimeAttrMap(this->RuntimeAttrs()); } catch (platform::EnforceNotMet& exception) { framework::AppendErrorOpHint(retv.front().get()->Type(), &exception); throw std::move(exception); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 1c68986a003461d7e4cb78a363fc562eae19a366..27b69d1cd2c390425c5fd1e11d12ebb11cde462e 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -298,6 +298,9 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { for (auto& attr : conv->Op()->GetAttrMap()) { desc.SetAttr(attr.first, attr.second); } + for (auto& attr : conv->Op()->GetRuntimeAttrMap()) { + desc.SetAttr(attr.first, attr.second); + } auto conv_bias_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node); diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index e01fcb68fbf3beb71f8d8e5f998db8b4616d5249..68312c7bb364943174630c9bada8a9b7aea1ed0e 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -18,6 +18,7 @@ set(STANDALONE_EXECUTOR_DEPS scope framework_proto data_feed_proto + ops_extra_info heter_service_proto trainer_desc_proto glog diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 27606ca2b0c2daad8d38405de6e7c2493cfd2ce5..99b3b68cc2a89260b17c503222feb54c62f59219 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h" +#include "paddle/fluid/operators/ops_extra_info.h" #include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_factory.h" @@ -242,20 +243,30 @@ void build_variable_scope(const framework::BlockDesc& block, void create_all_ops(const framework::BlockDesc& block, std::vector>* ops) { for (auto& op : block.AllOps()) { - VLOG(3) << "CreateOp from : " << op->Type(); + auto op_type = op->Type(); + VLOG(1) << "CreateOp from : " << op_type; - auto& info = OpInfoMap::Instance().Get(op->Type()); + auto& info = OpInfoMap::Instance().Get(op_type); const VariableNameMap& inputs_names = op->Inputs(); const VariableNameMap& outputs_names = op->Outputs(); AttributeMap op_attr_map = op->GetAttrMap(); + AttributeMap op_runtime_attr_map = op->GetRuntimeAttrMap(); if (info.Checker() != nullptr) { info.Checker()->Check(&op_attr_map); } + + const auto& extra_attr_checkers = + operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(op_type); + for (const auto& checker : extra_attr_checkers) { + checker(&op_runtime_attr_map, false); + } + auto op_base = - info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); + info.Creator()(op_type, inputs_names, outputs_names, op_attr_map); + op_base->SetRuntimeAttributeMap(op_runtime_attr_map); #ifdef PADDLE_WITH_MKLDNN if (FLAGS_use_mkldnn) { diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 507f7cd166ea070f7e21229d3dcee6afd9cfb0b1..22061d7cb2a14fdb7c483afe3e1dcac969ce77c1 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/var_type_inference.h" +#include "paddle/fluid/operators/ops_extra_info.h" #include "paddle/utils/blank.h" namespace paddle { @@ -409,6 +410,13 @@ class CompileTimeInferShapeContext : public InferShapeContext { const BlockDesc &block_; }; +static void InitRuntimeAttributeMapByOpExtraInfo(const std::string &op_type, + AttributeMap *runtime_attrs) { + const auto &extra_attr_map = + operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type); + runtime_attrs->insert(extra_attr_map.begin(), extra_attr_map.end()); +} + OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, @@ -419,6 +427,7 @@ OpDesc::OpDesc(const std::string &type, attrs_ = attrs; need_update_ = true; block_ = nullptr; + InitRuntimeAttributeMapByOpExtraInfo(type, &runtime_attrs_); } OpDesc::OpDesc(const OpDesc &other) { @@ -441,6 +450,8 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { inputs_ = op_desc.inputs_; outputs_ = op_desc.outputs_; attrs_ = op_desc.attrs_; + runtime_attrs_ = op_desc.runtime_attrs_; + // The record of original_id_ is only for auto parallel. original_id_ = op_desc.original_id_; if (op_desc.dist_attr_) { dist_attr_.reset(new OperatorDistAttr(*op_desc.dist_attr_)); @@ -473,8 +484,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) } } // restore attrs_ + InitRuntimeAttributeMapByOpExtraInfo(desc.type(), &runtime_attrs_); for (const proto::OpDesc::Attr &attr : desc_.attrs()) { - std::string attr_name = attr.name(); + const std::string &attr_name = attr.name(); // The sub_block referred to by the BLOCK attr hasn't been added // to ProgramDesc class yet, we skip setting BLOCK/BLOCKS/VAR/VARS attr // here. @@ -483,7 +495,12 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) attr_type != proto::AttrType::BLOCKS && attr_type != proto::AttrType::VAR && attr_type != proto::AttrType::VARS) { - attrs_[attr_name] = GetAttrValue(attr); + auto iter = runtime_attrs_.find(attr_name); + if (iter == runtime_attrs_.end()) { + attrs_[attr_name] = GetAttrValue(attr); + } else { + iter->second = GetAttrValue(attr); + } } } this->block_ = block; @@ -622,7 +639,13 @@ std::vector OpDesc::AttrNames(bool with_attr_var) const { bool OpDesc::HasAttr(const std::string &name, bool with_attr_var) const { auto iter = attrs_.find(name); - bool is_found = iter != attrs_.end(); + bool is_found = true; + if (iter == attrs_.end()) { + iter = runtime_attrs_.find(name); + if (iter == runtime_attrs_.end()) { + is_found = false; + } + } if (with_attr_var) { return is_found; } @@ -631,10 +654,19 @@ bool OpDesc::HasAttr(const std::string &name, bool with_attr_var) const { void OpDesc::RemoveAttr(const std::string &name) { attrs_.erase(name); + runtime_attrs_.erase(name); need_update_ = true; } void OpDesc::SetAttr(const std::string &name, const Attribute &v) { + AttributeMap *attrs_ptr = &(this->attrs_); + + const auto &extra_attr_map = + operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(Type()); + auto extra_attr_iter = extra_attr_map.find(name); + if (extra_attr_iter != extra_attr_map.end()) { + attrs_ptr = &(this->runtime_attrs_); + } // NOTICE(minqiyang): pybind11 will take the empty list in python as // the std::vector type in C++; so we have to change the attr's type // here if we meet this issue @@ -647,25 +679,25 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { case proto::AttrType::BOOLEANS: { VLOG(11) << "SetAttr: " << Type() << ", " << name << " from INTS to BOOLEANS"; - this->attrs_[name] = std::vector(); + attrs_ptr->operator[](name) = std::vector(); break; } case proto::AttrType::INTS: { VLOG(11) << "SetAttr: " << Type() << ", " << name << " from INTS to INTS"; - this->attrs_[name] = std::vector(); + attrs_ptr->operator[](name) = std::vector(); break; } case proto::AttrType::LONGS: { VLOG(11) << "SetAttr: " << Type() << ", " << name << " from LONGS to LONGS"; - this->attrs_[name] = std::vector(); + attrs_ptr->operator[](name) = std::vector(); break; } case proto::AttrType::FLOATS: { VLOG(11) << "SetAttr: " << Type() << ", " << name << " from INTS to FLOATS"; - this->attrs_[name] = std::vector(); + attrs_ptr->operator[](name) = std::vector(); break; } case proto::AttrType::FLOAT64S: { @@ -677,13 +709,13 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { case proto::AttrType::STRINGS: { VLOG(11) << "SetAttr: " << Type() << ", " << name << " from INTS to STRINGS"; - this->attrs_[name] = std::vector(); + attrs_ptr->operator[](name) = std::vector(); break; } case proto::AttrType::BLOCKS: { VLOG(11) << "SetAttr: " << Type() << ", " << name << " from INTS to BLOCKS"; - this->SetBlocksAttr(name, std::vector()); + attrs_ptr->operator[](name) = std::vector(); return; } default: @@ -695,14 +727,23 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { } // In order to set bool attr properly - if (attr_type == proto::AttrType::INT && HasProtoAttr(name) && - GetProtoAttr(name).type() == proto::AttrType::BOOLEAN) { - this->attrs_[name] = static_cast(PADDLE_GET_CONST(int, v)); - need_update_ = true; - return; + if (attr_type == proto::AttrType::INT) { + if (HasProtoAttr(name) && + GetProtoAttr(name).type() == proto::AttrType::BOOLEAN) { + attrs_ptr->operator[](name) = static_cast(PADDLE_GET_CONST(int, v)); + need_update_ = true; + return; + } + if (extra_attr_iter != extra_attr_map.end() && + static_cast(extra_attr_iter->second.index() - 1) == + proto::AttrType::BOOLEAN) { + attrs_ptr->operator[](name) = static_cast(PADDLE_GET_CONST(int, v)); + need_update_ = true; + return; + } } - this->attrs_[name] = v; + attrs_ptr->operator[](name) = v; need_update_ = true; } @@ -733,8 +774,17 @@ void OpDesc::SetAttrMap( need_update_ = true; } +void OpDesc::SetRuntimeAttrMap( + const std::unordered_map &attr_map) { + runtime_attrs_ = attr_map; + need_update_ = true; +} + Attribute OpDesc::GetAttr(const std::string &name, bool with_attr_var) const { auto it = attrs_.find(name); + if (it == attrs_.end()) { + it = runtime_attrs_.find(name); + } PADDLE_ENFORCE_NE( it, attrs_.end(), @@ -802,6 +852,8 @@ const std::unordered_map &OpDesc::GetAttrMap() const { return attrs_; } +const AttributeMap &OpDesc::GetRuntimeAttrMap() const { return runtime_attrs_; } + void OpDesc::Rename(const std::string &old_name, const std::string &new_name) { RenameInput(old_name, new_name); RenameOutput(old_name, new_name); @@ -925,6 +977,15 @@ void OpDesc::Flush() { } this->desc_.mutable_attrs()->Clear(); + auto set_attr_desc = [this](const std::string &attr_name, + const Attribute &attr) -> void { + auto *attr_desc = desc_.add_attrs(); + attr_desc->set_name(attr_name); + attr_desc->set_type(static_cast(attr.index() - 1)); + SetAttrDescVisitor visitor(attr_desc); + paddle::visit(visitor, attr); + }; + std::vector> sorted_attrs{attrs_.begin(), attrs_.end()}; std::sort( @@ -932,13 +993,12 @@ void OpDesc::Flush() { sorted_attrs.end(), [](std::pair a, std::pair b) { return a.first < b.first; }); + for (auto &attr : sorted_attrs) { - auto *attr_desc = desc_.add_attrs(); - attr_desc->set_name(attr.first); - attr_desc->set_type( - static_cast(attr.second.index() - 1)); - SetAttrDescVisitor visitor(attr_desc); - paddle::visit(visitor, attr.second); + set_attr_desc(attr.first, attr.second); + } + for (auto &attr : runtime_attrs_) { + set_attr_desc(attr.first, attr.second); } need_update_ = false; @@ -1155,7 +1215,7 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name, } AttrReader CompileTimeInferShapeContext::Attrs() const { - return AttrReader(op_.GetAttrMap()); + return AttrReader(op_.GetAttrMap(), op_.GetRuntimeAttrMap()); } std::vector CompileTimeInferShapeContext::Inputs( diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index a2f503f4b96531185178f5b15e9ba0f29bb4a0e6..7987a9ded475c8217b8659e51e8009ffdd4dfdba 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -144,6 +144,10 @@ class OpDesc { // Only be used in C++ void SetAttrMap(const AttributeMap &attr_map); + void SetRuntimeAttrMap(const AttributeMap &attr_map); + + const AttributeMap &GetRuntimeAttrMap() const; + std::vector InputNames(bool with_attr_var = false) const { return MapKeys(inputs_); } @@ -221,6 +225,12 @@ class OpDesc { VariableNameMap outputs_; // attribute name => all original attrs AttributeMap attrs_; + // runtime_attrs_ contains the attributes which used for dispatching kernel + // (use_mkldnn, use_cudnn, ...) or passing additional configuration for + // special heterogeneous kernel (workspace_size_MB, ...). + // The attributes in runtime_attrs_ are setted by framework (such as PASS), + // and not in the python api. + AttributeMap runtime_attrs_; // need_update_ indicate there some local changes not be synchronized. If // local changes should be synchronized, need_update_ should be set to true. diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 2d6411aafa5f682173cedc351a8a75717590098f..5f75991b50671bd5293a89ff0020c103830f35a8 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -12,8 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_proto_maker.h" - -#include +#include "paddle/fluid/operators/ops_extra_info.h" #include "paddle/fluid/platform/enforce.h" @@ -67,7 +66,16 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, op_checker_ = attr_checker; Make(); op_checker_->RecordExplicitCheckerNum(); - op_checker_->InitDefaultAttributeMap(); + + const AttributeMap* extra_attrs_ptr = nullptr; + const std::string& op_type = proto->type(); + + const auto& extra_attr_map = + operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type); + if (!extra_attr_map.empty()) { + extra_attrs_ptr = &extra_attr_map; + } + op_checker_->InitDefaultAttributeMap(extra_attrs_ptr); AddAttr(OpRoleAttrName(), "The role of this operator") .InEnum( diff --git a/paddle/fluid/framework/op_registry.cc b/paddle/fluid/framework/op_registry.cc index 21fca8a4d95d00584d26fe408d69cca2057e3188..a5ce5b704921b7a256fa094d6b0a20f66b5fdba8 100644 --- a/paddle/fluid/framework/op_registry.cc +++ b/paddle/fluid/framework/op_registry.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/ops_extra_info.h" #include "glog/logging.h" @@ -25,15 +26,47 @@ std::unique_ptr OpRegistry::CreateOp( const VariableNameMap& outputs, const AttributeMap& attrs, bool attr_check) { + AttributeMap standard_attrs; + AttributeMap runtime_attrs = + paddle::operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(type); + for (auto& attr : attrs) { + auto it = runtime_attrs.find(attr.first); + if (it != runtime_attrs.end()) { + it->second = attr.second; + } else { + standard_attrs[attr.first] = attr.second; + } + } + auto& info = OpInfoMap::Instance().Get(type); + if (attr_check && info.Checker() != nullptr) { + info.Checker()->Check(&standard_attrs); + } + auto op_base = std::unique_ptr( + info.Creator()(type, inputs, outputs, standard_attrs)); + op_base->SetRuntimeAttributeMap(runtime_attrs); + return op_base; +} + +std::unique_ptr OpRegistry::CreateOp( + const std::string& type, + const VariableNameMap& inputs, + const VariableNameMap& outputs, + const AttributeMap& attrs, + const AttributeMap& runtime_attrs, + bool attr_check) { + std::unique_ptr op_base; auto& info = OpInfoMap::Instance().Get(type); if (attr_check && info.Checker() != nullptr) { auto tmp_attrs = attrs; info.Checker()->Check(&tmp_attrs); - return std::unique_ptr( + op_base = std::unique_ptr( info.Creator()(type, inputs, outputs, tmp_attrs)); + } else { + op_base = std::unique_ptr( + info.Creator()(type, inputs, outputs, attrs)); } - return std::unique_ptr( - info.Creator()(type, inputs, outputs, attrs)); + op_base->SetRuntimeAttributeMap(runtime_attrs); + return op_base; } static VariableNameMap ConvertOpDescVarsToVarNameMap( @@ -59,18 +92,27 @@ std::unique_ptr OpRegistry::CreateOp( VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); AttributeMap attrs; + AttributeMap extra_attrs = + paddle::operators::ExtraInfoUtils::Instance().GetExtraAttrsMap( + op_desc.type()); for (auto& attr : op_desc.attrs()) { - attrs[attr.name()] = GetAttrValue(attr); + auto it = extra_attrs.find(attr.name()); + if (it != extra_attrs.end()) { + it->second = GetAttrValue(attr); + } else { + attrs[attr.name()] = GetAttrValue(attr); + } } - return CreateOp(op_desc.type(), inputs, outputs, attrs); + return CreateOp(op_desc.type(), inputs, outputs, attrs, extra_attrs); } std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { return CreateOp(op_desc.Type(), op_desc.Inputs(), op_desc.Outputs(), - op_desc.GetAttrMap()); + op_desc.GetAttrMap(), + op_desc.GetRuntimeAttrMap()); } } // namespace framework diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 53b77d538b3ed1db1509895de7561d17695ad1e0..2befc70b2d5ed5f93f86faf939842dee6f575cf7 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -132,6 +132,13 @@ class OpRegistry { const VariableNameMap& outputs, const AttributeMap& attrs, bool attr_check = true); + static std::unique_ptr CreateOp( + const std::string& type, + const VariableNameMap& inputs, + const VariableNameMap& outputs, + const AttributeMap& attrs, + const AttributeMap& runtime_attrs, + bool attr_check = true); static std::unique_ptr CreateOp(const proto::OpDesc& op_desc); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 71bd350af6eff599269c4fc86426b525ccecd495..04d51872852a7cf11dff9a0ae0e90a7e9e37d4e1 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1389,10 +1389,9 @@ bool OperatorWithKernel::SupportsKernelType( bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, proto::VarType::Type data_type) const { - const auto& attrs_map = ctx.Attrs(); - auto iter = attrs_map.find("use_mkldnn"); - bool use_mkldnn_ctx = iter != attrs_map.end() && - PADDLE_GET_CONST(bool, iter->second) && + const std::string use_mkldnn_attr = "use_mkldnn"; + bool use_mkldnn_ctx = ctx.HasAttr(use_mkldnn_attr) && + ctx.Attr(use_mkldnn_attr) && platform::is_cpu_place(ctx.GetPlace()); return use_mkldnn_ctx && this->SupportsMKLDNN(data_type); } @@ -2881,12 +2880,16 @@ void OperatorWithKernel::BuildPhiKernelContext( } } break; default: { - PADDLE_ENFORCE_NE( - attr_iter, - Attrs().end(), - platform::errors::NotFound("(%s) is not found in AttributeMap when " - "buildind static KernelContext.", - attr_names[i])); + if (attr_iter == Attrs().end()) { + attr_iter = RuntimeAttrs().find(attr_names[i]); + PADDLE_ENFORCE_NE(attr_iter, + RuntimeAttrs().end(), + platform::errors::NotFound( + "(%s) is not found in AttributeMap when " + "buildind static KernelContext.", + attr_names[i])); + } + switch (attr_defs[i].type_index) { case phi::AttributeType::FLOAT32: phi_kernel_context->EmplaceBackAttr( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 17ec9a1f93e723f5381e259669429db8209e9b3f..27ce31c25c0c444f4e0a39e4c68f2b0c6ffdd5cb 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -177,7 +177,9 @@ class OperatorBase { const std::string& Type() const { return type_; } - bool HasAttr(const std::string& name) const { return attrs_.count(name); } + bool HasAttr(const std::string& name) const { + return attrs_.count(name) || runtime_attrs_.count(name); + } template inline const T& Attr(const std::string& name) const { PADDLE_ENFORCE_NE( @@ -196,6 +198,10 @@ class OperatorBase { attrs_[name] = v; } const AttributeMap& Attrs() const { return attrs_; } + const AttributeMap& RuntimeAttrs() const { return runtime_attrs_; } + void SetRuntimeAttributeMap(const AttributeMap& runtime_attrs) { + runtime_attrs_ = runtime_attrs; + } const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Outputs() const { return outputs_; } @@ -250,6 +256,12 @@ class OperatorBase { // IG (Inputs Gradients) VariableNameMap outputs_; AttributeMap attrs_; + // NOTE: runtime_attrs_ contains the attributes which used for dispatching + // kernel (use_mkldnn, use_cudnn, ...) or passing additional configuration + // for special heterogeneous kernel (workspace_size_MB, ...). + // The attributes in runtime_attrs_ are setted by framework (such as PASS), + // and not in the python api. + AttributeMap runtime_attrs_; // OpInfo const OpInfo* info_; @@ -302,7 +314,12 @@ class ExecutionContext { } virtual const Attribute& GetAttr(const std::string& name) const { - return op_.Attrs().at(name); + auto iter = op_.Attrs().find(name); + if (iter == op_.Attrs().end()) { + return op_.RuntimeAttrs().at(name); + } else { + return iter->second; + } } virtual bool HasInput(const std::string& name) const; diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 16aa364736a12db6df203e7369865c7f43bb2ee0..70f2830d12067d0b90fd1f3ab32163c52cdfcef2 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -73,7 +73,8 @@ cc_library( denormal garbage_collector var_helper - layout_autotune) + layout_autotune + ops_extra_info) cc_library( basic_engine SRCS basic_engine.cc diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 07eb9ae6a8e5ebfe15bd4d56edef62d072039475..35eb3e9384200d29a5e94eb06b6270090ce8dc1e 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/imperative/execution_context.h" #include "paddle/fluid/imperative/layout_autotune.h" #include "paddle/fluid/imperative/op_base.h" +#include "paddle/fluid/operators/ops_extra_info.h" #include "paddle/fluid/platform/denormal.h" #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/profiler.h" @@ -240,6 +241,11 @@ void Tracer::TraceOpImpl(const std::string& type, if (attr_checker) { attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true); } + const auto& extra_attr_checkers = + operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(type); + for (const auto& checker : extra_attr_checkers) { + checker(&attrs, true); + } static paddle::framework::AttributeMap empty_attrs_map = {}; const paddle::framework::AttributeMap& default_attrs = diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index a722aca6a9f47586695130e0a509967682c45e52..39faf87406d5899b6bbdb69de18dd665557d776b 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -149,6 +149,7 @@ if (WITH_DGC) endif() cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEPS operator) +cc_library(ops_extra_info SRCS ops_extra_info.cc DEPS attribute cudnn_workspace_helper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows_utils lapack_function lod_tensor maxouting unpooling pooling lod_rank_table context_project diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index bfdd9ea5d1b78e55940725b9f17b9f160d85623e..ab222f3cb36c1a9cec2fb68761a5d3273ebc3796 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -348,82 +348,6 @@ void Conv2DOpMaker::Make() { "dilations(h_dilation, w_dilation) of " "convolution operator.") .SetDefault({1, 1}); - AddAttr( - "use_cudnn", - "(bool, default false) Only used in cudnn kernel, need install cudnn") - .SetDefault(false) - .AsExtra(); - AddAttr("fuse_relu_before_depthwise_conv", - "(bool, default false) Only used in cuda depthwise kernel") - .SetDefault(false) - .AsExtra(); - AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false) - .AsExtra(); - AddAttr( - "use_quantizer", - "(bool, default false) " - "This parameter is no longer used. Use 'mkldnn_data_type' instead.") - .SetDefault(false) - .AsExtra(); - AddAttr( - "mkldnn_data_type", - "(string, default \"float32\"). Data type of mkldnn kernel") - .SetDefault("float32") - .InEnum({"float32", "int8", "bfloat16"}) - .AsExtra(); - AddAttr("fuse_relu", "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false) - .AsExtra(); - AddAttr("fuse_activation", - "(string, default \"\") Only used in mkldnn kernel") - .SetDefault("") - .AsExtra(); - AddAttr("fuse_alpha", - "(float, default 0.0) Only used in mkldnn kernel") - .SetDefault(0.0f) - .AsExtra(); - AddAttr("fuse_beta", "(float, default 0.0) Only used in mkldnn kernel") - .SetDefault(0.0f) - .AsExtra(); - AddAttr( - "use_addto", - "(bool, default false) If use addto strategy or not, only used in " - "cudnn kernel") - .SetDefault(false) - .AsExtra(); - AddAttr("fuse_residual_connection", - "(bool, default false) Only used in mkldnn kernel. Used " - "whenever convolution output is as an input to residual " - "connection.") - .SetDefault(false) - .AsExtra(); - AddAttr("Scale_in", - "Scale_in to be used for int8 input data." - "Only used with MKL-DNN INT8.") - .SetDefault(1.0f) - .AsExtra(); - AddAttr("Scale_out", - "Scale_out to be used for int8 output data." - "Only used with MKL-DNN INT8.") - .SetDefault(1.0f) - .AsExtra(); - AddAttr("Scale_in_eltwise", - "Scale_in_eltwise to be used for int8 eltwise input data." - "Only used with MKL-DNN INT8.") - .SetDefault(1.0f) - .AsExtra(); - AddAttr>("Scale_weights", - "Scale_weights to be used for int8 weights data." - "Only used with MKL-DNN INT8.") - .SetDefault({1.0f}) - .AsExtra(); - AddAttr("force_fp32_output", - "(bool, default false) Force INT8 kernel output FP32, only " - "used in MKL-DNN INT8") - .SetDefault(false) - .AsExtra(); AddAttr( "data_format", "(string, default NCHW) Only used in " diff --git a/paddle/fluid/operators/ops_extra_info.h b/paddle/fluid/operators/ops_extra_info.h new file mode 100644 index 0000000000000000000000000000000000000000..8f3780dd3a3ad8f752f4790e3723b0dce401e0a8 --- /dev/null +++ b/paddle/fluid/operators/ops_extra_info.h @@ -0,0 +1,89 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/attribute.h" + +namespace paddle { +namespace operators { + +template +struct ExtraAttrChecker { + ExtraAttrChecker(const std::string& attr_name, T default_value) + : attr_name(attr_name), default_val(default_value) {} + + void operator()(framework::AttributeMap* attr_map, + bool only_check_exist_value) { + auto it = attr_map->find(attr_name); + if (it == attr_map->end()) { + if (!only_check_exist_value) { + attr_map->emplace(attr_name, default_val); + } + return; + } + framework::ExtractAttribute extract_attr(attr_name); + extract_attr(it->second); + } + + const std::string& attr_name; + T default_val; +}; + +class ExtraInfoUtils { + public: + static ExtraInfoUtils& Instance() { + static ExtraInfoUtils extra_info_utils; + return extra_info_utils; + } + + const std::unordered_map& + GetAllExtraAttrsMap() const { + return g_extra_attrs_map_; + } + + const paddle::framework::AttributeMap& GetExtraAttrsMap( + const std::string& op_type) const { + auto iter = g_extra_attrs_map_.find(op_type); + if (iter != g_extra_attrs_map_.end()) { + return iter->second; + } + return empty_extra_attrs_map_; + } + + const std::vector>& + GetExtraAttrsChecker(const std::string& op_type) const { + auto iter = g_extra_attrs_checker_.find(op_type); + if (iter != g_extra_attrs_checker_.end()) { + return iter->second; + } + return empty_extra_attrs_checker_; + } + + private: + ExtraInfoUtils(); + + std::unordered_map + g_extra_attrs_map_; + paddle::framework::AttributeMap empty_extra_attrs_map_{}; + std::unordered_map< + std::string, + std::vector>> + g_extra_attrs_checker_; + std::vector> + empty_extra_attrs_checker_{}; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index e7970f69e577507ffefaf0b84aa5984394702673..5680a84ca4fa5e88093d4c6340f0617519c55d8d 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -29,6 +29,7 @@ #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" +#include "paddle/fluid/operators/ops_extra_info.h" #include "paddle/fluid/pybind/imperative.h" namespace py = pybind11; @@ -961,6 +962,15 @@ void InitOpsAttrTypeMap() { OpAttrTypeMap::Instance().Map()[iter->first][attr.name()] = attr.type(); } } + const auto& extra_attr_maps = + operators::ExtraInfoUtils::Instance().GetAllExtraAttrsMap(); + for (const auto& extra_attrs : extra_attr_maps) { + for (auto& attr : extra_attrs.second) { + OpAttrTypeMap::Instance().Map()[extra_attrs.first][attr.first] = + static_cast(attr.second.index() - + 1); + } + } } ssize_t GetIdxFromCoreOpsInfoMap( diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index ae4ee11bdc6b76e907854885b5baa50a1c01b5ef..acce7781a23e911609507e501c1307c710980a95 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -71,6 +71,7 @@ limitations under the License. */ #include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/common_infer_shape_functions.h" +#include "paddle/fluid/operators/ops_extra_info.h" #include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_info.h" @@ -1068,6 +1069,23 @@ All parameter, weight, gradient are variables in Paddle. } return res; }); + m.def( + "get_op_extra_attrs", + [](const std::string &op_type) + -> const paddle::framework::AttributeMap & { + return operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type); + }); + + m.def( + "get_attrtibute_type", + [](const std::string &op_type, + const std::string &attr_name) -> paddle::framework::proto::AttrType { + const auto &defalut_val = + operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type).at( + attr_name); + return static_cast( + defalut_val.index() - 1); + }); m.def("get_grad_op_desc", [](const OpDesc &op_desc, const std::unordered_set &no_grad_set, diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index dfbf5aaba46e53febc03e3886a754fb7fc714554..924f6f3526a6b8790f78fd8f085eff71ccfbf5d5 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -100,7 +100,7 @@ set(ops_extra_info_gen_file set(api_compat_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/api_compat.yaml) set(ops_extra_info_file - ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.h) + ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.cc) if(NOT PYTHONINTERP_FOUND) find_package(PythonInterp REQUIRED) diff --git a/paddle/phi/api/yaml/api_compat.yaml b/paddle/phi/api/yaml/api_compat.yaml index 9ac1e8bd719bed3674ef7b4632823fe36ecd89f0..84c37cde318cc1903348c9e07f298aa56bdf4559 100644 --- a/paddle/phi/api/yaml/api_compat.yaml +++ b/paddle/phi/api/yaml/api_compat.yaml @@ -1,3 +1,8 @@ +# - api : conv3d_transpose +# backward : conv3d_transpose_grad +# extra : +# attrs : [bool use_cudnn = true, bool use_mkldnn = false, int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()] + - api : atan2 inputs : {x : X1, y : X2} @@ -23,19 +28,20 @@ out : Out - api : conv2d + backward : conv2d_grad extra : attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false, bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false, - str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false, + str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false, bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f, float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false, int workspace_size_MB = 512, bool exhaustive_search = false] -- api : conv2d +- api : conv2d_fusion extra : attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false, bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false, - str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false, + str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false, bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f, float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false, int workspace_size_MB = 512, bool exhaustive_search = false] @@ -48,6 +54,16 @@ outputs : out : Out +- api : depthwise_conv2d + backward : depthwise_conv2d_grad + extra : + attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false, + bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false, + str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false, + bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f, + float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false, + int workspace_size_MB = 512, bool exhaustive_search = false] + - api : diag op_name : diag_v2 grad_op_name : diag_v2_grad diff --git a/paddle/phi/api/yaml/generator/ops_extra_info_gen.py b/paddle/phi/api/yaml/generator/ops_extra_info_gen.py index ef5afbf595b961578f4fcc4c09b337e4437e49d1..675b889e4450cf14fa4eea736e390fe7ca688756 100644 --- a/paddle/phi/api/yaml/generator/ops_extra_info_gen.py +++ b/paddle/phi/api/yaml/generator/ops_extra_info_gen.py @@ -18,17 +18,27 @@ import re import argparse -def map_code_template(attrs_str): - return f""" -#include "paddle/fluid/framework/attribute.h" +def map_code_template(attrs_str, attrs_checker_str): + return f"""// This file is generated by paddle/phi/api/yaml/generator/ops_extra_info_gen.py +#include "paddle/fluid/operators/ops_extra_info.h" + +#include "paddle/fluid/platform/cudnn_workspace_helper.h" namespace paddle {{ -const static std::unordered_map extra_attrs_map = {{ -{attrs_str} -}}; +namespace operators {{ -}} // namespace paddle +ExtraInfoUtils::ExtraInfoUtils() {{ + g_extra_attrs_map_ = {{ + {attrs_str} + }}; + + g_extra_attrs_checker_ = {{ + {attrs_checker_str} + }}; +}} +}} // namespace operators +}} // namespace paddle """ @@ -61,6 +71,7 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): compat_apis = yaml.safe_load(f) extra_map_str_list = [] + extra_checker_str_list = [] for api_compat_args in compat_apis: if 'extra' in api_compat_args: @@ -68,8 +79,12 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): # TODO(chenweihang): add inputs and outputs if 'attrs' in extra_args_map: attr_map_list = [] + attr_checker_func_list = [] for attr in extra_args_map['attrs']: attr_type, attr_name, default_val = parse_attr(attr) + attr_checker_func_list.append( + f"[](framework::AttributeMap* attr_map, bool only_check_exist_value)-> void {{ ExtraAttrChecker<{attr_type}>(\"{attr_name}\", {default_val})(attr_map, only_check_exist_value);}}" + ) if attr_type.startswith("std::vector"): attr_map_list.append( f"{{\"{attr_name}\", {attr_type}{default_val}}}") @@ -78,12 +93,26 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): f"{{\"{attr_name}\", {attr_type}{{{default_val}}}}}" ) api_extra_attr_map = ", ".join(attr_map_list) + api_extra_attr_checkers = ",\n ".join( + attr_checker_func_list) extra_map_str_list.append( f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_map} }}}}" ) + extra_checker_str_list.append( + f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_checkers} }}}}" + ) + if 'backward' in api_compat_args: + extra_map_str_list.append( + f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_map} }}}}" + ) + extra_checker_str_list.append( + f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_checkers} }}}}" + ) ops_extra_info_file = open(ops_extra_info_path, 'w') - ops_extra_info_file.write(map_code_template(",\n".join(extra_map_str_list))) + ops_extra_info_file.write( + map_code_template(",\n ".join(extra_map_str_list), + ",\n ".join(extra_checker_str_list))) ops_extra_info_file.close() @@ -96,7 +125,7 @@ def main(): parser.add_argument('--ops_extra_info_path', help='output of generated extra_prama_info code file', - default='paddle/fluid/operators/ops_extra_info.h') + default='paddle/fluid/operators/ops_extra_info.cc') options = parser.parse_args() diff --git a/paddle/phi/core/attribute.h b/paddle/phi/core/attribute.h index d1b2920335576a02055da867db91c56ac88727b2..d8d684b9030e9dc6ac062b930d92add7daa74745 100644 --- a/paddle/phi/core/attribute.h +++ b/paddle/phi/core/attribute.h @@ -21,6 +21,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/common/scalar.h" +#include "paddle/utils/flat_hash_map.h" #include "paddle/utils/variant.h" namespace phi { @@ -47,4 +48,6 @@ using Attribute = paddle::variant; +using RuntimeAttrs = paddle::flat_hash_map; + } // namespace phi diff --git a/paddle/phi/core/kernel_context.h b/paddle/phi/core/kernel_context.h index 830443fca8fbacbe80dcc05aedd9ca28bc314921..7b79138fe78a3bee7e2a145107f2c797e0b6b364 100644 --- a/paddle/phi/core/kernel_context.h +++ b/paddle/phi/core/kernel_context.h @@ -138,6 +138,8 @@ class KernelContext { template const AttrType& AttrAt(size_t idx) const; + const RuntimeAttrs& GetRuntimeAttrs() const { return runtime_attrs_; } + size_t InputsSize() const { return inputs_.size(); } size_t OutputsSize() const { return outputs_.size(); } size_t AttrsSize() const { return attrs_.size(); } @@ -152,6 +154,8 @@ class KernelContext { paddle::small_vector, kInputSmallVectorSize> input_range_; paddle::small_vector, kOutputSmallVectorSize> output_range_; + + RuntimeAttrs runtime_attrs_; }; } // namespace phi diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index 92cc3f511f66fe49cee44e04a58877c7b1ceefe7..1cba62a86ef01d63b2fc7bee005057cdda1fae2c 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -205,6 +205,8 @@ struct KernelArgsParseFunctor { args_def->AppendAttribute(AttributeType::DATA_LAYOUT); } else if (arg_type == std::type_index(typeid(Place))) { args_def->AppendAttribute(AttributeType::PLACE); + } else if (arg_type == std::type_index(typeid(RuntimeAttrs))) { + // do nothing } else { PADDLE_THROW(phi::errors::Unavailable( "Unsupported kernel argument type `%s`.", arg_type.name())); diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index 9206acfd515420181f3036b5617c094b600102c2..df850389ff453ba324d5e3fb751533109c11df20 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -321,6 +321,22 @@ struct KernelImpl { PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(StringTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(StringTensor); + + template + struct KernelCallHelper { + template + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { + const auto& runtime_attrs = ctx->GetRuntimeAttrs(); + KernelCallHelper:: + template Compute( + ctx, pargs..., runtime_attrs); + } + }; + /* End case */ template struct KernelCallHelper> { diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 2ed95eed2689ac99a25087cddc6473d899605e23..409eb020d39601b3a4ee81e880fafcf236c2d0ce 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2840,6 +2840,7 @@ class Operator(object): arg.op = self self.desc.set_output(out_proto.name, out_arg_names) + extra_attrs_map = core.get_op_extra_attrs(type) if op_attrs is not None: if not isinstance(op_attrs, dict): raise TypeError("'attrs' should be a dict.") @@ -2850,6 +2851,13 @@ class Operator(object): continue attr_val = op_attrs[attr_name] self._update_desc_attr(attr_name, attr_val) + for attr_name in extra_attrs_map.keys(): + if (attr_name + not in op_attrs) or (op_attrs[attr_name] is None): + self._update_desc_attr(attr_name, + extra_attrs_map[attr_name]) + else: + self._update_desc_attr(attr_name, op_attrs[attr_name]) # proto.attrs doesn't include ipu_index if core.is_compiled_with_ipu(): @@ -5821,17 +5829,29 @@ class Program(object): ] res._sync_with_cpp() + # Note: The op_role and op_role_var cann't be deleted currently, + # and we will try to remove them in the future. + common_clipped_attrs_list = ['op_namescope', 'op_callstack'] + for i in six.moves.range(res.desc.num_blocks()): block = res.desc.block(i) for var in block.all_vars(): var.clear_is_parameter() var.clear_stop_gradient() - if not clip_extra: - continue for op_idx in range(0, block.op_size()): op = block.op(op_idx) if op.type() not in OpProtoHolder.instance().op_proto_map: continue + + if not clip_extra: + continue + + extra_attrs_map = core.get_op_extra_attrs(op.type()) + for name in op.attr_names(): + if name in extra_attrs_map: + op.remove_attr(name) + continue + proto = OpProtoHolder.instance().get_op_proto(op.type()) remove_input_list = [] for name in op.input_names(): @@ -5845,8 +5865,9 @@ class Program(object): break if not find: remove_input_list.append(name) - for name in remove_input_list: - op.remove_input(name) + # The extra input of op will be removed in the future + # for name in remove_input_list: + # op.remove_input(name) remove_output_list = [] for name in op.output_names(): @@ -5860,10 +5881,10 @@ class Program(object): break if not find: remove_output_list.append(name) - for name in remove_output_list: - op.remove_output(name) + # The extra input of op will be removed in the future + # for name in remove_output_list: + # op.remove_output(name) - remove_attr_list = [] op_quant_name = core.op_proto_and_checker_maker.kOpWithQuantAttrName( ) quant = bool(op.attr(op_quant_name) @@ -5873,18 +5894,21 @@ class Program(object): "activation_bits", "bit_length", "quantize_weight_bits", "weight_quant_scale" ] + remove_attr_list = [] for name in op.attr_names(): if quant: if name in quant_attrs: continue if name.endswith("_threshold"): continue + if name in common_clipped_attrs_list: + remove_attr_list.append(name) + continue + find = False for attr_proto in proto.attrs: if attr_proto.name != name: continue - if attr_proto.extra: - remove_attr_list.append(name) find = True break if not find: diff --git a/python/paddle/fluid/op.py b/python/paddle/fluid/op.py index 4581248d06ac230f4225b590c524867ba89d4f16..ca92ac44128953cf2839c2159c7df53a309ac252 100644 --- a/python/paddle/fluid/op.py +++ b/python/paddle/fluid/op.py @@ -52,6 +52,7 @@ class OpDescCreationMethod(object): raise TypeError( "Type of op_proto should be OpProto in PaddlePaddle.") self.__op_proto__ = op_proto + self.__extra_attrs__ = core.get_op_extra_attrs(op_proto.type) def __call__(self, *args, **kwargs): """ @@ -130,6 +131,40 @@ class OpDescCreationMethod(object): raise NotImplementedError( "A not supported attribute type: %s." % (str(attr.type))) + for attr_name, defalut_val in self.__extra_attrs__.items(): + user_defined_attr = kwargs.get(attr_name, None) + if user_defined_attr is not None: + attr_type = int( + core.get_attrtibute_type(op_desc.type, attr_name)) + new_attr = op_desc.attrs.add() + new_attr.name = attr_name + new_attr.type = attr_type + if isinstance(user_defined_attr, np.ndarray): + user_defined_attr = user_defined_attr.tolist() + if attr_type == framework_pb2.INT: + new_attr.i = user_defined_attr + elif attr_type == framework_pb2.FLOAT: + new_attr.f = user_defined_attr + elif attr_type == framework_pb2.LONG: + new_attr.l = user_defined_attr + elif attr_type == framework_pb2.STRING: + new_attr.s = user_defined_attr + elif attr_type == framework_pb2.BOOLEAN: + new_attr.b = user_defined_attr + elif attr_type == framework_pb2.INTS: + new_attr.ints.extend(user_defined_attr) + elif attr_type == framework_pb2.FLOATS: + new_attr.floats.extend(user_defined_attr) + elif attr_type == framework_pb2.STRINGS: + new_attr.strings.extend(user_defined_attr) + elif attr_type == framework_pb2.BOOLEANS: + new_attr.bools.extend(user_defined_attr) + elif attr_type == framework_pb2.LONGS: + new_attr.longs.extend(user_defined_attr) + else: + raise NotImplementedError( + "A not supported attribute type: %s." % + (str(attr_type))) return op_desc @@ -147,12 +182,13 @@ class OpDescCreationMethod(object): class OpInfo(object): - def __init__(self, name, method, inputs, outputs, attrs): + def __init__(self, name, method, inputs, outputs, attrs, extra_attrs): self.name = name self.method = method self.inputs = inputs self.outputs = outputs self.attrs = attrs + self.extra_attrs = extra_attrs def create_op_creation_method(op_proto): @@ -165,13 +201,16 @@ def create_op_creation_method(op_proto): opdesc = method(*args, **kwargs) return core.Operator.create(opdesc.SerializeToString()) + extra_attrs_map = core.get_op_extra_attrs(op_proto.type) + return OpInfo(method=__impl__, name=op_proto.type, inputs=[(var.name, var.duplicable) for var in op_proto.inputs], outputs=[(var.name, var.duplicable) for var in op_proto.outputs], - attrs=[attr.name for attr in op_proto.attrs]) + attrs=[attr.name for attr in op_proto.attrs], + extra_attrs=[item for item in extra_attrs_map.keys()]) class OperatorFactory(object): @@ -222,6 +261,9 @@ class OperatorFactory(object): def get_op_attr_names(self, type): return self.get_op_info(type).attrs + def get_op_extra_attr_names(self, type): + return self.get_op_info(type).extra_attrs + class __RecurrentOp__(object): __proto__ = None diff --git a/python/paddle/fluid/tests/unittests/testsuite.py b/python/paddle/fluid/tests/unittests/testsuite.py index e106f33c8a0680b0e2b291d121fff385c2c6f468..c3d271e018b6403fa23d7babc8e00f852ac8ec9f 100644 --- a/python/paddle/fluid/tests/unittests/testsuite.py +++ b/python/paddle/fluid/tests/unittests/testsuite.py @@ -64,6 +64,10 @@ def create_op(scope, op_type, inputs, outputs, attrs, cache_list=None): if attr_name in attrs: kwargs[attr_name] = attrs[attr_name] + for extra_attr_name in Operator.get_op_extra_attr_names(op_type): + if extra_attr_name in attrs: + kwargs[extra_attr_name] = attrs[extra_attr_name] + return Operator(op_type, **kwargs) diff --git a/tools/check_file_diff_approvals.sh b/tools/check_file_diff_approvals.sh index 6b0f53ac5236941489af5db18ac93839a5051575..976dbe1dc3598e3d0672a1f28ad69ef251963151 100644 --- a/tools/check_file_diff_approvals.sh +++ b/tools/check_file_diff_approvals.sh @@ -240,7 +240,7 @@ if [ "${HAS_MODIFIED_DECLARATIONS}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then HAS_MODIFIED_API_COMPAT_YAML=`git diff --name-only upstream/$BRANCH | grep "paddle/phi/api/yaml/api_compat.yaml" || true` if [ "${HAS_MODIFIED_API_COMPAT_YAML}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then - echo_line="You must be approved by chenwhql or zyfncg for paddle/phi/api/yaml/api_compat.yaml, which manages the extra params of Op and name mapping between Yaml and OpMaker. In order to ensure compatibility of framework, this file isn't allowed to be modified at will!\n" + echo_line="You must be approved by chenwhql or zyfncg for paddle/phi/api/yaml/api_compat.yaml changes, which manages the extra params of Op and name mapping between Yaml and OpMaker. In order to ensure compatibility of framework, this file isn't allowed to be modified at will!\n" check_approval 1 chenwhql zyfncg fi