未验证 提交 627091df 编写于 作者: C cc 提交者: GitHub

Support quantizing conv2d_transpose, test=develop (#3893)

上级 405e5de6
......@@ -34,12 +34,13 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
}
// fuse quantized node and dequant node
for (auto& op_type : {"conv2d", "mul", "depthwise_conv2d"}) {
for (auto& op_type :
{"conv2d", "mul", "depthwise_conv2d", "conv2d_transpose"}) {
fusion::DequantOpFuser fuser(op_type);
fuser(graph.get());
}
for (auto& op_type : {"conv2d", "depthwise_conv2d"}) {
for (auto& op_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) {
fusion::ChannelWiseDequantOpFuser fuser(op_type);
fuser(graph.get());
}
......
......@@ -23,6 +23,20 @@ namespace lite {
namespace mir {
namespace fusion {
static std::string GetWeightArgname(const std::string& op_type) {
std::string weight_argname{};
std::vector<std::string> conv_ops = {
"conv2d", "depthwise_conv2d", "conv2d_transpose"};
std::vector<std::string> mul_ops = {"mul", "matmul"};
if (std::find(conv_ops.begin(), conv_ops.end(), op_type) != conv_ops.end()) {
weight_argname = "Filter";
} else if (std::find(mul_ops.begin(), mul_ops.end(), op_type) !=
mul_ops.end()) {
weight_argname = "Y";
}
return weight_argname;
}
void DeleteQuantOpFuser::BuildPattern() {
auto* input_scale_node = VarNode("input_scale_node")
->assert_is_op_input(quant_op_type_, "InScale");
......@@ -83,20 +97,13 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) {
}
void DequantOpFuser::BuildPattern() {
std::string weight_name = "";
if (quantized_op_type_ == "conv2d" ||
quantized_op_type_ == "depthwise_conv2d") {
weight_name = "Filter";
} else {
weight_name = "Y";
}
std::string weight_argname = GetWeightArgname(quantized_op_type_);
auto* quantized_op_input = VarNode("quantized_op_input")
->assert_is_op_input(quantized_op_type_)
->AsInput();
auto* quantized_op_weight =
VarNode("quantized_op_weight")
->assert_is_op_input(quantized_op_type_, weight_name)
->assert_is_op_input(quantized_op_type_, weight_argname)
->AsInput();
auto* quantized_op = OpNode("quantized_op", quantized_op_type_)
->assert_is_op(quantized_op_type_)
......@@ -152,7 +159,8 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
std::vector<float> weight_scale;
int weight_scale_size = 0;
if (quantized_op_type_ == "conv2d" ||
quantized_op_type_ == "depthwise_conv2d") {
quantized_op_type_ == "depthwise_conv2d" ||
quantized_op_type_ == "conv2d_transpose") {
op_desc.SetInput("Input", {quantized_op_input->arg()->name});
op_desc.SetOutput("Output", {dequant_op_out->arg()->name});
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
......@@ -199,12 +207,13 @@ cpp::OpDesc DequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
void ChannelWiseDequantOpFuser::BuildPattern() {
std::string dequant_op_type = "fake_channel_wise_dequantize_max_abs";
std::string weight_argname = GetWeightArgname(quantized_op_type_);
auto* quantized_op_input = VarNode("quantized_op_input")
->assert_is_op_input(quantized_op_type_)
->AsInput();
auto* quantized_op_weight =
VarNode("quantized_op_weight")
->assert_is_op_input(quantized_op_type_, "Filter")
->assert_is_op_input(quantized_op_type_, weight_argname)
->AsInput();
auto* quantized_op = OpNode("quantized_op", quantized_op_type_)
->assert_is_op(quantized_op_type_)
......@@ -263,14 +272,17 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
// set op desc
auto op_desc = *quantized_op->stmt()->op_info();
if (quantized_op_type_ == "conv2d" ||
quantized_op_type_ == "depthwise_conv2d") {
quantized_op_type_ == "depthwise_conv2d" ||
quantized_op_type_ == "conv2d_transpose") {
op_desc.SetInput("Input", {quantized_op_input->arg()->name});
op_desc.SetOutput("Output", {dequant_op_out->arg()->name});
} else if (quantized_op_type_ == "mul" || quantized_op_type_ == "matmul") {
op_desc.SetInput("X", {quantized_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
}
op_desc.SetAttr("enable_int8", true);
if (quantized_op_type_ != "conv2d_transpose") {
op_desc.SetAttr("enable_int8", true);
}
op_desc.SetInputScale(weight_name, weight_scale);
// change the weight from the float type to int8 type.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册