From e45e6dc6763e9c1b8abb7704478399cd55117afc Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 24 Jul 2020 11:14:00 +0800 Subject: [PATCH] Support dygraph quantized model, test=develop (#3974) --- .../mir/fusion/quant_dequant_fuse_pass.cc | 17 +++-- .../core/mir/fusion/quant_dequant_op_fuser.cc | 45 +++++++------ lite/core/mir/fusion/quant_dequant_op_fuser.h | 11 +++- lite/operators/CMakeLists.txt | 1 + .../fake_quantize_dequantize_abs_max.cc | 25 +++++++ .../fake_quantize_dequantize_abs_max.h | 65 +++++++++++++++++++ lite/operators/op_params.h | 7 ++ 7 files changed, 141 insertions(+), 30 deletions(-) create mode 100644 lite/operators/fake_quantize_dequantize_abs_max.cc create mode 100644 lite/operators/fake_quantize_dequantize_abs_max.h diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index ea8400b0bb..da42d6d0c7 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -34,20 +34,25 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { } // fuse quantized node and dequant node - for (auto& op_type : - {"conv2d", "mul", "depthwise_conv2d", "conv2d_transpose"}) { + std::vector quantized_op_types = { + "conv2d", "depthwise_conv2d", "conv2d_transpose", "mul"}; + for (auto& op_type : quantized_op_types) { fusion::DequantOpFuser fuser(op_type); fuser(graph.get()); } - - for (auto& op_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) { + for (auto& op_type : quantized_op_types) { fusion::ChannelWiseDequantOpFuser fuser(op_type); fuser(graph.get()); } // process quant_dequant_node - fusion::DeleteQuantDequantOpFuser dqd_fuser; - dqd_fuser(graph.get()); + std::vector quant_dequant_op_types = { + "fake_quantize_dequantize_abs_max", + "fake_quantize_dequantize_moving_average_abs_max"}; + for (auto& op_type : quant_dequant_op_types) { + fusion::DeleteQuantDequantOpFuser dqd_fuser(op_type); + dqd_fuser(graph.get()); + } } } // namespace mir diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index 76796468da..758a85c840 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -175,11 +175,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, for (int i = 0; i < weight_scale_size; i++) { weight_scale.push_back(whole_weight_scale); } - - // Arm CPU does not support conv2d_transpose - if (quantized_op_type_ != "conv2d_transpose") { - op_desc.SetAttr("enable_int8", true); - } + op_desc.SetAttr("enable_int8", true); op_desc.SetInputScale(weight_name, weight_scale); // change the weight from the float type to int8 type. @@ -284,7 +280,6 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, op_desc.SetInput("X", {quantized_op_input->arg()->name}); op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); } - // Arm CPU does not support conv2d_transpose if (quantized_op_type_ != "conv2d_transpose") { op_desc.SetAttr("enable_int8", true); } @@ -320,30 +315,33 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { } void DeleteQuantDequantOpFuser::BuildPattern() { - std::string quant_dequant_op_type = - "fake_quantize_dequantize_moving_average_abs_max"; - auto* input_scale_node = - VarNode("input_scale_node") - ->assert_is_op_input(quant_dequant_op_type, "InScale"); - auto* input_act_node = - VarNode("input_act_node")->assert_is_op_input(quant_dequant_op_type, "X"); - auto* quant_dequant_node = OpNode("quant_dequant_node", quant_dequant_op_type) - ->assert_is_op(quant_dequant_op_type); + auto* input_act_node = VarNode("input_act_node") + ->assert_is_op_input(quant_dequant_op_type_, "X"); + auto* quant_dequant_node = + OpNode("quant_dequant_node", quant_dequant_op_type_) + ->assert_is_op(quant_dequant_op_type_); auto* output_scale_node = VarNode("output_scale_node") - ->assert_is_op_output(quant_dequant_op_type, "OutScale"); + ->assert_is_op_output(quant_dequant_op_type_, "OutScale"); auto* output_act_node = VarNode("output_act_node") - ->assert_is_op_output(quant_dequant_op_type, "Out"); - - quant_dequant_node->LinksFrom({input_scale_node, input_act_node}); + ->assert_is_op_output(quant_dequant_op_type_, "Out"); + + if (quant_dequant_op_type_ == + "fake_quantize_dequantize_moving_average_abs_max") { + auto* input_scale_node = + VarNode("input_scale_node") + ->assert_is_op_input(quant_dequant_op_type_, "InScale"); + quant_dequant_node->LinksFrom({input_scale_node, input_act_node}); + } else { + quant_dequant_node->LinksFrom({input_act_node}); + } output_scale_node->LinksFrom({quant_dequant_node}); output_act_node->LinksFrom({quant_dequant_node}); } void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { - auto* input_scale_node = matched.at("input_scale_node"); auto* input_act_node = matched.at("input_act_node"); auto* quant_dequant_node = matched.at("quant_dequant_node"); auto* output_scale_node = matched.at("output_scale_node"); @@ -373,7 +371,12 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, } // delete nodes and edges std::set nodes2rm = { - input_scale_node, quant_dequant_node, output_scale_node, output_act_node}; + quant_dequant_node, output_scale_node, output_act_node}; + if (quant_dequant_op_type_ == + "fake_quantize_dequantize_moving_average_abs_max") { + auto* input_scale_node = matched.at("input_scale_node"); + nodes2rm.insert(input_scale_node); + } GraphSafeRemoveNodes(graph, nodes2rm); } diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.h b/lite/core/mir/fusion/quant_dequant_op_fuser.h index ac3ac112b3..c2dd1e5191 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.h +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.h @@ -86,17 +86,22 @@ class ChannelWiseDequantOpFuser : public FuseBase { std::string quantized_op_type_{}; }; -/* The pattern like "fake_quantize_dequantize_moving_average_abs_max + - * quantized_op" can be deteted by this fuser. The fuser modifies the input - * scale for the quantized_op and deletes the fake_quant_dequant_op. +/* The pattern like "fake_quantize_dequantize_op + quantized_op" can be + * deteted by this fuser. The fuser modifies the input scale for the + * quantized_op and deletes the fake_quant_dequant_op. */ class DeleteQuantDequantOpFuser : public FuseBase { public: + explicit DeleteQuantDequantOpFuser(const std::string& quant_dequant_op_type) + : quant_dequant_op_type_(quant_dequant_op_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string quant_dequant_op_type_{}; }; } // namespace fusion diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 4e67acdb22..17abee4a21 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -89,6 +89,7 @@ add_operator(fake_quantize_range_abs_max_op extra SRCS fake_quantize_range_abs_m add_operator(sequence_expand_as_op_lite extra SRCS sequence_expand_as_op.cc DEPS ${op_DEPS}) add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS}) add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) +add_operator(fake_quantize_dequantize_abs_max_op extra SRCS fake_quantize_dequantize_abs_max.cc DEPS ${op_DEPS}) add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS}) add_operator(split_lod_tensor_op_lite extra SRCS split_lod_tensor_op.cc DEPS ${op_DEPS}) add_operator(merge_lod_tensor_op_lite extra SRCS merge_lod_tensor_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/fake_quantize_dequantize_abs_max.cc b/lite/operators/fake_quantize_dequantize_abs_max.cc new file mode 100644 index 0000000000..354f5e9dcd --- /dev/null +++ b/lite/operators/fake_quantize_dequantize_abs_max.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/fake_quantize_dequantize_abs_max.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators {} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(fake_quantize_dequantize_abs_max, + paddle::lite::operators::FakeQuantizeDequantizeAbsMaxOpLite); diff --git a/lite/operators/fake_quantize_dequantize_abs_max.h b/lite/operators/fake_quantize_dequantize_abs_max.h new file mode 100644 index 0000000000..7413b448ea --- /dev/null +++ b/lite/operators/fake_quantize_dequantize_abs_max.h @@ -0,0 +1,65 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/core/tensor.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FakeQuantizeDequantizeAbsMaxOpLite : public OpLite { + public: + FakeQuantizeDequantizeAbsMaxOpLite() {} + + explicit FakeQuantizeDequantizeAbsMaxOpLite(const std::string &type) + : OpLite(type) {} + + bool CheckShape() const override { return true; } + + bool InferShapeImpl() const override { return true; } + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto out = op_desc.Output("Out").front(); + auto out_scale = op_desc.Output("OutScale").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.out = scope->FindVar(out)->GetMutable(); + param_.out_scale = scope->FindVar(out_scale)->GetMutable(); + param_.bit_length = op_desc.GetAttr("bit_length"); + return true; + } + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "fake_quantize_dequantize_abs_max"; + } + + private: + mutable FakeQuantDequantAbsMaxParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 240cf65d26..ef728924c1 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -678,6 +678,13 @@ struct FakeChannelWiseDequantizeMaxAbsParam : ParamBase { std::vector quant_bits; }; +struct FakeQuantDequantAbsMaxParam : ParamBase { + const lite::Tensor* x{}; + lite::Tensor* out{}; + lite::Tensor* out_scale{}; + int bit_length; +}; + /// ----------------------- sgd operators ---------------------- struct SGDParam : ParamBase { int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; -- GitLab