From d4c4c53da189e52848db3d2b4e13ffb1af8adc71 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 6 Sep 2022 11:14:32 +0800 Subject: [PATCH] Clear extra attributes of matmul_v2 in OpMaker (#45708) * set use_cudnn=true for conv2d * clear opmaker of matmul_v2 * fix bug of set_attr * add extra attr checker in infer_shape --- paddle/fluid/framework/op_desc.cc | 12 ++++++-- paddle/fluid/operators/matmul_v2_op.cc | 38 -------------------------- paddle/phi/api/yaml/api_compat.yaml | 7 +++++ 3 files changed, 16 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 749630b009..fca4ff253d 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -661,10 +661,13 @@ void OpDesc::RemoveAttr(const std::string &name) { void OpDesc::SetAttr(const std::string &name, const Attribute &v) { AttributeMap *attrs_ptr = &(this->attrs_); + bool is_runtime_attr = false; + const auto &extra_attr_map = operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(Type()); auto extra_attr_iter = extra_attr_map.find(name); if (extra_attr_iter != extra_attr_map.end()) { + is_runtime_attr = true; attrs_ptr = &(this->runtime_attrs_); } // NOTICE(minqiyang): pybind11 will take the empty list in python as @@ -674,8 +677,11 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { if (attr_type == proto::AttrType::INTS && PADDLE_GET_CONST(std::vector, v).size() == 0u) { // Find current attr via attr name and set the correct attribute value - const proto::OpProto::Attr &attr = GetProtoAttr(name); - switch (attr.type()) { + auto attr_type = + is_runtime_attr + ? static_cast(extra_attr_iter->second.index() - 1) + : GetProtoAttr(name).type(); + switch (attr_type) { case proto::AttrType::BOOLEANS: { VLOG(11) << "SetAttr: " << Type() << ", " << name << " from INTS to BOOLEANS"; @@ -720,7 +726,7 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { } default: PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported attribute type (code %d).", attr.type())); + "Unsupported attribute type (code %d).", attr_type)); } need_update_ = true; return; diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 8c045630af..209bf6d1f6 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -194,44 +194,6 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { "Set true to transpose the last two dimensions of Y before " "doing multiplication") .SetDefault(false); - AddAttr>( - "fused_reshape_Out", - R"DOC(When MKLDNN matmul_v2_transpose_reshape fuse activated, " - "it's a shape atribute of fused reshape for `Out` output.)DOC") - .SetDefault({}) - .AsExtra(); - AddAttr>( - "fused_transpose_Out", - R"DOC(When MKLDNN matmul_v2_transpose_reshape fuse activated, " - "it's a axis atribute of fused transpose for `Out` output.)DOC") - .SetDefault({}) - .AsExtra(); - AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false) - .AsExtra(); - AddAttr( - "mkldnn_data_type", - "(string, default \"float32\"). Data type of mkldnn kernel") - .SetDefault("float32") - .InEnum({"float32", "bfloat16"}) - .AsExtra(); - AddAttr>("fused_reshape_X", - R"DOC(Shape of fused reshape of `X` input.)DOC") - .SetDefault({}) - .AsExtra(); - AddAttr>("fused_reshape_Y", - R"DOC(Shape of fused reshape of `Y` input.)DOC") - .SetDefault({}) - .AsExtra(); - AddAttr>("fused_transpose_X", - R"DOC(Axis of fused transpose of `X` input.)DOC") - .SetDefault({}) - .AsExtra(); - AddAttr>("fused_transpose_Y", - R"DOC(Axis of fused transpose of `Y` input.)DOC") - .SetDefault({}) - .AsExtra(); AddComment( R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K), B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)). diff --git a/paddle/phi/api/yaml/api_compat.yaml b/paddle/phi/api/yaml/api_compat.yaml index 14d54df783..310538f036 100644 --- a/paddle/phi/api/yaml/api_compat.yaml +++ b/paddle/phi/api/yaml/api_compat.yaml @@ -249,6 +249,13 @@ extra : attrs : [bool use_mkldnn = false, bool is_test = false] +- api : matmul (matmul_v2) + backward : matmul_grad (matmul_v2_grad) + extra : + attrs : [bool use_mkldnn = false, 'int[] fused_reshape_Out = {}', 'int[] fused_transpose_Out = {}', + str mkldnn_data_type = "float32", 'int[] fused_reshape_X = {}', 'int[] fused_reshape_Y = {}', + 'int[] fused_transpose_X = {}', 'int[] fused_transpose_Y = {}',] + - api : mv inputs : {x : X, vec : Vec} -- GitLab