diff --git a/paddle/fluid/operators/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/batch_norm_mkldnn_op.cc index cd1fb754a10e62c71a5559144882e4380175c18f..de641cb08e4cc3322cc8387d873f2aaab279e1dd 100644 --- a/paddle/fluid/operators/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/batch_norm_mkldnn_op.cc @@ -62,56 +62,42 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p"); } - std::shared_ptr AcquireTestBatchNormFwd( + std::shared_ptr AcquireTestTrainingBatchNormFwd( std::shared_ptr src_memory, - const mkldnn::primitive::at &mean_memory, - const mkldnn::primitive::at &variance_memory, std::shared_ptr scaleshift_memory, - std::shared_ptr dst_memory) { + std::shared_ptr dst_memory, std::shared_ptr mean_memory, + std::shared_ptr variance_memory, bool is_test) { auto prim_key = key_ + "@batch_norm_p"; auto batch_norm_p = std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE( - (batch_norm_p != nullptr) || (is_reusing_ == false), - "Fail to find batch norm primitive for test in device context"); - if (batch_norm_p == nullptr) { - batch_norm_p = std::make_shared( - *batch_norm_pd_, *src_memory, mean_memory, variance_memory, - *scaleshift_memory, *dst_memory); - dev_ctx_.SetBlob(prim_key, batch_norm_p); - } else { - is_reusing_ = true; - } - return batch_norm_p; - } + PADDLE_ENFORCE((batch_norm_p != nullptr) || !is_reusing_, + "Fail to find batch norm primitive in device context"); - std::shared_ptr AcquireTrainingBatchNormFwd( - std::shared_ptr src_memory, - std::shared_ptr scaleshift_memory, - std::shared_ptr dst_memory, std::shared_ptr mean_memory, - std::shared_ptr variance_memory) { - auto prim_key = key_ + "@batch_norm_p"; - auto batch_norm_p = - std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE( - (batch_norm_p != nullptr) || (is_reusing_ == false), - "Fail to find batch norm primitive for training in device context"); if (batch_norm_p == nullptr) { - batch_norm_p = std::make_shared( - *batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory, - *mean_memory, *variance_memory); + if (is_test) { + batch_norm_p = std::make_shared( + *batch_norm_pd_, *src_memory, + (const mkldnn::primitive::at &)*mean_memory, + (const mkldnn::primitive::at &)*variance_memory, *scaleshift_memory, + *dst_memory); + } else { + batch_norm_p = std::make_shared( + *batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory, + *mean_memory, *variance_memory); + } dev_ctx_.SetBlob(prim_key, batch_norm_p); } else { is_reusing_ = true; } + return batch_norm_p; } - // + static std::string GetHash(const memory::dims &input_dims, float epsilon, unsigned flag, bool is_test, memory::format format, - const std::string &suffix) { + const std::string &suffix = "") { auto dims2str = [](const memory::dims &operand_dims) { std::string dstr = ""; for (size_t i = 0; i < operand_dims.size(); ++i) { @@ -128,19 +114,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { std::shared_ptr batch_norm_pd_; }; -std::string gethash(const memory::dims &input_dims, float epsilon, - unsigned flag, bool is_test, memory::format format) { - auto dims2str = [](const memory::dims &operand_dims) { - std::string dstr = ""; - for (size_t i = 0; i < operand_dims.size(); ++i) { - dstr += std::to_string(operand_dims[i]) + "-"; - } - return dstr; - }; - return dims2str(input_dims) + std::to_string(epsilon) + std::to_string(flag) + - std::to_string(is_test) + std::to_string(format); -} - std::shared_ptr UpdateMemoryData( const platform::MKLDNNDeviceContext &dev_ctx, const std::string &key, void *new_ptr) { @@ -274,10 +247,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { handler.AcquireVarianceMemoryFromPrimitive( to_void_cast(variance_data)); - batch_norm_p = handler.AcquireTestBatchNormFwd( - src_memory, (const mkldnn::primitive::at &)*mean_memory, - (const mkldnn::primitive::at &)*variance_memory, scaleshift_memory, - dst_memory); + batch_norm_p = handler.AcquireTestTrainingBatchNormFwd( + src_memory, scaleshift_memory, dst_memory, mean_memory, + variance_memory, true); } else { // create mkldnn memory for stats (as output) std::shared_ptr mean_memory = @@ -285,9 +257,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { std::shared_ptr variance_memory = handler.AcquireVarianceMemoryFromPrimitive(batch_variance_data); - batch_norm_p = handler.AcquireTrainingBatchNormFwd( + batch_norm_p = handler.AcquireTestTrainingBatchNormFwd( src_memory, scaleshift_memory, dst_memory, mean_memory, - variance_memory); + variance_memory, false); } y->set_layout(DataLayout::kMKLDNN); @@ -377,7 +349,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { // keys for primitives reuse const std::string key_with_hash = - key + gethash(src_tz, epsilon, flags, false, input_format); + key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false, + input_format); const std::string key_batch_norm_bwd_p = key_with_hash + "@batch_norm_bwd_p"; const std::string key_batch_norm_src_mem_p =