diff --git a/paddle/fluid/operators/activation_mkldnn_op.cc b/paddle/fluid/operators/activation_mkldnn_op.cc index 6f2919255043a19eebb24ea8d2db3ba527777401..482095030fd58382714f98d83f6cc86c7bf5640d 100644 --- a/paddle/fluid/operators/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/activation_mkldnn_op.cc @@ -15,6 +15,7 @@ #include "mkldnn.hpp" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/mkldnn_activation_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" namespace paddle { namespace operators { @@ -25,9 +26,14 @@ using paddle::platform::MKLDNNDeviceContext; namespace { std::string gethash(const mkldnn::memory::dims &operand_dims, const mkldnn::algorithm algorithm) { - return std::string(std::to_string(operand_dims[0]) + "-" + - std::to_string(operand_dims[1]) + "-" + - std::to_string(algorithm)); + auto dim2str = [](const mkldnn::memory::dims &operand_dims) { + std::string dstr = ""; + for (size_t i = 0; i < operand_dims.size(); ++i) { + dstr += std::to_string(operand_dims[i]) + "-"; + } + return dstr; + }; + return dim2str(operand_dims) + std::to_string(algorithm); } template @@ -44,7 +50,7 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, const auto *src_data = src->template data(); auto *dst = ctx.template Output("Out"); - const T *dst_data = dst->template mutable_data(ctx.GetPlace()); + T *dst_data = dst->template mutable_data(ctx.GetPlace()); // get memory dim PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4, @@ -52,15 +58,14 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, std::vector src_tz = framework::vectorize2int(src->dims()); const std::string key = gethash(src_tz, algorithm); - const std::string key_src_mem = key + "@eltwise_src_mem"; - const std::string key_dst_mem = key + "@eltwise_dst_mem"; + const std::string key_src_mem = key + "@eltwise_fwd_src_mem"; + const std::string key_dst_mem = key + "@eltwise_fwd_dst_mem"; const std::string key_fwd = key + "@eltwise_fwd"; - std::shared_ptr p_src_mem = dev_ctx.GetBlob(key_src_mem); - std::shared_ptr p_dst_mem = dev_ctx.GetBlob(key_dst_mem); - std::shared_ptr p_fwd = dev_ctx.GetBlob(key_fwd); + auto p_fwd = std::static_pointer_cast( + dev_ctx.GetBlob(key_fwd)); - if (p_src_mem == nullptr || p_dst_mem == nullptr || p_fwd == nullptr) { + if (p_fwd == nullptr) { // create memory description auto data_md = src_tz.size() == 2 ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, @@ -69,35 +74,40 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, mkldnn::memory::format::nchw); // create memory primitives - p_src_mem = std::make_shared( - mkldnn::memory({data_md, mkldnn_engine}, - static_cast(const_cast(src_data)))); + auto p_src_mem = std::make_shared(mkldnn::memory( + {data_md, mkldnn_engine}, platform::to_void_cast(src_data))); dev_ctx.SetBlob(key_src_mem, p_src_mem); - p_dst_mem = std::make_shared( - mkldnn::memory({data_md, mkldnn_engine}, - static_cast(const_cast(dst_data)))); + auto p_dst_mem = std::make_shared(mkldnn::memory( + {data_md, mkldnn_engine}, platform::to_void_cast(dst_data))); dev_ctx.SetBlob(key_dst_mem, p_dst_mem); auto fwd_desc = mkldnn::eltwise_forward::desc( mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta); auto p_fwd_pd = std::make_shared( fwd_desc, mkldnn_engine); + const std::string key_fwd_pd = key + "eltwise_fwd_pd"; + dev_ctx.SetBlob(key_fwd_pd, p_fwd_pd); p_fwd = std::make_shared( - *(p_fwd_pd.get()), *(static_cast(p_src_mem.get())), - *(static_cast(p_dst_mem.get()))); + *p_fwd_pd, *(p_src_mem.get()), *(p_dst_mem.get())); dev_ctx.SetBlob(key_fwd, p_fwd); } else { - std::static_pointer_cast(p_src_mem)->set_data_handle( - reinterpret_cast(const_cast(src_data))); - - std::static_pointer_cast(p_dst_mem)->set_data_handle( - reinterpret_cast(const_cast(dst_data))); + // primitives already exist + auto p_src_mem = + std::static_pointer_cast(dev_ctx.GetBlob(key_src_mem)); + PADDLE_ENFORCE(p_src_mem != nullptr, + "Fail to find eltwise p_src_mem in device context."); + auto p_dst_mem = + std::static_pointer_cast(dev_ctx.GetBlob(key_dst_mem)); + PADDLE_ENFORCE(p_dst_mem != nullptr, + "Fail to find eltwise p_src_mem in device context."); + + p_src_mem->set_data_handle(platform::to_void_reinterpret_cast(src_data)); + p_dst_mem->set_data_handle(dst_data); } // push primitive to stream and wait until it's executed - std::vector pipeline = { - *(static_cast(p_fwd.get()))}; + std::vector pipeline = {*(p_fwd.get())}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); } @@ -121,47 +131,64 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, std::vector src_tz = framework::vectorize2int(out->dims()); const std::string key = gethash(src_tz, algorithm); - const std::string key_src_mem = key + "@eltwise_src_mem"; - const std::string key_dst_mem = key + "@eltwise_dst_mem"; - const std::string key_fwd = key + "@eltwise_fwd"; - // create memory description - auto data_md = src_tz.size() == 2 - ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, - mkldnn::memory::format::nc) - : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, - mkldnn::memory::format::nchw); - - // retrieve source memory from device context - const std::shared_ptr src_mem = dev_ctx.GetBlob(key_src_mem); - auto *p_src_mem = static_cast(src_mem.get()); - - // create memory primitives - auto diff_src_memory = - mkldnn::memory({data_md, mkldnn_engine}, - static_cast(const_cast(diff_src))); - auto diff_dst_memory = - mkldnn::memory({data_md, mkldnn_engine}, - static_cast(const_cast(diff_dst))); - - auto backward_desc = - mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta); - - // retrieve eltwise primitive desc from device context - const std::shared_ptr forward_pd = dev_ctx.GetBlob(key_fwd); - PADDLE_ENFORCE(forward_pd != nullptr, - "Fail to find eltwise_pd in device context"); - auto *p_forward_pd = - static_cast(forward_pd.get()); - - auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc( - backward_desc, mkldnn_engine, *p_forward_pd); - - auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, *p_src_mem, - diff_dst_memory, diff_src_memory); + const std::string key_diff_src_mem = key + "@eltwise_diff_src_mem"; + const std::string key_diff_dst_mem = key + "@eltwise_diff_dst_mem"; + const std::string key_grad = key + "@eltwise_grad"; + + auto p_grad = std::static_pointer_cast( + dev_ctx.GetBlob(key_grad)); + + if (p_grad == nullptr) { + // create memory description + auto data_md = src_tz.size() == 2 + ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, + mkldnn::memory::format::nc) + : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, + mkldnn::memory::format::nchw); + + // create memory primitives + std::shared_ptr p_diff_src_mem = + std::make_shared(mkldnn::memory( + {data_md, mkldnn_engine}, platform::to_void_cast(diff_src))); + dev_ctx.SetBlob(key_diff_src_mem, p_diff_src_mem); + std::shared_ptr p_diff_dst_mem = + std::make_shared(mkldnn::memory( + {data_md, mkldnn_engine}, platform::to_void_cast(diff_dst))); + dev_ctx.SetBlob(key_diff_dst_mem, p_diff_dst_mem); + + auto bwd_desc = mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, + alpha, beta); + + const std::string key_fwd_pd = key + "eltwise_fwd_pd"; + auto *p_fwd_pd = static_cast( + dev_ctx.GetBlob(key_fwd_pd).get()); + + auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc( + bwd_desc, mkldnn_engine, *p_fwd_pd); + + const std::string key_src_mem = key + "@eltwise_fwd_src_mem"; + const std::shared_ptr p_src_mem = dev_ctx.GetBlob(key_src_mem); + + p_grad = std::make_shared( + eltwise_bwd_prim_desc, *static_cast(p_src_mem.get()), + *(static_cast(p_diff_dst_mem.get())), + *(static_cast(p_diff_src_mem.get()))); + } else { + // primitives already exist + auto p_diff_src_mem = std::static_pointer_cast( + dev_ctx.GetBlob(key_diff_src_mem)); + auto p_diff_dst_mem = std::static_pointer_cast( + dev_ctx.GetBlob(key_diff_dst_mem)); + + p_diff_src_mem->set_data_handle( + platform::to_void_reinterpret_cast(diff_src)); + p_diff_dst_mem->set_data_handle( + platform::to_void_reinterpret_cast(diff_dst)); + } // push primitive to stream and wait until it's executed - std::vector pipeline = {eltwise_bwd}; + std::vector pipeline = {*(p_grad.get())}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); } } // anonymous namespace diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 56ed5912a15437b72b769610912c7493d77e5964..f1187620d81ff3bc1deef2106edb54d6199fa927 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -38,6 +38,11 @@ void* to_void_cast(const Type* t) { return static_cast(const_cast(t)); } +template +void* to_void_reinterpret_cast(const Type* t) { + return reinterpret_cast(const_cast(t)); +} + template using tf_desc = typename Type::desc;