未验证 提交 98aaf797 编写于 作者: J jakpiase 提交者: GitHub

Reenabled reshape, squeeze and flatten oneDNN kernels (#48359)

* re-enabled reshape, squeeze and flatten kernels

* added formatting
上级 11b9d85f
......@@ -87,6 +87,16 @@ class FlattenOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -159,6 +169,16 @@ class FlattenGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -223,6 +243,16 @@ class Flatten2Op : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -275,6 +305,16 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -80,7 +80,7 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
out->Resize(x_dims); // to match x numel, format is changed later
// reorder is done into a plain tag to allow usage with blocked formats
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, getPlainFormatTag(x), ctx.GetPlace());
out, phi::funcs::GetPlainOneDNNFormat(x_dims.size()), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
......@@ -194,31 +194,6 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
}
protected:
static dnnl::memory::format_tag getPlainFormatTag(
const phi::DenseTensor* tensor) {
auto tensor_dims_size = tensor->dims().size();
PADDLE_ENFORCE_EQ(
tensor_dims_size <= 6 && tensor_dims_size >= 1,
true,
platform::errors::InvalidArgument(
"Dims for squeeze_grad oneDNN op must be in range <1, 6>"));
switch (tensor_dims_size) {
case 1:
return dnnl::memory::format_tag::a;
case 2:
return dnnl::memory::format_tag::ab;
case 3:
return dnnl::memory::format_tag::abc;
case 4:
return dnnl::memory::format_tag::abcd;
case 5:
return dnnl::memory::format_tag::abcde;
default:
return dnnl::memory::format_tag::abcdef;
}
}
static framework::DDim ValidateShape(const std::vector<int>& shape,
const framework::DDim& in_dims) {
const int64_t in_size = phi::product(in_dims);
......@@ -348,7 +323,9 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->mem_desc(), phi::funcs::to_void_cast(dout->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx, this->getPlainFormatTag(dout), ctx.GetPlace());
dx,
phi::funcs::GetPlainOneDNNFormat(dout_vec_dims.size()),
ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
......
......@@ -258,6 +258,15 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -615,6 +624,15 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -125,13 +125,14 @@ class SqueezeOp : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
// #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// phi::DataLayout::ONEDNN,
// framework::LibraryType::kMKLDNN);
// }
// #endif
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -152,13 +153,14 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
// #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// phi::DataLayout::ONEDNN,
// framework::LibraryType::kMKLDNN);
// }
// #endif
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -219,13 +221,14 @@ class Squeeze2Op : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
// #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// phi::DataLayout::ONEDNN,
// framework::LibraryType::kMKLDNN);
// }
// #endif
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -267,13 +270,14 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
// #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// phi::DataLayout::ONEDNN,
// framework::LibraryType::kMKLDNN);
// }
// #endif
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册