diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc index f3209151b359aaba52d8bd5259013d79f130096d..6b1c870c3c148e169c32a7c04134a9134845079c 100644 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc @@ -19,136 +19,103 @@ limitations under the License. */ namespace paddle { namespace operators { -using batch_norm_bwd = mkldnn::batch_normalization_backward; -using batch_norm_fwd = mkldnn::batch_normalization_forward; using mkldnn::memory; using mkldnn::primitive; using mkldnn::reorder; using mkldnn::stream; using paddle::platform::MKLDNNDeviceContext; -using paddle::platform::MKLDNNMemDesc; using platform::to_void_cast; -namespace { template -struct bn_type_traits { - using op_type = T; - using op_desc = typename op_type::desc; - using op_prim = typename op_type::primitive_desc; -}; - -class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { +class BatchNormMKLDNNHandler + : public platform::MKLDNNHandlerT { public: - BatchNormMKLDNNHandler(const platform::MKLDNNDeviceContext &dev_ctx, - mkldnn::engine engine, const std::string &base_key) - : platform::MKLDNNHandler(dev_ctx, engine, base_key) {} + BatchNormMKLDNNHandler(const std::vector &dims, const float &epsilon, + const unsigned &flags, const bool &global_stats, + const MKLDNNMemoryFormat fmt, + const platform::MKLDNNDeviceContext &dev_ctx, + platform::Place cpu_place, + const std::string &uniq_name) + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(dims, epsilon, flags, global_stats, fmt, + uniq_name)) { + auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); + + this->AcquireForwardPrimitiveDescriptor( + global_stats == true ? mkldnn::prop_kind::forward_scoring + : mkldnn::prop_kind::forward_training, + md, epsilon, flags); + } + BatchNormMKLDNNHandler(const std::vector &dims, const float &epsilon, + const unsigned &flags, + const MKLDNNMemoryFormat diff_fmt, + const MKLDNNMemoryFormat src_fmt, + const platform::MKLDNNDeviceContext &dev_ctx, + platform::Place cpu_place, + const std::string &uniq_name) + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(dims, epsilon, flags, false, src_fmt, + uniq_name)) { + auto diff_dst_md = + mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), diff_fmt); + auto src_md = + mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), src_fmt); + + this->AcquireBackwardPrimitiveDescriptor( + mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, flags); + } - std::shared_ptr AcquireScaleshiftMemoryFromPrimitive(void *ptr) { + std::shared_ptr AcquireScaleShiftMemory(T *scaleshift_data) { return this->AcquireMemoryFromPrimitive( - batch_norm_pd_->weights_primitive_desc(), ptr, "@scaleshift_mem_p"); + this->fwd_pd_->weights_primitive_desc(), scaleshift_data, + "@scaleshift_mem_p"); } - std::shared_ptr AcquireMeanMemoryFromPrimitive(void *ptr) { + std::shared_ptr AcquireDiffScaleShiftMemory( + T *diff_scaleshift_data) { return this->AcquireMemoryFromPrimitive( - batch_norm_pd_->mean_primitive_desc(), ptr, "@mean_mem_p"); + this->bwd_pd_->diff_weights_primitive_desc(), diff_scaleshift_data, + "@diff_scaleshift_mem_p"); } - std::shared_ptr AcquireVarianceMemoryFromPrimitive(void *ptr) { + std::shared_ptr AcquireMeanMemory( + const framework::Tensor *mean) { + const T *mean_data = mean->data(); return this->AcquireMemoryFromPrimitive( - batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p"); + this->fwd_pd_->mean_primitive_desc(), to_void_cast(mean_data), + "@mean_mem_p"); } - template - std::shared_ptr AcquireDstMemoryFromPrimitive( - framework::Tensor *output, platform::Place place) { - T *ptr = output->mutable_data( - place, batch_norm_pd_->dst_primitive_desc().get_size()); + std::shared_ptr AcquireMeanMemory(framework::Tensor *mean) { + T *mean_data = mean->mutable_data( + this->place_, this->fwd_pd_->mean_primitive_desc().get_size()); return this->AcquireMemoryFromPrimitive( - batch_norm_pd_->dst_primitive_desc(), ptr, "@dst_mem_p"); + this->fwd_pd_->mean_primitive_desc(), mean_data, "@mean_mem_p"); } - std::shared_ptr - AcquireBatchNormPrimitiveDescriptor(const batch_norm_fwd::desc &bn_fwd_desc, - const mkldnn::engine &engine) { - // BatchNorm PD has to be passed to Grad op that - // may be executed by diffrent thread, hence - // for that one we use key that does not contain TID - const std::string key_batch_norm_fwd_pd = key_common_ + "@bn_fwd_pd"; - batch_norm_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_batch_norm_fwd_pd)); - - if (batch_norm_pd_ == nullptr) { - static std::mutex acquire_barrier; - std::lock_guard block_threads_until_finish_this_job( - acquire_barrier); - batch_norm_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_batch_norm_fwd_pd)); - if (batch_norm_pd_ == nullptr) { - batch_norm_pd_.reset( - new batch_norm_fwd::primitive_desc(bn_fwd_desc, engine)); - dev_ctx_.SetBlob(key_batch_norm_fwd_pd, batch_norm_pd_); - } - } - return batch_norm_pd_; + std::shared_ptr AcquireVarianceMemory( + const framework::Tensor *variance) { + const T *variance_data = variance->data(); + return this->AcquireMemoryFromPrimitive( + this->fwd_pd_->variance_primitive_desc(), + to_void_cast(variance_data), "@variance_mem_p"); } - std::shared_ptr AcquireTestTrainingBatchNormFwd( - std::shared_ptr src_memory, - std::shared_ptr scaleshift_memory, - std::shared_ptr dst_memory, std::shared_ptr mean_memory, - std::shared_ptr variance_memory, bool is_test) { - auto prim_key = key_ + "@batch_norm_p"; - auto batch_norm_p = - std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); - - if (batch_norm_p == nullptr) { - if (is_test) { - batch_norm_p = std::make_shared( - *batch_norm_pd_, *src_memory, - (const mkldnn::primitive::at &)*mean_memory, - (const mkldnn::primitive::at &)*variance_memory, *scaleshift_memory, - *dst_memory); - } else { - batch_norm_p = std::make_shared( - *batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory, - *mean_memory, *variance_memory); - } - - dev_ctx_.SetBlob(prim_key, batch_norm_p); - } - - return batch_norm_p; + std::shared_ptr AcquireVarianceMemory( + framework::Tensor *variance) { + T *variance_data = variance->mutable_data( + this->place_, this->fwd_pd_->variance_primitive_desc().get_size()); + return this->AcquireMemoryFromPrimitive( + this->fwd_pd_->variance_primitive_desc(), variance_data, + "@variance_mem_p"); } - - private: - std::shared_ptr batch_norm_pd_; }; -std::shared_ptr UpdateMemoryData( - const platform::MKLDNNDeviceContext &dev_ctx, const std::string &key, - void *new_ptr) { - auto mem = std::static_pointer_cast(dev_ctx.GetBlob(key)); - PADDLE_ENFORCE( - mem != nullptr, - (std::string("Fail to find memory in device context [key: ") + key + "]") - .c_str()); - mem->set_data_handle(new_ptr); - return mem; -} - -template -void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end, - Container *c) { - auto it = std::begin(*c); - - std::copy(scale_begin, scale_end, std::inserter(*c, it)); - std::copy( - shift_begin, shift_end, - std::inserter(*c, std::next(it, std::distance(scale_begin, scale_end)))); -} - -} // namespace - template class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -158,14 +125,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { const bool is_test = ctx.Attr("is_test"); const bool use_global_stats = ctx.Attr("use_global_stats"); const bool fuse_with_relu = ctx.Attr("fuse_with_relu"); - bool global_stats = is_test || use_global_stats; - const auto *x = ctx.Input("X"); - const auto *mean = ctx.Input("Mean"); - const auto *variance = ctx.Input("Variance"); + bool global_stats = is_test || use_global_stats; auto &dev_ctx = ctx.template device_context(); - auto mkldnn_engine = dev_ctx.GetEngine(); + + const auto *x = ctx.Input("X"); + const auto *scale = ctx.Input("Scale"); + const auto *shift = ctx.Input("Bias"); auto *y = ctx.Output("Y"); auto *mean_out = ctx.Output("MeanOut"); @@ -173,102 +140,61 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { auto *batch_mean = ctx.Output("SavedMean"); auto *batch_variance = ctx.Output("SavedVariance"); - const auto *scale = ctx.Input("Scale"); - const auto *shift = ctx.Input("Bias"); - PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN, "Wrong layout set for X tensor"); PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::format_undef, "Wrong format set for X tensor"); - const T *x_data = x->data(); - const T *mean_data = mean->data(); - const T *variance_data = variance->data(); - T *mean_out_data = mean_out->mutable_data(ctx.GetPlace()); - T *variance_out_data = variance_out->mutable_data(ctx.GetPlace()); - T *batch_mean_data = nullptr; - T *batch_variance_data = nullptr; - - if (!global_stats) { - batch_mean_data = batch_mean->mutable_data(ctx.GetPlace()); - batch_variance_data = batch_variance->mutable_data(ctx.GetPlace()); - } - - auto propagation = global_stats == true - ? mkldnn::prop_kind::forward_scoring - : mkldnn::prop_kind::forward_training; - auto src_tz = paddle::framework::vectorize(x->dims()); auto scale_tz = paddle::framework::vectorize(scale->dims()); PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1"); - const unsigned int ic = scale_tz[0]; + const unsigned int C = scale_tz[0]; // MKLDNN requires a single piece of memory for scale and shift/bias data - const size_t scaleshift_size = 2 * ic; - std::vector scaleshift_data; - scaleshift_data.reserve(scaleshift_size); - - copy_to_weights(scale->data(), scale->data() + ic, shift->data(), - shift->data() + ic, &scaleshift_data); - - unsigned flags = mkldnn::use_scale_shift; - if (global_stats) flags |= mkldnn::use_global_stats; - if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; - - // create mkldnn memory from input x tensor - MKLDNNMemoryFormat input_format = - platform::MKLDNNFormatForSize(src_tz.size(), x->format()); - - // keys for backward pass - const std::string key = - platform::CreateKey(src_tz, epsilon, flags, global_stats, input_format, - ctx.op().Output("SavedMean")); - BatchNormMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); - - auto user_src_md = platform::MKLDNNMemDesc( - {src_tz}, platform::MKLDNNGetDataType(), input_format); - - // create primitive descriptor for batch norm forward - using bn_fwd_types = bn_type_traits; - auto batch_norm_fwd_desc = - bn_fwd_types::op_desc{propagation, user_src_md, epsilon, flags}; - - auto batch_norm_fwd_pd = handler.AcquireBatchNormPrimitiveDescriptor( - batch_norm_fwd_desc, mkldnn_engine); - - auto src_memory = - handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data)); - - // crate mkldnn memory for weights(scale/shift) + std::vector scaleshift_data(scale->data(), scale->data() + C); + scaleshift_data.reserve(2 * C); + scaleshift_data.insert(scaleshift_data.end(), shift->data(), + shift->data() + C); + + // Flags are added by bitwise OR operation + unsigned flags = mkldnn::use_scale_shift; // 001 + if (global_stats) flags |= mkldnn::use_global_stats; // 010 + if (fuse_with_relu && is_test) flags |= mkldnn::fuse_bn_relu; // 100 + + BatchNormMKLDNNHandler handler( + src_tz, epsilon, flags, global_stats, + platform::MKLDNNFormatForSize(src_tz.size(), x->format()), dev_ctx, + ctx.GetPlace(), ctx.op().Output("SavedMean")); + + auto src_memory = handler.AcquireSrcMemory(x); auto scaleshift_memory = - handler.AcquireScaleshiftMemoryFromPrimitive(scaleshift_data.data()); - - // create mkldnn memory for output y tensor - auto dst_memory = - handler.AcquireDstMemoryFromPrimitive(y, ctx.GetPlace()); + handler.AcquireScaleShiftMemory(scaleshift_data.data()); + auto dst_memory = handler.AcquireDstMemory(y); - std::shared_ptr batch_norm_p; + std::shared_ptr batch_norm_p; if (global_stats) { - // create mkldnn memory for stats (as input) - std::shared_ptr mean_memory = - handler.AcquireMeanMemoryFromPrimitive(to_void_cast(mean_data)); + // mean and variance are taken from input Tensor + const auto *mean = ctx.Input("Mean"); + const auto *variance = ctx.Input("Variance"); + + std::shared_ptr mean_memory = handler.AcquireMeanMemory(mean); std::shared_ptr variance_memory = - handler.AcquireVarianceMemoryFromPrimitive( - to_void_cast(variance_data)); + handler.AcquireVarianceMemory(variance); - batch_norm_p = handler.AcquireTestTrainingBatchNormFwd( - src_memory, scaleshift_memory, dst_memory, mean_memory, - variance_memory, true); + batch_norm_p = handler.AcquireForwardPrimitive( + *src_memory, (const mkldnn::primitive::at &)*mean_memory, + (const mkldnn::primitive::at &)*variance_memory, *scaleshift_memory, + *dst_memory); } else { - // create mkldnn memory for stats (as output) + // mean and variance are calculated and saved in output Tensor std::shared_ptr mean_memory = - handler.AcquireMeanMemoryFromPrimitive(batch_mean_data); + handler.AcquireMeanMemory(batch_mean); std::shared_ptr variance_memory = - handler.AcquireVarianceMemoryFromPrimitive(batch_variance_data); + handler.AcquireVarianceMemory(batch_variance); - batch_norm_p = handler.AcquireTestTrainingBatchNormFwd( - src_memory, scaleshift_memory, dst_memory, mean_memory, - variance_memory, false); + batch_norm_p = handler.AcquireForwardPrimitive( + *src_memory, *scaleshift_memory, *dst_memory, *mean_memory, + *variance_memory); } y->set_layout(DataLayout::kMKLDNN); @@ -281,18 +207,20 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { if (!global_stats) { // mkldnn only compute stats for current batch // so we need compute momentum stats via Eigen lib - EigenVectorArrayMap batch_mean_e(batch_mean_data, ic); - EigenVectorArrayMap batch_variance_e(batch_variance_data, ic); - ConstEigenVectorArrayMap mean_e(mean_data, ic); - ConstEigenVectorArrayMap variance_e{variance_data, ic}; - - EigenVectorArrayMap running_mean_e(mean_out_data, ic); - EigenVectorArrayMap running_variance_e(variance_out_data, ic); - - auto one_minus_momentum = 1. - momentum; - running_mean_e = mean_e * momentum + batch_mean_e * one_minus_momentum; + EigenVectorArrayMap batch_mean_e( + batch_mean->mutable_data(ctx.GetPlace()), C); + EigenVectorArrayMap batch_variance_e( + batch_variance->mutable_data(ctx.GetPlace()), C); + + EigenVectorArrayMap running_mean_e( + mean_out->mutable_data(ctx.GetPlace()), C); + EigenVectorArrayMap running_variance_e( + variance_out->mutable_data(ctx.GetPlace()), C); + + running_mean_e = + running_mean_e * momentum + batch_mean_e * (1. - momentum); running_variance_e = - variance_e * momentum + batch_variance_e * one_minus_momentum; + running_variance_e * momentum + batch_variance_e * (1. - momentum); } } }; @@ -311,7 +239,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { const auto *shift = ctx.Input("Bias"); const auto *batch_mean = ctx.Input("SavedMean"); const auto *batch_variance = ctx.Input("SavedVariance"); - const auto *diff_y = ctx.Input(framework::GradVarName("Y")); auto *diff_x = ctx.Output(framework::GradVarName("X")); auto *diff_scale = ctx.Output(framework::GradVarName("Scale")); @@ -322,27 +249,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::format_undef, "Wrong format set for Input diff_y tensor"); - const T *x_data = x->data(); - const T *diff_y_data = diff_y->data(); - const T *batch_mean_data = batch_mean->data(); - const T *batch_variance_data = batch_variance->data(); - const T *scale_data = scale->data(); - const T *shift_data = shift->data(); - T *diff_x_data = diff_x->mutable_data(ctx.GetPlace()); - - T *diff_scale_data = diff_scale->mutable_data(ctx.GetPlace()); - T *diff_shift_data = diff_shift->mutable_data(ctx.GetPlace()); - auto src_tz = paddle::framework::vectorize(x->dims()); - auto diff_src_tz = src_tz; - auto dst_tz = src_tz; - auto diff_dst_tz = dst_tz; auto scale_tz = paddle::framework::vectorize(scale->dims()); PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1"); - const unsigned int ic = scale_tz[0]; - - using bn_bwd_types = bn_type_traits; + const unsigned int C = scale_tz[0]; MKLDNNMemoryFormat dst_format = platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format()); @@ -350,170 +261,52 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { MKLDNNMemoryFormat input_format = platform::MKLDNNFormatForSize(src_tz.size(), x->format()); - unsigned flags = mkldnn::use_scale_shift; - - // keys from forward pass - const std::string key = - platform::CreateKey(src_tz, epsilon, flags, false, input_format, - ctx.op().Input("SavedMean")); - const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; - - // keys for primitives reuse - const std::string key_with_hash = - key + platform::CreateKey(src_tz, epsilon, flags, false, input_format); - const std::string key_batch_norm_bwd_p = - key_with_hash + "@batch_norm_bwd_p"; - const std::string key_batch_norm_src_mem_p = - key_with_hash + "@batch_norm_bwd_src_mem_p"; - const std::string key_batch_norm_mean_mem_p = - key_with_hash + "@batch_norm_bwd_mean_mem_p"; - const std::string key_batch_norm_variance_mem_p = - key_with_hash + "@batch_norm_bwd_variance_mem_p"; - const std::string key_batch_norm_scaleshift_mem_p = - key_with_hash + "@batch_norm_bwd_scaleshift_mem_p"; - const std::string key_batch_norm_diff_scaleshift_mem_p = - key_with_hash + "@batch_norm_bwd_diff_scaleshift_mem_p"; - const std::string key_batch_norm_diff_src_mem_p = - key_with_hash + "@batch_norm_bwd_diff_src_mem_p"; - const std::string key_batch_norm_diff_dst_mem_p = - key_with_hash + "@batch_norm_bwd_diff_dst_mem_p"; - - primitive reorder_diff_dst; - bool is_diff_dst_reordered = false; - auto user_diff_dst_memory = memory( - {{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine}, - to_void_cast(diff_y_data)); + BatchNormMKLDNNHandler handler( + src_tz, epsilon, mkldnn::use_scale_shift, dst_format, input_format, + dev_ctx, ctx.GetPlace(), ctx.op().Input("SavedMean")); // MKLDNN requires a single piece of memory for scale and shift/bias data - const size_t scaleshift_size = 2 * ic; - - std::vector scaleshift_data; + const size_t scaleshift_size = 2 * C; + std::vector scaleshift_data(scale->data(), scale->data() + C); scaleshift_data.reserve(scaleshift_size); - copy_to_weights(scale_data, scale_data + ic, shift_data, shift_data + ic, - &scaleshift_data); + scaleshift_data.insert(scaleshift_data.end(), shift->data(), + shift->data() + C); std::vector diff_scaleshift_data; diff_scaleshift_data.reserve(scaleshift_size); - auto batch_norm_fwd_pd = - std::static_pointer_cast( - dev_ctx.GetBlob(key_batch_norm_fwd_pd)); - PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr, - "Fail to find batch_norm_fwd_pd in device context"); - - auto batch_norm_bwd_p = std::static_pointer_cast( - dev_ctx.GetBlob(key_batch_norm_bwd_p)); - - if (batch_norm_bwd_p == nullptr) { - auto src_memory = std::shared_ptr(new memory( - {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, - to_void_cast(x_data))); - - // for diff_dst, try to use same format as dst in forward pass - auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc(); - auto diff_dst_md = diff_dst_pd.desc(); - - // create primitive descriptor for batch norm backward - auto batch_norm_bwd_desc = bn_bwd_types::op_desc{ - mkldnn::prop_kind::backward, diff_dst_md, - src_memory->get_primitive_desc().desc(), epsilon, flags}; - auto batch_norm_bwd_pd = bn_bwd_types::op_prim{ - batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd}; - - // reorder user_diff_dst if it's not in preferred format - auto diff_dst_memory = std::make_shared(user_diff_dst_memory); - if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) { - diff_dst_memory = std::make_shared(diff_dst_pd); - reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory); - is_diff_dst_reordered = true; - } - - // create mkldnn memory for input tensors (src/mean/variance) - auto mean_memory = - std::make_shared(batch_norm_bwd_pd.mean_primitive_desc(), - to_void_cast(batch_mean_data)); - auto variance_memory = - std::make_shared(batch_norm_bwd_pd.variance_primitive_desc(), - to_void_cast(batch_variance_data)); - - // create mkldnn memory for input tensors (scale/shift) - auto scaleshift_memory = std::make_shared( - batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data()); - - // create mkldnn memory for output diff weights (combined scale/shift) - auto diff_scaleshift_memory = std::make_shared( - batch_norm_bwd_pd.diff_weights_primitive_desc(), - diff_scaleshift_data.data()); - - // here assume diff_src is in the same format of src - auto diff_src_memory = std::make_shared( - src_memory->get_primitive_desc(), diff_x_data); - - // finally create batch_norm backward primitive - batch_norm_bwd_p = std::make_shared( - batch_norm_bwd_pd, *src_memory, *mean_memory, *variance_memory, - *diff_dst_memory, *scaleshift_memory, *diff_src_memory, - *diff_scaleshift_memory); - - dev_ctx.SetBlob(key_batch_norm_bwd_p, batch_norm_bwd_p); - dev_ctx.SetBlob(key_batch_norm_src_mem_p, src_memory); - dev_ctx.SetBlob(key_batch_norm_mean_mem_p, mean_memory); - dev_ctx.SetBlob(key_batch_norm_variance_mem_p, variance_memory); - dev_ctx.SetBlob(key_batch_norm_scaleshift_mem_p, scaleshift_memory); - dev_ctx.SetBlob(key_batch_norm_diff_scaleshift_mem_p, - diff_scaleshift_memory); - dev_ctx.SetBlob(key_batch_norm_diff_src_mem_p, diff_src_memory); - dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory); - - // set layout/format of output tensors - diff_x->set_layout(DataLayout::kMKLDNN); - diff_x->set_format( - (MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc() - .desc() - .data.format); - } else { - // primitives already exist - UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data)); - UpdateMemoryData(dev_ctx, key_batch_norm_mean_mem_p, - to_void_cast(batch_mean_data)); - UpdateMemoryData(dev_ctx, key_batch_norm_variance_mem_p, - to_void_cast(batch_variance_data)); - UpdateMemoryData(dev_ctx, key_batch_norm_scaleshift_mem_p, - scaleshift_data.data()); - UpdateMemoryData(dev_ctx, key_batch_norm_diff_scaleshift_mem_p, - diff_scaleshift_data.data()); - auto diff_src_memory = UpdateMemoryData( - dev_ctx, key_batch_norm_diff_src_mem_p, to_void_cast(diff_x_data)); - auto diff_dst_memory = UpdateMemoryData( - dev_ctx, key_batch_norm_diff_dst_mem_p, to_void_cast(diff_y_data)); - - // reorder user_diff_dst if it's not in preferred format - if (diff_dst_memory->get_primitive_desc() != - user_diff_dst_memory.get_primitive_desc()) { - reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory); - is_diff_dst_reordered = true; - } - - // set layout/format of output tensors - diff_x->set_layout(DataLayout::kMKLDNN); - diff_x->set_format( - (MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc() - .desc() - .data.format); - } + auto src_memory = handler.AcquireSrcMemory(x); + auto mean_memory = handler.AcquireMeanMemory(batch_mean); + auto variance_memory = handler.AcquireVarianceMemory(batch_variance); + auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y); + auto scaleshift_memory = + handler.AcquireScaleShiftMemory(scaleshift_data.data()); + auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x); + auto diff_scaleshift_memory = + handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data()); + + // finally create batch_norm backward primitive + auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive( + *src_memory, *mean_memory, *variance_memory, *diff_dst_memory, + *scaleshift_memory, *diff_src_memory, *diff_scaleshift_memory); - // execute optional reorder and batch_norm backward primitive std::vector pipeline; - if (is_diff_dst_reordered) pipeline.push_back(reorder_diff_dst); pipeline.push_back(*batch_norm_bwd_p); stream(stream::kind::eager).submit(pipeline).wait(); + T *diff_scale_data = diff_scale->mutable_data(ctx.GetPlace()); + T *diff_shift_data = diff_shift->mutable_data(ctx.GetPlace()); + // copy back diff sacle/shift to output tensors (diff scale/shift) diff_scaleshift_data.resize(scaleshift_size); auto it = std::begin(diff_scaleshift_data); - std::copy(it, std::next(it, ic), diff_scale_data); - std::copy(std::next(it, ic), std::end(diff_scaleshift_data), + std::copy(it, std::next(it, C), diff_scale_data); + std::copy(std::next(it, C), std::end(diff_scaleshift_data), diff_shift_data); + + // set layout/format of output tensors + diff_x->set_layout(DataLayout::kMKLDNN); + diff_x->set_format(platform::GetMKLDNNFormat(*diff_src_memory)); } }; } // namespace operators