未验证 提交 07790ba1 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Reimplemented elementwise_add grad (#29747)

* - Reimplemented elementwise_add grad

- lint

* - fix after review

* - Fix to fix after review
上级 6ef8129d
......@@ -33,27 +33,45 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
in->set_layout(DataLayout::kMKLDNN);
in->set_format(out->format());
};
auto tz = paddle::framework::vectorize<int64_t>(dout->dims());
memory::data_type dout_type = framework::ToMKLDNNDataType(dout->type());
std::string key = platform::CreateKey(dev_ctx, tz, dout->format(),
dout->format(), dout_type);
platform::ReorderMKLDNNHandler handler(tz, dout->type(), dout_type, dev_ctx,
onednn_engine, key);
mkldnn::stream astream(onednn_engine);
auto reorder_src_memory_p = handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
// TODO(jczaja): Double check if vcopy works for blocked data
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dx->mutable_data<T>(ctx.GetPlace()));
set_mkldnn_format(dx, dout);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
}
if (dy) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dy->mutable_data<T>(ctx.GetPlace()));
set_mkldnn_format(dy, dout);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
}
}
};
......
......@@ -1054,13 +1054,14 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
auto dst_data = output->mutable_data(place, vtype_);
auto dst_data = output->mutable_data(place, vtype_, dst_md.get_size());
mem_p = std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
auto dst_data = output->mutable_data(place, vtype_);
// Even if memory object exists , we may be using it for diffrent tensor
auto dst_data =
output->mutable_data(place, vtype_, mem_p->get_desc().get_size());
mem_p->set_data_handle(dst_data);
}
return mem_p;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册