未验证 提交 d4c4c53d 编写于 作者: Z zyfncg 提交者: GitHub

Clear extra attributes of matmul_v2 in OpMaker (#45708)

* set use_cudnn=true for conv2d

* clear opmaker of matmul_v2

* fix bug of set_attr

* add extra attr checker in infer_shape
上级 22f042ba
...@@ -661,10 +661,13 @@ void OpDesc::RemoveAttr(const std::string &name) { ...@@ -661,10 +661,13 @@ void OpDesc::RemoveAttr(const std::string &name) {
void OpDesc::SetAttr(const std::string &name, const Attribute &v) { void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
AttributeMap *attrs_ptr = &(this->attrs_); AttributeMap *attrs_ptr = &(this->attrs_);
bool is_runtime_attr = false;
const auto &extra_attr_map = const auto &extra_attr_map =
operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(Type()); operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(Type());
auto extra_attr_iter = extra_attr_map.find(name); auto extra_attr_iter = extra_attr_map.find(name);
if (extra_attr_iter != extra_attr_map.end()) { if (extra_attr_iter != extra_attr_map.end()) {
is_runtime_attr = true;
attrs_ptr = &(this->runtime_attrs_); attrs_ptr = &(this->runtime_attrs_);
} }
// NOTICE(minqiyang): pybind11 will take the empty list in python as // NOTICE(minqiyang): pybind11 will take the empty list in python as
...@@ -674,8 +677,11 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -674,8 +677,11 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
if (attr_type == proto::AttrType::INTS && if (attr_type == proto::AttrType::INTS &&
PADDLE_GET_CONST(std::vector<int>, v).size() == 0u) { PADDLE_GET_CONST(std::vector<int>, v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value // Find current attr via attr name and set the correct attribute value
const proto::OpProto::Attr &attr = GetProtoAttr(name); auto attr_type =
switch (attr.type()) { is_runtime_attr
? static_cast<proto::AttrType>(extra_attr_iter->second.index() - 1)
: GetProtoAttr(name).type();
switch (attr_type) {
case proto::AttrType::BOOLEANS: { case proto::AttrType::BOOLEANS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BOOLEANS"; << " from INTS to BOOLEANS";
...@@ -720,7 +726,7 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -720,7 +726,7 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
} }
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported attribute type (code %d).", attr.type())); "Unsupported attribute type (code %d).", attr_type));
} }
need_update_ = true; need_update_ = true;
return; return;
......
...@@ -194,44 +194,6 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -194,44 +194,6 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
"Set true to transpose the last two dimensions of Y before " "Set true to transpose the last two dimensions of Y before "
"doing multiplication") "doing multiplication")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>(
"fused_reshape_Out",
R"DOC(When MKLDNN matmul_v2_transpose_reshape fuse activated, "
"it's a shape atribute of fused reshape for `Out` output.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>(
"fused_transpose_Out",
R"DOC(When MKLDNN matmul_v2_transpose_reshape fuse activated, "
"it's a axis atribute of fused transpose for `Out` output.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"})
.AsExtra();
AddAttr<std::vector<int>>("fused_reshape_X",
R"DOC(Shape of fused reshape of `X` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("fused_reshape_Y",
R"DOC(Shape of fused reshape of `Y` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("fused_transpose_X",
R"DOC(Axis of fused transpose of `X` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("fused_transpose_Y",
R"DOC(Axis of fused transpose of `Y` input.)DOC")
.SetDefault({})
.AsExtra();
AddComment( AddComment(
R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K), R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K),
B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)). B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)).
......
...@@ -249,6 +249,13 @@ ...@@ -249,6 +249,13 @@
extra : extra :
attrs : [bool use_mkldnn = false, bool is_test = false] attrs : [bool use_mkldnn = false, bool is_test = false]
- api : matmul (matmul_v2)
backward : matmul_grad (matmul_v2_grad)
extra :
attrs : [bool use_mkldnn = false, 'int[] fused_reshape_Out = {}', 'int[] fused_transpose_Out = {}',
str mkldnn_data_type = "float32", 'int[] fused_reshape_X = {}', 'int[] fused_reshape_Y = {}',
'int[] fused_transpose_X = {}', 'int[] fused_transpose_Y = {}',]
- api : mv - api : mv
inputs : inputs :
{x : X, vec : Vec} {x : X, vec : Vec}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册