未验证 提交 fe321f9a 编写于 作者: Z zyfncg 提交者: GitHub

Remove extra attribute in OpMaker (#44310)

* add runtime config in phi

* add runtime attr for op desc and op

* fix no proto error

* adjust opdesc set_attr impl

* try to remove conv_op extra attrs

* add init runtime attr map

* change extra header path

* fix runtime_attr

* fix trace_op

* fix bug of pass

* fix merge conflict

* fix dygraph attrs

* fix bug of pass

* fix dygraph bug

* fix unittest module

* delete extra attr default

* fix dropout kernel

* polish code

* fix extra output of instance_norm

* fix merge confilct

* fix op_desc bug

* add extra attr in yaml for conv3d_transpose

* don't remove extra input and output

* fix save_inference_model

* fix bug of batch_norm

* revert some change

* polish log

* polish code

* add code comment
Co-authored-by: NChen Weihang <chenweihang@baidu.com>
上级 a7c4facb
...@@ -5,7 +5,7 @@ paddle/fluid/API_PR.spec ...@@ -5,7 +5,7 @@ paddle/fluid/API_PR.spec
paddle/fluid/eager/api/generated/* paddle/fluid/eager/api/generated/*
paddle/fluid/op_use_default_grad_maker_DEV.spec paddle/fluid/op_use_default_grad_maker_DEV.spec
paddle/fluid/op_use_default_grad_maker_PR.spec paddle/fluid/op_use_default_grad_maker_PR.spec
paddle/fluid/operators/ops_extra_info.h paddle/fluid/operators/ops_extra_info.cc
paddle/phi/api/backward/backward_api.h paddle/phi/api/backward/backward_api.h
paddle/phi/api/backward/sparse_bw_api.h paddle/phi/api/backward/sparse_bw_api.h
paddle/phi/api/include/api.h paddle/phi/api/include/api.h
......
...@@ -348,7 +348,7 @@ cc_test( ...@@ -348,7 +348,7 @@ cc_test(
cc_library( cc_library(
op_proto_maker op_proto_maker
SRCS op_proto_maker.cc SRCS op_proto_maker.cc
DEPS framework_proto attribute glog) DEPS framework_proto attribute ops_extra_info glog)
cc_test( cc_test(
op_proto_maker_test op_proto_maker_test
SRCS op_proto_maker_test.cc SRCS op_proto_maker_test.cc
...@@ -483,6 +483,7 @@ cc_library( ...@@ -483,6 +483,7 @@ cc_library(
proto_desc proto_desc
SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc
DEPS attribute DEPS attribute
ops_extra_info
shape_inference shape_inference
op_info op_info
operator operator
...@@ -498,7 +499,7 @@ endif() ...@@ -498,7 +499,7 @@ endif()
cc_library( cc_library(
op_registry op_registry
SRCS op_registry.cc SRCS op_registry.cc
DEPS op_proto_maker op_info operator glog proto_desc) DEPS op_proto_maker op_info operator ops_extra_info glog proto_desc)
cc_library( cc_library(
op_call_stack op_call_stack
......
...@@ -325,10 +325,13 @@ class OpAttrChecker { ...@@ -325,10 +325,13 @@ class OpAttrChecker {
explicit_checker_num_ = attr_checkers_.size(); explicit_checker_num_ = attr_checkers_.size();
} }
void InitDefaultAttributeMap() { void InitDefaultAttributeMap(const AttributeMap* extra_attr_map) {
for (const auto& checker : attr_checkers_) { for (const auto& checker : attr_checkers_) {
checker(&default_attrs_, true, false); checker(&default_attrs_, true, false);
} }
if (extra_attr_map) {
default_attrs_.insert(extra_attr_map->begin(), extra_attr_map->end());
}
} }
const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; } const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; }
......
...@@ -207,9 +207,9 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { ...@@ -207,9 +207,9 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
"OpAttrChecker of %s has been registered.", op_type)); "OpAttrChecker of %s has been registered.", op_type));
info->proto_ = new proto::OpProto; info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker(); info->checker_ = new OpAttrChecker();
info->proto_->set_type(op_type);
T maker; T maker;
maker(info->proto_, info->checker_); maker(info->proto_, info->checker_);
info->proto_->set_type(op_type);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
info->proto_->IsInitialized(), info->proto_->IsInitialized(),
true, true,
......
...@@ -161,6 +161,10 @@ class GradOpDescMakerBase { ...@@ -161,6 +161,10 @@ class GradOpDescMakerBase {
return fwd_op_.GetAttrMap(); return fwd_op_.GetAttrMap();
} }
const std::unordered_map<std::string, Attribute>& RuntimeAttrs() const {
return fwd_op_.GetRuntimeAttrMap();
}
const Attribute& GetAttr(const std::string& name) const { const Attribute& GetAttr(const std::string& name) const {
auto& map = fwd_op_.GetAttrMap(); auto& map = fwd_op_.GetAttrMap();
auto it = map.find(name); auto it = map.find(name);
...@@ -209,6 +213,7 @@ class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase { ...@@ -209,6 +213,7 @@ class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase {
retv.emplace_back(new OpDesc()); retv.emplace_back(new OpDesc());
try { try {
this->Apply(retv.front().get()); this->Apply(retv.front().get());
retv.front()->SetRuntimeAttrMap(this->RuntimeAttrs());
} catch (platform::EnforceNotMet& exception) { } catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(retv.front().get()->Type(), &exception); framework::AppendErrorOpHint(retv.front().get()->Type(), &exception);
throw std::move(exception); throw std::move(exception);
......
...@@ -298,6 +298,9 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -298,6 +298,9 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
for (auto& attr : conv->Op()->GetAttrMap()) { for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second); desc.SetAttr(attr.first, attr.second);
} }
for (auto& attr : conv->Op()->GetRuntimeAttrMap()) {
desc.SetAttr(attr.first, attr.second);
}
auto conv_bias_node = g->CreateOpNode(&desc); auto conv_bias_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node); IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
......
...@@ -18,6 +18,7 @@ set(STANDALONE_EXECUTOR_DEPS ...@@ -18,6 +18,7 @@ set(STANDALONE_EXECUTOR_DEPS
scope scope
framework_proto framework_proto
data_feed_proto data_feed_proto
ops_extra_info
heter_service_proto heter_service_proto
trainer_desc_proto trainer_desc_proto
glog glog
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
...@@ -242,20 +243,30 @@ void build_variable_scope(const framework::BlockDesc& block, ...@@ -242,20 +243,30 @@ void build_variable_scope(const framework::BlockDesc& block,
void create_all_ops(const framework::BlockDesc& block, void create_all_ops(const framework::BlockDesc& block,
std::vector<std::unique_ptr<OperatorBase>>* ops) { std::vector<std::unique_ptr<OperatorBase>>* ops) {
for (auto& op : block.AllOps()) { for (auto& op : block.AllOps()) {
VLOG(3) << "CreateOp from : " << op->Type(); auto op_type = op->Type();
VLOG(1) << "CreateOp from : " << op_type;
auto& info = OpInfoMap::Instance().Get(op->Type()); auto& info = OpInfoMap::Instance().Get(op_type);
const VariableNameMap& inputs_names = op->Inputs(); const VariableNameMap& inputs_names = op->Inputs();
const VariableNameMap& outputs_names = op->Outputs(); const VariableNameMap& outputs_names = op->Outputs();
AttributeMap op_attr_map = op->GetAttrMap(); AttributeMap op_attr_map = op->GetAttrMap();
AttributeMap op_runtime_attr_map = op->GetRuntimeAttrMap();
if (info.Checker() != nullptr) { if (info.Checker() != nullptr) {
info.Checker()->Check(&op_attr_map); info.Checker()->Check(&op_attr_map);
} }
const auto& extra_attr_checkers =
operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(op_type);
for (const auto& checker : extra_attr_checkers) {
checker(&op_runtime_attr_map, false);
}
auto op_base = auto op_base =
info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); info.Creator()(op_type, inputs_names, outputs_names, op_attr_map);
op_base->SetRuntimeAttributeMap(op_runtime_attr_map);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) { if (FLAGS_use_mkldnn) {
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/utils/blank.h" #include "paddle/utils/blank.h"
namespace paddle { namespace paddle {
...@@ -409,6 +410,13 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -409,6 +410,13 @@ class CompileTimeInferShapeContext : public InferShapeContext {
const BlockDesc &block_; const BlockDesc &block_;
}; };
static void InitRuntimeAttributeMapByOpExtraInfo(const std::string &op_type,
AttributeMap *runtime_attrs) {
const auto &extra_attr_map =
operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type);
runtime_attrs->insert(extra_attr_map.begin(), extra_attr_map.end());
}
OpDesc::OpDesc(const std::string &type, OpDesc::OpDesc(const std::string &type,
const VariableNameMap &inputs, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
...@@ -419,6 +427,7 @@ OpDesc::OpDesc(const std::string &type, ...@@ -419,6 +427,7 @@ OpDesc::OpDesc(const std::string &type,
attrs_ = attrs; attrs_ = attrs;
need_update_ = true; need_update_ = true;
block_ = nullptr; block_ = nullptr;
InitRuntimeAttributeMapByOpExtraInfo(type, &runtime_attrs_);
} }
OpDesc::OpDesc(const OpDesc &other) { OpDesc::OpDesc(const OpDesc &other) {
...@@ -441,6 +450,8 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { ...@@ -441,6 +450,8 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
inputs_ = op_desc.inputs_; inputs_ = op_desc.inputs_;
outputs_ = op_desc.outputs_; outputs_ = op_desc.outputs_;
attrs_ = op_desc.attrs_; attrs_ = op_desc.attrs_;
runtime_attrs_ = op_desc.runtime_attrs_;
// The record of original_id_ is only for auto parallel.
original_id_ = op_desc.original_id_; original_id_ = op_desc.original_id_;
if (op_desc.dist_attr_) { if (op_desc.dist_attr_) {
dist_attr_.reset(new OperatorDistAttr(*op_desc.dist_attr_)); dist_attr_.reset(new OperatorDistAttr(*op_desc.dist_attr_));
...@@ -473,8 +484,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) ...@@ -473,8 +484,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
} }
} }
// restore attrs_ // restore attrs_
InitRuntimeAttributeMapByOpExtraInfo(desc.type(), &runtime_attrs_);
for (const proto::OpDesc::Attr &attr : desc_.attrs()) { for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name(); const std::string &attr_name = attr.name();
// The sub_block referred to by the BLOCK attr hasn't been added // The sub_block referred to by the BLOCK attr hasn't been added
// to ProgramDesc class yet, we skip setting BLOCK/BLOCKS/VAR/VARS attr // to ProgramDesc class yet, we skip setting BLOCK/BLOCKS/VAR/VARS attr
// here. // here.
...@@ -483,7 +495,12 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) ...@@ -483,7 +495,12 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
attr_type != proto::AttrType::BLOCKS && attr_type != proto::AttrType::BLOCKS &&
attr_type != proto::AttrType::VAR && attr_type != proto::AttrType::VAR &&
attr_type != proto::AttrType::VARS) { attr_type != proto::AttrType::VARS) {
attrs_[attr_name] = GetAttrValue(attr); auto iter = runtime_attrs_.find(attr_name);
if (iter == runtime_attrs_.end()) {
attrs_[attr_name] = GetAttrValue(attr);
} else {
iter->second = GetAttrValue(attr);
}
} }
} }
this->block_ = block; this->block_ = block;
...@@ -622,7 +639,13 @@ std::vector<std::string> OpDesc::AttrNames(bool with_attr_var) const { ...@@ -622,7 +639,13 @@ std::vector<std::string> OpDesc::AttrNames(bool with_attr_var) const {
bool OpDesc::HasAttr(const std::string &name, bool with_attr_var) const { bool OpDesc::HasAttr(const std::string &name, bool with_attr_var) const {
auto iter = attrs_.find(name); auto iter = attrs_.find(name);
bool is_found = iter != attrs_.end(); bool is_found = true;
if (iter == attrs_.end()) {
iter = runtime_attrs_.find(name);
if (iter == runtime_attrs_.end()) {
is_found = false;
}
}
if (with_attr_var) { if (with_attr_var) {
return is_found; return is_found;
} }
...@@ -631,10 +654,19 @@ bool OpDesc::HasAttr(const std::string &name, bool with_attr_var) const { ...@@ -631,10 +654,19 @@ bool OpDesc::HasAttr(const std::string &name, bool with_attr_var) const {
void OpDesc::RemoveAttr(const std::string &name) { void OpDesc::RemoveAttr(const std::string &name) {
attrs_.erase(name); attrs_.erase(name);
runtime_attrs_.erase(name);
need_update_ = true; need_update_ = true;
} }
void OpDesc::SetAttr(const std::string &name, const Attribute &v) { void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
AttributeMap *attrs_ptr = &(this->attrs_);
const auto &extra_attr_map =
operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(Type());
auto extra_attr_iter = extra_attr_map.find(name);
if (extra_attr_iter != extra_attr_map.end()) {
attrs_ptr = &(this->runtime_attrs_);
}
// NOTICE(minqiyang): pybind11 will take the empty list in python as // NOTICE(minqiyang): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type // the std::vector<int> type in C++; so we have to change the attr's type
// here if we meet this issue // here if we meet this issue
...@@ -647,25 +679,25 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -647,25 +679,25 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
case proto::AttrType::BOOLEANS: { case proto::AttrType::BOOLEANS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BOOLEANS"; << " from INTS to BOOLEANS";
this->attrs_[name] = std::vector<bool>(); attrs_ptr->operator[](name) = std::vector<bool>();
break; break;
} }
case proto::AttrType::INTS: { case proto::AttrType::INTS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to INTS"; << " from INTS to INTS";
this->attrs_[name] = std::vector<int>(); attrs_ptr->operator[](name) = std::vector<int>();
break; break;
} }
case proto::AttrType::LONGS: { case proto::AttrType::LONGS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from LONGS to LONGS"; << " from LONGS to LONGS";
this->attrs_[name] = std::vector<int64_t>(); attrs_ptr->operator[](name) = std::vector<int64_t>();
break; break;
} }
case proto::AttrType::FLOATS: { case proto::AttrType::FLOATS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to FLOATS"; << " from INTS to FLOATS";
this->attrs_[name] = std::vector<float>(); attrs_ptr->operator[](name) = std::vector<float>();
break; break;
} }
case proto::AttrType::FLOAT64S: { case proto::AttrType::FLOAT64S: {
...@@ -677,13 +709,13 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -677,13 +709,13 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
case proto::AttrType::STRINGS: { case proto::AttrType::STRINGS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to STRINGS"; << " from INTS to STRINGS";
this->attrs_[name] = std::vector<std::string>(); attrs_ptr->operator[](name) = std::vector<std::string>();
break; break;
} }
case proto::AttrType::BLOCKS: { case proto::AttrType::BLOCKS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BLOCKS"; << " from INTS to BLOCKS";
this->SetBlocksAttr(name, std::vector<BlockDesc *>()); attrs_ptr->operator[](name) = std::vector<BlockDesc *>();
return; return;
} }
default: default:
...@@ -695,14 +727,23 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -695,14 +727,23 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
} }
// In order to set bool attr properly // In order to set bool attr properly
if (attr_type == proto::AttrType::INT && HasProtoAttr(name) && if (attr_type == proto::AttrType::INT) {
GetProtoAttr(name).type() == proto::AttrType::BOOLEAN) { if (HasProtoAttr(name) &&
this->attrs_[name] = static_cast<bool>(PADDLE_GET_CONST(int, v)); GetProtoAttr(name).type() == proto::AttrType::BOOLEAN) {
need_update_ = true; attrs_ptr->operator[](name) = static_cast<bool>(PADDLE_GET_CONST(int, v));
return; need_update_ = true;
return;
}
if (extra_attr_iter != extra_attr_map.end() &&
static_cast<proto::AttrType>(extra_attr_iter->second.index() - 1) ==
proto::AttrType::BOOLEAN) {
attrs_ptr->operator[](name) = static_cast<bool>(PADDLE_GET_CONST(int, v));
need_update_ = true;
return;
}
} }
this->attrs_[name] = v; attrs_ptr->operator[](name) = v;
need_update_ = true; need_update_ = true;
} }
...@@ -733,8 +774,17 @@ void OpDesc::SetAttrMap( ...@@ -733,8 +774,17 @@ void OpDesc::SetAttrMap(
need_update_ = true; need_update_ = true;
} }
void OpDesc::SetRuntimeAttrMap(
const std::unordered_map<std::string, Attribute> &attr_map) {
runtime_attrs_ = attr_map;
need_update_ = true;
}
Attribute OpDesc::GetAttr(const std::string &name, bool with_attr_var) const { Attribute OpDesc::GetAttr(const std::string &name, bool with_attr_var) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
if (it == attrs_.end()) {
it = runtime_attrs_.find(name);
}
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, it,
attrs_.end(), attrs_.end(),
...@@ -802,6 +852,8 @@ const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const { ...@@ -802,6 +852,8 @@ const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
return attrs_; return attrs_;
} }
const AttributeMap &OpDesc::GetRuntimeAttrMap() const { return runtime_attrs_; }
void OpDesc::Rename(const std::string &old_name, const std::string &new_name) { void OpDesc::Rename(const std::string &old_name, const std::string &new_name) {
RenameInput(old_name, new_name); RenameInput(old_name, new_name);
RenameOutput(old_name, new_name); RenameOutput(old_name, new_name);
...@@ -925,6 +977,15 @@ void OpDesc::Flush() { ...@@ -925,6 +977,15 @@ void OpDesc::Flush() {
} }
this->desc_.mutable_attrs()->Clear(); this->desc_.mutable_attrs()->Clear();
auto set_attr_desc = [this](const std::string &attr_name,
const Attribute &attr) -> void {
auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr_name);
attr_desc->set_type(static_cast<proto::AttrType>(attr.index() - 1));
SetAttrDescVisitor visitor(attr_desc);
paddle::visit(visitor, attr);
};
std::vector<std::pair<std::string, Attribute>> sorted_attrs{attrs_.begin(), std::vector<std::pair<std::string, Attribute>> sorted_attrs{attrs_.begin(),
attrs_.end()}; attrs_.end()};
std::sort( std::sort(
...@@ -932,13 +993,12 @@ void OpDesc::Flush() { ...@@ -932,13 +993,12 @@ void OpDesc::Flush() {
sorted_attrs.end(), sorted_attrs.end(),
[](std::pair<std::string, Attribute> a, [](std::pair<std::string, Attribute> a,
std::pair<std::string, Attribute> b) { return a.first < b.first; }); std::pair<std::string, Attribute> b) { return a.first < b.first; });
for (auto &attr : sorted_attrs) { for (auto &attr : sorted_attrs) {
auto *attr_desc = desc_.add_attrs(); set_attr_desc(attr.first, attr.second);
attr_desc->set_name(attr.first); }
attr_desc->set_type( for (auto &attr : runtime_attrs_) {
static_cast<proto::AttrType>(attr.second.index() - 1)); set_attr_desc(attr.first, attr.second);
SetAttrDescVisitor visitor(attr_desc);
paddle::visit(visitor, attr.second);
} }
need_update_ = false; need_update_ = false;
...@@ -1155,7 +1215,7 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name, ...@@ -1155,7 +1215,7 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name,
} }
AttrReader CompileTimeInferShapeContext::Attrs() const { AttrReader CompileTimeInferShapeContext::Attrs() const {
return AttrReader(op_.GetAttrMap()); return AttrReader(op_.GetAttrMap(), op_.GetRuntimeAttrMap());
} }
std::vector<std::string> CompileTimeInferShapeContext::Inputs( std::vector<std::string> CompileTimeInferShapeContext::Inputs(
......
...@@ -144,6 +144,10 @@ class OpDesc { ...@@ -144,6 +144,10 @@ class OpDesc {
// Only be used in C++ // Only be used in C++
void SetAttrMap(const AttributeMap &attr_map); void SetAttrMap(const AttributeMap &attr_map);
void SetRuntimeAttrMap(const AttributeMap &attr_map);
const AttributeMap &GetRuntimeAttrMap() const;
std::vector<std::string> InputNames(bool with_attr_var = false) const { std::vector<std::string> InputNames(bool with_attr_var = false) const {
return MapKeys(inputs_); return MapKeys(inputs_);
} }
...@@ -221,6 +225,12 @@ class OpDesc { ...@@ -221,6 +225,12 @@ class OpDesc {
VariableNameMap outputs_; VariableNameMap outputs_;
// attribute name => all original attrs // attribute name => all original attrs
AttributeMap attrs_; AttributeMap attrs_;
// runtime_attrs_ contains the attributes which used for dispatching kernel
// (use_mkldnn, use_cudnn, ...) or passing additional configuration for
// special heterogeneous kernel (workspace_size_MB, ...).
// The attributes in runtime_attrs_ are setted by framework (such as PASS),
// and not in the python api.
AttributeMap runtime_attrs_;
// need_update_ indicate there some local changes not be synchronized. If // need_update_ indicate there some local changes not be synchronized. If
// local changes should be synchronized, need_update_ should be set to true. // local changes should be synchronized, need_update_ should be set to true.
......
...@@ -12,8 +12,7 @@ See the License for the specific language governing permissions and ...@@ -12,8 +12,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include <string>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -67,7 +66,16 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -67,7 +66,16 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
op_checker_ = attr_checker; op_checker_ = attr_checker;
Make(); Make();
op_checker_->RecordExplicitCheckerNum(); op_checker_->RecordExplicitCheckerNum();
op_checker_->InitDefaultAttributeMap();
const AttributeMap* extra_attrs_ptr = nullptr;
const std::string& op_type = proto->type();
const auto& extra_attr_map =
operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type);
if (!extra_attr_map.empty()) {
extra_attrs_ptr = &extra_attr_map;
}
op_checker_->InitDefaultAttributeMap(extra_attrs_ptr);
AddAttr<int>(OpRoleAttrName(), "The role of this operator") AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum( .InEnum(
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "glog/logging.h" #include "glog/logging.h"
...@@ -25,15 +26,47 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp( ...@@ -25,15 +26,47 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const VariableNameMap& outputs, const VariableNameMap& outputs,
const AttributeMap& attrs, const AttributeMap& attrs,
bool attr_check) { bool attr_check) {
AttributeMap standard_attrs;
AttributeMap runtime_attrs =
paddle::operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(type);
for (auto& attr : attrs) {
auto it = runtime_attrs.find(attr.first);
if (it != runtime_attrs.end()) {
it->second = attr.second;
} else {
standard_attrs[attr.first] = attr.second;
}
}
auto& info = OpInfoMap::Instance().Get(type);
if (attr_check && info.Checker() != nullptr) {
info.Checker()->Check(&standard_attrs);
}
auto op_base = std::unique_ptr<OperatorBase>(
info.Creator()(type, inputs, outputs, standard_attrs));
op_base->SetRuntimeAttributeMap(runtime_attrs);
return op_base;
}
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs,
const AttributeMap& runtime_attrs,
bool attr_check) {
std::unique_ptr<OperatorBase> op_base;
auto& info = OpInfoMap::Instance().Get(type); auto& info = OpInfoMap::Instance().Get(type);
if (attr_check && info.Checker() != nullptr) { if (attr_check && info.Checker() != nullptr) {
auto tmp_attrs = attrs; auto tmp_attrs = attrs;
info.Checker()->Check(&tmp_attrs); info.Checker()->Check(&tmp_attrs);
return std::unique_ptr<OperatorBase>( op_base = std::unique_ptr<OperatorBase>(
info.Creator()(type, inputs, outputs, tmp_attrs)); info.Creator()(type, inputs, outputs, tmp_attrs));
} else {
op_base = std::unique_ptr<OperatorBase>(
info.Creator()(type, inputs, outputs, attrs));
} }
return std::unique_ptr<OperatorBase>( op_base->SetRuntimeAttributeMap(runtime_attrs);
info.Creator()(type, inputs, outputs, attrs)); return op_base;
} }
static VariableNameMap ConvertOpDescVarsToVarNameMap( static VariableNameMap ConvertOpDescVarsToVarNameMap(
...@@ -59,18 +92,27 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp( ...@@ -59,18 +92,27 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs; AttributeMap attrs;
AttributeMap extra_attrs =
paddle::operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(
op_desc.type());
for (auto& attr : op_desc.attrs()) { for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr); auto it = extra_attrs.find(attr.name());
if (it != extra_attrs.end()) {
it->second = GetAttrValue(attr);
} else {
attrs[attr.name()] = GetAttrValue(attr);
}
} }
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs, extra_attrs);
} }
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) { std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
return CreateOp(op_desc.Type(), return CreateOp(op_desc.Type(),
op_desc.Inputs(), op_desc.Inputs(),
op_desc.Outputs(), op_desc.Outputs(),
op_desc.GetAttrMap()); op_desc.GetAttrMap(),
op_desc.GetRuntimeAttrMap());
} }
} // namespace framework } // namespace framework
......
...@@ -132,6 +132,13 @@ class OpRegistry { ...@@ -132,6 +132,13 @@ class OpRegistry {
const VariableNameMap& outputs, const VariableNameMap& outputs,
const AttributeMap& attrs, const AttributeMap& attrs,
bool attr_check = true); bool attr_check = true);
static std::unique_ptr<OperatorBase> CreateOp(
const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs,
const AttributeMap& runtime_attrs,
bool attr_check = true);
static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
......
...@@ -1389,10 +1389,9 @@ bool OperatorWithKernel::SupportsKernelType( ...@@ -1389,10 +1389,9 @@ bool OperatorWithKernel::SupportsKernelType(
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const { proto::VarType::Type data_type) const {
const auto& attrs_map = ctx.Attrs(); const std::string use_mkldnn_attr = "use_mkldnn";
auto iter = attrs_map.find("use_mkldnn"); bool use_mkldnn_ctx = ctx.HasAttr(use_mkldnn_attr) &&
bool use_mkldnn_ctx = iter != attrs_map.end() && ctx.Attr<bool>(use_mkldnn_attr) &&
PADDLE_GET_CONST(bool, iter->second) &&
platform::is_cpu_place(ctx.GetPlace()); platform::is_cpu_place(ctx.GetPlace());
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type); return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
} }
...@@ -2881,12 +2880,16 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2881,12 +2880,16 @@ void OperatorWithKernel::BuildPhiKernelContext(
} }
} break; } break;
default: { default: {
PADDLE_ENFORCE_NE( if (attr_iter == Attrs().end()) {
attr_iter, attr_iter = RuntimeAttrs().find(attr_names[i]);
Attrs().end(), PADDLE_ENFORCE_NE(attr_iter,
platform::errors::NotFound("(%s) is not found in AttributeMap when " RuntimeAttrs().end(),
"buildind static KernelContext.", platform::errors::NotFound(
attr_names[i])); "(%s) is not found in AttributeMap when "
"buildind static KernelContext.",
attr_names[i]));
}
switch (attr_defs[i].type_index) { switch (attr_defs[i].type_index) {
case phi::AttributeType::FLOAT32: case phi::AttributeType::FLOAT32:
phi_kernel_context->EmplaceBackAttr( phi_kernel_context->EmplaceBackAttr(
......
...@@ -177,7 +177,9 @@ class OperatorBase { ...@@ -177,7 +177,9 @@ class OperatorBase {
const std::string& Type() const { return type_; } const std::string& Type() const { return type_; }
bool HasAttr(const std::string& name) const { return attrs_.count(name); } bool HasAttr(const std::string& name) const {
return attrs_.count(name) || runtime_attrs_.count(name);
}
template <typename T> template <typename T>
inline const T& Attr(const std::string& name) const { inline const T& Attr(const std::string& name) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
...@@ -196,6 +198,10 @@ class OperatorBase { ...@@ -196,6 +198,10 @@ class OperatorBase {
attrs_[name] = v; attrs_[name] = v;
} }
const AttributeMap& Attrs() const { return attrs_; } const AttributeMap& Attrs() const { return attrs_; }
const AttributeMap& RuntimeAttrs() const { return runtime_attrs_; }
void SetRuntimeAttributeMap(const AttributeMap& runtime_attrs) {
runtime_attrs_ = runtime_attrs;
}
const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; } const VariableNameMap& Outputs() const { return outputs_; }
...@@ -250,6 +256,12 @@ class OperatorBase { ...@@ -250,6 +256,12 @@ class OperatorBase {
// IG (Inputs Gradients) // IG (Inputs Gradients)
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
// NOTE: runtime_attrs_ contains the attributes which used for dispatching
// kernel (use_mkldnn, use_cudnn, ...) or passing additional configuration
// for special heterogeneous kernel (workspace_size_MB, ...).
// The attributes in runtime_attrs_ are setted by framework (such as PASS),
// and not in the python api.
AttributeMap runtime_attrs_;
// OpInfo // OpInfo
const OpInfo* info_; const OpInfo* info_;
...@@ -302,7 +314,12 @@ class ExecutionContext { ...@@ -302,7 +314,12 @@ class ExecutionContext {
} }
virtual const Attribute& GetAttr(const std::string& name) const { virtual const Attribute& GetAttr(const std::string& name) const {
return op_.Attrs().at(name); auto iter = op_.Attrs().find(name);
if (iter == op_.Attrs().end()) {
return op_.RuntimeAttrs().at(name);
} else {
return iter->second;
}
} }
virtual bool HasInput(const std::string& name) const; virtual bool HasInput(const std::string& name) const;
......
...@@ -73,7 +73,8 @@ cc_library( ...@@ -73,7 +73,8 @@ cc_library(
denormal denormal
garbage_collector garbage_collector
var_helper var_helper
layout_autotune) layout_autotune
ops_extra_info)
cc_library( cc_library(
basic_engine basic_engine
SRCS basic_engine.cc SRCS basic_engine.cc
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/imperative/execution_context.h" #include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/layout_autotune.h" #include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/platform/denormal.h" #include "paddle/fluid/platform/denormal.h"
#include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -240,6 +241,11 @@ void Tracer::TraceOpImpl(const std::string& type, ...@@ -240,6 +241,11 @@ void Tracer::TraceOpImpl(const std::string& type,
if (attr_checker) { if (attr_checker) {
attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true); attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
} }
const auto& extra_attr_checkers =
operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(type);
for (const auto& checker : extra_attr_checkers) {
checker(&attrs, true);
}
static paddle::framework::AttributeMap empty_attrs_map = {}; static paddle::framework::AttributeMap empty_attrs_map = {};
const paddle::framework::AttributeMap& default_attrs = const paddle::framework::AttributeMap& default_attrs =
......
...@@ -149,6 +149,7 @@ if (WITH_DGC) ...@@ -149,6 +149,7 @@ if (WITH_DGC)
endif() endif()
cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEPS operator) cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEPS operator)
cc_library(ops_extra_info SRCS ops_extra_info.cc DEPS attribute cudnn_workspace_helper)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows_utils lapack_function set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows_utils lapack_function
lod_tensor maxouting unpooling pooling lod_rank_table context_project lod_tensor maxouting unpooling pooling lod_rank_table context_project
......
...@@ -348,82 +348,6 @@ void Conv2DOpMaker::Make() { ...@@ -348,82 +348,6 @@ void Conv2DOpMaker::Make() {
"dilations(h_dilation, w_dilation) of " "dilations(h_dilation, w_dilation) of "
"convolution operator.") "convolution operator.")
.SetDefault({1, 1}); .SetDefault({1, 1});
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false)
.AsExtra();
AddAttr<bool>("fuse_relu_before_depthwise_conv",
"(bool, default false) Only used in cuda depthwise kernel")
.SetDefault(false)
.AsExtra();
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddAttr<bool>(
"use_quantizer",
"(bool, default false) "
"This parameter is no longer used. Use 'mkldnn_data_type' instead.")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "int8", "bfloat16"})
.AsExtra();
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>("fuse_activation",
"(string, default \"\") Only used in mkldnn kernel")
.SetDefault("")
.AsExtra();
AddAttr<float>("fuse_alpha",
"(float, default 0.0) Only used in mkldnn kernel")
.SetDefault(0.0f)
.AsExtra();
AddAttr<float>("fuse_beta", "(float, default 0.0) Only used in mkldnn kernel")
.SetDefault(0.0f)
.AsExtra();
AddAttr<bool>(
"use_addto",
"(bool, default false) If use addto strategy or not, only used in "
"cudnn kernel")
.SetDefault(false)
.AsExtra();
AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is as an input to residual "
"connection.")
.SetDefault(false)
.AsExtra();
AddAttr<float>("Scale_in",
"Scale_in to be used for int8 input data."
"Only used with MKL-DNN INT8.")
.SetDefault(1.0f)
.AsExtra();
AddAttr<float>("Scale_out",
"Scale_out to be used for int8 output data."
"Only used with MKL-DNN INT8.")
.SetDefault(1.0f)
.AsExtra();
AddAttr<float>("Scale_in_eltwise",
"Scale_in_eltwise to be used for int8 eltwise input data."
"Only used with MKL-DNN INT8.")
.SetDefault(1.0f)
.AsExtra();
AddAttr<std::vector<float>>("Scale_weights",
"Scale_weights to be used for int8 weights data."
"Only used with MKL-DNN INT8.")
.SetDefault({1.0f})
.AsExtra();
AddAttr<bool>("force_fp32_output",
"(bool, default false) Force INT8 kernel output FP32, only "
"used in MKL-DNN INT8")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCHW) Only used in "
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/attribute.h"
namespace paddle {
namespace operators {
template <typename T>
struct ExtraAttrChecker {
ExtraAttrChecker(const std::string& attr_name, T default_value)
: attr_name(attr_name), default_val(default_value) {}
void operator()(framework::AttributeMap* attr_map,
bool only_check_exist_value) {
auto it = attr_map->find(attr_name);
if (it == attr_map->end()) {
if (!only_check_exist_value) {
attr_map->emplace(attr_name, default_val);
}
return;
}
framework::ExtractAttribute<T> extract_attr(attr_name);
extract_attr(it->second);
}
const std::string& attr_name;
T default_val;
};
class ExtraInfoUtils {
public:
static ExtraInfoUtils& Instance() {
static ExtraInfoUtils extra_info_utils;
return extra_info_utils;
}
const std::unordered_map<std::string, paddle::framework::AttributeMap>&
GetAllExtraAttrsMap() const {
return g_extra_attrs_map_;
}
const paddle::framework::AttributeMap& GetExtraAttrsMap(
const std::string& op_type) const {
auto iter = g_extra_attrs_map_.find(op_type);
if (iter != g_extra_attrs_map_.end()) {
return iter->second;
}
return empty_extra_attrs_map_;
}
const std::vector<std::function<void(framework::AttributeMap*, bool)>>&
GetExtraAttrsChecker(const std::string& op_type) const {
auto iter = g_extra_attrs_checker_.find(op_type);
if (iter != g_extra_attrs_checker_.end()) {
return iter->second;
}
return empty_extra_attrs_checker_;
}
private:
ExtraInfoUtils();
std::unordered_map<std::string, paddle::framework::AttributeMap>
g_extra_attrs_map_;
paddle::framework::AttributeMap empty_extra_attrs_map_{};
std::unordered_map<
std::string,
std::vector<std::function<void(framework::AttributeMap*, bool)>>>
g_extra_attrs_checker_;
std::vector<std::function<void(framework::AttributeMap*, bool)>>
empty_extra_attrs_checker_{};
};
} // namespace operators
} // namespace paddle
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/imperative.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -961,6 +962,15 @@ void InitOpsAttrTypeMap() { ...@@ -961,6 +962,15 @@ void InitOpsAttrTypeMap() {
OpAttrTypeMap::Instance().Map()[iter->first][attr.name()] = attr.type(); OpAttrTypeMap::Instance().Map()[iter->first][attr.name()] = attr.type();
} }
} }
const auto& extra_attr_maps =
operators::ExtraInfoUtils::Instance().GetAllExtraAttrsMap();
for (const auto& extra_attrs : extra_attr_maps) {
for (auto& attr : extra_attrs.second) {
OpAttrTypeMap::Instance().Map()[extra_attrs.first][attr.first] =
static_cast<paddle::framework::proto::AttrType>(attr.second.index() -
1);
}
}
} }
ssize_t GetIdxFromCoreOpsInfoMap( ssize_t GetIdxFromCoreOpsInfoMap(
......
...@@ -71,6 +71,7 @@ limitations under the License. */ ...@@ -71,6 +71,7 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -1068,6 +1069,23 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1068,6 +1069,23 @@ All parameter, weight, gradient are variables in Paddle.
} }
return res; return res;
}); });
m.def(
"get_op_extra_attrs",
[](const std::string &op_type)
-> const paddle::framework::AttributeMap & {
return operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type);
});
m.def(
"get_attrtibute_type",
[](const std::string &op_type,
const std::string &attr_name) -> paddle::framework::proto::AttrType {
const auto &defalut_val =
operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type).at(
attr_name);
return static_cast<paddle::framework::proto::AttrType>(
defalut_val.index() - 1);
});
m.def("get_grad_op_desc", m.def("get_grad_op_desc",
[](const OpDesc &op_desc, [](const OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set, const std::unordered_set<std::string> &no_grad_set,
......
...@@ -100,7 +100,7 @@ set(ops_extra_info_gen_file ...@@ -100,7 +100,7 @@ set(ops_extra_info_gen_file
set(api_compat_yaml_file set(api_compat_yaml_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/api_compat.yaml) ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/api_compat.yaml)
set(ops_extra_info_file set(ops_extra_info_file
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.h) ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.cc)
if(NOT PYTHONINTERP_FOUND) if(NOT PYTHONINTERP_FOUND)
find_package(PythonInterp REQUIRED) find_package(PythonInterp REQUIRED)
......
# - api : conv3d_transpose
# backward : conv3d_transpose_grad
# extra :
# attrs : [bool use_cudnn = true, bool use_mkldnn = false, int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()]
- api : atan2 - api : atan2
inputs : inputs :
{x : X1, y : X2} {x : X1, y : X2}
...@@ -23,19 +28,20 @@ ...@@ -23,19 +28,20 @@
out : Out out : Out
- api : conv2d - api : conv2d
backward : conv2d_grad
extra : extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false, attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false, bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false, str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f, bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false, float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false] int workspace_size_MB = 512, bool exhaustive_search = false]
- api : conv2d - api : conv2d_fusion
extra : extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false, attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false, bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false, str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f, bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false, float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false] int workspace_size_MB = 512, bool exhaustive_search = false]
...@@ -48,6 +54,16 @@ ...@@ -48,6 +54,16 @@
outputs : outputs :
out : Out out : Out
- api : depthwise_conv2d
backward : depthwise_conv2d_grad
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
- api : diag - api : diag
op_name : diag_v2 op_name : diag_v2
grad_op_name : diag_v2_grad grad_op_name : diag_v2_grad
......
...@@ -18,17 +18,27 @@ import re ...@@ -18,17 +18,27 @@ import re
import argparse import argparse
def map_code_template(attrs_str): def map_code_template(attrs_str, attrs_checker_str):
return f""" return f"""// This file is generated by paddle/phi/api/yaml/generator/ops_extra_info_gen.py
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
namespace paddle {{ namespace paddle {{
const static std::unordered_map<std::string, paddle::framework::AttributeMap> extra_attrs_map = {{ namespace operators {{
{attrs_str}
}};
}} // namespace paddle ExtraInfoUtils::ExtraInfoUtils() {{
g_extra_attrs_map_ = {{
{attrs_str}
}};
g_extra_attrs_checker_ = {{
{attrs_checker_str}
}};
}}
}} // namespace operators
}} // namespace paddle
""" """
...@@ -61,6 +71,7 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): ...@@ -61,6 +71,7 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path):
compat_apis = yaml.safe_load(f) compat_apis = yaml.safe_load(f)
extra_map_str_list = [] extra_map_str_list = []
extra_checker_str_list = []
for api_compat_args in compat_apis: for api_compat_args in compat_apis:
if 'extra' in api_compat_args: if 'extra' in api_compat_args:
...@@ -68,8 +79,12 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): ...@@ -68,8 +79,12 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path):
# TODO(chenweihang): add inputs and outputs # TODO(chenweihang): add inputs and outputs
if 'attrs' in extra_args_map: if 'attrs' in extra_args_map:
attr_map_list = [] attr_map_list = []
attr_checker_func_list = []
for attr in extra_args_map['attrs']: for attr in extra_args_map['attrs']:
attr_type, attr_name, default_val = parse_attr(attr) attr_type, attr_name, default_val = parse_attr(attr)
attr_checker_func_list.append(
f"[](framework::AttributeMap* attr_map, bool only_check_exist_value)-> void {{ ExtraAttrChecker<{attr_type}>(\"{attr_name}\", {default_val})(attr_map, only_check_exist_value);}}"
)
if attr_type.startswith("std::vector"): if attr_type.startswith("std::vector"):
attr_map_list.append( attr_map_list.append(
f"{{\"{attr_name}\", {attr_type}{default_val}}}") f"{{\"{attr_name}\", {attr_type}{default_val}}}")
...@@ -78,12 +93,26 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): ...@@ -78,12 +93,26 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path):
f"{{\"{attr_name}\", {attr_type}{{{default_val}}}}}" f"{{\"{attr_name}\", {attr_type}{{{default_val}}}}}"
) )
api_extra_attr_map = ", ".join(attr_map_list) api_extra_attr_map = ", ".join(attr_map_list)
api_extra_attr_checkers = ",\n ".join(
attr_checker_func_list)
extra_map_str_list.append( extra_map_str_list.append(
f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_map} }}}}" f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_map} }}}}"
) )
extra_checker_str_list.append(
f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_checkers} }}}}"
)
if 'backward' in api_compat_args:
extra_map_str_list.append(
f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_map} }}}}"
)
extra_checker_str_list.append(
f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_checkers} }}}}"
)
ops_extra_info_file = open(ops_extra_info_path, 'w') ops_extra_info_file = open(ops_extra_info_path, 'w')
ops_extra_info_file.write(map_code_template(",\n".join(extra_map_str_list))) ops_extra_info_file.write(
map_code_template(",\n ".join(extra_map_str_list),
",\n ".join(extra_checker_str_list)))
ops_extra_info_file.close() ops_extra_info_file.close()
...@@ -96,7 +125,7 @@ def main(): ...@@ -96,7 +125,7 @@ def main():
parser.add_argument('--ops_extra_info_path', parser.add_argument('--ops_extra_info_path',
help='output of generated extra_prama_info code file', help='output of generated extra_prama_info code file',
default='paddle/fluid/operators/ops_extra_info.h') default='paddle/fluid/operators/ops_extra_info.cc')
options = parser.parse_args() options = parser.parse_args()
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/variant.h" #include "paddle/utils/variant.h"
namespace phi { namespace phi {
...@@ -47,4 +48,6 @@ using Attribute = paddle::variant<bool, ...@@ -47,4 +48,6 @@ using Attribute = paddle::variant<bool,
DataLayout, DataLayout,
Place>; Place>;
using RuntimeAttrs = paddle::flat_hash_map<std::string, Attribute>;
} // namespace phi } // namespace phi
...@@ -138,6 +138,8 @@ class KernelContext { ...@@ -138,6 +138,8 @@ class KernelContext {
template <typename AttrType> template <typename AttrType>
const AttrType& AttrAt(size_t idx) const; const AttrType& AttrAt(size_t idx) const;
const RuntimeAttrs& GetRuntimeAttrs() const { return runtime_attrs_; }
size_t InputsSize() const { return inputs_.size(); } size_t InputsSize() const { return inputs_.size(); }
size_t OutputsSize() const { return outputs_.size(); } size_t OutputsSize() const { return outputs_.size(); }
size_t AttrsSize() const { return attrs_.size(); } size_t AttrsSize() const { return attrs_.size(); }
...@@ -152,6 +154,8 @@ class KernelContext { ...@@ -152,6 +154,8 @@ class KernelContext {
paddle::small_vector<std::pair<int, int>, kInputSmallVectorSize> input_range_; paddle::small_vector<std::pair<int, int>, kInputSmallVectorSize> input_range_;
paddle::small_vector<std::pair<int, int>, kOutputSmallVectorSize> paddle::small_vector<std::pair<int, int>, kOutputSmallVectorSize>
output_range_; output_range_;
RuntimeAttrs runtime_attrs_;
}; };
} // namespace phi } // namespace phi
...@@ -205,6 +205,8 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -205,6 +205,8 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
args_def->AppendAttribute(AttributeType::DATA_LAYOUT); args_def->AppendAttribute(AttributeType::DATA_LAYOUT);
} else if (arg_type == std::type_index(typeid(Place))) { } else if (arg_type == std::type_index(typeid(Place))) {
args_def->AppendAttribute(AttributeType::PLACE); args_def->AppendAttribute(AttributeType::PLACE);
} else if (arg_type == std::type_index(typeid(RuntimeAttrs))) {
// do nothing
} else { } else {
PADDLE_THROW(phi::errors::Unavailable( PADDLE_THROW(phi::errors::Unavailable(
"Unsupported kernel argument type `%s`.", arg_type.name())); "Unsupported kernel argument type `%s`.", arg_type.name()));
......
...@@ -321,6 +321,22 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -321,6 +321,22 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(StringTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(StringTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(StringTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(StringTensor);
template <typename... Tail>
struct KernelCallHelper<const RuntimeAttrs&, Tail...> {
template <int dev_ctx_idx,
int in_idx,
int attr_idx,
int out_idx,
typename... PreviousArgs>
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) {
const auto& runtime_attrs = ctx->GetRuntimeAttrs();
KernelCallHelper<Tail...>::
template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx>(
ctx, pargs..., runtime_attrs);
}
};
/* End case */ /* End case */
template <typename T> template <typename T>
struct KernelCallHelper<TypeTag<T>> { struct KernelCallHelper<TypeTag<T>> {
......
...@@ -2840,6 +2840,7 @@ class Operator(object): ...@@ -2840,6 +2840,7 @@ class Operator(object):
arg.op = self arg.op = self
self.desc.set_output(out_proto.name, out_arg_names) self.desc.set_output(out_proto.name, out_arg_names)
extra_attrs_map = core.get_op_extra_attrs(type)
if op_attrs is not None: if op_attrs is not None:
if not isinstance(op_attrs, dict): if not isinstance(op_attrs, dict):
raise TypeError("'attrs' should be a dict.") raise TypeError("'attrs' should be a dict.")
...@@ -2850,6 +2851,13 @@ class Operator(object): ...@@ -2850,6 +2851,13 @@ class Operator(object):
continue continue
attr_val = op_attrs[attr_name] attr_val = op_attrs[attr_name]
self._update_desc_attr(attr_name, attr_val) self._update_desc_attr(attr_name, attr_val)
for attr_name in extra_attrs_map.keys():
if (attr_name
not in op_attrs) or (op_attrs[attr_name] is None):
self._update_desc_attr(attr_name,
extra_attrs_map[attr_name])
else:
self._update_desc_attr(attr_name, op_attrs[attr_name])
# proto.attrs doesn't include ipu_index # proto.attrs doesn't include ipu_index
if core.is_compiled_with_ipu(): if core.is_compiled_with_ipu():
...@@ -5821,17 +5829,29 @@ class Program(object): ...@@ -5821,17 +5829,29 @@ class Program(object):
] ]
res._sync_with_cpp() res._sync_with_cpp()
# Note: The op_role and op_role_var cann't be deleted currently,
# and we will try to remove them in the future.
common_clipped_attrs_list = ['op_namescope', 'op_callstack']
for i in six.moves.range(res.desc.num_blocks()): for i in six.moves.range(res.desc.num_blocks()):
block = res.desc.block(i) block = res.desc.block(i)
for var in block.all_vars(): for var in block.all_vars():
var.clear_is_parameter() var.clear_is_parameter()
var.clear_stop_gradient() var.clear_stop_gradient()
if not clip_extra:
continue
for op_idx in range(0, block.op_size()): for op_idx in range(0, block.op_size()):
op = block.op(op_idx) op = block.op(op_idx)
if op.type() not in OpProtoHolder.instance().op_proto_map: if op.type() not in OpProtoHolder.instance().op_proto_map:
continue continue
if not clip_extra:
continue
extra_attrs_map = core.get_op_extra_attrs(op.type())
for name in op.attr_names():
if name in extra_attrs_map:
op.remove_attr(name)
continue
proto = OpProtoHolder.instance().get_op_proto(op.type()) proto = OpProtoHolder.instance().get_op_proto(op.type())
remove_input_list = [] remove_input_list = []
for name in op.input_names(): for name in op.input_names():
...@@ -5845,8 +5865,9 @@ class Program(object): ...@@ -5845,8 +5865,9 @@ class Program(object):
break break
if not find: if not find:
remove_input_list.append(name) remove_input_list.append(name)
for name in remove_input_list: # The extra input of op will be removed in the future
op.remove_input(name) # for name in remove_input_list:
# op.remove_input(name)
remove_output_list = [] remove_output_list = []
for name in op.output_names(): for name in op.output_names():
...@@ -5860,10 +5881,10 @@ class Program(object): ...@@ -5860,10 +5881,10 @@ class Program(object):
break break
if not find: if not find:
remove_output_list.append(name) remove_output_list.append(name)
for name in remove_output_list: # The extra input of op will be removed in the future
op.remove_output(name) # for name in remove_output_list:
# op.remove_output(name)
remove_attr_list = []
op_quant_name = core.op_proto_and_checker_maker.kOpWithQuantAttrName( op_quant_name = core.op_proto_and_checker_maker.kOpWithQuantAttrName(
) )
quant = bool(op.attr(op_quant_name) quant = bool(op.attr(op_quant_name)
...@@ -5873,18 +5894,21 @@ class Program(object): ...@@ -5873,18 +5894,21 @@ class Program(object):
"activation_bits", "bit_length", "quantize_weight_bits", "activation_bits", "bit_length", "quantize_weight_bits",
"weight_quant_scale" "weight_quant_scale"
] ]
remove_attr_list = []
for name in op.attr_names(): for name in op.attr_names():
if quant: if quant:
if name in quant_attrs: if name in quant_attrs:
continue continue
if name.endswith("_threshold"): if name.endswith("_threshold"):
continue continue
if name in common_clipped_attrs_list:
remove_attr_list.append(name)
continue
find = False find = False
for attr_proto in proto.attrs: for attr_proto in proto.attrs:
if attr_proto.name != name: if attr_proto.name != name:
continue continue
if attr_proto.extra:
remove_attr_list.append(name)
find = True find = True
break break
if not find: if not find:
......
...@@ -52,6 +52,7 @@ class OpDescCreationMethod(object): ...@@ -52,6 +52,7 @@ class OpDescCreationMethod(object):
raise TypeError( raise TypeError(
"Type of op_proto should be OpProto in PaddlePaddle.") "Type of op_proto should be OpProto in PaddlePaddle.")
self.__op_proto__ = op_proto self.__op_proto__ = op_proto
self.__extra_attrs__ = core.get_op_extra_attrs(op_proto.type)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" """
...@@ -130,6 +131,40 @@ class OpDescCreationMethod(object): ...@@ -130,6 +131,40 @@ class OpDescCreationMethod(object):
raise NotImplementedError( raise NotImplementedError(
"A not supported attribute type: %s." % "A not supported attribute type: %s." %
(str(attr.type))) (str(attr.type)))
for attr_name, defalut_val in self.__extra_attrs__.items():
user_defined_attr = kwargs.get(attr_name, None)
if user_defined_attr is not None:
attr_type = int(
core.get_attrtibute_type(op_desc.type, attr_name))
new_attr = op_desc.attrs.add()
new_attr.name = attr_name
new_attr.type = attr_type
if isinstance(user_defined_attr, np.ndarray):
user_defined_attr = user_defined_attr.tolist()
if attr_type == framework_pb2.INT:
new_attr.i = user_defined_attr
elif attr_type == framework_pb2.FLOAT:
new_attr.f = user_defined_attr
elif attr_type == framework_pb2.LONG:
new_attr.l = user_defined_attr
elif attr_type == framework_pb2.STRING:
new_attr.s = user_defined_attr
elif attr_type == framework_pb2.BOOLEAN:
new_attr.b = user_defined_attr
elif attr_type == framework_pb2.INTS:
new_attr.ints.extend(user_defined_attr)
elif attr_type == framework_pb2.FLOATS:
new_attr.floats.extend(user_defined_attr)
elif attr_type == framework_pb2.STRINGS:
new_attr.strings.extend(user_defined_attr)
elif attr_type == framework_pb2.BOOLEANS:
new_attr.bools.extend(user_defined_attr)
elif attr_type == framework_pb2.LONGS:
new_attr.longs.extend(user_defined_attr)
else:
raise NotImplementedError(
"A not supported attribute type: %s." %
(str(attr_type)))
return op_desc return op_desc
...@@ -147,12 +182,13 @@ class OpDescCreationMethod(object): ...@@ -147,12 +182,13 @@ class OpDescCreationMethod(object):
class OpInfo(object): class OpInfo(object):
def __init__(self, name, method, inputs, outputs, attrs): def __init__(self, name, method, inputs, outputs, attrs, extra_attrs):
self.name = name self.name = name
self.method = method self.method = method
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.attrs = attrs self.attrs = attrs
self.extra_attrs = extra_attrs
def create_op_creation_method(op_proto): def create_op_creation_method(op_proto):
...@@ -165,13 +201,16 @@ def create_op_creation_method(op_proto): ...@@ -165,13 +201,16 @@ def create_op_creation_method(op_proto):
opdesc = method(*args, **kwargs) opdesc = method(*args, **kwargs)
return core.Operator.create(opdesc.SerializeToString()) return core.Operator.create(opdesc.SerializeToString())
extra_attrs_map = core.get_op_extra_attrs(op_proto.type)
return OpInfo(method=__impl__, return OpInfo(method=__impl__,
name=op_proto.type, name=op_proto.type,
inputs=[(var.name, var.duplicable) inputs=[(var.name, var.duplicable)
for var in op_proto.inputs], for var in op_proto.inputs],
outputs=[(var.name, var.duplicable) outputs=[(var.name, var.duplicable)
for var in op_proto.outputs], for var in op_proto.outputs],
attrs=[attr.name for attr in op_proto.attrs]) attrs=[attr.name for attr in op_proto.attrs],
extra_attrs=[item for item in extra_attrs_map.keys()])
class OperatorFactory(object): class OperatorFactory(object):
...@@ -222,6 +261,9 @@ class OperatorFactory(object): ...@@ -222,6 +261,9 @@ class OperatorFactory(object):
def get_op_attr_names(self, type): def get_op_attr_names(self, type):
return self.get_op_info(type).attrs return self.get_op_info(type).attrs
def get_op_extra_attr_names(self, type):
return self.get_op_info(type).extra_attrs
class __RecurrentOp__(object): class __RecurrentOp__(object):
__proto__ = None __proto__ = None
......
...@@ -64,6 +64,10 @@ def create_op(scope, op_type, inputs, outputs, attrs, cache_list=None): ...@@ -64,6 +64,10 @@ def create_op(scope, op_type, inputs, outputs, attrs, cache_list=None):
if attr_name in attrs: if attr_name in attrs:
kwargs[attr_name] = attrs[attr_name] kwargs[attr_name] = attrs[attr_name]
for extra_attr_name in Operator.get_op_extra_attr_names(op_type):
if extra_attr_name in attrs:
kwargs[extra_attr_name] = attrs[extra_attr_name]
return Operator(op_type, **kwargs) return Operator(op_type, **kwargs)
......
...@@ -240,7 +240,7 @@ if [ "${HAS_MODIFIED_DECLARATIONS}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then ...@@ -240,7 +240,7 @@ if [ "${HAS_MODIFIED_DECLARATIONS}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then
HAS_MODIFIED_API_COMPAT_YAML=`git diff --name-only upstream/$BRANCH | grep "paddle/phi/api/yaml/api_compat.yaml" || true` HAS_MODIFIED_API_COMPAT_YAML=`git diff --name-only upstream/$BRANCH | grep "paddle/phi/api/yaml/api_compat.yaml" || true`
if [ "${HAS_MODIFIED_API_COMPAT_YAML}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then if [ "${HAS_MODIFIED_API_COMPAT_YAML}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then
echo_line="You must be approved by chenwhql or zyfncg for paddle/phi/api/yaml/api_compat.yaml, which manages the extra params of Op and name mapping between Yaml and OpMaker. In order to ensure compatibility of framework, this file isn't allowed to be modified at will!\n" echo_line="You must be approved by chenwhql or zyfncg for paddle/phi/api/yaml/api_compat.yaml changes, which manages the extra params of Op and name mapping between Yaml and OpMaker. In order to ensure compatibility of framework, this file isn't allowed to be modified at will!\n"
check_approval 1 chenwhql zyfncg check_approval 1 chenwhql zyfncg
fi fi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册