未验证 提交 ba85799c 编写于 作者: J juncaipeng 提交者: GitHub

add channel_wise_dequantized_max_abs op and ChannelWiseDequantOpFuser (#2368)

* add channel_wise_dequantized_max_abs op and ChannelWiseDequantOpFuser, test=develop
上级 f077276f
......@@ -34,13 +34,16 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
}
// fuse quantized node and dequant node
std::vector<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d"};
for (auto& op_type : quantized_op_types) {
for (auto& op_type : {"conv2d", "mul", "depthwise_conv2d"}) {
fusion::DequantOpFuser fuser(op_type);
fuser(graph.get());
}
for (auto& op_type : {"conv2d", "depthwise_conv2d"}) {
fusion::ChannelWiseDequantOpFuser fuser(op_type);
fuser(graph.get());
}
// delete quant_dequant_node
for (auto op_type : {"pool2d", "elementwise_add"}) {
fusion::DeleteQuantDequantOpFuser fuser(op_type);
......
......@@ -79,23 +79,26 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) {
void DequantOpFuser::BuildPattern() {
std::string weight_name = "";
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
if (quantized_op_type_ == "conv2d" ||
quantized_op_type_ == "depthwise_conv2d") {
weight_name = "Filter";
} else {
weight_name = "Y";
}
auto* quantized_op_input =
VarNode("quantized_op_input")->assert_is_op_input(op_type_)->AsInput();
auto* quantized_op_weight = VarNode("quantized_op_weight")
->assert_is_op_input(op_type_, weight_name)
auto* quantized_op_input = VarNode("quantized_op_input")
->assert_is_op_input(quantized_op_type_)
->AsInput();
auto* quantized_op = OpNode("quantized_op", op_type_)
->assert_is_op(op_type_)
auto* quantized_op_weight =
VarNode("quantized_op_weight")
->assert_is_op_input(quantized_op_type_, weight_name)
->AsInput();
auto* quantized_op = OpNode("quantized_op", quantized_op_type_)
->assert_is_op(quantized_op_type_)
->AsIntermediate();
auto* quantized_op_out =
VarNode("quantized_op_out")
->assert_is_op_output(op_type_)
->assert_is_op_output(quantized_op_type_)
->assert_is_op_input("fake_dequantize_max_abs", "X")
->AsIntermediate();
auto* dequant_op = OpNode("dequant_op", "fake_dequantize_max_abs")
......@@ -110,12 +113,13 @@ void DequantOpFuser::BuildPattern() {
quantized_op_out->LinksFrom({quantized_op});
dequant_op->LinksFrom({quantized_op_out});
dequant_op_out->LinksFrom({dequant_op});
VLOG(4) << "DeQuantOpFuser BuildPattern op_type:" << op_type_;
VLOG(4) << "DeQuantOpFuser BuildPattern op_type:" << quantized_op_type_;
}
void DequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto* quant_op_input = matched.at("quantized_op_input");
auto* quantized_op_input = matched.at("quantized_op_input");
auto* quantized_op_weight = matched.at("quantized_op_weight");
auto* quantized_op = matched.at("quantized_op");
auto* dequant_op = matched.at("dequant_op");
......@@ -142,14 +146,15 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
std::vector<float> weight_scale;
int weight_scale_size;
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {quant_op_input->arg()->name});
if (quantized_op_type_ == "conv2d" ||
quantized_op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {quantized_op_input->arg()->name});
op_desc.SetOutput("Output", {dequant_op_out->arg()->name});
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
// be Cout.
weight_scale_size = quantized_weight_t->dims()[0];
} else if (op_type_ == "mul") {
op_desc.SetInput("X", {quant_op_input->arg()->name});
} else if (quantized_op_type_ == "mul") {
op_desc.SetInput("X", {quantized_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
// Fc weight: Cin * Cout, the weight_scale_size should be Cout.
weight_scale_size = quantized_weight_t->dims()[1];
......@@ -174,11 +179,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
quantized_weight_t->set_precision(PRECISION(kInt8));
// new op and relink nodes
auto new_quantized_op = LiteOpRegistry::Global().Create(op_type_);
auto new_quantized_op = LiteOpRegistry::Global().Create(quantized_op_type_);
new_quantized_op->Attach(op_desc, scope);
auto* new_quantized_op_node =
graph->GraphCreateInstructNode(new_quantized_op, valid_places);
IR_NODE_LINK_TO(quant_op_input, new_quantized_op_node);
IR_NODE_LINK_TO(quantized_op_input, new_quantized_op_node);
IR_NODE_LINK_TO(quantized_op_weight, new_quantized_op_node);
IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out);
}
......@@ -188,6 +193,107 @@ cpp::OpDesc DequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
return op_desc;
}
void ChannelWiseDequantOpFuser::BuildPattern() {
std::string dequant_op_type = "fake_channel_wise_dequantize_max_abs";
auto* quantized_op_input = VarNode("quantized_op_input")
->assert_is_op_input(quantized_op_type_)
->AsInput();
auto* quantized_op_weight =
VarNode("quantized_op_weight")
->assert_is_op_input(quantized_op_type_, "Filter")
->AsInput();
auto* quantized_op = OpNode("quantized_op", quantized_op_type_)
->assert_is_op(quantized_op_type_)
->AsIntermediate();
auto* quantized_op_out = VarNode("quantized_op_out")
->assert_is_op_output(quantized_op_type_)
->assert_is_op_input(dequant_op_type, "X")
->AsIntermediate();
auto* dequant_op_channel_scale = VarNode("dequant_op_channel_scale")
->assert_is_op_input(dequant_op_type)
->AsIntermediate();
auto* dequant_op = OpNode("dequant_op", dequant_op_type)
->assert_is_op(dequant_op_type)
->AsIntermediate();
auto* dequant_op_out = VarNode("dequant_op_out")
->assert_is_op_output(dequant_op_type, "Out")
->AsOutput();
quantized_op->LinksFrom({quantized_op_input, quantized_op_weight});
quantized_op_out->LinksFrom({quantized_op});
dequant_op->LinksFrom({quantized_op_out, dequant_op_channel_scale});
dequant_op_out->LinksFrom({dequant_op});
VLOG(4) << "ChannelWiseDequantOpFuser BuildPattern op_type:"
<< quantized_op_type_;
}
void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto* quantized_op_input = matched.at("quantized_op_input");
auto* quantized_op_weight = matched.at("quantized_op_weight");
auto* quantized_op = matched.at("quantized_op");
auto* dequant_op_channel_scale = matched.at("dequant_op_channel_scale");
auto* dequant_op = matched.at("dequant_op");
auto* dequant_op_out = matched.at("dequant_op_out");
// obtain input_scale and weight_scale
auto* scope = quantized_op->stmt()->op()->scope();
auto& valid_places = quantized_op->stmt()->op()->valid_places();
float input_scale =
quantized_op->stmt()->op_info()->GetAttr<float>("input_scale");
std::vector<float> weight_scale;
std::vector<int> quant_bits =
dequant_op->stmt()->op_info()->GetAttr<std::vector<int>>("quant_bits");
int weight_bit_length = quant_bits[0];
int range = ((1 << (weight_bit_length - 1)) - 1);
auto channel_scale_name = dequant_op_channel_scale->arg()->name;
auto channel_scale_tensor =
scope->FindVar(channel_scale_name)->GetMutable<lite::Tensor>();
auto* channel_scale_data = channel_scale_tensor->data<float>();
for (int i = 0; i < channel_scale_tensor->data_size(); i++) {
weight_scale.push_back(channel_scale_data[i] / range);
}
// set op desc
cpp::OpDesc op_desc = *quantized_op->stmt()->op_info();
op_desc.SetInput("Input", {quantized_op_input->arg()->name});
op_desc.SetOutput("Output", {dequant_op_out->arg()->name});
op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("input_scale", input_scale);
op_desc.SetAttr("weight_scale", weight_scale);
// change the weight from the float type to int8 type.
auto quantized_weight_var_name = quantized_op_weight->arg()->name;
auto quantized_weight_t =
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
Tensor temp_tensor;
temp_tensor.CopyDataFrom(*quantized_weight_t);
float* temp_data = temp_tensor.mutable_data<float>();
int8_t* quantized_weight_data = quantized_weight_t->mutable_data<int8_t>();
for (size_t i = 0; i < quantized_weight_t->data_size(); i++) {
quantized_weight_data[i] = static_cast<int8_t>(temp_data[i]);
}
quantized_weight_t->set_persistable(true);
quantized_weight_t->set_precision(PRECISION(kInt8));
// new op and relink nodes
auto new_quantized_op = LiteOpRegistry::Global().Create(quantized_op_type_);
new_quantized_op->Attach(op_desc, scope);
auto* new_quantized_op_node =
graph->GraphCreateInstructNode(new_quantized_op, valid_places);
IR_NODE_LINK_TO(quantized_op_input, new_quantized_op_node);
IR_NODE_LINK_TO(quantized_op_weight, new_quantized_op_node);
IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out);
}
cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
void DeleteQuantDequantOpFuser::BuildPattern() {
std::string quant_dequant_op_type =
"fake_quantize_dequantize_moving_average_abs_max";
......
......@@ -24,18 +24,21 @@ namespace mir {
namespace fusion {
/* The model trained by fluid quantization is a simulation of real int8.
* The quantized Ops(conv2d, mul, depthwise conv2d etc) have fake_quantop
* in front and fake_dequantop behind.
* The quantized Ops(conv2d, mul, depthwise conv2d etc) have fake_quant op
* in front and fake_dequant op behind.
*
* When in int8 mode, the pattern like "fake_quant + quantized_op +
* fake_dequant"
* can be detected by this fuser. The fuser extract the input_scale and
* the weight_scale info from fake_quant, fake_dequant op and fuse those into
* the quantized_op.
* For int8 mode, the pattern like "fake_quant + quantized_op + fake_dequant"
* can be processed by the following three fuser. The fuser extract the
* input_scale and the weight_scale info from fake_quant, fake_dequant op and
* fuse those into the quantized_op.
* In addition, the fuser delete fake_quant and fake_dequant op in the graph at
* the last.
*/
/* DeleteQuantOpFuser process
* fake_quantize_range_abs_max/fake_quantize_moving_average_abs_max
* + conv2d/mul/depthwise.
*/
class DeleteQuantOpFuser : public FuseBase {
public:
explicit DeleteQuantOpFuser(const std::string& quant_op_type)
......@@ -50,9 +53,12 @@ class DeleteQuantOpFuser : public FuseBase {
std::string quant_op_type_{};
};
/* DequantOpFuser process conv2d/depthwise_conv2d/mul + fake_dequantize_max_abs.
*/
class DequantOpFuser : public FuseBase {
public:
explicit DequantOpFuser(const std::string& op_type) : op_type_(op_type) {}
explicit DequantOpFuser(const std::string& quantized_op_type)
: quantized_op_type_(quantized_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
......@@ -60,7 +66,24 @@ class DequantOpFuser : public FuseBase {
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string op_type_{};
std::string quantized_op_type_{};
};
/* ChannelWiseDequantOpFuser process conv2d/depthwise_conv2d +
* fake_channel_wise_dequantize_max_abs.
*/
class ChannelWiseDequantOpFuser : public FuseBase {
public:
explicit ChannelWiseDequantOpFuser(const std::string& quantized_op_type)
: quantized_op_type_(quantized_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quantized_op_type_{};
};
/* The pattern like "fake_quantize_dequantize_moving_average_abs_max +
......
......@@ -76,6 +76,7 @@ add_operator(sequence_expand_as_op_lite extra SRCS sequence_expand_as_op.cc DEPS
add_operator(range_op extra SRCS range_op.cc DEPS ${op_DEPS})
add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS})
add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/fake_channel_wise_dequantize_max_abs.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(
fake_channel_wise_dequantize_max_abs,
paddle::lite::operators::FakeChannelWiseDequantizeMaxAbsOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite {
public:
FakeChannelWiseDequantizeMaxAbsOpLite() {}
explicit FakeChannelWiseDequantizeMaxAbsOpLite(const std::string &type)
: OpLite(type) {}
bool CheckShape() const override { return true; }
bool InferShape() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
auto args = op_desc.Input("Scales");
for (auto arg : args) {
auto *var = scope->FindVar(arg);
if (var != nullptr) {
param_.scale_tensors.push_back(var->GetMutable<lite::Tensor>());
}
}
auto out = op_desc.Output("Out").front();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.quant_bits = op_desc.GetAttr<std::vector<int>>("quant_bits");
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "fake_channel_wise_dequantize_max_abs";
}
private:
mutable FakeChannelWiseDequantizeMaxAbsParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -427,6 +427,13 @@ struct FakeDequantizeMaxAbsParam {
float max_range;
};
struct FakeChannelWiseDequantizeMaxAbsParam {
const lite::Tensor* x{};
std::vector<const lite::Tensor*> scale_tensors{};
lite::Tensor* out{};
std::vector<int> quant_bits;
};
/// ----------------------- sgd operators ----------------------
struct SGDParam {
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册