未验证 提交 4b89120b 编写于 作者: S Sławomir Siwek 提交者: GitHub

Remove mkldnn attributes from base ops (#42852)

* remove attrs from base op

* fix typos

* remove brelu

* undo removing code related to matmul

* remove whitespaces

* undo changes in matmul

* remove empty line
上级 94194275
...@@ -489,14 +489,6 @@ void QuantDequantMkldnnPass::UpdateActivations(ir::Graph* graph) const { ...@@ -489,14 +489,6 @@ void QuantDequantMkldnnPass::UpdateActivations(ir::Graph* graph) const {
std::string activation; std::string activation;
if (op_desc->GetAttrIfExists<bool>("fuse_relu")) { if (op_desc->GetAttrIfExists<bool>("fuse_relu")) {
activation = "relu"; activation = "relu";
} else if (op_desc->GetAttrIfExists<bool>("fuse_brelu")) {
activation = "relu6";
float alpha = 6.0;
if (op_desc->HasAttr("fuse_brelu_threshold")) {
alpha = BOOST_GET_CONST(float,
op_desc->GetAttr("fuse_brelu_threshold"));
}
op_node->Op()->SetAttr("fuse_alpha", alpha);
} }
op_node->Op()->SetAttr("fuse_activation", activation); op_node->Op()->SetAttr("fuse_activation", activation);
} }
......
...@@ -77,14 +77,6 @@ extra { ...@@ -77,14 +77,6 @@ extra {
name: "fuse_relu" name: "fuse_relu"
type: BOOLEAN type: BOOLEAN
} }
attrs {
name: "fuse_brelu"
type: BOOLEAN
}
attrs {
name: "fuse_brelu_threshold"
type: FLOAT
}
attrs { attrs {
name: "fuse_activation" name: "fuse_activation"
type: STRING type: STRING
...@@ -134,4 +126,3 @@ extra { ...@@ -134,4 +126,3 @@ extra {
type: BOOLEAN type: BOOLEAN
} }
} }
...@@ -69,14 +69,6 @@ extra { ...@@ -69,14 +69,6 @@ extra {
name: "fuse_relu" name: "fuse_relu"
type: BOOLEAN type: BOOLEAN
} }
attrs {
name: "fuse_brelu"
type: BOOLEAN
}
attrs {
name: "fuse_brelu_threshold"
type: FLOAT
}
attrs { attrs {
name: "fuse_activation" name: "fuse_activation"
type: STRING type: STRING
...@@ -126,4 +118,3 @@ extra { ...@@ -126,4 +118,3 @@ extra {
type: BOOLEAN type: BOOLEAN
} }
} }
...@@ -348,14 +348,6 @@ void Conv2DOpMaker::Make() { ...@@ -348,14 +348,6 @@ void Conv2DOpMaker::Make() {
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel") AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false) .SetDefault(false)
.AsExtra(); .AsExtra();
AddAttr<bool>("fuse_brelu",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddAttr<float>("fuse_brelu_threshold",
"(float, default false 6.0) Only used in mkldnn kernel")
.SetDefault(6.0f)
.AsExtra();
AddAttr<std::string>("fuse_activation", AddAttr<std::string>("fuse_activation",
"(string, default \"\") Only used in mkldnn kernel") "(string, default \"\") Only used in mkldnn kernel")
.SetDefault("") .SetDefault("")
......
...@@ -376,13 +376,6 @@ class Quant2Int8MkldnnPass(object): ...@@ -376,13 +376,6 @@ class Quant2Int8MkldnnPass(object):
activation = "" activation = ""
if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"): if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"):
activation = "relu" activation = "relu"
elif op.op().has_attr("fuse_brelu") and op.op().attr(
"fuse_brelu"):
activation = "relu6"
alpha = 6.0
if op.op().has_attr("fuse_brelu_threshold"):
alpha = op.op().attr("fuse_brelu_threshold")
op.set_attr("fuse_alpha", alpha)
op.set_attr("fuse_activation", activation) op.set_attr("fuse_activation", activation)
return graph return graph
......
...@@ -177,8 +177,7 @@ class TestQuant2Int8MkldnnPassConv2D(unittest.TestCase): ...@@ -177,8 +177,7 @@ class TestQuant2Int8MkldnnPassConv2D(unittest.TestCase):
'dilations': self.dilations, 'dilations': self.dilations,
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn, 'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format, 'data_format': self.data_format
'fuse_brelu': True
}) })
def remove_fuse_activation_attribute(self, graph): def remove_fuse_activation_attribute(self, graph):
...@@ -196,9 +195,6 @@ class TestQuant2Int8MkldnnPassConv2D(unittest.TestCase): ...@@ -196,9 +195,6 @@ class TestQuant2Int8MkldnnPassConv2D(unittest.TestCase):
self.assertTrue(op.op().has_attr("fuse_activation")) self.assertTrue(op.op().has_attr("fuse_activation"))
if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"): if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"):
self.assertTrue(op.op().attr("fuse_activation") == "relu") self.assertTrue(op.op().attr("fuse_activation") == "relu")
if op.op().has_attr("fuse_brelu") and op.op().attr(
"fuse_brelu"):
self.assertTrue(op.op().attr("fuse_activation") == "relu6")
def test_quant_update_activation(self): def test_quant_update_activation(self):
program = fluid.Program() program = fluid.Program()
......
...@@ -61,7 +61,6 @@ class TestConv2DMKLDNNOp(TestConv2DOp): ...@@ -61,7 +61,6 @@ class TestConv2DMKLDNNOp(TestConv2DOp):
self.fuse_activation = "" self.fuse_activation = ""
self.fuse_alpha = 0 self.fuse_alpha = 0
self.fuse_beta = 0 self.fuse_beta = 0
self.fuse_brelu_threshold = 6.0
self.fuse_residual_connection = False self.fuse_residual_connection = False
self.input_residual_size = None self.input_residual_size = None
...@@ -99,7 +98,6 @@ class TestConv2DMKLDNNOp(TestConv2DOp): ...@@ -99,7 +98,6 @@ class TestConv2DMKLDNNOp(TestConv2DOp):
self.attrs['fuse_activation'] = self.fuse_activation self.attrs['fuse_activation'] = self.fuse_activation
self.attrs['fuse_alpha'] = self.fuse_alpha self.attrs['fuse_alpha'] = self.fuse_alpha
self.attrs['fuse_beta'] = self.fuse_beta self.attrs['fuse_beta'] = self.fuse_beta
self.attrs['fuse_brelu_threshold'] = self.fuse_brelu_threshold
self.attrs['fuse_residual_connection'] = self.fuse_residual_connection self.attrs['fuse_residual_connection'] = self.fuse_residual_connection
self.outputs['Output'] = output self.outputs['Output'] = output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册