未验证 提交 173660be 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Cache oneDNN stream not to recreate in each oneDNN op (#30358)

上级 ae0f88a9
......@@ -193,7 +193,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
mkldnn::stream astream(cpu_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("ext_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
......
......@@ -48,7 +48,7 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
platform::ReorderMKLDNNHandler handler(tz, dout->type(), dout_type, dev_ctx,
onednn_engine, key);
mkldnn::stream astream(onednn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto reorder_src_memory_p = handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
......
......@@ -68,7 +68,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
const auto binary_prim = handler.AcquireForwardPrimitive();
mkldnn::stream astream(mkldnn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_x_memory},
......
......@@ -246,7 +246,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
memory_p = std::make_shared<dnnl::memory>(this->fwd_pd_->src_iter_desc(),
this->engine_);
dnnl::stream astream(this->engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
dnnl::reorder(user_h0_memory, *memory_p, attr_)
.execute(astream, user_h0_memory, *memory_p);
......@@ -284,7 +284,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
memory_p = std::make_shared<dnnl::memory>(
this->fwd_pd_->weights_layer_desc(), this->engine_);
dnnl::stream astream(this->engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
dnnl::reorder(user_memory, *memory_p, attr_)
.execute(astream, user_memory, *memory_p);
......@@ -337,7 +337,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
memory_p = std::make_shared<dnnl::memory>(
this->fwd_pd_->weights_iter_desc(), this->engine_);
dnnl::stream astream(this->engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
dnnl::reorder(user_memory, *memory_p, attr_)
.execute(astream, user_memory, *memory_p);
......@@ -469,7 +469,7 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
auto gru_forward_p = handler.AcquireForwardPrimitive();
dnnl::stream astream(mkldnn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
gru_forward_p->execute(astream, gru_args);
astream.wait();
......
......@@ -292,7 +292,7 @@ class MultiGRUHandler {
auto gru_forward_p0 = AcquireGruPrimitive(layer, dir);
dnnl::stream astream(engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
gru_forward_p0->execute(astream, gru_args);
astream.wait();
return out_mem;
......@@ -315,7 +315,7 @@ class MultiGRUHandler {
memory_p = std::make_shared<dnnl::memory>(
gru_pds_[{layer, dir}]->src_iter_desc(), engine_);
dnnl::stream astream(engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
dnnl::reorder(user_h0_memory, *memory_p, attrs_[2 * layer + (dir == R2L)])
.execute(astream, user_h0_memory, *memory_p);
......@@ -354,7 +354,7 @@ class MultiGRUHandler {
memory_p = std::make_shared<dnnl::memory>(
gru_pds_[{layer, dir}]->weights_layer_desc(), engine_);
dnnl::stream astream(engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
dnnl::reorder(user_memory, *memory_p, attrs_[2 * layer + (dir == R2L)])
.execute(astream, user_memory, *memory_p);
......@@ -410,7 +410,7 @@ class MultiGRUHandler {
memory_p = std::make_shared<dnnl::memory>(
gru_pds_[{layer, dir}]->weights_iter_desc(), engine_);
dnnl::stream astream(engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
dnnl::reorder(user_memory, *memory_p, attrs_[2 * layer + (dir == R2L)])
.execute(astream, user_memory, *memory_p);
......@@ -516,7 +516,7 @@ class MultiGRUHandler {
auto concat_p = AcquireConcatPrimitive(layer);
dnnl::stream astream(engine_);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
concat_p->execute(astream, concat_args);
astream.wait();
return out_mem;
......
......@@ -112,7 +112,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
auto activation_p = handler.AcquireForwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine());
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_p->execute(astream, {{MKLDNN_ARG_FROM, *src_memory_p},
{MKLDNN_ARG_TO, *dst_memory_p}});
astream.wait();
......@@ -158,7 +158,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine());
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_backward_p->execute(astream,
{{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
......
......@@ -220,7 +220,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
y->set_layout(DataLayout::kMKLDNN);
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
mkldnn::stream astream(dev_ctx.GetEngine());
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
batch_norm_p->execute(astream,
{{MKLDNN_ARG_SRC, *src_memory},
{MKLDNN_ARG_SCALE_SHIFT, *scaleshift_memory},
......@@ -321,7 +321,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// finally create batch_norm backward primitive
auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine());
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
batch_norm_bwd_p->execute(
astream, {{MKLDNN_ARG_SRC, *src_memory},
{MKLDNN_ARG_MEAN, *mean_memory},
......
......@@ -202,7 +202,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->mutable_data<T>(place, concat_pd->dst_desc().get_size()));
}
mkldnn::stream astream(mkldnn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
std::unordered_map<int, memory> args;
for (size_t i = 0; i < multi_input.size(); ++i) {
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, (*srcs).at(i)});
......
......@@ -471,7 +471,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
}
mkldnn::stream astream(mkldnn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
conv_p->execute(astream, args);
astream.wait();
......@@ -553,7 +553,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx.GetBlob(prim_key));
mkldnn::stream astream(mkldnn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (conv_p == nullptr || !is_test) {
float fuse_alpha = ctx.Attr<float>("fuse_alpha");
......@@ -1045,7 +1045,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
user_weights_md, to_void_cast<T>(filter_data));
auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory(
user_diff_dst_md, to_void_cast<T>(output_grad_data));
mkldnn::stream astream(mkldnn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (filter_grad) {
auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive(
user_src_memory_p, pipeline);
......
......@@ -242,7 +242,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto conv_p = handler.AcquireConvolution();
mkldnn::stream astream(mkldnn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (bias) {
const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc(
......
......@@ -124,7 +124,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace()));
}
mkldnn::stream astream(engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *src_memory, *dst_memory);
astream.wait();
......
......@@ -137,7 +137,7 @@ class FCPrimitiveFactory {
}
void Execute() {
mkldnn::stream astream(engine_);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (bias_) {
fc_->execute(astream, {{MKLDNN_ARG_SRC, *input_},
{MKLDNN_ARG_WEIGHTS, *weights_},
......@@ -280,7 +280,7 @@ class FCPrimitiveFactory {
auto dst_mem = std::make_shared<memory>(dst_desc, engine_);
auto reorder = mkldnn::reorder(src_mem, *dst_mem);
mkldnn::stream astream(engine_);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
{
platform::RecordEvent record_reorder("int_reorder",
......@@ -309,7 +309,7 @@ class FCPrimitiveFactory {
attributes.set_output_scales(mask, scale_data);
auto reorder = mkldnn::reorder(*src_mem, *dst_mem, attributes);
mkldnn::stream astream(engine_);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
{
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
......
......@@ -154,7 +154,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
auto resampling_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
mkldnn::stream astream(mkldnn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
resampling_prim->execute(astream, args);
astream.wait();
......
......@@ -120,7 +120,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto layer_norm_p = handler.AcquireForwardPrimitive();
dnnl::stream astream(dev_ctx.GetEngine());
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
std::unordered_map<int, dnnl::memory> args;
args.insert({DNNL_ARG_SRC, *src_memory});
......
......@@ -59,7 +59,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto workspace_memory = handler.AcquireWorkspaceMemory(mid);
mid->set_layout(framework::DataLayout::kMKLDNN);
mkldnn::stream astream(dev_ctx.GetEngine());
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (!workspace_memory->get_desc().is_zero()) {
mid->set_format(platform::GetMKLDNNFormat(*workspace_memory));
lrn_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
......@@ -118,7 +118,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto lrn_bwd = handler.AcquireBackwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine());
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
lrn_bwd->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory},
{MKLDNN_ARG_DIFF_SRC, *diff_src_memory},
......
......@@ -109,7 +109,7 @@ class MulPrimitiveFactory {
auto reorder = mkldnn::reorder(reorder_pd);
mkldnn::stream astream(engine_);
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
{
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
......@@ -184,7 +184,7 @@ class MulPrimitiveFactory {
}
void Execute() {
mkldnn::stream astream(engine_);
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
(*mul_).execute(astream, {{MKLDNN_ARG_SRC, *x_input_},
{MKLDNN_ARG_WEIGHTS, *y_input_},
{MKLDNN_ARG_DST, *output_}});
......@@ -270,8 +270,7 @@ class MulPrimitiveFactory {
auto reorder = mkldnn::reorder(src_mem, dst_mem);
mkldnn::stream astream(engine_);
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
{
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
......@@ -355,7 +354,7 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
"Operator DNNL Mul must use CPUPlace"));
platform::MKLDNNDeviceContext::tls().log_lib_version();
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
auto &mkldnn_engine = dev_ctx.GetEngine();
const Tensor *x = ctx.Input<Tensor>("X");
const Tensor *y = ctx.Input<Tensor>("Y");
......
......@@ -51,7 +51,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto pool_p = handler.AcquireForwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine());
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if ((ctx.Attr<bool>("is_test") == false) &&
(ctx.Attr<std::string>("pooling_type") == "max")) {
// Training
......@@ -154,7 +154,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto pool_bwd_p = handler.AcquireBackwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine());
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (pooling_type == "max") {
// Max - pooling needs Workspace
auto workspace_memory = handler.AcquireWorkspaceMemory();
......
......@@ -140,7 +140,7 @@ class QuantOpKernel : public framework::OpKernel<T> {
}
}
mkldnn::stream astream(engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
{
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
......
......@@ -137,7 +137,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
}
}
dnnl::stream astream(engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
{
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
......
......@@ -117,7 +117,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto softmax_p = handler.AcquireForwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine());
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
softmax_p->execute(astream, {{DNNL_ARG_SRC, *softmax_src_memory_p},
{DNNL_ARG_DST, *softmax_dst_memory_p}});
astream.wait();
......@@ -169,7 +169,7 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
auto softmax_bwd_p = handler.AcquireBackwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine());
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
softmax_bwd_p->execute(astream,
{{MKLDNN_ARG_DST, *dst_memory_p},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
......
......@@ -178,7 +178,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
args.insert({MKLDNN_ARG_DST, *dst_mem});
mkldnn::stream astream(dev_ctx.GetEngine());
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
sum_p->execute(astream, args);
astream.wait();
......
......@@ -61,7 +61,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
mkldnn::stream astream(mkldnn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
transpose_p->execute(astream, *transpose_src_memory_p,
*transpose_dst_memory_p);
astream.wait();
......@@ -116,7 +116,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
mkldnn::stream astream(mkldnn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
transpose_p->execute(astream, *transpose_src_memory_p,
*transpose_dst_memory_p);
astream.wait();
......
......@@ -458,20 +458,34 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
: CPUDeviceContext(place),
engine_(mkldnn::engine::kind::cpu, 0),
p_blobmap_() {
: CPUDeviceContext(place), p_blobmap_() {
p_blobmap_.reset(new BlobMap());
p_mutex_.reset(new std::mutex());
}
MKLDNNDeviceContextThreadLocals::Body::Body() {
MKLDNNDeviceContextThreadLocals::Body::Body()
: cur_engine(mkldnn::engine::kind::cpu, 0), cur_stream(cur_engine) {
cur_mkldnn_session_id = kMKLDNNSessionID_Default;
cur_input_shape_str = "";
cur_input_shape_cache_capacity = 1;
cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
auto cpu_place = paddle::platform::CPUPlace();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
dev_ctx->ResetBlobMap();
}
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
size_t sid) {
cur_mkldnn_session_id = sid;
......@@ -508,6 +522,14 @@ void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
}
}
const mkldnn::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
return cur_engine;
}
mkldnn::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
return cur_stream;
}
void MKLDNNDeviceContext::ResetBlobMap() {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
if (!block_next_cache_clearing_) {
......
......@@ -525,8 +525,12 @@ class MKLDNNDeviceContextThreadLocals {
// Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN
paddle::framework::DataLayout cur_paddle_data_layout;
// MKL-DNN stream used for execution of primitives (per-thread)
mkldnn::engine cur_engine;
mkldnn::stream cur_stream;
Body();
~Body();
void set_cur_mkldnn_session_id(size_t sid);
size_t get_cur_mkldnn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str);
......@@ -534,6 +538,8 @@ class MKLDNNDeviceContextThreadLocals {
void set_cur_paddle_data_layout(framework::DataLayout dl);
framework::DataLayout get_cur_paddle_data_layout(void);
void log_lib_version(void);
const mkldnn::engine& get_engine(void);
mkldnn::stream& get_stream(void);
};
MKLDNNDeviceContextThreadLocals() = default;
MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
......@@ -572,7 +578,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
explicit MKLDNNDeviceContext(CPUPlace place);
/* \brief Get the active engine */
const mkldnn::engine& GetEngine() const { return engine_; }
const mkldnn::engine& GetEngine() const { return tls().get_engine(); }
// Remove all entries from the blob map
void ResetBlobMap();
......@@ -605,7 +611,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
}
private:
mkldnn::engine engine_;
std::shared_ptr<BlobMap> p_blobmap_;
std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false;
......
......@@ -188,7 +188,7 @@ MKLDNNGetDataType<paddle::platform::bfloat16>() {
inline void Reorder(mkldnn::memory src, mkldnn::memory dst,
const mkldnn::engine& engine) {
auto reorder_prim = mkldnn::reorder(src, dst);
mkldnn::stream astream(engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_prim.execute(astream, src, dst);
......
......@@ -232,7 +232,7 @@ class MKLDNNHandlerT {
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
}
mkldnn::stream astream(engine_);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
......@@ -261,7 +261,7 @@ class MKLDNNHandlerT {
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
mkldnn::stream astream(engine_);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
......@@ -273,7 +273,7 @@ class MKLDNNHandlerT {
dev_ctx_.SetBlob(user_key, user_memory_p);
dev_ctx_.SetBlob(target_key, target_memory_p);
} else if (!is_persistent) {
mkldnn::stream astream(engine_);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto user_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key));
......@@ -425,7 +425,7 @@ class MKLDNNHandler {
auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
mkldnn::stream astream(engine_);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
......@@ -451,7 +451,7 @@ class MKLDNNHandler {
auto target_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
mkldnn::stream astream(engine_);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (target_memory_p == nullptr) {
target_memory_p = user_memory_p;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册