diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc index 33cde9bf9860a800fdb25d9cd2b64a4e631c65d1..31b93fe272bb245896d59cebaad952fa795f5094 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc @@ -18,14 +18,17 @@ limitations under the License. #include #include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project @@ -37,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/path.h" @@ -45,19 +49,30 @@ namespace mlir { namespace quant { namespace { +using DebuggerType = tensorflow::quantization::DebuggerOptions::DebuggerType; +using DebuggerOptions = tensorflow::quantization::DebuggerOptions; + constexpr StringRef kCompositeFuncPrefix = "composite_"; // AddDumpTensorOp pass adds DumpTensorOp - which saves entire value of its // input into a file - to quantizable layer's output. class AddDumpTensorOpPass - : public PassWrapper> { + : public PassWrapper> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddDumpTensorOpPass) explicit AddDumpTensorOpPass() = default; - explicit AddDumpTensorOpPass(std::string log_dir_path) - : log_dir_path_(std::move(log_dir_path)) {} + explicit AddDumpTensorOpPass(DebuggerType debugger_type, + std::string log_dir_path) + : log_dir_path_(std::move(log_dir_path)) { + debugger_type_ = debugger_type; + } + + AddDumpTensorOpPass(const AddDumpTensorOpPass &other) { + debugger_type_ = other.debugger_type_; + log_dir_path_ = other.log_dir_path_; + } StringRef getArgument() const final { // This is the argument used to refer to the pass in the textual format (on @@ -79,6 +94,14 @@ class AddDumpTensorOpPass private: void runOnOperation() override; + Option debugger_type_{ + *this, "debugger_type", + llvm::cl::init(DebuggerOptions::DEBUGGER_TYPE_UNSPECIFIED), + llvm::cl::values(clEnumValN(DebuggerOptions::DEBUGGER_TYPE_WHOLE_MODEL, + "whole_model", "Whole model verify"), + clEnumValN(DebuggerOptions::DEBUGGER_TYPE_PER_LAYER, + "per_layer", "Per-layer verify"))}; + std::string log_dir_path_ = "/tmp/dumps"; }; @@ -86,10 +109,14 @@ class AddDumpTensorOp : public OpRewritePattern { public: // Does not take ownership of context, which must refer to a valid value that // outlives this object. - explicit AddDumpTensorOp(MLIRContext *context, std::string log_dir_path) - : OpRewritePattern(context), log_dir_path_(std::move(log_dir_path)) {} + explicit AddDumpTensorOp(MLIRContext *context, DebuggerType debugger_type, + std::string log_dir_path) + : OpRewritePattern(context), + debugger_type_(debugger_type), + log_dir_path_(std::move(log_dir_path)) {} private: + DebuggerType debugger_type_; std::string log_dir_path_; LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, @@ -116,13 +143,22 @@ class AddDumpTensorOp : public OpRewritePattern { auto folder_name = tensorflow::io::JoinPath(log_dir_path_, f_attr.getValue()); + // In Whole model, we first need to set file_name as + // unquantized_tensor_data.pb as it is used by unquantized dump model. + // After saving unquantized dump model, the file name will be changed to + // quantized_tensor_data.pb. + // Since this process doesn't happen for per layer, we need to set file_name + // as quantized_tensor_data.pb here. + // TODO: b/296933893 - Refactor the debugger code when no quantize option + // is added + auto file_name = debugger_type_ == DebuggerOptions::DEBUGGER_TYPE_PER_LAYER + ? "quantized_tensor_data.pb" + : "unquantized_tensor_data.pb"; + SmallVector dump_attributes{ rewriter.getNamedAttr("log_dir_path", rewriter.getStringAttr(folder_name)), - // The file_name will be changed from unquantized_tensor_data.pb to - // quantized_tensor_data.pb after the calibration. - rewriter.getNamedAttr( - "file_name", rewriter.getStringAttr("unquantized_tensor_data.pb")), + rewriter.getNamedAttr("file_name", rewriter.getStringAttr(file_name)), // The op is disabled by default. Otherwise, values will be saved // during calibration. rewriter.getNamedAttr("enabled", rewriter.getBoolAttr(false)), @@ -131,6 +167,39 @@ class AddDumpTensorOp : public OpRewritePattern { rewriter.create(call_op->getLoc(), TypeRange{}, result, dump_attributes); + // Per-layer mode. + if (debugger_type_ == DebuggerOptions::DEBUGGER_TYPE_PER_LAYER) { + auto module = call_op->getParentOfType(); + SymbolTable symbol_table(module); + + // Copy composite function of quantizable layer. + const mlir::func::FuncOp ref_func = dyn_cast_or_null( + symbol_table.lookup(f_attr.getValue())); + mlir::func::FuncOp new_ref_func = + dyn_cast(ref_func->clone()); + const StringAttr new_ref_func_name = symbol_table.insert(new_ref_func); + + // Create PartitionedCallOp to the copied composite function. + // This PartitionedCallOp does not have kQuantTraitAttrName, and therefore + // won't get quantized. + auto ref_call_op = rewriter.create( + call_op.getLoc(), call_op.getResultTypes(), call_op.getOperands(), + FlatSymbolRefAttr::get(new_ref_func_name)); + + // Attach DumpTensorOp to its output unquantized layer. + SmallVector dump_attributes{ + rewriter.getNamedAttr("log_dir_path", + rewriter.getStringAttr(folder_name)), + rewriter.getNamedAttr("file_name", rewriter.getStringAttr( + "unquantized_tensor_data.pb")), + rewriter.getNamedAttr("enabled", rewriter.getBoolAttr(false)), + }; + + rewriter.create(call_op->getLoc(), TypeRange{}, + ref_call_op.getResult(0), + dump_attributes); + } + return success(); } }; @@ -140,20 +209,21 @@ static PassRegistration pass; void AddDumpTensorOpPass::runOnOperation() { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - func::FuncOp func = getOperation(); + ModuleOp module = getOperation(); - patterns.add(ctx, log_dir_path_); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { - func.emitError() << "quant-add-dump-tensor-op failed."; + patterns.add(ctx, debugger_type_, log_dir_path_); + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + module.emitError() << "quant-add-dump-tensor-op failed."; signalPassFailure(); } } } // namespace -std::unique_ptr> CreateAddDumpTensorOpPass( - std::string log_dir_path) { - return std::make_unique(std::move(log_dir_path)); +std::unique_ptr> CreateAddDumpTensorOpPass( + DebuggerType debugger_type, std::string log_dir_path) { + return std::make_unique(debugger_type, + std::move(log_dir_path)); } } // namespace quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h index 910a3abf126f0ba2798981028bba9557f76f78a9..988bd5e60bae616f46d7a28751768557e1e9a6fb 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h @@ -234,7 +234,8 @@ std::unique_ptr> CreateQuantizeWeightsPass( const tensorflow::quantization::QuantizationOptions& quant_options); // Create a pass that inserts dump tensor to quantizable layer's output. -std::unique_ptr> CreateAddDumpTensorOpPass( +std::unique_ptr> CreateAddDumpTensorOpPass( + tensorflow::quantization::DebuggerOptions::DebuggerType debugger_type, std::string log_dir_path); } // namespace quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index 9237d72f94129a4f56d80b67932cb439763b3b86..be9fbbbbd3b6db0ca2e907b38069d2a27b739055 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -5548,6 +5548,26 @@ class WeightOnlyQuantizationTest(quantize_model_test_base.QuantizedModelTest): class DebuggerTest(quantize_model_test_base.QuantizedModelTest): + def _run_model_in_sess(self, model_dir, tags, signature_key, sample_input): + with tensorflow.compat.v1.Session(graph=tensorflow.Graph()) as sess: + meta_graph = saved_model_loader.load(sess, tags, export_dir=model_dir) + signature_def = meta_graph.signature_def[signature_key] + + # DumpTensorOp only works in graph mode. + # Execute the model using session to run DumpTensorOp. + output_tensor_names = [ + output_tensor_info.name + for output_tensor_info in signature_def.outputs.values() + ] + + feed_dict = {} + for input_key, input_value in sample_input.items(): + input_tensor_name = signature_def.inputs[input_key].name + feed_dict[input_tensor_name] = input_value + + # Obtain the output of the model. + return sess.run(output_tensor_names, feed_dict=feed_dict)[0] + @parameterized.named_parameters( { 'testcase_name': 'none', @@ -5633,42 +5653,137 @@ class DebuggerTest(quantize_model_test_base.QuantizedModelTest): [unquantized_dump_model_path, 'unquantized_tensor_data.pb'], [self._output_saved_model_path, 'quantized_tensor_data.pb'], ]: - with tensorflow.compat.v1.Session(graph=tensorflow.Graph()) as sess: - meta_graph = saved_model_loader.load(sess, tags, export_dir=model_path) - signature_def = meta_graph.signature_def['serving_default'] - - # Verify that DumpTensor exists - self.assertTrue(self._contains_op(meta_graph.graph_def, 'DumpTensor')) - - # DumpTensorOp only works in graph mode. - # Execute the model using session to run DumpTensorOp. - output_tensor_names = [ - output_tensor_info.name - for output_tensor_info in signature_def.outputs.values() - ] + output_value = self._run_model_in_sess( + model_path, tags, 'serving_default', sample_input + ) - feed_dict = {} - for input_key, input_value in sample_input.items(): - input_tensor_name = signature_def.inputs[input_key].name - feed_dict[input_tensor_name] = input_value + # Find the dump file and parse it. + folder = os.path.join(log_dir_path, os.listdir(log_dir_path)[0]) + dump_file_path = os.path.join(log_dir_path, folder, file_name) - # Obtain the output of the model. - output_value = sess.run(output_tensor_names, feed_dict=feed_dict)[0] + dump_file_proto = tensor_pb2.TensorProto.FromString( + open(dump_file_path, 'rb').read() + ) - # Find the dump file and parse it. - folder = os.path.join(log_dir_path, os.listdir(log_dir_path)[0]) - dump_file_path = os.path.join(log_dir_path, folder, file_name) + dump_file_numpy = tensorflow.make_ndarray(dump_file_proto) - dump_file_proto = tensor_pb2.TensorProto.FromString( - open(dump_file_path, 'rb').read() - ) + # Since the model only has one conv2d and its output is directly used as + # the output of the model, output of the model and conv2d's dump value + # should be the same. + self.assertAllEqual(output_value, dump_file_numpy) + + @parameterized.named_parameters( + { + 'testcase_name': 'none', + 'activation_fn': None, + 'has_bias': False, + }, + { + 'testcase_name': 'relu', + 'activation_fn': nn_ops.relu, + 'has_bias': False, + }, + { + 'testcase_name': 'with_bias', + 'activation_fn': None, + 'has_bias': True, + }, + { + 'testcase_name': 'with_bias_and_relu', + 'activation_fn': nn_ops.relu, + 'has_bias': True, + }, + ) + def test_conv2d_ptq_model_per_layer_verify(self, activation_fn, has_bias): + input_shape = [None, None, None, 3] + filter_shape = [2, 3, 3, 2] + + model = self._create_conv2d_model( + input_shape, + filter_shape, + activation_fn=activation_fn, + has_bias=has_bias, + ) + saved_model_save.save(model, self._input_saved_model_path) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(8): + yield { + 'input_tensor': ops.convert_to_tensor( + np.random.uniform(low=0, high=150, size=(1, 3, 4, 3)).astype( + 'f4' + ) + ), + } + + tags = {tag_constants.SERVING} + + log_dir_path = self.create_tempdir().full_path + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.XLA, + debugger_options=quant_opts_pb2.DebuggerOptions( + debugger_type=quant_opts_pb2.DebuggerOptions.DEBUGGER_TYPE_PER_LAYER, + log_dir_path=log_dir_path, + ), + tags=tags, + signature_keys=['serving_default'], + ) - dump_file_numpy = tensorflow.make_ndarray(dump_file_proto) + converted_model = quantize_model.quantize( + self._input_saved_model_path, + self._output_saved_model_path, + quantization_options, + representative_dataset=data_gen(), + ) + + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + sample_input = { + 'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3)) + } + + unquantized_output_value = self._run_model_in_sess( + self._input_saved_model_path, tags, 'serving_default', sample_input + ) + quantized_output_value = self._run_model_in_sess( + self._output_saved_model_path, tags, 'serving_default', sample_input + ) + + # Find the both quantized and unquantized dump file. + folder = os.path.join(log_dir_path, os.listdir(log_dir_path)[0]) + unquantized_dump_file_path = os.path.join( + log_dir_path, folder, 'unquantized_tensor_data.pb' + ) + quantized_dump_file_path = os.path.join( + log_dir_path, folder, 'quantized_tensor_data.pb' + ) + + unquantized_dump_file_proto = tensor_pb2.TensorProto.FromString( + open(unquantized_dump_file_path, 'rb').read() + ) + quantized_dump_file_proto = tensor_pb2.TensorProto.FromString( + open(quantized_dump_file_path, 'rb').read() + ) + + unquantized_dump_file_numpy = tensorflow.make_ndarray( + unquantized_dump_file_proto + ) + quantized_dump_file_numpy = tensorflow.make_ndarray( + quantized_dump_file_proto + ) - # Since the model only has one conv2d and its output is directly used as - # the output of the model, output of the model and conv2d's dump value - # should be the same. - self.assertAllEqual(output_value, dump_file_numpy) + # Since the model only has one conv2d and its output is directly used as + # the output of the model, output of the model and conv2d's dump value + # should be the same. + self.assertAllEqual(quantized_output_value, quantized_dump_file_numpy) + self.assertAllEqual(unquantized_output_value, unquantized_dump_file_numpy) @test_util.run_all_in_graph_and_eager_modes diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index ff9beb2b83acaa1b819a882adbd119ce30589298..5dab89a29e740cee3b43b55f3c7f0b7c827b3a30 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -1294,12 +1294,6 @@ def _populate_quantization_options_default_values( _populate_calibration_options(quantization_options) if quantization_options.HasField('debugger_options'): - if ( - quantization_options.debugger_options.debugger_type - == quant_opts_pb2.DebuggerOptions.DebuggerType.DEBUGGER_TYPE_PER_LAYER - ): - raise ValueError('Currently, only whole model debugger is supported.') - if not quantization_options.debugger_options.log_dir_path: quantization_options.debugger_options.log_dir_path = '/tmp/dumps' diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc index 11f07ae8234d2225353d0140103e348c55d32f92..4bdf385c4d0d857014d5357972f36f33b2b5857c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc @@ -145,7 +145,8 @@ void AddQuantizePtqPreCalibrationPasses( quantization_options.enable_two_input_tensors())); // TODO: b/295140328 - Add debugger support for weight only if (quantization_options.has_debugger_options()) { - pm.addNestedPass(mlir::quant::CreateAddDumpTensorOpPass( + pm.addPass(mlir::quant::CreateAddDumpTensorOpPass( + quantization_options.debugger_options().debugger_type(), quantization_options.debugger_options().log_dir_path())); } pm.addNestedPass( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op.mlir index 07cc8b6efde5597164f90ed10f541a650e1d76f7..bddfbf19d470ec7757b9407594691d6fea6a2ddb 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op.mlir @@ -1,4 +1,5 @@ -// RUN: tf-quant-opt %s -split-input-file -quant-add-dump-tensor-op | FileCheck %s +// RUN: tf-quant-opt %s -split-input-file -quant-add-dump-tensor-op='debugger_type=whole_model' | FileCheck --check-prefix=WholeModel %s +// RUN: tf-quant-opt %s -split-input-file -quant-add-dump-tensor-op='debugger_type=per_layer' | FileCheck --check-prefix=PerLayer %s module { func.func @conv(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) { @@ -21,13 +22,23 @@ module { func.return %2 : tensor<*xf32> } -// CHECK-LABEL: func @conv -// CHECK-DAG: %[[w:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}1.600000e-01, 1.000000e-01 -// CHECK-DAG: %[[b:.*]] = "tf.Const"() {value = dense<[-2.000000e+00, 3.000000e+00 -// CHECK-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} -// CHECK-DAG: %[[output1:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} -// CHECK-DAG: "tf.DumpTensor"(%[[output1]]) {enabled = false, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} : (tensor<*xf32>) -> () -// CHECK-DAG: return %[[output0]], %[[output1]] +// WholeModel-LABEL: func @conv +// WholeModel-DAG: %[[w:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}1.600000e-01, 1.000000e-01 +// WholeModel-DAG: %[[b:.*]] = "tf.Const"() {value = dense<[-2.000000e+00, 3.000000e+00 +// WholeModel-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} +// WholeModel-DAG: %[[output1:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} +// WholeModel-DAG: "tf.DumpTensor"(%[[output1]]) {enabled = false, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} : (tensor<*xf32>) -> () +// WholeModel-DAG: return %[[output0]], %[[output1]] + +// PerLayer-LABEL: func @conv +// PerLayer-DAG: %[[w:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}1.600000e-01, 1.000000e-01 +// PerLayer-DAG: %[[b:.*]] = "tf.Const"() {value = dense<[-2.000000e+00, 3.000000e+00 +// PerLayer-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} +// PerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} +// PerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %cst, %cst_0) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0} +// PerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) {enabled = false, file_name = "quantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} : (tensor<*xf32>) -> () +// PerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) {enabled = false, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} : (tensor<*xf32>) -> () +// PerLayer-DAG: return %[[output0]], %[[output1_quantized]] } // ----- @@ -57,14 +68,29 @@ module { return %2 : tensor } -// CHECK-LABEL: func @multiple_conv2d -// CHECK-DAG: %[[b0:.*]] = "tf.Const"() {value = dense<0.000000e+00> -// CHECK-DAG: %[[b1:.*]] = "tf.Const"() {value = dense<1.000000e+00> -// CHECK-DAG: %[[w0:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}0.193340182, 0.285152316 -// CHECK-DAG: %[[w1:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445 -// CHECK-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} -// CHECK-DAG: "tf.DumpTensor"(%[[output0]]) {enabled = false, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2"} -// CHECK-DAG: %[[output1:.*]] = "tf.PartitionedCall"(%[[output0]], %[[w1]], %[[b1]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} -// CHECK-DAG: "tf.DumpTensor"(%[[output1]]) {enabled = false, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} -// CHECK-DAG: return %[[output1]] +// WholeModel-LABEL: func @multiple_conv2d +// WholeModel-DAG: %[[b0:.*]] = "tf.Const"() {value = dense<0.000000e+00> +// WholeModel-DAG: %[[b1:.*]] = "tf.Const"() {value = dense<1.000000e+00> +// WholeModel-DAG: %[[w0:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}0.193340182, 0.285152316 +// WholeModel-DAG: %[[w1:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445 +// WholeModel-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} +// WholeModel-DAG: "tf.DumpTensor"(%[[output0]]) {enabled = false, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2"} +// WholeModel-DAG: %[[output1:.*]] = "tf.PartitionedCall"(%[[output0]], %[[w1]], %[[b1]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} +// WholeModel-DAG: "tf.DumpTensor"(%[[output1]]) {enabled = false, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} +// WholeModel-DAG: return %[[output1]] + +// PerLayer-LABEL: func @multiple_conv2d +// PerLayer-DAG: %[[b0:.*]] = "tf.Const"() {value = dense<0.000000e+00> +// PerLayer-DAG: %[[b1:.*]] = "tf.Const"() {value = dense<1.000000e+00> +// PerLayer-DAG: %[[w0:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}0.193340182, 0.285152316 +// PerLayer-DAG: %[[w1:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445 +// PerLayer-DAG: %[[output0_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} +// PerLayer-DAG: %[[output0_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_0} +// PerLayer-DAG: "tf.DumpTensor"(%[[output0_quantized]]) {enabled = false, file_name = "quantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2"} +// PerLayer-DAG: "tf.DumpTensor"(%[[output0_unquantized]]) {enabled = false, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2"} +// PerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%[[output0_quantized]], %[[w1]], %[[b1]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} +// PerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%[[output0_quantized]], %[[w1]], %[[b1]]) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0} +// PerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) {enabled = false, file_name = "quantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} +// PerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) {enabled = false, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} +// PerLayer-DAG: return %[[output1_quantized]] } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir index 915aa2eb1c5c0afebf905de35f6542d9cc1826b9..0a35859628a7a19047eab675eab23b58899827bf 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir @@ -199,3 +199,88 @@ module { // CHECK: Number of quantize layers added: 0 // CHECK: Number of dequantize layers added: 0 } + +// ----- + +module { + func.func @conv_with_dump(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %cst = "tf.Const"() {device = "", value = dense<[-0.0192535277, -5.998660e-03]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<[[[[-0.282878935, -0.211567819], [-0.248810023, -0.0989695191]], [[0.400888503, 0.0803082585], [-0.0671417713, -0.23301053]]], [[[0.345862567, 0.311298311], [-0.595954239, 0.202630222]], [[-0.606417357, -0.257358253], [-0.3036502, -0.35013032]]]]> : tensor<2x2x2x2xf32>} : () -> tensor<2x2x2x2xf32> + %cst_1 = "tf.Const"() {device = "", value = dense<[-0.0291469581, 0.0106381178]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_2 = "tf.Const"() {device = "", value = dense<[[[[0.208403707, 0.478067577], [0.593097508, -0.305721074]], [[-0.114346057, 0.583530128], [0.211413622, -0.606618404]]], [[[0.314416587, -0.260997623], [-0.375806928, 0.0813755393]], [[-0.208318114, 0.275989294], [-3.469230e-01, -0.406548172]]]]> : tensor<2x2x2x2xf32>} : () -> tensor<2x2x2x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[1.878980e-02, 0.988373816]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.PartitionedCall"(%0, %cst_0, %cst) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_20} : (tensor<*xf32>, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[0.000000e+00, 0.36084348]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + "tf.DumpTensor"(%2) {device = "", enabled = true, file_name = "quantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2"} : (tensor<*xf32>) -> () + %3 = "tf.PartitionedCall"(%2, %cst_2, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_10} : (tensor<*xf32>, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor<*xf32> + %4 = "quantfork.stats"(%3) {layerStats = dense<[0.000000e+00, 0.18486841]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + %5 = "tf.Identity"(%4) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + %6 = "quantfork.stats"(%5) {layerStats = dense<[0.000000e+00, 0.18486841]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + %7 = "tf.PartitionedCall"(%2, %cst_2, %cst_1) {config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_00} : (tensor<*xf32>, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor<*xf32> + %8 = "quantfork.stats"(%7) {layerStats = dense<[0.000000e+00, 0.18486841]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + %9 = "tf.PartitionedCall"(%0, %cst_0, %cst) {config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_00} : (tensor<*xf32>, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor<*xf32> + %10 = "quantfork.stats"(%9) {layerStats = dense<[0.000000e+00, 0.36084348]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + "tf.DumpTensor"(%10) {device = "", enabled = true, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2"} : (tensor<*xf32>) -> () + "tf.DumpTensor"(%4) {device = "", enabled = true, file_name = "quantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} : (tensor<*xf32>) -> () + "tf.DumpTensor"(%8) {device = "", enabled = true, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} : (tensor<*xf32>) -> () + func.return %6 : tensor<*xf32> + } + + func.func private @composite_conv2d_with_bias_and_relu6_fn_10(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_1_00(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_20(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_2_00(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + +// CHECK-LABE: @conv_with_dump +// CHECK-DAG: %[[w0_float:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-0.282878935, -0.211567819 +// CHECK-DAG: %[[b0_float:.*]] = "tf.Const"() {value = dense<[-0.0192535277, -5.998660e-03]> : tensor<2xf32>} : () -> tensor<2xf32> +// CHECK-DAG: %[[w1_float:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}0.208403707, 0.478067577 +// CHECK-DAG: %[[b1_float:.*]] = "tf.Const"() {value = dense<[-0.0291469581, 0.0106381178]> : tensor<2xf32>} : () -> tensor<2xf32> +// CHECK-DAG: %[[w0_quantized:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-59, -44 +// CHECK-DAG: %[[b0_quantized:.*]] = "tf.Const"() {value = dense<[-1040, -324]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[w1_quantized:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}44, 100 +// CHECK-DAG: %[[b1_quantized:.*]] = "tf.Const"() {value = dense<[-4312, 1574]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[in_scale:.*]] = "tf.Const"() {value = dense<0.00387597573> : tensor} : () -> tensor +// CHECK-DAG: %[[in_out_zp:.*]] = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor +// CHECK-DAG: %[[w0_scale:.*]] = "tf.Const"() {value = dense<0.00477493973> : tensor} : () -> tensor +// CHECK-DAG: %[[w_b_zp:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK-DAG: %[[b0_scale:.*]] = "tf.Const"() {value = dense<1.85075514E-5> : tensor} : () -> tensor +// CHECK-DAG: %[[mid_scale:.*]] = "tf.Const"() {value = dense<0.00141507247> : tensor} : () -> tensor +// CHECK-DAG: %[[w1_scale:.*]] = "tf.Const"() {value = dense<0.00477652298> : tensor} : () -> tensor +// CHECK-DAG: %[[b1_scale:.*]] = "tf.Const"() {value = dense<6.75912588E-6> : tensor} : () -> tensor +// CHECK-DAG: %[[out_scale:.*]] = "tf.Const"() {value = dense<7.24974147E-4> : tensor} : () -> tensor +// CHECK-DAG: %[[arg_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[in_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantize_i8} +// CHECK-DAG: %[[conv0_quantized:.*]] = "tf.PartitionedCall"(%[[arg_quantized]], %[[w0_quantized]], %[[b0_quantized]], %[[in_scale]], %[[in_out_zp]], %[[w0_scale]], %[[w_b_zp]], %[[b0_scale]], %[[w_b_zp]], %[[mid_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_fn_1} +// CHECK-DAG: %[[conv0_dequantized:.*]] = "tf.PartitionedCall"(%[[conv0_quantized]], %[[mid_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @dequantize_i8} +// CHECK-DAG: %[[conv1_quantized:.*]] = "tf.PartitionedCall"(%[[conv0_quantized]], %[[w1_quantized]], %[[b1_quantized]], %[[mid_scale]], %[[in_out_zp]], %[[w1_scale]], %[[w_b_zp]], %[[b1_scale]], %[[w_b_zp]], %[[out_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_fn_0} +// CHECK-DAG: %[[conv1_dequantized_0:.*]] = "tf.PartitionedCall"(%[[conv1_quantized]], %[[out_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @dequantize_i8} +// CHECK-DAG: %[[conv1_dequantized_1:.*]] = "tf.PartitionedCall"(%[[conv1_quantized]], %[[out_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @dequantize_i8} +// CHECK-DAG: %[[identity:.*]] = "tf.Identity"(%[[conv1_dequantized_1]]) +// CHECK-DAG: %[[conv0_float:.*]] = "tf.PartitionedCall"(%arg0, %[[w0_float]], %[[b0_float]]) {config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_00} +// CHECK-DAG: %[[conv1_float:.*]] = "tf.PartitionedCall"(%[[conv0_dequantized]], %[[w1_float]], %[[b1_float]]) {config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_00} +// CHECK-DAG: "tf.DumpTensor"(%[[conv0_dequantized]]) {device = "", enabled = true, file_name = "quantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2"} +// CHECK-DAG: "tf.DumpTensor"(%[[conv0_float]]) {device = "", enabled = true, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2"} +// CHECK-DAG: "tf.DumpTensor"(%[[conv1_dequantized_0]]) {device = "", enabled = true, file_name = "quantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} +// CHECK-DAG: "tf.DumpTensor"(%[[conv1_float]]) {device = "", enabled = true, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} +// CHECK-DAG: return %[[identity]] +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir index 03c030c1adf632343b0dced52574c1e043a74cb2..2a6ea22ddcc1d6c335c2eb54d5981d711128b3c8 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir @@ -245,3 +245,86 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: Number of quantize layers added: 1 // CHECK: Number of dequantize layers added: 0 } + +// ----- + +module { + func.func @conv_with_dump(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %cst = "tf.Const"() {device = "", value = dense<[-0.0192535277, -5.998660e-03]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<[[[[-0.282878935, -0.211567819], [-0.248810023, -0.0989695191]], [[0.400888503, 0.0803082585], [-0.0671417713, -0.23301053]]], [[[0.345862567, 0.311298311], [-0.595954239, 0.202630222]], [[-0.606417357, -0.257358253], [-0.3036502, -0.35013032]]]]> : tensor<2x2x2x2xf32>} : () -> tensor<2x2x2x2xf32> + %cst_1 = "tf.Const"() {device = "", value = dense<[-0.0291469581, 0.0106381178]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_2 = "tf.Const"() {device = "", value = dense<[[[[0.208403707, 0.478067577], [0.593097508, -0.305721074]], [[-0.114346057, 0.583530128], [0.211413622, -0.606618404]]], [[[0.314416587, -0.260997623], [-0.375806928, 0.0813755393]], [[-0.208318114, 0.275989294], [-3.469230e-01, -0.406548172]]]]> : tensor<2x2x2x2xf32>} : () -> tensor<2x2x2x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[1.878980e-02, 0.988373816]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.PartitionedCall"(%0, %cst_0, %cst) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_20} : (tensor<*xf32>, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[0.000000e+00, 0.36084348]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + "tf.DumpTensor"(%2) {device = "", enabled = true, file_name = "quantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2"} : (tensor<*xf32>) -> () + %3 = "tf.PartitionedCall"(%2, %cst_2, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_10} : (tensor<*xf32>, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor<*xf32> + %4 = "quantfork.stats"(%3) {layerStats = dense<[0.000000e+00, 0.18486841]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + %5 = "tf.Identity"(%4) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + %6 = "quantfork.stats"(%5) {layerStats = dense<[0.000000e+00, 0.18486841]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + %7 = "tf.PartitionedCall"(%2, %cst_2, %cst_1) {config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_00} : (tensor<*xf32>, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor<*xf32> + %8 = "quantfork.stats"(%7) {layerStats = dense<[0.000000e+00, 0.18486841]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + %9 = "tf.PartitionedCall"(%0, %cst_0, %cst) {config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_00} : (tensor<*xf32>, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor<*xf32> + %10 = "quantfork.stats"(%9) {layerStats = dense<[0.000000e+00, 0.36084348]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + "tf.DumpTensor"(%10) {device = "", enabled = true, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2"} : (tensor<*xf32>) -> () + "tf.DumpTensor"(%4) {device = "", enabled = true, file_name = "quantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} : (tensor<*xf32>) -> () + "tf.DumpTensor"(%8) {device = "", enabled = true, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} : (tensor<*xf32>) -> () + func.return %6 : tensor<*xf32> + } + + func.func private @composite_conv2d_with_bias_and_relu6_fn_10(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_1_00(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_20(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_2_00(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + +// CHECK-LABE: @conv_with_dump +// CHECK-DAG: %[[w0_float:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-0.282878935, -0.211567819 +// CHECK-DAG: %[[b0_float:.*]] = "tf.Const"() {value = dense<[-0.0192535277, -5.998660e-03]> : tensor<2xf32>} : () -> tensor<2xf32> +// CHECK-DAG: %[[w1_float:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}0.208403707, 0.478067577 +// CHECK-DAG: %[[b1_float:.*]] = "tf.Const"() {value = dense<[-0.0291469581, 0.0106381178]> : tensor<2xf32>} : () -> tensor<2xf32> +// CHECK-DAG: %[[w0_quantized:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-59, -44 +// CHECK-DAG: %[[b0_quantized:.*]] = "tf.Const"() {value = dense<[-1040, -324]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[w1_quantized:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}44, 100 +// CHECK-DAG: %[[b1_quantized:.*]] = "tf.Const"() {value = dense<[-4312, 1574]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[in_scale:.*]] = "tf.Const"() {value = dense<0.00387597573> : tensor} : () -> tensor +// CHECK-DAG: %[[in_out_zp:.*]] = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor +// CHECK-DAG: %[[w0_scale:.*]] = "tf.Const"() {value = dense<0.00477493973> : tensor} : () -> tensor +// CHECK-DAG: %[[w_b_zp:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK-DAG: %[[b0_scale:.*]] = "tf.Const"() {value = dense<1.85075514E-5> : tensor} : () -> tensor +// CHECK-DAG: %[[mid_scale:.*]] = "tf.Const"() {value = dense<0.00141507247> : tensor} : () -> tensor +// CHECK-DAG: %[[w1_scale:.*]] = "tf.Const"() {value = dense<0.00477652298> : tensor} : () -> tensor +// CHECK-DAG: %[[b1_scale:.*]] = "tf.Const"() {value = dense<6.75912588E-6> : tensor} : () -> tensor +// CHECK-DAG: %[[out_scale:.*]] = "tf.Const"() {value = dense<7.24974147E-4> : tensor} : () -> tensor +// CHECK-DAG: %[[quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[in_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantize_i8} +// CHECK-DAG: %[[conv0_dequantized:.*]] = "tf.PartitionedCall"(%[[quantized]], %[[w0_quantized]], %[[b0_quantized]], %[[in_scale]], %[[in_out_zp]], %[[w0_scale]], %[[w_b_zp]], %[[b0_scale]], %[[w_b_zp]], %[[mid_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_float_output_fn_1} +// CHECK-DAG: %[[conv0_quantized:.*]] = "tf.PartitionedCall"(%[[quantized]], %[[w0_quantized]], %[[b0_quantized]], %[[in_scale]], %[[in_out_zp]], %[[w0_scale]], %[[w_b_zp]], %[[b0_scale]], %[[w_b_zp]], %[[mid_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_fn_1} +// CHECK-DAG: %[[conv1_dequantized:.*]] = "tf.PartitionedCall"(%[[conv0_quantized]], %[[w1_quantized]], %[[b1_quantized]], %[[mid_scale]], %[[in_out_zp]], %[[w1_scale]], %[[w_b_zp]], %[[b1_scale]], %[[w_b_zp]], %[[out_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_float_output_fn_0} +// CHECK-DAG: %[[identity:.*]] = "tf.Identity"(%[[conv1_dequantized]]) +// CHECK-DAG: %[[conv0_float:.*]] = "tf.PartitionedCall"(%arg0, %[[w0_float]], %[[b0_float]]) {config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_00} +// CHECK-DAG: %[[conv1_float:.*]] = "tf.PartitionedCall"(%[[conv0_dequantized]], %[[w1_float]], %[[b1_float]]) {config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_00} +// CHECK-DAG: "tf.DumpTensor"(%[[conv0_dequantized]]) {device = "", enabled = true, file_name = "quantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2"} +// CHECK-DAG: "tf.DumpTensor"(%[[conv0_float]]) {device = "", enabled = true, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2"} +// CHECK-DAG: "tf.DumpTensor"(%[[conv1_dequantized]]) {device = "", enabled = true, file_name = "quantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} +// CHECK-DAG: "tf.DumpTensor"(%[[conv1_float]]) {device = "", enabled = true, file_name = "unquantized_tensor_data.pb", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1"} +// CHECK-DAG: return %[[identity]] +}