提交 7d564356 编写于 作者: M mozga-intel

MKLDNN layout: Support for batch norm operator

上级 b7c683b8
...@@ -19,10 +19,17 @@ limitations under the License. */ ...@@ -19,10 +19,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using batch_norm_bwd = mkldnn::batch_normalization_backward;
using batch_norm_fwd = mkldnn::batch_normalization_forward;
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc; using paddle::platform::MKLDNNMemDesc;
using mkldnn::memory; using platform::to_void_cast;
template <typename T> template <typename T>
using EigenArrayMap = using EigenArrayMap =
...@@ -64,21 +71,12 @@ void run_batch_norm_op(Args &&... args) { ...@@ -64,21 +71,12 @@ void run_batch_norm_op(Args &&... args) {
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
} }
template <typename T>
inline void *cast_const_to_void(const T *t) {
return static_cast<void *>(const_cast<T *>(t));
}
} // namespace } // namespace
template <typename T> template <typename T>
class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto data_layout_str = ctx.Attr<std::string>("data_layout");
auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW,
"MKLDNN batch normalization handles only NCHW data layout");
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum"); const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
...@@ -99,41 +97,53 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -99,41 +97,53 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *shift = ctx.Input<Tensor>("Bias"); const auto *shift = ctx.Input<Tensor>("Bias");
y->mutable_data<T>(ctx.GetPlace()); PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
mean_out->mutable_data<T>(ctx.GetPlace()); x->format() != memory::format::format_undef,
variance_out->mutable_data<T>(ctx.GetPlace()); "Wrong layout/format set for Input x tensor");
const T *x_data = x->data<T>();
const T *mean_data = mean->data<T>();
const T *variance_data = variance->data<T>();
T *y_data = y->mutable_data<T>(ctx.GetPlace());
T *mean_out_data = mean_out->mutable_data<T>(ctx.GetPlace());
T *variance_out_data = variance_out->mutable_data<T>(ctx.GetPlace());
T *batch_mean_data = nullptr;
T *batch_variance_data = nullptr;
if (!is_test) { if (!is_test) {
batch_mean->mutable_data<T>(ctx.GetPlace()); batch_mean_data = batch_mean->mutable_data<T>(ctx.GetPlace());
batch_variance->mutable_data<T>(ctx.GetPlace()); batch_variance_data = batch_variance->mutable_data<T>(ctx.GetPlace());
} }
auto propagation = is_test == true ? mkldnn::prop_kind::forward_scoring auto propagation = is_test == true ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
auto dims = paddle::framework::vectorize2int(x->dims()); auto src_tz = paddle::framework::vectorize2int(x->dims());
auto scale_tz = paddle::framework::vectorize2int(scale->dims());
auto src_md = PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); const unsigned int ic = scale_tz[0];
auto dst_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto src_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
auto dst_pd = mkldnn::memory::primitive_desc{dst_md, mkldnn_engine};
auto src = mkldnn::memory{src_pd, cast_const_to_void(x->data<T>())};
auto dst = mkldnn::memory{dst_pd, y->data<T>()};
unsigned flags = mkldnn::use_scale_shift; unsigned flags = mkldnn::use_scale_shift;
if (is_test) flags |= mkldnn::use_global_stats; if (is_test) flags |= mkldnn::use_global_stats;
// create mkldnn memory from input x tensor
auto src_memory =
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine},
to_void_cast(x_data));
// create primitive descriptor for batch norm forward
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>; using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
auto batch_norm_fwd_desc = auto batch_norm_fwd_desc = bn_fwd_types::op_desc{
bn_fwd_types::op_desc{propagation, src_md, epsilon, flags}; propagation, src_memory.get_primitive_desc().desc(), epsilon, flags};
auto batch_norm_fwd_pd = std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_fwd_pd =
bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine}; std::shared_ptr<batch_norm_fwd::primitive_desc>(
new batch_norm_fwd::primitive_desc(batch_norm_fwd_desc,
mkldnn_engine));
const unsigned int ic = dims[1]; // Save the pd to be used in backward pass
const std::string key = ctx.op().Output("SavedMean");
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
dev_ctx.SetBlob(key_batch_norm_fwd_pd, batch_norm_fwd_pd);
// MKLDNN requires a single piece of memory for scale and shift/bias data // MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic; const size_t scaleshift_size = 2 * ic;
...@@ -143,73 +153,58 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -143,73 +153,58 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(), copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(),
shift->data<T>() + ic, &scaleshift_data); shift->data<T>() + ic, &scaleshift_data);
auto scaleshift_memory = mkldnn::memory{ // crate mkldnn memory for weights(scale/shift)
batch_norm_fwd_pd.weights_primitive_desc(), scaleshift_data.data()}; auto scaleshift_memory = memory(batch_norm_fwd_pd->weights_primitive_desc(),
scaleshift_data.data());
if (is_test) { // create mkldnn memory for output y tensor
auto mean_memory = mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(), auto dst_memory = memory(batch_norm_fwd_pd->dst_primitive_desc(), y_data);
cast_const_to_void(mean->data<T>())};
if (is_test) {
// create mkldnn memory for stats (as input)
auto mean_memory = memory(batch_norm_fwd_pd->mean_primitive_desc(),
to_void_cast(mean_data));
auto variance_memory = auto variance_memory =
mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(), memory(batch_norm_fwd_pd->variance_primitive_desc(),
cast_const_to_void(variance->data<T>())}; to_void_cast(variance_data));
run_batch_norm_op<typename bn_fwd_types::op_type>( run_batch_norm_op<typename bn_fwd_types::op_type>(
batch_norm_fwd_pd, src, (const mkldnn::primitive::at &)mean_memory, *batch_norm_fwd_pd, src_memory,
(const mkldnn::primitive::at &)mean_memory,
(const mkldnn::primitive::at &)variance_memory, scaleshift_memory, (const mkldnn::primitive::at &)variance_memory, scaleshift_memory,
dst); dst_memory);
} else { } else {
// create mkldnn memory for stats (as output)
auto mean_memory = auto mean_memory =
mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(), memory(batch_norm_fwd_pd->mean_primitive_desc(), batch_mean_data);
cast_const_to_void(batch_mean->data<T>())}; auto variance_memory = memory(
batch_norm_fwd_pd->variance_primitive_desc(), batch_variance_data);
auto variance_memory =
mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(),
cast_const_to_void(batch_variance->data<T>())};
run_batch_norm_op<bn_fwd_types::op_type>(batch_norm_fwd_pd, src, run_batch_norm_op<bn_fwd_types::op_type>(*batch_norm_fwd_pd, src_memory,
scaleshift_memory, dst, scaleshift_memory, dst_memory,
mean_memory, variance_memory); mean_memory, variance_memory);
} }
if (!is_test) { if (!is_test) {
const unsigned int in = dims[0]; // mkldnn only compute stats for current batch
const unsigned int sample_size = x->numel() / in / ic; // so we need compute momentum stats via Eigen lib
EigenVectorArrayMap<T> batch_mean_e(batch_mean_data, ic);
// saved_xx is use just in this batch of data EigenVectorArrayMap<T> batch_variance_e(batch_variance_data, ic);
EigenVectorArrayMap<T> saved_mean_e( ConstEigenVectorArrayMap<T> mean_e(mean_data, ic);
batch_mean->mutable_data<T>(ctx.GetPlace()), ic); ConstEigenVectorArrayMap<T> variance_e{variance_data, ic};
EigenVectorArrayMap<T> saved_variance_e(
batch_variance->mutable_data<T>(ctx.GetPlace()), ic); EigenVectorArrayMap<T> running_mean_e(mean_out_data, ic);
saved_mean_e.setZero(); EigenVectorArrayMap<T> running_variance_e(variance_out_data, ic);
saved_variance_e.setZero();
const unsigned int x_arr_size = in * ic;
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, x_arr_size);
for (unsigned int nc = 0; nc < x_arr_size; ++nc) {
saved_mean_e(nc % ic) += x_arr.col(nc).sum();
}
saved_mean_e /= in * sample_size;
for (unsigned int nc = 0; nc < x_arr_size; ++nc) {
saved_variance_e(nc % ic) +=
(x_arr.col(nc) - saved_mean_e(nc % ic)).matrix().squaredNorm();
}
saved_variance_e /= in * sample_size;
ConstEigenVectorArrayMap<T> mean_arr{mean->data<T>(), ic};
ConstEigenVectorArrayMap<T> variance_arr{variance->data<T>(), ic};
EigenVectorArrayMap<T> running_mean_arr(
mean_out->mutable_data<T>(ctx.GetPlace()), ic);
EigenVectorArrayMap<T> running_var_arr(
variance_out->mutable_data<T>(ctx.GetPlace()), ic);
auto one_minus_momentum = 1. - momentum; auto one_minus_momentum = 1. - momentum;
running_mean_arr = running_mean_e = mean_e * momentum + batch_mean_e * one_minus_momentum;
mean_arr * momentum + saved_mean_e * one_minus_momentum; running_variance_e =
running_var_arr = variance_e * momentum + batch_variance_e * one_minus_momentum;
variance_arr * momentum + saved_variance_e * one_minus_momentum;
} }
y->set_layout(DataLayout::kMKLDNN);
y->set_format(
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
} }
}; };
...@@ -217,11 +212,6 @@ template <typename T> ...@@ -217,11 +212,6 @@ template <typename T>
class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override { void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto data_layout_str = ctx.Attr<std::string>("data_layout");
auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW,
"MKLDNN batch normalization handles only NCHW data layout");
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine(); auto mkldnn_engine = dev_ctx.GetEngine();
...@@ -238,88 +228,132 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -238,88 +228,132 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
diff_x->mutable_data<T>(ctx.GetPlace()); PADDLE_ENFORCE(diff_y->layout() == DataLayout::kMKLDNN &&
diff_scale->mutable_data<T>(ctx.GetPlace()); diff_y->format() != memory::format::format_undef,
diff_shift->mutable_data<T>(ctx.GetPlace()); "Wrong layout/format set for Input diff_y tensor");
const T *x_data = x->data<T>();
const T *diff_y_data = diff_y->data<T>();
const T *batch_mean_data = batch_mean->data<T>();
const T *batch_variance_data = batch_variance->data<T>();
const T *scale_data = scale->data<T>();
const T *shift_data = shift->data<T>();
T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace());
T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace());
auto src_tz = paddle::framework::vectorize2int(x->dims());
auto diff_src_tz = src_tz;
auto dst_tz = src_tz;
auto diff_dst_tz = dst_tz;
auto scale_tz = paddle::framework::vectorize2int(scale->dims());
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
const unsigned int ic = scale_tz[0];
// Retrieve bn_fwd_pd from device context
const std::string key = ctx.op().Input("SavedMean");
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
auto batch_norm_fwd_pd =
std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx.GetBlob(key_batch_norm_fwd_pd));
PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr,
"Fail to find batch_norm_fwd_pd in device context");
auto dims = paddle::framework::vectorize2int(x->dims()); using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
unsigned flags = mkldnn::use_scale_shift | !mkldnn::use_global_stats;
auto src_md = // create mkldnn memory from input diff_y tensor
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); auto user_diff_dst_memory =
auto dst_md = memory({{{diff_dst_tz}, memory::data_type::f32, diff_y->format()},
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); mkldnn_engine},
auto diff_src_md = to_void_cast(diff_y_data));
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto diff_dst_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>; // create mkldnn memory from input x tensor
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>; auto src_memory =
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine},
to_void_cast(x_data));
auto batch_norm_fwd_desc = bn_fwd_types::op_desc{ // for diff_dst, try to use same format as dst in forward pass
mkldnn::prop_kind::forward_training, src_md, epsilon, flags}; auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc();
auto batch_norm_fwd_pd = auto diff_dst_md = diff_dst_pd.desc();
bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine};
// create primitive descriptor for batch norm backward
unsigned flags = mkldnn::use_scale_shift;
auto batch_norm_bwd_desc = bn_bwd_types::op_desc{ auto batch_norm_bwd_desc = bn_bwd_types::op_desc{
mkldnn::prop_kind::backward, diff_dst_md, dst_md, epsilon, flags}; mkldnn::prop_kind::backward, diff_dst_md,
src_memory.get_primitive_desc().desc(), epsilon, flags};
auto batch_norm_bwd_pd = bn_bwd_types::op_prim{ auto batch_norm_bwd_pd = bn_bwd_types::op_prim{
batch_norm_bwd_desc, mkldnn_engine, batch_norm_fwd_pd}; batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd};
auto src = mkldnn::memory{{src_md, mkldnn_engine}, // reorder user_diff_dst if it's not in preferred format
cast_const_to_void(x->data<T>())}; auto diff_dst_memory = user_diff_dst_memory;
primitive reorder_diff_dst;
auto mean = mkldnn::memory{batch_norm_bwd_pd.mean_primitive_desc(), bool is_diff_dst_reordered = false;
cast_const_to_void(batch_mean->data<T>())}; if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory = memory(diff_dst_pd);
auto variance = reorder_diff_dst = reorder(user_diff_dst_memory, diff_dst_memory);
mkldnn::memory{batch_norm_bwd_pd.variance_primitive_desc(), is_diff_dst_reordered = true;
cast_const_to_void(batch_variance->data<T>())}; }
auto diff_dst = mkldnn::memory{{diff_dst_md, mkldnn_engine},
cast_const_to_void(diff_y->data<T>())};
const unsigned int ic = dims[1]; // create mkldnn memory for input tensors (src/mean/variance)
auto mean_memory = memory(batch_norm_bwd_pd.mean_primitive_desc(),
to_void_cast(batch_mean_data));
auto variance_memory = memory(batch_norm_bwd_pd.variance_primitive_desc(),
to_void_cast(batch_variance_data));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic; const size_t scaleshift_size = 2 * ic;
std::vector<T> scaleshift_data; std::vector<T> scaleshift_data;
scaleshift_data.reserve(scaleshift_size); scaleshift_data.reserve(scaleshift_size);
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(), copy_to_weights(scale_data, scale_data + ic, shift_data, shift_data + ic,
shift->data<T>() + ic, &scaleshift_data); &scaleshift_data);
auto scaleshift_memory = mkldnn::memory{ // create mkldnn memory for input tensors (scale/shift)
batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data()}; auto scaleshift_memory = memory(batch_norm_bwd_pd.weights_primitive_desc(),
scaleshift_data.data());
// create mkldnn memory for output diff weights (combined scale/shift)
std::vector<T> diff_scaleshift_data; std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size); diff_scaleshift_data.reserve(scaleshift_size);
copy_to_weights(diff_scale->data<T>(), diff_scale->data<T>() + ic,
diff_shift->data<T>(), diff_shift->data<T>() + ic,
&diff_scaleshift_data);
auto diff_scaleshift_memory = auto diff_scaleshift_memory =
mkldnn::memory{batch_norm_bwd_pd.diff_weights_primitive_desc(), memory(batch_norm_bwd_pd.diff_weights_primitive_desc(),
diff_scaleshift_data.data()}; diff_scaleshift_data.data());
auto diff_src = mkldnn::memory{{diff_src_md, mkldnn_engine}, // here assume diff_src is in the same format of src
static_cast<void *>(diff_x->data<T>())}; auto diff_src_memory = memory(src_memory.get_primitive_desc(), diff_x_data);
run_batch_norm_op<bn_bwd_types::op_type>( // finally create batch_norm backward primitive
batch_norm_bwd_pd, src, mean, variance, diff_dst, scaleshift_memory, auto batch_norm_bwd_prim =
diff_src, diff_scaleshift_memory); batch_norm_bwd(batch_norm_bwd_pd, src_memory, mean_memory,
variance_memory, diff_dst_memory, scaleshift_memory,
diff_src_memory, diff_scaleshift_memory);
// execute optional reorder and batch_norm backward primitive
std::vector<primitive> pipeline;
if (is_diff_dst_reordered) pipeline.push_back(reorder_diff_dst);
pipeline.push_back(batch_norm_bwd_prim);
stream(stream::kind::eager).submit(pipeline).wait();
// copy back diff sacle/shift to output tensors (diff scale/shift)
diff_scaleshift_data.resize(scaleshift_size);
auto it = std::begin(diff_scaleshift_data); auto it = std::begin(diff_scaleshift_data);
std::copy(it, std::next(it, ic), diff_scale->data<T>()); std::copy(it, std::next(it, ic), diff_scale_data);
std::copy(std::next(it, ic), std::end(diff_scaleshift_data), std::copy(std::next(it, ic), std::end(diff_scaleshift_data),
diff_shift->data<T>()); diff_shift_data);
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory.get_primitive_desc()
.desc()
.data.format);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(batch_norm, MKLDNN, paddle::platform::CPUPlace, REGISTER_OP_KERNEL(batch_norm, MKLDNN, ::paddle::platform::CPUPlace,
ops::BatchNormMKLDNNOpKernel<float>); ops::BatchNormMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(batch_norm_grad, MKLDNN, paddle::platform::CPUPlace, REGISTER_OP_KERNEL(batch_norm_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::BatchNormMKLDNNGradOpKernel<float>); ops::BatchNormMKLDNNGradOpKernel<float>);
...@@ -110,19 +110,19 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -110,19 +110,19 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx.Input<Tensor>("Variance")->type()), ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type"); "Variance input should be of float type");
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library_); library);
} }
}; };
...@@ -368,19 +368,21 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -368,19 +368,21 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
PADDLE_THROW("can't find Y@GRAD"); PADDLE_THROW("can't find Y@GRAD");
} }
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_); layout, library);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册