未验证 提交 ceb83562 编写于 作者: W Wilber 提交者: GitHub

fix conv_fusion in multi thread. (#55374)

上级 ce8455c0
...@@ -130,7 +130,6 @@ class CudnnConvDescManager { ...@@ -130,7 +130,6 @@ class CudnnConvDescManager {
XXH64_hash_t hash_key = XXH64_digest(state); XXH64_hash_t hash_key = XXH64_digest(state);
XXH64_freeState(state); XXH64_freeState(state);
if (!cudnn_conv_cache_.count(hash_key)) {
std::lock_guard<std::mutex> lock(cache_mutex_); std::lock_guard<std::mutex> lock(cache_mutex_);
if (!cudnn_conv_cache_.count(hash_key)) { if (!cudnn_conv_cache_.count(hash_key)) {
cudnn_conv_cache_[hash_key] = CudnnCacheInfo(); cudnn_conv_cache_[hash_key] = CudnnCacheInfo();
...@@ -158,7 +157,6 @@ class CudnnConvDescManager { ...@@ -158,7 +157,6 @@ class CudnnConvDescManager {
cudnn_conv_cache_[hash_key].workspace_size = workspace_size; cudnn_conv_cache_[hash_key].workspace_size = workspace_size;
cudnn_conv_cache_[hash_key].algo = algo; cudnn_conv_cache_[hash_key].algo = algo;
} }
}
return &cudnn_conv_cache_.at(hash_key); return &cudnn_conv_cache_.at(hash_key);
} }
...@@ -199,7 +197,6 @@ class CudnnConvDescManager { ...@@ -199,7 +197,6 @@ class CudnnConvDescManager {
XXH64_hash_t hash_key = XXH64_digest(state); XXH64_hash_t hash_key = XXH64_digest(state);
XXH64_freeState(state); XXH64_freeState(state);
if (!conv_attr_cache_.count(hash_key)) {
std::lock_guard<std::mutex> lock(attr_mutex_); std::lock_guard<std::mutex> lock(attr_mutex_);
if (!conv_attr_cache_.count(hash_key)) { if (!conv_attr_cache_.count(hash_key)) {
ConvAttrCacheInfo cache; ConvAttrCacheInfo cache;
...@@ -254,12 +251,10 @@ class CudnnConvDescManager { ...@@ -254,12 +251,10 @@ class CudnnConvDescManager {
} }
if (format == CUDNN_TENSOR_NCHW) { if (format == CUDNN_TENSOR_NCHW) {
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 4 + 1] = input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
paddings[2 * i + 1] - padding_common[i];
} else { } else {
input_pad[2 * i + 2] = paddings[2 * i] - padding_common[i]; input_pad[2 * i + 2] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 2 + 1] = input_pad[2 * i + 2 + 1] = paddings[2 * i + 1] - padding_common[i];
paddings[2 * i + 1] - padding_common[i];
} }
} }
...@@ -283,7 +278,6 @@ class CudnnConvDescManager { ...@@ -283,7 +278,6 @@ class CudnnConvDescManager {
cache.paddings = padding_common; cache.paddings = padding_common;
conv_attr_cache_[hash_key] = cache; conv_attr_cache_[hash_key] = cache;
} }
}
return &conv_attr_cache_.at(hash_key); return &conv_attr_cache_.at(hash_key);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册