未验证 提交 bf748f24 编写于 作者: J Jacek Czaja 提交者: GitHub

Implemented LRU based cache clearing (#36290)

- Lint

- Merge with develop

- lint
上级 59e425cd
......@@ -78,7 +78,8 @@ class ConvMKLDNNHandlerT
mkldnn::convolution_backward_weights>(
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) {
unique_name)),
is_test_(ctx.Attr<bool>("is_test")) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
input->layout(), framework::DataLayout::kMKLDNN,
......@@ -159,7 +160,6 @@ class ConvMKLDNNHandlerT
framework::slice_ddim(filter_dims, 2, filter_dims.size());
const auto ksize = framework::vectorize(filter_data_dims);
const bool is_test = ctx.Attr<bool>("is_test");
auto strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
......@@ -214,9 +214,8 @@ class ConvMKLDNNHandlerT
const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
const auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
const auto fwd_prop_kind = is_test_ ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training;
float sum_scale = 1.0f;
std::vector<float> output_shift_scale;
if (platform::is_int8<T>())
......@@ -261,7 +260,8 @@ class ConvMKLDNNHandlerT
mkldnn::convolution_backward_weights>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(in->dims()),
unique_name)) {
unique_name)),
is_test_(false) {
if (!this->isBwdCached()) {
PADDLE_ENFORCE_EQ(
in->layout(), framework::DataLayout::kMKLDNN,
......@@ -291,7 +291,7 @@ class ConvMKLDNNHandlerT
"Wrong format set for output_grad tensor"));
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
is_test_, false,
platform::errors::InvalidArgument(
"is_test attribute should be set to False in training phase."));
......@@ -557,13 +557,14 @@ class ConvMKLDNNHandlerT
framework::vectorize(in_mem->dims()),
platform::MKLDNNGetDataType<T>(), in_mem->format());
return this->AcquireMemoryWithReorder(
user_mem_md, mem_md, platform::to_void_cast<T>(in_mem_data), key_mem);
user_mem_md, mem_md, platform::to_void_cast<T>(in_mem_data), key_mem,
is_test_);
} else {
const std::string target_key_suffix{key_mem_target};
const auto target_mem_p = this->AcquireMemory(target_key_suffix);
user_mem_p->set_data_handle(platform::to_void_cast<T>(in_mem_data));
if (user_mem_p != target_mem_p) {
this->AcquireReorder(user_mem_p, target_mem_p, key_mem);
this->AcquireReorder(user_mem_p, target_mem_p);
}
return target_mem_p;
}
......@@ -571,12 +572,11 @@ class ConvMKLDNNHandlerT
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
const framework::Tensor* filter, const int groups, const bool is_conv3d,
const bool is_test, const std::vector<float>& scale_data = {1.0f},
int mask = 0) {
const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
// This is workaround to make execution faster, delete
// if statement after including md inside Tensor
auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
if (is_test && weights_mem_p) {
if (is_test_ && weights_mem_p) {
return weights_mem_p;
} else {
const K* filter_data = filter->data<K>();
......@@ -589,16 +589,16 @@ class ConvMKLDNNHandlerT
return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test, {},
scale_data, mask);
platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test_,
{}, scale_data, mask);
}
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
const framework::Tensor* bias, const bool is_test,
const framework::Tensor* bias,
const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
if (is_test && bias_mem_p) {
if (is_test_ && bias_mem_p) {
return bias_mem_p;
} else {
const K* bias_data = bias->data<K>();
......@@ -608,7 +608,7 @@ class ConvMKLDNNHandlerT
return this->AcquireMemoryWithReorder(
user_bias_md, this->fwd_pd_->bias_desc(),
platform::to_void_cast<K>(bias_data), "@bias_mem_p", is_test, {},
platform::to_void_cast<K>(bias_data), "@bias_mem_p", is_test_, {},
scale_data, mask);
}
}
......@@ -641,7 +641,7 @@ class ConvMKLDNNHandlerT
platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) {
auto residual_memory_p = this->AcquireResidualMemory(residual_param);
dst_memory_p = this->template AcquireDstMemory<T_out>(output);
this->AcquireReorder(residual_memory_p, dst_memory_p, "@residual_dst");
this->AcquireReorder(residual_memory_p, dst_memory_p);
} else {
// Changing ShareDataWith to TensorCopy results in performance drop
// on ResNet architectures
......@@ -651,6 +651,9 @@ class ConvMKLDNNHandlerT
}
return dst_memory_p;
}
private:
const bool is_test_;
};
} // anonymous namespace
......@@ -695,7 +698,6 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const bool is_test = ctx.Attr<bool>("is_test");
const bool is_conv3d = ctx.Attr<std::vector<int>>("strides").size() == 3U;
const bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
......@@ -712,7 +714,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, ctx.Attr<int>("groups"), is_conv3d, is_test);
filter, ctx.Attr<int>("groups"), is_conv3d);
std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) {
......@@ -731,7 +733,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
{MKLDNN_ARG_DST, *dst_memory_p}};
if (bias) {
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test);
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias);
args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
}
......@@ -783,11 +785,10 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.Attr<std::vector<float>>("Scale_weights");
const bool is_multi_channel = scale_weights_data.size() > 1;
const int& groups = ctx.Attr<int>("groups");
const bool& is_test = ctx.Attr<bool>("is_test");
int mask_reorder =
is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0;
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, groups, false, is_test, scale_weights_data, mask_reorder);
filter, groups, false, scale_weights_data, mask_reorder);
std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) {
......@@ -822,7 +823,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
handler.get_int8_bias_scales(ctx);
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
bias, is_test, scale_bias_data, mask_reorder);
bias, scale_bias_data, mask_reorder);
args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
}
......
......@@ -51,10 +51,10 @@ class ConvTransposeMKLDNNHandlerT
: platform::MKLDNNHandlerT<T, mkldnn::deconvolution_forward>(
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) {
unique_name)),
is_test_(ctx.Attr<bool>("is_test")) {
if (!this->isCached()) {
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(is_test, true,
PADDLE_ENFORCE_EQ(is_test_, true,
platform::errors::InvalidArgument(
"ConvTransposeMKLDNN works only for inference. "
"The attribute \'is_test\' value should be set to "
......@@ -169,7 +169,7 @@ class ConvTransposeMKLDNNHandlerT
const mkldnn::primitive_attr conv_trans_attr =
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta);
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
auto fwd_prop_kind = is_test_ ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training;
if (bias) {
std::vector<int64_t> bias_tz = framework::vectorize(bias->dims());
......@@ -231,18 +231,18 @@ class ConvTransposeMKLDNNHandlerT
const auto target_src_mem_p = this->AcquireMemory(target_key_suffix);
user_src_mem_p->set_data_handle(platform::to_void_cast<T>(input_data));
if (user_src_mem_p != target_src_mem_p) {
this->AcquireReorder(user_src_mem_p, target_src_mem_p, "@src_mem_p");
this->AcquireReorder(user_src_mem_p, target_src_mem_p);
}
return target_src_mem_p;
}
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
const framework::Tensor* filter, const int& groups, const bool& is_test) {
const framework::Tensor* filter, const int& groups) {
// This is workaround to make execution faster, delete
// if statement after including md inside Tensor
auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
if (is_test && weights_mem_p) {
if (is_test_ && weights_mem_p) {
return weights_mem_p;
} else {
const K* filter_data = filter->data<K>();
......@@ -277,15 +277,15 @@ class ConvTransposeMKLDNNHandlerT
return this->template AcquireMemoryWithReorder<K>(
user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test,
platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test_,
iohw2oihw_reorder);
}
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
const framework::Tensor* bias, const bool& is_test) {
const framework::Tensor* bias) {
auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
if (is_test && bias_mem_p) {
if (is_test_ && bias_mem_p) {
return bias_mem_p;
} else {
const K* bias_data = bias->data<K>();
......@@ -294,9 +294,12 @@ class ConvTransposeMKLDNNHandlerT
MKLDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder(
user_bias_md, this->fwd_pd_->bias_desc(),
platform::to_void_cast<K>(bias_data), "@bias_mem_p", is_test);
platform::to_void_cast<K>(bias_data), "@bias_mem_p", is_test_);
}
}
private:
const bool is_test_;
};
template <typename T, typename K>
......@@ -325,8 +328,6 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const bool is_test = ctx.Attr<bool>("is_test");
const auto* input = ctx.Input<Tensor>("Input");
const auto* filter = ctx.Input<Tensor>("Filter");
const auto* bias =
......@@ -340,7 +341,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
output, unique_name);
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, ctx.Attr<int>("groups"), is_test);
filter, ctx.Attr<int>("groups"));
std::shared_ptr<dnnl::memory> dst_memory_p =
handler.template AcquireDstMemory<T_out>(output);
......@@ -352,7 +353,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
{MKLDNN_ARG_DST, *dst_memory_p}};
if (bias) {
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test);
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias);
args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
}
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
......
......@@ -64,21 +64,11 @@ class QuantOpKernel : public framework::OpKernel<T> {
bool is_negative_input = ctx.Attr<bool>("is_negative_input");
bool bfloat16 = ctx.Attr<bool>("bfloat16");
std::string key =
platform::CreateKey(dev_ctx, src_tz, scale_data, scale_shift,
is_negative_input, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
// TODO(jczaja): Refactor with Acquire API
std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory;
std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
if (reorder_p == nullptr) {
std::string out_layout = ctx.Attr<std::string>("output_format");
MKLDNNMemoryFormat out_format =
platform::data_format_to_memory_format(out_layout);
......@@ -97,8 +87,8 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32,
input->format());
src_memory = std::make_shared<mkldnn::memory>(
src_md, engine, to_void_cast<T>(input_data));
src_memory = std::make_shared<mkldnn::memory>(src_md, engine,
to_void_cast<T>(input_data));
std::shared_ptr<mkldnn::memory::desc> dst_md;
if (bfloat16) {
......@@ -108,38 +98,13 @@ class QuantOpKernel : public framework::OpKernel<T> {
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory, out_format);
} else {
platform::SetDstMemoryQuantized<uint8_t>(
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
platform::SetDstMemoryQuantized<uint8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory, out_format);
}
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
dev_ctx.SetBlob(key_prim, reorder_p);
dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory);
} else {
src_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(to_void_cast<T>(input_data));
dst_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_dst_mem));
auto place = ctx.GetPlace();
if (bfloat16) {
dst_memory->set_data_handle(
output->mutable_data<paddle::platform::bfloat16>(place));
} else if (with_shift || !is_negative_input) {
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
if (with_shift) std::memset(output_data, scale_shift, output->numel());
dst_memory->set_data_handle(output_data);
} else {
dst_memory->set_data_handle(
output->mutable_data<int8_t>(ctx.GetPlace()));
}
}
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
{
platform::RecordEvent record_reorder("int_reorder",
......
......@@ -11,6 +11,12 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_context.h"
#include <set>
#include <utility>
#ifdef _WIN32
#include <intrin.h>
#else
#include <x86intrin.h>
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
......@@ -666,7 +672,7 @@ void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
// of this executor
for (auto& s : *p_exec_items_) {
for (auto& v : (*s.second)[ptr]) {
(v.first)->erase(v.second);
(v.first)->second.erase(v.second);
}
s.second->erase(ptr);
}
......@@ -677,11 +683,26 @@ void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
}
}
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
p_exec_items_->erase(p_exec_items_->begin());
std::string MKLDNNDeviceContext::PickLeastUsedShape(
BlobPtr_t<ShapeBlob> sb) const {
auto ancient_one = sb->begin();
for (auto v = std::next(sb->begin()); v != sb->end(); ++v) {
if (v->second->first < ancient_one->second->first) {
ancient_one = v;
}
}
VLOG(2) << "num_shapes: " << sb->size()
<< ", remove all blobs of shape: " << ancient_one->first;
return ancient_one->first;
}
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(
std::string shape_to_be_removed) const {
p_exec_items_->erase(shape_to_be_removed);
}
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
void MKLDNNDeviceContext::LinkEntryWithExecutor(
BlobPtr_t<std::pair<unsigned long long, KeyBlob>> pblob,
KeyBlob::iterator it) const {
// Take current input shape from TLS
// Take current executor addess from TLS
......@@ -719,7 +740,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
BlobPtr_t<void> data) const {
BlobMap* pMap = p_blobmap_.get();
BlobPtr_t<ShapeBlob> sBlob = nullptr;
BlobPtr_t<KeyBlob> pBlob = nullptr;
BlobPtr_t<std::pair<unsigned long long, KeyBlob>> pBlob = nullptr;
int sid = tls().get_cur_mkldnn_session_id();
......@@ -748,22 +769,24 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
sBlob->size() &&
(sBlob->size() >=
static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
VLOG(2) << "sid=" << sid
<< ", remove all blobs of shape: " << sBlob->begin()->first;
sBlob->erase(sBlob->begin()->first);
RemoveShapeEntriesWithExecutor();
auto shape_to_be_erased = PickLeastUsedShape(sBlob);
sBlob->erase(shape_to_be_erased);
RemoveShapeEntriesWithExecutor(shape_to_be_erased);
}
pBlob = std::make_shared<KeyBlob>();
pBlob = std::make_shared<std::pair<unsigned long long, KeyBlob>>();
pBlob->first = __rdtsc();
(*sBlob)[tls().cur_input_shape_str] = pBlob;
} else {
pBlob = key_it->second;
// Update time stamp
pBlob->first = __rdtsc();
}
// Find Blob via name
auto blob_it = pBlob->find(name);
if (blob_it == pBlob->end()) {
auto el =
pBlob->insert(std::make_pair(name, data)); // (*pBlob)[name] = data;
auto blob_it = pBlob->second.find(name);
if (blob_it == pBlob->second.end()) {
auto el = pBlob->second.insert(
std::make_pair(name, data)); // (*pBlob)[name] = data;
// Register new element in per executor map
// to have easily erased when executor terminated
LinkEntryWithExecutor(pBlob, el.first);
......@@ -779,7 +802,7 @@ unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
unsigned int num_entries = 0;
for (auto const& l3 : *p_blobmap_) {
for (auto const& l2 : *(l3.second)) {
num_entries += (l2.second)->size();
num_entries += (l2.second->second).size();
}
}
return num_entries;
......@@ -789,7 +812,7 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
const std::string& name) const {
BlobMap* pMap = p_blobmap_.get();
BlobPtr_t<ShapeBlob> sBlob = nullptr;
BlobPtr_t<KeyBlob> pBlob = nullptr;
BlobPtr_t<std::pair<unsigned long long, KeyBlob>> pBlob = nullptr;
int sid = tls().get_cur_mkldnn_session_id();
......@@ -813,12 +836,14 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
pBlob = sBlob_it->second;
// Find Blob via name
auto key_it = pBlob->find(name);
auto key_it = pBlob->second.find(name);
if (key_it == pBlob->end()) {
if (key_it == pBlob->second.end()) {
VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
return nullptr;
}
// Update timestamp
sBlob_it->second->first = __rdtsc(); // TODO(windows)
VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
// lock will be automatically released when out of scope
......
......@@ -757,18 +757,20 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - ShapeBlob = Map<cur_input_shape_str,<unsigned long long, KeyBlob>>
// - KeyBlob = Map<blob_name, blob>
using KeyBlob = umap_key_string_t<void>;
using ShapeBlob = umap_key_string_t<KeyBlob>;
using ShapeBlob = umap_key_string_t<std::pair<unsigned long long, KeyBlob>>;
using BlobMap = umap_value_smart_t<int, ShapeBlob>;
// Auxillary two-level structure (shape, executor) to easier control
// clearing cache objects related to specific executor
using ExecKey = void*;
using ExecMapCacheIterPair = std::pair<BlobPtr_t<KeyBlob>, KeyBlob::iterator>;
using ExecMapCacheIterPair =
std::pair<BlobPtr_t<std::pair<unsigned long long, KeyBlob>>,
KeyBlob::iterator>;
using ExecMap =
std::unordered_map<ExecKey, std::vector<ExecMapCacheIterPair>>;
using ExecShape = std::unordered_map<std::string, std::shared_ptr<ExecMap>>;
......@@ -779,8 +781,11 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
const mkldnn::engine& GetEngine() const { return tls().get_engine(); }
// Register object to currently used executor's map
void LinkEntryWithExecutor(BlobPtr_t<KeyBlob>, KeyBlob::iterator) const;
void RemoveShapeEntriesWithExecutor(void) const;
void LinkEntryWithExecutor(
BlobPtr_t<std::pair<unsigned long long, KeyBlob>> pblob,
KeyBlob::iterator it) const;
void RemoveShapeEntriesWithExecutor(std::string) const;
std::string PickLeastUsedShape(BlobPtr_t<ShapeBlob> sb) const;
// Remove all entries from the blob map
void ResetBlobMap(void* ptr);
......
......@@ -500,18 +500,9 @@ class MKLDNNHandlerT {
}
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p,
const std::string& suffix) {
const auto key_reorder_p = key_ + suffix + "reorder_p";
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p == nullptr) {
reorder_p =
const std::shared_ptr<mkldnn::memory>& target_memory_p) {
auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
}
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
......@@ -578,6 +569,8 @@ class MKLDNNHandlerT {
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key));
user_memory_p->set_data_handle(ptr);
// TODO(jczaja): Here we detect if reorder is cached it means it is needed
// need to change this to get rid of keys
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册