未验证 提交 eb529128 编写于 作者: 王明冬 提交者: GitHub

fix the attributes error in transpose.pbtxt,test=develop. (#33770)

上级 2c4cc68f
...@@ -34,10 +34,13 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { ...@@ -34,10 +34,13 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("alpha") // unconstrained. can be any float value. .AddAttr("alpha") // unconstrained. can be any float value.
.IsType<float>()
.End() .End()
.AddAttr("transpose_X") // unconstrained. can be any bool value. .AddAttr("transpose_X") // unconstrained. can be any bool value.
.IsType<bool>()
.End() .End()
.AddAttr("transpose_Y") // unconstrained. can be any bool value. .AddAttr("transpose_Y") // unconstrained. can be any bool value.
.IsType<bool>()
.End(); .End();
AddOpCompat(OpCompat("transpose2")) AddOpCompat(OpCompat("transpose2"))
...@@ -51,9 +54,7 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { ...@@ -51,9 +54,7 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("axis") // ints .AddAttr("axis") // ints
.End() .IsType<std::vector<int>>()
.AddAttr("data_format")
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("reshape2")) AddOpCompat(OpCompat("reshape2"))
...@@ -75,6 +76,7 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { ...@@ -75,6 +76,7 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("shape") // ints .AddAttr("shape") // ints
.IsType<std::vector<int>>()
.End(); .End();
} }
void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
......
...@@ -28,7 +28,6 @@ void SetOp(ProgramDesc *prog, const std::string &type, ...@@ -28,7 +28,6 @@ void SetOp(ProgramDesc *prog, const std::string &type,
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
if (type == "transpose2") { if (type == "transpose2") {
op->SetAttr("axis", std::vector<int>({0, 2, 1, 3})); op->SetAttr("axis", std::vector<int>({0, 2, 1, 3}));
op->SetAttr("data_format", std::string("NCHW"));
op->SetOutput("XShape", {outputs[1]}); op->SetOutput("XShape", {outputs[1]});
} }
if (type == "reshape2") { if (type == "reshape2") {
......
...@@ -10,12 +10,12 @@ def { ...@@ -10,12 +10,12 @@ def {
name: "axis" name: "axis"
type: INTS type: INTS
} }
}
extra {
attrs { attrs {
name: "data_format" name: "data_format"
type: STRING type: STRING
} }
}
extra {
attrs { attrs {
name: "use_mkldnn" name: "use_mkldnn"
type: BOOLEAN type: BOOLEAN
......
...@@ -13,12 +13,12 @@ def { ...@@ -13,12 +13,12 @@ def {
name: "axis" name: "axis"
type: INTS type: INTS
} }
}
extra {
attrs { attrs {
name: "data_format" name: "data_format"
type: STRING type: STRING
} }
}
extra {
attrs { attrs {
name: "use_mkldnn" name: "use_mkldnn"
type: BOOLEAN type: BOOLEAN
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册