提交 50d3e6e9 编写于 作者: K Krzysztof Binias

Reusing primitives for forward Batch Norm operator

上级 ef7bd03a
...@@ -37,6 +37,122 @@ struct bn_type_traits { ...@@ -37,6 +37,122 @@ struct bn_type_traits {
using op_prim = typename op_type::primitive_desc; using op_prim = typename op_type::primitive_desc;
}; };
class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
public:
BatchNormMKLDNNHandler(
std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_pd,
const platform::MKLDNNDeviceContext &dev_ctx, mkldnn::engine engine,
const std::string &base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {
batch_norm_pd_ = batch_norm_pd;
}
std::shared_ptr<memory> AcquireScaleshiftMemoryFromPrimitive(void *ptr) {
return this->AcquireMemoryFromPrimitive(
batch_norm_pd_->weights_primitive_desc(), ptr, "@scaleshift_mem_p");
}
std::shared_ptr<memory> AcquireMeanMemoryFromPrimitive(void *ptr) {
return this->AcquireMemoryFromPrimitive(
batch_norm_pd_->mean_primitive_desc(), ptr, "@mean_mem_p");
}
std::shared_ptr<memory> AcquireVarianceMemoryFromPrimitive(void *ptr) {
return this->AcquireMemoryFromPrimitive(
batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p");
}
std::shared_ptr<batch_norm_fwd> AcquireTestBatchNormFwd(
std::shared_ptr<memory> src_memory,
const mkldnn::primitive::at &mean_memory,
const mkldnn::primitive::at &variance_memory,
std::shared_ptr<memory> scaleshift_memory,
std::shared_ptr<memory> dst_memory) {
auto prim_key = key_ + "@batch_norm_p";
auto batch_norm_p =
std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(batch_norm_p != nullptr) || (is_reusing_ == false),
"Fail to find batch norm primitive for test in device context");
if (batch_norm_p == nullptr) {
batch_norm_p = std::make_shared<batch_norm_fwd>(
*batch_norm_pd_, *src_memory, mean_memory, variance_memory,
*scaleshift_memory, *dst_memory);
dev_ctx_.SetBlob(prim_key, batch_norm_p);
} else {
is_reusing_ = true;
}
return batch_norm_p;
}
std::shared_ptr<batch_norm_fwd> AcquireTrainingBatchNormFwd(
std::shared_ptr<memory> src_memory,
std::shared_ptr<memory> scaleshift_memory,
std::shared_ptr<memory> dst_memory, std::shared_ptr<memory> mean_memory,
std::shared_ptr<memory> variance_memory) {
auto prim_key = key_ + "@batch_norm_p";
auto batch_norm_p =
std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(batch_norm_p != nullptr) || (is_reusing_ == false),
"Fail to find batch norm primitive for training in device context");
if (batch_norm_p == nullptr) {
batch_norm_p = std::make_shared<batch_norm_fwd>(
*batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory,
*mean_memory, *variance_memory);
dev_ctx_.SetBlob(prim_key, batch_norm_p);
} else {
is_reusing_ = true;
}
return batch_norm_p;
}
//
static std::string GetHash(const memory::dims &input_dims, float epsilon,
unsigned flag, bool is_test, memory::format format,
const std::string &suffix) {
auto dims2str = [](const memory::dims &operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
};
return dims2str(input_dims) + std::to_string(epsilon) +
std::to_string(flag) + std::to_string(is_test) +
std::to_string(format) + suffix;
}
private:
std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_pd_;
};
std::string gethash(const memory::dims &input_dims, float epsilon,
unsigned flag, bool is_test, memory::format format) {
auto dims2str = [](const memory::dims &operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
};
return dims2str(input_dims) + std::to_string(epsilon) + std::to_string(flag) +
std::to_string(is_test) + std::to_string(format);
}
std::shared_ptr<memory> UpdateMemoryData(
const platform::MKLDNNDeviceContext &dev_ctx, const std::string &key,
void *new_ptr) {
auto mem = std::static_pointer_cast<memory>(dev_ctx.GetBlob(key));
PADDLE_ENFORCE(
mem != nullptr,
(std::string("Fail to find memory in device context [key: ") + key + "]")
.c_str());
mem->set_data_handle(new_ptr);
return mem;
}
template <typename T, typename Container> template <typename T, typename Container>
void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end, void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end,
Container *c) { Container *c) {
...@@ -48,15 +164,6 @@ void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end, ...@@ -48,15 +164,6 @@ void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end,
std::inserter(*c, std::next(it, std::distance(scale_begin, scale_end)))); std::inserter(*c, std::next(it, std::distance(scale_begin, scale_end))));
} }
template <typename Op, typename... Args>
void run_batch_norm_op(Args &&... args) {
Op batch_norm_op{args...};
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(batch_norm_op);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
} // namespace } // namespace
template <typename T> template <typename T>
...@@ -110,6 +217,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -110,6 +217,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1"); PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
const unsigned int ic = scale_tz[0]; const unsigned int ic = scale_tz[0];
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic;
std::vector<T> scaleshift_data;
scaleshift_data.reserve(scaleshift_size);
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(),
shift->data<T>() + ic, &scaleshift_data);
unsigned flags = mkldnn::use_scale_shift; unsigned flags = mkldnn::use_scale_shift;
if (is_test) flags |= mkldnn::use_global_stats; if (is_test) flags |= mkldnn::use_global_stats;
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
...@@ -118,64 +233,70 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -118,64 +233,70 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::format input_format = mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format()); platform::MKLDNNFormatForSize(src_tz.size(), x->format());
auto src_memory = memory( // keys for backward pass
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, const std::string key = BatchNormMKLDNNHandler::GetHash(
to_void_cast(x_data)); src_tz, epsilon, flags, is_test, input_format,
ctx.op().Output("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input_format);
// create primitive descriptor for batch norm forward // create primitive descriptor for batch norm forward
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>; using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
auto batch_norm_fwd_desc = bn_fwd_types::op_desc{ auto batch_norm_fwd_desc =
propagation, src_memory.get_primitive_desc().desc(), epsilon, flags}; bn_fwd_types::op_desc{propagation, user_src_md, epsilon, flags};
std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_fwd_pd = auto batch_norm_fwd_pd = std::make_shared<batch_norm_fwd::primitive_desc>(
std::shared_ptr<batch_norm_fwd::primitive_desc>( batch_norm_fwd_desc, mkldnn_engine);
new batch_norm_fwd::primitive_desc(batch_norm_fwd_desc, // Save conv_pd/src_memory/weights_memory for backward pass
mkldnn_engine));
// Save the pd to be used in backward pass
const std::string key = ctx.op().Output("SavedMean");
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
dev_ctx.SetBlob(key_batch_norm_fwd_pd, batch_norm_fwd_pd); dev_ctx.SetBlob(key_batch_norm_fwd_pd, batch_norm_fwd_pd);
// MKLDNN requires a single piece of memory for scale and shift/bias data BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine,
const size_t scaleshift_size = 2 * ic; key);
std::vector<T> scaleshift_data;
scaleshift_data.reserve(scaleshift_size);
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(), auto src_memory =
shift->data<T>() + ic, &scaleshift_data); handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data));
// crate mkldnn memory for weights(scale/shift) // crate mkldnn memory for weights(scale/shift)
auto scaleshift_memory = memory(batch_norm_fwd_pd->weights_primitive_desc(), auto scaleshift_memory =
scaleshift_data.data()); handler.AcquireScaleshiftMemoryFromPrimitive(scaleshift_data.data());
// create mkldnn memory for output y tensor // create mkldnn memory for output y tensor
auto dst_memory = memory(batch_norm_fwd_pd->dst_primitive_desc(), y_data); auto dst_memory = handler.AcquireDstMemory(
batch_norm_fwd_pd->dst_primitive_desc().desc(), y_data);
std::shared_ptr<batch_norm_fwd> batch_norm_p;
if (is_test) { if (is_test) {
// create mkldnn memory for stats (as input) // create mkldnn memory for stats (as input)
auto mean_memory = memory(batch_norm_fwd_pd->mean_primitive_desc(), std::shared_ptr<memory> mean_memory =
to_void_cast(mean_data)); handler.AcquireMeanMemoryFromPrimitive(to_void_cast(mean_data));
auto variance_memory = std::shared_ptr<memory> variance_memory =
memory(batch_norm_fwd_pd->variance_primitive_desc(), handler.AcquireVarianceMemoryFromPrimitive(
to_void_cast(variance_data)); to_void_cast(variance_data));
run_batch_norm_op<typename bn_fwd_types::op_type>( batch_norm_p = handler.AcquireTestBatchNormFwd(
*batch_norm_fwd_pd, src_memory, src_memory, (const mkldnn::primitive::at &)*mean_memory,
(const mkldnn::primitive::at &)mean_memory, (const mkldnn::primitive::at &)*variance_memory, scaleshift_memory,
(const mkldnn::primitive::at &)variance_memory, scaleshift_memory,
dst_memory); dst_memory);
} else { } else {
// create mkldnn memory for stats (as output) // create mkldnn memory for stats (as output)
auto mean_memory = std::shared_ptr<memory> mean_memory =
memory(batch_norm_fwd_pd->mean_primitive_desc(), batch_mean_data); handler.AcquireMeanMemoryFromPrimitive(batch_mean_data);
auto variance_memory = memory( std::shared_ptr<memory> variance_memory =
batch_norm_fwd_pd->variance_primitive_desc(), batch_variance_data); handler.AcquireVarianceMemoryFromPrimitive(batch_variance_data);
run_batch_norm_op<bn_fwd_types::op_type>(*batch_norm_fwd_pd, src_memory, batch_norm_p = handler.AcquireTrainingBatchNormFwd(
scaleshift_memory, dst_memory, src_memory, scaleshift_memory, dst_memory, mean_memory,
mean_memory, variance_memory); variance_memory);
} }
y->set_layout(DataLayout::kMKLDNN);
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*batch_norm_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
if (!is_test) { if (!is_test) {
// mkldnn only compute stats for current batch // mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib // so we need compute momentum stats via Eigen lib
...@@ -192,10 +313,6 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -192,10 +313,6 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
running_variance_e = running_variance_e =
variance_e * momentum + batch_variance_e * one_minus_momentum; variance_e * momentum + batch_variance_e * one_minus_momentum;
} }
y->set_layout(DataLayout::kMKLDNN);
y->set_format(
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
} }
}; };
...@@ -242,94 +359,168 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -242,94 +359,168 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const unsigned int ic = scale_tz[0]; const unsigned int ic = scale_tz[0];
// Retrieve bn_fwd_pd from device context
const std::string key = ctx.op().Input("SavedMean");
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
auto batch_norm_fwd_pd =
std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx.GetBlob(key_batch_norm_fwd_pd));
PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr,
"Fail to find batch_norm_fwd_pd in device context");
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>; using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
// create mkldnn memory from input diff_y tensor
mkldnn::memory::format dst_format = mkldnn::memory::format dst_format =
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format()); platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
unsigned flags = mkldnn::use_scale_shift;
// keys from forward pass
const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, false, input_format,
ctx.op().Input("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
// keys for primitives reuse
const std::string key_with_hash =
key + gethash(src_tz, epsilon, flags, false, input_format);
const std::string key_batch_norm_bwd_p =
key_with_hash + "@batch_norm_bwd_p";
const std::string key_batch_norm_src_mem_p =
key_with_hash + "@batch_norm_bwd_src_mem_p";
const std::string key_batch_norm_mean_mem_p =
key_with_hash + "@batch_norm_bwd_mean_mem_p";
const std::string key_batch_norm_variance_mem_p =
key_with_hash + "@batch_norm_bwd_variance_mem_p";
const std::string key_batch_norm_scaleshift_mem_p =
key_with_hash + "@batch_norm_bwd_scaleshift_mem_p";
const std::string key_batch_norm_diff_scaleshift_mem_p =
key_with_hash + "@batch_norm_bwd_diff_scaleshift_mem_p";
const std::string key_batch_norm_diff_src_mem_p =
key_with_hash + "@batch_norm_bwd_diff_src_mem_p";
const std::string key_batch_norm_diff_dst_mem_p =
key_with_hash + "@batch_norm_bwd_diff_dst_mem_p";
primitive reorder_diff_dst;
bool is_diff_dst_reordered = false;
auto user_diff_dst_memory = memory( auto user_diff_dst_memory = memory(
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine}, {{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
to_void_cast(diff_y_data)); to_void_cast(diff_y_data));
// create mkldnn memory from input x tensor // MKLDNN requires a single piece of memory for scale and shift/bias data
mkldnn::memory::format input_format = const size_t scaleshift_size = 2 * ic;
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
std::vector<T> scaleshift_data;
scaleshift_data.reserve(scaleshift_size);
copy_to_weights(scale_data, scale_data + ic, shift_data, shift_data + ic,
&scaleshift_data);
auto src_memory = memory( std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size);
auto batch_norm_fwd_pd =
std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx.GetBlob(key_batch_norm_fwd_pd));
PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr,
"Fail to find batch_norm_fwd_pd in device context");
auto batch_norm_bwd_p = std::static_pointer_cast<batch_norm_bwd>(
dev_ctx.GetBlob(key_batch_norm_bwd_p));
if (batch_norm_bwd_p == nullptr) {
auto src_memory = std::shared_ptr<memory>(new memory(
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
to_void_cast(x_data)); to_void_cast(x_data)));
// for diff_dst, try to use same format as dst in forward pass // for diff_dst, try to use same format as dst in forward pass
auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc(); auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc();
auto diff_dst_md = diff_dst_pd.desc(); auto diff_dst_md = diff_dst_pd.desc();
// create primitive descriptor for batch norm backward // create primitive descriptor for batch norm backward
unsigned flags = mkldnn::use_scale_shift;
auto batch_norm_bwd_desc = bn_bwd_types::op_desc{ auto batch_norm_bwd_desc = bn_bwd_types::op_desc{
mkldnn::prop_kind::backward, diff_dst_md, mkldnn::prop_kind::backward, diff_dst_md,
src_memory.get_primitive_desc().desc(), epsilon, flags}; src_memory->get_primitive_desc().desc(), epsilon, flags};
auto batch_norm_bwd_pd = bn_bwd_types::op_prim{ auto batch_norm_bwd_pd = bn_bwd_types::op_prim{
batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd}; batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd};
// reorder user_diff_dst if it's not in preferred format // reorder user_diff_dst if it's not in preferred format
auto diff_dst_memory = user_diff_dst_memory; auto diff_dst_memory = std::make_shared<memory>(user_diff_dst_memory);
primitive reorder_diff_dst;
bool is_diff_dst_reordered = false;
if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) { if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory = memory(diff_dst_pd); diff_dst_memory = std::make_shared<memory>(diff_dst_pd);
reorder_diff_dst = reorder(user_diff_dst_memory, diff_dst_memory); reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory);
is_diff_dst_reordered = true; is_diff_dst_reordered = true;
} }
// create mkldnn memory for input tensors (src/mean/variance) // create mkldnn memory for input tensors (src/mean/variance)
auto mean_memory = memory(batch_norm_bwd_pd.mean_primitive_desc(), auto mean_memory =
std::make_shared<memory>(batch_norm_bwd_pd.mean_primitive_desc(),
to_void_cast(batch_mean_data)); to_void_cast(batch_mean_data));
auto variance_memory = memory(batch_norm_bwd_pd.variance_primitive_desc(), auto variance_memory =
std::make_shared<memory>(batch_norm_bwd_pd.variance_primitive_desc(),
to_void_cast(batch_variance_data)); to_void_cast(batch_variance_data));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic;
std::vector<T> scaleshift_data;
scaleshift_data.reserve(scaleshift_size);
copy_to_weights(scale_data, scale_data + ic, shift_data, shift_data + ic,
&scaleshift_data);
// create mkldnn memory for input tensors (scale/shift) // create mkldnn memory for input tensors (scale/shift)
auto scaleshift_memory = memory(batch_norm_bwd_pd.weights_primitive_desc(), auto scaleshift_memory = std::make_shared<memory>(
scaleshift_data.data()); batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data());
// create mkldnn memory for output diff weights (combined scale/shift) // create mkldnn memory for output diff weights (combined scale/shift)
std::vector<T> diff_scaleshift_data; auto diff_scaleshift_memory = std::make_shared<memory>(
diff_scaleshift_data.reserve(scaleshift_size); batch_norm_bwd_pd.diff_weights_primitive_desc(),
auto diff_scaleshift_memory =
memory(batch_norm_bwd_pd.diff_weights_primitive_desc(),
diff_scaleshift_data.data()); diff_scaleshift_data.data());
// here assume diff_src is in the same format of src // here assume diff_src is in the same format of src
auto diff_src_memory = memory(src_memory.get_primitive_desc(), diff_x_data); auto diff_src_memory = std::make_shared<memory>(
src_memory->get_primitive_desc(), diff_x_data);
// finally create batch_norm backward primitive // finally create batch_norm backward primitive
auto batch_norm_bwd_prim = batch_norm_bwd_p = std::make_shared<batch_norm_bwd>(
batch_norm_bwd(batch_norm_bwd_pd, src_memory, mean_memory, batch_norm_bwd_pd, *src_memory, *mean_memory, *variance_memory,
variance_memory, diff_dst_memory, scaleshift_memory, *diff_dst_memory, *scaleshift_memory, *diff_src_memory,
diff_src_memory, diff_scaleshift_memory); *diff_scaleshift_memory);
dev_ctx.SetBlob(key_batch_norm_bwd_p, batch_norm_bwd_p);
dev_ctx.SetBlob(key_batch_norm_src_mem_p, src_memory);
dev_ctx.SetBlob(key_batch_norm_mean_mem_p, mean_memory);
dev_ctx.SetBlob(key_batch_norm_variance_mem_p, variance_memory);
dev_ctx.SetBlob(key_batch_norm_scaleshift_mem_p, scaleshift_memory);
dev_ctx.SetBlob(key_batch_norm_diff_scaleshift_mem_p,
diff_scaleshift_memory);
dev_ctx.SetBlob(key_batch_norm_diff_src_mem_p, diff_src_memory);
dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory);
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
} else {
// primitives already exist
UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data));
UpdateMemoryData(dev_ctx, key_batch_norm_mean_mem_p,
to_void_cast(batch_mean_data));
UpdateMemoryData(dev_ctx, key_batch_norm_variance_mem_p,
to_void_cast(batch_variance_data));
UpdateMemoryData(dev_ctx, key_batch_norm_scaleshift_mem_p,
scaleshift_data.data());
UpdateMemoryData(dev_ctx, key_batch_norm_diff_scaleshift_mem_p,
diff_scaleshift_data.data());
auto diff_src_memory = UpdateMemoryData(
dev_ctx, key_batch_norm_diff_src_mem_p, to_void_cast(diff_x_data));
auto diff_dst_memory = UpdateMemoryData(
dev_ctx, key_batch_norm_diff_dst_mem_p, to_void_cast(diff_y_data));
// reorder user_diff_dst if it's not in preferred format
if (diff_dst_memory->get_primitive_desc() !=
user_diff_dst_memory.get_primitive_desc()) {
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory);
is_diff_dst_reordered = true;
}
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
}
// execute optional reorder and batch_norm backward primitive // execute optional reorder and batch_norm backward primitive
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
if (is_diff_dst_reordered) pipeline.push_back(reorder_diff_dst); if (is_diff_dst_reordered) pipeline.push_back(reorder_diff_dst);
pipeline.push_back(batch_norm_bwd_prim); pipeline.push_back(*batch_norm_bwd_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
// copy back diff sacle/shift to output tensors (diff scale/shift) // copy back diff sacle/shift to output tensors (diff scale/shift)
...@@ -338,12 +529,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -338,12 +529,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::copy(it, std::next(it, ic), diff_scale_data); std::copy(it, std::next(it, ic), diff_scale_data);
std::copy(std::next(it, ic), std::end(diff_scaleshift_data), std::copy(std::next(it, ic), std::end(diff_scaleshift_data),
diff_shift_data); diff_shift_data);
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory.get_primitive_desc()
.desc()
.data.format);
} }
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册