未验证 提交 ba85799c 编写于 作者: J juncaipeng 提交者: GitHub

add channel_wise_dequantized_max_abs op and ChannelWiseDequantOpFuser (#2368)

* add channel_wise_dequantized_max_abs op and ChannelWiseDequantOpFuser, test=develop
上级 f077276f
...@@ -34,13 +34,16 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -34,13 +34,16 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} }
// fuse quantized node and dequant node // fuse quantized node and dequant node
std::vector<std::string> quantized_op_types = { for (auto& op_type : {"conv2d", "mul", "depthwise_conv2d"}) {
"conv2d", "mul", "depthwise_conv2d"};
for (auto& op_type : quantized_op_types) {
fusion::DequantOpFuser fuser(op_type); fusion::DequantOpFuser fuser(op_type);
fuser(graph.get()); fuser(graph.get());
} }
for (auto& op_type : {"conv2d", "depthwise_conv2d"}) {
fusion::ChannelWiseDequantOpFuser fuser(op_type);
fuser(graph.get());
}
// delete quant_dequant_node // delete quant_dequant_node
for (auto op_type : {"pool2d", "elementwise_add"}) { for (auto op_type : {"pool2d", "elementwise_add"}) {
fusion::DeleteQuantDequantOpFuser fuser(op_type); fusion::DeleteQuantDequantOpFuser fuser(op_type);
......
...@@ -79,23 +79,26 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -79,23 +79,26 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) {
void DequantOpFuser::BuildPattern() { void DequantOpFuser::BuildPattern() {
std::string weight_name = ""; std::string weight_name = "";
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { if (quantized_op_type_ == "conv2d" ||
quantized_op_type_ == "depthwise_conv2d") {
weight_name = "Filter"; weight_name = "Filter";
} else { } else {
weight_name = "Y"; weight_name = "Y";
} }
auto* quantized_op_input = auto* quantized_op_input = VarNode("quantized_op_input")
VarNode("quantized_op_input")->assert_is_op_input(op_type_)->AsInput(); ->assert_is_op_input(quantized_op_type_)
auto* quantized_op_weight = VarNode("quantized_op_weight") ->AsInput();
->assert_is_op_input(op_type_, weight_name) auto* quantized_op_weight =
->AsInput(); VarNode("quantized_op_weight")
auto* quantized_op = OpNode("quantized_op", op_type_) ->assert_is_op_input(quantized_op_type_, weight_name)
->assert_is_op(op_type_) ->AsInput();
auto* quantized_op = OpNode("quantized_op", quantized_op_type_)
->assert_is_op(quantized_op_type_)
->AsIntermediate(); ->AsIntermediate();
auto* quantized_op_out = auto* quantized_op_out =
VarNode("quantized_op_out") VarNode("quantized_op_out")
->assert_is_op_output(op_type_) ->assert_is_op_output(quantized_op_type_)
->assert_is_op_input("fake_dequantize_max_abs", "X") ->assert_is_op_input("fake_dequantize_max_abs", "X")
->AsIntermediate(); ->AsIntermediate();
auto* dequant_op = OpNode("dequant_op", "fake_dequantize_max_abs") auto* dequant_op = OpNode("dequant_op", "fake_dequantize_max_abs")
...@@ -110,12 +113,13 @@ void DequantOpFuser::BuildPattern() { ...@@ -110,12 +113,13 @@ void DequantOpFuser::BuildPattern() {
quantized_op_out->LinksFrom({quantized_op}); quantized_op_out->LinksFrom({quantized_op});
dequant_op->LinksFrom({quantized_op_out}); dequant_op->LinksFrom({quantized_op_out});
dequant_op_out->LinksFrom({dequant_op}); dequant_op_out->LinksFrom({dequant_op});
VLOG(4) << "DeQuantOpFuser BuildPattern op_type:" << op_type_;
VLOG(4) << "DeQuantOpFuser BuildPattern op_type:" << quantized_op_type_;
} }
void DequantOpFuser::InsertNewNode(SSAGraph* graph, void DequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) { const key2nodes_t& matched) {
auto* quant_op_input = matched.at("quantized_op_input"); auto* quantized_op_input = matched.at("quantized_op_input");
auto* quantized_op_weight = matched.at("quantized_op_weight"); auto* quantized_op_weight = matched.at("quantized_op_weight");
auto* quantized_op = matched.at("quantized_op"); auto* quantized_op = matched.at("quantized_op");
auto* dequant_op = matched.at("dequant_op"); auto* dequant_op = matched.at("dequant_op");
...@@ -142,14 +146,15 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -142,14 +146,15 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>(); scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
std::vector<float> weight_scale; std::vector<float> weight_scale;
int weight_scale_size; int weight_scale_size;
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { if (quantized_op_type_ == "conv2d" ||
op_desc.SetInput("Input", {quant_op_input->arg()->name}); quantized_op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {quantized_op_input->arg()->name});
op_desc.SetOutput("Output", {dequant_op_out->arg()->name}); op_desc.SetOutput("Output", {dequant_op_out->arg()->name});
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
// be Cout. // be Cout.
weight_scale_size = quantized_weight_t->dims()[0]; weight_scale_size = quantized_weight_t->dims()[0];
} else if (op_type_ == "mul") { } else if (quantized_op_type_ == "mul") {
op_desc.SetInput("X", {quant_op_input->arg()->name}); op_desc.SetInput("X", {quantized_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
// Fc weight: Cin * Cout, the weight_scale_size should be Cout. // Fc weight: Cin * Cout, the weight_scale_size should be Cout.
weight_scale_size = quantized_weight_t->dims()[1]; weight_scale_size = quantized_weight_t->dims()[1];
...@@ -174,11 +179,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, ...@@ -174,11 +179,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
quantized_weight_t->set_precision(PRECISION(kInt8)); quantized_weight_t->set_precision(PRECISION(kInt8));
// new op and relink nodes // new op and relink nodes
auto new_quantized_op = LiteOpRegistry::Global().Create(op_type_); auto new_quantized_op = LiteOpRegistry::Global().Create(quantized_op_type_);
new_quantized_op->Attach(op_desc, scope); new_quantized_op->Attach(op_desc, scope);
auto* new_quantized_op_node = auto* new_quantized_op_node =
graph->GraphCreateInstructNode(new_quantized_op, valid_places); graph->GraphCreateInstructNode(new_quantized_op, valid_places);
IR_NODE_LINK_TO(quant_op_input, new_quantized_op_node); IR_NODE_LINK_TO(quantized_op_input, new_quantized_op_node);
IR_NODE_LINK_TO(quantized_op_weight, new_quantized_op_node); IR_NODE_LINK_TO(quantized_op_weight, new_quantized_op_node);
IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out); IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out);
} }
...@@ -188,6 +193,107 @@ cpp::OpDesc DequantOpFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -188,6 +193,107 @@ cpp::OpDesc DequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
return op_desc; return op_desc;
} }
void ChannelWiseDequantOpFuser::BuildPattern() {
std::string dequant_op_type = "fake_channel_wise_dequantize_max_abs";
auto* quantized_op_input = VarNode("quantized_op_input")
->assert_is_op_input(quantized_op_type_)
->AsInput();
auto* quantized_op_weight =
VarNode("quantized_op_weight")
->assert_is_op_input(quantized_op_type_, "Filter")
->AsInput();
auto* quantized_op = OpNode("quantized_op", quantized_op_type_)
->assert_is_op(quantized_op_type_)
->AsIntermediate();
auto* quantized_op_out = VarNode("quantized_op_out")
->assert_is_op_output(quantized_op_type_)
->assert_is_op_input(dequant_op_type, "X")
->AsIntermediate();
auto* dequant_op_channel_scale = VarNode("dequant_op_channel_scale")
->assert_is_op_input(dequant_op_type)
->AsIntermediate();
auto* dequant_op = OpNode("dequant_op", dequant_op_type)
->assert_is_op(dequant_op_type)
->AsIntermediate();
auto* dequant_op_out = VarNode("dequant_op_out")
->assert_is_op_output(dequant_op_type, "Out")
->AsOutput();
quantized_op->LinksFrom({quantized_op_input, quantized_op_weight});
quantized_op_out->LinksFrom({quantized_op});
dequant_op->LinksFrom({quantized_op_out, dequant_op_channel_scale});
dequant_op_out->LinksFrom({dequant_op});
VLOG(4) << "ChannelWiseDequantOpFuser BuildPattern op_type:"
<< quantized_op_type_;
}
void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto* quantized_op_input = matched.at("quantized_op_input");
auto* quantized_op_weight = matched.at("quantized_op_weight");
auto* quantized_op = matched.at("quantized_op");
auto* dequant_op_channel_scale = matched.at("dequant_op_channel_scale");
auto* dequant_op = matched.at("dequant_op");
auto* dequant_op_out = matched.at("dequant_op_out");
// obtain input_scale and weight_scale
auto* scope = quantized_op->stmt()->op()->scope();
auto& valid_places = quantized_op->stmt()->op()->valid_places();
float input_scale =
quantized_op->stmt()->op_info()->GetAttr<float>("input_scale");
std::vector<float> weight_scale;
std::vector<int> quant_bits =
dequant_op->stmt()->op_info()->GetAttr<std::vector<int>>("quant_bits");
int weight_bit_length = quant_bits[0];
int range = ((1 << (weight_bit_length - 1)) - 1);
auto channel_scale_name = dequant_op_channel_scale->arg()->name;
auto channel_scale_tensor =
scope->FindVar(channel_scale_name)->GetMutable<lite::Tensor>();
auto* channel_scale_data = channel_scale_tensor->data<float>();
for (int i = 0; i < channel_scale_tensor->data_size(); i++) {
weight_scale.push_back(channel_scale_data[i] / range);
}
// set op desc
cpp::OpDesc op_desc = *quantized_op->stmt()->op_info();
op_desc.SetInput("Input", {quantized_op_input->arg()->name});
op_desc.SetOutput("Output", {dequant_op_out->arg()->name});
op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("input_scale", input_scale);
op_desc.SetAttr("weight_scale", weight_scale);
// change the weight from the float type to int8 type.
auto quantized_weight_var_name = quantized_op_weight->arg()->name;
auto quantized_weight_t =
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
Tensor temp_tensor;
temp_tensor.CopyDataFrom(*quantized_weight_t);
float* temp_data = temp_tensor.mutable_data<float>();
int8_t* quantized_weight_data = quantized_weight_t->mutable_data<int8_t>();
for (size_t i = 0; i < quantized_weight_t->data_size(); i++) {
quantized_weight_data[i] = static_cast<int8_t>(temp_data[i]);
}
quantized_weight_t->set_persistable(true);
quantized_weight_t->set_precision(PRECISION(kInt8));
// new op and relink nodes
auto new_quantized_op = LiteOpRegistry::Global().Create(quantized_op_type_);
new_quantized_op->Attach(op_desc, scope);
auto* new_quantized_op_node =
graph->GraphCreateInstructNode(new_quantized_op, valid_places);
IR_NODE_LINK_TO(quantized_op_input, new_quantized_op_node);
IR_NODE_LINK_TO(quantized_op_weight, new_quantized_op_node);
IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out);
}
cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
void DeleteQuantDequantOpFuser::BuildPattern() { void DeleteQuantDequantOpFuser::BuildPattern() {
std::string quant_dequant_op_type = std::string quant_dequant_op_type =
"fake_quantize_dequantize_moving_average_abs_max"; "fake_quantize_dequantize_moving_average_abs_max";
......
...@@ -24,18 +24,21 @@ namespace mir { ...@@ -24,18 +24,21 @@ namespace mir {
namespace fusion { namespace fusion {
/* The model trained by fluid quantization is a simulation of real int8. /* The model trained by fluid quantization is a simulation of real int8.
* The quantized Ops(conv2d, mul, depthwise conv2d etc) have fake_quantop * The quantized Ops(conv2d, mul, depthwise conv2d etc) have fake_quant op
* in front and fake_dequantop behind. * in front and fake_dequant op behind.
* *
* When in int8 mode, the pattern like "fake_quant + quantized_op + * For int8 mode, the pattern like "fake_quant + quantized_op + fake_dequant"
* fake_dequant" * can be processed by the following three fuser. The fuser extract the
* can be detected by this fuser. The fuser extract the input_scale and * input_scale and the weight_scale info from fake_quant, fake_dequant op and
* the weight_scale info from fake_quant, fake_dequant op and fuse those into * fuse those into the quantized_op.
* the quantized_op.
* In addition, the fuser delete fake_quant and fake_dequant op in the graph at * In addition, the fuser delete fake_quant and fake_dequant op in the graph at
* the last. * the last.
*/ */
/* DeleteQuantOpFuser process
* fake_quantize_range_abs_max/fake_quantize_moving_average_abs_max
* + conv2d/mul/depthwise.
*/
class DeleteQuantOpFuser : public FuseBase { class DeleteQuantOpFuser : public FuseBase {
public: public:
explicit DeleteQuantOpFuser(const std::string& quant_op_type) explicit DeleteQuantOpFuser(const std::string& quant_op_type)
...@@ -50,9 +53,12 @@ class DeleteQuantOpFuser : public FuseBase { ...@@ -50,9 +53,12 @@ class DeleteQuantOpFuser : public FuseBase {
std::string quant_op_type_{}; std::string quant_op_type_{};
}; };
/* DequantOpFuser process conv2d/depthwise_conv2d/mul + fake_dequantize_max_abs.
*/
class DequantOpFuser : public FuseBase { class DequantOpFuser : public FuseBase {
public: public:
explicit DequantOpFuser(const std::string& op_type) : op_type_(op_type) {} explicit DequantOpFuser(const std::string& quantized_op_type)
: quantized_op_type_(quantized_op_type) {}
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
...@@ -60,7 +66,24 @@ class DequantOpFuser : public FuseBase { ...@@ -60,7 +66,24 @@ class DequantOpFuser : public FuseBase {
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private: private:
std::string op_type_{}; std::string quantized_op_type_{};
};
/* ChannelWiseDequantOpFuser process conv2d/depthwise_conv2d +
* fake_channel_wise_dequantize_max_abs.
*/
class ChannelWiseDequantOpFuser : public FuseBase {
public:
explicit ChannelWiseDequantOpFuser(const std::string& quantized_op_type)
: quantized_op_type_(quantized_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quantized_op_type_{};
}; };
/* The pattern like "fake_quantize_dequantize_moving_average_abs_max + /* The pattern like "fake_quantize_dequantize_moving_average_abs_max +
......
...@@ -76,6 +76,7 @@ add_operator(sequence_expand_as_op_lite extra SRCS sequence_expand_as_op.cc DEPS ...@@ -76,6 +76,7 @@ add_operator(sequence_expand_as_op_lite extra SRCS sequence_expand_as_op.cc DEPS
add_operator(range_op extra SRCS range_op.cc DEPS ${op_DEPS}) add_operator(range_op extra SRCS range_op.cc DEPS ${op_DEPS})
add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS}) add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS})
add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS}) add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/fake_channel_wise_dequantize_max_abs.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(
fake_channel_wise_dequantize_max_abs,
paddle::lite::operators::FakeChannelWiseDequantizeMaxAbsOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite {
public:
FakeChannelWiseDequantizeMaxAbsOpLite() {}
explicit FakeChannelWiseDequantizeMaxAbsOpLite(const std::string &type)
: OpLite(type) {}
bool CheckShape() const override { return true; }
bool InferShape() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
auto args = op_desc.Input("Scales");
for (auto arg : args) {
auto *var = scope->FindVar(arg);
if (var != nullptr) {
param_.scale_tensors.push_back(var->GetMutable<lite::Tensor>());
}
}
auto out = op_desc.Output("Out").front();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.quant_bits = op_desc.GetAttr<std::vector<int>>("quant_bits");
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "fake_channel_wise_dequantize_max_abs";
}
private:
mutable FakeChannelWiseDequantizeMaxAbsParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -427,6 +427,13 @@ struct FakeDequantizeMaxAbsParam { ...@@ -427,6 +427,13 @@ struct FakeDequantizeMaxAbsParam {
float max_range; float max_range;
}; };
struct FakeChannelWiseDequantizeMaxAbsParam {
const lite::Tensor* x{};
std::vector<const lite::Tensor*> scale_tensors{};
lite::Tensor* out{};
std::vector<int> quant_bits;
};
/// ----------------------- sgd operators ---------------------- /// ----------------------- sgd operators ----------------------
struct SGDParam { struct SGDParam {
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)}; int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册