未验证 提交 130db92a 编写于 作者: P Paulina Gacek 提交者: GitHub

Split quant (#47449)

* Split kernel registered, tests for uint/int added

* Split quantized

* Split output scales calculated only once

* NearestInterp test fix reversed

* DequantizeOutputs corrected
上级 c7cd8d98
......@@ -1112,7 +1112,7 @@ struct ResidualElementwise : public PatternBase {
};
// General struct for immutable ops:
// reshape, transpose, slice, shape, nearest-interp
// reshape, transpose, slice, shape, nearest-interp, split
// Forward pass for no weights-op.
// immutable_out is a result of the operator.
struct Immutable : public PatternBase {
......
......@@ -498,7 +498,8 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const {
"slice",
"shape",
"nearest_interp",
"nearest_interp_v2"};
"nearest_interp_v2",
"split"};
StringPairMap var_quant_scales{};
......
......@@ -245,6 +245,54 @@ void CPUQuantizePass::DequantizeOutput(Graph* g,
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}
void CPUQuantizePass::DequantizeOutputs(Graph* g,
Node* op,
std::string output_name,
double scale_to_one,
bool is_unsigned,
std::string scale_attr_name) const {
auto outputs = op->outputs;
PADDLE_ENFORCE_GE(outputs.size(),
1,
platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal or greater than 1.",
op->Name(),
outputs.size()));
std::vector<std::string> quantize_in_node_names(outputs.size());
unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max;
for (size_t i = 0; i < outputs.size(); i++) {
// Create dequantize input variable
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
Node* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
quantize_in_node_names[i] = dequantize_in_node->Name();
// create a dequantize op node for output.
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({quantize_in_node_names[i]}));
deq_desc.SetOutput("Output",
std::vector<std::string>({outputs[i]->Name()}));
deq_desc.SetAttr("Scale", scale);
deq_desc.SetAttr("is_negative_input", !is_unsigned);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.
// link dequantize op
UnlinkNodes(op, outputs[i]);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, outputs[i]);
}
// update op's output
op->Op()->SetOutput(output_name, quantize_in_node_names);
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}
bool CPUQuantizePass::AreScalesPresentForVarNames(
std::vector<std::string> names) const {
bool present = true;
......@@ -730,13 +778,17 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph,
bool is_output_unsigned{false};
auto output_scale =
GetScaleValueForNode(immutable_out, &is_output_unsigned);
DequantizeOutput(g,
immutable_op,
immutable_out,
"Out",
output_scale,
is_output_unsigned);
if (immutable_type == "split") { // ops with multiple outputs
DequantizeOutputs(
g, immutable_op, "Out", output_scale, is_output_unsigned);
} else {
DequantizeOutput(g,
immutable_op,
immutable_out,
"Out",
output_scale,
is_output_unsigned);
}
++quantize_immutable_count;
};
......@@ -1184,6 +1236,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeImmutable(graph, "slice", "Input");
QuantizeImmutable(graph, "nearest_interp", "X");
QuantizeImmutable(graph, "nearest_interp_v2", "X");
QuantizeImmutable(graph, "split", "X");
QuantizeElementwise(graph, "elementwise_add");
QuantizeElementwise(graph, "elementwise_mul");
QuantizeElementwise(graph, "elementwise_sub");
......
......@@ -91,6 +91,14 @@ class CPUQuantizePass : public FusePassBase {
bool is_unsigned,
std::string scale_attr_name = "") const;
// quantize all outputs of given name
void DequantizeOutputs(Graph* g,
Node* op,
std::string output_name,
double scale_to_one,
bool is_unsigned,
std::string scale_attr_name = "") const;
bool AreScalesPresentForVarNames(std::vector<std::string> names) const;
bool AreScalesPresentForNodes(std::initializer_list<Node*> nodes) const;
std::pair<bool, phi::DenseTensor> GetScaleDataByName(
......
......@@ -69,6 +69,9 @@ void SetOp(ProgramDesc* prog,
} else if (type == "slice") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
} else if (type == "split") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs});
} else if (type == "dropout") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
......@@ -556,8 +559,12 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) {
SCALE * S8_MAX);
}
const std::vector<std::string> immutables = {
"reshape2", "transpose2", "slice", "nearest_interp", "nearest_interp_v2"};
const std::vector<std::string> immutables = {"reshape2",
"transpose2",
"slice",
"nearest_interp",
"nearest_interp_v2",
"split"};
class TestImmutables : public testing::TestWithParam<std::string> {};
......
......@@ -42,7 +42,8 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
"fusion_gru",
"fusion_lstm",
"multi_gru",
"slice"});
"slice",
"split"});
const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids");
const auto& op_types_list =
......
......@@ -131,7 +131,7 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs(
is_unsigned = true;
} else if (op->Type() == "transpose2" || op->Type() == "reshape2" ||
op->Type() == "pool2d" || op->Type() == "nearest_interp" ||
op->Type() == "nearest_interp_v2") {
op->Type() == "nearest_interp_v2" || op->Type() == "split") {
auto input_var_name = op->Input("X")[0];
PADDLE_ENFORCE_NE(scales_.find(input_var_name),
scales_.end(),
......
......@@ -48,6 +48,9 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["shape"]["Input"] = ScaleAlgo::KL;
rules_["shape"]["Out"] = ScaleAlgo::NONE;
rules_["split"]["X"] = ScaleAlgo::KL;
rules_["split"]["Out"] = ScaleAlgo::NONE;
rules_["fc"]["Input"] = ScaleAlgo::KL;
rules_["fc"]["W"] = ScaleAlgo::MAX_CH_T;
rules_["fc"]["Bias"] = ScaleAlgo::NONE;
......
......@@ -1134,7 +1134,8 @@ struct PD_INFER_DECL AnalysisConfig {
"fusion_gru",
"fusion_lstm",
"multi_gru",
"slice"};
"slice",
"split"};
// ipu related.
bool use_ipu_{false};
......
......@@ -198,7 +198,7 @@ Example:
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"});
.InEnum({"float32", "bfloat16", "int8", "uint8"});
}
};
......
......@@ -77,12 +77,20 @@ void SplitWithNumKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
split, OneDNN, ONEDNN, phi::SplitKernel, float, phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(split,
OneDNN,
ONEDNN,
phi::SplitKernel,
float,
phi::dtype::bfloat16,
int8_t,
uint8_t) {}
PD_REGISTER_KERNEL(split_with_num,
OneDNN,
ONEDNN,
phi::SplitWithNumKernel,
float,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
int8_t,
uint8_t) {}
......@@ -74,6 +74,7 @@ class Quant2Int8MkldnnPass(object):
'shape',
'nearest_interp',
'nearest_interp_v2',
'split',
]
self._scale_ops = ['scale']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
......@@ -284,6 +285,7 @@ class Quant2Int8MkldnnPass(object):
self._var_quant_scales[
input_name
] = self._var_quant_scales[output_name]
elif op.name() == 'concat':
output_name = op.output("Out")[0]
if output_name in self._var_quant_scales:
......
......@@ -19,11 +19,27 @@ from paddle.fluid.tests.unittests.op_test import OpTest
class TestSplitSectionsOneDNNOp(OpTest):
def init_data(self):
self.x = np.random.random((4, 5, 6)).astype("float32")
def init_data_type(self):
self.dtype = np.float32
def init_x(self):
if self.dtype == np.float32:
self.x = np.random.random(self.input_shape).astype(self.dtype)
elif self.dtype == np.int8:
self.x = np.random.randint(-5, 5, self.input_shape).astype(
self.dtype
)
else: # uint8
self.x = np.random.randint(0, 10, self.input_shape).astype(
self.dtype
)
def init_test_case(self):
self.input_shape = (4, 5, 6)
self.init_x()
self.axis = 1
self.num = 0
self.sections = [2, 1, 2]
indices_or_sections = [2, 3] # sections
np_sections = [2, 3]
self.out = np.split(self.x, np_sections, self.axis)
......@@ -31,8 +47,8 @@ class TestSplitSectionsOneDNNOp(OpTest):
self.op_type = "split"
self.axis_tensor = None
self.sections_tensor_list = None
self.num = 0
self.init_data()
self.init_data_type()
self.init_test_case()
self.inputs = {'X': self.x}
self.attrs = {'use_mkldnn': True, 'num': self.num}
......@@ -58,11 +74,12 @@ class TestSplitSectionsOneDNNOp(OpTest):
# test with attr(num)
class TestSplitNumOneDNNOp(TestSplitSectionsOneDNNOp):
def init_data(self):
self.x = np.random.random((4, 8, 5, 3)).astype("float32")
def init_test_case(self):
self.input_shape = (4, 8, 5, 3)
self.init_x()
self.axis = 1
self.sections = []
self.num = 4
self.sections = []
indices_or_sections = 4 # indices
self.out = np.split(self.x, indices_or_sections, self.axis)
......@@ -71,20 +88,23 @@ class TestSplitNumOneDNNOp(TestSplitSectionsOneDNNOp):
class TestSplitNumAxisTensorOneDNNOp(TestSplitSectionsOneDNNOp):
def init_data(self):
self.x = np.random.random((4, 5, 6)).astype("float32")
def init_test_case(self):
self.input_shape = (4, 5, 6)
self.init_x()
self.num = 3
self.axis = None
self.sections = []
self.num = 3
indices_or_sections = 3 # indices
self.axis_tensor = np.array([2]).astype("int32")
indices_or_sections = 3 # indices
self.out = np.split(self.x, indices_or_sections, 2)
# attr(sections) is list containing Tensor
class TestSplitSectionsTensorOneDNNOp(TestSplitSectionsOneDNNOp):
def init_data(self):
self.x = np.random.random((4, 5, 6)).astype("float32")
def init_test_case(self):
self.input_shape = (4, 5, 6)
self.init_x()
self.num = 0
self.axis = 1
self.sections = [2, 1, 2]
self.sections_tensor_list = []
......@@ -98,14 +118,47 @@ class TestSplitSectionsTensorOneDNNOp(TestSplitSectionsOneDNNOp):
class TestSplitOpUnknownSectionOneDNNOp(TestSplitSectionsOneDNNOp):
def init_data(self):
self.x = np.random.random((4, 5, 6)).astype("float32")
def init_test_case(self):
self.input_shape = (4, 5, 6)
self.init_x()
self.num = 0
self.axis = 2
self.sections = [2, 2, -1]
indices_or_sections = [2, 4] # sections
self.out = np.split(self.x, indices_or_sections, self.axis)
def create_test_class(parent):
'''
Create int8 and uint8 versions for each test. Parent tests work by default on fp32.
'''
class TestInt8Case(parent):
def init_data_type(self):
self.dtype = np.int8
def test_check_grad(self):
pass
class TestUint8Case(parent):
def init_data_type(self):
self.dtype = np.uint8
def test_check_grad(self):
pass
TestInt8Case.__name__ = "{0}_{1}".format(parent.__name__, "INT8")
TestUint8Case.__name__ = "{0}_{1}".format(parent.__name__, "UINT8")
globals()[TestInt8Case.__name__] = TestUint8Case
globals()[TestUint8Case.__name__] = TestInt8Case
create_test_class(TestSplitNumOneDNNOp)
create_test_class(TestSplitNumAxisTensorOneDNNOp)
create_test_class(TestSplitSectionsTensorOneDNNOp)
create_test_class(TestSplitOpUnknownSectionOneDNNOp)
create_test_class(TestSplitSectionsOneDNNOp)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册