未验证 提交 e4c2a854 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Disable caching of Reorder operation (#35664)

* - REorder disabling caching

* - compilation fix

* - another compilation fix

* - another compilation fix

* - compilation fix

* - Fix

* - yet another compilation fix

* - suppresingly another compilation fix

* - lint

* - fix after review

* - fix
上级 d411a038
文件已添加
...@@ -179,11 +179,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -179,11 +179,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
if ((in_format != out_format) || always_copy) { if ((in_format != out_format) || always_copy) {
void* in_data = GetDataFromTensor(in, in_type); void* in_data = GetDataFromTensor(in, in_type);
std::string key =
platform::CreateKey(*dev_ctx, in_tz, in_format, out_format, in_type);
platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx, platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type,
cpu_engine, key); cpu_engine);
auto reorder_src_memory_p = handler.AcquireSrcMemory(in_format, in_data); auto reorder_src_memory_p = handler.AcquireSrcMemory(in_format, in_data);
auto reorder_dst_memory_p = auto reorder_dst_memory_p =
......
...@@ -43,10 +43,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -43,10 +43,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
auto tz = paddle::framework::vectorize<int64_t>(dout->dims()); auto tz = paddle::framework::vectorize<int64_t>(dout->dims());
memory::data_type dout_type = framework::ToMKLDNNDataType(dout->type()); memory::data_type dout_type = framework::ToMKLDNNDataType(dout->type());
std::string key = platform::CreateKey(dev_ctx, tz, dout->format(), platform::ReorderMKLDNNHandler handler(tz, dout->type(), dout_type,
dout->format(), dout_type); onednn_engine);
platform::ReorderMKLDNNHandler handler(tz, dout->type(), dout_type, dev_ctx,
onednn_engine, key);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto reorder_src_memory_p = handler.AcquireSrcMemory( auto reorder_src_memory_p = handler.AcquireSrcMemory(
......
...@@ -43,11 +43,9 @@ class CastMKLDNNKernel : public framework::OpKernel<T> { ...@@ -43,11 +43,9 @@ class CastMKLDNNKernel : public framework::OpKernel<T> {
auto x_tz = framework::vectorize(x->dims()); auto x_tz = framework::vectorize(x->dims());
std::string key = platform::ReorderMKLDNNHandler reorder_handler(x_tz, x_paddle_type, x_type,
platform::CreateKey(dev_ctx, x_tz, x->format(), x->format(), x_type); out_paddle_type, out_type,
platform::ReorderMKLDNNHandler reorder_handler( dev_ctx.GetEngine());
x_tz, x_paddle_type, x_type, out_paddle_type, out_type, dev_ctx,
dev_ctx.GetEngine(), key);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->format(), platform::to_void_cast(x->data<T>())); x->format(), platform::to_void_cast(x->data<T>()));
......
...@@ -1125,12 +1125,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -1125,12 +1125,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::format_tag out_format = mkldnn::memory::format_tag out_format =
weights_tz.size() == 6 ? mkldnn::memory::format_tag::goidhw weights_tz.size() == 6 ? mkldnn::memory::format_tag::goidhw
: mkldnn::memory::format_tag::goihw; : mkldnn::memory::format_tag::goihw;
std::string key = platform::CreateKey(dev_ctx, weights_tz, filter_fmt, platform::ReorderMKLDNNHandler handler(weights_tz, filter->type(),
out_format, in_type); in_type, mkldnn_engine);
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
platform::ReorderMKLDNNHandler handler(
weights_tz, filter->type(), in_type, dev_ctx, mkldnn_engine, key);
auto reorder_dst_memory_p = auto reorder_dst_memory_p =
handler.AcquireDstMemory(filter_grad, out_format, ctx.GetPlace()); handler.AcquireDstMemory(filter_grad, out_format, ctx.GetPlace());
......
...@@ -114,10 +114,8 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -114,10 +114,8 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
if (dout_vec_dims == dx_vec_dims) { if (dout_vec_dims == dx_vec_dims) {
mkldnn::memory::data_type dout_type = mkldnn::memory::data_type dout_type =
paddle::framework::ToMKLDNNDataType(dout->type()); paddle::framework::ToMKLDNNDataType(dout->type());
std::string key = paddle::platform::CreateKey(
dev_ctx, dout_vec_dims, dout->format(), dout->format(), dout_type);
paddle::platform::ReorderMKLDNNHandler reorder_handler( paddle::platform::ReorderMKLDNNHandler reorder_handler(
dout_vec_dims, dout->type(), dout_type, dev_ctx, onednn_engine, key); dout_vec_dims, dout->type(), dout_type, onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), paddle::platform::to_void_cast(dout->data<T>())); dout->format(), paddle::platform::to_void_cast(dout->data<T>()));
......
...@@ -58,11 +58,8 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx, ...@@ -58,11 +58,8 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx,
memory::data_type input_type = memory::data_type input_type =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
std::string key = paddle::platform::CreateKey(
dev_ctx, input_dims, input->format(), input->format(), input_type);
paddle::platform::ReorderMKLDNNHandler reorder_handler( paddle::platform::ReorderMKLDNNHandler reorder_handler(
output_dims, input->type(), input_type, dev_ctx, dev_ctx.GetEngine(), output_dims, input->type(), input_type, dev_ctx.GetEngine());
key);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
memory::format_tag::abc, memory::format_tag::abc,
......
...@@ -93,10 +93,8 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> { ...@@ -93,10 +93,8 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
} }
mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type()); mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type());
std::string key = platform::ReorderMKLDNNHandler reorder_handler(x_vec_dims, x->type(),
platform::CreateKey(dev_ctx, x_vec_dims, x->format(), x_type); x_type, onednn_engine);
platform::ReorderMKLDNNHandler reorder_handler(
x_vec_dims, x->type(), x_type, dev_ctx, onednn_engine, key);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->format(), platform::to_void_cast(x->data<T>())); x->format(), platform::to_void_cast(x->data<T>()));
...@@ -253,11 +251,8 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T> { ...@@ -253,11 +251,8 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T> {
mkldnn::memory::data_type dout_type = mkldnn::memory::data_type dout_type =
framework::ToMKLDNNDataType(dout->type()); framework::ToMKLDNNDataType(dout->type());
std::string key = platform::ReorderMKLDNNHandler reorder_handler(dout_vec_dims, dout->type(),
platform::CreateKey(dev_ctx, dout_vec_dims, this->getPlainFormatTag(dx), dout_type, onednn_engine);
dx->format(), dout_type);
platform::ReorderMKLDNNHandler reorder_handler(
dout_vec_dims, dout->type(), dout_type, dev_ctx, onednn_engine, key);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>())); dout->format(), platform::to_void_cast(dout->data<T>()));
......
...@@ -98,18 +98,16 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> { ...@@ -98,18 +98,16 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> {
out->Resize(framework::make_ddim(slice_dims)); out->Resize(framework::make_ddim(slice_dims));
mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type()); mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type());
auto key = platform::CreateKey(dev_ctx, x_vec_dims, axes, starts, ends,
x->format(), x_type);
platform::ReorderMKLDNNHandler reorder_handler( platform::ReorderMKLDNNHandler reorder_handler(x_vec_dims, x->type(),
x_vec_dims, x->type(), x_type, dev_ctx, onednn_engine, key); x_type, onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->format(), platform::to_void_cast(x->data<T>())); x->format(), platform::to_void_cast(x->data<T>()));
auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets, auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets,
reorder_src_memory_p); reorder_src_memory_p);
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, slice_dims, 0, get_plain_format_tag(x), ctx.GetPlace()); out, slice_dims, get_plain_format_tag(x), ctx.GetPlace());
auto reorder_p = auto reorder_p =
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);
...@@ -201,16 +199,13 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -201,16 +199,13 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
mkldnn::memory::format_tag reorder_format_tag = mkldnn::memory::format_tag reorder_format_tag =
platform::GetMKLDNNFormat(md.reshape(slice_dims)); platform::GetMKLDNNFormat(md.reshape(slice_dims));
auto key = platform::CreateKey(dev_ctx, dout_vec_dims, axes, starts, ends, platform::ReorderMKLDNNHandler reorder_handler(slice_dims, dout->type(),
reorder_format_tag, dout_type); dout_type, onednn_engine);
platform::ReorderMKLDNNHandler reorder_handler(
slice_dims, dout->type(), dout_type, dev_ctx, onednn_engine, key);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
reorder_format_tag, platform::to_void_cast(dout->data<T>())); reorder_format_tag, platform::to_void_cast(dout->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx, dx_vec_dims, 0, reorder_format_tag, ctx.GetPlace()); dx, dx_vec_dims, reorder_format_tag, ctx.GetPlace());
memset(dx->data<T>(), 0, reorder_dst_memory_p->get_desc().get_size()); memset(dx->data<T>(), 0, reorder_dst_memory_p->get_desc().get_size());
auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets, auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets,
......
...@@ -91,27 +91,25 @@ class SplitMKLDNNKernel : public framework::OpKernel<T> { ...@@ -91,27 +91,25 @@ class SplitMKLDNNKernel : public framework::OpKernel<T> {
auto x_vec_dims = framework::vectorize(x_dims); auto x_vec_dims = framework::vectorize(x_dims);
mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type()); mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type());
auto key = platform::CreateKey(dev_ctx, x_vec_dims, axis, num, sections,
x->format(), x_type);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
std::vector<int64_t> offset(x_vec_dims.size(), 0); std::vector<int64_t> offset(x_vec_dims.size(), 0);
platform::ReorderMKLDNNHandler reorder_handler( platform::ReorderMKLDNNHandler reorder_handler(x_vec_dims, x->type(),
x_vec_dims, x->type(), x_type, dev_ctx, onednn_engine, key); x_type, onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->format(), platform::to_void_cast(x->data<T>())); x->format(), platform::to_void_cast(x->data<T>()));
for (size_t i = 0; i < outs_number; ++i) { for (size_t i = 0; i < outs_number; ++i) {
auto out_vec_dims = framework::vectorize(outs[i]->dims()); auto out_vec_dims = framework::vectorize(outs[i]->dims());
auto slice_mem_p = reorder_handler.AcquireSubmemory( auto slice_mem_p = reorder_handler.AcquireSubmemory(out_vec_dims, offset,
out_vec_dims, offset, reorder_src_memory_p, i); reorder_src_memory_p);
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
outs[i], out_vec_dims, i, x->format(), ctx.GetPlace()); outs[i], out_vec_dims, x->format(), ctx.GetPlace());
auto reorder_p = auto reorder_p =
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p, i); reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);
reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p);
......
...@@ -155,15 +155,11 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -155,15 +155,11 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// For in-place execution which sum does not have we need to fake it // For in-place execution which sum does not have we need to fake it
// so from oneDNN dst memory we reorder data into input // so from oneDNN dst memory we reorder data into input
if (in_place) { if (in_place) {
const std::string reorder_key =
platform::CreateKey(dev_ctx, framework::vectorize(output->dims()),
ctx.OutputName("Out") + "-I");
auto& in_out = in_vars[0]->Get<framework::LoDTensor>(); auto& in_out = in_vars[0]->Get<framework::LoDTensor>();
auto output_tz = framework::vectorize<int64_t>(output->dims()); auto output_tz = framework::vectorize<int64_t>(output->dims());
platform::ReorderMKLDNNHandler reorder_handler( platform::ReorderMKLDNNHandler reorder_handler(
output_tz, output->type(), framework::ToMKLDNNDataType(in_out.type()), output_tz, output->type(), framework::ToMKLDNNDataType(in_out.type()),
dev_ctx, dev_ctx.GetEngine(), reorder_key); dev_ctx.GetEngine());
auto target_mem = reorder_handler.AcquireDstMemory( auto target_mem = reorder_handler.AcquireDstMemory(
output, in_out.format(), ctx.GetPlace()); output, in_out.format(), ctx.GetPlace());
......
...@@ -71,10 +71,8 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> { ...@@ -71,10 +71,8 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
if (input_dims == output_dims) { if (input_dims == output_dims) {
mkldnn::memory::data_type input_type = mkldnn::memory::data_type input_type =
framework::ToMKLDNNDataType(input->type()); framework::ToMKLDNNDataType(input->type());
std::string key = platform::CreateKey( platform::ReorderMKLDNNHandler reorder_handler(input_dims, input->type(),
dev_ctx, input_dims, input->format(), input->format(), input_type); input_type, onednn_engine);
platform::ReorderMKLDNNHandler reorder_handler(
input_dims, input->type(), input_type, dev_ctx, onednn_engine, key);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
input->format(), platform::to_void_cast(input->data<T>())); input->format(), platform::to_void_cast(input->data<T>()));
......
...@@ -1071,138 +1071,73 @@ class ActivationMKLDNNHandler ...@@ -1071,138 +1071,73 @@ class ActivationMKLDNNHandler
} }
}; };
class ReorderMKLDNNHandler : public MKLDNNHandler { class ReorderMKLDNNHandler {
public: public:
ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
framework::proto::VarType::Type vtype, framework::proto::VarType::Type vtype,
mkldnn::memory::data_type dtype, mkldnn::memory::data_type dtype, mkldnn::engine engine)
const platform::MKLDNNDeviceContext& dev_ctx, : dims_(dims),
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
vtype_(vtype), vtype_(vtype),
vtype_dst_(vtype), vtype_dst_(vtype),
dtype_(dtype), dtype_(dtype),
dtype_dst_(dtype) {} dtype_dst_(dtype),
engine_(engine) {}
ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
framework::proto::VarType::Type vtype, framework::proto::VarType::Type vtype,
mkldnn::memory::data_type dtype, mkldnn::memory::data_type dtype,
framework::proto::VarType::Type vtype_dst, framework::proto::VarType::Type vtype_dst,
mkldnn::memory::data_type dtype_dst, mkldnn::memory::data_type dtype_dst,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine)
mkldnn::engine engine, const std::string& base_key) : dims_(dims),
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
vtype_(vtype), vtype_(vtype),
vtype_dst_(vtype_dst), vtype_dst_(vtype_dst),
dtype_(dtype), dtype_(dtype),
dtype_dst_(dtype_dst) {} dtype_dst_(dtype_dst),
engine_(engine) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const MKLDNNMemoryFormat& fmt, void* ptr) { const MKLDNNMemoryFormat& fmt, void* ptr) {
return this->AcquireMemory(dims_, dtype_, fmt, ptr, "@user_src_mem_p"); auto md = mkldnn::memory::desc(dims_, dtype_, fmt);
return std::make_shared<mkldnn::memory>(md, engine_, ptr);
} }
std::shared_ptr<mkldnn::memory> AcquireSubmemory( std::shared_ptr<mkldnn::memory> AcquireSubmemory(
const std::vector<int64_t>& dims, const std::vector<int64_t>& offset, const std::vector<int64_t>& dims, const std::vector<int64_t>& offset,
const std::shared_ptr<mkldnn::memory>& mem_p, int submemory_number = 0) { const std::shared_ptr<mkldnn::memory>& mem_p) {
std::string local_key = key_; auto sub_md = mem_p->get_desc().submemory_desc(dims, {offset});
local_key.append("@submem") auto sub_mem_p = std::make_shared<mkldnn::memory>(sub_md, engine_,
.append(std::to_string(submemory_number)) mem_p->get_data_handle());
.append("_p");
auto sub_mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (sub_mem_p == nullptr) {
auto sub_md = mem_p->get_desc().submemory_desc(dims, {offset});
sub_mem_p = std::make_shared<mkldnn::memory>(sub_md, engine_,
mem_p->get_data_handle());
dev_ctx_.SetBlob(local_key, sub_mem_p);
} else {
sub_mem_p->set_data_handle(mem_p->get_data_handle());
}
return sub_mem_p; return sub_mem_p;
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemory( std::shared_ptr<mkldnn::memory> AcquireDstMemory(
framework::Tensor* output, const MKLDNNMemoryFormat& fmt, framework::Tensor* output, const MKLDNNMemoryFormat& fmt,
platform::Place place) { platform::Place place) {
auto local_key = key_ + "@user_dst_mem_p"; auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_dst_, fmt);
auto mem_p = auto dst_data = output->mutable_data(place, vtype_dst_, dst_md.get_size());
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); return std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
if (mem_p == nullptr) {
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_dst_, fmt);
auto dst_data =
output->mutable_data(place, vtype_dst_, dst_md.get_size());
mem_p = std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
// Even if memory object exists , we may be using it for diffrent tensor
auto dst_data =
output->mutable_data(place, vtype_dst_, mem_p->get_desc().get_size());
mem_p->set_data_handle(dst_data);
}
return mem_p;
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemory( std::shared_ptr<mkldnn::memory> AcquireDstMemory(
framework::Tensor* output, const std::vector<int64_t>& dims, framework::Tensor* output, const std::vector<int64_t>& dims,
const int memory_number, const MKLDNNMemoryFormat& fmt, const MKLDNNMemoryFormat& fmt, platform::Place place) {
platform::Place place) { auto dst_md = platform::MKLDNNMemDesc(dims, dtype_dst_, fmt);
auto local_key = auto dst_data = output->mutable_data(place, vtype_dst_, dst_md.get_size());
key_ + "@user_dst_mem" + std::to_string(memory_number) + "_p"; return std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto dst_md = platform::MKLDNNMemDesc(dims, dtype_dst_, fmt);
auto dst_data =
output->mutable_data(place, vtype_dst_, dst_md.get_size());
mem_p = std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
// Even if memory object exists , we may be using it for diffrent tensor
auto dst_data =
output->mutable_data(place, vtype_dst_, mem_p->get_desc().get_size());
mem_p->set_data_handle(dst_data);
}
return mem_p;
}
std::shared_ptr<mkldnn::reorder> AcquireReorder(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p, int reorder_number) {
auto prim_key = key_ + "@reorder" + std::to_string(reorder_number) + "_p";
auto reorder_p =
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
if (reorder_p == nullptr) {
reorder_p =
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, reorder_p);
}
return reorder_p;
} }
std::shared_ptr<mkldnn::reorder> AcquireReorder( std::shared_ptr<mkldnn::reorder> AcquireReorder(
std::shared_ptr<mkldnn::memory> dst_memory_p, std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) { std::shared_ptr<mkldnn::memory> src_memory_p) {
auto prim_key = key_ + "@reorder_p"; return std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
auto reorder_p =
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
if (reorder_p == nullptr) {
reorder_p =
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, reorder_p);
}
return reorder_p;
} }
private: private:
std::vector<int64_t> dims_; std::vector<int64_t> dims_;
framework::proto::VarType::Type vtype_, vtype_dst_; framework::proto::VarType::Type vtype_, vtype_dst_;
mkldnn::memory::data_type dtype_, dtype_dst_; mkldnn::memory::data_type dtype_, dtype_dst_;
mkldnn::engine engine_;
}; };
template <typename T> template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册