diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index 8ec50b8112b6b853e83abf5c491163fa4475f2f4..ff5a7a1f25239d9dbfc79491bd137804b16b6cfa 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -34,13 +34,16 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { } // fuse quantized node and dequant node - std::vector quantized_op_types = { - "conv2d", "mul", "depthwise_conv2d"}; - for (auto& op_type : quantized_op_types) { + for (auto& op_type : {"conv2d", "mul", "depthwise_conv2d"}) { fusion::DequantOpFuser fuser(op_type); fuser(graph.get()); } + for (auto& op_type : {"conv2d", "depthwise_conv2d"}) { + fusion::ChannelWiseDequantOpFuser fuser(op_type); + fuser(graph.get()); + } + // delete quant_dequant_node for (auto op_type : {"pool2d", "elementwise_add"}) { fusion::DeleteQuantDequantOpFuser fuser(op_type); diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index c8b32d46e20586bddc0c1c61fd03cf2a082137e7..f823f45dc66f8ef6cc67cbb9b0d9860c86ec9340 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -79,23 +79,26 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) { void DequantOpFuser::BuildPattern() { std::string weight_name = ""; - if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + if (quantized_op_type_ == "conv2d" || + quantized_op_type_ == "depthwise_conv2d") { weight_name = "Filter"; } else { weight_name = "Y"; } - auto* quantized_op_input = - VarNode("quantized_op_input")->assert_is_op_input(op_type_)->AsInput(); - auto* quantized_op_weight = VarNode("quantized_op_weight") - ->assert_is_op_input(op_type_, weight_name) - ->AsInput(); - auto* quantized_op = OpNode("quantized_op", op_type_) - ->assert_is_op(op_type_) + auto* quantized_op_input = VarNode("quantized_op_input") + ->assert_is_op_input(quantized_op_type_) + ->AsInput(); + auto* quantized_op_weight = + VarNode("quantized_op_weight") + ->assert_is_op_input(quantized_op_type_, weight_name) + ->AsInput(); + auto* quantized_op = OpNode("quantized_op", quantized_op_type_) + ->assert_is_op(quantized_op_type_) ->AsIntermediate(); auto* quantized_op_out = VarNode("quantized_op_out") - ->assert_is_op_output(op_type_) + ->assert_is_op_output(quantized_op_type_) ->assert_is_op_input("fake_dequantize_max_abs", "X") ->AsIntermediate(); auto* dequant_op = OpNode("dequant_op", "fake_dequantize_max_abs") @@ -110,12 +113,13 @@ void DequantOpFuser::BuildPattern() { quantized_op_out->LinksFrom({quantized_op}); dequant_op->LinksFrom({quantized_op_out}); dequant_op_out->LinksFrom({dequant_op}); - VLOG(4) << "DeQuantOpFuser BuildPattern op_type:" << op_type_; + + VLOG(4) << "DeQuantOpFuser BuildPattern op_type:" << quantized_op_type_; } void DequantOpFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { - auto* quant_op_input = matched.at("quantized_op_input"); + auto* quantized_op_input = matched.at("quantized_op_input"); auto* quantized_op_weight = matched.at("quantized_op_weight"); auto* quantized_op = matched.at("quantized_op"); auto* dequant_op = matched.at("dequant_op"); @@ -142,14 +146,15 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, scope->FindVar(quantized_weight_var_name)->GetMutable(); std::vector weight_scale; int weight_scale_size; - if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { - op_desc.SetInput("Input", {quant_op_input->arg()->name}); + if (quantized_op_type_ == "conv2d" || + quantized_op_type_ == "depthwise_conv2d") { + op_desc.SetInput("Input", {quantized_op_input->arg()->name}); op_desc.SetOutput("Output", {dequant_op_out->arg()->name}); // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should // be Cout. weight_scale_size = quantized_weight_t->dims()[0]; - } else if (op_type_ == "mul") { - op_desc.SetInput("X", {quant_op_input->arg()->name}); + } else if (quantized_op_type_ == "mul") { + op_desc.SetInput("X", {quantized_op_input->arg()->name}); op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); // Fc weight: Cin * Cout, the weight_scale_size should be Cout. weight_scale_size = quantized_weight_t->dims()[1]; @@ -174,11 +179,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, quantized_weight_t->set_precision(PRECISION(kInt8)); // new op and relink nodes - auto new_quantized_op = LiteOpRegistry::Global().Create(op_type_); + auto new_quantized_op = LiteOpRegistry::Global().Create(quantized_op_type_); new_quantized_op->Attach(op_desc, scope); auto* new_quantized_op_node = graph->GraphCreateInstructNode(new_quantized_op, valid_places); - IR_NODE_LINK_TO(quant_op_input, new_quantized_op_node); + IR_NODE_LINK_TO(quantized_op_input, new_quantized_op_node); IR_NODE_LINK_TO(quantized_op_weight, new_quantized_op_node); IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out); } @@ -188,6 +193,107 @@ cpp::OpDesc DequantOpFuser::GenOpDesc(const key2nodes_t& matched) { return op_desc; } +void ChannelWiseDequantOpFuser::BuildPattern() { + std::string dequant_op_type = "fake_channel_wise_dequantize_max_abs"; + auto* quantized_op_input = VarNode("quantized_op_input") + ->assert_is_op_input(quantized_op_type_) + ->AsInput(); + auto* quantized_op_weight = + VarNode("quantized_op_weight") + ->assert_is_op_input(quantized_op_type_, "Filter") + ->AsInput(); + auto* quantized_op = OpNode("quantized_op", quantized_op_type_) + ->assert_is_op(quantized_op_type_) + ->AsIntermediate(); + auto* quantized_op_out = VarNode("quantized_op_out") + ->assert_is_op_output(quantized_op_type_) + ->assert_is_op_input(dequant_op_type, "X") + ->AsIntermediate(); + auto* dequant_op_channel_scale = VarNode("dequant_op_channel_scale") + ->assert_is_op_input(dequant_op_type) + ->AsIntermediate(); + auto* dequant_op = OpNode("dequant_op", dequant_op_type) + ->assert_is_op(dequant_op_type) + ->AsIntermediate(); + auto* dequant_op_out = VarNode("dequant_op_out") + ->assert_is_op_output(dequant_op_type, "Out") + ->AsOutput(); + + quantized_op->LinksFrom({quantized_op_input, quantized_op_weight}); + quantized_op_out->LinksFrom({quantized_op}); + dequant_op->LinksFrom({quantized_op_out, dequant_op_channel_scale}); + dequant_op_out->LinksFrom({dequant_op}); + + VLOG(4) << "ChannelWiseDequantOpFuser BuildPattern op_type:" + << quantized_op_type_; +} + +void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto* quantized_op_input = matched.at("quantized_op_input"); + auto* quantized_op_weight = matched.at("quantized_op_weight"); + auto* quantized_op = matched.at("quantized_op"); + auto* dequant_op_channel_scale = matched.at("dequant_op_channel_scale"); + auto* dequant_op = matched.at("dequant_op"); + auto* dequant_op_out = matched.at("dequant_op_out"); + + // obtain input_scale and weight_scale + auto* scope = quantized_op->stmt()->op()->scope(); + auto& valid_places = quantized_op->stmt()->op()->valid_places(); + float input_scale = + quantized_op->stmt()->op_info()->GetAttr("input_scale"); + + std::vector weight_scale; + std::vector quant_bits = + dequant_op->stmt()->op_info()->GetAttr>("quant_bits"); + int weight_bit_length = quant_bits[0]; + int range = ((1 << (weight_bit_length - 1)) - 1); + auto channel_scale_name = dequant_op_channel_scale->arg()->name; + auto channel_scale_tensor = + scope->FindVar(channel_scale_name)->GetMutable(); + auto* channel_scale_data = channel_scale_tensor->data(); + for (int i = 0; i < channel_scale_tensor->data_size(); i++) { + weight_scale.push_back(channel_scale_data[i] / range); + } + + // set op desc + cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); + op_desc.SetInput("Input", {quantized_op_input->arg()->name}); + op_desc.SetOutput("Output", {dequant_op_out->arg()->name}); + + op_desc.SetAttr("enable_int8", true); + op_desc.SetAttr("input_scale", input_scale); + op_desc.SetAttr("weight_scale", weight_scale); + + // change the weight from the float type to int8 type. + auto quantized_weight_var_name = quantized_op_weight->arg()->name; + auto quantized_weight_t = + scope->FindVar(quantized_weight_var_name)->GetMutable(); + Tensor temp_tensor; + temp_tensor.CopyDataFrom(*quantized_weight_t); + float* temp_data = temp_tensor.mutable_data(); + int8_t* quantized_weight_data = quantized_weight_t->mutable_data(); + for (size_t i = 0; i < quantized_weight_t->data_size(); i++) { + quantized_weight_data[i] = static_cast(temp_data[i]); + } + quantized_weight_t->set_persistable(true); + quantized_weight_t->set_precision(PRECISION(kInt8)); + + // new op and relink nodes + auto new_quantized_op = LiteOpRegistry::Global().Create(quantized_op_type_); + new_quantized_op->Attach(op_desc, scope); + auto* new_quantized_op_node = + graph->GraphCreateInstructNode(new_quantized_op, valid_places); + IR_NODE_LINK_TO(quantized_op_input, new_quantized_op_node); + IR_NODE_LINK_TO(quantized_op_weight, new_quantized_op_node); + IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out); +} + +cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + return op_desc; +} + void DeleteQuantDequantOpFuser::BuildPattern() { std::string quant_dequant_op_type = "fake_quantize_dequantize_moving_average_abs_max"; diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.h b/lite/core/mir/fusion/quant_dequant_op_fuser.h index a56fb665770cb3d523c5666550e295ef51af8474..bef9f4d9573d049700736c166cd0d31b668f7eff 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.h +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.h @@ -24,18 +24,21 @@ namespace mir { namespace fusion { /* The model trained by fluid quantization is a simulation of real int8. - * The quantized Ops(conv2d, mul, depthwise conv2d etc) have fake_quantop - * in front and fake_dequantop behind. + * The quantized Ops(conv2d, mul, depthwise conv2d etc) have fake_quant op + * in front and fake_dequant op behind. * - * When in int8 mode, the pattern like "fake_quant + quantized_op + - * fake_dequant" - * can be detected by this fuser. The fuser extract the input_scale and - * the weight_scale info from fake_quant, fake_dequant op and fuse those into - * the quantized_op. + * For int8 mode, the pattern like "fake_quant + quantized_op + fake_dequant" + * can be processed by the following three fuser. The fuser extract the + * input_scale and the weight_scale info from fake_quant, fake_dequant op and + * fuse those into the quantized_op. * In addition, the fuser delete fake_quant and fake_dequant op in the graph at * the last. */ +/* DeleteQuantOpFuser process + * fake_quantize_range_abs_max/fake_quantize_moving_average_abs_max + * + conv2d/mul/depthwise. +*/ class DeleteQuantOpFuser : public FuseBase { public: explicit DeleteQuantOpFuser(const std::string& quant_op_type) @@ -50,9 +53,12 @@ class DeleteQuantOpFuser : public FuseBase { std::string quant_op_type_{}; }; +/* DequantOpFuser process conv2d/depthwise_conv2d/mul + fake_dequantize_max_abs. +*/ class DequantOpFuser : public FuseBase { public: - explicit DequantOpFuser(const std::string& op_type) : op_type_(op_type) {} + explicit DequantOpFuser(const std::string& quantized_op_type) + : quantized_op_type_(quantized_op_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; @@ -60,7 +66,24 @@ class DequantOpFuser : public FuseBase { cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; private: - std::string op_type_{}; + std::string quantized_op_type_{}; +}; + +/* ChannelWiseDequantOpFuser process conv2d/depthwise_conv2d + + * fake_channel_wise_dequantize_max_abs. +*/ +class ChannelWiseDequantOpFuser : public FuseBase { + public: + explicit ChannelWiseDequantOpFuser(const std::string& quantized_op_type) + : quantized_op_type_(quantized_op_type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string quantized_op_type_{}; }; /* The pattern like "fake_quantize_dequantize_moving_average_abs_max + diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 21b8ec278a6df16711bef3d1b3be34f77c52c9b3..49badbb27b00979117f9e75d1c66763a7be99837 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -76,6 +76,7 @@ add_operator(sequence_expand_as_op_lite extra SRCS sequence_expand_as_op.cc DEPS add_operator(range_op extra SRCS range_op.cc DEPS ${op_DEPS}) add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS}) add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) +add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS}) add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS}) add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) diff --git a/lite/operators/fake_channel_wise_dequantize_max_abs.cc b/lite/operators/fake_channel_wise_dequantize_max_abs.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bf1a44bb5b1225e62acb1f6a32b43c7887cd906 --- /dev/null +++ b/lite/operators/fake_channel_wise_dequantize_max_abs.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/fake_channel_wise_dequantize_max_abs.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators {} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP( + fake_channel_wise_dequantize_max_abs, + paddle::lite::operators::FakeChannelWiseDequantizeMaxAbsOpLite); diff --git a/lite/operators/fake_channel_wise_dequantize_max_abs.h b/lite/operators/fake_channel_wise_dequantize_max_abs.h new file mode 100644 index 0000000000000000000000000000000000000000..43afb7791fe617af0c7ac496cc62a12e6cc548d2 --- /dev/null +++ b/lite/operators/fake_channel_wise_dequantize_max_abs.h @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/core/tensor.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite { + public: + FakeChannelWiseDequantizeMaxAbsOpLite() {} + + explicit FakeChannelWiseDequantizeMaxAbsOpLite(const std::string &type) + : OpLite(type) {} + + bool CheckShape() const override { return true; } + + bool InferShape() const override { return true; } + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + param_.x = scope->FindVar(x)->GetMutable(); + + auto args = op_desc.Input("Scales"); + for (auto arg : args) { + auto *var = scope->FindVar(arg); + if (var != nullptr) { + param_.scale_tensors.push_back(var->GetMutable()); + } + } + + auto out = op_desc.Output("Out").front(); + param_.out = scope->FindVar(out)->GetMutable(); + + param_.quant_bits = op_desc.GetAttr>("quant_bits"); + return true; + } + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "fake_channel_wise_dequantize_max_abs"; + } + + private: + mutable FakeChannelWiseDequantizeMaxAbsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 84abe5a2e61e451d7115c8d3cba0e891d55deed6..7ed7715c304377116fd42c9c17971545571c8678 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -427,6 +427,13 @@ struct FakeDequantizeMaxAbsParam { float max_range; }; +struct FakeChannelWiseDequantizeMaxAbsParam { + const lite::Tensor* x{}; + std::vector scale_tensors{}; + lite::Tensor* out{}; + std::vector quant_bits; +}; + /// ----------------------- sgd operators ---------------------- struct SGDParam { int dtype{static_cast(VarDescAPI::VarDataType::FP32)};