未验证 提交 53125c2f 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Model converter to dot file (#23169)

上级 72c370c8
...@@ -113,6 +113,13 @@ function(save_qat_nlp_model_test target qat_model_dir fp32_model_save_path int8_ ...@@ -113,6 +113,13 @@ function(save_qat_nlp_model_test target qat_model_dir fp32_model_save_path int8_
--int8_model_save_path ${int8_model_save_path}) --int8_model_save_path ${int8_model_save_path})
endfunction() endfunction()
function(convert_model2dot_test target model_path save_graph_dir save_graph_name)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/convert_model2dot.py
ARGS --model_path ${model_path}
--save_graph_dir ${save_graph_dir}
--save_graph_name ${save_graph_name})
endfunction()
if(WIN32) if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas) list(REMOVE_ITEM TEST_OPS test_light_nas)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
...@@ -277,6 +284,10 @@ if(LINUX AND WITH_MKLDNN) ...@@ -277,6 +284,10 @@ if(LINUX AND WITH_MKLDNN)
set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32") set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32")
save_qat_nlp_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH}) save_qat_nlp_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH})
# Convert QAT2 model to dot and pdf files
set(QAT2_INT8_ERNIE_DOT_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8_dot_file")
convert_model2dot_test(convert_model2dot_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_qat2_int8")
endif() endif()
# Since the tests for QAT FP32 & INT8 comparison support only testing on Linux # Since the tests for QAT FP32 & INT8 comparison support only testing on Linux
......
# copyright (c) 2020 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import sys
import argparse
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_path', type=str, default='', help='A path to a model.')
parser.add_argument(
'--save_graph_dir',
type=str,
default='',
help='A path to save the graph.')
parser.add_argument(
'--save_graph_name',
type=str,
default='',
help='A name to save the graph. Default - name from model path will be used'
)
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
def generate_dot_for_model(model_path, save_graph_dir, save_graph_name):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe,
'model', 'params')
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if not os.path.exists(save_graph_dir):
os.makedirs(save_graph_dir)
model_name = os.path.basename(os.path.normpath(save_graph_dir))
if save_graph_name is '':
save_graph_name = model_name
graph.draw(save_graph_dir, save_graph_name, graph.all_op_nodes())
print(
"Success! Generated dot and pdf files for {0} model, that can be found at {1} named {2}.\n".
format(model_name, save_graph_dir, save_graph_name))
if __name__ == '__main__':
global test_args
test_args, remaining_args = parse_args()
generate_dot_for_model(test_args.model_path, test_args.save_graph_dir,
test_args.save_graph_name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册