提交 7d564356 编写于 作者: M mozga-intel

MKLDNN layout: Support for batch norm operator

上级 b7c683b8
......@@ -19,10 +19,17 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using batch_norm_bwd = mkldnn::batch_normalization_backward;
using batch_norm_fwd = mkldnn::batch_normalization_forward;
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc;
using mkldnn::memory;
using platform::to_void_cast;
template <typename T>
using EigenArrayMap =
......@@ -64,21 +71,12 @@ void run_batch_norm_op(Args &&... args) {
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
template <typename T>
inline void *cast_const_to_void(const T *t) {
return static_cast<void *>(const_cast<T *>(t));
}
} // namespace
template <typename T>
class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto data_layout_str = ctx.Attr<std::string>("data_layout");
auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW,
"MKLDNN batch normalization handles only NCHW data layout");
const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
......@@ -99,41 +97,53 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *shift = ctx.Input<Tensor>("Bias");
y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<T>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
x->format() != memory::format::format_undef,
"Wrong layout/format set for Input x tensor");
const T *x_data = x->data<T>();
const T *mean_data = mean->data<T>();
const T *variance_data = variance->data<T>();
T *y_data = y->mutable_data<T>(ctx.GetPlace());
T *mean_out_data = mean_out->mutable_data<T>(ctx.GetPlace());
T *variance_out_data = variance_out->mutable_data<T>(ctx.GetPlace());
T *batch_mean_data = nullptr;
T *batch_variance_data = nullptr;
if (!is_test) {
batch_mean->mutable_data<T>(ctx.GetPlace());
batch_variance->mutable_data<T>(ctx.GetPlace());
batch_mean_data = batch_mean->mutable_data<T>(ctx.GetPlace());
batch_variance_data = batch_variance->mutable_data<T>(ctx.GetPlace());
}
auto propagation = is_test == true ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
auto dims = paddle::framework::vectorize2int(x->dims());
auto src_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto dst_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto src_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
auto dst_pd = mkldnn::memory::primitive_desc{dst_md, mkldnn_engine};
auto src = mkldnn::memory{src_pd, cast_const_to_void(x->data<T>())};
auto dst = mkldnn::memory{dst_pd, y->data<T>()};
auto src_tz = paddle::framework::vectorize2int(x->dims());
auto scale_tz = paddle::framework::vectorize2int(scale->dims());
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
const unsigned int ic = scale_tz[0];
unsigned flags = mkldnn::use_scale_shift;
if (is_test) flags |= mkldnn::use_global_stats;
// create mkldnn memory from input x tensor
auto src_memory =
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine},
to_void_cast(x_data));
// create primitive descriptor for batch norm forward
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
auto batch_norm_fwd_desc =
bn_fwd_types::op_desc{propagation, src_md, epsilon, flags};
auto batch_norm_fwd_pd =
bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine};
auto batch_norm_fwd_desc = bn_fwd_types::op_desc{
propagation, src_memory.get_primitive_desc().desc(), epsilon, flags};
std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_fwd_pd =
std::shared_ptr<batch_norm_fwd::primitive_desc>(
new batch_norm_fwd::primitive_desc(batch_norm_fwd_desc,
mkldnn_engine));
const unsigned int ic = dims[1];
// Save the pd to be used in backward pass
const std::string key = ctx.op().Output("SavedMean");
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
dev_ctx.SetBlob(key_batch_norm_fwd_pd, batch_norm_fwd_pd);
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic;
......@@ -143,73 +153,58 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(),
shift->data<T>() + ic, &scaleshift_data);
auto scaleshift_memory = mkldnn::memory{
batch_norm_fwd_pd.weights_primitive_desc(), scaleshift_data.data()};
// crate mkldnn memory for weights(scale/shift)
auto scaleshift_memory = memory(batch_norm_fwd_pd->weights_primitive_desc(),
scaleshift_data.data());
if (is_test) {
auto mean_memory = mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(),
cast_const_to_void(mean->data<T>())};
// create mkldnn memory for output y tensor
auto dst_memory = memory(batch_norm_fwd_pd->dst_primitive_desc(), y_data);
if (is_test) {
// create mkldnn memory for stats (as input)
auto mean_memory = memory(batch_norm_fwd_pd->mean_primitive_desc(),
to_void_cast(mean_data));
auto variance_memory =
mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(),
cast_const_to_void(variance->data<T>())};
memory(batch_norm_fwd_pd->variance_primitive_desc(),
to_void_cast(variance_data));
run_batch_norm_op<typename bn_fwd_types::op_type>(
batch_norm_fwd_pd, src, (const mkldnn::primitive::at &)mean_memory,
*batch_norm_fwd_pd, src_memory,
(const mkldnn::primitive::at &)mean_memory,
(const mkldnn::primitive::at &)variance_memory, scaleshift_memory,
dst);
dst_memory);
} else {
// create mkldnn memory for stats (as output)
auto mean_memory =
mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(),
cast_const_to_void(batch_mean->data<T>())};
auto variance_memory =
mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(),
cast_const_to_void(batch_variance->data<T>())};
memory(batch_norm_fwd_pd->mean_primitive_desc(), batch_mean_data);
auto variance_memory = memory(
batch_norm_fwd_pd->variance_primitive_desc(), batch_variance_data);
run_batch_norm_op<bn_fwd_types::op_type>(batch_norm_fwd_pd, src,
scaleshift_memory, dst,
run_batch_norm_op<bn_fwd_types::op_type>(*batch_norm_fwd_pd, src_memory,
scaleshift_memory, dst_memory,
mean_memory, variance_memory);
}
if (!is_test) {
const unsigned int in = dims[0];
const unsigned int sample_size = x->numel() / in / ic;
// saved_xx is use just in this batch of data
EigenVectorArrayMap<T> saved_mean_e(
batch_mean->mutable_data<T>(ctx.GetPlace()), ic);
EigenVectorArrayMap<T> saved_variance_e(
batch_variance->mutable_data<T>(ctx.GetPlace()), ic);
saved_mean_e.setZero();
saved_variance_e.setZero();
const unsigned int x_arr_size = in * ic;
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, x_arr_size);
for (unsigned int nc = 0; nc < x_arr_size; ++nc) {
saved_mean_e(nc % ic) += x_arr.col(nc).sum();
}
saved_mean_e /= in * sample_size;
for (unsigned int nc = 0; nc < x_arr_size; ++nc) {
saved_variance_e(nc % ic) +=
(x_arr.col(nc) - saved_mean_e(nc % ic)).matrix().squaredNorm();
}
saved_variance_e /= in * sample_size;
ConstEigenVectorArrayMap<T> mean_arr{mean->data<T>(), ic};
ConstEigenVectorArrayMap<T> variance_arr{variance->data<T>(), ic};
// mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib
EigenVectorArrayMap<T> batch_mean_e(batch_mean_data, ic);
EigenVectorArrayMap<T> batch_variance_e(batch_variance_data, ic);
ConstEigenVectorArrayMap<T> mean_e(mean_data, ic);
ConstEigenVectorArrayMap<T> variance_e{variance_data, ic};
EigenVectorArrayMap<T> running_mean_arr(
mean_out->mutable_data<T>(ctx.GetPlace()), ic);
EigenVectorArrayMap<T> running_var_arr(
variance_out->mutable_data<T>(ctx.GetPlace()), ic);
EigenVectorArrayMap<T> running_mean_e(mean_out_data, ic);
EigenVectorArrayMap<T> running_variance_e(variance_out_data, ic);
auto one_minus_momentum = 1. - momentum;
running_mean_arr =
mean_arr * momentum + saved_mean_e * one_minus_momentum;
running_var_arr =
variance_arr * momentum + saved_variance_e * one_minus_momentum;
running_mean_e = mean_e * momentum + batch_mean_e * one_minus_momentum;
running_variance_e =
variance_e * momentum + batch_variance_e * one_minus_momentum;
}
y->set_layout(DataLayout::kMKLDNN);
y->set_format(
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
}
};
......@@ -217,11 +212,6 @@ template <typename T>
class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto data_layout_str = ctx.Attr<std::string>("data_layout");
auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW,
"MKLDNN batch normalization handles only NCHW data layout");
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
......@@ -238,88 +228,132 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
diff_x->mutable_data<T>(ctx.GetPlace());
diff_scale->mutable_data<T>(ctx.GetPlace());
diff_shift->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE(diff_y->layout() == DataLayout::kMKLDNN &&
diff_y->format() != memory::format::format_undef,
"Wrong layout/format set for Input diff_y tensor");
const T *x_data = x->data<T>();
const T *diff_y_data = diff_y->data<T>();
const T *batch_mean_data = batch_mean->data<T>();
const T *batch_variance_data = batch_variance->data<T>();
const T *scale_data = scale->data<T>();
const T *shift_data = shift->data<T>();
T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace());
T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace());
auto src_tz = paddle::framework::vectorize2int(x->dims());
auto diff_src_tz = src_tz;
auto dst_tz = src_tz;
auto diff_dst_tz = dst_tz;
auto scale_tz = paddle::framework::vectorize2int(scale->dims());
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
const unsigned int ic = scale_tz[0];
// Retrieve bn_fwd_pd from device context
const std::string key = ctx.op().Input("SavedMean");
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
auto batch_norm_fwd_pd =
std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx.GetBlob(key_batch_norm_fwd_pd));
PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr,
"Fail to find batch_norm_fwd_pd in device context");
auto dims = paddle::framework::vectorize2int(x->dims());
unsigned flags = mkldnn::use_scale_shift | !mkldnn::use_global_stats;
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
auto src_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto dst_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto diff_src_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto diff_dst_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
// create mkldnn memory from input diff_y tensor
auto user_diff_dst_memory =
memory({{{diff_dst_tz}, memory::data_type::f32, diff_y->format()},
mkldnn_engine},
to_void_cast(diff_y_data));
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
// create mkldnn memory from input x tensor
auto src_memory =
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine},
to_void_cast(x_data));
auto batch_norm_fwd_desc = bn_fwd_types::op_desc{
mkldnn::prop_kind::forward_training, src_md, epsilon, flags};
auto batch_norm_fwd_pd =
bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine};
// for diff_dst, try to use same format as dst in forward pass
auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc();
auto diff_dst_md = diff_dst_pd.desc();
// create primitive descriptor for batch norm backward
unsigned flags = mkldnn::use_scale_shift;
auto batch_norm_bwd_desc = bn_bwd_types::op_desc{
mkldnn::prop_kind::backward, diff_dst_md, dst_md, epsilon, flags};
mkldnn::prop_kind::backward, diff_dst_md,
src_memory.get_primitive_desc().desc(), epsilon, flags};
auto batch_norm_bwd_pd = bn_bwd_types::op_prim{
batch_norm_bwd_desc, mkldnn_engine, batch_norm_fwd_pd};
auto src = mkldnn::memory{{src_md, mkldnn_engine},
cast_const_to_void(x->data<T>())};
auto mean = mkldnn::memory{batch_norm_bwd_pd.mean_primitive_desc(),
cast_const_to_void(batch_mean->data<T>())};
auto variance =
mkldnn::memory{batch_norm_bwd_pd.variance_primitive_desc(),
cast_const_to_void(batch_variance->data<T>())};
auto diff_dst = mkldnn::memory{{diff_dst_md, mkldnn_engine},
cast_const_to_void(diff_y->data<T>())};
batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd};
// reorder user_diff_dst if it's not in preferred format
auto diff_dst_memory = user_diff_dst_memory;
primitive reorder_diff_dst;
bool is_diff_dst_reordered = false;
if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory = memory(diff_dst_pd);
reorder_diff_dst = reorder(user_diff_dst_memory, diff_dst_memory);
is_diff_dst_reordered = true;
}
const unsigned int ic = dims[1];
// create mkldnn memory for input tensors (src/mean/variance)
auto mean_memory = memory(batch_norm_bwd_pd.mean_primitive_desc(),
to_void_cast(batch_mean_data));
auto variance_memory = memory(batch_norm_bwd_pd.variance_primitive_desc(),
to_void_cast(batch_variance_data));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic;
std::vector<T> scaleshift_data;
scaleshift_data.reserve(scaleshift_size);
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(),
shift->data<T>() + ic, &scaleshift_data);
copy_to_weights(scale_data, scale_data + ic, shift_data, shift_data + ic,
&scaleshift_data);
auto scaleshift_memory = mkldnn::memory{
batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data()};
// create mkldnn memory for input tensors (scale/shift)
auto scaleshift_memory = memory(batch_norm_bwd_pd.weights_primitive_desc(),
scaleshift_data.data());
// create mkldnn memory for output diff weights (combined scale/shift)
std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size);
copy_to_weights(diff_scale->data<T>(), diff_scale->data<T>() + ic,
diff_shift->data<T>(), diff_shift->data<T>() + ic,
&diff_scaleshift_data);
auto diff_scaleshift_memory =
mkldnn::memory{batch_norm_bwd_pd.diff_weights_primitive_desc(),
diff_scaleshift_data.data()};
auto diff_src = mkldnn::memory{{diff_src_md, mkldnn_engine},
static_cast<void *>(diff_x->data<T>())};
run_batch_norm_op<bn_bwd_types::op_type>(
batch_norm_bwd_pd, src, mean, variance, diff_dst, scaleshift_memory,
diff_src, diff_scaleshift_memory);
memory(batch_norm_bwd_pd.diff_weights_primitive_desc(),
diff_scaleshift_data.data());
// here assume diff_src is in the same format of src
auto diff_src_memory = memory(src_memory.get_primitive_desc(), diff_x_data);
// finally create batch_norm backward primitive
auto batch_norm_bwd_prim =
batch_norm_bwd(batch_norm_bwd_pd, src_memory, mean_memory,
variance_memory, diff_dst_memory, scaleshift_memory,
diff_src_memory, diff_scaleshift_memory);
// execute optional reorder and batch_norm backward primitive
std::vector<primitive> pipeline;
if (is_diff_dst_reordered) pipeline.push_back(reorder_diff_dst);
pipeline.push_back(batch_norm_bwd_prim);
stream(stream::kind::eager).submit(pipeline).wait();
// copy back diff sacle/shift to output tensors (diff scale/shift)
diff_scaleshift_data.resize(scaleshift_size);
auto it = std::begin(diff_scaleshift_data);
std::copy(it, std::next(it, ic), diff_scale->data<T>());
std::copy(it, std::next(it, ic), diff_scale_data);
std::copy(std::next(it, ic), std::end(diff_scaleshift_data),
diff_shift->data<T>());
diff_shift_data);
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory.get_primitive_desc()
.desc()
.data.format);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(batch_norm, MKLDNN, paddle::platform::CPUPlace,
REGISTER_OP_KERNEL(batch_norm, MKLDNN, ::paddle::platform::CPUPlace,
ops::BatchNormMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(batch_norm_grad, MKLDNN, paddle::platform::CPUPlace,
REGISTER_OP_KERNEL(batch_norm_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::BatchNormMKLDNNGradOpKernel<float>);
......@@ -110,19 +110,19 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type");
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library_);
library);
}
};
......@@ -368,19 +368,21 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
PADDLE_THROW("can't find Y@GRAD");
}
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_);
layout, library);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册