提交 05b5ef29 编写于 作者: C cc 提交者: GitHub

Support quantizing softmax op, test=develop (#3051)

上级 1448efc0
...@@ -45,7 +45,7 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -45,7 +45,7 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} }
// delete quant_dequant_node // delete quant_dequant_node
for (auto op_type : {"pool2d", "elementwise_add"}) { for (auto op_type : {"pool2d", "softmax", "elementwise_add"}) {
fusion::DeleteQuantDequantOpFuser fuser(op_type); fusion::DeleteQuantDequantOpFuser fuser(op_type);
fuser(graph.get()); fuser(graph.get());
} }
......
...@@ -297,7 +297,7 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -297,7 +297,7 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
void DeleteQuantDequantOpFuser::BuildPattern() { void DeleteQuantDequantOpFuser::BuildPattern() {
std::string quant_dequant_op_type = std::string quant_dequant_op_type =
"fake_quantize_dequantize_moving_average_abs_max"; "fake_quantize_dequantize_moving_average_abs_max";
if (quantized_op_type_ == "pool2d") { if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") {
auto* input_scale_node = auto* input_scale_node =
VarNode("input_scale_node") VarNode("input_scale_node")
->assert_is_op_input(quant_dequant_op_type, "InScale"); ->assert_is_op_input(quant_dequant_op_type, "InScale");
...@@ -374,7 +374,7 @@ void DeleteQuantDequantOpFuser::BuildPattern() { ...@@ -374,7 +374,7 @@ void DeleteQuantDequantOpFuser::BuildPattern() {
void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) { const key2nodes_t& matched) {
if (quantized_op_type_ == "pool2d") { if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") {
auto* input_scale_node = matched.at("input_scale_node"); auto* input_scale_node = matched.at("input_scale_node");
auto* input_act_node = matched.at("input_act_node"); auto* input_act_node = matched.at("input_act_node");
auto* quant_dequant_node = matched.at("quant_dequant_node"); auto* quant_dequant_node = matched.at("quant_dequant_node");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册