未验证 提交 23728301 编写于 作者: C cc 提交者: GitHub

Save all scales as vector, test=develop (#3867)

上级 bc66d2be
......@@ -161,8 +161,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
<< "INT8 mode: Conv should has weight_scale attr";
auto conv_weight_d = conv_weight_t->mutable_data<int8_t>();
// compute new conv_weight for int8
auto weight_scale =
conv_op_desc->GetInputScale<std::vector<float>>(weight_name);
auto weight_scale = conv_op_desc->GetInputScale(weight_name);
if (conv_type_ == "conv2d_transpose" && !depthwise) {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] *
conv_weight_t->dims()[3];
......
......@@ -74,22 +74,15 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
auto op_desc = *matched.at("mul")->stmt()->op_info();
// Get the input scale from mul
float x_scale{};
std::vector<float> x_scale_vct;
std::vector<float> y_scale_vct;
auto y_var_node = matched.at("W")->arg();
auto input_x_name = op_desc.Input("X").front();
auto input_y_name = op_desc.Input("Y").front();
bool is_quantized_op = op_desc.HasInputScale(input_x_name) &&
op_desc.HasInputScale(input_y_name);
if (is_quantized_op) {
x_scale = op_desc.GetInputScale<float>(input_x_name);
if (y_var_node->is_weight) { // the scale of y is a vector
y_scale_vct =
op_desc.GetInputScale<std::vector<float>>(op_desc.Input("Y").front());
} else {
y_scale_vct.push_back( // the scale of y is scalar
op_desc.GetInputScale<float>(op_desc.Input("Y").front()));
}
x_scale_vct = op_desc.GetInputScale(input_x_name);
y_scale_vct = op_desc.GetInputScale(op_desc.Input("Y").front());
}
op_desc.mutable_inputs()->clear();
......@@ -108,12 +101,8 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
// Set the input scale into fc
if (is_quantized_op) {
op_desc.SetInputScale(matched.at("x")->arg()->name, x_scale);
if (y_var_node->is_weight) {
op_desc.SetInputScale(matched.at("W")->arg()->name, y_scale_vct);
} else {
op_desc.SetInputScale(matched.at("W")->arg()->name, y_scale_vct.front());
}
op_desc.SetInputScale(matched.at("x")->arg()->name, x_scale_vct);
op_desc.SetInputScale(matched.at("W")->arg()->name, y_scale_vct);
}
return op_desc;
......
......@@ -64,7 +64,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
for (auto* quantized_node : outlinks) {
// save input scale in quantized op by input argname + index
auto op_desc = *quantized_node->stmt()->mutable_op_info();
op_desc.SetInputScale(out_act_name, scale_value);
op_desc.SetInputScale(out_act_name, {scale_value});
op_desc.SetAttr<int>("bit_length", bit_length);
op_desc.UpdateAllInputs(out_act_name, in_act_name);
quantized_node->stmt()->ResetOp(op_desc, graph->valid_places());
......@@ -150,7 +150,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
auto quantized_weight_t =
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
std::vector<float> weight_scale;
int weight_scale_size;
int weight_scale_size = 0;
if (quantized_op_type_ == "conv2d" ||
quantized_op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {quantized_op_input->arg()->name});
......@@ -348,7 +348,7 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
// Save quantization info in op_info attr
auto op_info = *quantized_node->stmt()->op_info();
op_info.SetAttr<int>("bit_length", bit_length);
op_info.SetInputScale(output_act_name, scale_value);
op_info.SetInputScale(output_act_name, {scale_value});
op_info.UpdateAllInputs(output_act_name, input_act_name);
quantized_node->stmt()->ResetOp(op_info, graph->valid_places());
......
......@@ -60,7 +60,7 @@ void QuantizedOpAttributesInferencePass::Apply(
auto& out_inst = out_op_node->AsStmt();
auto out_op_info = out_inst.op_info();
if (!out_op_info->HasInputScale(out_var_node_name)) continue;
auto input_scale = out_op_info->GetInputScale<float>(out_var_node_name);
auto input_scale = out_op_info->GetInputScale(out_var_node_name)[0];
if (!found) {
found = true;
output_scale = input_scale;
......@@ -69,14 +69,15 @@ void QuantizedOpAttributesInferencePass::Apply(
}
}
if (found) {
inst.mutable_op_info()->SetOutputScale(out_var_node_name, output_scale);
inst.mutable_op_info()->SetOutputScale(out_var_node_name,
{output_scale});
} else if (op_info->HasAttr("out_threshold")) {
// Only consider one output, there are only one out_threshold
int bit_length = op_info->GetAttr<int>("bit_length");
int range = (1 << (bit_length - 1)) - 1;
output_scale = op_info->GetAttr<float>("out_threshold");
inst.mutable_op_info()->SetOutputScale(out_var_node_name,
output_scale / range);
{output_scale / range});
} else {
is_quantized = false;
}
......
......@@ -119,7 +119,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
instruct.mutable_op_info()->SetOutputScale(
out_node_name,
one_adj_instruct.op_info()->GetInputScale<float>(out_node_name));
one_adj_instruct.op_info()->GetInputScale(out_node_name));
auto update_desc = *instruct.mutable_op_info();
instruct.ResetOp(update_desc, graph->valid_places());
......
......@@ -463,7 +463,7 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasInputScale(var_node_name)) {
input_data_scales.push_back(
any_inst.op_info()->GetInputScale<float>(var_node_name));
any_inst.op_info()->GetInputScale(var_node_name)[0]);
}
}
for (auto &var_node : output_var_nodes) {
......@@ -473,7 +473,7 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasOutputScale(var_node_name)) {
output_data_scales.push_back(
any_inst.op_info()->GetOutputScale<float>(var_node_name));
any_inst.op_info()->GetOutputScale(var_node_name)[0]);
}
}
if (input_data_scales.size() > 0) {
......
......@@ -108,7 +108,7 @@ static bool InferScale(Node* var_node, Node* op_node, float* scale) {
found = InferScaleFromSubgraph(var_name, op_info, scale, false);
} else {
if (op_info->HasInputScale(var_name)) {
*scale = op_info->GetInputScale<float>(var_name);
*scale = op_info->GetInputScale(var_name)[0];
found = true;
} else {
// Obtain the output_scale from one of its previous Ops
......@@ -121,7 +121,7 @@ static bool InferScale(Node* var_node, Node* op_node, float* scale) {
found = InferScaleFromSubgraph(var_name, prev_op_info, scale, true);
} else {
if (prev_op_info->HasOutputScale(var_name)) {
*scale = prev_op_info->GetOutputScale<float>(var_name);
*scale = prev_op_info->GetOutputScale(var_name)[0];
found = true;
}
}
......
......@@ -22,7 +22,7 @@
namespace paddle {
namespace lite {
std::string int2string(int index) {
static std::string int2string(int index) {
const int BUFFER_LENGTH = 30;
char buffer[BUFFER_LENGTH];
int num = snprintf(buffer, sizeof(buffer), "%d", index);
......@@ -262,17 +262,6 @@ bool OpInfo::HasOutputScale(const std::string &output_name) const {
}
}
template <>
void OpInfo::SetInputScale(const std::string &input_name,
const float &scale_value) {
std::string argname;
int index;
CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index));
SetAttr<float>(argname + int2string(index) + "_scale", scale_value);
}
template <>
void OpInfo::SetInputScale(const std::string &input_name,
const std::vector<float> &scale_value) {
std::string argname;
......@@ -283,25 +272,31 @@ void OpInfo::SetInputScale(const std::string &input_name,
scale_value);
}
template <>
void OpInfo::SetOutputScale(const std::string &output_name,
const float &scale_value) {
const std::vector<float> &scale_value) {
std::string argname;
int index;
CHECK(GetOutputArgname(output_name, &argname));
CHECK(GetOutputIndex(output_name, &index));
SetAttr<float>(argname + int2string(index) + "_scale", scale_value);
SetAttr<std::vector<float>>(argname + int2string(index) + "_scale",
scale_value);
}
template <>
void OpInfo::SetOutputScale(const std::string &output_name,
const std::vector<float> &scale_value) {
std::vector<float> OpInfo::GetInputScale(const std::string &input_name) const {
std::string argname;
int index;
CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index));
return GetAttr<std::vector<float>>(argname + int2string(index) + "_scale");
}
std::vector<float> OpInfo::GetOutputScale(
const std::string &output_name) const {
std::string argname;
int index;
CHECK(GetOutputArgname(output_name, &argname));
CHECK(GetOutputIndex(output_name, &index));
SetAttr<std::vector<float>>(argname + int2string(index) + "_scale",
scale_value);
return GetAttr<std::vector<float>>(argname + int2string(index) + "_scale");
}
} // namespace lite
......
......@@ -30,8 +30,6 @@
namespace paddle {
namespace lite {
std::string int2string(int index);
// For registry factory.
struct Registry {
void Touch() {}
......@@ -231,38 +229,6 @@ class OpInfo : public cpp::OpDesc {
return OutputArgumentNames();
}
bool GetInputArgname(const std::string &value_name, std::string *out) const;
bool GetOutputArgname(const std::string &value_name, std::string *out) const;
bool GetInputIndex(const std::string &input_name, int *out) const;
bool GetOutputIndex(const std::string &output_name, int *out) const;
bool HasInputScale(const std::string &input_name) const;
bool HasOutputScale(const std::string &output_name) const;
template <typename T>
void SetInputScale(const std::string &input_name, const T &scale_value);
template <typename T>
void SetOutputScale(const std::string &output_name, const T &scale_value);
template <typename T>
T GetInputScale(const std::string &input_name) const {
std::string argname;
int index;
CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index));
return GetAttr<T>(argname + int2string(index) + "_scale");
}
template <typename T>
T GetOutputScale(const std::string &output_name) const {
std::string argname;
int index;
CHECK(GetOutputArgname(output_name, &argname));
CHECK(GetOutputIndex(output_name, &index));
return GetAttr<T>(argname + int2string(index) + "_scale");
}
void UpdateAllInputs(const std::string &from, const std::string &to) {
for (auto &item : inputs_) {
for (auto &var : item.second) {
......@@ -278,6 +244,26 @@ class OpInfo : public cpp::OpDesc {
}
}
}
bool GetInputArgname(const std::string &value_name, std::string *out) const;
bool GetOutputArgname(const std::string &value_name, std::string *out) const;
bool GetInputIndex(const std::string &input_name, int *out) const;
bool GetOutputIndex(const std::string &output_name, int *out) const;
bool HasInputScale(const std::string &input_name) const;
bool HasOutputScale(const std::string &output_name) const;
void SetInputScale(const std::string &input_name,
const std::vector<float> &scale_value);
void SetOutputScale(const std::string &output_name,
const std::vector<float> &scale_value);
// For conv2d, depthwise_conv2d and mul, the scale of weight are a vector.
// Otherwise, all input and output scales are scalar, but we save these
// as vecotr.
std::vector<float> GetInputScale(const std::string &input_name) const;
std::vector<float> GetOutputScale(const std::string &output_name) const;
};
} // namespace lite
......
......@@ -103,11 +103,11 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto filter_name = op_info->Input("Filter").front();
auto output_name = op_info->Output("Output").front();
if (op_info->HasInputScale(input_name))
input_scale = op_info->GetInputScale<float>(input_name);
input_scale = op_info->GetInputScale(input_name)[0];
if (op_info->HasInputScale(filter_name))
weight_scale = op_info->GetInputScale<std::vector<float>>(filter_name);
weight_scale = op_info->GetInputScale(filter_name);
if (op_info->HasOutputScale(output_name)) {
output_scale = op_info->GetOutputScale<float>(output_name);
output_scale = op_info->GetOutputScale(output_name)[0];
}
VLOG(3) << "has output scale:" << output_scale;
} else {
......
......@@ -61,11 +61,11 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto weight_name = op_info->Input("W").front();
auto out_name = op_info->Output("Out").front();
if (op_info->HasInputScale(input_name))
input_scale = op_info->GetInputScale<float>(input_name);
input_scale = op_info->GetInputScale(input_name)[0];
if (op_info->HasInputScale(weight_name))
w_scale = op_info->GetInputScale<std::vector<float>>(weight_name);
w_scale = op_info->GetInputScale(weight_name);
if (op_info->HasOutputScale(out_name))
out_scale = op_info->GetOutputScale<float>(out_name);
out_scale = op_info->GetOutputScale(out_name)[0];
} else {
return FAILED;
}
......
......@@ -94,9 +94,9 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto x_name = op_info->Input("X").front();
auto out_name = op_info->Output("Out").front();
if (op_info->HasInputScale(x_name))
x_scale = op_info->GetInputScale<float>(x_name);
x_scale = op_info->GetInputScale(x_name)[0];
if (op_info->HasOutputScale(out_name))
out_scale = op_info->GetOutputScale<float>(out_name);
out_scale = op_info->GetOutputScale(out_name)[0];
} else {
LOG(WARNING) << "Do not enable_int8";
return FAILED;
......
......@@ -52,9 +52,9 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto x_name = op_info->Input("X").front();
auto out_name = op_info->Output("Out").front();
if (op_info->HasInputScale(x_name))
input_scale = op_info->GetInputScale<float>(x_name);
input_scale = op_info->GetInputScale(x_name)[0];
if (op_info->HasOutputScale(out_name))
out_scale = op_info->GetOutputScale<float>(out_name);
out_scale = op_info->GetOutputScale(out_name)[0];
} else {
LOG(WARNING) << "Do not enable_int8";
return FAILED;
......
......@@ -137,12 +137,11 @@ class ConvOpLite : public OpLite {
auto filter_name = op_info->Input("Filter").front();
auto output_name = op_info->Output("Output").front();
if (op_info->HasInputScale(input_name))
param_.input_scale = op_info->GetInputScale<float>(input_name);
param_.input_scale = op_info->GetInputScale(input_name)[0];
if (op_info->HasInputScale(filter_name))
param_.weight_scale =
op_info->GetInputScale<std::vector<float>>(filter_name);
param_.weight_scale = op_info->GetInputScale(filter_name);
if (op_info->HasOutputScale(output_name)) {
param_.output_scale = op_info->GetOutputScale<float>(output_name);
param_.output_scale = op_info->GetOutputScale(output_name)[0];
}
}
......
......@@ -109,12 +109,11 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
auto weight_name = op_info->Input("W").front();
auto out_name = op_info->Output("Out").front();
if (op_info->HasInputScale(input_name))
param_.input_scale = op_info->GetInputScale<float>(input_name);
param_.input_scale = op_info->GetInputScale(input_name)[0];
if (op_info->HasInputScale(weight_name))
param_.weight_scale =
op_info->GetInputScale<std::vector<float>>(weight_name);
param_.weight_scale = op_info->GetInputScale(weight_name);
if (op_info->HasOutputScale(out_name))
param_.output_scale = op_info->GetOutputScale<float>(out_name);
param_.output_scale = op_info->GetOutputScale(out_name)[0];
}
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册