未验证 提交 657abd51 编写于 作者: J jakpiase 提交者: GitHub

OneDNN md-in-tensor refactoring part 4: Memory descriptor enabled for more ops (#42946)

* added support for md in more ops

* fixed typo
上级 c6f98fa0
......@@ -79,8 +79,10 @@ class FillConstantMKLDNNKernel : public framework::OpKernel<T> {
{DNNL_ARG_DST, *src0_memory_p}});
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetPlainMKLDNNFormat(out->dims().size()));
// src0_memory_p's md was just to allow the usage of a binary
// primitive as a memset, and now we need to create a real one
out->set_mem_desc({phi::vectorize(shape), platform::MKLDNNGetDataType<T>(),
platform::GetPlainMKLDNNFormat(shape.size())});
}
T CalculateFillValue(const framework::ExecutionContext& ctx) const {
......
......@@ -124,7 +124,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (!workspace_memory->get_desc().is_zero()) {
mid->set_format(platform::GetMKLDNNFormat(*workspace_memory));
mid->set_mem_desc(workspace_memory->get_desc());
lrn_p->execute(astream, {{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_DST, *dst_memory},
{DNNL_ARG_WORKSPACE, *workspace_memory}});
......@@ -134,8 +134,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(*dst_memory));
out->set_mem_desc(dst_memory->get_desc());
}
};
......@@ -177,8 +176,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
{DNNL_ARG_WORKSPACE, *workspace}});
astream.wait();
in_x_grad->set_layout(framework::DataLayout::kMKLDNN);
in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
in_x_grad->set_mem_desc(diff_src_memory->get_desc());
}
};
} // namespace operators
......
......@@ -175,19 +175,17 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
dnnl::memory::data_type dout_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(dout->dtype()));
dnnl::memory::desc md(dout_vec_dims, platform::MKLDNNGetDataType<T>(),
dout->format());
dnnl::memory::format_tag reorder_format_tag =
platform::GetMKLDNNFormat(md.reshape(slice_dims));
platform::ReorderMKLDNNHandler reorder_handler(
slice_dims, framework::TransToProtoVarType(dout->dtype()), dout_type,
onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
reorder_format_tag, platform::to_void_cast(dout->data<T>()));
dout->mem_desc().reshape(slice_dims),
platform::to_void_cast(dout->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx, dx_vec_dims, reorder_format_tag, ctx.GetPlace());
dx, dx_vec_dims, platform::GetPlainMKLDNNFormat(dx_vec_dims.size()),
ctx.GetPlace());
memset(dx->data<T>(), 0, reorder_dst_memory_p->get_desc().get_size());
auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets,
......@@ -199,8 +197,7 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
reorder_p->execute(astream, *reorder_src_memory_p, *slice_mem_p);
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(reorder_format_tag);
dx->set_mem_desc(reorder_dst_memory_p->get_desc());
}
};
} // namespace operators
......
......@@ -59,7 +59,7 @@ class StackMKLDNNHandler
// wrong output format deduction and suboptimal performance as a result
if (stack_axis != ndims) {
for (size_t i = 0; i < inputs.size(); ++i) {
srcs_md.emplace_back(memory::desc(input_dims, dt, inputs[i]->format()));
srcs_md.push_back(inputs[i]->mem_desc());
}
input_dims[stack_axis] *= inputs.size();
......@@ -69,8 +69,7 @@ class StackMKLDNNHandler
extended_input_dims[stack_axis] = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
srcs_md.emplace_back(memory::desc(input_dims, dt, inputs[i]->format())
.reshape(extended_input_dims));
srcs_md.push_back(inputs[i]->mem_desc().reshape(extended_input_dims));
}
// concat primitive choses suboptimal format tag because it cannot
......@@ -130,9 +129,8 @@ class StackMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
concat_p->execute(astream, args);
astream.wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(
dst_mem->get_desc().reshape(phi::vectorize(output->dims()))));
output->set_mem_desc(
dst_mem->get_desc().reshape(phi::vectorize(output->dims())));
}
};
} // namespace operators
......
......@@ -60,17 +60,16 @@ class SumMKLDNNHandler
auto src_tz = dst_tz;
std::vector<dnnl::memory::desc> srcs_md;
srcs_md.reserve(in_vars.size());
for (size_t i = 0; i < in_vars.size(); i++) {
auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) {
continue;
}
MKLDNNMemoryFormat input_format = input_it.format();
srcs_md.push_back(dnnl::memory::desc(
src_tz, platform::MKLDNNGetDataType<T>(), input_format));
srcs_md.push_back(input_it.mem_desc());
++num_inputs_;
}
std::vector<float> scales(num_inputs_, 1.0);
std::vector<float> scales(num_inputs_, 1.0f);
auto dst_md = dnnl::memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
......@@ -139,47 +138,27 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
++input_index;
}
std::shared_ptr<dnnl::memory> dst_mem = nullptr;
std::unordered_map<int, dnnl::memory> args;
std::shared_ptr<dnnl::memory> dst_mem;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
}
if (in_place) {
dst_mem = handler.AcquireDstMemory();
output->mutable_data<T>(ctx.GetPlace());
dst_mem = srcs_mem[0];
} else {
dst_mem = handler.AcquireDstMemory(output);
}
args.insert({DNNL_ARG_DST, *dst_mem});
auto sum_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> args;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
}
args.insert({DNNL_ARG_DST, *dst_mem});
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
sum_p->execute(astream, args);
astream.wait();
// For in-place execution which sum does not have we need to fake it
// so from oneDNN dst memory we reorder data into input
if (in_place) {
auto& in_out = in_vars[0]->Get<framework::LoDTensor>();
auto output_tz = phi::vectorize<int64_t>(output->dims());
platform::ReorderMKLDNNHandler reorder_handler(
output_tz, framework::TransToProtoVarType(output->dtype()),
framework::ToMKLDNNDataType(
framework::TransToProtoVarType(in_out.dtype())),
dev_ctx.GetEngine());
auto target_mem = reorder_handler.AcquireDstMemory(
output, in_out.format(), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(target_mem, dst_mem);
reorder_p->execute(astream, *dst_mem, *target_mem);
astream.wait();
}
output->set_layout(framework::DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_mem));
output->set_mem_desc(dst_mem->get_desc());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册