未验证 提交 1456b02d 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Quantize nearest_interp and nearest_interp_v2 (#38622)

* Quantize nearest_interp and nearest_interp_v2

* Check if avx_core supported

* Add depthwise_conv2d to supported quantization list
上级 a268c7ce
......@@ -1641,6 +1641,32 @@ PDNode *patterns::Slice::operator()() {
return slice_out;
}
PDNode *patterns::NearestInterp::operator()() {
auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
auto nearest_interp_op =
pattern->NewNode(nearest_interp_op_repr())
->assert_is_ops({"nearest_interp", "nearest_interp_v2"});
auto nearest_interp_in =
pattern->NewNode(nearest_interp_in_repr())
->AsInput()
->assert_is_ops_input({"nearest_interp", "nearest_interp_v2"}, "X");
auto nearest_interp_out =
pattern->NewNode(nearest_interp_out_repr())
->AsOutput()
->assert_is_ops_output({"nearest_interp", "nearest_interp_v2"},
"Out");
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
prev_op->LinksTo({nearest_interp_in});
nearest_interp_op->LinksFrom({nearest_interp_in})
.LinksTo({nearest_interp_out});
next_op->LinksFrom({nearest_interp_out});
return nearest_interp_out;
}
PDNode *patterns::Matmul::operator()() {
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
......@@ -2376,15 +2402,8 @@ PDNode *patterns::MultipleQuantize::operator()() {
PDNode *patterns::QuantizePlacement::operator()(
const std::unordered_set<std::string> &quantize_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"concat", "conv2d", "elementwise_add",
"fc", "matmul", "pool2d", "prior_box",
"reshape2", "transpose2", "fusion_gru",
"fusion_lstm", "multi_gru", "slice"});
if (!quantize_enabled_op_types.empty()) {
supported_op_types = quantize_enabled_op_types;
}
auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types);
auto *op =
pattern->NewNode(op_repr())->assert_is_ops(quantize_enabled_op_types);
return op;
}
......
......@@ -995,6 +995,21 @@ struct Slice : public PatternBase {
PATTERN_DECL_NODE(next_op);
};
// Nearest Interp op
// Forward pass for nearest_interp.
// nearest_interp_out is a result of the operator.
struct NearestInterp : public PatternBase {
NearestInterp(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "nearest_interp") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(nearest_interp_in);
PATTERN_DECL_NODE(nearest_interp_op);
PATTERN_DECL_NODE(nearest_interp_out);
PATTERN_DECL_NODE(next_op);
};
// Matmul op
// Forward pass for matmul.
struct Matmul : public PatternBase {
......
......@@ -1053,6 +1053,67 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
PrettyLogDetail("--- quantized %d fusion_lstm ops", quantize_count);
}
void CPUQuantizePass::QuantizeNearestInterp(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::NearestInterp nearest_interp_pattern{pattern, name_scope_};
nearest_interp_pattern();
int quantize_nearest_interp_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize nearest_interp op";
GET_IR_NODE_FROM_SUBGRAPH(nearest_interp_op, nearest_interp_op,
nearest_interp_pattern);
// skip if should not be quantized
if (!platform::HasOpINT8DataType(nearest_interp_op->Op())) {
LogQuantizationDisabled(nearest_interp_op);
return;
}
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, nearest_interp_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, nearest_interp_pattern);
// skip if prev op and next op is not quantized
if (!(IsOpDequantized(prev_op)) && !(IsOpQuantized(next_op))) {
LogCannotQuantizeOp(nearest_interp_op,
"There are no other quantized operators nearby, so "
"quantization is not recommended.");
return;
}
GET_IR_NODE_FROM_SUBGRAPH(nearest_interp_in, nearest_interp_in,
nearest_interp_pattern);
GET_IR_NODE_FROM_SUBGRAPH(nearest_interp_out, nearest_interp_out,
nearest_interp_pattern);
if (!AreScalesPresentForNodes({nearest_interp_in, nearest_interp_out})) {
LogCannotQuantizeOp(nearest_interp_op);
return;
}
bool is_input_unsigned{false};
auto input_scale =
GetScaleValueForNode(nearest_interp_in, &is_input_unsigned);
QuantizeInput(g, nearest_interp_op, nearest_interp_in, "X", input_scale,
is_input_unsigned);
bool is_output_unsigned{false};
auto output_scale =
GetScaleValueForNode(nearest_interp_out, &is_output_unsigned);
DequantizeOutput(g, nearest_interp_op, nearest_interp_out, "Out",
output_scale, is_output_unsigned);
++quantize_nearest_interp_count;
};
gpd(graph, handler);
AddStatis(quantize_nearest_interp_count);
PrettyLogDetail("--- quantized %d nearest_interp ops",
quantize_nearest_interp_count);
}
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE_NOT_NULL(
......@@ -1076,6 +1137,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeMultiGru(graph);
QuantizeFusionLSTM(graph);
QuantizeSlice(graph);
QuantizeNearestInterp(graph);
}
} // namespace ir
......
......@@ -62,6 +62,7 @@ class CPUQuantizePass : public FusePassBase {
void QuantizeMultiGru(Graph* graph) const;
void QuantizeFusionLSTM(Graph* graph) const;
void QuantizeSlice(Graph* graph) const;
void QuantizeNearestInterp(Graph* graph) const;
void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name,
double scale_to_one, bool is_input_unsigned,
......
......@@ -58,7 +58,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("Scale_in", 1.0f);
op->SetAttr("Scale_out", 1.0f);
op->SetAttr("Scale_weights", std::vector<float>{1.0f});
} else if (type == "pool2d" || type == "transpose2" || type == "reshape2") {
} else if (type == "pool2d" || type == "transpose2" || type == "reshape2" ||
type == "nearest_interp" || type == "nearest_interp_v2") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
} else if (type == "slice") {
......@@ -434,6 +435,18 @@ TEST(CpuQuantizePass, sliceBetweenNonQuantizedOp) {
TestImmutableOpBetweenNonQuantizedOp("slice");
}
TEST(CpuQuantizePass, nearestInterp) { TestImmutableOp("nearest_interp"); }
TEST(CpuQuantizePass, nearestInterpBetweenNonQuantizedOp) {
TestImmutableOpBetweenNonQuantizedOp("nearest_interp");
}
TEST(CpuQuantizePass, nearestInterpV2) { TestImmutableOp("nearest_interp_v2"); }
TEST(CpuQuantizePass, nearestInterpV2BetweenNonQuantizedOp) {
TestImmutableOpBetweenNonQuantizedOp("nearest_interp_v2");
}
static const std::initializer_list<std::string> variable_names_matmul = {
"a", "b", "c", "d", "e", "f"};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h"
#include <unordered_set>
namespace paddle {
......@@ -23,15 +24,34 @@ class Graph;
void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Marks operators which are to be quantized.";
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>(
{"concat", "conv2d", "depthwise_conv2d", "elementwise_add", "fc",
"matmul", "nearest_interp", "nearest_interp_v2", "pool2d",
"prior_box", "reshape2", "transpose2", "fusion_gru", "fusion_lstm",
"multi_gru", "slice"});
const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids");
const auto& op_types_list =
Get<std::unordered_set<std::string>>("quantize_enabled_op_types");
if (!op_types_list.empty()) {
// Verify that all user-specified operators can be quantized.
for (const auto& op : op_types_list) {
PADDLE_ENFORCE_NE(
supported_op_types.count(op), 0,
platform::errors::InvalidArgument(
"Pass attribute quantize_enabled_op_types contains operator %s "
"that is not supported by OneDNN quantization.",
op));
}
supported_op_types = op_types_list;
}
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::QuantizePlacement quantize_placement_pattern{gpd.mutable_pattern(),
"quantize_placement"};
quantize_placement_pattern(op_types_list);
quantize_placement_pattern(supported_op_types);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -46,16 +66,7 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
return;
}
if (op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) {
// use_quantizer is no longer used
// assign value for compatibility
if (op->Op()->GetAttrIfExists<bool>("use_quantizer")) {
op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
}
op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
op->Op()->SetAttr("use_quantizer", true);
}
op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
};
gpd(graph, handler);
}
......
......@@ -140,6 +140,32 @@ TEST(QuantizerPlacementPass, default_attr_value) {
DefaultAttrTest(5);
}
void EnabledOpTypesTest(
std::initializer_list<std::string> quantize_enabled_op_types,
std::string missing_op) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("cpu_quantize_placement_pass");
pass->Set("quantize_enabled_op_types",
new std::unordered_set<std::string>(quantize_enabled_op_types));
try {
graph.reset(pass->Apply(graph.release()));
} catch (paddle::platform::EnforceNotMet& err) {
std::string ex_msg = err.what();
std::string expected_msg =
"Pass attribute quantize_enabled_op_types contains operator " +
missing_op + " that is not supported by OneDNN quantization.";
EXPECT_TRUE(ex_msg.find(expected_msg) != std::string::npos);
}
}
TEST(QuantizerPlacementPass, unsupported_op_type) {
// Dropout op is not supported by OneDNN quantization
EnabledOpTypesTest({"conv2d", "dropout"}, "dropout");
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -124,7 +124,8 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs(
} else if (op->Type() == "relu") {
is_unsigned = true;
} else if (op->Type() == "transpose2" || op->Type() == "reshape2" ||
op->Type() == "pool2d") {
op->Type() == "pool2d" || op->Type() == "nearest_interp" ||
op->Type() == "nearest_interp_v2") {
auto input_var_name = op->Input("X")[0];
PADDLE_ENFORCE_NE(scales_.find(input_var_name), scales_.end(),
platform::errors::PreconditionNotMet(
......
......@@ -107,6 +107,18 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["fusion_lstm"]["ReorderedC0"] = ScaleAlgo::NONE;
rules_["fusion_lstm"]["CheckedCell"] = ScaleAlgo::NONE;
rules_["fusion_lstm"]["Hidden"] = ScaleAlgo::KL;
rules_["nearest_interp"]["X"] = ScaleAlgo::KL;
rules_["nearest_interp"]["OutSize"] = ScaleAlgo::NONE;
rules_["nearest_interp"]["SizeTensor"] = ScaleAlgo::NONE;
rules_["nearest_interp"]["Scale"] = ScaleAlgo::NONE;
rules_["nearest_interp"]["Out"] = ScaleAlgo::NONE;
rules_["nearest_interp_v2"]["X"] = ScaleAlgo::KL;
rules_["nearest_interp_v2"]["OutSize"] = ScaleAlgo::NONE;
rules_["nearest_interp_v2"]["SizeTensor"] = ScaleAlgo::NONE;
rules_["nearest_interp_v2"]["Scale"] = ScaleAlgo::NONE;
rules_["nearest_interp_v2"]["Out"] = ScaleAlgo::NONE;
}
ScaleAlgo MkldnnQuantizerConfig::scale_algo(
......
......@@ -63,7 +63,8 @@ class Quant2Int8MkldnnPass(object):
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
[-1])
self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'slice'
'transpose2', 'reshape2', 'pool2d', 'slice', 'nearest_interp',
'nearest_interp_v2'
]
self._scale_ops = ['scale']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
......
......@@ -216,6 +216,141 @@ class TestQuant2Int8MkldnnPassConv2D(unittest.TestCase):
graph = quant2_int8_mkldnn_pass._update_activations(graph)
self.check_graph_after_pass(graph)
class TestQuant2Int8MkldnnPassNearestInterp(unittest.TestCase):
def op_name(self):
return "nearest_interp"
def setUp(self):
self.scope = fluid.Scope()
self.place = fluid.CPUPlace()
self.dtype = np.float32
self.use_cudnn = False
self.use_mkldnn = True
# conv2d
self.data_format = "ANYLAYOUT"
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [1, 3, 5, 5]
self.filter_size = [16, 3, 3, 3]
self.conv_output_size = [1, 16, 3, 3]
self.input = np.random.random(self.input_size).astype(self.dtype)
self.filter = np.random.random(self.filter_size).astype(self.dtype)
self.conv_output = np.ndarray(self.conv_output_size).astype(
self.dtype)
# nearest_interp
self.out_h = 1
self.out_w = 1
self.scale = 2.0
self.interp_method = 'nearest'
self.data_layout = 'NCHW'
self.nearest_interp_output_size = [1, 1, 2, 2]
self.nearest_interp_output = np.ndarray(
self.nearest_interp_output_size).astype(self.dtype)
# dropout
self.dropout_prob = 0.5
self.dropout_out = np.ndarray(
self.nearest_interp_output_size).astype(self.dtype)
self.dropout_mask = np.ndarray(self.nearest_interp_output_size)
self.quantized_ops = {
"conv2d", "nearest_interp", "nearest_interp_v2"
}
self.variables = {
"input": self.input,
"filter": self.filter,
"conv_output": self.conv_output,
"nearest_interp_output": self.nearest_interp_output,
"dropout_out": self.dropout_out,
'dropout_mask': self.dropout_mask
}
def prepare_program(self, program):
block = program.global_block()
for name in self.variables:
block.create_var(
name=name,
dtype="float32",
shape=self.variables[name].shape)
block.append_op(
type="conv2d",
inputs={
"Input": block.var('input'),
'Filter': block.var('filter')
},
outputs={"Output": block.var('conv_output')},
attrs={
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups,
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format,
'fuse_relu': True
})
block.append_op(
type=self.op_name(),
inputs={"X": block.var('conv_output'), },
outputs={"Out": block.var('nearest_interp_output')},
attrs={
'interp_method': self.interp_method,
'out_h': self.out_h,
'out_w': self.out_w,
'scale': self.scale,
'data_layout': self.data_layout,
'use_mkldnn': self.use_mkldnn
})
block.append_op(
type='dropout',
inputs={"X": block.var('nearest_interp_output'), },
outputs={
'Out': block.var('dropout_out'),
'Mask': block.var('dropout_mask')
},
attrs={'dropout_prob': self.dropout_prob, })
def check_graph_after_pass(self, graph):
for op in graph.all_op_nodes():
if op.op().type() in self.quantized_ops:
self.assertTrue(op.op().has_attr("mkldnn_data_type"))
self.assertTrue(op.op().attr("mkldnn_data_type") == "int8")
def test_quant_update_activation(self):
program = fluid.Program()
with fluid.program_guard(program):
self.prepare_program(program)
graph = IrGraph(core.Graph(program.desc), for_test=True)
quant2_int8_mkldnn_pass = Quant2Int8MkldnnPass(
self.quantized_ops,
_scope=self.scope,
_place=self.place,
_core=core,
_debug=False)
input_scale_tensor = quant2_int8_mkldnn_pass._convert_scale2tensor(
np.array(self.scale).astype(np.float64))
output_scale_tensor = quant2_int8_mkldnn_pass._convert_scale2tensor(
np.array(1. / self.scale * self.scale).astype(np.float64))
var_scale = {
"input": (False, input_scale_tensor),
"filter": (False, input_scale_tensor),
"conv_output": (False, output_scale_tensor),
}
if core.avx_supported():
quant2_int8_mkldnn_pass._var_quant_scales = var_scale
graph = quant2_int8_mkldnn_pass._propagate_scales(graph)
graph = quant2_int8_mkldnn_pass._quantize_fp32_graph(graph)
self.check_graph_after_pass(graph)
class TestQuant2Int8MkldnnPassNearestInterpV2(unittest.TestCase):
def op_name(self):
return "nearest_interp_v2"
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册