diff --git a/dnn/src/cuda/cudnn_wrapper.cpp b/dnn/src/cuda/cudnn_wrapper.cpp index 5101d4ada5a6074f74e15085e5e07b31c9b68205..dbc2c4dc47780342ea1ce2780d6d742261daf2b5 100644 --- a/dnn/src/cuda/cudnn_wrapper.cpp +++ b/dnn/src/cuda/cudnn_wrapper.cpp @@ -485,7 +485,7 @@ CudnnAlgoPack::conv_bwd_data_algos() { CudnnAlgoPack::Attr> algos = { DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false, false), - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, false), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, true), DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true, true), DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true, true), #if CUDNN_MAJOR >= 5 diff --git a/imperative/python/megengine/tools/accuracy_shake_var_tree.py b/imperative/python/megengine/tools/accuracy_shake_var_tree.py new file mode 100755 index 0000000000000000000000000000000000000000..3734437b152ba33d8b52edec0724174e259d179d --- /dev/null +++ b/imperative/python/megengine/tools/accuracy_shake_var_tree.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# -*-coding=utf-8-*- + +# This tool is used to analyze the file generated by compare_binary_iodump.py. +# parse() can build a dependency tree with those varnodes +# where accuracy shake occurs and show the root varnodes. +# get_varNode()/get_dependence_list()/get_reference_list()/show_src_info() +# are some functions which are used to query dependencies between varnodes. +import argparse +import os + + +class varNode: + var_node_dict = {} + var_node_root_dict = {} + + def __init__(self, id, dependence_list, src_info): + self.src_info = src_info + if not id in varNode.var_node_dict.keys(): + self.id = id + self.dependence_list = [] + self.reference_list = [] + else: + self = varNode.var_node_dict[id] + + if dependence_list: + self.vitrual = False + self.is_root = True + else: + self.vitrual = True + self.is_root = False + + for i in dependence_list: + if not i in varNode.var_node_dict.keys(): + varNode.var_node_dict[i] = varNode(i, [], "") + + dv = varNode.var_node_dict[i] + self.dependence_list.append(dv) + if not dv.vitrual: + self.is_root = False + dv.reference_list.append(self) + + for i in self.reference_list: + i.is_root = False + varNode.var_node_root_dict.pop[i.id] + + if self.is_root: + varNode.var_node_root_dict[id] = self + + varNode.var_node_dict[id] = self + + @staticmethod + def get_varNode(id): + return varNode.var_node_dict[id] + + def get_dependence_list(self): + return self.dependence_list + + def get_reference_list(self): + return self.reference_list + + def show_src_info(self): + print(self.src_info) + + +def get_dependence(string, src_info): + start1 = "id:" + end1 = "," + e = 0 + + count = string.count(start1) + dependence_list = [] + for x in range(0, count): + s = string.find(start1, e) + e = string.find(end1, s) + sub_str = string[s:e] + if x == 0: + var = sub_str + else: + dependence_list.append(sub_str) + varNode(var, dependence_list, src_info) + + +def parse(filename): + with open(filename) as f: + varNode.var_node_dict.clear() + varNode.var_node_root_dict.clear() + line = f.readline() + s = ["", "", ""] + idx = 1 + while line: + if line.find("not equal: ") != -1: + s[2] = line + src_info = s[0] + "\n" + s[1] + "\n" + s[2] + get_dependence(s[0], src_info) + else: + if line.find("var={id:") != -1: + idx = idx ^ 1 + s[idx] = "" + s[idx] = s[idx] + line.strip() + line = f.readline() + + return varNode.var_node_root_dict + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Analyze the outputs of compare_binary_iodump.py" + "Should save the outputs of compare_binary_iodump.py" + "as a file" + ), + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "filename", help="file which save the outputs of compare_binary_iodump.py" + ) + args = parser.parse_args() + + parse(args.filename) + + print("varnode root:") + for key, value in varNode.var_node_root_dict.items(): + print(key) + print("detail info:") + value.show_src_info() + + +if __name__ == "__main__": + main() diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index 91608d6b2bd8493344aa0a1416fe20015dc05cc4..9859267f83c06219965985e0730d404cea3533a2 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -701,25 +701,19 @@ void run_test_st(Args &env) { mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, env.workspace_limit); using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; S strategy = S::HEURISTIC; + if (env.reproducible) { + strategy = S::REPRODUCIBLE; + } #if MGB_ENABLE_FASTRUN if (env.use_full_run) { - if (env.reproducible) { - strategy = S::PROFILE | S::REPRODUCIBLE; - } else { - strategy = S::PROFILE; - } + strategy = S::PROFILE | strategy; } else if (env.use_fast_run) { - strategy = S::PROFILE | S::OPTIMIZED; - if (env.reproducible){ - strategy = strategy | S::REPRODUCIBLE; - } - } else if (env.reproducible) { - strategy = S::HEURISTIC | S::REPRODUCIBLE; + strategy = S::PROFILE | S::OPTIMIZED | strategy; + } else { + strategy = S::HEURISTIC | strategy; } #else - if (env.reproducible) { - strategy = S::HEURISTIC | S::REPRODUCIBLE; - } + strategy = S::HEURISTIC | strategy; #endif mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy); if (!env.fast_run_cache_path.empty()) {