未验证 提交 e45e6dc6 编写于 作者: C cc 提交者: GitHub

Support dygraph quantized model, test=develop (#3974)

上级 769ba40b
......@@ -34,20 +34,25 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
}
// fuse quantized node and dequant node
for (auto& op_type :
{"conv2d", "mul", "depthwise_conv2d", "conv2d_transpose"}) {
std::vector<std::string> quantized_op_types = {
"conv2d", "depthwise_conv2d", "conv2d_transpose", "mul"};
for (auto& op_type : quantized_op_types) {
fusion::DequantOpFuser fuser(op_type);
fuser(graph.get());
}
for (auto& op_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) {
for (auto& op_type : quantized_op_types) {
fusion::ChannelWiseDequantOpFuser fuser(op_type);
fuser(graph.get());
}
// process quant_dequant_node
fusion::DeleteQuantDequantOpFuser dqd_fuser;
dqd_fuser(graph.get());
std::vector<std::string> quant_dequant_op_types = {
"fake_quantize_dequantize_abs_max",
"fake_quantize_dequantize_moving_average_abs_max"};
for (auto& op_type : quant_dequant_op_types) {
fusion::DeleteQuantDequantOpFuser dqd_fuser(op_type);
dqd_fuser(graph.get());
}
}
} // namespace mir
......
......@@ -175,11 +175,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
for (int i = 0; i < weight_scale_size; i++) {
weight_scale.push_back(whole_weight_scale);
}
// Arm CPU does not support conv2d_transpose
if (quantized_op_type_ != "conv2d_transpose") {
op_desc.SetAttr("enable_int8", true);
}
op_desc.SetAttr("enable_int8", true);
op_desc.SetInputScale(weight_name, weight_scale);
// change the weight from the float type to int8 type.
......@@ -284,7 +280,6 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
op_desc.SetInput("X", {quantized_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
}
// Arm CPU does not support conv2d_transpose
if (quantized_op_type_ != "conv2d_transpose") {
op_desc.SetAttr("enable_int8", true);
}
......@@ -320,30 +315,33 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
}
void DeleteQuantDequantOpFuser::BuildPattern() {
std::string quant_dequant_op_type =
"fake_quantize_dequantize_moving_average_abs_max";
auto* input_scale_node =
VarNode("input_scale_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_node =
VarNode("input_act_node")->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_node = OpNode("quant_dequant_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* input_act_node = VarNode("input_act_node")
->assert_is_op_input(quant_dequant_op_type_, "X");
auto* quant_dequant_node =
OpNode("quant_dequant_node", quant_dequant_op_type_)
->assert_is_op(quant_dequant_op_type_);
auto* output_scale_node =
VarNode("output_scale_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
->assert_is_op_output(quant_dequant_op_type_, "OutScale");
auto* output_act_node =
VarNode("output_act_node")
->assert_is_op_output(quant_dequant_op_type, "Out");
quant_dequant_node->LinksFrom({input_scale_node, input_act_node});
->assert_is_op_output(quant_dequant_op_type_, "Out");
if (quant_dequant_op_type_ ==
"fake_quantize_dequantize_moving_average_abs_max") {
auto* input_scale_node =
VarNode("input_scale_node")
->assert_is_op_input(quant_dequant_op_type_, "InScale");
quant_dequant_node->LinksFrom({input_scale_node, input_act_node});
} else {
quant_dequant_node->LinksFrom({input_act_node});
}
output_scale_node->LinksFrom({quant_dequant_node});
output_act_node->LinksFrom({quant_dequant_node});
}
void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto* input_scale_node = matched.at("input_scale_node");
auto* input_act_node = matched.at("input_act_node");
auto* quant_dequant_node = matched.at("quant_dequant_node");
auto* output_scale_node = matched.at("output_scale_node");
......@@ -373,7 +371,12 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
}
// delete nodes and edges
std::set<const Node*> nodes2rm = {
input_scale_node, quant_dequant_node, output_scale_node, output_act_node};
quant_dequant_node, output_scale_node, output_act_node};
if (quant_dequant_op_type_ ==
"fake_quantize_dequantize_moving_average_abs_max") {
auto* input_scale_node = matched.at("input_scale_node");
nodes2rm.insert(input_scale_node);
}
GraphSafeRemoveNodes(graph, nodes2rm);
}
......
......@@ -86,17 +86,22 @@ class ChannelWiseDequantOpFuser : public FuseBase {
std::string quantized_op_type_{};
};
/* The pattern like "fake_quantize_dequantize_moving_average_abs_max +
* quantized_op" can be deteted by this fuser. The fuser modifies the input
* scale for the quantized_op and deletes the fake_quant_dequant_op.
/* The pattern like "fake_quantize_dequantize_op + quantized_op" can be
* deteted by this fuser. The fuser modifies the input scale for the
* quantized_op and deletes the fake_quant_dequant_op.
*/
class DeleteQuantDequantOpFuser : public FuseBase {
public:
explicit DeleteQuantDequantOpFuser(const std::string& quant_dequant_op_type)
: quant_dequant_op_type_(quant_dequant_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quant_dequant_op_type_{};
};
} // namespace fusion
......
......@@ -89,6 +89,7 @@ add_operator(fake_quantize_range_abs_max_op extra SRCS fake_quantize_range_abs_m
add_operator(sequence_expand_as_op_lite extra SRCS sequence_expand_as_op.cc DEPS ${op_DEPS})
add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
add_operator(fake_quantize_dequantize_abs_max_op extra SRCS fake_quantize_dequantize_abs_max.cc DEPS ${op_DEPS})
add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS})
add_operator(split_lod_tensor_op_lite extra SRCS split_lod_tensor_op.cc DEPS ${op_DEPS})
add_operator(merge_lod_tensor_op_lite extra SRCS merge_lod_tensor_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/fake_quantize_dequantize_abs_max.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fake_quantize_dequantize_abs_max,
paddle::lite::operators::FakeQuantizeDequantizeAbsMaxOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class FakeQuantizeDequantizeAbsMaxOpLite : public OpLite {
public:
FakeQuantizeDequantizeAbsMaxOpLite() {}
explicit FakeQuantizeDequantizeAbsMaxOpLite(const std::string &type)
: OpLite(type) {}
bool CheckShape() const override { return true; }
bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
auto out_scale = op_desc.Output("OutScale").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.out_scale = scope->FindVar(out_scale)->GetMutable<lite::Tensor>();
param_.bit_length = op_desc.GetAttr<int>("bit_length");
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "fake_quantize_dequantize_abs_max";
}
private:
mutable FakeQuantDequantAbsMaxParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -678,6 +678,13 @@ struct FakeChannelWiseDequantizeMaxAbsParam : ParamBase {
std::vector<int> quant_bits;
};
struct FakeQuantDequantAbsMaxParam : ParamBase {
const lite::Tensor* x{};
lite::Tensor* out{};
lite::Tensor* out_scale{};
int bit_length;
};
/// ----------------------- sgd operators ----------------------
struct SGDParam : ParamBase {
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册