diff --git a/paddle/fluid/operators/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/batch_norm_mkldnn_op.cc index 9ab2179b5fe689762704039c5f67dd080e530aa5..de641cb08e4cc3322cc8387d873f2aaab279e1dd 100644 --- a/paddle/fluid/operators/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/batch_norm_mkldnn_op.cc @@ -37,6 +37,95 @@ struct bn_type_traits { using op_prim = typename op_type::primitive_desc; }; +class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { + public: + BatchNormMKLDNNHandler( + std::shared_ptr batch_norm_pd, + const platform::MKLDNNDeviceContext &dev_ctx, mkldnn::engine engine, + const std::string &base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key) { + batch_norm_pd_ = batch_norm_pd; + } + + std::shared_ptr AcquireScaleshiftMemoryFromPrimitive(void *ptr) { + return this->AcquireMemoryFromPrimitive( + batch_norm_pd_->weights_primitive_desc(), ptr, "@scaleshift_mem_p"); + } + + std::shared_ptr AcquireMeanMemoryFromPrimitive(void *ptr) { + return this->AcquireMemoryFromPrimitive( + batch_norm_pd_->mean_primitive_desc(), ptr, "@mean_mem_p"); + } + + std::shared_ptr AcquireVarianceMemoryFromPrimitive(void *ptr) { + return this->AcquireMemoryFromPrimitive( + batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p"); + } + + std::shared_ptr AcquireTestTrainingBatchNormFwd( + std::shared_ptr src_memory, + std::shared_ptr scaleshift_memory, + std::shared_ptr dst_memory, std::shared_ptr mean_memory, + std::shared_ptr variance_memory, bool is_test) { + auto prim_key = key_ + "@batch_norm_p"; + auto batch_norm_p = + std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); + + PADDLE_ENFORCE((batch_norm_p != nullptr) || !is_reusing_, + "Fail to find batch norm primitive in device context"); + + if (batch_norm_p == nullptr) { + if (is_test) { + batch_norm_p = std::make_shared( + *batch_norm_pd_, *src_memory, + (const mkldnn::primitive::at &)*mean_memory, + (const mkldnn::primitive::at &)*variance_memory, *scaleshift_memory, + *dst_memory); + } else { + batch_norm_p = std::make_shared( + *batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory, + *mean_memory, *variance_memory); + } + + dev_ctx_.SetBlob(prim_key, batch_norm_p); + } else { + is_reusing_ = true; + } + + return batch_norm_p; + } + + static std::string GetHash(const memory::dims &input_dims, float epsilon, + unsigned flag, bool is_test, memory::format format, + const std::string &suffix = "") { + auto dims2str = [](const memory::dims &operand_dims) { + std::string dstr = ""; + for (size_t i = 0; i < operand_dims.size(); ++i) { + dstr += std::to_string(operand_dims[i]) + "-"; + } + return dstr; + }; + return dims2str(input_dims) + std::to_string(epsilon) + + std::to_string(flag) + std::to_string(is_test) + + std::to_string(format) + suffix; + } + + private: + std::shared_ptr batch_norm_pd_; +}; + +std::shared_ptr UpdateMemoryData( + const platform::MKLDNNDeviceContext &dev_ctx, const std::string &key, + void *new_ptr) { + auto mem = std::static_pointer_cast(dev_ctx.GetBlob(key)); + PADDLE_ENFORCE( + mem != nullptr, + (std::string("Fail to find memory in device context [key: ") + key + "]") + .c_str()); + mem->set_data_handle(new_ptr); + return mem; +} + template void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end, Container *c) { @@ -48,15 +137,6 @@ void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end, std::inserter(*c, std::next(it, std::distance(scale_begin, scale_end)))); } -template -void run_batch_norm_op(Args &&... args) { - Op batch_norm_op{args...}; - - std::vector pipeline; - pipeline.push_back(batch_norm_op); - mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); -} - } // namespace template @@ -110,6 +190,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1"); const unsigned int ic = scale_tz[0]; + // MKLDNN requires a single piece of memory for scale and shift/bias data + const size_t scaleshift_size = 2 * ic; + std::vector scaleshift_data; + scaleshift_data.reserve(scaleshift_size); + + copy_to_weights(scale->data(), scale->data() + ic, shift->data(), + shift->data() + ic, &scaleshift_data); + unsigned flags = mkldnn::use_scale_shift; if (is_test) flags |= mkldnn::use_global_stats; if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; @@ -118,64 +206,69 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { mkldnn::memory::format input_format = platform::MKLDNNFormatForSize(src_tz.size(), x->format()); - auto src_memory = memory( - {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, - to_void_cast(x_data)); + // keys for backward pass + const std::string key = BatchNormMKLDNNHandler::GetHash( + src_tz, epsilon, flags, is_test, input_format, + ctx.op().Output("SavedMean")); + const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; + + auto user_src_md = platform::MKLDNNMemDesc( + {src_tz}, platform::MKLDNNGetDataType(), input_format); // create primitive descriptor for batch norm forward using bn_fwd_types = bn_type_traits; - auto batch_norm_fwd_desc = bn_fwd_types::op_desc{ - propagation, src_memory.get_primitive_desc().desc(), epsilon, flags}; - std::shared_ptr batch_norm_fwd_pd = - std::shared_ptr( - new batch_norm_fwd::primitive_desc(batch_norm_fwd_desc, - mkldnn_engine)); - - // Save the pd to be used in backward pass - const std::string key = ctx.op().Output("SavedMean"); - const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; + auto batch_norm_fwd_desc = + bn_fwd_types::op_desc{propagation, user_src_md, epsilon, flags}; + auto batch_norm_fwd_pd = std::make_shared( + batch_norm_fwd_desc, mkldnn_engine); + // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_batch_norm_fwd_pd, batch_norm_fwd_pd); - // MKLDNN requires a single piece of memory for scale and shift/bias data - const size_t scaleshift_size = 2 * ic; - std::vector scaleshift_data; - scaleshift_data.reserve(scaleshift_size); + BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine, + key); - copy_to_weights(scale->data(), scale->data() + ic, shift->data(), - shift->data() + ic, &scaleshift_data); + auto src_memory = + handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data)); // crate mkldnn memory for weights(scale/shift) - auto scaleshift_memory = memory(batch_norm_fwd_pd->weights_primitive_desc(), - scaleshift_data.data()); + auto scaleshift_memory = + handler.AcquireScaleshiftMemoryFromPrimitive(scaleshift_data.data()); // create mkldnn memory for output y tensor - auto dst_memory = memory(batch_norm_fwd_pd->dst_primitive_desc(), y_data); + auto dst_memory = handler.AcquireDstMemory( + batch_norm_fwd_pd->dst_primitive_desc().desc(), y_data); + std::shared_ptr batch_norm_p; if (is_test) { // create mkldnn memory for stats (as input) - auto mean_memory = memory(batch_norm_fwd_pd->mean_primitive_desc(), - to_void_cast(mean_data)); - auto variance_memory = - memory(batch_norm_fwd_pd->variance_primitive_desc(), - to_void_cast(variance_data)); - - run_batch_norm_op( - *batch_norm_fwd_pd, src_memory, - (const mkldnn::primitive::at &)mean_memory, - (const mkldnn::primitive::at &)variance_memory, scaleshift_memory, - dst_memory); + std::shared_ptr mean_memory = + handler.AcquireMeanMemoryFromPrimitive(to_void_cast(mean_data)); + std::shared_ptr variance_memory = + handler.AcquireVarianceMemoryFromPrimitive( + to_void_cast(variance_data)); + + batch_norm_p = handler.AcquireTestTrainingBatchNormFwd( + src_memory, scaleshift_memory, dst_memory, mean_memory, + variance_memory, true); } else { // create mkldnn memory for stats (as output) - auto mean_memory = - memory(batch_norm_fwd_pd->mean_primitive_desc(), batch_mean_data); - auto variance_memory = memory( - batch_norm_fwd_pd->variance_primitive_desc(), batch_variance_data); - - run_batch_norm_op(*batch_norm_fwd_pd, src_memory, - scaleshift_memory, dst_memory, - mean_memory, variance_memory); + std::shared_ptr mean_memory = + handler.AcquireMeanMemoryFromPrimitive(batch_mean_data); + std::shared_ptr variance_memory = + handler.AcquireVarianceMemoryFromPrimitive(batch_variance_data); + + batch_norm_p = handler.AcquireTestTrainingBatchNormFwd( + src_memory, scaleshift_memory, dst_memory, mean_memory, + variance_memory, false); } + y->set_layout(DataLayout::kMKLDNN); + y->set_format(platform::GetMKLDNNFormat(*dst_memory)); + + std::vector pipeline; + pipeline.push_back(*batch_norm_p); + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + if (!is_test) { // mkldnn only compute stats for current batch // so we need compute momentum stats via Eigen lib @@ -192,10 +285,6 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { running_variance_e = variance_e * momentum + batch_variance_e * one_minus_momentum; } - - y->set_layout(DataLayout::kMKLDNN); - y->set_format( - (memory::format)dst_memory.get_primitive_desc().desc().data.format); } }; @@ -242,61 +331,48 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { const unsigned int ic = scale_tz[0]; - // Retrieve bn_fwd_pd from device context - const std::string key = ctx.op().Input("SavedMean"); - const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; - auto batch_norm_fwd_pd = - std::static_pointer_cast( - dev_ctx.GetBlob(key_batch_norm_fwd_pd)); - PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr, - "Fail to find batch_norm_fwd_pd in device context"); - using bn_bwd_types = bn_type_traits; - // create mkldnn memory from input diff_y tensor - mkldnn::memory::format dst_format = platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format()); - auto user_diff_dst_memory = memory( - {{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine}, - to_void_cast(diff_y_data)); - - // create mkldnn memory from input x tensor mkldnn::memory::format input_format = platform::MKLDNNFormatForSize(src_tz.size(), x->format()); - auto src_memory = memory( - {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, - to_void_cast(x_data)); + unsigned flags = mkldnn::use_scale_shift; - // for diff_dst, try to use same format as dst in forward pass - auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc(); - auto diff_dst_md = diff_dst_pd.desc(); + // keys from forward pass + const std::string key = BatchNormMKLDNNHandler::GetHash( + src_tz, epsilon, flags, false, input_format, + ctx.op().Input("SavedMean")); + const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; + + // keys for primitives reuse + const std::string key_with_hash = + key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false, + input_format); + const std::string key_batch_norm_bwd_p = + key_with_hash + "@batch_norm_bwd_p"; + const std::string key_batch_norm_src_mem_p = + key_with_hash + "@batch_norm_bwd_src_mem_p"; + const std::string key_batch_norm_mean_mem_p = + key_with_hash + "@batch_norm_bwd_mean_mem_p"; + const std::string key_batch_norm_variance_mem_p = + key_with_hash + "@batch_norm_bwd_variance_mem_p"; + const std::string key_batch_norm_scaleshift_mem_p = + key_with_hash + "@batch_norm_bwd_scaleshift_mem_p"; + const std::string key_batch_norm_diff_scaleshift_mem_p = + key_with_hash + "@batch_norm_bwd_diff_scaleshift_mem_p"; + const std::string key_batch_norm_diff_src_mem_p = + key_with_hash + "@batch_norm_bwd_diff_src_mem_p"; + const std::string key_batch_norm_diff_dst_mem_p = + key_with_hash + "@batch_norm_bwd_diff_dst_mem_p"; - // create primitive descriptor for batch norm backward - unsigned flags = mkldnn::use_scale_shift; - auto batch_norm_bwd_desc = bn_bwd_types::op_desc{ - mkldnn::prop_kind::backward, diff_dst_md, - src_memory.get_primitive_desc().desc(), epsilon, flags}; - auto batch_norm_bwd_pd = bn_bwd_types::op_prim{ - batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd}; - - // reorder user_diff_dst if it's not in preferred format - auto diff_dst_memory = user_diff_dst_memory; primitive reorder_diff_dst; bool is_diff_dst_reordered = false; - if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) { - diff_dst_memory = memory(diff_dst_pd); - reorder_diff_dst = reorder(user_diff_dst_memory, diff_dst_memory); - is_diff_dst_reordered = true; - } - - // create mkldnn memory for input tensors (src/mean/variance) - auto mean_memory = memory(batch_norm_bwd_pd.mean_primitive_desc(), - to_void_cast(batch_mean_data)); - auto variance_memory = memory(batch_norm_bwd_pd.variance_primitive_desc(), - to_void_cast(batch_variance_data)); + auto user_diff_dst_memory = memory( + {{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine}, + to_void_cast(diff_y_data)); // MKLDNN requires a single piece of memory for scale and shift/bias data const size_t scaleshift_size = 2 * ic; @@ -306,30 +382,118 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { copy_to_weights(scale_data, scale_data + ic, shift_data, shift_data + ic, &scaleshift_data); - // create mkldnn memory for input tensors (scale/shift) - auto scaleshift_memory = memory(batch_norm_bwd_pd.weights_primitive_desc(), - scaleshift_data.data()); - - // create mkldnn memory for output diff weights (combined scale/shift) std::vector diff_scaleshift_data; diff_scaleshift_data.reserve(scaleshift_size); - auto diff_scaleshift_memory = - memory(batch_norm_bwd_pd.diff_weights_primitive_desc(), - diff_scaleshift_data.data()); - // here assume diff_src is in the same format of src - auto diff_src_memory = memory(src_memory.get_primitive_desc(), diff_x_data); + auto batch_norm_fwd_pd = + std::static_pointer_cast( + dev_ctx.GetBlob(key_batch_norm_fwd_pd)); + PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr, + "Fail to find batch_norm_fwd_pd in device context"); - // finally create batch_norm backward primitive - auto batch_norm_bwd_prim = - batch_norm_bwd(batch_norm_bwd_pd, src_memory, mean_memory, - variance_memory, diff_dst_memory, scaleshift_memory, - diff_src_memory, diff_scaleshift_memory); + auto batch_norm_bwd_p = std::static_pointer_cast( + dev_ctx.GetBlob(key_batch_norm_bwd_p)); + + if (batch_norm_bwd_p == nullptr) { + auto src_memory = std::shared_ptr(new memory( + {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, + to_void_cast(x_data))); + + // for diff_dst, try to use same format as dst in forward pass + auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc(); + auto diff_dst_md = diff_dst_pd.desc(); + + // create primitive descriptor for batch norm backward + auto batch_norm_bwd_desc = bn_bwd_types::op_desc{ + mkldnn::prop_kind::backward, diff_dst_md, + src_memory->get_primitive_desc().desc(), epsilon, flags}; + auto batch_norm_bwd_pd = bn_bwd_types::op_prim{ + batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd}; + + // reorder user_diff_dst if it's not in preferred format + auto diff_dst_memory = std::make_shared(user_diff_dst_memory); + if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) { + diff_dst_memory = std::make_shared(diff_dst_pd); + reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory); + is_diff_dst_reordered = true; + } + + // create mkldnn memory for input tensors (src/mean/variance) + auto mean_memory = + std::make_shared(batch_norm_bwd_pd.mean_primitive_desc(), + to_void_cast(batch_mean_data)); + auto variance_memory = + std::make_shared(batch_norm_bwd_pd.variance_primitive_desc(), + to_void_cast(batch_variance_data)); + + // create mkldnn memory for input tensors (scale/shift) + auto scaleshift_memory = std::make_shared( + batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data()); + + // create mkldnn memory for output diff weights (combined scale/shift) + auto diff_scaleshift_memory = std::make_shared( + batch_norm_bwd_pd.diff_weights_primitive_desc(), + diff_scaleshift_data.data()); + + // here assume diff_src is in the same format of src + auto diff_src_memory = std::make_shared( + src_memory->get_primitive_desc(), diff_x_data); + + // finally create batch_norm backward primitive + batch_norm_bwd_p = std::make_shared( + batch_norm_bwd_pd, *src_memory, *mean_memory, *variance_memory, + *diff_dst_memory, *scaleshift_memory, *diff_src_memory, + *diff_scaleshift_memory); + + dev_ctx.SetBlob(key_batch_norm_bwd_p, batch_norm_bwd_p); + dev_ctx.SetBlob(key_batch_norm_src_mem_p, src_memory); + dev_ctx.SetBlob(key_batch_norm_mean_mem_p, mean_memory); + dev_ctx.SetBlob(key_batch_norm_variance_mem_p, variance_memory); + dev_ctx.SetBlob(key_batch_norm_scaleshift_mem_p, scaleshift_memory); + dev_ctx.SetBlob(key_batch_norm_diff_scaleshift_mem_p, + diff_scaleshift_memory); + dev_ctx.SetBlob(key_batch_norm_diff_src_mem_p, diff_src_memory); + dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory); + + // set layout/format of output tensors + diff_x->set_layout(DataLayout::kMKLDNN); + diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc() + .desc() + .data.format); + } else { + // primitives already exist + UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data)); + UpdateMemoryData(dev_ctx, key_batch_norm_mean_mem_p, + to_void_cast(batch_mean_data)); + UpdateMemoryData(dev_ctx, key_batch_norm_variance_mem_p, + to_void_cast(batch_variance_data)); + UpdateMemoryData(dev_ctx, key_batch_norm_scaleshift_mem_p, + scaleshift_data.data()); + UpdateMemoryData(dev_ctx, key_batch_norm_diff_scaleshift_mem_p, + diff_scaleshift_data.data()); + auto diff_src_memory = UpdateMemoryData( + dev_ctx, key_batch_norm_diff_src_mem_p, to_void_cast(diff_x_data)); + auto diff_dst_memory = UpdateMemoryData( + dev_ctx, key_batch_norm_diff_dst_mem_p, to_void_cast(diff_y_data)); + + // reorder user_diff_dst if it's not in preferred format + if (diff_dst_memory->get_primitive_desc() != + user_diff_dst_memory.get_primitive_desc()) { + reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory); + is_diff_dst_reordered = true; + } + + // set layout/format of output tensors + diff_x->set_layout(DataLayout::kMKLDNN); + diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc() + .desc() + .data.format); + } // execute optional reorder and batch_norm backward primitive std::vector pipeline; if (is_diff_dst_reordered) pipeline.push_back(reorder_diff_dst); - pipeline.push_back(batch_norm_bwd_prim); + pipeline.push_back(*batch_norm_bwd_p); stream(stream::kind::eager).submit(pipeline).wait(); // copy back diff sacle/shift to output tensors (diff scale/shift) @@ -338,12 +502,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { std::copy(it, std::next(it, ic), diff_scale_data); std::copy(std::next(it, ic), std::end(diff_scaleshift_data), diff_shift_data); - - // set layout/format of output tensors - diff_x->set_layout(DataLayout::kMKLDNN); - diff_x->set_format((memory::format)diff_src_memory.get_primitive_desc() - .desc() - .data.format); } }; } // namespace operators