未验证 提交 13568f85 编写于 作者: C cc 提交者: GitHub

Support quantizing softmax op, test=develop (#3051)

上级 51d15cc1
......@@ -45,7 +45,7 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
}
// delete quant_dequant_node
for (auto op_type : {"pool2d", "elementwise_add"}) {
for (auto op_type : {"pool2d", "softmax", "elementwise_add"}) {
fusion::DeleteQuantDequantOpFuser fuser(op_type);
fuser(graph.get());
}
......
......@@ -297,7 +297,7 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
void DeleteQuantDequantOpFuser::BuildPattern() {
std::string quant_dequant_op_type =
"fake_quantize_dequantize_moving_average_abs_max";
if (quantized_op_type_ == "pool2d") {
if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") {
auto* input_scale_node =
VarNode("input_scale_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
......@@ -374,7 +374,7 @@ void DeleteQuantDequantOpFuser::BuildPattern() {
void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
if (quantized_op_type_ == "pool2d") {
if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") {
auto* input_scale_node = matched.at("input_scale_node");
auto* input_act_node = matched.at("input_act_node");
auto* quant_dequant_node = matched.at("quant_dequant_node");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册