未验证 提交 3bc4b850 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enable to record whether the conv algo is got by exhaustive search to fix...

Enable to record whether the conv algo is got by exhaustive search to fix autotune cache bug. (#47065)
上级 af4bdede
......@@ -40,9 +40,6 @@ using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
template <typename AlgoT>
struct SearchResult {
SearchResult() {}
explicit SearchResult(const phi::autotune::DnnNode& node)
: algo(static_cast<AlgoT>(node.algo)),
workspace_size(node.workspace_size) {}
explicit SearchResult(AlgoT a) : algo(a) {}
explicit SearchResult(AlgoT a, float t, size_t size)
......@@ -51,12 +48,21 @@ struct SearchResult {
AlgoT algo = static_cast<AlgoT>(0);
float time = -1.f;
size_t workspace_size = 0;
bool exhaustive_search = false;
};
template <typename T>
static std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
out << "[";
for (auto const& tmp : v) out << tmp << ",";
bool is_first = true;
for (auto const& tmp : v) {
if (is_first) {
out << tmp;
is_first = false;
} else {
out << ", " << tmp;
}
}
out << "]";
return out;
}
......@@ -109,7 +115,7 @@ struct ConvArgsBase {
auto w_shape = phi::vectorize(w->dims());
VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape
<< ", strides=" << s << ", paddings=" << p << ", dilations=" << d
<< ",data= " << paddle::experimental::CppTypeToDataType<T>::Type()
<< ", data=" << paddle::experimental::CppTypeToDataType<T>::Type()
<< ", group=" << group
<< ", data layout=" << static_cast<int64_t>(data_layout);
......
......@@ -14,12 +14,11 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/conv_base_helper.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace paddle {
......@@ -53,11 +52,9 @@ static void RemovePaddingSlice(const phi::GPUContext& context,
}
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*input);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, new_out_dims);
phi::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(*input);
auto out_t = phi::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, new_out_dims);
phi::funcs::EigenSlice<std::decay_t<decltype(place)>, T, D>::Eval(
place, out_t, in_t, offsets, extents);
......@@ -161,6 +158,8 @@ struct SearchAlgorithmBase<cudnnConvolutionFwdAlgoPerf_t> {
constexpr static phi::autotune::AlgorithmType kAlgoType =
phi::autotune::AlgorithmType::kConvForward;
static const std::string GetPerfName() { return "ConvForward"; }
static size_t GetWorkspaceSize(const ConvArgs& args,
cudnnConvolutionFwdAlgo_t algo) {
size_t workspace_size = 0;
......@@ -334,6 +333,8 @@ struct SearchAlgorithmBase<cudnnConvolutionBwdDataAlgoPerf_t> {
constexpr static phi::autotune::AlgorithmType kAlgoType =
phi::autotune::AlgorithmType::kConvBackwardData;
static const std::string GetPerfName() { return "ConvBackwardData"; }
static size_t GetWorkspaceSize(const ConvArgs& args,
cudnnConvolutionBwdDataAlgo_t algo) {
size_t workspace_size = 0;
......@@ -514,6 +515,8 @@ struct SearchAlgorithmBase<cudnnConvolutionBwdFilterAlgoPerf_t> {
constexpr static phi::autotune::AlgorithmType kAlgoType =
phi::autotune::AlgorithmType::kConvBackwardFilter;
static const std::string GetPerfName() { return "ConvBackwardFilter"; }
static size_t GetWorkspaceSize(const ConvArgs& args,
cudnnConvolutionBwdFilterAlgo_t algo) {
platform::CUDAGraphCaptureModeGuard guard;
......@@ -752,11 +755,13 @@ struct SearchAlgorithm : public SearchAlgorithmBase<PerfT> {
using AlgoT = typename SearchAlgorithmBase<PerfT>::AlgoT;
template <typename T>
static SearchResult<AlgoT> Find(const ConvArgs& args,
static SearchResult<AlgoT> Find(const phi::GPUContext& ctx,
const ConvArgs& args,
bool exhaustive_search,
bool deterministic,
const phi::GPUContext& ctx) {
bool enable_autotune = true) {
SearchResult<AlgoT> result;
bool use_autotune = false;
auto dtype = platform::CudnnDataType<T>::type;
SetConvMathType(ctx, dtype, args.cdesc);
......@@ -764,33 +769,50 @@ struct SearchAlgorithm : public SearchAlgorithmBase<PerfT> {
result = SearchAlgorithmBase<PerfT>::FindAlgoDeterministic(args);
} else {
// 1. Once turning on exhaustive FLAGS, always get exhaustive_search.
// 2. Once turning on auto-tune, runn heuristic search(default) before
// 2. Once turning on auto-tune, run heuristic (default) before
// auto-tune process, run exhaustive_search during mentioned process.
// Auto tune is only enabled between specified range.
// 3. After auto-tune process, run cached algorithm if cached, run
// default mode for the rest.
auto key = args.Convert2ConvCacheKey<T>();
auto& cache = phi::autotune::AutoTuneCache::Instance().GetConv(
SearchAlgorithmBase<PerfT>::kAlgoType);
if (cache.Find(key)) {
bool find_in_cache = cache.Find(key);
if (find_in_cache) {
auto t = cache.Get(key);
result.algo = static_cast<AlgoT>(t.algo);
result.workspace_size = t.workspace_size;
} else {
bool use_autotune =
phi::autotune::AutoTuneStatus::Instance().UseAutoTune();
result.exhaustive_search = t.exhaustive_search;
}
if (!result.exhaustive_search) {
bool need_update_cache = false;
// In conv2d_tranpose, enable_autotune is set to false because some
// algorithm picked by exhaustive search method produce wrong result.
use_autotune = enable_autotune &&
phi::autotune::AutoTuneStatus::Instance().UseAutoTune();
if (exhaustive_search || use_autotune) {
// Once autotune is enabled, the autotuned result can rewrite the
// previous result in cache found by heuristic method.
result =
SearchAlgorithmBase<PerfT>::template FindAlgoExhaustiveSearch<T>(
args, ctx);
} else {
need_update_cache = true;
} else if (!find_in_cache) {
result = SearchAlgorithmBase<PerfT>::FindAlgoHeuristic(args, ctx);
need_update_cache = true;
}
if (need_update_cache) {
phi::autotune::ConvAutoTuneResult node(
static_cast<int64_t>(result.algo),
result.workspace_size,
exhaustive_search || use_autotune);
cache.Set(key, node);
}
phi::autotune::DnnNode node(static_cast<int64_t>(result.algo),
result.workspace_size);
cache.Set(key, node);
}
}
VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search
VLOG(3) << "[cuDNN " << SearchAlgorithmBase<PerfT>::GetPerfName()
<< "] exhaustive_search=" << exhaustive_search
<< ", use_autotune=" << use_autotune
<< ", deterministic=" << deterministic
<< ", choose algo=" << result.algo
<< ", workspace=" << ToMegaBytes(result.workspace_size) << " MB";
......
......@@ -56,12 +56,14 @@ struct hash<std::vector<T>> {
namespace phi {
namespace autotune {
struct DnnNode {
DnnNode() {}
explicit DnnNode(int64_t a, size_t size) : algo(a), workspace_size(size) {}
struct ConvAutoTuneResult {
ConvAutoTuneResult() {}
ConvAutoTuneResult(int64_t a, size_t size, bool search)
: algo(a), workspace_size(size), exhaustive_search(search) {}
int64_t algo;
size_t workspace_size = 0;
bool exhaustive_search = false;
};
template <typename... Args>
......@@ -73,40 +75,41 @@ size_t GetKey(Args&&... args) {
struct ConvCacheKey {
ConvCacheKey() {}
explicit ConvCacheKey(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& w_dims,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
phi::DataType dtype,
int groups,
int64_t data_layout)
: x_dims_(x_dims),
w_dims_(w_dims),
strides_(strides),
paddings_(paddings),
dilations_(dilations),
dtype_(dtype),
groups_(groups),
data_layout_(data_layout) {}
ConvCacheKey(const std::vector<int64_t>& arg_x_dims,
const std::vector<int64_t>& arg_w_dims,
const std::vector<int>& arg_strides,
const std::vector<int>& arg_paddings,
const std::vector<int>& arg_dilations,
phi::DataType arg_dtype,
int arg_groups,
int64_t arg_data_layout)
: x_dims(arg_x_dims),
w_dims(arg_w_dims),
strides(arg_strides),
paddings(arg_paddings),
dilations(arg_dilations),
dtype(arg_dtype),
groups(arg_groups),
data_layout(arg_data_layout) {}
size_t hash_value() const {
return GetKey(x_dims_,
w_dims_,
strides_,
paddings_,
dilations_,
static_cast<int64_t>(dtype_),
groups_,
data_layout_);
return GetKey(x_dims,
w_dims,
strides,
paddings,
dilations,
static_cast<int64_t>(dtype),
groups,
data_layout);
}
std::vector<int64_t> x_dims_;
std::vector<int64_t> w_dims_;
std::vector<int> strides_;
std::vector<int> paddings_;
std::vector<int> dilations_;
phi::DataType dtype_;
int groups_;
int64_t data_layout_;
std::vector<int64_t> x_dims;
std::vector<int64_t> w_dims;
std::vector<int> strides;
std::vector<int> paddings;
std::vector<int> dilations;
phi::DataType dtype;
int groups;
int64_t data_layout;
};
struct ConvCacheKeyHash {
......@@ -118,14 +121,14 @@ struct ConvCacheKeyHash {
struct ConvCacheKeyEqual {
size_t operator()(const ConvCacheKey& first,
const ConvCacheKey& second) const {
if (first.x_dims_ != second.x_dims_) return false;
if (first.w_dims_ != second.w_dims_) return false;
if (first.strides_ != second.strides_) return false;
if (first.paddings_ != second.paddings_) return false;
if (first.dilations_ != second.dilations_) return false;
if (first.dtype_ != second.dtype_) return false;
if (first.groups_ != second.groups_) return false;
if (first.data_layout_ != second.data_layout_) return false;
if (first.x_dims != second.x_dims) return false;
if (first.w_dims != second.w_dims) return false;
if (first.strides != second.strides) return false;
if (first.paddings != second.paddings) return false;
if (first.dilations != second.dilations) return false;
if (first.dtype != second.dtype) return false;
if (first.groups != second.groups) return false;
if (first.data_layout != second.data_layout) return false;
return true;
}
......@@ -135,7 +138,7 @@ class CudnnAlgorithmsCacheMap {
public:
CudnnAlgorithmsCacheMap() : cache_mutex_(new std::mutex()) { hash_.clear(); }
DnnNode Get(const ConvCacheKey& key) {
ConvAutoTuneResult Get(const ConvCacheKey& key) {
std::lock_guard<std::mutex> lock(*cache_mutex_);
PADDLE_ENFORCE_NE(
hash_.find(key),
......@@ -163,7 +166,7 @@ class CudnnAlgorithmsCacheMap {
cache_misses_ = 0;
}
void Set(const ConvCacheKey& key, DnnNode algo) {
void Set(const ConvCacheKey& key, ConvAutoTuneResult algo) {
std::lock_guard<std::mutex> lock(*cache_mutex_);
if (hash_.size() > static_cast<size_t>(FLAGS_search_cache_max_number)) {
hash_.clear();
......@@ -188,7 +191,10 @@ class CudnnAlgorithmsCacheMap {
int64_t Size() const { return hash_.size(); }
private:
std::unordered_map<ConvCacheKey, DnnNode, ConvCacheKeyHash, ConvCacheKeyEqual>
std::unordered_map<ConvCacheKey,
ConvAutoTuneResult,
ConvCacheKeyHash,
ConvCacheKeyEqual>
hash_;
std::shared_ptr<std::mutex> cache_mutex_;
......@@ -293,21 +299,6 @@ class AutoTuneCache {
return cudnn_auto_tune_map_[static_cast<int64_t>(algo_type)];
}
CudnnAlgorithmsCacheMap& GetConvForward() {
return cudnn_auto_tune_map_[static_cast<int64_t>(
AlgorithmType::kConvForward)];
}
CudnnAlgorithmsCacheMap& GetConvBackwardData() {
return cudnn_auto_tune_map_[static_cast<int64_t>(
AlgorithmType::kConvBackwardData)];
}
CudnnAlgorithmsCacheMap& GetConvBackwardFilter() {
return cudnn_auto_tune_map_[static_cast<int64_t>(
AlgorithmType::kConvBackwardFilter)];
}
AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); }
void Clean() {
......
......@@ -25,7 +25,8 @@ enum ConvAlgos { GEMMKernel = 0, CuDNNKernel_1 = 1, CuDNNKernel_2 = 2 };
TEST(AlgosCache, AlgosCache) {
auto autotune_cache = phi::autotune::AutoTuneCache::Instance();
auto& cache = autotune_cache.GetConvForward();
auto& cache =
autotune_cache.GetConv(phi::autotune::AlgorithmType::kConvForward);
std::vector<int64_t> x_shape = {4, 224, 224, 3};
std::vector<int64_t> w_shape = {32, 3, 3, 3};
......@@ -37,7 +38,8 @@ TEST(AlgosCache, AlgosCache) {
phi::autotune::ConvCacheKey key(
x_shape, w_shape, paddings, strides, dilations, dtype, 0, 0);
EXPECT_EQ(cache.Find(key), false);
phi::autotune::DnnNode node(static_cast<int64_t>(ConvAlgos::GEMMKernel), 0);
phi::autotune::ConvAutoTuneResult node(
static_cast<int64_t>(ConvAlgos::GEMMKernel), 0, false);
cache.Set(key, node);
EXPECT_EQ(cache.Size(), 1);
EXPECT_EQ(cache.Find(key), true);
......@@ -48,8 +50,8 @@ TEST(AlgosCache, AlgosCache) {
phi::autotune::ConvCacheKey key1(
x_shape, w_shape, paddings, strides, dilations, dtype, 0, 1);
EXPECT_EQ(cache.Find(key1), false);
phi::autotune::DnnNode node1(static_cast<int64_t>(ConvAlgos::CuDNNKernel_1),
0);
phi::autotune::ConvAutoTuneResult node1(
static_cast<int64_t>(ConvAlgos::CuDNNKernel_1), 0, false);
cache.Set(key1, node1);
EXPECT_EQ(cache.Size(), 2);
EXPECT_EQ(cache.CacheHits(), 1);
......
......@@ -336,7 +336,7 @@ void ConvCudnnGradGradKernel(
#else
using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
fwd_result1 = search1::Find<T>(ctx, args1, exhaustive_search, false);
workspace_size = search1::GetWorkspaceSize(args1, fwd_result1.algo);
#endif
}
......@@ -364,7 +364,7 @@ void ConvCudnnGradGradKernel(
#else
using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
fwd_result2 = search2::Find<T>(ctx, args2, exhaustive_search, false);
workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, fwd_result2.algo));
#endif
......@@ -394,7 +394,7 @@ void ConvCudnnGradGradKernel(
using search3 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_result =
search3::Find<T>(args3, exhaustive_search, deterministic, ctx);
search3::Find<T>(ctx, args3, exhaustive_search, deterministic);
workspace_size = std::max(
workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo));
#endif
......@@ -424,7 +424,7 @@ void ConvCudnnGradGradKernel(
using search4 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_result =
search4::Find<T>(args4, exhaustive_search, deterministic, ctx);
search4::Find<T>(ctx, args4, exhaustive_search, deterministic);
workspace_size = std::max(
workspace_size, search4::GetWorkspaceSize(args4, data_result.algo));
#endif
......
......@@ -373,7 +373,7 @@ void ConvCudnnGradKernel(const Context& ctx,
#else
using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_result = search1::Find<T>(args1, exhaustive_search, deterministic, ctx);
bwd_result = search1::Find<T>(ctx, args1, exhaustive_search, deterministic);
workspace_size_d = std::max(workspace_size_d, bwd_result.workspace_size);
#endif
}
......@@ -402,7 +402,7 @@ void ConvCudnnGradKernel(const Context& ctx,
using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_result =
search2::Find<T>(args2, exhaustive_search, deterministic, ctx);
search2::Find<T>(ctx, args2, exhaustive_search, deterministic);
VLOG(3) << "filter algo: " << filter_result.algo << ", time "
<< filter_result.time;
workspace_size_w = std::max(workspace_size_w, filter_result.workspace_size);
......
......@@ -25,7 +25,6 @@
#endif
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
......@@ -56,8 +55,7 @@ void ConvCudnnKernel(const Context& ctx,
bool exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search_t;
bool deterministic = FLAGS_cudnn_deterministic;
auto exhaustive_deterministic = exhaustive_search && deterministic;
PADDLE_ENFORCE_EQ(exhaustive_deterministic,
PADDLE_ENFORCE_EQ(exhaustive_search && deterministic,
false,
phi::errors::InvalidArgument(
"Cann't set exhaustive_search True and "
......@@ -315,7 +313,7 @@ void ConvCudnnKernel(const Context& ctx,
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result;
using search =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result = search::Find<T>(args, exhaustive_search, deterministic, ctx);
fwd_result = search::Find<T>(ctx, args, exhaustive_search, deterministic);
workspace_size = fwd_result.workspace_size;
#endif
......
......@@ -230,7 +230,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
#else
using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result = search1::Find<T>(args1, false, deterministic, ctx);
fwd_result = search1::Find<T>(ctx, args1, false, deterministic, false);
workspace_size = std::max(
workspace_size, search1::GetWorkspaceSize(args1, fwd_result.algo));
#endif
......@@ -257,7 +257,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
#else
using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_result = search2::Find<T>(args2, false, deterministic, ctx);
filter_result = search2::Find<T>(ctx, args2, false, deterministic, false);
workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, filter_result.algo));
#endif
......@@ -710,7 +710,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
#else
using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_result1 = search1::Find<T>(args1, false, deterministic, ctx);
bwd_result1 = search1::Find<T>(ctx, args1, false, deterministic, false);
workspace_size = search1::GetWorkspaceSize(args1, bwd_result1.algo);
#endif
......@@ -734,7 +734,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
#else
using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_result2 = search2::Find<T>(args2, false, deterministic, ctx);
bwd_result2 = search2::Find<T>(ctx, args2, false, deterministic, false);
workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, bwd_result2.algo));
#endif
......@@ -761,7 +761,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
#else
using search3 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_result = search3::Find<T>(args3, false, deterministic, ctx);
filter_result = search3::Find<T>(ctx, args3, false, deterministic, false);
workspace_size = std::max(
workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo));
#endif
......@@ -789,7 +789,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
#else
using search4 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result = search4::Find<T>(args4, false, deterministic, ctx);
fwd_result = search4::Find<T>(ctx, args4, false, deterministic, false);
workspace_size = std::max(
workspace_size, search4::GetWorkspaceSize(args4, fwd_result.algo));
#endif
......
......@@ -230,7 +230,7 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx,
paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result;
using search =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_result = search::Find<T>(args, false, deterministic, ctx);
bwd_result = search::Find<T>(ctx, args, false, deterministic, false);
workspace_size =
std::max(workspace_size, search::GetWorkspaceSize(args, bwd_result.algo));
#endif
......
......@@ -72,15 +72,20 @@ def check_speed_result(case_name, develop_data, pr_data, pr_result):
"""
pr_gpu_time = pr_data.get("gpu_time")
develop_gpu_time = develop_data.get("gpu_time")
gpu_time_diff = (pr_gpu_time - develop_gpu_time) / develop_gpu_time
if develop_gpu_time != 0.0:
gpu_time_diff = (pr_gpu_time - develop_gpu_time) / develop_gpu_time
gpu_time_diff_str = "{:.5f}".format(gpu_time_diff * 100)
else:
gpu_time_diff = None
gpu_time_diff_str = ""
pr_total_time = pr_data.get("total")
develop_total_time = develop_data.get("total")
total_time_diff = (pr_total_time - develop_total_time) / develop_total_time
logging.info("------ OP: %s ------" % case_name)
logging.info("GPU time change: %.5f%% (develop: %.7f -> PR: %.7f)" %
(gpu_time_diff * 100, develop_gpu_time, pr_gpu_time))
logging.info("GPU time change: %s (develop: %.7f -> PR: %.7f)" %
(gpu_time_diff_str, develop_gpu_time, pr_gpu_time))
logging.info("Total time change: %.5f%% (develop: %.7f -> PR: %.7f)" %
(total_time_diff * 100, develop_total_time, pr_total_time))
logging.info("backward: %s" % pr_result.get("backward"))
......@@ -196,7 +201,8 @@ if __name__ == "__main__":
args.develop_logs_dir)
check_path_exists(args.pr_logs_dir)
for log_file in os.listdir(args.pr_logs_dir):
pr_log_files = os.listdir(args.pr_logs_dir)
for log_file in sorted(pr_log_files):
develop_result = develop_result_dict.get(log_file)
pr_result = parse_log_file(os.path.join(args.pr_logs_dir, log_file))
if develop_result is None or pr_result is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册