diff --git a/paddle/fluid/operators/conv_base_helper.h b/paddle/fluid/operators/conv_base_helper.h index 705fc1f5618b5f239d21039b43bac6a5d706af1a..00c24ead0c5341e7872fa74f70aab297367f8a19 100644 --- a/paddle/fluid/operators/conv_base_helper.h +++ b/paddle/fluid/operators/conv_base_helper.h @@ -40,9 +40,6 @@ using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; template struct SearchResult { SearchResult() {} - explicit SearchResult(const phi::autotune::DnnNode& node) - : algo(static_cast(node.algo)), - workspace_size(node.workspace_size) {} explicit SearchResult(AlgoT a) : algo(a) {} explicit SearchResult(AlgoT a, float t, size_t size) @@ -51,12 +48,21 @@ struct SearchResult { AlgoT algo = static_cast(0); float time = -1.f; size_t workspace_size = 0; + bool exhaustive_search = false; }; template static std::ostream& operator<<(std::ostream& out, const std::vector& v) { out << "["; - for (auto const& tmp : v) out << tmp << ","; + bool is_first = true; + for (auto const& tmp : v) { + if (is_first) { + out << tmp; + is_first = false; + } else { + out << ", " << tmp; + } + } out << "]"; return out; } @@ -109,7 +115,7 @@ struct ConvArgsBase { auto w_shape = phi::vectorize(w->dims()); VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape << ", strides=" << s << ", paddings=" << p << ", dilations=" << d - << ",data= " << paddle::experimental::CppTypeToDataType::Type() + << ", data=" << paddle::experimental::CppTypeToDataType::Type() << ", group=" << group << ", data layout=" << static_cast(data_layout); diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index 8795b3fa14bcc9c801c53a7b42ee3f9b8ea5246e..0388665a15ef6ec99983d20ef32f5fb143d1bcee 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -14,12 +14,11 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/conv_base_helper.h" #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/profiler.h" #include "paddle/phi/kernels/autotune/switch_autotune.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" namespace paddle { @@ -53,11 +52,9 @@ static void RemovePaddingSlice(const phi::GPUContext& context, } auto in_t = - framework::EigenTensor::From( - *input); - auto out_t = - framework::EigenTensor::From( - *out, new_out_dims); + phi::EigenTensor::From(*input); + auto out_t = phi::EigenTensor::From( + *out, new_out_dims); phi::funcs::EigenSlice, T, D>::Eval( place, out_t, in_t, offsets, extents); @@ -161,6 +158,8 @@ struct SearchAlgorithmBase { constexpr static phi::autotune::AlgorithmType kAlgoType = phi::autotune::AlgorithmType::kConvForward; + static const std::string GetPerfName() { return "ConvForward"; } + static size_t GetWorkspaceSize(const ConvArgs& args, cudnnConvolutionFwdAlgo_t algo) { size_t workspace_size = 0; @@ -334,6 +333,8 @@ struct SearchAlgorithmBase { constexpr static phi::autotune::AlgorithmType kAlgoType = phi::autotune::AlgorithmType::kConvBackwardData; + static const std::string GetPerfName() { return "ConvBackwardData"; } + static size_t GetWorkspaceSize(const ConvArgs& args, cudnnConvolutionBwdDataAlgo_t algo) { size_t workspace_size = 0; @@ -514,6 +515,8 @@ struct SearchAlgorithmBase { constexpr static phi::autotune::AlgorithmType kAlgoType = phi::autotune::AlgorithmType::kConvBackwardFilter; + static const std::string GetPerfName() { return "ConvBackwardFilter"; } + static size_t GetWorkspaceSize(const ConvArgs& args, cudnnConvolutionBwdFilterAlgo_t algo) { platform::CUDAGraphCaptureModeGuard guard; @@ -752,11 +755,13 @@ struct SearchAlgorithm : public SearchAlgorithmBase { using AlgoT = typename SearchAlgorithmBase::AlgoT; template - static SearchResult Find(const ConvArgs& args, + static SearchResult Find(const phi::GPUContext& ctx, + const ConvArgs& args, bool exhaustive_search, bool deterministic, - const phi::GPUContext& ctx) { + bool enable_autotune = true) { SearchResult result; + bool use_autotune = false; auto dtype = platform::CudnnDataType::type; SetConvMathType(ctx, dtype, args.cdesc); @@ -764,33 +769,50 @@ struct SearchAlgorithm : public SearchAlgorithmBase { result = SearchAlgorithmBase::FindAlgoDeterministic(args); } else { // 1. Once turning on exhaustive FLAGS, always get exhaustive_search. - // 2. Once turning on auto-tune, runn heuristic search(default) before + // 2. Once turning on auto-tune, run heuristic (default) before // auto-tune process, run exhaustive_search during mentioned process. + // Auto tune is only enabled between specified range. // 3. After auto-tune process, run cached algorithm if cached, run // default mode for the rest. auto key = args.Convert2ConvCacheKey(); auto& cache = phi::autotune::AutoTuneCache::Instance().GetConv( SearchAlgorithmBase::kAlgoType); - if (cache.Find(key)) { + bool find_in_cache = cache.Find(key); + if (find_in_cache) { auto t = cache.Get(key); result.algo = static_cast(t.algo); result.workspace_size = t.workspace_size; - } else { - bool use_autotune = - phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); + result.exhaustive_search = t.exhaustive_search; + } + if (!result.exhaustive_search) { + bool need_update_cache = false; + // In conv2d_tranpose, enable_autotune is set to false because some + // algorithm picked by exhaustive search method produce wrong result. + use_autotune = enable_autotune && + phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); if (exhaustive_search || use_autotune) { + // Once autotune is enabled, the autotuned result can rewrite the + // previous result in cache found by heuristic method. result = SearchAlgorithmBase::template FindAlgoExhaustiveSearch( args, ctx); - } else { + need_update_cache = true; + } else if (!find_in_cache) { result = SearchAlgorithmBase::FindAlgoHeuristic(args, ctx); + need_update_cache = true; + } + if (need_update_cache) { + phi::autotune::ConvAutoTuneResult node( + static_cast(result.algo), + result.workspace_size, + exhaustive_search || use_autotune); + cache.Set(key, node); } - phi::autotune::DnnNode node(static_cast(result.algo), - result.workspace_size); - cache.Set(key, node); } } - VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search + VLOG(3) << "[cuDNN " << SearchAlgorithmBase::GetPerfName() + << "] exhaustive_search=" << exhaustive_search + << ", use_autotune=" << use_autotune << ", deterministic=" << deterministic << ", choose algo=" << result.algo << ", workspace=" << ToMegaBytes(result.workspace_size) << " MB"; diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index 949cae1532b3f72cc495487645348ef9b139e996..dc639e9f21ecfa7126b4a136acb2e38d5156f9fb 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -56,12 +56,14 @@ struct hash> { namespace phi { namespace autotune { -struct DnnNode { - DnnNode() {} - explicit DnnNode(int64_t a, size_t size) : algo(a), workspace_size(size) {} +struct ConvAutoTuneResult { + ConvAutoTuneResult() {} + ConvAutoTuneResult(int64_t a, size_t size, bool search) + : algo(a), workspace_size(size), exhaustive_search(search) {} int64_t algo; size_t workspace_size = 0; + bool exhaustive_search = false; }; template @@ -73,40 +75,41 @@ size_t GetKey(Args&&... args) { struct ConvCacheKey { ConvCacheKey() {} - explicit ConvCacheKey(const std::vector& x_dims, - const std::vector& w_dims, - const std::vector& strides, - const std::vector& paddings, - const std::vector& dilations, - phi::DataType dtype, - int groups, - int64_t data_layout) - : x_dims_(x_dims), - w_dims_(w_dims), - strides_(strides), - paddings_(paddings), - dilations_(dilations), - dtype_(dtype), - groups_(groups), - data_layout_(data_layout) {} + ConvCacheKey(const std::vector& arg_x_dims, + const std::vector& arg_w_dims, + const std::vector& arg_strides, + const std::vector& arg_paddings, + const std::vector& arg_dilations, + phi::DataType arg_dtype, + int arg_groups, + int64_t arg_data_layout) + : x_dims(arg_x_dims), + w_dims(arg_w_dims), + strides(arg_strides), + paddings(arg_paddings), + dilations(arg_dilations), + dtype(arg_dtype), + groups(arg_groups), + data_layout(arg_data_layout) {} size_t hash_value() const { - return GetKey(x_dims_, - w_dims_, - strides_, - paddings_, - dilations_, - static_cast(dtype_), - groups_, - data_layout_); + return GetKey(x_dims, + w_dims, + strides, + paddings, + dilations, + static_cast(dtype), + groups, + data_layout); } - std::vector x_dims_; - std::vector w_dims_; - std::vector strides_; - std::vector paddings_; - std::vector dilations_; - phi::DataType dtype_; - int groups_; - int64_t data_layout_; + + std::vector x_dims; + std::vector w_dims; + std::vector strides; + std::vector paddings; + std::vector dilations; + phi::DataType dtype; + int groups; + int64_t data_layout; }; struct ConvCacheKeyHash { @@ -118,14 +121,14 @@ struct ConvCacheKeyHash { struct ConvCacheKeyEqual { size_t operator()(const ConvCacheKey& first, const ConvCacheKey& second) const { - if (first.x_dims_ != second.x_dims_) return false; - if (first.w_dims_ != second.w_dims_) return false; - if (first.strides_ != second.strides_) return false; - if (first.paddings_ != second.paddings_) return false; - if (first.dilations_ != second.dilations_) return false; - if (first.dtype_ != second.dtype_) return false; - if (first.groups_ != second.groups_) return false; - if (first.data_layout_ != second.data_layout_) return false; + if (first.x_dims != second.x_dims) return false; + if (first.w_dims != second.w_dims) return false; + if (first.strides != second.strides) return false; + if (first.paddings != second.paddings) return false; + if (first.dilations != second.dilations) return false; + if (first.dtype != second.dtype) return false; + if (first.groups != second.groups) return false; + if (first.data_layout != second.data_layout) return false; return true; } @@ -135,7 +138,7 @@ class CudnnAlgorithmsCacheMap { public: CudnnAlgorithmsCacheMap() : cache_mutex_(new std::mutex()) { hash_.clear(); } - DnnNode Get(const ConvCacheKey& key) { + ConvAutoTuneResult Get(const ConvCacheKey& key) { std::lock_guard lock(*cache_mutex_); PADDLE_ENFORCE_NE( hash_.find(key), @@ -163,7 +166,7 @@ class CudnnAlgorithmsCacheMap { cache_misses_ = 0; } - void Set(const ConvCacheKey& key, DnnNode algo) { + void Set(const ConvCacheKey& key, ConvAutoTuneResult algo) { std::lock_guard lock(*cache_mutex_); if (hash_.size() > static_cast(FLAGS_search_cache_max_number)) { hash_.clear(); @@ -188,7 +191,10 @@ class CudnnAlgorithmsCacheMap { int64_t Size() const { return hash_.size(); } private: - std::unordered_map + std::unordered_map hash_; std::shared_ptr cache_mutex_; @@ -293,21 +299,6 @@ class AutoTuneCache { return cudnn_auto_tune_map_[static_cast(algo_type)]; } - CudnnAlgorithmsCacheMap& GetConvForward() { - return cudnn_auto_tune_map_[static_cast( - AlgorithmType::kConvForward)]; - } - - CudnnAlgorithmsCacheMap& GetConvBackwardData() { - return cudnn_auto_tune_map_[static_cast( - AlgorithmType::kConvBackwardData)]; - } - - CudnnAlgorithmsCacheMap& GetConvBackwardFilter() { - return cudnn_auto_tune_map_[static_cast( - AlgorithmType::kConvBackwardFilter)]; - } - AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); } void Clean() { diff --git a/paddle/phi/kernels/autotune/cache_test.cc b/paddle/phi/kernels/autotune/cache_test.cc index 29affd45f0f5c7bec74f8d4e013bc5950714d4ae..18454ad3e19977d6e98ef5b5613be52137d06940 100644 --- a/paddle/phi/kernels/autotune/cache_test.cc +++ b/paddle/phi/kernels/autotune/cache_test.cc @@ -25,7 +25,8 @@ enum ConvAlgos { GEMMKernel = 0, CuDNNKernel_1 = 1, CuDNNKernel_2 = 2 }; TEST(AlgosCache, AlgosCache) { auto autotune_cache = phi::autotune::AutoTuneCache::Instance(); - auto& cache = autotune_cache.GetConvForward(); + auto& cache = + autotune_cache.GetConv(phi::autotune::AlgorithmType::kConvForward); std::vector x_shape = {4, 224, 224, 3}; std::vector w_shape = {32, 3, 3, 3}; @@ -37,7 +38,8 @@ TEST(AlgosCache, AlgosCache) { phi::autotune::ConvCacheKey key( x_shape, w_shape, paddings, strides, dilations, dtype, 0, 0); EXPECT_EQ(cache.Find(key), false); - phi::autotune::DnnNode node(static_cast(ConvAlgos::GEMMKernel), 0); + phi::autotune::ConvAutoTuneResult node( + static_cast(ConvAlgos::GEMMKernel), 0, false); cache.Set(key, node); EXPECT_EQ(cache.Size(), 1); EXPECT_EQ(cache.Find(key), true); @@ -48,8 +50,8 @@ TEST(AlgosCache, AlgosCache) { phi::autotune::ConvCacheKey key1( x_shape, w_shape, paddings, strides, dilations, dtype, 0, 1); EXPECT_EQ(cache.Find(key1), false); - phi::autotune::DnnNode node1(static_cast(ConvAlgos::CuDNNKernel_1), - 0); + phi::autotune::ConvAutoTuneResult node1( + static_cast(ConvAlgos::CuDNNKernel_1), 0, false); cache.Set(key1, node1); EXPECT_EQ(cache.Size(), 2); EXPECT_EQ(cache.CacheHits(), 1); diff --git a/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu index fb9580427e1f4535c7932fefdaf43febe418ff26..e61f58450b34f22cb76342798672cdd06241a067 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu @@ -336,7 +336,7 @@ void ConvCudnnGradGradKernel( #else using search1 = paddle::operators::SearchAlgorithm; - fwd_result1 = search1::Find(args1, exhaustive_search, false, ctx); + fwd_result1 = search1::Find(ctx, args1, exhaustive_search, false); workspace_size = search1::GetWorkspaceSize(args1, fwd_result1.algo); #endif } @@ -364,7 +364,7 @@ void ConvCudnnGradGradKernel( #else using search2 = paddle::operators::SearchAlgorithm; - fwd_result2 = search2::Find(args2, exhaustive_search, false, ctx); + fwd_result2 = search2::Find(ctx, args2, exhaustive_search, false); workspace_size = std::max( workspace_size, search2::GetWorkspaceSize(args2, fwd_result2.algo)); #endif @@ -394,7 +394,7 @@ void ConvCudnnGradGradKernel( using search3 = paddle::operators::SearchAlgorithm; filter_result = - search3::Find(args3, exhaustive_search, deterministic, ctx); + search3::Find(ctx, args3, exhaustive_search, deterministic); workspace_size = std::max( workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo)); #endif @@ -424,7 +424,7 @@ void ConvCudnnGradGradKernel( using search4 = paddle::operators::SearchAlgorithm; data_result = - search4::Find(args4, exhaustive_search, deterministic, ctx); + search4::Find(ctx, args4, exhaustive_search, deterministic); workspace_size = std::max( workspace_size, search4::GetWorkspaceSize(args4, data_result.algo)); #endif diff --git a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu index bc7a8b4f3784015e3b97c07d4ab4775233eafdde..2d61ec6e62c9ca07b23114123990133eea9b01cd 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu @@ -373,7 +373,7 @@ void ConvCudnnGradKernel(const Context& ctx, #else using search1 = paddle::operators::SearchAlgorithm; - bwd_result = search1::Find(args1, exhaustive_search, deterministic, ctx); + bwd_result = search1::Find(ctx, args1, exhaustive_search, deterministic); workspace_size_d = std::max(workspace_size_d, bwd_result.workspace_size); #endif } @@ -402,7 +402,7 @@ void ConvCudnnGradKernel(const Context& ctx, using search2 = paddle::operators::SearchAlgorithm; filter_result = - search2::Find(args2, exhaustive_search, deterministic, ctx); + search2::Find(ctx, args2, exhaustive_search, deterministic); VLOG(3) << "filter algo: " << filter_result.algo << ", time " << filter_result.time; workspace_size_w = std::max(workspace_size_w, filter_result.workspace_size); diff --git a/paddle/phi/kernels/gpudnn/conv_kernel.cu b/paddle/phi/kernels/gpudnn/conv_kernel.cu index aa591a34a4399c3784f1343fe69fae48845fe280..7a6e8d8148fa150dee3483549926e793ba7dd439 100644 --- a/paddle/phi/kernels/gpudnn/conv_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_kernel.cu @@ -25,7 +25,6 @@ #endif #include "paddle/fluid/platform/cudnn_workspace_helper.h" -#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" @@ -56,8 +55,7 @@ void ConvCudnnKernel(const Context& ctx, bool exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search_t; bool deterministic = FLAGS_cudnn_deterministic; - auto exhaustive_deterministic = exhaustive_search && deterministic; - PADDLE_ENFORCE_EQ(exhaustive_deterministic, + PADDLE_ENFORCE_EQ(exhaustive_search && deterministic, false, phi::errors::InvalidArgument( "Cann't set exhaustive_search True and " @@ -315,7 +313,7 @@ void ConvCudnnKernel(const Context& ctx, paddle::operators::SearchResult fwd_result; using search = paddle::operators::SearchAlgorithm; - fwd_result = search::Find(args, exhaustive_search, deterministic, ctx); + fwd_result = search::Find(ctx, args, exhaustive_search, deterministic); workspace_size = fwd_result.workspace_size; #endif diff --git a/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu index 3acb1604f4a610d73b551f2d6086960f0c0d6a04..d05bd58e33080a13585c08016ff700148bb44860 100644 --- a/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu @@ -230,7 +230,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, #else using search1 = paddle::operators::SearchAlgorithm; - fwd_result = search1::Find(args1, false, deterministic, ctx); + fwd_result = search1::Find(ctx, args1, false, deterministic, false); workspace_size = std::max( workspace_size, search1::GetWorkspaceSize(args1, fwd_result.algo)); #endif @@ -257,7 +257,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, #else using search2 = paddle::operators::SearchAlgorithm; - filter_result = search2::Find(args2, false, deterministic, ctx); + filter_result = search2::Find(ctx, args2, false, deterministic, false); workspace_size = std::max( workspace_size, search2::GetWorkspaceSize(args2, filter_result.algo)); #endif @@ -710,7 +710,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( #else using search1 = paddle::operators::SearchAlgorithm; - bwd_result1 = search1::Find(args1, false, deterministic, ctx); + bwd_result1 = search1::Find(ctx, args1, false, deterministic, false); workspace_size = search1::GetWorkspaceSize(args1, bwd_result1.algo); #endif @@ -734,7 +734,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( #else using search2 = paddle::operators::SearchAlgorithm; - bwd_result2 = search2::Find(args2, false, deterministic, ctx); + bwd_result2 = search2::Find(ctx, args2, false, deterministic, false); workspace_size = std::max( workspace_size, search2::GetWorkspaceSize(args2, bwd_result2.algo)); #endif @@ -761,7 +761,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( #else using search3 = paddle::operators::SearchAlgorithm; - filter_result = search3::Find(args3, false, deterministic, ctx); + filter_result = search3::Find(ctx, args3, false, deterministic, false); workspace_size = std::max( workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo)); #endif @@ -789,7 +789,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( #else using search4 = paddle::operators::SearchAlgorithm; - fwd_result = search4::Find(args4, false, deterministic, ctx); + fwd_result = search4::Find(ctx, args4, false, deterministic, false); workspace_size = std::max( workspace_size, search4::GetWorkspaceSize(args4, fwd_result.algo)); #endif diff --git a/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu b/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu index 6fc1e2eff135206b7baa03bcfb3cd958397fe86d..84332f0ccb892a9a7fed4ff37daf57f2efd4740c 100644 --- a/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu @@ -230,7 +230,7 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx, paddle::operators::SearchResult bwd_result; using search = paddle::operators::SearchAlgorithm; - bwd_result = search::Find(args, false, deterministic, ctx); + bwd_result = search::Find(ctx, args, false, deterministic, false); workspace_size = std::max(workspace_size, search::GetWorkspaceSize(args, bwd_result.algo)); #endif diff --git a/tools/check_op_benchmark_result.py b/tools/check_op_benchmark_result.py index 73075125ac46b7c5a3974504586f079b6bc035fe..aaf194ff95ec587a115436cae8971f589cec7040 100644 --- a/tools/check_op_benchmark_result.py +++ b/tools/check_op_benchmark_result.py @@ -72,15 +72,20 @@ def check_speed_result(case_name, develop_data, pr_data, pr_result): """ pr_gpu_time = pr_data.get("gpu_time") develop_gpu_time = develop_data.get("gpu_time") - gpu_time_diff = (pr_gpu_time - develop_gpu_time) / develop_gpu_time + if develop_gpu_time != 0.0: + gpu_time_diff = (pr_gpu_time - develop_gpu_time) / develop_gpu_time + gpu_time_diff_str = "{:.5f}".format(gpu_time_diff * 100) + else: + gpu_time_diff = None + gpu_time_diff_str = "" pr_total_time = pr_data.get("total") develop_total_time = develop_data.get("total") total_time_diff = (pr_total_time - develop_total_time) / develop_total_time logging.info("------ OP: %s ------" % case_name) - logging.info("GPU time change: %.5f%% (develop: %.7f -> PR: %.7f)" % - (gpu_time_diff * 100, develop_gpu_time, pr_gpu_time)) + logging.info("GPU time change: %s (develop: %.7f -> PR: %.7f)" % + (gpu_time_diff_str, develop_gpu_time, pr_gpu_time)) logging.info("Total time change: %.5f%% (develop: %.7f -> PR: %.7f)" % (total_time_diff * 100, develop_total_time, pr_total_time)) logging.info("backward: %s" % pr_result.get("backward")) @@ -196,7 +201,8 @@ if __name__ == "__main__": args.develop_logs_dir) check_path_exists(args.pr_logs_dir) - for log_file in os.listdir(args.pr_logs_dir): + pr_log_files = os.listdir(args.pr_logs_dir) + for log_file in sorted(pr_log_files): develop_result = develop_result_dict.get(log_file) pr_result = parse_log_file(os.path.join(args.pr_logs_dir, log_file)) if develop_result is None or pr_result is None: