未验证 提交 f480d474 编写于 作者: J juncaipeng 提交者: GitHub

Optimize quant_dequant (#2215)

* Add DeleteQuantOpFuser
* Add fake_quantize_dequantize_moving_avg_abs_max_op
* Add DeleteQuantDequantOpFuser
上级 4e05ea29
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <fstream>
#include <vector> #include <vector>
#include "lite/api/cxx_api.h" #include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_kernels.h"
...@@ -22,6 +23,10 @@ ...@@ -22,6 +23,10 @@
#include "lite/api/test_helper.h" #include "lite/api/test_helper.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
DEFINE_string(input_img_txt_path,
"",
"if set input_img_txt_path, read the img filename as input.");
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -36,9 +41,19 @@ void TestModel(const std::vector<Place>& valid_places) { ...@@ -36,9 +41,19 @@ void TestModel(const std::vector<Place>& valid_places) {
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
auto* data = input_tensor->mutable_data<float>(); auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production(); auto item_size = input_tensor->dims().production();
if (FLAGS_input_img_txt_path.empty()) {
for (int i = 0; i < item_size; i++) { for (int i = 0; i < item_size; i++) {
data[i] = 1; data[i] = 1;
} }
} else {
std::fstream fs(FLAGS_input_img_txt_path, std::ios::in);
if (!fs.is_open()) {
LOG(FATAL) << "open input_img_txt error.";
}
for (int i = 0; i < item_size; i++) {
fs >> data[i];
}
}
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
predictor.Run(); predictor.Run();
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "lite/core/mir/fusion/quant_dequant_fuse_pass.h" #include "lite/core/mir/fusion/quant_dequant_fuse_pass.h"
#include <list> #include <list>
#include <memory> #include <memory>
#include <unordered_set>
#include <vector> #include <vector>
#include "lite/api/paddle_place.h" #include "lite/api/paddle_place.h"
#include "lite/core/mir/fusion/quant_dequant_op_fuser.h" #include "lite/core/mir/fusion/quant_dequant_op_fuser.h"
...@@ -26,63 +25,25 @@ namespace lite { ...@@ -26,63 +25,25 @@ namespace lite {
namespace mir { namespace mir {
void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// obtain useful values and save to quantized_node, remove quant_nodes and // delete quant node
// releated nodes std::vector<std::string> quant_op_types = {
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::vector<Node*> quant_nodes; for (auto& op_type : quant_op_types) {
for (auto& cur_node : graph->mutable_nodes()) { fusion::DeleteQuantOpFuser fuser(op_type);
if (cur_node.IsStmt() && quant_types.count(cur_node.stmt()->op_type())) { fuser(graph.get());
quant_nodes.push_back(&cur_node);
}
}
for (auto quant_node : quant_nodes) {
// find input nodes and output nodes
std::list<Node*> input_nodes = quant_node->inlinks;
std::list<Node*> output_nodes = quant_node->outlinks;
CHECK_EQ(input_nodes.size(), 2);
CHECK_EQ(output_nodes.size(), 2);
bool front_is_scale = input_nodes.front()->arg()->is_weight;
Node* input_scale_node =
front_is_scale ? input_nodes.front() : input_nodes.back();
Node* input_act_node =
front_is_scale ? input_nodes.back() : input_nodes.front();
front_is_scale = output_nodes.front()->arg()->is_weight;
Node* output_scale_node =
front_is_scale ? output_nodes.front() : output_nodes.back();
Node* output_act_node =
front_is_scale ? output_nodes.back() : output_nodes.front();
// relink nodes and save value to quantized_node
int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_node->stmt()->op()->scope();
auto scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto outlinks = output_act_node->outlinks;
for (auto* quantized_node_ptr : outlinks) {
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<int>("bit_length",
bit_length);
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<float>(
"input_scale", scale_value);
IR_NODE_LINK_TO(input_act_node, quantized_node_ptr)
RemoveDirectedLink(output_act_node, quantized_node_ptr);
}
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, quant_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph.get(), nodes2rm);
} }
// fuse quantized node and dequant node // fuse quantized node and dequant node
std::unordered_set<std::string> quantized_op_types = { std::vector<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d"}; "conv2d", "mul", "depthwise_conv2d"};
for (auto& op_type : quantized_op_types) { for (auto& op_type : quantized_op_types) {
fusion::QuantDequantOpFuser fuser(op_type); fusion::DequantOpFuser fuser(op_type);
fuser(graph.get());
}
// delete quant_dequant_node
for (auto op_type : {"pool2d", "elementwise_add"}) {
fusion::DeleteQuantDequantOpFuser fuser(op_type);
fuser(graph.get()); fuser(graph.get());
} }
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "lite/core/mir/fusion/quant_dequant_op_fuser.h" #include "lite/core/mir/fusion/quant_dequant_op_fuser.h"
#include <memory> #include <memory>
#include <unordered_set>
#include <vector> #include <vector>
#include "lite/utils/string.h" #include "lite/utils/string.h"
...@@ -22,7 +23,61 @@ namespace lite { ...@@ -22,7 +23,61 @@ namespace lite {
namespace mir { namespace mir {
namespace fusion { namespace fusion {
void QuantDequantOpFuser::BuildPattern() { void DeleteQuantOpFuser::BuildPattern() {
auto* input_scale_node = VarNode("input_scale_node")
->assert_is_op_input(quant_op_type_, "InScale");
auto* input_act_node =
VarNode("input_act_node")->assert_is_op_input(quant_op_type_, "X");
auto* quant_node =
OpNode("quant_node", quant_op_type_)->assert_is_op(quant_op_type_);
auto* output_scale_node =
VarNode("output_scale_node")
->assert_is_op_output(quant_op_type_, "OutScale");
auto* output_act_node =
VarNode("output_act_node")->assert_is_op_output(quant_op_type_, "Out");
quant_node->LinksFrom({input_scale_node, input_act_node});
output_scale_node->LinksFrom({quant_node});
output_act_node->LinksFrom({quant_node});
VLOG(4) << "DeleteQuantOpFuser BuildPattern quant_op_type:" << quant_op_type_;
}
void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto* input_scale_node = matched.at("input_scale_node");
auto* input_act_node = matched.at("input_act_node");
auto* quant_node = matched.at("quant_node");
auto* output_scale_node = matched.at("output_scale_node");
auto* output_act_node = matched.at("output_act_node");
// obtain values, save values and relink node
int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_node->stmt()->op()->scope();
auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto outlinks = output_act_node->outlinks;
for (auto* quantized_node : outlinks) {
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("input_scale", scale_value);
IR_NODE_LINK_TO(input_act_node, quantized_node)
}
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, quant_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph, nodes2rm);
}
cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
void DequantOpFuser::BuildPattern() {
std::string weight_name = ""; std::string weight_name = "";
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
weight_name = "Filter"; weight_name = "Filter";
...@@ -55,9 +110,10 @@ void QuantDequantOpFuser::BuildPattern() { ...@@ -55,9 +110,10 @@ void QuantDequantOpFuser::BuildPattern() {
quantized_op_out->LinksFrom({quantized_op}); quantized_op_out->LinksFrom({quantized_op});
dequant_op->LinksFrom({quantized_op_out}); dequant_op->LinksFrom({quantized_op_out});
dequant_op_out->LinksFrom({dequant_op}); dequant_op_out->LinksFrom({dequant_op});
VLOG(4) << "DeQuantOpFuser BuildPattern op_type:" << op_type_;
} }
void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph, void DequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) { const key2nodes_t& matched) {
auto* quant_op_input = matched.at("quantized_op_input"); auto* quant_op_input = matched.at("quantized_op_input");
auto* quantized_op_weight = matched.at("quantized_op_weight"); auto* quantized_op_weight = matched.at("quantized_op_weight");
...@@ -127,7 +183,174 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -127,7 +183,174 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out); IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out);
} }
cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc DequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
void DeleteQuantDequantOpFuser::BuildPattern() {
std::string quant_dequant_op_type =
"fake_quantize_dequantize_moving_average_abs_max";
if (quantized_op_type_ == "pool2d") {
auto* input_scale_node =
VarNode("input_scale_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_node = VarNode("input_act_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_node =
OpNode("quant_dequant_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_node =
VarNode("output_scale_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_node =
VarNode("output_act_node")
->assert_is_op_output(quant_dequant_op_type, "Out");
auto* quantized_node = OpNode("quantized_node", quantized_op_type_)
->assert_is_op(quantized_op_type_);
quant_dequant_node->LinksFrom({input_scale_node, input_act_node});
output_scale_node->LinksFrom({quant_dequant_node});
output_act_node->LinksFrom({quant_dequant_node});
quantized_node->LinksFrom({output_act_node});
} else if (quantized_op_type_ == "elementwise_add") {
auto* input_scale_left_node =
VarNode("input_scale_left_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_left_node =
VarNode("input_act_left_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_left_node =
OpNode("quant_dequant_left_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_left_node =
VarNode("output_scale_left_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_left_node =
VarNode("output_act_left_node")
->assert_is_op_output(quant_dequant_op_type, "Out")
->assert_is_op_input(quantized_op_type_, "X");
quant_dequant_left_node->LinksFrom(
{input_scale_left_node, input_act_left_node});
output_scale_left_node->LinksFrom({quant_dequant_left_node});
output_act_left_node->LinksFrom({quant_dequant_left_node});
auto* input_scale_right_node =
VarNode("input_scale_right_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_right_node =
VarNode("input_act_right_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_right_node =
OpNode("quant_dequant_right_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_right_node =
VarNode("output_scale_right_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_right_node =
VarNode("output_act_right_node")
->assert_is_op_output(quant_dequant_op_type, "Out")
->assert_is_op_input(quantized_op_type_, "Y");
quant_dequant_right_node->LinksFrom(
{input_scale_right_node, input_act_right_node});
output_scale_right_node->LinksFrom({quant_dequant_right_node});
output_act_right_node->LinksFrom({quant_dequant_right_node});
auto* quantized_node = OpNode("quantized_node", quantized_op_type_)
->assert_is_op(quantized_op_type_);
quantized_node->LinksFrom({output_act_left_node, output_act_right_node});
} else {
LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_;
}
VLOG(4) << "DeleteQuantDequantOpFuser BuildPattern op_type:"
<< quantized_op_type_;
}
void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
if (quantized_op_type_ == "pool2d") {
auto* input_scale_node = matched.at("input_scale_node");
auto* input_act_node = matched.at("input_act_node");
auto* quant_dequant_node = matched.at("quant_dequant_node");
auto* output_scale_node = matched.at("output_scale_node");
auto* output_act_node = matched.at("output_act_node");
auto* quantized_node = matched.at("quantized_node");
// obtain values, save values and relink node
int bit_length =
quant_dequant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_dequant_node->stmt()->op()->scope();
auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("input_scale", scale_value);
op_desc->SetInput("X", {input_act_node->arg()->name});
IR_NODE_LINK_TO(input_act_node, quantized_node)
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale_node,
quant_dequant_node,
output_scale_node,
output_act_node};
GraphSafeRemoveNodes(graph, nodes2rm);
} else if (quantized_op_type_ == "elementwise_add") {
auto* input_scale_left_node = matched.at("input_scale_left_node");
auto* input_act_left_node = matched.at("input_act_left_node");
auto* quant_dequant_left_node = matched.at("quant_dequant_left_node");
auto* output_scale_left_node = matched.at("output_scale_left_node");
auto* output_act_left_node = matched.at("output_act_left_node");
auto* input_scale_right_node = matched.at("input_scale_right_node");
auto* input_act_right_node = matched.at("input_act_right_node");
auto* quant_dequant_right_node = matched.at("quant_dequant_right_node");
auto* output_scale_right_node = matched.at("output_scale_right_node");
auto* output_act_right_node = matched.at("output_act_right_node");
auto* quantized_node = matched.at("quantized_node");
// obtain values, save values and relink node
int bit_length =
quant_dequant_left_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_dequant_left_node->stmt()->op()->scope();
auto* left_scale_tensor =
scope->FindVar(output_scale_left_node->arg()->name)
->GetMutable<lite::Tensor>();
float left_scale_value = left_scale_tensor->data<float>()[0] / range;
auto* right_scale_tensor =
scope->FindVar(output_scale_right_node->arg()->name)
->GetMutable<lite::Tensor>();
float right_scale_value = right_scale_tensor->data<float>()[0] / range;
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("x_input_scale", left_scale_value);
op_desc->SetAttr<float>("y_input_scale", right_scale_value);
op_desc->SetInput("X", {input_act_left_node->arg()->name});
op_desc->SetInput("Y", {input_act_right_node->arg()->name});
IR_NODE_LINK_TO(input_act_left_node, quantized_node)
IR_NODE_LINK_TO(input_act_right_node, quantized_node)
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale_left_node,
quant_dequant_left_node,
output_scale_left_node,
output_act_left_node,
input_scale_right_node,
quant_dequant_right_node,
output_scale_right_node,
output_act_right_node};
GraphSafeRemoveNodes(graph, nodes2rm);
} else {
LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_;
}
}
cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc; cpp::OpDesc op_desc;
return op_desc; return op_desc;
} }
......
...@@ -34,11 +34,25 @@ namespace fusion { ...@@ -34,11 +34,25 @@ namespace fusion {
* the quantized_op. * the quantized_op.
* In addition, the fuser delete fake_quant and fake_dequant op in the graph at * In addition, the fuser delete fake_quant and fake_dequant op in the graph at
* the last. * the last.
*/ */
class QuantDequantOpFuser : public FuseBase {
class DeleteQuantOpFuser : public FuseBase {
public:
explicit DeleteQuantOpFuser(const std::string& quant_op_type)
: quant_op_type_(quant_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quant_op_type_{};
};
class DequantOpFuser : public FuseBase {
public: public:
explicit QuantDequantOpFuser(const std::string& op_type) explicit DequantOpFuser(const std::string& op_type) : op_type_(op_type) {}
: op_type_(op_type) {}
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
...@@ -49,6 +63,27 @@ class QuantDequantOpFuser : public FuseBase { ...@@ -49,6 +63,27 @@ class QuantDequantOpFuser : public FuseBase {
std::string op_type_{}; std::string op_type_{};
}; };
/* The pattern like "fake_quantize_dequantize_moving_average_abs_max +
* pooled/elementwise_add" can be deteted by this fuser. The fuser
* extract the input_scale form fake_quant_dequant_op and save into
* the quantized_op. Besides, the fuser delete fake_quant_dequant_op in
* the graph.
*/
class DeleteQuantDequantOpFuser : public FuseBase {
public:
explicit DeleteQuantDequantOpFuser(const std::string& quantized_op_type)
: quantized_op_type_(quantized_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quantized_op_type_{};
};
} // namespace fusion } // namespace fusion
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
......
...@@ -75,6 +75,7 @@ add_operator(fake_quantize_range_abs_max_op basic SRCS fake_quantize_range_abs_m ...@@ -75,6 +75,7 @@ add_operator(fake_quantize_range_abs_max_op basic SRCS fake_quantize_range_abs_m
add_operator(sequence_expand_as_op_lite basic SRCS sequence_expand_as_op.cc DEPS ${op_DEPS}) add_operator(sequence_expand_as_op_lite basic SRCS sequence_expand_as_op.cc DEPS ${op_DEPS})
add_operator(range_op basic SRCS range_op.cc DEPS ${op_DEPS}) add_operator(range_op basic SRCS range_op.cc DEPS ${op_DEPS})
add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS}) add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_dequantize_moving_avg_abs_max_op basic SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
# for OCR specific # for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(
fake_quantize_dequantize_moving_average_abs_max,
paddle::lite::operators::FakeQuantizeDequantizeMovingAvgMaxAbsOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite {
public:
FakeQuantizeDequantizeMovingAvgMaxAbsOpLite() {}
explicit FakeQuantizeDequantizeMovingAvgMaxAbsOpLite(const std::string &type)
: OpLite(type) {}
bool CheckShape() const override { return true; }
bool InferShape() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto in_scale = op_desc.Input("InScale").front();
auto out = op_desc.Output("Out").front();
auto out_scale = op_desc.Output("OutScale").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.in_scale = scope->FindVar(in_scale)->GetMutable<lite::Tensor>();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.out_scale = scope->FindVar(out_scale)->GetMutable<lite::Tensor>();
param_.bit_length = op_desc.GetAttr<int>("bit_length");
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "fake_quantize_dequantize_moving_avg_max_abs";
}
private:
mutable FakeQuantizeMovingAvgMaxAbsParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -294,6 +294,8 @@ struct PoolParam { ...@@ -294,6 +294,8 @@ struct PoolParam {
bool ceil_mode{false}; bool ceil_mode{false};
bool use_quantizer{false}; bool use_quantizer{false};
std::string data_format{"AnyLayout"}; std::string data_format{"AnyLayout"};
// for int8
WITH_INT8_CONFIG
}; };
// For Dropout op // For Dropout op
...@@ -332,7 +334,10 @@ struct ElementwiseParam { ...@@ -332,7 +334,10 @@ struct ElementwiseParam {
const lite::Tensor* Y{}; const lite::Tensor* Y{};
lite::Tensor* Out{}; lite::Tensor* Out{};
int axis{-1}; // for broadcasting. int axis{-1}; // for broadcasting.
// for int8
WITH_INT8_CONFIG WITH_INT8_CONFIG
float x_input_scale{1.0};
float y_input_scale{1.0};
}; };
struct ElementwiseGradParam { struct ElementwiseGradParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册