未验证 提交 7da5368d 编写于 作者: J Jacek Czaja 提交者: GitHub

Make GetBlob assuming elements are cached (#38336)

* First set of fixes

* - Make more likely to GetBlob find a blobs

* - Lint
上级 3629cd27
...@@ -43,8 +43,8 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -43,8 +43,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
Tensor in; Tensor in;
in.ShareDataWith(input_tensor); in.ShareDataWith(input_tensor);
Tensor out; Tensor out;
DataLayout lin = kernel_type_for_var.data_layout_; const DataLayout lin = kernel_type_for_var.data_layout_;
DataLayout lout = expected_kernel_type.data_layout_; const DataLayout lout = expected_kernel_type.data_layout_;
// do layout transform // do layout transform
if (NeedTransformLayout(lout, lin)) { if (NeedTransformLayout(lout, lin)) {
......
...@@ -806,11 +806,10 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -806,11 +806,10 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.Attr<std::vector<float>>("Scale_weights"); ctx.Attr<std::vector<float>>("Scale_weights");
const bool is_multi_channel = scale_weights_data.size() > 1; const bool is_multi_channel = scale_weights_data.size() > 1;
const int& groups = ctx.Attr<int>("groups"); const int& groups = ctx.Attr<int>("groups");
const bool& is_test = ctx.Attr<bool>("is_test");
int mask_reorder = int mask_reorder =
is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0; is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0;
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, groups, false, is_test, scale_weights_data, mask_reorder); filter, groups, false, true, scale_weights_data, mask_reorder);
std::shared_ptr<dnnl::memory> dst_memory_p; std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) { if (fuse_residual_conn) {
...@@ -842,7 +841,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -842,7 +841,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
auto p_scales_tuple = handler.get_int8_bias_scales(ctx); auto p_scales_tuple = handler.get_int8_bias_scales(ctx);
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder( auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
bias, is_test, std::get<1>(*p_scales_tuple), bias, true, std::get<1>(*p_scales_tuple),
std::get<0>(*p_scales_tuple)); std::get<0>(*p_scales_tuple));
args.insert({DNNL_ARG_BIAS, *bias_memory_p}); args.insert({DNNL_ARG_BIAS, *bias_memory_p});
} }
......
...@@ -869,6 +869,15 @@ unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const { ...@@ -869,6 +869,15 @@ unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
return num_entries; return num_entries;
} }
// TODO(jczaja): Replace with C++20 equivalents when applicable
#ifdef _WIN32
#define likely(expr) (expr)
#define unlikely(expr) (expr)
#else
#define likely(expr) (__builtin_expect(!!(expr), 1))
#define unlikely(expr) (__builtin_expect(!!(expr), 0))
#endif
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob( MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
const std::string& name) const { const std::string& name) const {
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
...@@ -881,7 +890,10 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob( ...@@ -881,7 +890,10 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
// Find ShapeBlob for current mkldnn session id firstly // Find ShapeBlob for current mkldnn session id firstly
auto map_it = pMap->find(sid); auto map_it = pMap->find(sid);
if (map_it == pMap->end()) { // (jczaja): After first iteration of model's execution we
// should have all elements cached (mostly) so failures are unlikely (less
// likely for dynamic shapes)
if (unlikely(map_it == pMap->end())) {
VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n"; VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
return nullptr; return nullptr;
} }
...@@ -889,7 +901,7 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob( ...@@ -889,7 +901,7 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
// Find KeyBlob for current input shape secondly // Find KeyBlob for current input shape secondly
auto sBlob_it = sBlob->find(tls().cur_input_shape_str); auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
if (sBlob_it == sBlob->end()) { if (unlikely(sBlob_it == sBlob->end())) {
VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
<< ", miss input_shape_str\n"; << ", miss input_shape_str\n";
return nullptr; return nullptr;
...@@ -899,7 +911,7 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob( ...@@ -899,7 +911,7 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
// Find Blob via name // Find Blob via name
auto key_it = pBlob->find(name); auto key_it = pBlob->find(name);
if (key_it == pBlob->end()) { if (unlikely(key_it == pBlob->end())) {
VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n"; VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
return nullptr; return nullptr;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册