未验证 提交 657abd51 编写于 作者: J jakpiase 提交者: GitHub

OneDNN md-in-tensor refactoring part 4: Memory descriptor enabled for more ops (#42946)

* added support for md in more ops

* fixed typo
上级 c6f98fa0
...@@ -79,8 +79,10 @@ class FillConstantMKLDNNKernel : public framework::OpKernel<T> { ...@@ -79,8 +79,10 @@ class FillConstantMKLDNNKernel : public framework::OpKernel<T> {
{DNNL_ARG_DST, *src0_memory_p}}); {DNNL_ARG_DST, *src0_memory_p}});
astream.wait(); astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); // src0_memory_p's md was just to allow the usage of a binary
out->set_format(platform::GetPlainMKLDNNFormat(out->dims().size())); // primitive as a memset, and now we need to create a real one
out->set_mem_desc({phi::vectorize(shape), platform::MKLDNNGetDataType<T>(),
platform::GetPlainMKLDNNFormat(shape.size())});
} }
T CalculateFillValue(const framework::ExecutionContext& ctx) const { T CalculateFillValue(const framework::ExecutionContext& ctx) const {
......
...@@ -124,7 +124,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -124,7 +124,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (!workspace_memory->get_desc().is_zero()) { if (!workspace_memory->get_desc().is_zero()) {
mid->set_format(platform::GetMKLDNNFormat(*workspace_memory)); mid->set_mem_desc(workspace_memory->get_desc());
lrn_p->execute(astream, {{DNNL_ARG_SRC, *src_memory}, lrn_p->execute(astream, {{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_DST, *dst_memory}, {DNNL_ARG_DST, *dst_memory},
{DNNL_ARG_WORKSPACE, *workspace_memory}}); {DNNL_ARG_WORKSPACE, *workspace_memory}});
...@@ -134,8 +134,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -134,8 +134,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
astream.wait(); astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); out->set_mem_desc(dst_memory->get_desc());
out->set_format(platform::GetMKLDNNFormat(*dst_memory));
} }
}; };
...@@ -177,8 +176,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -177,8 +176,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
{DNNL_ARG_WORKSPACE, *workspace}}); {DNNL_ARG_WORKSPACE, *workspace}});
astream.wait(); astream.wait();
in_x_grad->set_layout(framework::DataLayout::kMKLDNN); in_x_grad->set_mem_desc(diff_src_memory->get_desc());
in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -175,19 +175,17 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -175,19 +175,17 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
dnnl::memory::data_type dout_type = framework::ToMKLDNNDataType( dnnl::memory::data_type dout_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(dout->dtype())); framework::TransToProtoVarType(dout->dtype()));
dnnl::memory::desc md(dout_vec_dims, platform::MKLDNNGetDataType<T>(),
dout->format());
dnnl::memory::format_tag reorder_format_tag =
platform::GetMKLDNNFormat(md.reshape(slice_dims));
platform::ReorderMKLDNNHandler reorder_handler( platform::ReorderMKLDNNHandler reorder_handler(
slice_dims, framework::TransToProtoVarType(dout->dtype()), dout_type, slice_dims, framework::TransToProtoVarType(dout->dtype()), dout_type,
onednn_engine); onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
reorder_format_tag, platform::to_void_cast(dout->data<T>())); dout->mem_desc().reshape(slice_dims),
platform::to_void_cast(dout->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx, dx_vec_dims, reorder_format_tag, ctx.GetPlace()); dx, dx_vec_dims, platform::GetPlainMKLDNNFormat(dx_vec_dims.size()),
ctx.GetPlace());
memset(dx->data<T>(), 0, reorder_dst_memory_p->get_desc().get_size()); memset(dx->data<T>(), 0, reorder_dst_memory_p->get_desc().get_size());
auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets, auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets,
...@@ -199,8 +197,7 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -199,8 +197,7 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
reorder_p->execute(astream, *reorder_src_memory_p, *slice_mem_p); reorder_p->execute(astream, *reorder_src_memory_p, *slice_mem_p);
astream.wait(); astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN); dx->set_mem_desc(reorder_dst_memory_p->get_desc());
dx->set_format(reorder_format_tag);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -59,7 +59,7 @@ class StackMKLDNNHandler ...@@ -59,7 +59,7 @@ class StackMKLDNNHandler
// wrong output format deduction and suboptimal performance as a result // wrong output format deduction and suboptimal performance as a result
if (stack_axis != ndims) { if (stack_axis != ndims) {
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
srcs_md.emplace_back(memory::desc(input_dims, dt, inputs[i]->format())); srcs_md.push_back(inputs[i]->mem_desc());
} }
input_dims[stack_axis] *= inputs.size(); input_dims[stack_axis] *= inputs.size();
...@@ -69,8 +69,7 @@ class StackMKLDNNHandler ...@@ -69,8 +69,7 @@ class StackMKLDNNHandler
extended_input_dims[stack_axis] = 1; extended_input_dims[stack_axis] = 1;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
srcs_md.emplace_back(memory::desc(input_dims, dt, inputs[i]->format()) srcs_md.push_back(inputs[i]->mem_desc().reshape(extended_input_dims));
.reshape(extended_input_dims));
} }
// concat primitive choses suboptimal format tag because it cannot // concat primitive choses suboptimal format tag because it cannot
...@@ -130,9 +129,8 @@ class StackMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -130,9 +129,8 @@ class StackMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
concat_p->execute(astream, args); concat_p->execute(astream, args);
astream.wait(); astream.wait();
output->set_layout(DataLayout::kMKLDNN); output->set_mem_desc(
output->set_format(platform::GetMKLDNNFormat( dst_mem->get_desc().reshape(phi::vectorize(output->dims())));
dst_mem->get_desc().reshape(phi::vectorize(output->dims()))));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -60,17 +60,16 @@ class SumMKLDNNHandler ...@@ -60,17 +60,16 @@ class SumMKLDNNHandler
auto src_tz = dst_tz; auto src_tz = dst_tz;
std::vector<dnnl::memory::desc> srcs_md; std::vector<dnnl::memory::desc> srcs_md;
srcs_md.reserve(in_vars.size());
for (size_t i = 0; i < in_vars.size(); i++) { for (size_t i = 0; i < in_vars.size(); i++) {
auto& input_it = in_vars[i]->Get<framework::LoDTensor>(); auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) { if (input_it.numel() == 0) {
continue; continue;
} }
MKLDNNMemoryFormat input_format = input_it.format(); srcs_md.push_back(input_it.mem_desc());
srcs_md.push_back(dnnl::memory::desc(
src_tz, platform::MKLDNNGetDataType<T>(), input_format));
++num_inputs_; ++num_inputs_;
} }
std::vector<float> scales(num_inputs_, 1.0); std::vector<float> scales(num_inputs_, 1.0f);
auto dst_md = dnnl::memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(), auto dst_md = dnnl::memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any); MKLDNNMemoryFormat::any);
...@@ -139,47 +138,27 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -139,47 +138,27 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
++input_index; ++input_index;
} }
std::shared_ptr<dnnl::memory> dst_mem = nullptr; std::unordered_map<int, dnnl::memory> args;
std::shared_ptr<dnnl::memory> dst_mem;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
}
if (in_place) { if (in_place) {
dst_mem = handler.AcquireDstMemory(); dst_mem = srcs_mem[0];
output->mutable_data<T>(ctx.GetPlace());
} else { } else {
dst_mem = handler.AcquireDstMemory(output); dst_mem = handler.AcquireDstMemory(output);
} }
args.insert({DNNL_ARG_DST, *dst_mem});
auto sum_p = handler.AcquireForwardPrimitive(); auto sum_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> args;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
}
args.insert({DNNL_ARG_DST, *dst_mem});
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
sum_p->execute(astream, args); sum_p->execute(astream, args);
astream.wait(); astream.wait();
// For in-place execution which sum does not have we need to fake it output->set_mem_desc(dst_mem->get_desc());
// so from oneDNN dst memory we reorder data into input
if (in_place) {
auto& in_out = in_vars[0]->Get<framework::LoDTensor>();
auto output_tz = phi::vectorize<int64_t>(output->dims());
platform::ReorderMKLDNNHandler reorder_handler(
output_tz, framework::TransToProtoVarType(output->dtype()),
framework::ToMKLDNNDataType(
framework::TransToProtoVarType(in_out.dtype())),
dev_ctx.GetEngine());
auto target_mem = reorder_handler.AcquireDstMemory(
output, in_out.format(), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(target_mem, dst_mem);
reorder_p->execute(astream, *dst_mem, *target_mem);
astream.wait();
}
output->set_layout(framework::DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_mem));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册