提交 e94b26da 编写于 作者: A Adam 提交者: Tao Luo

using MKLDNNMemoryFormat = mkldnn::memory::format changes (#19568)

* using MKLDNNMemoryFormat = mkldnn::memory::format changes
test=develop

* PADDLE_ENFORCE update
test=develop
上级 e045aadf
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
#endif #endif
...@@ -135,8 +134,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -135,8 +134,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out, const Tensor& in, Tensor* out,
platform::Place place) { platform::Place place) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
PADDLE_ENFORCE(in.format() != memory::format::format_undef && PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::format_undef,
in.format() != memory::format::any, "Input tensor should have specified memory format");
PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::any,
"Input tensor should have specified memory format"); "Input tensor should have specified memory format");
// Set default as NCHW in case not specified // Set default as NCHW in case not specified
...@@ -183,7 +183,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -183,7 +183,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
} }
out->set_layout(out_layout); out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel // reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(memory::format::format_undef); out->set_format(MKLDNNMemoryFormat::format_undef);
#endif #endif
} }
......
...@@ -21,30 +21,33 @@ ...@@ -21,30 +21,33 @@
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using MKLDNNFormat = mkldnn::memory::format;
using MKLDNNDataType = mkldnn::memory::data_type; using MKLDNNDataType = mkldnn::memory::data_type;
inline MKLDNNFormat ToMKLDNNFormat(const DataLayout& layout) { inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) {
switch (layout) { switch (layout) {
case DataLayout::kNHWC: case DataLayout::kNHWC:
return MKLDNNFormat::nhwc; return MKLDNNMemoryFormat::nhwc;
case DataLayout::kNCHW: case DataLayout::kNCHW:
return MKLDNNFormat::nchw; return MKLDNNMemoryFormat::nchw;
default: default:
PADDLE_THROW("Fail to convert layout %s to MKLDNN format", PADDLE_THROW("Fail to convert layout %s to MKLDNN format",
DataLayoutToString(layout)); DataLayoutToString(layout));
} }
} }
inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) { inline DataLayout ToPaddleLayout(const MKLDNNMemoryFormat& format) {
switch (format) { switch (format) {
case MKLDNNFormat::nhwc: case MKLDNNMemoryFormat::nhwc:
return DataLayout::kNHWC; return DataLayout::kNHWC;
case MKLDNNFormat::nchw: case MKLDNNMemoryFormat::nchw:
return DataLayout::kNCHW; return DataLayout::kNCHW;
default: default:
PADDLE_THROW("Fail to convert MKLDNN format to paddle layout"); PADDLE_THROW("Fail to convert MKLDNN format to paddle layout");
......
...@@ -55,20 +55,20 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -55,20 +55,20 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
// broadcast operations need to performed. // broadcast operations need to performed.
if (x_dims != y_dims_untrimed) { if (x_dims != y_dims_untrimed) {
Tensor _x; Tensor _x;
mkldnn::memory::format format; MKLDNNMemoryFormat format;
std::vector<int> src_x_tz = framework::vectorize2int(x_dims); std::vector<int> src_x_tz = framework::vectorize2int(x_dims);
if ((src_x_tz.size() == 3 && if ((src_x_tz.size() == 3 &&
x->format() != (format = memory::format::ncw)) || x->format() != (format = MKLDNNMemoryFormat::ncw)) ||
(src_x_tz.size() == 4 && (src_x_tz.size() == 4 &&
x->format() != (format = memory::format::nchw)) || x->format() != (format = MKLDNNMemoryFormat::nchw)) ||
(src_x_tz.size() == 5 && (src_x_tz.size() == 5 &&
x->format() != (format = memory::format::ncdhw))) { x->format() != (format = MKLDNNMemoryFormat::ncdhw))) {
_x.Resize(x_dims); _x.Resize(x_dims);
mkldnn::memory::data_type in_type = platform::MKLDNNGetDataType<T>(); mkldnn::memory::data_type in_type = platform::MKLDNNGetDataType<T>();
auto out_format = platform::MKLDNNFormatForSize( auto out_format = platform::MKLDNNFormatForSize(
x_dims.size(), mkldnn::memory::format::nchw); x_dims.size(), MKLDNNMemoryFormat::nchw);
const std::string key = platform::ReorderMKLDNNHandler::GetHash( const std::string key = platform::ReorderMKLDNNHandler::GetHash(
src_x_tz, x->format(), out_format, std::to_string(in_type)); src_x_tz, x->format(), out_format, std::to_string(in_type));
...@@ -119,12 +119,15 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -119,12 +119,15 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
z->set_layout(DataLayout::kMKLDNN); z->set_layout(DataLayout::kMKLDNN);
z->set_format(format); z->set_format(format);
} else { } else {
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
x->format() != memory::format::format_undef, "Wrong layout set for X tensor");
"Wrong layout/format set for X tensor"); PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::format_undef,
PADDLE_ENFORCE(y->layout() == DataLayout::kMKLDNN && "Wrong format set for X tensor");
y->format() != memory::format::format_undef,
"Wrong layout/format set for Y tensor"); PADDLE_ENFORCE_EQ(y->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Y tensor");
PADDLE_ENFORCE_NE(y->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for Y tensor");
std::vector<int> src_x_tz = framework::vectorize2int(x_dims); std::vector<int> src_x_tz = framework::vectorize2int(x_dims);
std::vector<int> src_y_tz = framework::vectorize2int(y_dims_untrimed); std::vector<int> src_y_tz = framework::vectorize2int(y_dims_untrimed);
...@@ -148,7 +151,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -148,7 +151,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
paddle::platform::to_void_cast(y_data)); paddle::platform::to_void_cast(y_data));
auto dst_md = memory::desc({dst_tz}, platform::MKLDNNGetDataType<T>(), auto dst_md = memory::desc({dst_tz}, platform::MKLDNNGetDataType<T>(),
memory::format::any); MKLDNNMemoryFormat::any);
auto sum_pd = handler.AcquireSumPrimitiveDescriptor( auto sum_pd = handler.AcquireSumPrimitiveDescriptor(
{src_x_memory, src_y_memory}, scales, dst_md); {src_x_memory, src_y_memory}, scales, dst_md);
...@@ -164,8 +167,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -164,8 +167,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
z->set_layout(DataLayout::kMKLDNN); z->set_layout(DataLayout::kMKLDNN);
z->set_format( z->set_format((MKLDNNMemoryFormat)dst_memory->get_primitive_desc()
(memory::format)dst_memory->get_primitive_desc().desc().data.format); .desc()
.data.format);
} }
} }
}; };
......
...@@ -37,7 +37,7 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx, ...@@ -37,7 +37,7 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx,
if (ctx.op().HasAttr(attribute)) { if (ctx.op().HasAttr(attribute)) {
auto format_as_string = ctx.Attr<std::string>(attribute); auto format_as_string = ctx.Attr<std::string>(attribute);
auto format = StringToMKLDNNFormat(&format_as_string); auto format = StringToMKLDNNFormat(&format_as_string);
if (format != memory::format::any) { if (format != MKLDNNMemoryFormat::any) {
tensor->set_format(format); tensor->set_format(format);
} }
} }
...@@ -51,7 +51,8 @@ static void ReorderInput(framework::Tensor* tensor, ...@@ -51,7 +51,8 @@ static void ReorderInput(framework::Tensor* tensor,
auto dims = paddle::framework::vectorize2int(tensor->dims()); auto dims = paddle::framework::vectorize2int(tensor->dims());
framework::Tensor out_tensor; framework::Tensor out_tensor;
out_tensor.Resize(tensor->dims()); out_tensor.Resize(tensor->dims());
out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc); out_tensor.set_format(isFourDim ? MKLDNNMemoryFormat::nchw
: MKLDNNMemoryFormat::nc);
out_tensor.set_layout(tensor->layout()); out_tensor.set_layout(tensor->layout());
mkldnn::memory input_memory = { mkldnn::memory input_memory = {
{{dims, platform::MKLDNNGetDataType<T>(), tensor->format()}, engine}, {{dims, platform::MKLDNNGetDataType<T>(), tensor->format()}, engine},
...@@ -86,8 +87,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -86,8 +87,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
const bool is_avx512_enabled = platform::MayIUse(platform::avx512f); const bool is_avx512_enabled = platform::MayIUse(platform::avx512f);
const bool are_dims_divisable = !(x_int_dims[1] % 16); const bool are_dims_divisable = !(x_int_dims[1] % 16);
const bool is_x_format_correct = x->format() == memory::format::nChw16c; const bool is_x_format_correct = x->format() == MKLDNNMemoryFormat::nChw16c;
const bool is_y_format_correct = y->format() == memory::format::nc; const bool is_y_format_correct = y->format() == MKLDNNMemoryFormat::nc;
if (is_x_format_correct && is_y_format_correct && are_dims_divisable && if (is_x_format_correct && is_y_format_correct && are_dims_divisable &&
is_avx512_enabled) { is_avx512_enabled) {
int pre, n, post; int pre, n, post;
...@@ -133,12 +134,12 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -133,12 +134,12 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
} else { } else {
// Fallback to naive version: // Fallback to naive version:
const bool are_inputs_in_same_format = x->format() == y->format(); const bool are_inputs_in_same_format = x->format() == y->format();
const bool is_x_nchw = x->format() == memory::format::nchw; const bool is_x_nchw = x->format() == MKLDNNMemoryFormat::nchw;
const bool is_x_nc = x->format() == memory::format::nc; const bool is_x_nc = x->format() == MKLDNNMemoryFormat::nc;
const bool is_x_x = x->format() == memory::format::x; const bool is_x_x = x->format() == MKLDNNMemoryFormat::x;
const bool is_y_nchw = y->format() == memory::format::nchw; const bool is_y_nchw = y->format() == MKLDNNMemoryFormat::nchw;
const bool is_y_nc = y->format() == memory::format::nc; const bool is_y_nc = y->format() == MKLDNNMemoryFormat::nc;
const bool is_y_x = y->format() == memory::format::x; const bool is_y_x = y->format() == MKLDNNMemoryFormat::x;
if (!are_inputs_in_same_format) { if (!are_inputs_in_same_format) {
using platform::MKLDNNDeviceContext; using platform::MKLDNNDeviceContext;
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
......
...@@ -47,9 +47,10 @@ class MKLDNNActivationKernel ...@@ -47,9 +47,10 @@ class MKLDNNActivationKernel
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
x->format() != memory::format::format_undef, "Wrong layout set for X tensor");
"Wrong layout/format set for Input x tensor"); PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for X tensor");
Functor functor; Functor functor;
functor(ctx); functor(ctx);
...@@ -62,12 +63,13 @@ class MKLDNNActivationGradKernel ...@@ -62,12 +63,13 @@ class MKLDNNActivationGradKernel
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(diff_y->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
diff_y->format() != memory::format::format_undef, "Wrong layout set for Input OutGrad tensor");
"Wrong layout/format set for Input OutGrad tensor"); PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for Input OutGrad tensor");
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
!ctx.Attr<bool>("is_test"), ctx.Attr<bool>("is_test"), false,
"is_test attribute should be set to False in training phase."); "is_test attribute should be set to False in training phase.");
Functor functor; Functor functor;
...@@ -97,8 +99,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -97,8 +99,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
std::vector<int> src_tz = framework::vectorize2int(x->dims()); std::vector<int> src_tz = framework::vectorize2int(x->dims());
auto src_format = auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
bool is_test = ctx.Attr<bool>("is_test"); bool is_test = ctx.Attr<bool>("is_test");
...@@ -152,10 +153,10 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -152,10 +153,10 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
// diff_dst and src dims should be the same // diff_dst and src dims should be the same
auto src_format = auto src_format =
diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : x->format(); diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
auto diff_y_format = auto diff_y_format =
diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format(); diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : diff_y->format();
auto diff_dst_md = platform::MKLDNNMemDesc( auto diff_dst_md = platform::MKLDNNMemDesc(
diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format); diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
......
...@@ -121,7 +121,8 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -121,7 +121,8 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
} }
static std::string GetHash(const memory::dims &input_dims, float epsilon, static std::string GetHash(const memory::dims &input_dims, float epsilon,
unsigned flag, bool is_test, memory::format format, unsigned flag, bool is_test,
MKLDNNMemoryFormat format,
const std::string &suffix = "") { const std::string &suffix = "") {
auto dims2str = [](const memory::dims &operand_dims) { auto dims2str = [](const memory::dims &operand_dims) {
std::string dstr = ""; std::string dstr = "";
...@@ -191,9 +192,10 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -191,9 +192,10 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *shift = ctx.Input<Tensor>("Bias"); const auto *shift = ctx.Input<Tensor>("Bias");
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
x->format() != memory::format::format_undef, "Wrong layout set for X tensor");
"Wrong layout/format set for Input x tensor"); PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for X tensor");
const T *x_data = x->data<T>(); const T *x_data = x->data<T>();
const T *mean_data = mean->data<T>(); const T *mean_data = mean->data<T>();
...@@ -230,7 +232,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -230,7 +232,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor // create mkldnn memory from input x tensor
mkldnn::memory::format input_format = MKLDNNMemoryFormat input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format()); platform::MKLDNNFormatForSize(src_tz.size(), x->format());
// keys for backward pass // keys for backward pass
...@@ -331,9 +333,10 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -331,9 +333,10 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
PADDLE_ENFORCE(diff_y->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
diff_y->format() != memory::format::format_undef, "Wrong layout set for Input diff_y tensor");
"Wrong layout/format set for Input diff_y tensor"); PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for Input diff_y tensor");
const T *x_data = x->data<T>(); const T *x_data = x->data<T>();
const T *diff_y_data = diff_y->data<T>(); const T *diff_y_data = diff_y->data<T>();
...@@ -357,10 +360,10 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -357,10 +360,10 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>; using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
mkldnn::memory::format dst_format = MKLDNNMemoryFormat dst_format =
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format()); platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
mkldnn::memory::format input_format = MKLDNNMemoryFormat input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format()); platform::MKLDNNFormatForSize(src_tz.size(), x->format());
unsigned flags = mkldnn::use_scale_shift; unsigned flags = mkldnn::use_scale_shift;
...@@ -481,7 +484,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -481,7 +484,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// set layout/format of output tensors // set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN); diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc() diff_x->set_format(
(MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc()
.desc() .desc()
.data.format); .data.format);
} else { } else {
...@@ -509,7 +513,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -509,7 +513,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// set layout/format of output tensors // set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN); diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc() diff_x->set_format(
(MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc()
.desc() .desc()
.data.format); .data.format);
} }
......
...@@ -30,11 +30,10 @@ using platform::to_void_cast; ...@@ -30,11 +30,10 @@ using platform::to_void_cast;
static void EnforceLayouts(const std::vector<const Tensor*> inputs) { static void EnforceLayouts(const std::vector<const Tensor*> inputs) {
for (auto* input : inputs) { for (auto* input : inputs) {
const bool is_layout_correct = input->layout() == DataLayout::kMKLDNN; PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
const bool is_format_defined = "Wrong layout set for Input tensor");
input->format() != memory::format::format_undef; PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef,
PADDLE_ENFORCE(is_layout_correct && is_format_defined, "Wrong format set for Input tensor");
"Wrong layout/format set for Input tensor");
} }
} }
...@@ -48,9 +47,9 @@ static memory::primitive_desc CreateMemPrimDesc(const Tensor& input, ...@@ -48,9 +47,9 @@ static memory::primitive_desc CreateMemPrimDesc(const Tensor& input,
return mem_prim_desc; return mem_prim_desc;
} }
static mkldnn::memory::format GetDstMemFormat( static MKLDNNMemoryFormat GetDstMemFormat(
const concat::primitive_desc& concat_pd) { const concat::primitive_desc& concat_pd) {
return (memory::format)concat_pd.dst_primitive_desc().desc().data.format; return (MKLDNNMemoryFormat)concat_pd.dst_primitive_desc().desc().data.format;
} }
static platform::CPUPlace GetCpuPlace( static platform::CPUPlace GetCpuPlace(
...@@ -126,7 +125,7 @@ class ConcatPrimitiveFactory { ...@@ -126,7 +125,7 @@ class ConcatPrimitiveFactory {
memory::desc CreateDstMemDescriptor(Tensor* output, memory::desc CreateDstMemDescriptor(Tensor* output,
const memory::data_type& dt) { const memory::data_type& dt) {
auto dst_dims = paddle::framework::vectorize2int(output->dims()); auto dst_dims = paddle::framework::vectorize2int(output->dims());
return memory::desc(dst_dims, dt, memory::format::any); return memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any);
} }
mkldnn::memory CreateDstMemory(const concat::primitive_desc& concat_pd, mkldnn::memory CreateDstMemory(const concat::primitive_desc& concat_pd,
......
...@@ -60,12 +60,12 @@ inline void GetWeightsTz(std::vector<int>& weights_tz, int groups, // NOLINT ...@@ -60,12 +60,12 @@ inline void GetWeightsTz(std::vector<int>& weights_tz, int groups, // NOLINT
} }
} }
inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format, inline MKLDNNMemoryFormat GetWeightsFormat(MKLDNNMemoryFormat format,
int groups, bool is_conv3d) { int groups, bool is_conv3d) {
if (is_conv3d) { if (is_conv3d) {
return (groups == 1) ? format : mkldnn::memory::format::goidhw; return (groups == 1) ? format : MKLDNNMemoryFormat::goidhw;
} else { } else {
return (groups == 1) ? format : mkldnn::memory::format::goihw; return (groups == 1) ? format : MKLDNNMemoryFormat::goihw;
} }
} }
...@@ -129,21 +129,37 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -129,21 +129,37 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr; auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
input->format() != memory::format::format_undef, "Wrong layout set for Input tensor");
"Wrong layout/format set for Input tensor"); PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef,
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN && "Wrong format set for Input tensor");
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor"); PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN,
PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5, "Wrong layout set for Filter tensor");
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for Filter tensor");
PADDLE_ENFORCE_GE(
input->dims().size(), 4,
"Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
PADDLE_ENFORCE_LE(
input->dims().size(), 5,
"Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW"); "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5,
PADDLE_ENFORCE_GE(
filter->dims().size(), 4,
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
PADDLE_ENFORCE_LE(
filter->dims().size(), 5,
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW"); "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
if (bias) { if (bias) {
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(bias->layout(), DataLayout::kMKLDNN,
bias->format() != memory::format::format_undef, "Wrong layout set for Bias tensor");
"Wrong layout/format set for Bias tensor"); PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::format_undef,
PADDLE_ENFORCE(bias->dims().size() == 1, "Wrong format set for Bias tensor");
PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
"Bias must only have 1 dimension, i.e. X"); "Bias must only have 1 dimension, i.e. X");
} }
...@@ -182,7 +198,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -182,7 +198,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
auto src_format = input->format(); auto src_format = input->format();
mkldnn::memory::format weights_format = MKLDNNMemoryFormat weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d); GetWeightsFormat(filter->format(), g, is_conv3d);
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
...@@ -198,9 +214,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -198,9 +214,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
weights_format = mkldnn::memory::format::any; weights_format = MKLDNNMemoryFormat::any;
// Check the format for user's special output // Check the format for user's special output
if (chosen_memory_format != mkldnn::memory::format::any) { if (chosen_memory_format != MKLDNNMemoryFormat::any) {
if (is_conv3d) { if (is_conv3d) {
chosen_memory_format = chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
...@@ -224,7 +240,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -224,7 +240,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn, fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
...@@ -295,7 +311,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -295,7 +311,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (bias) { if (bias) {
const T* bias_data = bias->data<T>(); const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x); {bias_tz}, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
user_bias_memory_p = user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data)); handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
...@@ -328,21 +344,37 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -328,21 +344,37 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr; auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
input->format() != memory::format::format_undef, "Wrong layout set for Input tensor");
"Wrong layout/format set for Input tensor"); PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef,
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN && "Wrong format set for Input tensor");
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor"); PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN,
PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5, "Wrong layout set for Filter tensor");
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for Filter tensor");
PADDLE_ENFORCE_GE(
input->dims().size(), 4,
"Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW"); "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5, PADDLE_ENFORCE_LE(
input->dims().size(), 5,
"Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
PADDLE_ENFORCE_GE(
filter->dims().size(), 4,
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW"); "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
PADDLE_ENFORCE_LE(
filter->dims().size(), 5,
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
if (bias) { if (bias) {
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(bias->layout(), DataLayout::kMKLDNN,
bias->format() != memory::format::format_undef, "Wrong layout set for Bias tensor");
"Wrong layout/format set for Bias tensor"); PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::format_undef,
PADDLE_ENFORCE(bias->dims().size() == 1, "Wrong format set for Bias tensor");
PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
"Bias must only have 1 dimension, i.e. X"); "Bias must only have 1 dimension, i.e. X");
} }
...@@ -456,8 +488,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -456,8 +488,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::MKLDNNMemDesc({src_tz}, src_dt, input->format()); platform::MKLDNNMemDesc({src_tz}, src_dt, input->format());
auto user_weights_md = platform::MKLDNNMemDesc( auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<K>(), {weights_tz}, platform::MKLDNNGetDataType<K>(),
((g) == 1) ? mkldnn::memory::format::oihw ((g) == 1) ? MKLDNNMemoryFormat::oihw : MKLDNNMemoryFormat::goihw);
: mkldnn::memory::format::goihw);
/* create memory descriptor for convolution without specified format /* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose * ('any') which lets a primitive (convolution in this case) choose
...@@ -485,7 +516,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -485,7 +516,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
mkldnn::memory::format::x); MKLDNNMemoryFormat::x);
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor( conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, src_md, weights_md, bias_md, dst_md, strides, paddings,
mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
...@@ -545,7 +576,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -545,7 +576,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (bias) { if (bias) {
const K* bias_data = bias->data<K>(); const K* bias_data = bias->data<K>();
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<K>(), memory::format::x); {bias_tz}, platform::MKLDNNGetDataType<K>(), MKLDNNMemoryFormat::x);
auto user_bias_memory_p = handler->AcquireBiasMemory( auto user_bias_memory_p = handler->AcquireBiasMemory(
user_bias_md, to_void_cast<K>(bias_data)); user_bias_md, to_void_cast<K>(bias_data));
std::shared_ptr<mkldnn::memory> bias_memory_p; std::shared_ptr<mkldnn::memory> bias_memory_p;
...@@ -641,18 +672,23 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -641,18 +672,23 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input")); Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter")); Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
input->format() != memory::format::format_undef, "Wrong layout set for Input tensor");
"Wrong layout/format set for Input tensor"); PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef,
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN && "Wrong format set for Input tensor");
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(output_grad->layout() == DataLayout::kMKLDNN &&
output_grad->format() != memory::format::format_undef,
"Wrong layout/format set for output_grad tensor");
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN,
!ctx.Attr<bool>("is_test"), "Wrong layout set for Filter tensor");
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for Filter tensor");
PADDLE_ENFORCE_EQ(output_grad->layout(), DataLayout::kMKLDNN,
"Wrong layout set for output_grad tensor");
PADDLE_ENFORCE_NE(output_grad->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for output_grad tensor");
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
"is_test attribute should be set to False in training phase."); "is_test attribute should be set to False in training phase.");
if (!input_grad && !filter_grad) return; if (!input_grad && !filter_grad) return;
...@@ -677,7 +713,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -677,7 +713,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dst_tz = std::vector<int> dst_tz =
paddle::framework::vectorize2int(output_grad->dims()); paddle::framework::vectorize2int(output_grad->dims());
auto src_format = input->format(); auto src_format = input->format();
mkldnn::memory::format weights_format = MKLDNNMemoryFormat weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d); GetWeightsFormat(filter->format(), g, is_conv3d);
// Get an unique name from "argument" name of "input" and "Filter" variable // Get an unique name from "argument" name of "input" and "Filter" variable
...@@ -706,9 +742,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -706,9 +742,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
weights_format = mkldnn::memory::format::any; weights_format = MKLDNNMemoryFormat::any;
// Check the format for user's special output // Check the format for user's special output
if (chosen_memory_format != mkldnn::memory::format::any) { if (chosen_memory_format != MKLDNNMemoryFormat::any) {
if (is_conv3d) { if (is_conv3d) {
chosen_memory_format = chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
......
...@@ -45,22 +45,28 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -45,22 +45,28 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr; auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
input->format() != mkldnn::memory::format::format_undef, "Wrong layout set for Input tensor");
"Wrong layout/format set for Input tensor"); PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef,
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN && "Wrong format set for Input tensor");
filter->format() != mkldnn::memory::format::format_undef,
"Wrong layout/format set for Filter tensor"); PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN,
PADDLE_ENFORCE(input->dims().size() == 4, "Wrong layout set for Filter tensor");
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for Filter tensor");
PADDLE_ENFORCE_EQ(input->dims().size(), 4,
"Input must be with 4 dimensions, i.e. NCHW"); "Input must be with 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(filter->dims().size() == 4, PADDLE_ENFORCE_EQ(filter->dims().size(), 4,
"Filter must be with 4 dimensions, i.e. OIHW"); "Filter must be with 4 dimensions, i.e. OIHW");
if (bias) { if (bias) {
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(bias->layout(), DataLayout::kMKLDNN,
bias->format() != mkldnn::memory::format::format_undef, "Wrong layout set for Bias tensor");
"Wrong layout/format set for Bias tensor"); PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::format_undef,
PADDLE_ENFORCE(bias->dims().size() == 1, "Wrong format set for Bias tensor");
PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
"Bias must only have 1 dimension, i.e. X"); "Bias must only have 1 dimension, i.e. X");
} }
...@@ -129,10 +135,9 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -129,10 +135,9 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format()); {src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
auto user_weights_md = auto user_weights_md = platform::MKLDNNMemDesc(
platform::MKLDNNMemDesc({weights_tz}, platform::MKLDNNGetDataType<T>(), {weights_tz}, platform::MKLDNNGetDataType<T>(),
(g == 1) ? mkldnn::memory::format::oihw (g == 1) ? MKLDNNMemoryFormat::oihw : MKLDNNMemoryFormat::goihw);
: mkldnn::memory::format::goihw);
/* create memory descriptor for convolution without specified format /* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose * ('any') which lets a primitive (convolution in this case) choose
...@@ -163,7 +168,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -163,7 +168,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), mkldnn::memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
fuse_activation, fuse_alpha, fuse_beta, false, fwd_prop_kind); fuse_activation, fuse_alpha, fuse_beta, false, fwd_prop_kind);
...@@ -198,9 +203,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -198,9 +203,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::deconvolution_forward> conv_p; std::shared_ptr<mkldnn::deconvolution_forward> conv_p;
if (bias) { if (bias) {
const T* bias_data = bias->data<T>(); const T* bias_data = bias->data<T>();
auto user_bias_md = auto user_bias_md = platform::MKLDNNMemDesc(
platform::MKLDNNMemDesc({bias_tz}, platform::MKLDNNGetDataType<T>(), {bias_tz}, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
mkldnn::memory::format::x);
auto user_bias_memory_p = handler.AcquireBiasMemory( auto user_bias_memory_p = handler.AcquireBiasMemory(
user_bias_md, platform::to_void_cast<T>(bias_data)); user_bias_md, platform::to_void_cast<T>(bias_data));
......
...@@ -63,7 +63,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -63,7 +63,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
mkldnn::memory::format src_fmt = input->format(); MKLDNNMemoryFormat src_fmt = input->format();
std::string key = CreateKey(ctx, src_dt, src_tz, reorder_scale[0]); std::string key = CreateKey(ctx, src_dt, src_tz, reorder_scale[0]);
const std::string key_prim = key + "@reorder_p"; const std::string key_prim = key + "@reorder_p";
const std::string key_src_mem = key + "@src_mem"; const std::string key_src_mem = key + "@src_mem";
...@@ -87,7 +87,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -87,7 +87,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
std::shared_ptr<primitive::at>(new primitive::at(*src_memory)); std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32, auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32,
memory::format::nchw); MKLDNNMemoryFormat::nchw);
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine); auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
dst_memory = std::make_shared<mkldnn::memory>( dst_memory = std::make_shared<mkldnn::memory>(
dst_pd, to_void_cast<float>(output_data)); dst_pd, to_void_cast<float>(output_data));
......
...@@ -59,7 +59,7 @@ class FCPrimitiveFactory { ...@@ -59,7 +59,7 @@ class FCPrimitiveFactory {
weights_ = CreateFourDimWeightsMemory(input, weights); weights_ = CreateFourDimWeightsMemory(input, weights);
} }
auto dst_desc = CreateMemDescriptor(output, memory::format::any); auto dst_desc = CreateMemDescriptor(output, MKLDNNMemoryFormat::any);
fc_ = CreateFcPrimitive(*input_, *weights_, dst_desc, bias, output, ctx); fc_ = CreateFcPrimitive(*input_, *weights_, dst_desc, bias, output, ctx);
return *fc_; return *fc_;
...@@ -70,14 +70,14 @@ class FCPrimitiveFactory { ...@@ -70,14 +70,14 @@ class FCPrimitiveFactory {
const Tensor* in) { const Tensor* in) {
input_->set_data_handle(const_cast<T*>(in->data<T>())); input_->set_data_handle(const_cast<T*>(in->data<T>()));
output_->set_data_handle(out->mutable_data<T>(ctx.GetPlace())); output_->set_data_handle(out->mutable_data<T>(ctx.GetPlace()));
if (out->format() == memory::format::format_undef) { if (out->format() == MKLDNNMemoryFormat::format_undef) {
auto output_format = output_->get_primitive_desc().desc().data.format; auto output_format = output_->get_primitive_desc().desc().data.format;
out->set_format((memory::format)output_format); out->set_format((MKLDNNMemoryFormat)output_format);
} }
} }
memory::format MatchWeightFormat(memory::format fmt) { MKLDNNMemoryFormat MatchWeightFormat(MKLDNNMemoryFormat fmt) {
using format = memory::format; using format = MKLDNNMemoryFormat;
switch (fmt) { switch (fmt) {
case format::nChw16c: case format::nChw16c:
return format::oIhw16i; return format::oIhw16i;
...@@ -102,13 +102,13 @@ class FCPrimitiveFactory { ...@@ -102,13 +102,13 @@ class FCPrimitiveFactory {
} }
static mkldnn::memory::desc CreateMemDescriptor(const std::vector<int>& dims, static mkldnn::memory::desc CreateMemDescriptor(const std::vector<int>& dims,
memory::format format) { MKLDNNMemoryFormat format) {
return platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(), return platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(),
format); format);
} }
static mkldnn::memory::desc CreateMemDescriptor(const Tensor* tensor, static mkldnn::memory::desc CreateMemDescriptor(const Tensor* tensor,
memory::format format) { MKLDNNMemoryFormat format) {
auto dims = framework::vectorize2int(tensor->dims()); auto dims = framework::vectorize2int(tensor->dims());
return CreateMemDescriptor(dims, format); return CreateMemDescriptor(dims, format);
} }
...@@ -126,8 +126,8 @@ class FCPrimitiveFactory { ...@@ -126,8 +126,8 @@ class FCPrimitiveFactory {
mkldnn::memory TransposeWeights(const Tensor* weights) { mkldnn::memory TransposeWeights(const Tensor* weights) {
auto dims = framework::vectorize2int(weights->dims()); auto dims = framework::vectorize2int(weights->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor(dims, memory::format::io); auto src_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor(dims, memory::format::oi); auto dst_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::oi);
return Reorder(src_desc, dst_desc, weights->data<T>()); return Reorder(src_desc, dst_desc, weights->data<T>());
} }
...@@ -187,7 +187,7 @@ class FCPrimitiveFactory { ...@@ -187,7 +187,7 @@ class FCPrimitiveFactory {
auto dims = {weight_dims[1], input_dims[1], input_dims[2], input_dims[3]}; auto dims = {weight_dims[1], input_dims[1], input_dims[2], input_dims[3]};
auto dst_format = MatchWeightFormat(input->format()); auto dst_format = MatchWeightFormat(input->format());
auto src_desc = CreateMemDescriptor(dims, memory::format::oihw); auto src_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::oihw);
auto dst_desc = CreateMemDescriptor(dims, dst_format); auto dst_desc = CreateMemDescriptor(dims, dst_format);
return Reorder(src_desc, dst_desc, weights_->get_data_handle()); return Reorder(src_desc, dst_desc, weights_->get_data_handle());
...@@ -199,7 +199,7 @@ class FCPrimitiveFactory { ...@@ -199,7 +199,7 @@ class FCPrimitiveFactory {
auto dst_prim_desc = fc_prim_desc.dst_primitive_desc(); auto dst_prim_desc = fc_prim_desc.dst_primitive_desc();
auto buffer_size = dst_prim_desc.get_size(); auto buffer_size = dst_prim_desc.get_size();
T* output_data = output->mutable_data<T>(ctx.GetPlace(), buffer_size); T* output_data = output->mutable_data<T>(ctx.GetPlace(), buffer_size);
output->set_format((memory::format)dst_prim_desc.desc().data.format); output->set_format((MKLDNNMemoryFormat)dst_prim_desc.desc().data.format);
return memory(dst_prim_desc, to_void_cast<T>(output_data)); return memory(dst_prim_desc, to_void_cast<T>(output_data));
} }
......
...@@ -62,10 +62,10 @@ class MulPrimitiveFactory { ...@@ -62,10 +62,10 @@ class MulPrimitiveFactory {
return *mul_; return *mul_;
} }
auto src_desc = CreateMemDescriptor<XT>(&x_matrix, memory::format::nc); auto src_desc = CreateMemDescriptor<XT>(&x_matrix, MKLDNNMemoryFormat::nc);
x_input_ = CreateMemory<XT>(src_desc, &x_matrix); x_input_ = CreateMemory<XT>(src_desc, &x_matrix);
y_input_ = TransposeInputY(&y_matrix); y_input_ = TransposeInputY(&y_matrix);
auto dst_desc = CreateMemDescriptor<OT>(output, memory::format::any); auto dst_desc = CreateMemDescriptor<OT>(output, MKLDNNMemoryFormat::any);
mul_ = CreateMulPrimitive(*x_input_, *y_input_, dst_desc, output, ctx); mul_ = CreateMulPrimitive(*x_input_, *y_input_, dst_desc, output, ctx);
return *mul_; return *mul_;
...@@ -77,14 +77,14 @@ class MulPrimitiveFactory { ...@@ -77,14 +77,14 @@ class MulPrimitiveFactory {
const ExecutionContext &ctx) { const ExecutionContext &ctx) {
Tensor x_tmp; Tensor x_tmp;
Tensor data_matrix; Tensor data_matrix;
memory::format src_fmt = data->format(); MKLDNNMemoryFormat src_fmt = data->format();
memory::format dst_fmt; MKLDNNMemoryFormat dst_fmt;
auto src_mdesc = CreateMemDescriptor<T>(data, src_fmt); auto src_mdesc = CreateMemDescriptor<T>(data, src_fmt);
if ((data->dims().size() == 4 && if ((data->dims().size() == 4 &&
src_fmt != (dst_fmt = memory::format::nchw)) || src_fmt != (dst_fmt = MKLDNNMemoryFormat::nchw)) ||
(data->dims().size() == 5 && (data->dims().size() == 5 &&
src_fmt != (dst_fmt = memory::format::ncdhw))) { src_fmt != (dst_fmt = MKLDNNMemoryFormat::ncdhw))) {
auto dst_mdesc = CreateMemDescriptor<T>(data, dst_fmt); auto dst_mdesc = CreateMemDescriptor<T>(data, dst_fmt);
x_tmp.mutable_data<T>(ctx.GetPlace(), data->memory_size()); x_tmp.mutable_data<T>(ctx.GetPlace(), data->memory_size());
...@@ -92,7 +92,7 @@ class MulPrimitiveFactory { ...@@ -92,7 +92,7 @@ class MulPrimitiveFactory {
to_void_cast<T>(x_tmp.data<T>())); to_void_cast<T>(x_tmp.data<T>()));
x_tmp.Resize(data->dims()); x_tmp.Resize(data->dims());
x_tmp.set_format((memory::format)dst_mdesc.data.format); x_tmp.set_format((MKLDNNMemoryFormat)dst_mdesc.data.format);
data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims); data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims);
} else { } else {
data_matrix = framework::ReshapeToMatrix(*data, num_col_dims); data_matrix = framework::ReshapeToMatrix(*data, num_col_dims);
...@@ -106,15 +106,15 @@ class MulPrimitiveFactory { ...@@ -106,15 +106,15 @@ class MulPrimitiveFactory {
x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>())); x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>()));
output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace())); output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
if (out->format() == memory::format::format_undef) { if (out->format() == MKLDNNMemoryFormat::format_undef) {
auto output_format = output_->get_primitive_desc().desc().data.format; auto output_format = output_->get_primitive_desc().desc().data.format;
out->set_format((memory::format)output_format); out->set_format((MKLDNNMemoryFormat)output_format);
} }
} }
template <typename T> template <typename T>
memory::desc CreateMemDescriptor( memory::desc CreateMemDescriptor(
const Tensor *tensor, memory::format format, const Tensor *tensor, MKLDNNMemoryFormat format,
memory::data_type type = platform::MKLDNNGetDataType<T>()) { memory::data_type type = platform::MKLDNNGetDataType<T>()) {
auto dims = framework::vectorize2int(tensor->dims()); auto dims = framework::vectorize2int(tensor->dims());
return platform::MKLDNNMemDesc(dims, type, format); return platform::MKLDNNMemDesc(dims, type, format);
...@@ -122,7 +122,7 @@ class MulPrimitiveFactory { ...@@ -122,7 +122,7 @@ class MulPrimitiveFactory {
template <typename T> template <typename T>
memory::desc CreateMemDescriptor( memory::desc CreateMemDescriptor(
const std::vector<int> &dims, memory::format format, const std::vector<int> &dims, MKLDNNMemoryFormat format,
memory::data_type type = platform::MKLDNNGetDataType<T>()) { memory::data_type type = platform::MKLDNNGetDataType<T>()) {
return platform::MKLDNNMemDesc(dims, type, format); return platform::MKLDNNMemDesc(dims, type, format);
} }
...@@ -139,7 +139,7 @@ class MulPrimitiveFactory { ...@@ -139,7 +139,7 @@ class MulPrimitiveFactory {
auto buffer_size = dst_prim_desc.get_size(); auto buffer_size = dst_prim_desc.get_size();
OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size); OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size);
output->set_format((memory::format)dst_prim_desc.desc().data.format); output->set_format((MKLDNNMemoryFormat)dst_prim_desc.desc().data.format);
return memory(dst_prim_desc, to_void_cast<OT>(output_data)); return memory(dst_prim_desc, to_void_cast<OT>(output_data));
} }
...@@ -158,8 +158,8 @@ class MulPrimitiveFactory { ...@@ -158,8 +158,8 @@ class MulPrimitiveFactory {
memory TransposeInputY(const Tensor *input_y) { memory TransposeInputY(const Tensor *input_y) {
auto dims = framework::vectorize2int(input_y->dims()); auto dims = framework::vectorize2int(input_y->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor<YT>(dims, memory::format::io); auto src_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor<YT>(dims, memory::format::oi); auto dst_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::oi);
return Reorder(src_desc, dst_desc, to_void_cast<YT>(input_y->data<YT>())); return Reorder(src_desc, dst_desc, to_void_cast<YT>(input_y->data<YT>()));
} }
...@@ -230,15 +230,15 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> { ...@@ -230,15 +230,15 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> {
return *(this->mul_); return *(this->mul_);
} }
auto src_desc = auto src_desc = this->template CreateMemDescriptor<XT>(
this->template CreateMemDescriptor<XT>(&x_matrix, memory::format::nc); &x_matrix, MKLDNNMemoryFormat::nc);
this->x_input_ = this->template CreateMemory<XT>(src_desc, &x_matrix); this->x_input_ = this->template CreateMemory<XT>(src_desc, &x_matrix);
const auto trans_y = this->TransposeInputY(&y_matrix); const auto trans_y = this->TransposeInputY(&y_matrix);
this->y_input_ = QuantInputY(trans_y, scale_y); this->y_input_ = QuantInputY(trans_y, scale_y);
auto dst_desc = auto dst_desc =
this->template CreateMemDescriptor<OT>(output, memory::format::any); this->template CreateMemDescriptor<OT>(output, MKLDNNMemoryFormat::any);
this->mul_ = CreateMulPrimitive(*(this->x_input_), *(this->y_input_), this->mul_ = CreateMulPrimitive(*(this->x_input_), *(this->y_input_),
dst_desc, output, ctx); dst_desc, output, ctx);
...@@ -270,9 +270,9 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> { ...@@ -270,9 +270,9 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> {
auto y_dims = std::vector<int>(dims, dims + ndims); auto y_dims = std::vector<int>(dims, dims + ndims);
auto user_y_desc = auto user_y_desc =
this->template CreateMemDescriptor<YT>(y_dims, memory::format::oi); this->template CreateMemDescriptor<YT>(y_dims, MKLDNNMemoryFormat::oi);
auto y_desc = auto y_desc = this->template CreateMemDescriptor<int8_t>(
this->template CreateMemDescriptor<int8_t>(y_dims, memory::format::oi); y_dims, MKLDNNMemoryFormat::oi);
return ReorderWithScale(user_y_desc, y_desc, input_y.get_data_handle(), return ReorderWithScale(user_y_desc, y_desc, input_y.get_data_handle(),
scale_y); scale_y);
......
...@@ -42,9 +42,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -42,9 +42,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* input = ctx.Input<Tensor>("X"); const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
input->format() != memory::format::format_undef, "Wrong layout set for Input tensor");
"Wrong layout/format set for Input tensor"); PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for Input tensor");
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
...@@ -72,7 +73,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -72,7 +73,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
auto input_format = input->format(); auto input_format = input->format();
memory::format output_format{memory::format::format_undef}; MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::format_undef};
mkldnn::memory::data_type dt = mkldnn::memory::data_type dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
...@@ -95,8 +96,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -95,8 +96,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
* ('any') which lets a primitive (pooling in this case) choose * ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance * the memory format preferred for best performance
*/ */
auto dst_md = auto dst_md = platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any);
platform::MKLDNNMemDesc(dst_tz, dt, mkldnn::memory::format::any);
auto pooling_pd = handler.AcquirePoolingPrimitiveDescriptor( auto pooling_pd = handler.AcquirePoolingPrimitiveDescriptor(
src_tz, dst_tz, src_md, dst_md, ksize, strides, paddings, src_tz, dst_tz, src_md, dst_md, ksize, strides, paddings,
...@@ -112,7 +112,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -112,7 +112,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output_format = output_format =
(memory::format)dst_memory->get_primitive_desc().desc().data.format; (MKLDNNMemoryFormat)dst_memory->get_primitive_desc().desc().data.format;
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(output_format); output->set_format(output_format);
...@@ -130,15 +130,18 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -130,15 +130,18 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
PADDLE_ENFORCE(in_x->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(in_x->layout(), DataLayout::kMKLDNN,
in_x->format() != memory::format::format_undef, "Wrong layout set for Input tensor");
"Wrong layout/format set for Input X tensor"); PADDLE_ENFORCE_NE(in_x->format(), MKLDNNMemoryFormat::format_undef,
PADDLE_ENFORCE(out_grad->layout() == DataLayout::kMKLDNN && "Wrong format set for Input tensor");
out_grad->format() != memory::format::format_undef,
"Wrong layout/format set for Input output_grad tensor");
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
!ctx.Attr<bool>("is_test"), "Wrong layout set for Input output_grad tensor");
PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for Input output_grad tensor");
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
"is_test attribute should be set to False in training phase."); "is_test attribute should be set to False in training phase.");
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); std::string pooling_type = ctx.Attr<std::string>("pooling_type");
...@@ -161,7 +164,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -161,7 +164,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace()); T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace());
memory::format in_x_grad_format{memory::format::format_undef}; MKLDNNMemoryFormat in_x_grad_format{MKLDNNMemoryFormat::format_undef};
std::vector<int> diff_src_tz = std::vector<int> diff_src_tz =
paddle::framework::vectorize2int(in_x_grad->dims()); paddle::framework::vectorize2int(in_x_grad->dims());
...@@ -186,9 +189,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -186,9 +189,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto diff_dst_memory = handler.AcquireDiffDstMemory( auto diff_dst_memory = handler.AcquireDiffDstMemory(
diff_dst_md, to_void_cast<T>(out_grad_data)); diff_dst_md, to_void_cast<T>(out_grad_data));
auto diff_src_md = auto diff_src_md = platform::MKLDNNMemDesc(
platform::MKLDNNMemDesc(diff_src_tz, platform::MKLDNNGetDataType<T>(), diff_src_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
mkldnn::memory::format::any);
auto bwd_pd = handler.AcquirePoolingBackwardPrimitiveDescriptor( auto bwd_pd = handler.AcquirePoolingBackwardPrimitiveDescriptor(
diff_dst_md, diff_src_md, ksize, strides, paddings); diff_dst_md, diff_src_md, ksize, strides, paddings);
...@@ -202,7 +204,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -202,7 +204,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*pool_bwd_p); pipeline.push_back(*pool_bwd_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
in_x_grad_format = (memory::format)diff_src_memory->get_primitive_desc() in_x_grad_format = (MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc()
.desc() .desc()
.data.format; .data.format;
in_x_grad->set_layout(DataLayout::kMKLDNN); in_x_grad->set_layout(DataLayout::kMKLDNN);
......
...@@ -48,8 +48,8 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -48,8 +48,8 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
mkldnn::memory::data_type dst_dt = src_dt; mkldnn::memory::data_type dst_dt = src_dt;
mkldnn::memory::format src_fmt = memory::format::nhwc; MKLDNNMemoryFormat src_fmt = MKLDNNMemoryFormat::nhwc;
mkldnn::memory::format dst_fmt = memory::format::nhwc; MKLDNNMemoryFormat dst_fmt = MKLDNNMemoryFormat::nhwc;
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
......
...@@ -36,7 +36,7 @@ template <typename T> ...@@ -36,7 +36,7 @@ template <typename T>
class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
public: public:
SoftmaxMKLDNNHandler(const std::vector<int>& dims, SoftmaxMKLDNNHandler(const std::vector<int>& dims,
const mkldnn::memory::format fmt, const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key) mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key), : platform::MKLDNNHandler(dev_ctx, engine, base_key),
...@@ -44,8 +44,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -44,8 +44,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
fmt_(fmt) {} fmt_(fmt) {}
SoftmaxMKLDNNHandler(const std::vector<int>& dims, SoftmaxMKLDNNHandler(const std::vector<int>& dims,
const mkldnn::memory::format fmt, const MKLDNNMemoryFormat fmt,
const mkldnn::memory::format diff_fmt, const MKLDNNMemoryFormat diff_fmt,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key) mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key), : platform::MKLDNNHandler(dev_ctx, engine, base_key),
...@@ -165,8 +165,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -165,8 +165,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
private: private:
std::vector<int> dims_; std::vector<int> dims_;
mkldnn::memory::format fmt_; MKLDNNMemoryFormat fmt_;
mkldnn::memory::format diff_fmt_; MKLDNNMemoryFormat diff_fmt_;
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd_; std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd_;
}; };
...@@ -207,8 +207,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -207,8 +207,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
const std::string key = const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out")); platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out"));
SoftmaxMKLDNNHandler<T> handler(softmax_tz, mkldnn::memory::format::nc, SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc, dev_ctx,
dev_ctx, mkldnn_engine, key); mkldnn_engine, key);
// Currently only NC data format is supported // Currently only NC data format is supported
auto softmax_src_memory_p = auto softmax_src_memory_p =
...@@ -288,8 +288,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -288,8 +288,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
// TODO(jczaja): Add layouts support when there is a need to do so // TODO(jczaja): Add layouts support when there is a need to do so
// Two dimensional softmax does support NC format // Two dimensional softmax does support NC format
// Normalization is made after innermost dimension eg. C out of NC // Normalization is made after innermost dimension eg. C out of NC
SoftmaxMKLDNNHandler<T> handler(softmax_tz, mkldnn::memory::format::nc, SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc,
mkldnn::memory::format::nc, dev_ctx, MKLDNNMemoryFormat::nc, dev_ctx,
mkldnn_engine, key); mkldnn_engine, key);
auto dst_memory_p = handler.AcquireDstMemory(to_void_cast<T>(dst_data)); auto dst_memory_p = handler.AcquireDstMemory(to_void_cast<T>(dst_data));
......
...@@ -65,27 +65,29 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -65,27 +65,29 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dst_tz = framework::vectorize2int(output->dims()); std::vector<int> dst_tz = framework::vectorize2int(output->dims());
auto src_tz = dst_tz; auto src_tz = dst_tz;
memory::format output_format{memory::format::format_undef}; MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::format_undef};
std::vector<float> scales; std::vector<float> scales;
std::vector<memory::primitive_desc> srcs_mpd; std::vector<memory::primitive_desc> srcs_mpd;
std::vector<mkldnn::memory> srcs_mem; std::vector<mkldnn::memory> srcs_mem;
PADDLE_ENFORCE(in_vars[0]->IsType<LoDTensor>(), PADDLE_ENFORCE_EQ(in_vars[0]->IsType<LoDTensor>(), true,
"Input[0] must be LoDTensors"); "Input[0] must be LoDTensors");
auto& input0 = in_vars[0]->Get<LoDTensor>(); auto& input0 = in_vars[0]->Get<LoDTensor>();
PADDLE_ENFORCE(input0.layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(input0.layout(), DataLayout::kMKLDNN,
input0.format() != memory::format::format_undef, "Wrong layout set for inputs[0] tensor");
"Wrong layout/format for inputs[0]"); PADDLE_ENFORCE_NE(input0.format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for inputs[0] tensor");
memory::format input_format = input0.format(); MKLDNNMemoryFormat input_format = input0.format();
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
PADDLE_ENFORCE(in_vars[i]->IsType<LoDTensor>(), PADDLE_ENFORCE_EQ(in_vars[i]->IsType<LoDTensor>(), true,
"all inputs must be all LoDTensors"); "all inputs must be all LoDTensors");
auto& input = in_vars[i]->Get<LoDTensor>(); auto& input = in_vars[i]->Get<LoDTensor>();
PADDLE_ENFORCE(input.layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE_EQ(input.layout(), DataLayout::kMKLDNN,
input.format() != memory::format::format_undef, "Wrong layout set for inputs");
"Wrong layout/format for inputs"); PADDLE_ENFORCE_NE(input.format(), MKLDNNMemoryFormat::format_undef,
"Wrong format set for inputs");
if (input.numel() == 0) { if (input.numel() == 0) {
continue; continue;
...@@ -103,7 +105,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -103,7 +105,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
auto dst_md = auto dst_md =
memory::desc(dst_tz, memory::data_type::f32, memory::format::any); memory::desc(dst_tz, memory::data_type::f32, MKLDNNMemoryFormat::any);
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd); auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
...@@ -119,7 +121,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -119,7 +121,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem); auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem);
output_format = (memory::format)platform::GetMKLDNNFormat(sum_pd); output_format = (MKLDNNMemoryFormat)platform::GetMKLDNNFormat(sum_pd);
primitive reorder_prim; primitive reorder_prim;
std::shared_ptr<memory> target_mem; std::shared_ptr<memory> target_mem;
......
...@@ -64,7 +64,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -64,7 +64,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kNCHW); output->set_layout(DataLayout::kNCHW);
output->set_format(mkldnn::memory::format::format_undef); output->set_format(MKLDNNMemoryFormat::format_undef);
} }
}; };
......
...@@ -21,6 +21,9 @@ limitations under the License. */ ...@@ -21,6 +21,9 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
#ifdef PADDLE_WITH_MKLDNN
using MKLDNNMemoryFormat = mkldnn::memory::format;
#endif
namespace platform { namespace platform {
using MKLDNNStream = mkldnn::stream; using MKLDNNStream = mkldnn::stream;
...@@ -69,7 +72,7 @@ tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p, ...@@ -69,7 +72,7 @@ tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims, inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims,
mkldnn::memory::data_type data_type, mkldnn::memory::data_type data_type,
mkldnn::memory::format format) { MKLDNNMemoryFormat format) {
mkldnn::memory::dims tz = dims; mkldnn::memory::dims tz = dims;
return mkldnn::memory::desc({tz}, data_type, format); return mkldnn::memory::desc({tz}, data_type, format);
} }
...@@ -108,71 +111,71 @@ inline void Reorder(const mkldnn::memory& src, const mkldnn::memory& dst) { ...@@ -108,71 +111,71 @@ inline void Reorder(const mkldnn::memory& src, const mkldnn::memory& dst) {
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
} }
inline mkldnn::memory::format GetMKLDNNFormat(const mkldnn::memory memory) { inline MKLDNNMemoryFormat GetMKLDNNFormat(const mkldnn::memory memory) {
return static_cast<mkldnn::memory::format>( return static_cast<MKLDNNMemoryFormat>(
memory.get_primitive_desc().desc().data.format); memory.get_primitive_desc().desc().data.format);
} }
inline mkldnn::memory::format GetMKLDNNFormat( inline MKLDNNMemoryFormat GetMKLDNNFormat(
const mkldnn::sum::primitive_desc& memory) { const mkldnn::sum::primitive_desc& memory) {
return static_cast<mkldnn::memory::format>( return static_cast<MKLDNNMemoryFormat>(
memory.dst_primitive_desc().desc().data.format); memory.dst_primitive_desc().desc().data.format);
} }
inline mkldnn::memory::format MKLDNNFormatForSize( inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
size_t dims_size, mkldnn::memory::format data_format) { MKLDNNMemoryFormat data_format) {
if (dims_size == 1) { if (dims_size == 1) {
return mkldnn::memory::format::x; return MKLDNNMemoryFormat::x;
} else if (dims_size == 2) { } else if (dims_size == 2) {
return mkldnn::memory::format::nc; return MKLDNNMemoryFormat::nc;
} else if (dims_size == 3) { } else if (dims_size == 3) {
if (data_format == mkldnn::memory::format::nchw) { if (data_format == MKLDNNMemoryFormat::nchw) {
return mkldnn::memory::format::ncw; return MKLDNNMemoryFormat::ncw;
} else if (data_format == mkldnn::memory::format::nhwc) { } else if (data_format == MKLDNNMemoryFormat::nhwc) {
return mkldnn::memory::format::nwc; return MKLDNNMemoryFormat::nwc;
} }
} else if (dims_size == 4) { } else if (dims_size == 4) {
if (data_format == mkldnn::memory::format::goihw) { if (data_format == MKLDNNMemoryFormat::goihw) {
return mkldnn::memory::format::oihw; return MKLDNNMemoryFormat::oihw;
} }
} else if (dims_size == 5) { } else if (dims_size == 5) {
if (data_format == mkldnn::memory::format::goidhw) { if (data_format == MKLDNNMemoryFormat::goidhw) {
return mkldnn::memory::format::oidhw; return MKLDNNMemoryFormat::oidhw;
} }
if (data_format == mkldnn::memory::format::nchw) { if (data_format == MKLDNNMemoryFormat::nchw) {
return mkldnn::memory::format::ncdhw; return MKLDNNMemoryFormat::ncdhw;
} else if (data_format == mkldnn::memory::format::nhwc) { } else if (data_format == MKLDNNMemoryFormat::nhwc) {
return mkldnn::memory::format::ndhwc; return MKLDNNMemoryFormat::ndhwc;
} }
} }
return data_format; return data_format;
} }
inline mkldnn::memory::format data_format_to_memory_format( inline MKLDNNMemoryFormat data_format_to_memory_format(
const std::string& data_format) { const std::string& data_format) {
switch (framework::StringToDataLayout(data_format)) { switch (framework::StringToDataLayout(data_format)) {
case framework::DataLayout::kNHWC: case framework::DataLayout::kNHWC:
return mkldnn::memory::format::nhwc; return MKLDNNMemoryFormat::nhwc;
case framework::DataLayout::kNCHW: case framework::DataLayout::kNCHW:
return mkldnn::memory::format::nchw; return MKLDNNMemoryFormat::nchw;
default: default:
return mkldnn::memory::format::any; return MKLDNNMemoryFormat::any;
} }
} }
inline mkldnn::memory::format StringToMKLDNNFormat(std::string* format) { inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
std::transform(format->begin(), format->end(), format->begin(), ::tolower); std::transform(format->begin(), format->end(), format->begin(), ::tolower);
if (!format->compare("nchw")) { if (!format->compare("nchw")) {
return mkldnn::memory::format::nchw; return MKLDNNMemoryFormat::nchw;
} else if (!format->compare("nchw16c")) { } else if (!format->compare("nchw16c")) {
return mkldnn::memory::format::nChw16c; return MKLDNNMemoryFormat::nChw16c;
} else if (!format->compare("nchw8c")) { } else if (!format->compare("nchw8c")) {
return mkldnn::memory::format::nChw8c; return MKLDNNMemoryFormat::nChw8c;
} else if (!format->compare("nhwc")) { } else if (!format->compare("nhwc")) {
return mkldnn::memory::format::nhwc; return MKLDNNMemoryFormat::nhwc;
} else { } else {
return mkldnn::memory::format::any; return MKLDNNMemoryFormat::any;
} }
} }
......
...@@ -121,7 +121,7 @@ class MKLDNNHandler { ...@@ -121,7 +121,7 @@ class MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireMemory( std::shared_ptr<mkldnn::memory> AcquireMemory(
const std::vector<int>& dims, const mkldnn::memory::data_type dtype, const std::vector<int>& dims, const mkldnn::memory::data_type dtype,
const mkldnn::memory::format& fmt, void* ptr, const std::string& suffix) { const MKLDNNMemoryFormat& fmt, void* ptr, const std::string& suffix) {
/*Generate key*/ /*Generate key*/
auto local_key = key_ + suffix; auto local_key = key_ + suffix;
auto mem_p = auto mem_p =
...@@ -236,7 +236,7 @@ class MKLDNNHandler { ...@@ -236,7 +236,7 @@ class MKLDNNHandler {
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides, const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations, const std::vector<int>& paddings, const std::vector<int>& dilations,
const int& groups, const mkldnn::memory::data_type& srcdt, const int& groups, const mkldnn::memory::data_type& srcdt,
const mkldnn::memory::format& format, const std::string& fuse_activation, const MKLDNNMemoryFormat& format, const std::string& fuse_activation,
const bool& residual, const std::string& suffix) { const bool& residual, const std::string& suffix) {
AppendKeyDims(key, input_dims); AppendKeyDims(key, input_dims);
...@@ -454,9 +454,8 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { ...@@ -454,9 +454,8 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
static std::string GetHash(const memory::dims& input_dims, static std::string GetHash(const memory::dims& input_dims,
const mkldnn::algorithm algorithm, const mkldnn::algorithm algorithm,
const mkldnn::memory::format fmt, const MKLDNNMemoryFormat fmt, const float alpha,
const float alpha, const float beta, const float beta, const std::string& suffix) {
const std::string& suffix) {
std::string key; std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength); key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, input_dims); platform::MKLDNNHandler::AppendKeyDims(&key, input_dims);
...@@ -606,7 +605,7 @@ class LRNMKLDNNHandler : public MKLDNNHandler { ...@@ -606,7 +605,7 @@ class LRNMKLDNNHandler : public MKLDNNHandler {
static std::string GetHash(const memory::dims& input_dims, const int n, static std::string GetHash(const memory::dims& input_dims, const int n,
const float alpha, const float beta, const float k, const float alpha, const float beta, const float k,
const memory::format& fmt, const MKLDNNMemoryFormat& fmt,
const std::string& suffix) { const std::string& suffix) {
std::string key; std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength); key.reserve(platform::MKLDNNHandler::MaxKeyLength);
...@@ -691,7 +690,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandler { ...@@ -691,7 +690,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandler {
pooling_type_ == "max" pooling_type_ == "max"
? fwd_pd_->workspace_primitive_desc() ? fwd_pd_->workspace_primitive_desc()
: mkldnn::memory::primitive_desc( : mkldnn::memory::primitive_desc(
{{}, dt_, mkldnn::memory::format::nchw}, engine_); {{}, dt_, MKLDNNMemoryFormat::nchw}, engine_);
// Pooling PD has to be passed to Grad op that // Pooling PD has to be passed to Grad op that
// may be executed by diffrent thread, hence // may be executed by diffrent thread, hence
// for that one we use key that does not contain TID // for that one we use key that does not contain TID
...@@ -801,7 +800,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandler { ...@@ -801,7 +800,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandler {
const memory::dims& input_dims, const std::string& pooling_type, const memory::dims& input_dims, const std::string& pooling_type,
const std::vector<int>& ksize, const std::vector<int>& strides, const std::vector<int>& ksize, const std::vector<int>& strides,
const std::vector<int>& paddings, const memory::data_type& dt, const std::vector<int>& paddings, const memory::data_type& dt,
const memory::format& fmt, const std::string& suffix) { const MKLDNNMemoryFormat& fmt, const std::string& suffix) {
std::string key; std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength); key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, input_dims); platform::MKLDNNHandler::AppendKeyDims(&key, input_dims);
...@@ -855,7 +854,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -855,7 +854,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
logical_axis_(dims.size(), 0) {} logical_axis_(dims.size(), 0) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::format& fmt, void* ptr) { const MKLDNNMemoryFormat& fmt, void* ptr) {
auto local_key = key_ + "@user_src_mem_p"; auto local_key = key_ + "@user_src_mem_p";
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
...@@ -865,7 +864,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -865,7 +864,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
for (size_t i = 0; i < logical_axis_.size(); ++i) { for (size_t i = 0; i < logical_axis_.size(); ++i) {
logical_axis_[i] = i; logical_axis_[i] = i;
} }
auto src_md = fmt != mkldnn::memory::format::nchw auto src_md = fmt != MKLDNNMemoryFormat::nchw
? platform::MKLDNNMemDesc( ? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<float>(), fmt) dims_, platform::MKLDNNGetDataType<float>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_); : Axis2MemoryDesc(dims_, logical_axis_);
...@@ -967,12 +966,12 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { ...@@ -967,12 +966,12 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
dtype_(dtype) {} dtype_(dtype) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::format& fmt, void* ptr) { const MKLDNNMemoryFormat& fmt, void* ptr) {
return this->AcquireMemory(dims_, dtype_, fmt, ptr, "@user_src_mem_p"); return this->AcquireMemory(dims_, dtype_, fmt, ptr, "@user_src_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemory( std::shared_ptr<mkldnn::memory> AcquireDstMemory(
framework::Tensor* output, const mkldnn::memory::format& fmt, framework::Tensor* output, const MKLDNNMemoryFormat& fmt,
platform::Place place) { platform::Place place) {
auto local_key = key_ + "@user_dst_mem_p"; auto local_key = key_ + "@user_dst_mem_p";
auto mem_p = auto mem_p =
...@@ -1007,8 +1006,8 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { ...@@ -1007,8 +1006,8 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
} }
static std::string GetHash(std::vector<int>& shape, // NOLINT static std::string GetHash(std::vector<int>& shape, // NOLINT
mkldnn::memory::format in_fmt, MKLDNNMemoryFormat in_fmt,
mkldnn::memory::format out_fmt, MKLDNNMemoryFormat out_fmt,
const std::string& suffix) { const std::string& suffix) {
return dims2str(shape) + std::to_string(in_fmt) + "->" + return dims2str(shape) + std::to_string(in_fmt) + "->" +
std::to_string(out_fmt) + "#" + suffix; std::to_string(out_fmt) + "#" + suffix;
...@@ -1071,8 +1070,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -1071,8 +1070,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
return conv_pd_->dst_primitive_desc().get_size(); return conv_pd_->dst_primitive_desc().get_size();
} }
mkldnn::memory::format GetDstFormat() const { MKLDNNMemoryFormat GetDstFormat() const {
return static_cast<mkldnn::memory::format>( return static_cast<MKLDNNMemoryFormat>(
conv_pd_->dst_primitive_desc().desc().data.format); conv_pd_->dst_primitive_desc().desc().data.format);
} }
...@@ -1435,10 +1434,10 @@ static void SetDstMemoryQuantized( ...@@ -1435,10 +1434,10 @@ static void SetDstMemoryQuantized(
std::shared_ptr<mkldnn::memory>& dst_memory) { // NOLINT std::shared_ptr<mkldnn::memory>& dst_memory) { // NOLINT
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
const size_t dst_dims = dst_tz.size(); const size_t dst_dims = dst_tz.size();
memory::format dst_fmt; MKLDNNMemoryFormat dst_fmt;
PADDLE_ENFORCE(dst_dims <= 5, PADDLE_ENFORCE_LE(dst_dims, 5,
"Dst memory for quantization can not have dims > 5"); "Dst memory for quantization can not have dims > 5");
dst_fmt = platform::MKLDNNFormatForSize(dst_dims, memory::format::nhwc); dst_fmt = platform::MKLDNNFormatForSize(dst_dims, MKLDNNMemoryFormat::nhwc);
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, paddle::framework::ToMKLDNNDataType( {dst_tz}, paddle::framework::ToMKLDNNDataType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册