提交 4b65af77 编写于 作者: A Adam 提交者: Tao Luo

MKLDNN BatchNorm operator refactor (#20012)

test=develop
上级 bda7eab7
...@@ -19,136 +19,103 @@ limitations under the License. */ ...@@ -19,136 +19,103 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using batch_norm_bwd = mkldnn::batch_normalization_backward;
using batch_norm_fwd = mkldnn::batch_normalization_forward;
using mkldnn::memory; using mkldnn::memory;
using mkldnn::primitive; using mkldnn::primitive;
using mkldnn::reorder; using mkldnn::reorder;
using mkldnn::stream; using mkldnn::stream;
using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc;
using platform::to_void_cast; using platform::to_void_cast;
namespace {
template <typename T> template <typename T>
struct bn_type_traits { class BatchNormMKLDNNHandler
using op_type = T; : public platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
using op_desc = typename op_type::desc; mkldnn::batch_normalization_backward> {
using op_prim = typename op_type::primitive_desc;
};
class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
public: public:
BatchNormMKLDNNHandler(const platform::MKLDNNDeviceContext &dev_ctx, BatchNormMKLDNNHandler(const std::vector<int> &dims, const float &epsilon,
mkldnn::engine engine, const std::string &base_key) const unsigned &flags, const bool &global_stats,
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {} const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext &dev_ctx,
platform::Place cpu_place,
const std::string &uniq_name)
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, epsilon, flags, global_stats, fmt,
uniq_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(
global_stats == true ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training,
md, epsilon, flags);
}
BatchNormMKLDNNHandler(const std::vector<int> &dims, const float &epsilon,
const unsigned &flags,
const MKLDNNMemoryFormat diff_fmt,
const MKLDNNMemoryFormat src_fmt,
const platform::MKLDNNDeviceContext &dev_ctx,
platform::Place cpu_place,
const std::string &uniq_name)
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, epsilon, flags, false, src_fmt,
uniq_name)) {
auto diff_dst_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt);
this->AcquireBackwardPrimitiveDescriptor(
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, flags);
}
std::shared_ptr<memory> AcquireScaleshiftMemoryFromPrimitive(void *ptr) { std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(T *scaleshift_data) {
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
batch_norm_pd_->weights_primitive_desc(), ptr, "@scaleshift_mem_p"); this->fwd_pd_->weights_primitive_desc(), scaleshift_data,
"@scaleshift_mem_p");
} }
std::shared_ptr<memory> AcquireMeanMemoryFromPrimitive(void *ptr) { std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory(
T *diff_scaleshift_data) {
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
batch_norm_pd_->mean_primitive_desc(), ptr, "@mean_mem_p"); this->bwd_pd_->diff_weights_primitive_desc(), diff_scaleshift_data,
"@diff_scaleshift_mem_p");
} }
std::shared_ptr<memory> AcquireVarianceMemoryFromPrimitive(void *ptr) { std::shared_ptr<mkldnn::memory> AcquireMeanMemory(
const framework::Tensor *mean) {
const T *mean_data = mean->data<T>();
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p"); this->fwd_pd_->mean_primitive_desc(), to_void_cast<T>(mean_data),
"@mean_mem_p");
} }
template <typename T> std::shared_ptr<mkldnn::memory> AcquireMeanMemory(framework::Tensor *mean) {
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive( T *mean_data = mean->mutable_data<T>(
framework::Tensor *output, platform::Place place) { this->place_, this->fwd_pd_->mean_primitive_desc().get_size());
T *ptr = output->mutable_data<T>(
place, batch_norm_pd_->dst_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
batch_norm_pd_->dst_primitive_desc(), ptr, "@dst_mem_p"); this->fwd_pd_->mean_primitive_desc(), mean_data, "@mean_mem_p");
} }
std::shared_ptr<batch_norm_fwd::primitive_desc> std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
AcquireBatchNormPrimitiveDescriptor(const batch_norm_fwd::desc &bn_fwd_desc, const framework::Tensor *variance) {
const mkldnn::engine &engine) { const T *variance_data = variance->data<T>();
// BatchNorm PD has to be passed to Grad op that return this->AcquireMemoryFromPrimitive(
// may be executed by diffrent thread, hence this->fwd_pd_->variance_primitive_desc(),
// for that one we use key that does not contain TID to_void_cast<T>(variance_data), "@variance_mem_p");
const std::string key_batch_norm_fwd_pd = key_common_ + "@bn_fwd_pd";
batch_norm_pd_ = std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx_.GetBlob(key_batch_norm_fwd_pd));
if (batch_norm_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
batch_norm_pd_ = std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx_.GetBlob(key_batch_norm_fwd_pd));
if (batch_norm_pd_ == nullptr) {
batch_norm_pd_.reset(
new batch_norm_fwd::primitive_desc(bn_fwd_desc, engine));
dev_ctx_.SetBlob(key_batch_norm_fwd_pd, batch_norm_pd_);
}
}
return batch_norm_pd_;
} }
std::shared_ptr<batch_norm_fwd> AcquireTestTrainingBatchNormFwd( std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
std::shared_ptr<memory> src_memory, framework::Tensor *variance) {
std::shared_ptr<memory> scaleshift_memory, T *variance_data = variance->mutable_data<T>(
std::shared_ptr<memory> dst_memory, std::shared_ptr<memory> mean_memory, this->place_, this->fwd_pd_->variance_primitive_desc().get_size());
std::shared_ptr<memory> variance_memory, bool is_test) { return this->AcquireMemoryFromPrimitive(
auto prim_key = key_ + "@batch_norm_p"; this->fwd_pd_->variance_primitive_desc(), variance_data,
auto batch_norm_p = "@variance_mem_p");
std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key));
if (batch_norm_p == nullptr) {
if (is_test) {
batch_norm_p = std::make_shared<batch_norm_fwd>(
*batch_norm_pd_, *src_memory,
(const mkldnn::primitive::at &)*mean_memory,
(const mkldnn::primitive::at &)*variance_memory, *scaleshift_memory,
*dst_memory);
} else {
batch_norm_p = std::make_shared<batch_norm_fwd>(
*batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory,
*mean_memory, *variance_memory);
}
dev_ctx_.SetBlob(prim_key, batch_norm_p);
}
return batch_norm_p;
} }
private:
std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_pd_;
}; };
std::shared_ptr<memory> UpdateMemoryData(
const platform::MKLDNNDeviceContext &dev_ctx, const std::string &key,
void *new_ptr) {
auto mem = std::static_pointer_cast<memory>(dev_ctx.GetBlob(key));
PADDLE_ENFORCE(
mem != nullptr,
(std::string("Fail to find memory in device context [key: ") + key + "]")
.c_str());
mem->set_data_handle(new_ptr);
return mem;
}
template <typename T, typename Container>
void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end,
Container *c) {
auto it = std::begin(*c);
std::copy(scale_begin, scale_end, std::inserter(*c, it));
std::copy(
shift_begin, shift_end,
std::inserter(*c, std::next(it, std::distance(scale_begin, scale_end))));
}
} // namespace
template <typename T> template <typename T>
class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
...@@ -158,14 +125,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -158,14 +125,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats"); const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu"); const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
bool global_stats = is_test || use_global_stats;
const auto *x = ctx.Input<Tensor>("X"); bool global_stats = is_test || use_global_stats;
const auto *mean = ctx.Input<Tensor>("Mean");
const auto *variance = ctx.Input<Tensor>("Variance");
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *shift = ctx.Input<Tensor>("Bias");
auto *y = ctx.Output<Tensor>("Y"); auto *y = ctx.Output<Tensor>("Y");
auto *mean_out = ctx.Output<Tensor>("MeanOut"); auto *mean_out = ctx.Output<Tensor>("MeanOut");
...@@ -173,102 +140,61 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -173,102 +140,61 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto *batch_mean = ctx.Output<Tensor>("SavedMean"); auto *batch_mean = ctx.Output<Tensor>("SavedMean");
auto *batch_variance = ctx.Output<Tensor>("SavedVariance"); auto *batch_variance = ctx.Output<Tensor>("SavedVariance");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *shift = ctx.Input<Tensor>("Bias");
PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
"Wrong layout set for X tensor"); "Wrong layout set for X tensor");
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for X tensor"); "Wrong format set for X tensor");
const T *x_data = x->data<T>();
const T *mean_data = mean->data<T>();
const T *variance_data = variance->data<T>();
T *mean_out_data = mean_out->mutable_data<T>(ctx.GetPlace());
T *variance_out_data = variance_out->mutable_data<T>(ctx.GetPlace());
T *batch_mean_data = nullptr;
T *batch_variance_data = nullptr;
if (!global_stats) {
batch_mean_data = batch_mean->mutable_data<T>(ctx.GetPlace());
batch_variance_data = batch_variance->mutable_data<T>(ctx.GetPlace());
}
auto propagation = global_stats == true
? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
auto src_tz = paddle::framework::vectorize<int>(x->dims()); auto src_tz = paddle::framework::vectorize<int>(x->dims());
auto scale_tz = paddle::framework::vectorize<int>(scale->dims()); auto scale_tz = paddle::framework::vectorize<int>(scale->dims());
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1"); PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
const unsigned int ic = scale_tz[0]; const unsigned int C = scale_tz[0];
// MKLDNN requires a single piece of memory for scale and shift/bias data // MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic; std::vector<T> scaleshift_data(scale->data<T>(), scale->data<T>() + C);
std::vector<T> scaleshift_data; scaleshift_data.reserve(2 * C);
scaleshift_data.reserve(scaleshift_size); scaleshift_data.insert(scaleshift_data.end(), shift->data<T>(),
shift->data<T>() + C);
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(),
shift->data<T>() + ic, &scaleshift_data); // Flags are added by bitwise OR operation
unsigned flags = mkldnn::use_scale_shift; // 001
unsigned flags = mkldnn::use_scale_shift; if (global_stats) flags |= mkldnn::use_global_stats; // 010
if (global_stats) flags |= mkldnn::use_global_stats; if (fuse_with_relu && is_test) flags |= mkldnn::fuse_bn_relu; // 100
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
BatchNormMKLDNNHandler<T> handler(
// create mkldnn memory from input x tensor src_tz, epsilon, flags, global_stats,
MKLDNNMemoryFormat input_format = platform::MKLDNNFormatForSize(src_tz.size(), x->format()), dev_ctx,
platform::MKLDNNFormatForSize(src_tz.size(), x->format()); ctx.GetPlace(), ctx.op().Output("SavedMean"));
// keys for backward pass auto src_memory = handler.AcquireSrcMemory(x);
const std::string key =
platform::CreateKey(src_tz, epsilon, flags, global_stats, input_format,
ctx.op().Output("SavedMean"));
BatchNormMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input_format);
// create primitive descriptor for batch norm forward
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
auto batch_norm_fwd_desc =
bn_fwd_types::op_desc{propagation, user_src_md, epsilon, flags};
auto batch_norm_fwd_pd = handler.AcquireBatchNormPrimitiveDescriptor(
batch_norm_fwd_desc, mkldnn_engine);
auto src_memory =
handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data));
// crate mkldnn memory for weights(scale/shift)
auto scaleshift_memory = auto scaleshift_memory =
handler.AcquireScaleshiftMemoryFromPrimitive(scaleshift_data.data()); handler.AcquireScaleShiftMemory(scaleshift_data.data());
auto dst_memory = handler.AcquireDstMemory(y);
// create mkldnn memory for output y tensor
auto dst_memory =
handler.AcquireDstMemoryFromPrimitive<T>(y, ctx.GetPlace());
std::shared_ptr<batch_norm_fwd> batch_norm_p; std::shared_ptr<mkldnn::batch_normalization_forward> batch_norm_p;
if (global_stats) { if (global_stats) {
// create mkldnn memory for stats (as input) // mean and variance are taken from input Tensor
std::shared_ptr<memory> mean_memory = const auto *mean = ctx.Input<Tensor>("Mean");
handler.AcquireMeanMemoryFromPrimitive(to_void_cast(mean_data)); const auto *variance = ctx.Input<Tensor>("Variance");
std::shared_ptr<memory> mean_memory = handler.AcquireMeanMemory(mean);
std::shared_ptr<memory> variance_memory = std::shared_ptr<memory> variance_memory =
handler.AcquireVarianceMemoryFromPrimitive( handler.AcquireVarianceMemory(variance);
to_void_cast(variance_data));
batch_norm_p = handler.AcquireTestTrainingBatchNormFwd( batch_norm_p = handler.AcquireForwardPrimitive(
src_memory, scaleshift_memory, dst_memory, mean_memory, *src_memory, (const mkldnn::primitive::at &)*mean_memory,
variance_memory, true); (const mkldnn::primitive::at &)*variance_memory, *scaleshift_memory,
*dst_memory);
} else { } else {
// create mkldnn memory for stats (as output) // mean and variance are calculated and saved in output Tensor
std::shared_ptr<memory> mean_memory = std::shared_ptr<memory> mean_memory =
handler.AcquireMeanMemoryFromPrimitive(batch_mean_data); handler.AcquireMeanMemory(batch_mean);
std::shared_ptr<memory> variance_memory = std::shared_ptr<memory> variance_memory =
handler.AcquireVarianceMemoryFromPrimitive(batch_variance_data); handler.AcquireVarianceMemory(batch_variance);
batch_norm_p = handler.AcquireTestTrainingBatchNormFwd( batch_norm_p = handler.AcquireForwardPrimitive(
src_memory, scaleshift_memory, dst_memory, mean_memory, *src_memory, *scaleshift_memory, *dst_memory, *mean_memory,
variance_memory, false); *variance_memory);
} }
y->set_layout(DataLayout::kMKLDNN); y->set_layout(DataLayout::kMKLDNN);
...@@ -281,18 +207,20 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -281,18 +207,20 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (!global_stats) { if (!global_stats) {
// mkldnn only compute stats for current batch // mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib // so we need compute momentum stats via Eigen lib
EigenVectorArrayMap<T> batch_mean_e(batch_mean_data, ic); EigenVectorArrayMap<T> batch_mean_e(
EigenVectorArrayMap<T> batch_variance_e(batch_variance_data, ic); batch_mean->mutable_data<T>(ctx.GetPlace()), C);
ConstEigenVectorArrayMap<T> mean_e(mean_data, ic); EigenVectorArrayMap<T> batch_variance_e(
ConstEigenVectorArrayMap<T> variance_e{variance_data, ic}; batch_variance->mutable_data<T>(ctx.GetPlace()), C);
EigenVectorArrayMap<T> running_mean_e(mean_out_data, ic); EigenVectorArrayMap<T> running_mean_e(
EigenVectorArrayMap<T> running_variance_e(variance_out_data, ic); mean_out->mutable_data<T>(ctx.GetPlace()), C);
EigenVectorArrayMap<T> running_variance_e(
auto one_minus_momentum = 1. - momentum; variance_out->mutable_data<T>(ctx.GetPlace()), C);
running_mean_e = mean_e * momentum + batch_mean_e * one_minus_momentum;
running_mean_e =
running_mean_e * momentum + batch_mean_e * (1. - momentum);
running_variance_e = running_variance_e =
variance_e * momentum + batch_variance_e * one_minus_momentum; running_variance_e * momentum + batch_variance_e * (1. - momentum);
} }
} }
}; };
...@@ -311,7 +239,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -311,7 +239,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const auto *shift = ctx.Input<Tensor>("Bias"); const auto *shift = ctx.Input<Tensor>("Bias");
const auto *batch_mean = ctx.Input<Tensor>("SavedMean"); const auto *batch_mean = ctx.Input<Tensor>("SavedMean");
const auto *batch_variance = ctx.Input<Tensor>("SavedVariance"); const auto *batch_variance = ctx.Input<Tensor>("SavedVariance");
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Y")); const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
...@@ -322,27 +249,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -322,27 +249,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for Input diff_y tensor"); "Wrong format set for Input diff_y tensor");
const T *x_data = x->data<T>();
const T *diff_y_data = diff_y->data<T>();
const T *batch_mean_data = batch_mean->data<T>();
const T *batch_variance_data = batch_variance->data<T>();
const T *scale_data = scale->data<T>();
const T *shift_data = shift->data<T>();
T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace());
T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace());
auto src_tz = paddle::framework::vectorize<int>(x->dims()); auto src_tz = paddle::framework::vectorize<int>(x->dims());
auto diff_src_tz = src_tz;
auto dst_tz = src_tz;
auto diff_dst_tz = dst_tz;
auto scale_tz = paddle::framework::vectorize<int>(scale->dims()); auto scale_tz = paddle::framework::vectorize<int>(scale->dims());
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1"); PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
const unsigned int ic = scale_tz[0]; const unsigned int C = scale_tz[0];
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
MKLDNNMemoryFormat dst_format = MKLDNNMemoryFormat dst_format =
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format()); platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
...@@ -350,170 +261,52 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -350,170 +261,52 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
MKLDNNMemoryFormat input_format = MKLDNNMemoryFormat input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format()); platform::MKLDNNFormatForSize(src_tz.size(), x->format());
unsigned flags = mkldnn::use_scale_shift; BatchNormMKLDNNHandler<T> handler(
src_tz, epsilon, mkldnn::use_scale_shift, dst_format, input_format,
// keys from forward pass dev_ctx, ctx.GetPlace(), ctx.op().Input("SavedMean"));
const std::string key =
platform::CreateKey(src_tz, epsilon, flags, false, input_format,
ctx.op().Input("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
// keys for primitives reuse
const std::string key_with_hash =
key + platform::CreateKey(src_tz, epsilon, flags, false, input_format);
const std::string key_batch_norm_bwd_p =
key_with_hash + "@batch_norm_bwd_p";
const std::string key_batch_norm_src_mem_p =
key_with_hash + "@batch_norm_bwd_src_mem_p";
const std::string key_batch_norm_mean_mem_p =
key_with_hash + "@batch_norm_bwd_mean_mem_p";
const std::string key_batch_norm_variance_mem_p =
key_with_hash + "@batch_norm_bwd_variance_mem_p";
const std::string key_batch_norm_scaleshift_mem_p =
key_with_hash + "@batch_norm_bwd_scaleshift_mem_p";
const std::string key_batch_norm_diff_scaleshift_mem_p =
key_with_hash + "@batch_norm_bwd_diff_scaleshift_mem_p";
const std::string key_batch_norm_diff_src_mem_p =
key_with_hash + "@batch_norm_bwd_diff_src_mem_p";
const std::string key_batch_norm_diff_dst_mem_p =
key_with_hash + "@batch_norm_bwd_diff_dst_mem_p";
primitive reorder_diff_dst;
bool is_diff_dst_reordered = false;
auto user_diff_dst_memory = memory(
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
to_void_cast(diff_y_data));
// MKLDNN requires a single piece of memory for scale and shift/bias data // MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic; const size_t scaleshift_size = 2 * C;
std::vector<T> scaleshift_data(scale->data<T>(), scale->data<T>() + C);
std::vector<T> scaleshift_data;
scaleshift_data.reserve(scaleshift_size); scaleshift_data.reserve(scaleshift_size);
copy_to_weights(scale_data, scale_data + ic, shift_data, shift_data + ic, scaleshift_data.insert(scaleshift_data.end(), shift->data<T>(),
&scaleshift_data); shift->data<T>() + C);
std::vector<T> diff_scaleshift_data; std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size); diff_scaleshift_data.reserve(scaleshift_size);
auto batch_norm_fwd_pd = auto src_memory = handler.AcquireSrcMemory(x);
std::static_pointer_cast<batch_norm_fwd::primitive_desc>( auto mean_memory = handler.AcquireMeanMemory(batch_mean);
dev_ctx.GetBlob(key_batch_norm_fwd_pd)); auto variance_memory = handler.AcquireVarianceMemory(batch_variance);
PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr, auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y);
"Fail to find batch_norm_fwd_pd in device context"); auto scaleshift_memory =
handler.AcquireScaleShiftMemory(scaleshift_data.data());
auto batch_norm_bwd_p = std::static_pointer_cast<batch_norm_bwd>( auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x);
dev_ctx.GetBlob(key_batch_norm_bwd_p)); auto diff_scaleshift_memory =
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());
if (batch_norm_bwd_p == nullptr) {
auto src_memory = std::shared_ptr<memory>(new memory( // finally create batch_norm backward primitive
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive(
to_void_cast(x_data))); *src_memory, *mean_memory, *variance_memory, *diff_dst_memory,
*scaleshift_memory, *diff_src_memory, *diff_scaleshift_memory);
// for diff_dst, try to use same format as dst in forward pass
auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc();
auto diff_dst_md = diff_dst_pd.desc();
// create primitive descriptor for batch norm backward
auto batch_norm_bwd_desc = bn_bwd_types::op_desc{
mkldnn::prop_kind::backward, diff_dst_md,
src_memory->get_primitive_desc().desc(), epsilon, flags};
auto batch_norm_bwd_pd = bn_bwd_types::op_prim{
batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd};
// reorder user_diff_dst if it's not in preferred format
auto diff_dst_memory = std::make_shared<memory>(user_diff_dst_memory);
if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory = std::make_shared<memory>(diff_dst_pd);
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory);
is_diff_dst_reordered = true;
}
// create mkldnn memory for input tensors (src/mean/variance)
auto mean_memory =
std::make_shared<memory>(batch_norm_bwd_pd.mean_primitive_desc(),
to_void_cast(batch_mean_data));
auto variance_memory =
std::make_shared<memory>(batch_norm_bwd_pd.variance_primitive_desc(),
to_void_cast(batch_variance_data));
// create mkldnn memory for input tensors (scale/shift)
auto scaleshift_memory = std::make_shared<memory>(
batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data());
// create mkldnn memory for output diff weights (combined scale/shift)
auto diff_scaleshift_memory = std::make_shared<memory>(
batch_norm_bwd_pd.diff_weights_primitive_desc(),
diff_scaleshift_data.data());
// here assume diff_src is in the same format of src
auto diff_src_memory = std::make_shared<memory>(
src_memory->get_primitive_desc(), diff_x_data);
// finally create batch_norm backward primitive
batch_norm_bwd_p = std::make_shared<batch_norm_bwd>(
batch_norm_bwd_pd, *src_memory, *mean_memory, *variance_memory,
*diff_dst_memory, *scaleshift_memory, *diff_src_memory,
*diff_scaleshift_memory);
dev_ctx.SetBlob(key_batch_norm_bwd_p, batch_norm_bwd_p);
dev_ctx.SetBlob(key_batch_norm_src_mem_p, src_memory);
dev_ctx.SetBlob(key_batch_norm_mean_mem_p, mean_memory);
dev_ctx.SetBlob(key_batch_norm_variance_mem_p, variance_memory);
dev_ctx.SetBlob(key_batch_norm_scaleshift_mem_p, scaleshift_memory);
dev_ctx.SetBlob(key_batch_norm_diff_scaleshift_mem_p,
diff_scaleshift_memory);
dev_ctx.SetBlob(key_batch_norm_diff_src_mem_p, diff_src_memory);
dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory);
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format(
(MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
} else {
// primitives already exist
UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data));
UpdateMemoryData(dev_ctx, key_batch_norm_mean_mem_p,
to_void_cast(batch_mean_data));
UpdateMemoryData(dev_ctx, key_batch_norm_variance_mem_p,
to_void_cast(batch_variance_data));
UpdateMemoryData(dev_ctx, key_batch_norm_scaleshift_mem_p,
scaleshift_data.data());
UpdateMemoryData(dev_ctx, key_batch_norm_diff_scaleshift_mem_p,
diff_scaleshift_data.data());
auto diff_src_memory = UpdateMemoryData(
dev_ctx, key_batch_norm_diff_src_mem_p, to_void_cast(diff_x_data));
auto diff_dst_memory = UpdateMemoryData(
dev_ctx, key_batch_norm_diff_dst_mem_p, to_void_cast(diff_y_data));
// reorder user_diff_dst if it's not in preferred format
if (diff_dst_memory->get_primitive_desc() !=
user_diff_dst_memory.get_primitive_desc()) {
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory);
is_diff_dst_reordered = true;
}
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format(
(MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
}
// execute optional reorder and batch_norm backward primitive
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
if (is_diff_dst_reordered) pipeline.push_back(reorder_diff_dst);
pipeline.push_back(*batch_norm_bwd_p); pipeline.push_back(*batch_norm_bwd_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace());
T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace());
// copy back diff sacle/shift to output tensors (diff scale/shift) // copy back diff sacle/shift to output tensors (diff scale/shift)
diff_scaleshift_data.resize(scaleshift_size); diff_scaleshift_data.resize(scaleshift_size);
auto it = std::begin(diff_scaleshift_data); auto it = std::begin(diff_scaleshift_data);
std::copy(it, std::next(it, ic), diff_scale_data); std::copy(it, std::next(it, C), diff_scale_data);
std::copy(std::next(it, ic), std::end(diff_scaleshift_data), std::copy(std::next(it, C), std::end(diff_scaleshift_data),
diff_shift_data); diff_shift_data);
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
} }
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册