diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index f401d64b73b89327afa1a4b81de78118e1f7ce3b..095489bc736e480062ac5c0e1ea9fbd1bc687507 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -113,6 +113,13 @@ function(save_qat_nlp_model_test target qat_model_dir fp32_model_save_path int8_ --int8_model_save_path ${int8_model_save_path}) endfunction() +function(convert_model2dot_test target model_path save_graph_dir save_graph_name) + py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/convert_model2dot.py + ARGS --model_path ${model_path} + --save_graph_dir ${save_graph_dir} + --save_graph_name ${save_graph_name}) +endfunction() + if(WIN32) list(REMOVE_ITEM TEST_OPS test_light_nas) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) @@ -277,6 +284,10 @@ if(LINUX AND WITH_MKLDNN) set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32") save_qat_nlp_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH}) + # Convert QAT2 model to dot and pdf files + set(QAT2_INT8_ERNIE_DOT_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8_dot_file") + convert_model2dot_test(convert_model2dot_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_qat2_int8") + endif() # Since the tests for QAT FP32 & INT8 comparison support only testing on Linux diff --git a/python/paddle/fluid/contrib/slim/tests/convert_model2dot.py b/python/paddle/fluid/contrib/slim/tests/convert_model2dot.py new file mode 100644 index 0000000000000000000000000000000000000000..877897c0a0e7282546727d56b54c0af506e18bc0 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/convert_model2dot.py @@ -0,0 +1,72 @@ +# copyright (c) 2020 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import unittest +import os +import sys +import argparse +import paddle.fluid as fluid +from paddle.fluid.framework import IrGraph +from paddle.fluid import core + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model_path', type=str, default='', help='A path to a model.') + parser.add_argument( + '--save_graph_dir', + type=str, + default='', + help='A path to save the graph.') + parser.add_argument( + '--save_graph_name', + type=str, + default='', + help='A name to save the graph. Default - name from model path will be used' + ) + + test_args, args = parser.parse_known_args(namespace=unittest) + return test_args, sys.argv[:1] + args + + +def generate_dot_for_model(model_path, save_graph_dir, save_graph_name): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + inference_scope = fluid.executor.global_scope() + with fluid.scope_guard(inference_scope): + if os.path.exists(os.path.join(model_path, '__model__')): + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(model_path, exe) + else: + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(model_path, exe, + 'model', 'params') + graph = IrGraph(core.Graph(inference_program.desc), for_test=True) + if not os.path.exists(save_graph_dir): + os.makedirs(save_graph_dir) + model_name = os.path.basename(os.path.normpath(save_graph_dir)) + if save_graph_name is '': + save_graph_name = model_name + graph.draw(save_graph_dir, save_graph_name, graph.all_op_nodes()) + print( + "Success! Generated dot and pdf files for {0} model, that can be found at {1} named {2}.\n". + format(model_name, save_graph_dir, save_graph_name)) + + +if __name__ == '__main__': + global test_args + test_args, remaining_args = parse_args() + generate_dot_for_model(test_args.model_path, test_args.save_graph_dir, + test_args.save_graph_name)