提交 50d3e6e9 编写于 作者: K Krzysztof Binias

Reusing primitives for forward Batch Norm operator

上级 ef7bd03a
......@@ -37,6 +37,122 @@ struct bn_type_traits {
using op_prim = typename op_type::primitive_desc;
};
class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
public:
BatchNormMKLDNNHandler(
std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_pd,
const platform::MKLDNNDeviceContext &dev_ctx, mkldnn::engine engine,
const std::string &base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {
batch_norm_pd_ = batch_norm_pd;
}
std::shared_ptr<memory> AcquireScaleshiftMemoryFromPrimitive(void *ptr) {
return this->AcquireMemoryFromPrimitive(
batch_norm_pd_->weights_primitive_desc(), ptr, "@scaleshift_mem_p");
}
std::shared_ptr<memory> AcquireMeanMemoryFromPrimitive(void *ptr) {
return this->AcquireMemoryFromPrimitive(
batch_norm_pd_->mean_primitive_desc(), ptr, "@mean_mem_p");
}
std::shared_ptr<memory> AcquireVarianceMemoryFromPrimitive(void *ptr) {
return this->AcquireMemoryFromPrimitive(
batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p");
}
std::shared_ptr<batch_norm_fwd> AcquireTestBatchNormFwd(
std::shared_ptr<memory> src_memory,
const mkldnn::primitive::at &mean_memory,
const mkldnn::primitive::at &variance_memory,
std::shared_ptr<memory> scaleshift_memory,
std::shared_ptr<memory> dst_memory) {
auto prim_key = key_ + "@batch_norm_p";
auto batch_norm_p =
std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(batch_norm_p != nullptr) || (is_reusing_ == false),
"Fail to find batch norm primitive for test in device context");
if (batch_norm_p == nullptr) {
batch_norm_p = std::make_shared<batch_norm_fwd>(
*batch_norm_pd_, *src_memory, mean_memory, variance_memory,
*scaleshift_memory, *dst_memory);
dev_ctx_.SetBlob(prim_key, batch_norm_p);
} else {
is_reusing_ = true;
}
return batch_norm_p;
}
std::shared_ptr<batch_norm_fwd> AcquireTrainingBatchNormFwd(
std::shared_ptr<memory> src_memory,
std::shared_ptr<memory> scaleshift_memory,
std::shared_ptr<memory> dst_memory, std::shared_ptr<memory> mean_memory,
std::shared_ptr<memory> variance_memory) {
auto prim_key = key_ + "@batch_norm_p";
auto batch_norm_p =
std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(batch_norm_p != nullptr) || (is_reusing_ == false),
"Fail to find batch norm primitive for training in device context");
if (batch_norm_p == nullptr) {
batch_norm_p = std::make_shared<batch_norm_fwd>(
*batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory,
*mean_memory, *variance_memory);
dev_ctx_.SetBlob(prim_key, batch_norm_p);
} else {
is_reusing_ = true;
}
return batch_norm_p;
}
//
static std::string GetHash(const memory::dims &input_dims, float epsilon,
unsigned flag, bool is_test, memory::format format,
const std::string &suffix) {
auto dims2str = [](const memory::dims &operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
};
return dims2str(input_dims) + std::to_string(epsilon) +
std::to_string(flag) + std::to_string(is_test) +
std::to_string(format) + suffix;
}
private:
std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_pd_;
};
std::string gethash(const memory::dims &input_dims, float epsilon,
unsigned flag, bool is_test, memory::format format) {
auto dims2str = [](const memory::dims &operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
};
return dims2str(input_dims) + std::to_string(epsilon) + std::to_string(flag) +
std::to_string(is_test) + std::to_string(format);
}
std::shared_ptr<memory> UpdateMemoryData(
const platform::MKLDNNDeviceContext &dev_ctx, const std::string &key,
void *new_ptr) {
auto mem = std::static_pointer_cast<memory>(dev_ctx.GetBlob(key));
PADDLE_ENFORCE(
mem != nullptr,
(std::string("Fail to find memory in device context [key: ") + key + "]")
.c_str());
mem->set_data_handle(new_ptr);
return mem;
}
template <typename T, typename Container>
void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end,
Container *c) {
......@@ -48,15 +164,6 @@ void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end,
std::inserter(*c, std::next(it, std::distance(scale_begin, scale_end))));
}
template <typename Op, typename... Args>
void run_batch_norm_op(Args &&... args) {
Op batch_norm_op{args...};
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(batch_norm_op);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
} // namespace
template <typename T>
......@@ -110,6 +217,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
const unsigned int ic = scale_tz[0];
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic;
std::vector<T> scaleshift_data;
scaleshift_data.reserve(scaleshift_size);
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(),
shift->data<T>() + ic, &scaleshift_data);
unsigned flags = mkldnn::use_scale_shift;
if (is_test) flags |= mkldnn::use_global_stats;
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
......@@ -118,64 +233,70 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
auto src_memory = memory(
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
to_void_cast(x_data));
// keys for backward pass
const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, is_test, input_format,
ctx.op().Output("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input_format);
// create primitive descriptor for batch norm forward
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
auto batch_norm_fwd_desc = bn_fwd_types::op_desc{
propagation, src_memory.get_primitive_desc().desc(), epsilon, flags};
std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_fwd_pd =
std::shared_ptr<batch_norm_fwd::primitive_desc>(
new batch_norm_fwd::primitive_desc(batch_norm_fwd_desc,
mkldnn_engine));
// Save the pd to be used in backward pass
const std::string key = ctx.op().Output("SavedMean");
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
auto batch_norm_fwd_desc =
bn_fwd_types::op_desc{propagation, user_src_md, epsilon, flags};
auto batch_norm_fwd_pd = std::make_shared<batch_norm_fwd::primitive_desc>(
batch_norm_fwd_desc, mkldnn_engine);
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_batch_norm_fwd_pd, batch_norm_fwd_pd);
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic;
std::vector<T> scaleshift_data;
scaleshift_data.reserve(scaleshift_size);
BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine,
key);
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(),
shift->data<T>() + ic, &scaleshift_data);
auto src_memory =
handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data));
// crate mkldnn memory for weights(scale/shift)
auto scaleshift_memory = memory(batch_norm_fwd_pd->weights_primitive_desc(),
scaleshift_data.data());
auto scaleshift_memory =
handler.AcquireScaleshiftMemoryFromPrimitive(scaleshift_data.data());
// create mkldnn memory for output y tensor
auto dst_memory = memory(batch_norm_fwd_pd->dst_primitive_desc(), y_data);
auto dst_memory = handler.AcquireDstMemory(
batch_norm_fwd_pd->dst_primitive_desc().desc(), y_data);
std::shared_ptr<batch_norm_fwd> batch_norm_p;
if (is_test) {
// create mkldnn memory for stats (as input)
auto mean_memory = memory(batch_norm_fwd_pd->mean_primitive_desc(),
to_void_cast(mean_data));
auto variance_memory =
memory(batch_norm_fwd_pd->variance_primitive_desc(),
to_void_cast(variance_data));
run_batch_norm_op<typename bn_fwd_types::op_type>(
*batch_norm_fwd_pd, src_memory,
(const mkldnn::primitive::at &)mean_memory,
(const mkldnn::primitive::at &)variance_memory, scaleshift_memory,
std::shared_ptr<memory> mean_memory =
handler.AcquireMeanMemoryFromPrimitive(to_void_cast(mean_data));
std::shared_ptr<memory> variance_memory =
handler.AcquireVarianceMemoryFromPrimitive(
to_void_cast(variance_data));
batch_norm_p = handler.AcquireTestBatchNormFwd(
src_memory, (const mkldnn::primitive::at &)*mean_memory,
(const mkldnn::primitive::at &)*variance_memory, scaleshift_memory,
dst_memory);
} else {
// create mkldnn memory for stats (as output)
auto mean_memory =
memory(batch_norm_fwd_pd->mean_primitive_desc(), batch_mean_data);
auto variance_memory = memory(
batch_norm_fwd_pd->variance_primitive_desc(), batch_variance_data);
run_batch_norm_op<bn_fwd_types::op_type>(*batch_norm_fwd_pd, src_memory,
scaleshift_memory, dst_memory,
mean_memory, variance_memory);
std::shared_ptr<memory> mean_memory =
handler.AcquireMeanMemoryFromPrimitive(batch_mean_data);
std::shared_ptr<memory> variance_memory =
handler.AcquireVarianceMemoryFromPrimitive(batch_variance_data);
batch_norm_p = handler.AcquireTrainingBatchNormFwd(
src_memory, scaleshift_memory, dst_memory, mean_memory,
variance_memory);
}
y->set_layout(DataLayout::kMKLDNN);
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*batch_norm_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
if (!is_test) {
// mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib
......@@ -192,10 +313,6 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
running_variance_e =
variance_e * momentum + batch_variance_e * one_minus_momentum;
}
y->set_layout(DataLayout::kMKLDNN);
y->set_format(
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
}
};
......@@ -242,61 +359,47 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const unsigned int ic = scale_tz[0];
// Retrieve bn_fwd_pd from device context
const std::string key = ctx.op().Input("SavedMean");
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
auto batch_norm_fwd_pd =
std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx.GetBlob(key_batch_norm_fwd_pd));
PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr,
"Fail to find batch_norm_fwd_pd in device context");
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
// create mkldnn memory from input diff_y tensor
mkldnn::memory::format dst_format =
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
auto user_diff_dst_memory = memory(
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
to_void_cast(diff_y_data));
// create mkldnn memory from input x tensor
mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
auto src_memory = memory(
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
to_void_cast(x_data));
unsigned flags = mkldnn::use_scale_shift;
// keys from forward pass
const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, false, input_format,
ctx.op().Input("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
// for diff_dst, try to use same format as dst in forward pass
auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc();
auto diff_dst_md = diff_dst_pd.desc();
// keys for primitives reuse
const std::string key_with_hash =
key + gethash(src_tz, epsilon, flags, false, input_format);
const std::string key_batch_norm_bwd_p =
key_with_hash + "@batch_norm_bwd_p";
const std::string key_batch_norm_src_mem_p =
key_with_hash + "@batch_norm_bwd_src_mem_p";
const std::string key_batch_norm_mean_mem_p =
key_with_hash + "@batch_norm_bwd_mean_mem_p";
const std::string key_batch_norm_variance_mem_p =
key_with_hash + "@batch_norm_bwd_variance_mem_p";
const std::string key_batch_norm_scaleshift_mem_p =
key_with_hash + "@batch_norm_bwd_scaleshift_mem_p";
const std::string key_batch_norm_diff_scaleshift_mem_p =
key_with_hash + "@batch_norm_bwd_diff_scaleshift_mem_p";
const std::string key_batch_norm_diff_src_mem_p =
key_with_hash + "@batch_norm_bwd_diff_src_mem_p";
const std::string key_batch_norm_diff_dst_mem_p =
key_with_hash + "@batch_norm_bwd_diff_dst_mem_p";
// create primitive descriptor for batch norm backward
unsigned flags = mkldnn::use_scale_shift;
auto batch_norm_bwd_desc = bn_bwd_types::op_desc{
mkldnn::prop_kind::backward, diff_dst_md,
src_memory.get_primitive_desc().desc(), epsilon, flags};
auto batch_norm_bwd_pd = bn_bwd_types::op_prim{
batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd};
// reorder user_diff_dst if it's not in preferred format
auto diff_dst_memory = user_diff_dst_memory;
primitive reorder_diff_dst;
bool is_diff_dst_reordered = false;
if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory = memory(diff_dst_pd);
reorder_diff_dst = reorder(user_diff_dst_memory, diff_dst_memory);
is_diff_dst_reordered = true;
}
// create mkldnn memory for input tensors (src/mean/variance)
auto mean_memory = memory(batch_norm_bwd_pd.mean_primitive_desc(),
to_void_cast(batch_mean_data));
auto variance_memory = memory(batch_norm_bwd_pd.variance_primitive_desc(),
to_void_cast(batch_variance_data));
auto user_diff_dst_memory = memory(
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
to_void_cast(diff_y_data));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic;
......@@ -306,30 +409,118 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
copy_to_weights(scale_data, scale_data + ic, shift_data, shift_data + ic,
&scaleshift_data);
// create mkldnn memory for input tensors (scale/shift)
auto scaleshift_memory = memory(batch_norm_bwd_pd.weights_primitive_desc(),
scaleshift_data.data());
// create mkldnn memory for output diff weights (combined scale/shift)
std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size);
auto diff_scaleshift_memory =
memory(batch_norm_bwd_pd.diff_weights_primitive_desc(),
diff_scaleshift_data.data());
// here assume diff_src is in the same format of src
auto diff_src_memory = memory(src_memory.get_primitive_desc(), diff_x_data);
auto batch_norm_fwd_pd =
std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx.GetBlob(key_batch_norm_fwd_pd));
PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr,
"Fail to find batch_norm_fwd_pd in device context");
// finally create batch_norm backward primitive
auto batch_norm_bwd_prim =
batch_norm_bwd(batch_norm_bwd_pd, src_memory, mean_memory,
variance_memory, diff_dst_memory, scaleshift_memory,
diff_src_memory, diff_scaleshift_memory);
auto batch_norm_bwd_p = std::static_pointer_cast<batch_norm_bwd>(
dev_ctx.GetBlob(key_batch_norm_bwd_p));
if (batch_norm_bwd_p == nullptr) {
auto src_memory = std::shared_ptr<memory>(new memory(
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
to_void_cast(x_data)));
// for diff_dst, try to use same format as dst in forward pass
auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc();
auto diff_dst_md = diff_dst_pd.desc();
// create primitive descriptor for batch norm backward
auto batch_norm_bwd_desc = bn_bwd_types::op_desc{
mkldnn::prop_kind::backward, diff_dst_md,
src_memory->get_primitive_desc().desc(), epsilon, flags};
auto batch_norm_bwd_pd = bn_bwd_types::op_prim{
batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd};
// reorder user_diff_dst if it's not in preferred format
auto diff_dst_memory = std::make_shared<memory>(user_diff_dst_memory);
if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory = std::make_shared<memory>(diff_dst_pd);
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory);
is_diff_dst_reordered = true;
}
// create mkldnn memory for input tensors (src/mean/variance)
auto mean_memory =
std::make_shared<memory>(batch_norm_bwd_pd.mean_primitive_desc(),
to_void_cast(batch_mean_data));
auto variance_memory =
std::make_shared<memory>(batch_norm_bwd_pd.variance_primitive_desc(),
to_void_cast(batch_variance_data));
// create mkldnn memory for input tensors (scale/shift)
auto scaleshift_memory = std::make_shared<memory>(
batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data());
// create mkldnn memory for output diff weights (combined scale/shift)
auto diff_scaleshift_memory = std::make_shared<memory>(
batch_norm_bwd_pd.diff_weights_primitive_desc(),
diff_scaleshift_data.data());
// here assume diff_src is in the same format of src
auto diff_src_memory = std::make_shared<memory>(
src_memory->get_primitive_desc(), diff_x_data);
// finally create batch_norm backward primitive
batch_norm_bwd_p = std::make_shared<batch_norm_bwd>(
batch_norm_bwd_pd, *src_memory, *mean_memory, *variance_memory,
*diff_dst_memory, *scaleshift_memory, *diff_src_memory,
*diff_scaleshift_memory);
dev_ctx.SetBlob(key_batch_norm_bwd_p, batch_norm_bwd_p);
dev_ctx.SetBlob(key_batch_norm_src_mem_p, src_memory);
dev_ctx.SetBlob(key_batch_norm_mean_mem_p, mean_memory);
dev_ctx.SetBlob(key_batch_norm_variance_mem_p, variance_memory);
dev_ctx.SetBlob(key_batch_norm_scaleshift_mem_p, scaleshift_memory);
dev_ctx.SetBlob(key_batch_norm_diff_scaleshift_mem_p,
diff_scaleshift_memory);
dev_ctx.SetBlob(key_batch_norm_diff_src_mem_p, diff_src_memory);
dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory);
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
} else {
// primitives already exist
UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data));
UpdateMemoryData(dev_ctx, key_batch_norm_mean_mem_p,
to_void_cast(batch_mean_data));
UpdateMemoryData(dev_ctx, key_batch_norm_variance_mem_p,
to_void_cast(batch_variance_data));
UpdateMemoryData(dev_ctx, key_batch_norm_scaleshift_mem_p,
scaleshift_data.data());
UpdateMemoryData(dev_ctx, key_batch_norm_diff_scaleshift_mem_p,
diff_scaleshift_data.data());
auto diff_src_memory = UpdateMemoryData(
dev_ctx, key_batch_norm_diff_src_mem_p, to_void_cast(diff_x_data));
auto diff_dst_memory = UpdateMemoryData(
dev_ctx, key_batch_norm_diff_dst_mem_p, to_void_cast(diff_y_data));
// reorder user_diff_dst if it's not in preferred format
if (diff_dst_memory->get_primitive_desc() !=
user_diff_dst_memory.get_primitive_desc()) {
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory);
is_diff_dst_reordered = true;
}
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
}
// execute optional reorder and batch_norm backward primitive
std::vector<primitive> pipeline;
if (is_diff_dst_reordered) pipeline.push_back(reorder_diff_dst);
pipeline.push_back(batch_norm_bwd_prim);
pipeline.push_back(*batch_norm_bwd_p);
stream(stream::kind::eager).submit(pipeline).wait();
// copy back diff sacle/shift to output tensors (diff scale/shift)
......@@ -338,12 +529,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::copy(it, std::next(it, ic), diff_scale_data);
std::copy(std::next(it, ic), std::end(diff_scaleshift_data),
diff_shift_data);
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory.get_primitive_desc()
.desc()
.data.format);
}
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册