From f480d474cdcdea22d25e7e49bd124b8c9a6d1949 Mon Sep 17 00:00:00 2001 From: juncaipeng <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 22 Oct 2019 15:48:38 +0800 Subject: [PATCH] Optimize quant_dequant (#2215) * Add DeleteQuantOpFuser * Add fake_quantize_dequantize_moving_avg_abs_max_op * Add DeleteQuantDequantOpFuser --- lite/api/mobilenetv1_int8_test.cc | 19 +- .../mir/fusion/quant_dequant_fuse_pass.cc | 65 +---- .../core/mir/fusion/quant_dequant_op_fuser.cc | 231 +++++++++++++++++- lite/core/mir/fusion/quant_dequant_op_fuser.h | 43 +++- lite/operators/CMakeLists.txt | 1 + ..._quantize_dequantize_moving_avg_max_abs.cc | 26 ++ ...e_quantize_dequantize_moving_avg_max_abs.h | 69 ++++++ lite/operators/op_params.h | 5 + 8 files changed, 397 insertions(+), 62 deletions(-) create mode 100644 lite/operators/fake_quantize_dequantize_moving_avg_max_abs.cc create mode 100644 lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h diff --git a/lite/api/mobilenetv1_int8_test.cc b/lite/api/mobilenetv1_int8_test.cc index 2a54042f43..fb4a98084c 100644 --- a/lite/api/mobilenetv1_int8_test.cc +++ b/lite/api/mobilenetv1_int8_test.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include "lite/api/cxx_api.h" #include "lite/api/paddle_use_kernels.h" @@ -22,6 +23,10 @@ #include "lite/api/test_helper.h" #include "lite/core/op_registry.h" +DEFINE_string(input_img_txt_path, + "", + "if set input_img_txt_path, read the img filename as input."); + namespace paddle { namespace lite { @@ -36,8 +41,18 @@ void TestModel(const std::vector& valid_places) { input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); auto* data = input_tensor->mutable_data(); auto item_size = input_tensor->dims().production(); - for (int i = 0; i < item_size; i++) { - data[i] = 1; + if (FLAGS_input_img_txt_path.empty()) { + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + } else { + std::fstream fs(FLAGS_input_img_txt_path, std::ios::in); + if (!fs.is_open()) { + LOG(FATAL) << "open input_img_txt error."; + } + for (int i = 0; i < item_size; i++) { + fs >> data[i]; + } } for (int i = 0; i < FLAGS_warmup; ++i) { diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index 92ef0180ac..5498c28922 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -15,7 +15,6 @@ #include "lite/core/mir/fusion/quant_dequant_fuse_pass.h" #include #include -#include #include #include "lite/api/paddle_place.h" #include "lite/core/mir/fusion/quant_dequant_op_fuser.h" @@ -26,63 +25,25 @@ namespace lite { namespace mir { void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { - // obtain useful values and save to quantized_node, remove quant_nodes and - // releated nodes - std::unordered_set quant_types = { + // delete quant node + std::vector quant_op_types = { "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; - std::vector quant_nodes; - for (auto& cur_node : graph->mutable_nodes()) { - if (cur_node.IsStmt() && quant_types.count(cur_node.stmt()->op_type())) { - quant_nodes.push_back(&cur_node); - } - } - for (auto quant_node : quant_nodes) { - // find input nodes and output nodes - std::list input_nodes = quant_node->inlinks; - std::list output_nodes = quant_node->outlinks; - CHECK_EQ(input_nodes.size(), 2); - CHECK_EQ(output_nodes.size(), 2); - - bool front_is_scale = input_nodes.front()->arg()->is_weight; - Node* input_scale_node = - front_is_scale ? input_nodes.front() : input_nodes.back(); - Node* input_act_node = - front_is_scale ? input_nodes.back() : input_nodes.front(); - front_is_scale = output_nodes.front()->arg()->is_weight; - Node* output_scale_node = - front_is_scale ? output_nodes.front() : output_nodes.back(); - Node* output_act_node = - front_is_scale ? output_nodes.back() : output_nodes.front(); - - // relink nodes and save value to quantized_node - int bit_length = quant_node->stmt()->op_info()->GetAttr("bit_length"); - int range = ((1 << (bit_length - 1)) - 1); - auto* scope = quant_node->stmt()->op()->scope(); - auto scale_tensor = scope->FindVar(output_scale_node->arg()->name) - ->GetMutable(); - float scale_value = scale_tensor->data()[0] / range; - - auto outlinks = output_act_node->outlinks; - for (auto* quantized_node_ptr : outlinks) { - quantized_node_ptr->stmt()->mutable_op_info()->SetAttr("bit_length", - bit_length); - quantized_node_ptr->stmt()->mutable_op_info()->SetAttr( - "input_scale", scale_value); - IR_NODE_LINK_TO(input_act_node, quantized_node_ptr) - RemoveDirectedLink(output_act_node, quantized_node_ptr); - } - - // delete nodes and edges - std::unordered_set nodes2rm = { - input_scale_node, quant_node, output_scale_node, output_act_node}; - GraphSafeRemoveNodes(graph.get(), nodes2rm); + for (auto& op_type : quant_op_types) { + fusion::DeleteQuantOpFuser fuser(op_type); + fuser(graph.get()); } // fuse quantized node and dequant node - std::unordered_set quantized_op_types = { + std::vector quantized_op_types = { "conv2d", "mul", "depthwise_conv2d"}; for (auto& op_type : quantized_op_types) { - fusion::QuantDequantOpFuser fuser(op_type); + fusion::DequantOpFuser fuser(op_type); + fuser(graph.get()); + } + + // delete quant_dequant_node + for (auto op_type : {"pool2d", "elementwise_add"}) { + fusion::DeleteQuantDequantOpFuser fuser(op_type); fuser(graph.get()); } } diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index a0ede90446..c8b32d46e2 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -14,6 +14,7 @@ #include "lite/core/mir/fusion/quant_dequant_op_fuser.h" #include +#include #include #include "lite/utils/string.h" @@ -22,7 +23,61 @@ namespace lite { namespace mir { namespace fusion { -void QuantDequantOpFuser::BuildPattern() { +void DeleteQuantOpFuser::BuildPattern() { + auto* input_scale_node = VarNode("input_scale_node") + ->assert_is_op_input(quant_op_type_, "InScale"); + auto* input_act_node = + VarNode("input_act_node")->assert_is_op_input(quant_op_type_, "X"); + auto* quant_node = + OpNode("quant_node", quant_op_type_)->assert_is_op(quant_op_type_); + auto* output_scale_node = + VarNode("output_scale_node") + ->assert_is_op_output(quant_op_type_, "OutScale"); + auto* output_act_node = + VarNode("output_act_node")->assert_is_op_output(quant_op_type_, "Out"); + + quant_node->LinksFrom({input_scale_node, input_act_node}); + output_scale_node->LinksFrom({quant_node}); + output_act_node->LinksFrom({quant_node}); + VLOG(4) << "DeleteQuantOpFuser BuildPattern quant_op_type:" << quant_op_type_; +} + +void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto* input_scale_node = matched.at("input_scale_node"); + auto* input_act_node = matched.at("input_act_node"); + auto* quant_node = matched.at("quant_node"); + auto* output_scale_node = matched.at("output_scale_node"); + auto* output_act_node = matched.at("output_act_node"); + + // obtain values, save values and relink node + int bit_length = quant_node->stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + auto* scope = quant_node->stmt()->op()->scope(); + auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name) + ->GetMutable(); + float scale_value = scale_tensor->data()[0] / range; + + auto outlinks = output_act_node->outlinks; + for (auto* quantized_node : outlinks) { + auto* op_desc = quantized_node->stmt()->mutable_op_info(); + op_desc->SetAttr("bit_length", bit_length); + op_desc->SetAttr("input_scale", scale_value); + IR_NODE_LINK_TO(input_act_node, quantized_node) + } + + // delete nodes and edges + std::unordered_set nodes2rm = { + input_scale_node, quant_node, output_scale_node, output_act_node}; + GraphSafeRemoveNodes(graph, nodes2rm); +} + +cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + return op_desc; +} + +void DequantOpFuser::BuildPattern() { std::string weight_name = ""; if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { weight_name = "Filter"; @@ -55,10 +110,11 @@ void QuantDequantOpFuser::BuildPattern() { quantized_op_out->LinksFrom({quantized_op}); dequant_op->LinksFrom({quantized_op_out}); dequant_op_out->LinksFrom({dequant_op}); + VLOG(4) << "DeQuantOpFuser BuildPattern op_type:" << op_type_; } -void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph, - const key2nodes_t& matched) { +void DequantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { auto* quant_op_input = matched.at("quantized_op_input"); auto* quantized_op_weight = matched.at("quantized_op_weight"); auto* quantized_op = matched.at("quantized_op"); @@ -127,7 +183,174 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph, IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out); } -cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { +cpp::OpDesc DequantOpFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + return op_desc; +} + +void DeleteQuantDequantOpFuser::BuildPattern() { + std::string quant_dequant_op_type = + "fake_quantize_dequantize_moving_average_abs_max"; + if (quantized_op_type_ == "pool2d") { + auto* input_scale_node = + VarNode("input_scale_node") + ->assert_is_op_input(quant_dequant_op_type, "InScale"); + auto* input_act_node = VarNode("input_act_node") + ->assert_is_op_input(quant_dequant_op_type, "X"); + auto* quant_dequant_node = + OpNode("quant_dequant_node", quant_dequant_op_type) + ->assert_is_op(quant_dequant_op_type); + auto* output_scale_node = + VarNode("output_scale_node") + ->assert_is_op_output(quant_dequant_op_type, "OutScale"); + auto* output_act_node = + VarNode("output_act_node") + ->assert_is_op_output(quant_dequant_op_type, "Out"); + auto* quantized_node = OpNode("quantized_node", quantized_op_type_) + ->assert_is_op(quantized_op_type_); + + quant_dequant_node->LinksFrom({input_scale_node, input_act_node}); + output_scale_node->LinksFrom({quant_dequant_node}); + output_act_node->LinksFrom({quant_dequant_node}); + quantized_node->LinksFrom({output_act_node}); + } else if (quantized_op_type_ == "elementwise_add") { + auto* input_scale_left_node = + VarNode("input_scale_left_node") + ->assert_is_op_input(quant_dequant_op_type, "InScale"); + auto* input_act_left_node = + VarNode("input_act_left_node") + ->assert_is_op_input(quant_dequant_op_type, "X"); + auto* quant_dequant_left_node = + OpNode("quant_dequant_left_node", quant_dequant_op_type) + ->assert_is_op(quant_dequant_op_type); + auto* output_scale_left_node = + VarNode("output_scale_left_node") + ->assert_is_op_output(quant_dequant_op_type, "OutScale"); + auto* output_act_left_node = + VarNode("output_act_left_node") + ->assert_is_op_output(quant_dequant_op_type, "Out") + ->assert_is_op_input(quantized_op_type_, "X"); + quant_dequant_left_node->LinksFrom( + {input_scale_left_node, input_act_left_node}); + output_scale_left_node->LinksFrom({quant_dequant_left_node}); + output_act_left_node->LinksFrom({quant_dequant_left_node}); + + auto* input_scale_right_node = + VarNode("input_scale_right_node") + ->assert_is_op_input(quant_dequant_op_type, "InScale"); + auto* input_act_right_node = + VarNode("input_act_right_node") + ->assert_is_op_input(quant_dequant_op_type, "X"); + auto* quant_dequant_right_node = + OpNode("quant_dequant_right_node", quant_dequant_op_type) + ->assert_is_op(quant_dequant_op_type); + auto* output_scale_right_node = + VarNode("output_scale_right_node") + ->assert_is_op_output(quant_dequant_op_type, "OutScale"); + auto* output_act_right_node = + VarNode("output_act_right_node") + ->assert_is_op_output(quant_dequant_op_type, "Out") + ->assert_is_op_input(quantized_op_type_, "Y"); + quant_dequant_right_node->LinksFrom( + {input_scale_right_node, input_act_right_node}); + output_scale_right_node->LinksFrom({quant_dequant_right_node}); + output_act_right_node->LinksFrom({quant_dequant_right_node}); + + auto* quantized_node = OpNode("quantized_node", quantized_op_type_) + ->assert_is_op(quantized_op_type_); + quantized_node->LinksFrom({output_act_left_node, output_act_right_node}); + } else { + LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_; + } + VLOG(4) << "DeleteQuantDequantOpFuser BuildPattern op_type:" + << quantized_op_type_; +} + +void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + if (quantized_op_type_ == "pool2d") { + auto* input_scale_node = matched.at("input_scale_node"); + auto* input_act_node = matched.at("input_act_node"); + auto* quant_dequant_node = matched.at("quant_dequant_node"); + auto* output_scale_node = matched.at("output_scale_node"); + auto* output_act_node = matched.at("output_act_node"); + auto* quantized_node = matched.at("quantized_node"); + + // obtain values, save values and relink node + int bit_length = + quant_dequant_node->stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + auto* scope = quant_dequant_node->stmt()->op()->scope(); + auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name) + ->GetMutable(); + float scale_value = scale_tensor->data()[0] / range; + + auto* op_desc = quantized_node->stmt()->mutable_op_info(); + op_desc->SetAttr("bit_length", bit_length); + op_desc->SetAttr("input_scale", scale_value); + op_desc->SetInput("X", {input_act_node->arg()->name}); + IR_NODE_LINK_TO(input_act_node, quantized_node) + + // delete nodes and edges + std::unordered_set nodes2rm = {input_scale_node, + quant_dequant_node, + output_scale_node, + output_act_node}; + GraphSafeRemoveNodes(graph, nodes2rm); + } else if (quantized_op_type_ == "elementwise_add") { + auto* input_scale_left_node = matched.at("input_scale_left_node"); + auto* input_act_left_node = matched.at("input_act_left_node"); + auto* quant_dequant_left_node = matched.at("quant_dequant_left_node"); + auto* output_scale_left_node = matched.at("output_scale_left_node"); + auto* output_act_left_node = matched.at("output_act_left_node"); + + auto* input_scale_right_node = matched.at("input_scale_right_node"); + auto* input_act_right_node = matched.at("input_act_right_node"); + auto* quant_dequant_right_node = matched.at("quant_dequant_right_node"); + auto* output_scale_right_node = matched.at("output_scale_right_node"); + auto* output_act_right_node = matched.at("output_act_right_node"); + + auto* quantized_node = matched.at("quantized_node"); + + // obtain values, save values and relink node + int bit_length = + quant_dequant_left_node->stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + auto* scope = quant_dequant_left_node->stmt()->op()->scope(); + auto* left_scale_tensor = + scope->FindVar(output_scale_left_node->arg()->name) + ->GetMutable(); + float left_scale_value = left_scale_tensor->data()[0] / range; + auto* right_scale_tensor = + scope->FindVar(output_scale_right_node->arg()->name) + ->GetMutable(); + float right_scale_value = right_scale_tensor->data()[0] / range; + + auto* op_desc = quantized_node->stmt()->mutable_op_info(); + op_desc->SetAttr("bit_length", bit_length); + op_desc->SetAttr("x_input_scale", left_scale_value); + op_desc->SetAttr("y_input_scale", right_scale_value); + op_desc->SetInput("X", {input_act_left_node->arg()->name}); + op_desc->SetInput("Y", {input_act_right_node->arg()->name}); + IR_NODE_LINK_TO(input_act_left_node, quantized_node) + IR_NODE_LINK_TO(input_act_right_node, quantized_node) + + // delete nodes and edges + std::unordered_set nodes2rm = {input_scale_left_node, + quant_dequant_left_node, + output_scale_left_node, + output_act_left_node, + input_scale_right_node, + quant_dequant_right_node, + output_scale_right_node, + output_act_right_node}; + GraphSafeRemoveNodes(graph, nodes2rm); + } else { + LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_; + } +} + +cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc op_desc; return op_desc; } diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.h b/lite/core/mir/fusion/quant_dequant_op_fuser.h index b635b58f2f..a56fb66577 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.h +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.h @@ -34,11 +34,25 @@ namespace fusion { * the quantized_op. * In addition, the fuser delete fake_quant and fake_dequant op in the graph at * the last. - */ -class QuantDequantOpFuser : public FuseBase { +*/ + +class DeleteQuantOpFuser : public FuseBase { + public: + explicit DeleteQuantOpFuser(const std::string& quant_op_type) + : quant_op_type_(quant_op_type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string quant_op_type_{}; +}; + +class DequantOpFuser : public FuseBase { public: - explicit QuantDequantOpFuser(const std::string& op_type) - : op_type_(op_type) {} + explicit DequantOpFuser(const std::string& op_type) : op_type_(op_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; @@ -49,6 +63,27 @@ class QuantDequantOpFuser : public FuseBase { std::string op_type_{}; }; +/* The pattern like "fake_quantize_dequantize_moving_average_abs_max + + * pooled/elementwise_add" can be deteted by this fuser. The fuser + * extract the input_scale form fake_quant_dequant_op and save into + * the quantized_op. Besides, the fuser delete fake_quant_dequant_op in + * the graph. +*/ + +class DeleteQuantDequantOpFuser : public FuseBase { + public: + explicit DeleteQuantDequantOpFuser(const std::string& quantized_op_type) + : quantized_op_type_(quantized_op_type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string quantized_op_type_{}; +}; + } // namespace fusion } // namespace mir } // namespace lite diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 0a4beab7f2..2d23d8bb06 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -75,6 +75,7 @@ add_operator(fake_quantize_range_abs_max_op basic SRCS fake_quantize_range_abs_m add_operator(sequence_expand_as_op_lite basic SRCS sequence_expand_as_op.cc DEPS ${op_DEPS}) add_operator(range_op basic SRCS range_op.cc DEPS ${op_DEPS}) add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS}) +add_operator(fake_quantize_dequantize_moving_avg_abs_max_op basic SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.cc b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.cc new file mode 100644 index 0000000000..5a86d3e468 --- /dev/null +++ b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators {} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP( + fake_quantize_dequantize_moving_average_abs_max, + paddle::lite::operators::FakeQuantizeDequantizeMovingAvgMaxAbsOpLite); diff --git a/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h new file mode 100644 index 0000000000..8efa46c415 --- /dev/null +++ b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/core/tensor.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite { + public: + FakeQuantizeDequantizeMovingAvgMaxAbsOpLite() {} + + explicit FakeQuantizeDequantizeMovingAvgMaxAbsOpLite(const std::string &type) + : OpLite(type) {} + + bool CheckShape() const override { return true; } + + bool InferShape() const override { return true; } + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto in_scale = op_desc.Input("InScale").front(); + + auto out = op_desc.Output("Out").front(); + auto out_scale = op_desc.Output("OutScale").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.in_scale = scope->FindVar(in_scale)->GetMutable(); + + param_.out = scope->FindVar(out)->GetMutable(); + param_.out_scale = scope->FindVar(out_scale)->GetMutable(); + param_.bit_length = op_desc.GetAttr("bit_length"); + return true; + } + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "fake_quantize_dequantize_moving_avg_max_abs"; + } + + private: + mutable FakeQuantizeMovingAvgMaxAbsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index ecca911b76..3071f6f907 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -294,6 +294,8 @@ struct PoolParam { bool ceil_mode{false}; bool use_quantizer{false}; std::string data_format{"AnyLayout"}; + // for int8 + WITH_INT8_CONFIG }; // For Dropout op @@ -332,7 +334,10 @@ struct ElementwiseParam { const lite::Tensor* Y{}; lite::Tensor* Out{}; int axis{-1}; // for broadcasting. + // for int8 WITH_INT8_CONFIG + float x_input_scale{1.0}; + float y_input_scale{1.0}; }; struct ElementwiseGradParam { -- GitLab