未验证 提交 fe321f9a 编写于 作者: Z zyfncg 提交者: GitHub

Remove extra attribute in OpMaker (#44310)

* add runtime config in phi

* add runtime attr for op desc and op

* fix no proto error

* adjust opdesc set_attr impl

* try to remove conv_op extra attrs

* add init runtime attr map

* change extra header path

* fix runtime_attr

* fix trace_op

* fix bug of pass

* fix merge conflict

* fix dygraph attrs

* fix bug of pass

* fix dygraph bug

* fix unittest module

* delete extra attr default

* fix dropout kernel

* polish code

* fix extra output of instance_norm

* fix merge confilct

* fix op_desc bug

* add extra attr in yaml for conv3d_transpose

* don't remove extra input and output

* fix save_inference_model

* fix bug of batch_norm

* revert some change

* polish log

* polish code

* add code comment
Co-authored-by: NChen Weihang <chenweihang@baidu.com>
上级 a7c4facb
......@@ -5,7 +5,7 @@ paddle/fluid/API_PR.spec
paddle/fluid/eager/api/generated/*
paddle/fluid/op_use_default_grad_maker_DEV.spec
paddle/fluid/op_use_default_grad_maker_PR.spec
paddle/fluid/operators/ops_extra_info.h
paddle/fluid/operators/ops_extra_info.cc
paddle/phi/api/backward/backward_api.h
paddle/phi/api/backward/sparse_bw_api.h
paddle/phi/api/include/api.h
......
......@@ -348,7 +348,7 @@ cc_test(
cc_library(
op_proto_maker
SRCS op_proto_maker.cc
DEPS framework_proto attribute glog)
DEPS framework_proto attribute ops_extra_info glog)
cc_test(
op_proto_maker_test
SRCS op_proto_maker_test.cc
......@@ -483,6 +483,7 @@ cc_library(
proto_desc
SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc
DEPS attribute
ops_extra_info
shape_inference
op_info
operator
......@@ -498,7 +499,7 @@ endif()
cc_library(
op_registry
SRCS op_registry.cc
DEPS op_proto_maker op_info operator glog proto_desc)
DEPS op_proto_maker op_info operator ops_extra_info glog proto_desc)
cc_library(
op_call_stack
......
......@@ -325,10 +325,13 @@ class OpAttrChecker {
explicit_checker_num_ = attr_checkers_.size();
}
void InitDefaultAttributeMap() {
void InitDefaultAttributeMap(const AttributeMap* extra_attr_map) {
for (const auto& checker : attr_checkers_) {
checker(&default_attrs_, true, false);
}
if (extra_attr_map) {
default_attrs_.insert(extra_attr_map->begin(), extra_attr_map->end());
}
}
const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; }
......
......@@ -207,9 +207,9 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
"OpAttrChecker of %s has been registered.", op_type));
info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker();
info->proto_->set_type(op_type);
T maker;
maker(info->proto_, info->checker_);
info->proto_->set_type(op_type);
PADDLE_ENFORCE_EQ(
info->proto_->IsInitialized(),
true,
......
......@@ -161,6 +161,10 @@ class GradOpDescMakerBase {
return fwd_op_.GetAttrMap();
}
const std::unordered_map<std::string, Attribute>& RuntimeAttrs() const {
return fwd_op_.GetRuntimeAttrMap();
}
const Attribute& GetAttr(const std::string& name) const {
auto& map = fwd_op_.GetAttrMap();
auto it = map.find(name);
......@@ -209,6 +213,7 @@ class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase {
retv.emplace_back(new OpDesc());
try {
this->Apply(retv.front().get());
retv.front()->SetRuntimeAttrMap(this->RuntimeAttrs());
} catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(retv.front().get()->Type(), &exception);
throw std::move(exception);
......
......@@ -298,6 +298,9 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second);
}
for (auto& attr : conv->Op()->GetRuntimeAttrMap()) {
desc.SetAttr(attr.first, attr.second);
}
auto conv_bias_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
......
......@@ -18,6 +18,7 @@ set(STANDALONE_EXECUTOR_DEPS
scope
framework_proto
data_feed_proto
ops_extra_info
heter_service_proto
trainer_desc_proto
glog
......
......@@ -21,6 +21,7 @@
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
......@@ -242,20 +243,30 @@ void build_variable_scope(const framework::BlockDesc& block,
void create_all_ops(const framework::BlockDesc& block,
std::vector<std::unique_ptr<OperatorBase>>* ops) {
for (auto& op : block.AllOps()) {
VLOG(3) << "CreateOp from : " << op->Type();
auto op_type = op->Type();
VLOG(1) << "CreateOp from : " << op_type;
auto& info = OpInfoMap::Instance().Get(op->Type());
auto& info = OpInfoMap::Instance().Get(op_type);
const VariableNameMap& inputs_names = op->Inputs();
const VariableNameMap& outputs_names = op->Outputs();
AttributeMap op_attr_map = op->GetAttrMap();
AttributeMap op_runtime_attr_map = op->GetRuntimeAttrMap();
if (info.Checker() != nullptr) {
info.Checker()->Check(&op_attr_map);
}
const auto& extra_attr_checkers =
operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(op_type);
for (const auto& checker : extra_attr_checkers) {
checker(&op_runtime_attr_map, false);
}
auto op_base =
info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
info.Creator()(op_type, inputs_names, outputs_names, op_attr_map);
op_base->SetRuntimeAttributeMap(op_runtime_attr_map);
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) {
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/utils/blank.h"
namespace paddle {
......@@ -409,6 +410,13 @@ class CompileTimeInferShapeContext : public InferShapeContext {
const BlockDesc &block_;
};
static void InitRuntimeAttributeMapByOpExtraInfo(const std::string &op_type,
AttributeMap *runtime_attrs) {
const auto &extra_attr_map =
operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type);
runtime_attrs->insert(extra_attr_map.begin(), extra_attr_map.end());
}
OpDesc::OpDesc(const std::string &type,
const VariableNameMap &inputs,
const VariableNameMap &outputs,
......@@ -419,6 +427,7 @@ OpDesc::OpDesc(const std::string &type,
attrs_ = attrs;
need_update_ = true;
block_ = nullptr;
InitRuntimeAttributeMapByOpExtraInfo(type, &runtime_attrs_);
}
OpDesc::OpDesc(const OpDesc &other) {
......@@ -441,6 +450,8 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
inputs_ = op_desc.inputs_;
outputs_ = op_desc.outputs_;
attrs_ = op_desc.attrs_;
runtime_attrs_ = op_desc.runtime_attrs_;
// The record of original_id_ is only for auto parallel.
original_id_ = op_desc.original_id_;
if (op_desc.dist_attr_) {
dist_attr_.reset(new OperatorDistAttr(*op_desc.dist_attr_));
......@@ -473,8 +484,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
}
}
// restore attrs_
InitRuntimeAttributeMapByOpExtraInfo(desc.type(), &runtime_attrs_);
for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
const std::string &attr_name = attr.name();
// The sub_block referred to by the BLOCK attr hasn't been added
// to ProgramDesc class yet, we skip setting BLOCK/BLOCKS/VAR/VARS attr
// here.
......@@ -483,7 +495,12 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
attr_type != proto::AttrType::BLOCKS &&
attr_type != proto::AttrType::VAR &&
attr_type != proto::AttrType::VARS) {
auto iter = runtime_attrs_.find(attr_name);
if (iter == runtime_attrs_.end()) {
attrs_[attr_name] = GetAttrValue(attr);
} else {
iter->second = GetAttrValue(attr);
}
}
}
this->block_ = block;
......@@ -622,7 +639,13 @@ std::vector<std::string> OpDesc::AttrNames(bool with_attr_var) const {
bool OpDesc::HasAttr(const std::string &name, bool with_attr_var) const {
auto iter = attrs_.find(name);
bool is_found = iter != attrs_.end();
bool is_found = true;
if (iter == attrs_.end()) {
iter = runtime_attrs_.find(name);
if (iter == runtime_attrs_.end()) {
is_found = false;
}
}
if (with_attr_var) {
return is_found;
}
......@@ -631,10 +654,19 @@ bool OpDesc::HasAttr(const std::string &name, bool with_attr_var) const {
void OpDesc::RemoveAttr(const std::string &name) {
attrs_.erase(name);
runtime_attrs_.erase(name);
need_update_ = true;
}
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
AttributeMap *attrs_ptr = &(this->attrs_);
const auto &extra_attr_map =
operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(Type());
auto extra_attr_iter = extra_attr_map.find(name);
if (extra_attr_iter != extra_attr_map.end()) {
attrs_ptr = &(this->runtime_attrs_);
}
// NOTICE(minqiyang): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type
// here if we meet this issue
......@@ -647,25 +679,25 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
case proto::AttrType::BOOLEANS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BOOLEANS";
this->attrs_[name] = std::vector<bool>();
attrs_ptr->operator[](name) = std::vector<bool>();
break;
}
case proto::AttrType::INTS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to INTS";
this->attrs_[name] = std::vector<int>();
attrs_ptr->operator[](name) = std::vector<int>();
break;
}
case proto::AttrType::LONGS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from LONGS to LONGS";
this->attrs_[name] = std::vector<int64_t>();
attrs_ptr->operator[](name) = std::vector<int64_t>();
break;
}
case proto::AttrType::FLOATS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to FLOATS";
this->attrs_[name] = std::vector<float>();
attrs_ptr->operator[](name) = std::vector<float>();
break;
}
case proto::AttrType::FLOAT64S: {
......@@ -677,13 +709,13 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
case proto::AttrType::STRINGS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to STRINGS";
this->attrs_[name] = std::vector<std::string>();
attrs_ptr->operator[](name) = std::vector<std::string>();
break;
}
case proto::AttrType::BLOCKS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BLOCKS";
this->SetBlocksAttr(name, std::vector<BlockDesc *>());
attrs_ptr->operator[](name) = std::vector<BlockDesc *>();
return;
}
default:
......@@ -695,14 +727,23 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
}
// In order to set bool attr properly
if (attr_type == proto::AttrType::INT && HasProtoAttr(name) &&
if (attr_type == proto::AttrType::INT) {
if (HasProtoAttr(name) &&
GetProtoAttr(name).type() == proto::AttrType::BOOLEAN) {
this->attrs_[name] = static_cast<bool>(PADDLE_GET_CONST(int, v));
attrs_ptr->operator[](name) = static_cast<bool>(PADDLE_GET_CONST(int, v));
need_update_ = true;
return;
}
if (extra_attr_iter != extra_attr_map.end() &&
static_cast<proto::AttrType>(extra_attr_iter->second.index() - 1) ==
proto::AttrType::BOOLEAN) {
attrs_ptr->operator[](name) = static_cast<bool>(PADDLE_GET_CONST(int, v));
need_update_ = true;
return;
}
}
this->attrs_[name] = v;
attrs_ptr->operator[](name) = v;
need_update_ = true;
}
......@@ -733,8 +774,17 @@ void OpDesc::SetAttrMap(
need_update_ = true;
}
void OpDesc::SetRuntimeAttrMap(
const std::unordered_map<std::string, Attribute> &attr_map) {
runtime_attrs_ = attr_map;
need_update_ = true;
}
Attribute OpDesc::GetAttr(const std::string &name, bool with_attr_var) const {
auto it = attrs_.find(name);
if (it == attrs_.end()) {
it = runtime_attrs_.find(name);
}
PADDLE_ENFORCE_NE(
it,
attrs_.end(),
......@@ -802,6 +852,8 @@ const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
return attrs_;
}
const AttributeMap &OpDesc::GetRuntimeAttrMap() const { return runtime_attrs_; }
void OpDesc::Rename(const std::string &old_name, const std::string &new_name) {
RenameInput(old_name, new_name);
RenameOutput(old_name, new_name);
......@@ -925,6 +977,15 @@ void OpDesc::Flush() {
}
this->desc_.mutable_attrs()->Clear();
auto set_attr_desc = [this](const std::string &attr_name,
const Attribute &attr) -> void {
auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr_name);
attr_desc->set_type(static_cast<proto::AttrType>(attr.index() - 1));
SetAttrDescVisitor visitor(attr_desc);
paddle::visit(visitor, attr);
};
std::vector<std::pair<std::string, Attribute>> sorted_attrs{attrs_.begin(),
attrs_.end()};
std::sort(
......@@ -932,13 +993,12 @@ void OpDesc::Flush() {
sorted_attrs.end(),
[](std::pair<std::string, Attribute> a,
std::pair<std::string, Attribute> b) { return a.first < b.first; });
for (auto &attr : sorted_attrs) {
auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast<proto::AttrType>(attr.second.index() - 1));
SetAttrDescVisitor visitor(attr_desc);
paddle::visit(visitor, attr.second);
set_attr_desc(attr.first, attr.second);
}
for (auto &attr : runtime_attrs_) {
set_attr_desc(attr.first, attr.second);
}
need_update_ = false;
......@@ -1155,7 +1215,7 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name,
}
AttrReader CompileTimeInferShapeContext::Attrs() const {
return AttrReader(op_.GetAttrMap());
return AttrReader(op_.GetAttrMap(), op_.GetRuntimeAttrMap());
}
std::vector<std::string> CompileTimeInferShapeContext::Inputs(
......
......@@ -144,6 +144,10 @@ class OpDesc {
// Only be used in C++
void SetAttrMap(const AttributeMap &attr_map);
void SetRuntimeAttrMap(const AttributeMap &attr_map);
const AttributeMap &GetRuntimeAttrMap() const;
std::vector<std::string> InputNames(bool with_attr_var = false) const {
return MapKeys(inputs_);
}
......@@ -221,6 +225,12 @@ class OpDesc {
VariableNameMap outputs_;
// attribute name => all original attrs
AttributeMap attrs_;
// runtime_attrs_ contains the attributes which used for dispatching kernel
// (use_mkldnn, use_cudnn, ...) or passing additional configuration for
// special heterogeneous kernel (workspace_size_MB, ...).
// The attributes in runtime_attrs_ are setted by framework (such as PASS),
// and not in the python api.
AttributeMap runtime_attrs_;
// need_update_ indicate there some local changes not be synchronized. If
// local changes should be synchronized, need_update_ should be set to true.
......
......@@ -12,8 +12,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h"
#include <string>
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -67,7 +66,16 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
op_checker_ = attr_checker;
Make();
op_checker_->RecordExplicitCheckerNum();
op_checker_->InitDefaultAttributeMap();
const AttributeMap* extra_attrs_ptr = nullptr;
const std::string& op_type = proto->type();
const auto& extra_attr_map =
operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type);
if (!extra_attr_map.empty()) {
extra_attrs_ptr = &extra_attr_map;
}
op_checker_->InitDefaultAttributeMap(extra_attrs_ptr);
AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum(
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "glog/logging.h"
......@@ -25,15 +26,47 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const VariableNameMap& outputs,
const AttributeMap& attrs,
bool attr_check) {
AttributeMap standard_attrs;
AttributeMap runtime_attrs =
paddle::operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(type);
for (auto& attr : attrs) {
auto it = runtime_attrs.find(attr.first);
if (it != runtime_attrs.end()) {
it->second = attr.second;
} else {
standard_attrs[attr.first] = attr.second;
}
}
auto& info = OpInfoMap::Instance().Get(type);
if (attr_check && info.Checker() != nullptr) {
info.Checker()->Check(&standard_attrs);
}
auto op_base = std::unique_ptr<OperatorBase>(
info.Creator()(type, inputs, outputs, standard_attrs));
op_base->SetRuntimeAttributeMap(runtime_attrs);
return op_base;
}
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs,
const AttributeMap& runtime_attrs,
bool attr_check) {
std::unique_ptr<OperatorBase> op_base;
auto& info = OpInfoMap::Instance().Get(type);
if (attr_check && info.Checker() != nullptr) {
auto tmp_attrs = attrs;
info.Checker()->Check(&tmp_attrs);
return std::unique_ptr<OperatorBase>(
op_base = std::unique_ptr<OperatorBase>(
info.Creator()(type, inputs, outputs, tmp_attrs));
}
return std::unique_ptr<OperatorBase>(
} else {
op_base = std::unique_ptr<OperatorBase>(
info.Creator()(type, inputs, outputs, attrs));
}
op_base->SetRuntimeAttributeMap(runtime_attrs);
return op_base;
}
static VariableNameMap ConvertOpDescVarsToVarNameMap(
......@@ -59,18 +92,27 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs;
AttributeMap extra_attrs =
paddle::operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(
op_desc.type());
for (auto& attr : op_desc.attrs()) {
auto it = extra_attrs.find(attr.name());
if (it != extra_attrs.end()) {
it->second = GetAttrValue(attr);
} else {
attrs[attr.name()] = GetAttrValue(attr);
}
}
return CreateOp(op_desc.type(), inputs, outputs, attrs);
return CreateOp(op_desc.type(), inputs, outputs, attrs, extra_attrs);
}
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
return CreateOp(op_desc.Type(),
op_desc.Inputs(),
op_desc.Outputs(),
op_desc.GetAttrMap());
op_desc.GetAttrMap(),
op_desc.GetRuntimeAttrMap());
}
} // namespace framework
......
......@@ -132,6 +132,13 @@ class OpRegistry {
const VariableNameMap& outputs,
const AttributeMap& attrs,
bool attr_check = true);
static std::unique_ptr<OperatorBase> CreateOp(
const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs,
const AttributeMap& runtime_attrs,
bool attr_check = true);
static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
......
......@@ -1389,10 +1389,9 @@ bool OperatorWithKernel::SupportsKernelType(
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
const auto& attrs_map = ctx.Attrs();
auto iter = attrs_map.find("use_mkldnn");
bool use_mkldnn_ctx = iter != attrs_map.end() &&
PADDLE_GET_CONST(bool, iter->second) &&
const std::string use_mkldnn_attr = "use_mkldnn";
bool use_mkldnn_ctx = ctx.HasAttr(use_mkldnn_attr) &&
ctx.Attr<bool>(use_mkldnn_attr) &&
platform::is_cpu_place(ctx.GetPlace());
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
}
......@@ -2881,12 +2880,16 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
} break;
default: {
PADDLE_ENFORCE_NE(
attr_iter,
Attrs().end(),
platform::errors::NotFound("(%s) is not found in AttributeMap when "
if (attr_iter == Attrs().end()) {
attr_iter = RuntimeAttrs().find(attr_names[i]);
PADDLE_ENFORCE_NE(attr_iter,
RuntimeAttrs().end(),
platform::errors::NotFound(
"(%s) is not found in AttributeMap when "
"buildind static KernelContext.",
attr_names[i]));
}
switch (attr_defs[i].type_index) {
case phi::AttributeType::FLOAT32:
phi_kernel_context->EmplaceBackAttr(
......
......@@ -177,7 +177,9 @@ class OperatorBase {
const std::string& Type() const { return type_; }
bool HasAttr(const std::string& name) const { return attrs_.count(name); }
bool HasAttr(const std::string& name) const {
return attrs_.count(name) || runtime_attrs_.count(name);
}
template <typename T>
inline const T& Attr(const std::string& name) const {
PADDLE_ENFORCE_NE(
......@@ -196,6 +198,10 @@ class OperatorBase {
attrs_[name] = v;
}
const AttributeMap& Attrs() const { return attrs_; }
const AttributeMap& RuntimeAttrs() const { return runtime_attrs_; }
void SetRuntimeAttributeMap(const AttributeMap& runtime_attrs) {
runtime_attrs_ = runtime_attrs;
}
const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; }
......@@ -250,6 +256,12 @@ class OperatorBase {
// IG (Inputs Gradients)
VariableNameMap outputs_;
AttributeMap attrs_;
// NOTE: runtime_attrs_ contains the attributes which used for dispatching
// kernel (use_mkldnn, use_cudnn, ...) or passing additional configuration
// for special heterogeneous kernel (workspace_size_MB, ...).
// The attributes in runtime_attrs_ are setted by framework (such as PASS),
// and not in the python api.
AttributeMap runtime_attrs_;
// OpInfo
const OpInfo* info_;
......@@ -302,7 +314,12 @@ class ExecutionContext {
}
virtual const Attribute& GetAttr(const std::string& name) const {
return op_.Attrs().at(name);
auto iter = op_.Attrs().find(name);
if (iter == op_.Attrs().end()) {
return op_.RuntimeAttrs().at(name);
} else {
return iter->second;
}
}
virtual bool HasInput(const std::string& name) const;
......
......@@ -73,7 +73,8 @@ cc_library(
denormal
garbage_collector
var_helper
layout_autotune)
layout_autotune
ops_extra_info)
cc_library(
basic_engine
SRCS basic_engine.cc
......
......@@ -23,6 +23,7 @@
#include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/platform/denormal.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -240,6 +241,11 @@ void Tracer::TraceOpImpl(const std::string& type,
if (attr_checker) {
attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
}
const auto& extra_attr_checkers =
operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(type);
for (const auto& checker : extra_attr_checkers) {
checker(&attrs, true);
}
static paddle::framework::AttributeMap empty_attrs_map = {};
const paddle::framework::AttributeMap& default_attrs =
......
......@@ -149,6 +149,7 @@ if (WITH_DGC)
endif()
cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEPS operator)
cc_library(ops_extra_info SRCS ops_extra_info.cc DEPS attribute cudnn_workspace_helper)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows_utils lapack_function
lod_tensor maxouting unpooling pooling lod_rank_table context_project
......
......@@ -348,82 +348,6 @@ void Conv2DOpMaker::Make() {
"dilations(h_dilation, w_dilation) of "
"convolution operator.")
.SetDefault({1, 1});
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false)
.AsExtra();
AddAttr<bool>("fuse_relu_before_depthwise_conv",
"(bool, default false) Only used in cuda depthwise kernel")
.SetDefault(false)
.AsExtra();
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddAttr<bool>(
"use_quantizer",
"(bool, default false) "
"This parameter is no longer used. Use 'mkldnn_data_type' instead.")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "int8", "bfloat16"})
.AsExtra();
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>("fuse_activation",
"(string, default \"\") Only used in mkldnn kernel")
.SetDefault("")
.AsExtra();
AddAttr<float>("fuse_alpha",
"(float, default 0.0) Only used in mkldnn kernel")
.SetDefault(0.0f)
.AsExtra();
AddAttr<float>("fuse_beta", "(float, default 0.0) Only used in mkldnn kernel")
.SetDefault(0.0f)
.AsExtra();
AddAttr<bool>(
"use_addto",
"(bool, default false) If use addto strategy or not, only used in "
"cudnn kernel")
.SetDefault(false)
.AsExtra();
AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is as an input to residual "
"connection.")
.SetDefault(false)
.AsExtra();
AddAttr<float>("Scale_in",
"Scale_in to be used for int8 input data."
"Only used with MKL-DNN INT8.")
.SetDefault(1.0f)
.AsExtra();
AddAttr<float>("Scale_out",
"Scale_out to be used for int8 output data."
"Only used with MKL-DNN INT8.")
.SetDefault(1.0f)
.AsExtra();
AddAttr<float>("Scale_in_eltwise",
"Scale_in_eltwise to be used for int8 eltwise input data."
"Only used with MKL-DNN INT8.")
.SetDefault(1.0f)
.AsExtra();
AddAttr<std::vector<float>>("Scale_weights",
"Scale_weights to be used for int8 weights data."
"Only used with MKL-DNN INT8.")
.SetDefault({1.0f})
.AsExtra();
AddAttr<bool>("force_fp32_output",
"(bool, default false) Force INT8 kernel output FP32, only "
"used in MKL-DNN INT8")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/attribute.h"
namespace paddle {
namespace operators {
template <typename T>
struct ExtraAttrChecker {
ExtraAttrChecker(const std::string& attr_name, T default_value)
: attr_name(attr_name), default_val(default_value) {}
void operator()(framework::AttributeMap* attr_map,
bool only_check_exist_value) {
auto it = attr_map->find(attr_name);
if (it == attr_map->end()) {
if (!only_check_exist_value) {
attr_map->emplace(attr_name, default_val);
}
return;
}
framework::ExtractAttribute<T> extract_attr(attr_name);
extract_attr(it->second);
}
const std::string& attr_name;
T default_val;
};
class ExtraInfoUtils {
public:
static ExtraInfoUtils& Instance() {
static ExtraInfoUtils extra_info_utils;
return extra_info_utils;
}
const std::unordered_map<std::string, paddle::framework::AttributeMap>&
GetAllExtraAttrsMap() const {
return g_extra_attrs_map_;
}
const paddle::framework::AttributeMap& GetExtraAttrsMap(
const std::string& op_type) const {
auto iter = g_extra_attrs_map_.find(op_type);
if (iter != g_extra_attrs_map_.end()) {
return iter->second;
}
return empty_extra_attrs_map_;
}
const std::vector<std::function<void(framework::AttributeMap*, bool)>>&
GetExtraAttrsChecker(const std::string& op_type) const {
auto iter = g_extra_attrs_checker_.find(op_type);
if (iter != g_extra_attrs_checker_.end()) {
return iter->second;
}
return empty_extra_attrs_checker_;
}
private:
ExtraInfoUtils();
std::unordered_map<std::string, paddle::framework::AttributeMap>
g_extra_attrs_map_;
paddle::framework::AttributeMap empty_extra_attrs_map_{};
std::unordered_map<
std::string,
std::vector<std::function<void(framework::AttributeMap*, bool)>>>
g_extra_attrs_checker_;
std::vector<std::function<void(framework::AttributeMap*, bool)>>
empty_extra_attrs_checker_{};
};
} // namespace operators
} // namespace paddle
......@@ -29,6 +29,7 @@
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/pybind/imperative.h"
namespace py = pybind11;
......@@ -961,6 +962,15 @@ void InitOpsAttrTypeMap() {
OpAttrTypeMap::Instance().Map()[iter->first][attr.name()] = attr.type();
}
}
const auto& extra_attr_maps =
operators::ExtraInfoUtils::Instance().GetAllExtraAttrsMap();
for (const auto& extra_attrs : extra_attr_maps) {
for (auto& attr : extra_attrs.second) {
OpAttrTypeMap::Instance().Map()[extra_attrs.first][attr.first] =
static_cast<paddle::framework::proto::AttrType>(attr.second.index() -
1);
}
}
}
ssize_t GetIdxFromCoreOpsInfoMap(
......
......@@ -71,6 +71,7 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -1068,6 +1069,23 @@ All parameter, weight, gradient are variables in Paddle.
}
return res;
});
m.def(
"get_op_extra_attrs",
[](const std::string &op_type)
-> const paddle::framework::AttributeMap & {
return operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type);
});
m.def(
"get_attrtibute_type",
[](const std::string &op_type,
const std::string &attr_name) -> paddle::framework::proto::AttrType {
const auto &defalut_val =
operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type).at(
attr_name);
return static_cast<paddle::framework::proto::AttrType>(
defalut_val.index() - 1);
});
m.def("get_grad_op_desc",
[](const OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set,
......
......@@ -100,7 +100,7 @@ set(ops_extra_info_gen_file
set(api_compat_yaml_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/api_compat.yaml)
set(ops_extra_info_file
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.h)
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.cc)
if(NOT PYTHONINTERP_FOUND)
find_package(PythonInterp REQUIRED)
......
# - api : conv3d_transpose
# backward : conv3d_transpose_grad
# extra :
# attrs : [bool use_cudnn = true, bool use_mkldnn = false, int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()]
- api : atan2
inputs :
{x : X1, y : X2}
......@@ -23,19 +28,20 @@
out : Out
- api : conv2d
backward : conv2d_grad
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false,
str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
- api : conv2d
- api : conv2d_fusion
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false,
str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
......@@ -48,6 +54,16 @@
outputs :
out : Out
- api : depthwise_conv2d
backward : depthwise_conv2d_grad
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
- api : diag
op_name : diag_v2
grad_op_name : diag_v2_grad
......
......@@ -18,17 +18,27 @@ import re
import argparse
def map_code_template(attrs_str):
return f"""
#include "paddle/fluid/framework/attribute.h"
def map_code_template(attrs_str, attrs_checker_str):
return f"""// This file is generated by paddle/phi/api/yaml/generator/ops_extra_info_gen.py
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
namespace paddle {{
const static std::unordered_map<std::string, paddle::framework::AttributeMap> extra_attrs_map = {{
{attrs_str}
}};
namespace operators {{
}} // namespace paddle
ExtraInfoUtils::ExtraInfoUtils() {{
g_extra_attrs_map_ = {{
{attrs_str}
}};
g_extra_attrs_checker_ = {{
{attrs_checker_str}
}};
}}
}} // namespace operators
}} // namespace paddle
"""
......@@ -61,6 +71,7 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path):
compat_apis = yaml.safe_load(f)
extra_map_str_list = []
extra_checker_str_list = []
for api_compat_args in compat_apis:
if 'extra' in api_compat_args:
......@@ -68,8 +79,12 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path):
# TODO(chenweihang): add inputs and outputs
if 'attrs' in extra_args_map:
attr_map_list = []
attr_checker_func_list = []
for attr in extra_args_map['attrs']:
attr_type, attr_name, default_val = parse_attr(attr)
attr_checker_func_list.append(
f"[](framework::AttributeMap* attr_map, bool only_check_exist_value)-> void {{ ExtraAttrChecker<{attr_type}>(\"{attr_name}\", {default_val})(attr_map, only_check_exist_value);}}"
)
if attr_type.startswith("std::vector"):
attr_map_list.append(
f"{{\"{attr_name}\", {attr_type}{default_val}}}")
......@@ -78,12 +93,26 @@ def generate_extra_info(api_compat_yaml_path, ops_extra_info_path):
f"{{\"{attr_name}\", {attr_type}{{{default_val}}}}}"
)
api_extra_attr_map = ", ".join(attr_map_list)
api_extra_attr_checkers = ",\n ".join(
attr_checker_func_list)
extra_map_str_list.append(
f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_map} }}}}"
)
extra_checker_str_list.append(
f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_checkers} }}}}"
)
if 'backward' in api_compat_args:
extra_map_str_list.append(
f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_map} }}}}"
)
extra_checker_str_list.append(
f"{{\"{api_compat_args['backward']}\", {{ {api_extra_attr_checkers} }}}}"
)
ops_extra_info_file = open(ops_extra_info_path, 'w')
ops_extra_info_file.write(map_code_template(",\n".join(extra_map_str_list)))
ops_extra_info_file.write(
map_code_template(",\n ".join(extra_map_str_list),
",\n ".join(extra_checker_str_list)))
ops_extra_info_file.close()
......@@ -96,7 +125,7 @@ def main():
parser.add_argument('--ops_extra_info_path',
help='output of generated extra_prama_info code file',
default='paddle/fluid/operators/ops_extra_info.h')
default='paddle/fluid/operators/ops_extra_info.cc')
options = parser.parse_args()
......
......@@ -21,6 +21,7 @@
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/variant.h"
namespace phi {
......@@ -47,4 +48,6 @@ using Attribute = paddle::variant<bool,
DataLayout,
Place>;
using RuntimeAttrs = paddle::flat_hash_map<std::string, Attribute>;
} // namespace phi
......@@ -138,6 +138,8 @@ class KernelContext {
template <typename AttrType>
const AttrType& AttrAt(size_t idx) const;
const RuntimeAttrs& GetRuntimeAttrs() const { return runtime_attrs_; }
size_t InputsSize() const { return inputs_.size(); }
size_t OutputsSize() const { return outputs_.size(); }
size_t AttrsSize() const { return attrs_.size(); }
......@@ -152,6 +154,8 @@ class KernelContext {
paddle::small_vector<std::pair<int, int>, kInputSmallVectorSize> input_range_;
paddle::small_vector<std::pair<int, int>, kOutputSmallVectorSize>
output_range_;
RuntimeAttrs runtime_attrs_;
};
} // namespace phi
......@@ -205,6 +205,8 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
args_def->AppendAttribute(AttributeType::DATA_LAYOUT);
} else if (arg_type == std::type_index(typeid(Place))) {
args_def->AppendAttribute(AttributeType::PLACE);
} else if (arg_type == std::type_index(typeid(RuntimeAttrs))) {
// do nothing
} else {
PADDLE_THROW(phi::errors::Unavailable(
"Unsupported kernel argument type `%s`.", arg_type.name()));
......
......@@ -321,6 +321,22 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(StringTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(StringTensor);
template <typename... Tail>
struct KernelCallHelper<const RuntimeAttrs&, Tail...> {
template <int dev_ctx_idx,
int in_idx,
int attr_idx,
int out_idx,
typename... PreviousArgs>
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) {
const auto& runtime_attrs = ctx->GetRuntimeAttrs();
KernelCallHelper<Tail...>::
template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx>(
ctx, pargs..., runtime_attrs);
}
};
/* End case */
template <typename T>
struct KernelCallHelper<TypeTag<T>> {
......
......@@ -2840,6 +2840,7 @@ class Operator(object):
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
extra_attrs_map = core.get_op_extra_attrs(type)
if op_attrs is not None:
if not isinstance(op_attrs, dict):
raise TypeError("'attrs' should be a dict.")
......@@ -2850,6 +2851,13 @@ class Operator(object):
continue
attr_val = op_attrs[attr_name]
self._update_desc_attr(attr_name, attr_val)
for attr_name in extra_attrs_map.keys():
if (attr_name
not in op_attrs) or (op_attrs[attr_name] is None):
self._update_desc_attr(attr_name,
extra_attrs_map[attr_name])
else:
self._update_desc_attr(attr_name, op_attrs[attr_name])
# proto.attrs doesn't include ipu_index
if core.is_compiled_with_ipu():
......@@ -5821,17 +5829,29 @@ class Program(object):
]
res._sync_with_cpp()
# Note: The op_role and op_role_var cann't be deleted currently,
# and we will try to remove them in the future.
common_clipped_attrs_list = ['op_namescope', 'op_callstack']
for i in six.moves.range(res.desc.num_blocks()):
block = res.desc.block(i)
for var in block.all_vars():
var.clear_is_parameter()
var.clear_stop_gradient()
if not clip_extra:
continue
for op_idx in range(0, block.op_size()):
op = block.op(op_idx)
if op.type() not in OpProtoHolder.instance().op_proto_map:
continue
if not clip_extra:
continue
extra_attrs_map = core.get_op_extra_attrs(op.type())
for name in op.attr_names():
if name in extra_attrs_map:
op.remove_attr(name)
continue
proto = OpProtoHolder.instance().get_op_proto(op.type())
remove_input_list = []
for name in op.input_names():
......@@ -5845,8 +5865,9 @@ class Program(object):
break
if not find:
remove_input_list.append(name)
for name in remove_input_list:
op.remove_input(name)
# The extra input of op will be removed in the future
# for name in remove_input_list:
# op.remove_input(name)
remove_output_list = []
for name in op.output_names():
......@@ -5860,10 +5881,10 @@ class Program(object):
break
if not find:
remove_output_list.append(name)
for name in remove_output_list:
op.remove_output(name)
# The extra input of op will be removed in the future
# for name in remove_output_list:
# op.remove_output(name)
remove_attr_list = []
op_quant_name = core.op_proto_and_checker_maker.kOpWithQuantAttrName(
)
quant = bool(op.attr(op_quant_name)
......@@ -5873,18 +5894,21 @@ class Program(object):
"activation_bits", "bit_length", "quantize_weight_bits",
"weight_quant_scale"
]
remove_attr_list = []
for name in op.attr_names():
if quant:
if name in quant_attrs:
continue
if name.endswith("_threshold"):
continue
if name in common_clipped_attrs_list:
remove_attr_list.append(name)
continue
find = False
for attr_proto in proto.attrs:
if attr_proto.name != name:
continue
if attr_proto.extra:
remove_attr_list.append(name)
find = True
break
if not find:
......
......@@ -52,6 +52,7 @@ class OpDescCreationMethod(object):
raise TypeError(
"Type of op_proto should be OpProto in PaddlePaddle.")
self.__op_proto__ = op_proto
self.__extra_attrs__ = core.get_op_extra_attrs(op_proto.type)
def __call__(self, *args, **kwargs):
"""
......@@ -130,6 +131,40 @@ class OpDescCreationMethod(object):
raise NotImplementedError(
"A not supported attribute type: %s." %
(str(attr.type)))
for attr_name, defalut_val in self.__extra_attrs__.items():
user_defined_attr = kwargs.get(attr_name, None)
if user_defined_attr is not None:
attr_type = int(
core.get_attrtibute_type(op_desc.type, attr_name))
new_attr = op_desc.attrs.add()
new_attr.name = attr_name
new_attr.type = attr_type
if isinstance(user_defined_attr, np.ndarray):
user_defined_attr = user_defined_attr.tolist()
if attr_type == framework_pb2.INT:
new_attr.i = user_defined_attr
elif attr_type == framework_pb2.FLOAT:
new_attr.f = user_defined_attr
elif attr_type == framework_pb2.LONG:
new_attr.l = user_defined_attr
elif attr_type == framework_pb2.STRING:
new_attr.s = user_defined_attr
elif attr_type == framework_pb2.BOOLEAN:
new_attr.b = user_defined_attr
elif attr_type == framework_pb2.INTS:
new_attr.ints.extend(user_defined_attr)
elif attr_type == framework_pb2.FLOATS:
new_attr.floats.extend(user_defined_attr)
elif attr_type == framework_pb2.STRINGS:
new_attr.strings.extend(user_defined_attr)
elif attr_type == framework_pb2.BOOLEANS:
new_attr.bools.extend(user_defined_attr)
elif attr_type == framework_pb2.LONGS:
new_attr.longs.extend(user_defined_attr)
else:
raise NotImplementedError(
"A not supported attribute type: %s." %
(str(attr_type)))
return op_desc
......@@ -147,12 +182,13 @@ class OpDescCreationMethod(object):
class OpInfo(object):
def __init__(self, name, method, inputs, outputs, attrs):
def __init__(self, name, method, inputs, outputs, attrs, extra_attrs):
self.name = name
self.method = method
self.inputs = inputs
self.outputs = outputs
self.attrs = attrs
self.extra_attrs = extra_attrs
def create_op_creation_method(op_proto):
......@@ -165,13 +201,16 @@ def create_op_creation_method(op_proto):
opdesc = method(*args, **kwargs)
return core.Operator.create(opdesc.SerializeToString())
extra_attrs_map = core.get_op_extra_attrs(op_proto.type)
return OpInfo(method=__impl__,
name=op_proto.type,
inputs=[(var.name, var.duplicable)
for var in op_proto.inputs],
outputs=[(var.name, var.duplicable)
for var in op_proto.outputs],
attrs=[attr.name for attr in op_proto.attrs])
attrs=[attr.name for attr in op_proto.attrs],
extra_attrs=[item for item in extra_attrs_map.keys()])
class OperatorFactory(object):
......@@ -222,6 +261,9 @@ class OperatorFactory(object):
def get_op_attr_names(self, type):
return self.get_op_info(type).attrs
def get_op_extra_attr_names(self, type):
return self.get_op_info(type).extra_attrs
class __RecurrentOp__(object):
__proto__ = None
......
......@@ -64,6 +64,10 @@ def create_op(scope, op_type, inputs, outputs, attrs, cache_list=None):
if attr_name in attrs:
kwargs[attr_name] = attrs[attr_name]
for extra_attr_name in Operator.get_op_extra_attr_names(op_type):
if extra_attr_name in attrs:
kwargs[extra_attr_name] = attrs[extra_attr_name]
return Operator(op_type, **kwargs)
......
......@@ -240,7 +240,7 @@ if [ "${HAS_MODIFIED_DECLARATIONS}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then
HAS_MODIFIED_API_COMPAT_YAML=`git diff --name-only upstream/$BRANCH | grep "paddle/phi/api/yaml/api_compat.yaml" || true`
if [ "${HAS_MODIFIED_API_COMPAT_YAML}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then
echo_line="You must be approved by chenwhql or zyfncg for paddle/phi/api/yaml/api_compat.yaml, which manages the extra params of Op and name mapping between Yaml and OpMaker. In order to ensure compatibility of framework, this file isn't allowed to be modified at will!\n"
echo_line="You must be approved by chenwhql or zyfncg for paddle/phi/api/yaml/api_compat.yaml changes, which manages the extra params of Op and name mapping between Yaml and OpMaker. In order to ensure compatibility of framework, this file isn't allowed to be modified at will!\n"
check_approval 1 chenwhql zyfncg
fi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册