diff --git a/paddle/fluid/lite/CMakeLists.txt b/paddle/fluid/lite/CMakeLists.txt index 4de031077f730422399a305a3f5e031ca198c3ab..dfd85d00a4cd328b9f85d12b5aa6923d54232888 100644 --- a/paddle/fluid/lite/CMakeLists.txt +++ b/paddle/fluid/lite/CMakeLists.txt @@ -187,6 +187,7 @@ add_subdirectory(model_parser) add_subdirectory(utils) add_subdirectory(api) add_subdirectory(gen_code) +add_subdirectory(tools) if (WITH_TESTING) diff --git a/paddle/fluid/lite/api/cxx_api.cc b/paddle/fluid/lite/api/cxx_api.cc index 16a5cc891668f604b8f1bdc459473499e8a8a551..e86ddf04d256c045627e41a0db58aad0effbe116 100644 --- a/paddle/fluid/lite/api/cxx_api.cc +++ b/paddle/fluid/lite/api/cxx_api.cc @@ -50,18 +50,22 @@ const lite::Tensor *Predictor::GetOutput(size_t offset) { } void Predictor::Build(const std::string &model_path, const Place &prefer_place, - const std::vector &valid_places) { + const std::vector &valid_places, + const std::vector &passes) { LoadModel(model_path, scope_.get(), &program_desc_); - Build(program_desc_, prefer_place, valid_places); + Build(program_desc_, prefer_place, valid_places, passes); } const framework::proto::ProgramDesc &Predictor::program_desc() const { return program_desc_; } +const RuntimeProgram &Predictor::runtime_program() const { return *program_; } + void Predictor::Build(const framework::proto::ProgramDesc &desc, const Place &prefer_place, - const std::vector &valid_places) { + const std::vector &valid_places, + const std::vector &passes) { program_desc_ = desc; Program program(desc, scope_, valid_places); @@ -69,7 +73,7 @@ void Predictor::Build(const framework::proto::ProgramDesc &desc, core::KernelPickFactor factor; factor.ConsiderTarget(); factor.ConsiderPrecision(); - optimizer_.Run(std::move(program), valid_places, factor); + optimizer_.Run(std::move(program), valid_places, factor, passes); program_ = optimizer_.GenRuntimeProgram(); } diff --git a/paddle/fluid/lite/api/cxx_api.h b/paddle/fluid/lite/api/cxx_api.h index 5434bc18eb634a7c2136a64f4afdb490db92119d..da728b2dceb91808fdf2b3a31f74b3bec0ccb96a 100644 --- a/paddle/fluid/lite/api/cxx_api.h +++ b/paddle/fluid/lite/api/cxx_api.h @@ -39,10 +39,12 @@ class Predictor { // Build from a model, with places set for hardware config. void Build(const std::string& model_path, const Place& prefer_place, - const std::vector& valid_places); + const std::vector& valid_places, + const std::vector& passes = {}); void Build(const framework::proto::ProgramDesc& desc, - const Place& prefer_place, const std::vector& valid_places); + const Place& prefer_place, const std::vector& valid_places, + const std::vector& passes = {}); // Run the predictor for a single batch of data. void Run() { program_->Run(); } @@ -53,9 +55,9 @@ class Predictor { // Get offset-th col of fetch results. const lite::Tensor* GetOutput(size_t offset); - // Return the program desc for debug. const framework::proto::ProgramDesc& program_desc() const; const lite::Tensor* GetTensor(const std::string& name) const; + const RuntimeProgram& runtime_program() const; // This method is disabled in mobile, for unnecessary dependencies required. void SaveModel(const std::string& dir); diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index 46da1815f197a2107a2ab3c3d844f1c4d87b44f2..5aef3c280b4a039eaf72da2c7aacff8f11d9a467 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -154,6 +154,8 @@ class RuntimeProgram { size_t num_instructions() const { return instructions_.size(); } + const std::vector& instructions() const { return instructions_; } + protected: std::string SerializeProgram(const framework::proto::ProgramDesc& desc); void SaveParams(const std::string& dir, diff --git a/paddle/fluid/lite/kernels/use_kernels.h b/paddle/fluid/lite/kernels/use_kernels.h index 09395abab523accd0bc4f95c75d0b9b23f1e8999..3e32f05da3c68fddcde27fe275f1f7020217f66f 100644 --- a/paddle/fluid/lite/kernels/use_kernels.h +++ b/paddle/fluid/lite/kernels/use_kernels.h @@ -37,6 +37,7 @@ USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(transpose, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(transpose2, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def); #endif #ifdef LITE_WITH_X86 diff --git a/paddle/fluid/lite/tools/CMakeLists.txt b/paddle/fluid/lite/tools/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..71bebdf6f8cc949d0f851ae6cc45fed14c492154 --- /dev/null +++ b/paddle/fluid/lite/tools/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(debug) diff --git a/paddle/fluid/lite/tools/build.sh b/paddle/fluid/lite/tools/build.sh index fe956a0384554ea2d2d065c5bd231cbd6d646ecb..8ecc02e0874f38bb2e4f75b32d2d87dfea832f66 100755 --- a/paddle/fluid/lite/tools/build.sh +++ b/paddle/fluid/lite/tools/build.sh @@ -10,10 +10,17 @@ NUM_CORES_FOR_COMPILE=8 # for code gen, a source file is generated after a test, but is dependended by some targets in cmake. # here we fake an empty file to make cmake works. -function prepare_for_codegen { +function prepare_workspace { # in build directory - mkdir -p ./paddle/fluid/lite/gen_code - touch ./paddle/fluid/lite/gen_code/__generated_code__.cc + # 1. Prepare gen_code file + GEN_CODE_PATH_PREFIX=paddle/fluid/lite/gen_code + mkdir -p ./${GEN_CODE_PATH_PREFIX} + touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc + + # 2.Prepare debug tool + DEBUG_TOOL_PATH_PREFIX=paddle/fluid/lite/tools/debug + mkdir -p ./${DEBUG_TOOL_PATH_PREFIX} + cp ../${DEBUG_TOOL_PATH_PREFIX}/analysis_tool.py ./${DEBUG_TOOL_PATH_PREFIX}/ } function check_need_ci { @@ -21,7 +28,7 @@ function check_need_ci { } function cmake_x86 { - prepare_for_codegen + prepare_workspace cmake .. -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DLITE_WITH_X86=ON ${common_flags} } @@ -44,7 +51,7 @@ function cmake_opencl { # This method is only called in CI. function cmake_x86_for_CI { - prepare_for_codegen # fake an empty __generated_code__.cc to pass cmake. + prepare_workspace # fake an empty __generated_code__.cc to pass cmake. cmake .. -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DLITE_WITH_X86=ON ${common_flags} -DLITE_WITH_PROFILE=ON # Compile and execute the gen_code related test, so it will generate some code, and make the compilation reasonable. @@ -56,7 +63,7 @@ function cmake_x86_for_CI { } function cmake_gpu { - prepare_for_codegen + prepare_workspace cmake .. " -DWITH_GPU=ON {common_flags} -DLITE_WITH_GPU=ON" } @@ -164,6 +171,7 @@ function test_arm_model { } function cmake_arm { + prepare_workspace # $1: ARM_TARGET_OS in "android" , "armlinux" # $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf" # $3: ARM_TARGET_LANG in "gcc" "clang" diff --git a/paddle/fluid/lite/tools/debug/CMakeLists.txt b/paddle/fluid/lite/tools/debug/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..da033d61408470789bf5a6d8569f0312d07377e6 --- /dev/null +++ b/paddle/fluid/lite/tools/debug/CMakeLists.txt @@ -0,0 +1,12 @@ +cc_library(debug_utils_lite SRCS debug_utils.cc) + +lite_cc_binary(lite_model_debug_tool SRCS model_debug_tool.cc + DEPS + cxx_api_lite + debug_utils_lite + model_parser_lite + target_wrapper_host + mir_passes + ${ops_lite} ${host_kernels} + X86_DEPS ${x86_kernels} + ARM_DEPS ${arm_kernels}) diff --git a/paddle/fluid/lite/tools/debug/analysis_tool.py b/paddle/fluid/lite/tools/debug/analysis_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..73430ee4117c778b845b85926cc2ed5e8839ff22 --- /dev/null +++ b/paddle/fluid/lite/tools/debug/analysis_tool.py @@ -0,0 +1,403 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +Fluid model analysis tools +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import logging +import os +import subprocess +import sys +from collections import OrderedDict +from operator import mul + +# Simple logging config +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid import debugger +from paddle.fluid import core + +# Command arguments +parser = argparse.ArgumentParser() +parser.add_argument( + "--model_dir", type=str, required=True, help="Model dir path") +parser.add_argument( + "--input_file", default="", type=str, help="Input datas file path") +parser.add_argument( + "--topo_file", + type=str, + required=True, + help="Runtime topology order output file path") +parser.add_argument( + "--tensor_file", + default="", + type=str, + required=True, + help="Tensor file path") +parser.add_argument( + "--tensor_names", + default="", + type=str, + help="If tensor_names is not empty, then only this tensors will be compare") +parser.add_argument( + "--separator", + default=",", + type=str, + help="Deafult separator, use in string split") +parser.add_argument( + "--output_tensor", + default=0, + type=int, + help="dump fluid runntime tensors or not") +parser.add_argument( + "--tensor_output_file", + default="./tensor_output_py", + type=str, + help="dump fluid runntime tensors filepath") +parser.add_argument( + "--tensor_output_length", + default=-1, + type=int, + help="Output tensor data length, dims size will be used if tensor_output_length < 0" +) +parser.add_argument( + "--only_first", + default=1, + type=int, + help="If only output the first mismatch vars info or not") +parser.add_argument( + "--output_file", + default="./diff.txt", + type=str, + help="dump diff info filepath") +parser.add_argument( + "--threshold", default=1e-5, type=float, help="float value diff threshold") + + +# Help functions +def load_file(filename, delim=None): + """ + Load file help function + """ + with open(filename) as fd: + for line in fd: + line = line.strip() + assert len(line) != "" + if delim: + line = line.split(delim) + yield line + + +class FluidModelExecutor(object): + """ + A fluid inference model executeor + """ + + def __init__(self, model_dir, input_file): + self.model_dir = model_dir + self.place = fluid.CPUPlace() + self.exe = fluid.Executor(self.place) + self.scope = fluid.core.Scope() + self.input_data = self._load_input_file(input_file) + + self.program, self.feed_target_names, self.fetch_targets = self._load_inference_model( + ) + + def infer_var_list(self, + arg_names=None, + out_data_len=-1, + dump_tensor=False, + dump_tensor_file=''): + """ + Get variables' tensor in var_list + """ + with fluid.scope_guard(self.scope): + global_block = self.program.global_block() + feed_list = self._prepare_feed_data(global_block, + self.feed_target_names) + fetch_targets = self._fetch_tmp_vars(global_block, arg_names) + results = self.exe.run(program=self.program, + feed=feed_list, + fetch_list=fetch_targets, + return_numpy=False) + return self._get_results( + results, + fetch_targets, + arg_names=arg_names, + need_save=dump_tensor, + save_path=dump_tensor_file, + out_data_len=out_data_len) + + def draw_graph(self, output_path='./', filename='debug'): + """ + Draw graph with graphviz + """ + dot_path = os.path.join([output_path, filename + '.dot']) + pdf_path = os.path.join([output_path, filename + '.pdf']) + debugger.draw_block_graphviz(self.program.global_block(), path=dot_path) + cmd = ["dot", "-Tpdf", dot_path, "-o", pdf_path] + subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + def _prepare_feed_data(self, block, feed_target_names): + feed_dict = dict() + + def fill_data(np_dtype, col, shape): + if self.input_data: + input_size = reduce(mul, shape) + assert len(self.input_data[0]) > col + data = self.input_data[0][col].split(' ') + assert len(data) == input_size + return np.array( + map(np_dtype, data), dtype=np_dtype).reshape(shape) + else: + return np.ones(shape, dtype=np_dtype) + + # TODO(sangoly): support multiple feed fields + assert len(feed_target_names) == 1 + for idx, name in enumerate(feed_target_names): + var = block.var(name) + np_shape = list(var.shape) + # TODO(sangoly): support batch + if np_shape[0] == -1: + np_shape[0] = 1 + if var.dtype == core.VarDesc.VarType.INT32: + feed_dict[name] = fill_data(np.int32, idx, np_shape) + elif var.dtype == core.VarDesc.VarType.INT64: + feed_dict[name] = fill_data(np.int64, idx, np_shape) + elif var.dtype == core.VarDesc.VarType.FP16: + feed_dict[name] = fill_data(np.float16, idx, np_shape) + elif var.dtype == core.VarDesc.VarType.FP32: + feed_dict[name] = fill_data(np.float32, idx, np_shape) + elif var.dtype == core.VarDesc.VarType.FP64: + feed_dict[name] = fill_data(np.float64, idx, np_shape) + else: + raise TypeError("Data type is not supported") + return feed_dict + + def _load_input_file(self, input_file=None): + input_data = [] + if not input_file: + return input_data + logger.info("Loading input file %s ..." % input_file) + for line in load_file(input_file, "\t"): + input_data.append(line) + return input_data + + def _load_inference_model(self): + with fluid.scope_guard(self.scope): + model_abs_path = os.path.join(self.model_dir, 'model') + param_abs_path = os.path.join(self.model_dir, 'params') + if os.path.exists(model_abs_path) and os.path.exists( + param_abs_path): + return fluid.io.load_inference_model(self.model_dir, exe, + 'model', 'params') + else: + return fluid.io.load_inference_model(self.model_dir, self.exe) + + def _fetch_tmp_vars(self, block, var_names_list=None): + fetch_var = block.var('fetch') + old_fetch_names = set([var.name for var in self.fetch_targets]) + new_fetch_vars = [block.var(name) for name in old_fetch_names] + i = len(new_fetch_vars) + if var_names_list is None: + var_names_list = block.vars.keys() + for var_name in var_names_list: + if var_name in old_fetch_names: continue + new_fetch_vars.append(block.var(var_name)) + block.append_op( + type='fetch', + inputs={'X': [var_name]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}) + i = i + 1 + return new_fetch_vars + + def _get_results(self, + results, + new_fetch_targets, + need_save=False, + arg_names=None, + save_path='', + out_data_len=10): + res = OrderedDict() + old_fetch_names = set([var.name for var in self.fetch_targets]) + if need_save: + out_fd = open(save_path, 'w') + for result in results: + idx = results.index(result) + name = new_fetch_targets[idx].name + dim = [v if v >= 0 else 1 for v in new_fetch_targets[idx].shape] + size = min(reduce(mul, dim), + out_data_len) if out_data_len > 0 else reduce(mul, dim) + values = list(np.array(result).flatten())[:size] + res[name] = {"dim": dim, "values": values} + if need_save: + if arg_names and name not in arg_names: continue + dim_str = '{' + ','.join(map(str, dim)) + '}' + out_fd.write('\t'.join( + [name, dim_str, ' '.join(map(str, values))]) + '\n') + if need_save: + out_fd.close() + return res + + +class Analyser(object): + """ + A FLuid model analysis tool + """ + + def __init__(self, args): + self.args = args + self.tensors = OrderedDict() + self.topo = {} + self.input = [] + logger.info("Loading fluid inference model %s ..." % args.model_dir) + self.predictor = FluidModelExecutor(args.model_dir, args.input_file) + + def analysis(self): + """ + Analyser work function + """ + self._load_topo_file() + self._load_tensor_file() + arg_names = self.args.tensor_names.split(',') if self.args.tensor_names != "" \ + else self.tensors.keys() + infer_results = self.predictor.infer_var_list( + out_data_len=self.args.tensor_output_length, + arg_names=arg_names, + dump_tensor=self.args.output_tensor, + dump_tensor_file=self.args.tensor_output_file) + if self.args.tensor_names == "": + self._check_diff_nodes(infer_results) + + def _parse_topo_field(self, field): + params = [item.split(':')[1].strip() for item in field[1:-1].split(' ')] + params = [item.split('#') for item in params if item != ""] + return [item for lst in params for item in lst] + + def _load_topo_file(self): + if self.args.topo_file == "": + raise ValueError("Topo file path in empty") + logger.info("Loading topo file %s ..." % self.args.topo_file) + for line in load_file(self.args.topo_file, '\t'): + op_type, inputs, outputs = line + for name in self._parse_topo_field(outputs): + if name not in self.topo: + self.topo[name] = [] + self.topo[name].append(line) + + def _load_tensor_file(self): + if self.args.tensor_file == "": + raise ValueError("Tensor file path in empty") + logger.info("Loading tensor file %s ..." % args.tensor_file) + for line in load_file(args.tensor_file, "\t"): + name, dim, values = line + dim = map(int, dim[1:-1].split(',')) + values = map(float, values.split(' ')) + + dim_size = reduce(mul, dim) + value_size = len(values) + assert dim_size == value_size, \ + "Dim size mismatch with data: %d vs %d" % (dim_size, value_size) + + self.tensors[name] = {"dim": dim, "values": values} + + def _check_diff_nodes(self, results): + """ + NOTE: The tensor output by c++ debug tool is according to runtime topology order, + so we can find the first ops (may be one of them) with error results + """ + assert len(self.tensors) == len(results), \ + "FLuid output tensor'size mismatch with `tensor_file`" + diff_vars = [] + flag = False + for k in self.tensors: + if k not in results: + raise KeyError("Have not found infer result for `%s`" % k) + if len(self.tensors[k]['values']) != len(results[k]['values']): + raise ValueError( + "Argname: %s size mismatch with `tensor_file`: %d vs %d" % + (k, len(self.tensors[k]['values']), + len(results[k]['values']))) + for i in range(len(self.tensors[k]['values'])): + if abs(self.tensors[k]['values'][i] - results[k]['values'][ + i]) > args.threshold: + diff_vars.append(k) + if args.only_first: + flag = True + break + if flag: break + self._output_diff_nodes(results, diff_vars) + + def _output_diff_nodes(self, results, diff_vars): + logger.info('is here') + + def output_param_info(inputs, outputs, infos, fd): + def tensor_repr(name): + return '\t'.join([ + name, '{' + ','.join(map(str, infos[name]['dim'])) + '}', + ' '.join(map(str, infos[name]['values'])) + ]) + + for name in self._parse_topo_field(inputs): + if name not in infos: continue + fd.write(tensor_repr(name) + '\n') + for name in self._parse_topo_field(outputs): + if name not in infos: continue + fd.write(tensor_repr(name) + '\n') + + if len(diff_vars) == 0: + logger.info("No diff found. Congratulation!") + return + logger.info("Total diff vars: %d" % len(diff_vars)) + with open(self.args.output_file, 'w') as fd: + for var in diff_vars: + if var not in self.topo: + raise KeyError("%s not in any op's output params, " % var + + "please check your model and input") + fd.write( + '>>>>>>>>>>>>>>>>>>DIFF VARIABLE: %s<<<<<<<<<<<<<<<<<<<\n' % + var) + for idx, (op_type, inputs, + outputs) in enumerate(self.topo[var]): + op_repr = '\t'.join([op_type, inputs, outputs]) + logger.info("dump diff info: ------------ %s" % op_repr) + fd.write(op_repr + '\n') + fd.write( + "--------------- Tensor File info ---------------\n") + output_param_info(inputs, outputs, self.tensors, fd) + fd.write( + "--------------- Fluid Tensor info ---------------\n") + output_param_info(inputs, outputs, results, fd) + fd.write("\n\n") + + +if __name__ == "__main__": + args = parser.parse_args() + analyser = Analyser(args) + analyser.analysis() diff --git a/paddle/fluid/lite/tools/debug/check_model.sh b/paddle/fluid/lite/tools/debug/check_model.sh new file mode 100755 index 0000000000000000000000000000000000000000..67b898ad068ee9b48d3ea1e091e54e08b6523533 --- /dev/null +++ b/paddle/fluid/lite/tools/debug/check_model.sh @@ -0,0 +1,182 @@ +#!/bin/bash + +############################# Arguments ############################ +# For both cpp & python +BUILD_ROOT_DIR="" # Cmake build root path, for LD_LIBRARY_PATH +MODEL_DIR="" # Model dir path +INPUT_FILE="" # Input data file, only the first record will be used. + # If the path is empty, then all-ones input will be used. +CPP_TOPO_FILE=./topo_file.txt # Runtime program topology info. Write by Cpp-debug-tool and Read by Py-debug-tool +CPP_TENSOR_FILE=./tensor_cpp.txt # Store Cpp-debug-tool's tensor outputs int runtime topology order. + # Write by Cpp-debug-tool and Read by Py-debug-tool +TENSOR_NAMES="" # If is not empty, then only dump the tensor fo arguments whoes name is + # in tensor names. Separate by ','. +TENSOR_OUTPUT_LENGTH=-1 # Output tensor data length. Tensor's dim size will be used if this value < 0. + +# For Cpp debug tools +CPP_OUTPUT_TOPO=1 # If output topology info or not. +CPP_OUTPUT_VARS=1 # If output TmpVar' tensor or not. +CPP_OUTPUT_WEIGHTS=1 # If output WeightVar' tensor or not. +CPP_ARM_THREAD_NUM=1 # ARM thread num. Used by ARM device info. + # Only be used by compile option - LITE_WITH_ARM + +# For python debug tools +PY_THRESHOLD=0.00001 # The numerical lower bound be used to judge [Cpp vs Py] runtime model diff. +PY_TENSOR_FILE=./tensor_py.txt # Store Py-debug-tool's tensor outputs. +PY_OUTPUT_FILE=./diff.txt # Store model different op/var info for debug. +PY_ONLY_OUTPUT_FIRST_DIFF=1 # If only output the first different var's info in runtime topology order or not. +PY_OUTPUT_TENSOR=1 # If output var' tensor in CPP_TENSOR_FILE/TENSOR_NAMES or not. + +############################# MAIN ################################# +function print_usage { + echo -e "\nUSAGE:" + echo -e "debug_cpp_stage -> debug_py_stage" + echo + echo "----------------------------------------" + echo -e "debug_cpp_stage:" + echo -e "run_debug.sh [--option=value]* debug_cpp_stage" + echo -e "See run_debug.sh#run_cpp_debug_tool for detail" + echo + echo -e "debug_py_stage:" + echo -e "run_debug.sh [--option=value]* debug_py_stage" + echo -e "See run_debug.sh#run_py_debug_tool for detail" + echo "----------------------------------------" +} + +function check_enviroment { + if [ "X${BUILD_ROOT_DIR}" == "X" ]; then + echo -e "\nOption: --build_root_dir=xxx is required.\n"; + exit 1 + fi + if [ "X${MODEL_DIR}" == "X" ]; then + echo -e "\nOption: --model_dir=xxx is required.\n"; + exit 1 + fi +} + +function run_cpp_debug_tool { + check_enviroment + + local tool_name="lite_model_debug_tool" + local tool_path=$(find ${BUILD_ROOT_DIR} -type f -name ${tool_name}) + if [ "X${tool_path}" == "X" ]; then + echo -e "\nERROR: ${tool_name} not found in ${BUILD_ROOT_DIR}.\n" + exit 1 + fi + echo "Find Cpp-debug-tool path: ${tool_path}" + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$BUILD_ROOT_DIR/third_party/install/mklml/lib" + ${tool_path} \ + --model_dir=$MODEL_DIR \ + --input_file=$INPUT_FILE \ + --topo_output_file=$CPP_TOPO_FILE \ + --output_topo=$CPP_OUTPUT_TOPO \ + --tensor_output_file=$CPP_TENSOR_FILE \ + --output_vars=$CPP_OUTPUT_VARS \ + --output_weights=$CPP_OUTPUT_WEIGHTS \ + --tensor_names=$TENSOR_NAMES \ + --tensor_output_length=$TENSOR_OUTPUT_LENGTH \ + --arm_thread_num=$CPP_ARM_THREAD_NUM +} + +function run_py_debug_tool { + check_enviroment + + local tool_name="analysis_tool.py" + local tool_path=$(find ${BUILD_ROOT_DIR} -type f -name ${tool_name}) + if [ "X${tool_path}" == "X" ]; then + echo -e "\nERROR: ${tool_name} not found in ${BUILD_ROOT_DIR}.\n" + return + fi + echo "Find Py-debug-tool path: ${tool_path}" + python ${tool_path} \ + --model_dir=$MODEL_DIR \ + --input_file=$INPUT_FILE \ + --topo_file=$CPP_TOPO_FILE \ + --tensor_file=$CPP_TENSOR_FILE \ + --tensor_names=$TENSOR_NAMES \ + --output_tensor=$PY_OUTPUT_TENSOR \ + --tensor_output_file=$PY_TENSOR_FILE \ + --tensor_output_length=$TENSOR_OUTPUT_LENGTH \ + --only_first=$PY_ONLY_OUTPUT_FIRST_DIFF \ + --output_file=$PY_OUTPUT_FILE \ + --threshold=$PY_THRESHOLD +} + +function main { + # Parse command line. + for i in "$@"; do + case $i in + --model_dir=*) + MODEL_DIR="${i#*=}" + shift + ;; + --input_file=*) + INPUT_FILE="${i#*=}" + shift + ;; + --cpp_topo_file=*) + CPP_TOPO_FILE="${i#*=}" + shift + ;; + --cpp_tensor_file=*) + CPP_TENSOR_FILE="${i#*=}" + shift + ;; + --tensor_names=*) + TENSOR_NAMES="${i#*=}" + shift + ;; + --tensor_output_length=*) + TENSOR_OUTPUT_LENGTH="${i#*=}" + shift + ;; + --cpp_output_vars=*) + CPP_OUTPUT_VARS="${i#*=}" + shift + ;; + --cpp_output_weights=*) + CPP_OUTPUT_WEIGHTS="${i#*=}" + shift + ;; + --py_threshold=*) + PY_THRESHOLD="${i#*=}" + shift + ;; + --py_tensor_file=*) + PY_TENSOR_FILE="${i#*=}" + shift + ;; + --py_output_file=*) + PY_OUTPUT_FILE="${i#*=}" + shift + ;; + --py_only_output_first_diff=*) + PY_ONLY_OUTPUT_FIRST_DIFF="${i#*=}" + shift + ;; + --py_output_tensor=*) + PY_OUTPUT_TENSOR="${i#*=}" + shift + ;; + --build_root_dir=*) + BUILD_ROOT_DIR="${i#*=}" + shift + ;; + debug_cpp_stage) + run_cpp_debug_tool + shift + ;; + debug_py_stage) + run_py_debug_tool + shift + ;; + *) + # unknown option + print_usage + exit 1 + ;; + esac + done +} + +main $@ diff --git a/paddle/fluid/lite/tools/debug/debug_utils.cc b/paddle/fluid/lite/tools/debug/debug_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..1dd8a5998607ced5cb724b32d9bcdc9bbc27c42b --- /dev/null +++ b/paddle/fluid/lite/tools/debug/debug_utils.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/tools/debug/debug_utils.h" diff --git a/paddle/fluid/lite/tools/debug/debug_utils.h b/paddle/fluid/lite/tools/debug/debug_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9f1843b09edfab6400318a810558fc3aaaa36a38 --- /dev/null +++ b/paddle/fluid/lite/tools/debug/debug_utils.h @@ -0,0 +1,329 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/utils/string.h" + +DEFINE_string(model_dir, "", "Model dir path"); +DEFINE_string(input_file, "", "Input datas file path"); +DEFINE_string(topo_output_file, "", "Runtime topology order output file path"); +DEFINE_bool(output_topo, true, "Dump runtime topology or not"); +DEFINE_string(tensor_output_file, "", "Tensor output file path"); +DEFINE_bool(output_vars, true, "Dump vars or not"); +DEFINE_bool(output_weights, true, "Dump weight tensors or not"); +DEFINE_string( + tensor_names, "", + "If tensor_names is not empty, then only this tensors will be dump"); +DEFINE_int32(tensor_output_length, -1, + "Output tensor data length, dims size will be used if " + "output_tensor_length < 0"); +DEFINE_int32(arm_thread_num, 1, "Arm thread nums, 1 as default"); +DEFINE_string(separator, ",", "Deafult separator, use in string split"); + +namespace paddle { +namespace lite { +namespace tools { +namespace debug { + +struct DebugConfig { + // arguments + std::string model_dir; + std::string topo_output_file; + std::string tensor_output_file; + std::string input_file; + std::vector tensor_names; + bool output_weights; + bool output_topo; + bool output_vars; + int tensor_output_length; + int arm_thread_num; + + std::unordered_map var_descs; + std::vector> input_values; +}; + +template +std::vector Split2Vector(const std::string& input, + const std::string& separator) { + std::vector tgt; + std::vector inputs = Split(input, separator); + tgt.resize(inputs.size()); + std::stringstream ss; + for (int i = 0; i < inputs.size(); ++i) { + ss << inputs[i] << " "; + } + for (int i = 0; i < inputs.size(); ++i) { + ss >> tgt[i]; + } + return tgt; +} + +void CollectFeedVarsInfo(std::unordered_map* feed_vars_info, + const framework::proto::ProgramDesc& desc) { + CHECK(feed_vars_info); + for (const auto& proto_op_desc : desc.blocks(0).ops()) { + lite::OpDesc op_desc(proto_op_desc); + auto op_type = op_desc.Type(); + if (op_type == "feed") { + (*feed_vars_info) + .emplace(op_desc.GetAttr("col"), op_desc.Output("Out").front()); + } + } +} +template +void FillTensorData(lite::Tensor* tensor, const DebugConfig& conf, int col) { + CHECK(tensor); + auto dim_size = tensor->dims().production(); + auto* data = tensor->mutable_data(); + if (conf.input_values.size() > 0) { + CHECK(col < conf.input_values[0].size()) + << "Input data fields out of index. field_len: " + << conf.input_values[0].size() << " col: " << col; + std::vector input_data( + std::move(Split2Vector(conf.input_values[0][col], " "))); + CHECK(input_data.size() == dim_size) + << "Input data field[" << col + << "] mismatch TensorDim: " << input_data.size() << " vs " << dim_size; + for (int i = 0; i < dim_size; i++) { + data[i] = input_data[i]; + } + } else { + LOG(INFO) << "------------> Use all-ones input"; + for (int i = 0; i < dim_size; i++) { + data[i] = 1; + } + } +} + +void CheckDim(std::vector* dim) { + CHECK(dim); + for (int i = 0; i < dim->size(); ++i) { + if ((*dim)[i] < 0) (*dim)[i] = -(*dim)[i]; + } +} + +void PrepareModelInputTensor(const DebugConfig& conf, lite::Scope* scope, + const framework::proto::ProgramDesc& desc) { + CHECK(scope); + + std::unordered_map feed_vars_info; + CollectFeedVarsInfo(&feed_vars_info, desc); + auto* feed_var = + scope->FindVar("feed")->GetMutable>(); + feed_var->resize(feed_vars_info.size()); + + for (auto& item : feed_vars_info) { + auto& var_desc = conf.var_descs.at(item.second); + auto val_type = var_desc.GetDataType(); + auto dim = var_desc.GetShape(); + CheckDim(&dim); + auto* input_tensor = &feed_var->at(item.first); + input_tensor->Resize(DDim(dim)); + switch (val_type) { +#define FILL_TENSOR_BY_TYPE_ONCE(pb_type__, type__) \ + case framework::proto::VarType::pb_type__: \ + FillTensorData(input_tensor, conf, item.first); \ + break + + FILL_TENSOR_BY_TYPE_ONCE(UINT8, uint8_t); + FILL_TENSOR_BY_TYPE_ONCE(INT8, int8_t); + FILL_TENSOR_BY_TYPE_ONCE(INT16, int16_t); + FILL_TENSOR_BY_TYPE_ONCE(INT32, int32_t); + FILL_TENSOR_BY_TYPE_ONCE(INT64, int64_t); + FILL_TENSOR_BY_TYPE_ONCE(FP32, float); + FILL_TENSOR_BY_TYPE_ONCE(FP64, double); + + default: + LOG(FATAL) << "Unsupported data type: " << static_cast(val_type); +#undef FILL_TENSOR_BY_TYPE_ONCE + } + } +} + +void ParseInputFile(DebugConfig* conf) { + CHECK(conf); + if (conf->input_file.empty()) return; + auto& inputs = conf->input_values; + std::ifstream fd(conf->input_file); + CHECK(fd.is_open()) << "Open input file: " << conf->input_file << " failed!"; + std::string line; + while (std::getline(fd, line)) { + inputs.emplace_back(std::move(Split(line, FLAGS_separator))); + } + LOG(INFO) << "Load data:" << inputs.size() << " items"; +} + +void ParseConfig(DebugConfig* conf) { + CHECK(conf); +#define CHECK_NON_EMPTY(name__) \ + CHECK(!FLAGS_##name__.empty()) << "Option " << #name__ << " can't be empty." + CHECK_NON_EMPTY(model_dir); + if (FLAGS_output_topo) { + CHECK_NON_EMPTY(topo_output_file); + } + if (FLAGS_output_vars || FLAGS_output_weights) { + CHECK_NON_EMPTY(tensor_output_file); + } +#undef CHECK_NON_EMPTY + conf->model_dir = FLAGS_model_dir; + conf->topo_output_file = FLAGS_topo_output_file; + conf->tensor_output_file = FLAGS_tensor_output_file; + conf->input_file = FLAGS_input_file; + conf->output_weights = FLAGS_output_weights; + conf->output_vars = FLAGS_output_vars; + conf->output_topo = FLAGS_output_topo; + conf->tensor_output_length = FLAGS_tensor_output_length; + conf->arm_thread_num = FLAGS_arm_thread_num; + + if (!FLAGS_tensor_names.empty()) { + conf->tensor_names = Split(FLAGS_tensor_names, FLAGS_separator); + } + + ParseInputFile(conf); +} + +void CollectAndDumpTopoInfo(const std::vector& instructions, + const DebugConfig& conf) { + if (!conf.output_topo) return; + LOG(INFO) << "----------------- dump topo file"; + std::ofstream os(conf.topo_output_file); + CHECK(os.is_open()); + for (auto& inst : instructions) { + auto* op_info = inst.op()->op_info(); + CHECK(op_info); + os << op_info->Type() << "\t"; + os << "("; +#define DUMP_TOPO_INFO_ONCE(name__) \ + { \ + auto argnames = op_info->name__##ArgumentNames(); \ + for (int i = 0; i < argnames.size(); ++i) { \ + os << argnames[i] << ":"; \ + auto vars = op_info->name__(argnames[i]); \ + for (int j = 0; j < vars.size(); ++j) { \ + os << vars[j]; \ + if (j != vars.size() - 1) os << "#"; \ + } \ + if (i != argnames.size() - 1) os << " "; \ + } \ + } + DUMP_TOPO_INFO_ONCE(Input); + os << ")\t("; + DUMP_TOPO_INFO_ONCE(Output); + os << ")\n"; +#undef DUMP_TOPO_INFO_ONCE + } + os.close(); +} + +void CollectVarDescs(std::unordered_map* var_descs, + const framework::proto::ProgramDesc& desc) { + CHECK(var_descs); + CHECK(!desc.blocks().empty()); + std::unordered_set weights; + for (auto proto_var_desc : desc.blocks(0).vars()) { + lite::VarDesc var_desc(proto_var_desc); + (*var_descs).emplace(var_desc.Name(), std::move(var_desc)); + } +} + +std::unordered_set CollectUnusedVars( + const std::vector& instructions) { + std::unordered_set unused; + std::unordered_set all_inputs; + for (auto& inst : instructions) { + for (const auto& name : inst.op()->op_info()->input_names()) { + all_inputs.insert(name); + } + } + + for (auto& inst : instructions) { + for (const auto& name : inst.op()->op_info()->output_names()) { + if (all_inputs.count(name) == 0) unused.insert(name); + } + } + + return unused; +} + +std::string GetTensorRepr(const lite::Tensor& tensor, int out_data_len) { + std::stringstream ss; + auto size = tensor.dims().production(); + if (out_data_len >= 0) { + size = std::min(size, static_cast(out_data_len)); + } + for (int i = 0; i < size; i++) { + ss << tensor.template data()[i]; + if (i != size - 1) ss << " "; + } + return ss.str(); +} + +void CollectAndDumpTensorInfo(const std::vector& instructions, + const framework::proto::ProgramDesc& desc, + const DebugConfig& conf) { + CHECK(instructions.size() > 0) << "No instruction found"; + const auto* scope = const_cast(instructions[0].op())->scope(); + std::ofstream os(conf.tensor_output_file); + CHECK(os.is_open()); + + std::unordered_set dump_vars; +#define DUMP_TENSOR_ONCE(name__) \ + LOG(INFO) << "----------------- dump tensor: " << name__; \ + auto& tensor = scope->FindVar(name__)->Get(); \ + os << name__ << "\t" << tensor.dims() << "\t" \ + << GetTensorRepr(tensor, conf.tensor_output_length) << "\n"; \ + dump_vars.insert(name__) + +#define DUMP_OP_TENSOR_ONCE(name__, skip__) \ + for (const auto& name : inst.op()->op_info()->name__##_names()) { \ + bool is_weight = conf.var_descs.at(name).Persistable(); \ + if (unused.count(name) != 0 || name == #skip__ || \ + (!conf.output_weights && is_weight) || \ + (!conf.output_vars && !is_weight) || dump_vars.count(name) != 0) \ + continue; \ + DUMP_TENSOR_ONCE(name); \ + } + + if (conf.tensor_names.size() == 0) { + std::unordered_set unused( + std::move(CollectUnusedVars(instructions))); + + for (auto& inst : instructions) { + DUMP_OP_TENSOR_ONCE(input, feed); + DUMP_OP_TENSOR_ONCE(output, fetch); + } + } else { + for (const auto& name : conf.tensor_names) { + DUMP_TENSOR_ONCE(name); + } + } +#undef DUMP_OP_TENSOR_ONCE +#undef DUMP_TENSOR_ONCE + os.close(); +} + +} // namespace debug +} // namespace tools +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/tools/debug/model_debug_tool.cc b/paddle/fluid/lite/tools/debug/model_debug_tool.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e29e0bd41341c0e9bcb97ecf8fae251b9a3af7f --- /dev/null +++ b/paddle/fluid/lite/tools/debug/model_debug_tool.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/lite/core/mir/use_passes.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/kernels/use_kernels.h" +#include "paddle/fluid/lite/operators/use_ops.h" +#include "paddle/fluid/lite/tools/debug/debug_utils.h" + +namespace paddle { +namespace lite { +namespace tools { +namespace debug { + +void Run(DebugConfig* conf) { + CHECK(conf); +#ifdef LITE_WITH_ARM + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, conf->arm_thread_num); +#endif + lite::Predictor predictor; + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, +#ifdef LITE_WITH_ARM + Place{TARGET(kARM), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_X86 + Place{TARGET(kX86), PRECISION(kFloat)}, +#endif + }); + + std::vector passes{{ + "static_kernel_pick_pass", "variable_place_inference_pass", + "type_target_transform_pass", "variable_place_inference_pass", + "io_copy_kernel_pick_pass", "variable_place_inference_pass", + "runtime_context_assign_pass", + }}; + + predictor.Build(conf->model_dir, +#ifdef LITE_WITH_ARM + Place{TARGET(kARM), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_X86 + Place{TARGET(kX86), PRECISION(kFloat)}, +#endif + valid_places, passes); + + auto& instructions = predictor.runtime_program().instructions(); + auto& program_desc = predictor.program_desc(); + auto* scope = const_cast(instructions[0].op())->scope(); + + CollectVarDescs(&(conf->var_descs), program_desc); + PrepareModelInputTensor(*conf, scope, program_desc); + predictor.Run(); + + CollectAndDumpTopoInfo(instructions, *conf); + CollectAndDumpTensorInfo(instructions, program_desc, *conf); + + // TODO(sangoly): Maybe add some profile info here + auto* out = predictor.GetOutput(0); + LOG(INFO) << out << " memory size " << out->data_size(); + LOG(INFO) << "out " << out->data()[0]; + LOG(INFO) << "dims " << out->dims(); + LOG(INFO) << "out data size: " << out->data_size(); +} + +} // namespace debug +} // namespace tools +} // namespace lite +} // namespace paddle + +int main(int argc, char** argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + paddle::lite::tools::debug::DebugConfig conf; + paddle::lite::tools::debug::ParseConfig(&conf); + paddle::lite::tools::debug::Run(&conf); + + return 0; +} diff --git a/paddle/fluid/lite/utils/string.h b/paddle/fluid/lite/utils/string.h index 5e918bf5f841b3f8d18ccf9ff94721534ec6a698..b32291ec3551459edfbed789379af63cc3db8ada 100644 --- a/paddle/fluid/lite/utils/string.h +++ b/paddle/fluid/lite/utils/string.h @@ -74,14 +74,21 @@ static std::string Repr(const std::vector& v) { return "{" + Join(tmp, ",") + "}"; } -static std::vector Split(const std::string& s, char delim) { - std::stringstream ss(s); - std::string line; - std::vector res; - while (std::getline(ss, line, delim)) { - res.push_back(line); +static std::vector Split(const std::string& original, + const std::string& separator) { + std::vector results; + std::string::size_type pos1, pos2; + pos2 = original.find(separator); + pos1 = 0; + while (std::string::npos != pos2) { + results.push_back(original.substr(pos1, pos2 - pos1)); + pos1 = pos2 + separator.size(); + pos2 = original.find(separator, pos1); } - return res; + if (pos1 != original.length()) { + results.push_back(original.substr(pos1)); + } + return results; } } // namespace lite