未验证 提交 db468d7d 编写于 作者: J jakpiase 提交者: GitHub

oneDNN md-in-tensor 2nd batch of changes (#41997)

上级 5c738223
...@@ -107,8 +107,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -107,8 +107,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}}); astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
astream.wait(); astream.wait();
out->set_layout(DataLayout::kMKLDNN); out->set_mem_desc(dst_memory_p->get_desc());
out->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
template <typename T> template <typename T>
...@@ -136,8 +135,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -136,8 +135,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait(); astream.wait();
dx->set_layout(DataLayout::kMKLDNN); dx->set_mem_desc(diff_src_memory_p->get_desc());
dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
} }
template <typename T> template <typename T>
...@@ -165,8 +163,7 @@ void eltwise_grad_use_out(const framework::ExecutionContext &ctx, ...@@ -165,8 +163,7 @@ void eltwise_grad_use_out(const framework::ExecutionContext &ctx,
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait(); astream.wait();
dx->set_layout(DataLayout::kMKLDNN); dx->set_mem_desc(diff_src_memory_p->get_desc());
dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
} }
template <typename T, dnnl::algorithm algorithm> template <typename T, dnnl::algorithm algorithm>
...@@ -347,6 +344,7 @@ namespace ops = paddle::operators; ...@@ -347,6 +344,7 @@ namespace ops = paddle::operators;
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL); FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
// round eltwise primitive doesn't support BF16, nor does it support grad
REGISTER_ACTIVATION_MKLDNN_KERNEL_FWD_ONLY(round, RoundMKLDNNFunctor); REGISTER_ACTIVATION_MKLDNN_KERNEL_FWD_ONLY(round, RoundMKLDNNFunctor);
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -54,17 +54,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< ...@@ -54,17 +54,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
std::vector<std::string> DataLayout_error_msg = {"kNHWC", "kNCHW", std::vector<std::string> DataLayout_error_msg = {"kNHWC", "kNCHW",
"kAnyLayout", "kMKLDNN"}; "kAnyLayout", "kMKLDNN"};
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for X tensor. Expected layout is `kMKLDNN`, "
"But received %s.",
DataLayout_error_msg[static_cast<int>(DataLayout::kMKLDNN)]));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor"));
auto src_tz = phi::vectorize(x->dims());
// Flags are added by bitwise OR operation // Flags are added by bitwise OR operation
auto flags = dnnl::normalization_flags::use_scale_shift; // 001 auto flags = dnnl::normalization_flags::use_scale_shift; // 001
...@@ -73,14 +62,10 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< ...@@ -73,14 +62,10 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
if (fuse_with_relu && test_mode) if (fuse_with_relu && test_mode)
flags |= dnnl::normalization_flags::fuse_norm_relu; // 100 flags |= dnnl::normalization_flags::fuse_norm_relu; // 100
auto md = dnnl::memory::desc(
src_tz, platform::MKLDNNGetDataType<T>(),
platform::MKLDNNFormatForSize(src_tz.size(), x->format()));
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
global_stats == true ? dnnl::prop_kind::forward_scoring global_stats == true ? dnnl::prop_kind::forward_scoring
: dnnl::prop_kind::forward_training, : dnnl::prop_kind::forward_training,
md, epsilon, flags); x->mem_desc(), epsilon, flags);
} }
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
...@@ -89,14 +74,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< ...@@ -89,14 +74,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
: platform::MKLDNNHandlerNoCachingT<T, dnnl::batch_normalization_forward, : platform::MKLDNNHandlerNoCachingT<T, dnnl::batch_normalization_forward,
dnnl::batch_normalization_backward>( dnnl::batch_normalization_backward>(
mkldnn_engine, ctx.GetPlace()) { mkldnn_engine, ctx.GetPlace()) {
PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Input out_grad tensor"));
PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Input out_grad tensor"));
auto src_tz = phi::vectorize<int64_t>(in_x->dims());
auto scale_tz = phi::vectorize<int64_t>(scale->dims()); auto scale_tz = phi::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_tz.size(), 1, scale_tz.size(), 1,
...@@ -104,26 +81,14 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< ...@@ -104,26 +81,14 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
"Dims of scale tensor must be 1, but received scale's size is %d", "Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size())); scale_tz.size()));
MKLDNNMemoryFormat diff_fmt =
platform::MKLDNNFormatForSize(src_tz.size(), out_grad->format());
MKLDNNMemoryFormat src_fmt =
platform::MKLDNNFormatForSize(src_tz.size(), in_x->format());
auto dims = phi::vectorize(in_x->dims());
auto diff_dst_md =
dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt);
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_training, src_md, epsilon, dnnl::prop_kind::forward_training, in_x->mem_desc(), epsilon,
dnnl::normalization_flags::use_scale_shift); dnnl::normalization_flags::use_scale_shift);
this->AcquireBackwardPrimitiveDescriptor( this->AcquireBackwardPrimitiveDescriptor(
dnnl::prop_kind::backward, diff_dst_md, src_md, epsilon, dnnl::prop_kind::backward, out_grad->mem_desc(), in_x->mem_desc(),
dnnl::normalization_flags::use_scale_shift); epsilon, dnnl::normalization_flags::use_scale_shift);
} }
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(const Tensor *scale, std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(const Tensor *scale,
...@@ -227,8 +192,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -227,8 +192,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
variance_memory = handler.AcquireVarianceMemory(batch_variance); variance_memory = handler.AcquireVarianceMemory(batch_variance);
} }
y->set_layout(DataLayout::kMKLDNN); y->set_mem_desc(dst_memory->get_desc());
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
batch_norm_p->execute(astream, {{DNNL_ARG_SRC, *src_memory}, batch_norm_p->execute(astream, {{DNNL_ARG_SRC, *src_memory},
...@@ -322,9 +286,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -322,9 +286,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::copy(std::next(it, C), std::end(diff_scaleshift_data), std::copy(std::next(it, C), std::end(diff_scaleshift_data),
diff_shift_data); diff_shift_data);
// set layout/format of output tensors // set memory descriptor of out tensor
diff_x->set_layout(DataLayout::kMKLDNN); diff_x->set_mem_desc(diff_src_memory->get_desc());
diff_x->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -46,8 +46,7 @@ class ClipMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -46,8 +46,7 @@ class ClipMKLDNNKernel : public paddle::framework::OpKernel<T> {
{DNNL_ARG_TO, *dst_memory_p}}); {DNNL_ARG_TO, *dst_memory_p}});
astream.wait(); astream.wait();
out->set_layout(paddle::framework::DataLayout::kMKLDNN); out->set_mem_desc(dst_memory_p->get_desc());
out->set_format(paddle::platform::GetMKLDNNFormat(*dst_memory_p));
} }
}; };
...@@ -83,8 +82,7 @@ class ClipGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -83,8 +82,7 @@ class ClipGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait(); astream.wait();
dx->set_layout(paddle::framework::DataLayout::kMKLDNN); dx->set_mem_desc(diff_dst_memory_p->get_desc());
dx->set_format(paddle::platform::GetMKLDNNFormat(*diff_dst_memory_p));
} }
}; };
......
...@@ -68,8 +68,7 @@ class ConcatMKLDNNHandler ...@@ -68,8 +68,7 @@ class ConcatMKLDNNHandler
// Create memory descriptors for each of inputs // Create memory descriptors for each of inputs
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
const auto dims = phi::vectorize<int64_t>(inputs[i]->dims()); srcs_md.push_back(inputs[i]->mem_desc());
srcs_md.emplace_back(memory::desc(dims, dt, inputs[i]->format()));
} }
auto dst_dims = phi::vectorize<int64_t>(output->dims()); auto dst_dims = phi::vectorize<int64_t>(output->dims());
...@@ -99,9 +98,6 @@ static void EnforceLayouts(const std::vector<const Tensor*> inputs) { ...@@ -99,9 +98,6 @@ static void EnforceLayouts(const std::vector<const Tensor*> inputs) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->layout(), DataLayout::kMKLDNN, input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for Input tensor")); platform::errors::InvalidArgument("Wrong layout set for Input tensor"));
PADDLE_ENFORCE_NE(
input->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for Input tensor"));
} }
} }
...@@ -147,8 +143,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -147,8 +143,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
concat_p->execute(astream, args); concat_p->execute(astream, args);
astream.wait(); astream.wait();
output->set_layout(DataLayout::kMKLDNN); output->set_mem_desc(dst_mem->get_desc());
output->set_format(platform::GetMKLDNNFormat(*dst_mem));
} }
}; };
...@@ -192,7 +187,7 @@ class ConcatGradMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -192,7 +187,7 @@ class ConcatGradMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dout_vec_dims, framework::TransToProtoVarType(dout->dtype()), dout_type, dout_vec_dims, framework::TransToProtoVarType(dout->dtype()), dout_type,
onednn_engine); onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>())); dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
for (size_t i = 0; i < dx.size(); ++i) { for (size_t i = 0; i < dx.size(); ++i) {
if (out_var_names[i] != framework::kEmptyVarName && if (out_var_names[i] != framework::kEmptyVarName &&
...@@ -202,7 +197,8 @@ class ConcatGradMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -202,7 +197,8 @@ class ConcatGradMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dx_vec_dims, offset, reorder_src_memory_p); dx_vec_dims, offset, reorder_src_memory_p);
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx[i], dx_vec_dims, dout->format(), ctx.GetPlace()); dx[i], dx_vec_dims,
platform::GetPlainMKLDNNFormat(dx_vec_dims.size()), ctx.GetPlace());
auto reorder_p = auto reorder_p =
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);
...@@ -210,8 +206,7 @@ class ConcatGradMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -210,8 +206,7 @@ class ConcatGradMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
offset[axis] += dx[i]->dims()[axis]; offset[axis] += dx[i]->dims()[axis];
dx[i]->set_layout(framework::DataLayout::kMKLDNN); dx[i]->set_mem_desc(reorder_dst_memory_p->get_desc());
dx[i]->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
} }
} }
astream.wait(); astream.wait();
......
...@@ -115,10 +115,11 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -115,10 +115,11 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
dout_type, onednn_engine); dout_type, onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), paddle::platform::to_void_cast(dout->data<T>())); dout->mem_desc(), paddle::platform::to_void_cast(dout->data<T>()));
auto reorder_dst_memory_p = auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
reorder_handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace()); dx, paddle::platform::GetPlainMKLDNNFormat(dx_vec_dims.size()),
ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p); reorder_dst_memory_p);
...@@ -126,9 +127,7 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -126,9 +127,7 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
dx->set_layout(paddle::framework::DataLayout::kMKLDNN); dx->set_mem_desc(reorder_dst_memory_p->get_desc());
dx->set_format(
paddle::platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc()));
} else { } else {
paddle::platform::ReductionMKLDNNHandler<T> handler( paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
...@@ -145,8 +144,8 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -145,8 +144,8 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
reduction_p->execute(astream, reduction_args); reduction_p->execute(astream, reduction_args);
astream.wait(); astream.wait();
dx->set_layout(paddle::framework::DataLayout::kMKLDNN); dx->set_layout(paddle::framework::DataLayout::kMKLDNN);
dx->set_format(paddle::platform::GetMKLDNNFormat( dx->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(dx->dims())))); dst_memory_p->get_desc().reshape(vectorize<int64_t>(dx->dims())));
} }
} }
}; };
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/fill_constant_op.h" #include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,8 +42,13 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -42,8 +42,13 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
data[i] = dist(*engine); data[i] = dist(*engine);
} }
tensor->set_layout(DataLayout::kMKLDNN); dnnl::memory::desc out_mem_desc(
tensor->set_format(platform::GetPlainMKLDNNFormat(tensor->dims().size())); phi::vectorize(tensor->dims()),
framework::ToMKLDNNDataType(
framework::TransToProtoVarType(tensor->dtype())),
platform::GetPlainMKLDNNFormat(tensor->dims().size()));
tensor->set_mem_desc(out_mem_desc);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -34,17 +34,14 @@ class InterpolateMKLDNNHandler ...@@ -34,17 +34,14 @@ class InterpolateMKLDNNHandler
public: public:
InterpolateMKLDNNHandler(const dnnl::algorithm algo, InterpolateMKLDNNHandler(const dnnl::algorithm algo,
const dnnl::engine engine, platform::Place cpu_place, const dnnl::engine engine, platform::Place cpu_place,
const Tensor* x, Tensor* z) const Tensor* x, Tensor* out)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::resampling_forward>( : platform::MKLDNNHandlerNoCachingT<T, dnnl::resampling_forward>(
engine, cpu_place) { engine, cpu_place) {
const auto src_x_tz = phi::vectorize(x->dims()); const auto dst_tz = phi::vectorize(out->dims());
const auto dst_tz = phi::vectorize(z->dims());
const auto src_md = dnnl::memory::desc(
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(), const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any); MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_inference, this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_inference,
algo, src_md, dst_md); algo, x->mem_desc(), dst_md);
} }
}; };
...@@ -133,7 +130,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -133,7 +130,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<Tensor>("X"); const auto* x = ctx.Input<Tensor>("X");
auto* z = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
const auto interp_method = ctx.Attr<std::string>("interp_method"); const auto interp_method = ctx.Attr<std::string>("interp_method");
const dnnl::algorithm algo = (interp_method == "nearest") const dnnl::algorithm algo = (interp_method == "nearest")
...@@ -142,13 +139,13 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -142,13 +139,13 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const auto out_dims_vec = ComputeOutputShape(ctx); const auto out_dims_vec = ComputeOutputShape(ctx);
framework::DDim dim_out = phi::make_ddim(out_dims_vec); framework::DDim dim_out = phi::make_ddim(out_dims_vec);
z->Resize(dim_out); out->Resize(dim_out);
InterpolateMKLDNNHandler<T> handler(algo, mkldnn_engine, ctx.GetPlace(), x, InterpolateMKLDNNHandler<T> handler(algo, mkldnn_engine, ctx.GetPlace(), x,
z); out);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(z); auto dst_memory_p = handler.AcquireDstMemory(out);
auto resampling_prim = handler.AcquireForwardPrimitive(); auto resampling_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = { const std::unordered_map<int, dnnl::memory> args = {
...@@ -158,8 +155,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -158,8 +155,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
resampling_prim->execute(astream, args); resampling_prim->execute(astream, args);
astream.wait(); astream.wait();
z->set_layout(DataLayout::kMKLDNN); out->set_mem_desc(dst_memory_p->get_desc());
z->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
} }
}; };
......
...@@ -25,22 +25,21 @@ class LayerNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< ...@@ -25,22 +25,21 @@ class LayerNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
public: public:
LayerNormMKLDNNHandler(const std::vector<int64_t>& dims, const float& epsilon, LayerNormMKLDNNHandler(const std::vector<int64_t>& dims, const float& epsilon,
const dnnl::normalization_flags& flags, const dnnl::normalization_flags& flags,
const bool& is_test, const MKLDNNMemoryFormat fmt, const bool& is_test, const Tensor* x,
const dnnl::engine engine, platform::Place cpu_place) const dnnl::engine engine, platform::Place cpu_place)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::layer_normalization_forward>( : platform::MKLDNNHandlerNoCachingT<T, dnnl::layer_normalization_forward>(
engine, cpu_place) { engine, cpu_place) {
auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
if (!is_test) { if (!is_test) {
// TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced // TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced
auto stats_md = dnnl::memory::desc( auto stats_md = dnnl::memory::desc(
{begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType<float>(), {begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType<float>(),
platform::MKLDNNFormatForSize(dims.size() - 1, platform::GetPlainMKLDNNFormat(dims.size() - 1));
MKLDNNMemoryFormat::nchw));
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
md, stats_md, epsilon, flags); x->mem_desc(), stats_md, epsilon,
flags);
} else { } else {
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, md, epsilon, flags); dnnl::prop_kind::forward_inference, x->mem_desc(), epsilon, flags);
} }
} }
...@@ -83,7 +82,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -83,7 +82,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* scale = ctx.Input<Tensor>("Scale"); auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias"); auto* bias = ctx.Input<Tensor>("Bias");
auto* y = ctx.Output<Tensor>("Y"); auto* out = ctx.Output<Tensor>("Y");
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
...@@ -107,12 +106,11 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -107,12 +106,11 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
flags |= dnnl::normalization_flags::use_scale_shift; flags |= dnnl::normalization_flags::use_scale_shift;
} }
LayerNormMKLDNNHandler<T> handler(src_tz, epsilon, flags, is_test, LayerNormMKLDNNHandler<T> handler(src_tz, epsilon, flags, is_test, x,
x->format(), mkldnn_engine, mkldnn_engine, ctx.GetPlace());
ctx.GetPlace());
auto src_memory = handler.AcquireSrcMemory(x); auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(y); auto dst_memory = handler.AcquireDstMemory(out);
auto layer_norm_p = handler.AcquireForwardPrimitive(); auto layer_norm_p = handler.AcquireForwardPrimitive();
...@@ -140,8 +138,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -140,8 +138,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
layer_norm_p->execute(astream, args); layer_norm_p->execute(astream, args);
astream.wait(); astream.wait();
y->set_layout(phi::DataLayout::kMKLDNN); out->set_mem_desc(dst_memory->get_desc());
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
} }
}; };
......
...@@ -28,12 +28,8 @@ class LogSoftmaxMKLDNNHandler ...@@ -28,12 +28,8 @@ class LogSoftmaxMKLDNNHandler
const int axis) const int axis)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward>( : platform::MKLDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward>(
mkldnn_engine, cpu_place) { mkldnn_engine, cpu_place) {
const auto logsoftmax_tz = phi::vectorize(x->dims());
const auto md = dnnl::memory::desc(
logsoftmax_tz, platform::MKLDNNGetDataType<T>(), x->format());
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_inference, this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_inference,
md, axis); x->mem_desc(), axis);
} }
}; };
...@@ -63,8 +59,7 @@ class LogSoftmaxMKLDNNKernel : public framework::OpKernel<T> { ...@@ -63,8 +59,7 @@ class LogSoftmaxMKLDNNKernel : public framework::OpKernel<T> {
{DNNL_ARG_DST, *dst_memory_p}}); {DNNL_ARG_DST, *dst_memory_p}});
astream.wait(); astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); out->set_mem_desc(dst_memory_p->get_desc());
out->set_format(x->format());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -44,15 +44,11 @@ class LRNMKLDNNHandler ...@@ -44,15 +44,11 @@ class LRNMKLDNNHandler
const float k = ctx.Attr<float>("k"); const float k = ctx.Attr<float>("k");
bool is_test = ctx.Attr<bool>("is_test"); bool is_test = ctx.Attr<bool>("is_test");
auto dims = phi::vectorize(input->dims());
auto src_md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
is_test ? dnnl::prop_kind::forward_inference is_test ? dnnl::prop_kind::forward_inference
: dnnl::prop_kind::forward_training, : dnnl::prop_kind::forward_training,
dnnl::algorithm::lrn_across_channels, src_md, n, alpha, beta, k); dnnl::algorithm::lrn_across_channels, input->mem_desc(), n, alpha, beta,
k);
} }
LRNMKLDNNHandler(const framework::ExecutionContext& ctx, LRNMKLDNNHandler(const framework::ExecutionContext& ctx,
...@@ -72,20 +68,13 @@ class LRNMKLDNNHandler ...@@ -72,20 +68,13 @@ class LRNMKLDNNHandler
const float beta = ctx.Attr<float>("beta"); const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k"); const float k = ctx.Attr<float>("k");
auto dims = phi::vectorize<int64_t>(in_x->dims());
auto src_md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
in_x->format());
auto diff_md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
out_grad->format());
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_training, dnnl::algorithm::lrn_across_channels, dnnl::prop_kind::forward_training, dnnl::algorithm::lrn_across_channels,
src_md, n, alpha, beta, k); in_x->mem_desc(), n, alpha, beta, k);
this->AcquireBackwardPrimitiveDescriptor( this->AcquireBackwardPrimitiveDescriptor(
dnnl::algorithm::lrn_across_channels, src_md, diff_md, n, alpha, beta, dnnl::algorithm::lrn_across_channels, in_x->mem_desc(),
k); out_grad->mem_desc(), n, alpha, beta, k);
} }
std::shared_ptr<dnnl::memory> AcquireWorkspaceMemory(Tensor* workspace) { std::shared_ptr<dnnl::memory> AcquireWorkspaceMemory(Tensor* workspace) {
......
...@@ -41,13 +41,6 @@ class PoolingMKLDNNHandler ...@@ -41,13 +41,6 @@ class PoolingMKLDNNHandler
: platform::MKLDNNHandlerNoCachingT<T, dnnl::pooling_forward, : platform::MKLDNNHandlerNoCachingT<T, dnnl::pooling_forward,
dnnl::pooling_backward>( dnnl::pooling_backward>(
mkldnn_engine, ctx.GetPlace()) { mkldnn_engine, ctx.GetPlace()) {
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Input tensor."));
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Input tensor."));
const std::string pooling_type = ctx.Attr<std::string>("pooling_type"); const std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize"); std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
...@@ -91,29 +84,18 @@ class PoolingMKLDNNHandler ...@@ -91,29 +84,18 @@ class PoolingMKLDNNHandler
phi::funcs::UpdatePadding(&paddings, global_pooling, 0, padding_algorithm, phi::funcs::UpdatePadding(&paddings, global_pooling, 0, padding_algorithm,
data_dims, strides, ksize); data_dims, strides, ksize);
const auto src_tz = phi::vectorize(input->dims());
const auto dst_tz = phi::vectorize(output->dims());
const auto is_test = ctx.Attr<bool>("is_test"); const auto is_test = ctx.Attr<bool>("is_test");
const bool ceil_mode = ctx.Attr<bool>("ceil_mode");
const auto exclude_padding = ctx.Attr<bool>("exclusive");
auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
const auto dt = framework::ToMKLDNNDataType( const auto dt = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(input->dtype())); framework::TransToProtoVarType(input->dtype()));
const auto src_tz = phi::vectorize(input->dims());
const auto exclude_padding = ctx.Attr<bool>("exclusive"); const auto dst_tz = phi::vectorize(output->dims());
const auto src_md = dnnl::memory::desc(src_tz, dt, input->format());
/* create memory descriptor for pooling without specified format
* ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance
*/
const auto dst_md = const auto dst_md =
platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any); platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any);
auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
const bool ceil_mode = ctx.Attr<bool>("ceil_mode");
if (ceil_mode) { if (ceil_mode) {
CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides,
mkldnn_paddings[1]); mkldnn_paddings[1]);
...@@ -128,7 +110,8 @@ class PoolingMKLDNNHandler ...@@ -128,7 +110,8 @@ class PoolingMKLDNNHandler
? dnnl::algorithm::pooling_max ? dnnl::algorithm::pooling_max
: (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding : (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding
: dnnl::algorithm::pooling_avg_include_padding), : dnnl::algorithm::pooling_avg_include_padding),
src_md, dst_md, strides, ksize, mkldnn_paddings[0], mkldnn_paddings[1]); input->mem_desc(), dst_md, strides, ksize, mkldnn_paddings[0],
mkldnn_paddings[1]);
} }
PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
...@@ -138,20 +121,6 @@ class PoolingMKLDNNHandler ...@@ -138,20 +121,6 @@ class PoolingMKLDNNHandler
: platform::MKLDNNHandlerNoCachingT<T, dnnl::pooling_forward, : platform::MKLDNNHandlerNoCachingT<T, dnnl::pooling_forward,
dnnl::pooling_backward>( dnnl::pooling_backward>(
mkldnn_engine, ctx.GetPlace()) { mkldnn_engine, ctx.GetPlace()) {
PADDLE_ENFORCE_EQ(
in_x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for Input tensor"));
PADDLE_ENFORCE_NE(
in_x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for Input tensor"));
PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Input output_grad tensor"));
PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Input output_grad tensor"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false, ctx.Attr<bool>("is_test"), false,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -187,10 +156,7 @@ class PoolingMKLDNNHandler ...@@ -187,10 +156,7 @@ class PoolingMKLDNNHandler
const auto dt = framework::ToMKLDNNDataType( const auto dt = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(in_x->dtype())); framework::TransToProtoVarType(in_x->dtype()));
auto src_md = dnnl::memory::desc(src_tz, dt, in_x->format());
auto dst_md = dnnl::memory::desc(diff_dst_tz, dt, MKLDNNMemoryFormat::any); auto dst_md = dnnl::memory::desc(diff_dst_tz, dt, MKLDNNMemoryFormat::any);
auto diff_dst_md = dnnl::memory::desc(
diff_dst_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
auto diff_src_md = dnnl::memory::desc( auto diff_src_md = dnnl::memory::desc(
diff_src_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any); diff_src_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
...@@ -211,14 +177,15 @@ class PoolingMKLDNNHandler ...@@ -211,14 +177,15 @@ class PoolingMKLDNNHandler
? dnnl::algorithm::pooling_max ? dnnl::algorithm::pooling_max
: (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding : (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding
: dnnl::algorithm::pooling_avg_include_padding), : dnnl::algorithm::pooling_avg_include_padding),
src_md, dst_md, strides, ksize, mkldnn_paddings[0], mkldnn_paddings[1]); in_x->mem_desc(), dst_md, strides, ksize, mkldnn_paddings[0],
mkldnn_paddings[1]);
this->AcquireBackwardPrimitiveDescriptor( this->AcquireBackwardPrimitiveDescriptor(
pooling_type == "max" pooling_type == "max"
? dnnl::algorithm::pooling_max ? dnnl::algorithm::pooling_max
: (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding : (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding
: dnnl::algorithm::pooling_avg_include_padding), : dnnl::algorithm::pooling_avg_include_padding),
diff_src_md, diff_dst_md, strides, ksize, mkldnn_paddings[0], diff_src_md, out_grad->mem_desc(), strides, ksize, mkldnn_paddings[0],
mkldnn_paddings[1]); mkldnn_paddings[1]);
} }
...@@ -327,8 +294,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -327,8 +294,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
astream.wait(); astream.wait();
output->set_layout(DataLayout::kMKLDNN); output->set_mem_desc(dst_memory->get_desc());
output->set_format(platform::GetMKLDNNFormat(*dst_memory));
} }
}; };
...@@ -369,8 +335,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -369,8 +335,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
} }
astream.wait(); astream.wait();
in_x_grad->set_layout(DataLayout::kMKLDNN); in_x_grad->set_mem_desc(diff_src_memory->get_desc());
in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
} // Compute() } // Compute()
}; };
......
...@@ -41,9 +41,6 @@ class PReluMKLDNNHandler ...@@ -41,9 +41,6 @@ class PReluMKLDNNHandler
platform::CreateKey(dev_ctx, phi::vectorize(x->dims()), platform::CreateKey(dev_ctx, phi::vectorize(x->dims()),
uniq_name)) { uniq_name)) {
if (unlikely(!this->isCached())) { if (unlikely(!this->isCached())) {
auto x_md = memory::desc(phi::vectorize(x->dims()),
MKLDNNGetDataType<T>(), x->format());
auto weights_dims = phi::vectorize(weights->dims()); auto weights_dims = phi::vectorize(weights->dims());
// weights must have same size as X only for "element" case // weights must have same size as X only for "element" case
...@@ -59,30 +56,28 @@ class PReluMKLDNNHandler ...@@ -59,30 +56,28 @@ class PReluMKLDNNHandler
memory::format_tag::any); memory::format_tag::any);
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
x_md, weights_md); x->mem_desc(), weights_md);
if (!is_test) if (!is_test)
this->AcquireBackwardPrimitiveDescriptor(x_md, weights_md, x_md, this->AcquireBackwardPrimitiveDescriptor(x->mem_desc(), weights_md,
weights_md); x->mem_desc(), weights_md);
} }
} }
std::shared_ptr<memory> AcquireWeightsMemoryPossiblyWithReorder( std::shared_ptr<memory> AcquireWeightsMemoryPossiblyWithReorder(
const Tensor* input, const bool is_test) { const Tensor* weights, const bool is_test) {
const T* input_data = input->data<T>(); const T* weights_data = weights->data<T>();
// if weights are 1D, every format tag is correct, so we accept // if weights are 1D, every format tag is correct, so we accept
// format_tag::any's output and no reorder is needed // format_tag::any's output and no reorder is needed
if (input->dims().size() == 1) { if (weights->dims().size() == 1) {
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data), to_void_cast<T>(weights_data),
"@alpha_mem_p"); "@alpha_mem_p");
} }
auto user_weights_md = memory::desc(
phi::vectorize(input->dims()), MKLDNNGetDataType<T>(), input->format());
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_weights_md, this->fwd_pd_->weights_desc(), weights->mem_desc(), this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data), "@alpha_mem_p", is_test); to_void_cast<T>(weights_data), "@alpha_mem_p", is_test);
} }
std::shared_ptr<memory> AcquireDiffWeightsMemory(Tensor* output) { std::shared_ptr<memory> AcquireDiffWeightsMemory(Tensor* output) {
...@@ -128,8 +123,7 @@ class PReluMKLDNNKernel : public framework::OpKernel<T> { ...@@ -128,8 +123,7 @@ class PReluMKLDNNKernel : public framework::OpKernel<T> {
{DNNL_ARG_DST, *dst_memory_p}}); {DNNL_ARG_DST, *dst_memory_p}});
astream.wait(); astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); out->set_mem_desc(dst_memory_p->get_desc());
out->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
}; };
...@@ -174,8 +168,7 @@ class PReluGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -174,8 +168,7 @@ class PReluGradMKLDNNKernel : public framework::OpKernel<T> {
{DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}}); {DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}});
astream.wait(); astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN); dx->set_mem_desc(diff_src_memory_p->get_desc());
dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -54,8 +54,7 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> { ...@@ -54,8 +54,7 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
{DNNL_ARG_TO, *dst_memory_p}}); {DNNL_ARG_TO, *dst_memory_p}});
astream.wait(); astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); out->set_mem_desc(dst_memory_p->get_desc());
out->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -40,9 +40,13 @@ class ShapeMKLDNNKernel : public framework::OpKernel<T> { ...@@ -40,9 +40,13 @@ class ShapeMKLDNNKernel : public framework::OpKernel<T> {
out_data[i] = in_dims[i]; out_data[i] = in_dims[i];
} }
auto* out = ctx.Output<Tensor>("Out"); dnnl::memory::desc out_mem_desc(
out->set_layout(framework::DataLayout::kMKLDNN); phi::vectorize(out_t->dims()),
out->set_format(platform::GetPlainMKLDNNFormat(out->dims().size())); framework::ToMKLDNNDataType(
framework::TransToProtoVarType(out_t->dtype())),
platform::GetPlainMKLDNNFormat(out_t->dims().size()));
out_t->set_mem_desc(out_mem_desc);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -29,11 +29,8 @@ class ShuffleChannelMKLDNNHandler ...@@ -29,11 +29,8 @@ class ShuffleChannelMKLDNNHandler
: platform::MKLDNNHandlerNoCachingT<T, dnnl::shuffle_forward>(engine, : platform::MKLDNNHandlerNoCachingT<T, dnnl::shuffle_forward>(engine,
cpu_place) { cpu_place) {
static constexpr int channel_axis = 1; static constexpr int channel_axis = 1;
const auto md = dnnl::memory::desc(phi::vectorize(x->dims()),
MKLDNNGetDataType<T>(), x->format());
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
md, channel_axis, group); x->mem_desc(), channel_axis, group);
} }
}; };
...@@ -64,8 +61,7 @@ class ShuffleChannelMKLDNNKernel : public framework::OpKernel<T> { ...@@ -64,8 +61,7 @@ class ShuffleChannelMKLDNNKernel : public framework::OpKernel<T> {
{DNNL_ARG_DST, *dst_memory_p}}); {DNNL_ARG_DST, *dst_memory_p}});
astream.wait(); astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); out->set_mem_desc(dst_memory_p->get_desc());
out->set_format(x->format());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -47,12 +47,8 @@ class SoftmaxMKLDNNHandler ...@@ -47,12 +47,8 @@ class SoftmaxMKLDNNHandler
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The shape of input and output tensor must be identical.")); "The shape of input and output tensor must be identical."));
auto softmax_tz = phi::vectorize(input->dims()); this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring,
auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(), input->mem_desc(), axis);
input->format());
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
axis);
} }
SoftmaxMKLDNNHandler(const framework::ExecutionContext& ctx, SoftmaxMKLDNNHandler(const framework::ExecutionContext& ctx,
...@@ -73,17 +69,11 @@ class SoftmaxMKLDNNHandler ...@@ -73,17 +69,11 @@ class SoftmaxMKLDNNHandler
auto dims = out_grad->dims(); // input and output share the same shape auto dims = out_grad->dims(); // input and output share the same shape
const int axis = const int axis =
phi::funcs::CanonicalAxis(ctx.Attr<int>("axis"), dims.size()); phi::funcs::CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
auto softmax_tz = phi::vectorize<int64_t>(dims);
auto data_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out->format());
auto diff_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring,
data_softmax_md, axis); out->mem_desc(), axis);
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md, this->AcquireBackwardPrimitiveDescriptor(out_grad->mem_desc(),
axis); out->mem_desc(), axis);
} }
}; };
...@@ -128,9 +118,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -128,9 +118,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
}); });
} }
output->set_layout(framework::DataLayout::kMKLDNN); output->set_mem_desc(softmax_dst_memory_p->get_desc());
// Softmax output format is the same as input one
output->set_format(input->format());
} }
}; };
...@@ -162,8 +150,7 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -162,8 +150,7 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait(); astream.wait();
in_x_grad->set_layout(framework::DataLayout::kMKLDNN); in_x_grad->set_mem_desc(diff_src_memory_p->get_desc());
in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory_p));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -29,12 +29,11 @@ class SoftplusMKLDNNHandler ...@@ -29,12 +29,11 @@ class SoftplusMKLDNNHandler
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, : platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine,
ctx.GetPlace()) { ctx.GetPlace()) {
auto x_tz = phi::vectorize(x->dims()); auto x_tz = phi::vectorize(x->dims());
auto x_md =
dnnl::memory::desc(x_tz, platform::MKLDNNGetDataType<T>(), x->format());
auto beta_tz = std::vector<int64_t>(x_tz.size(), 1); auto beta_tz = std::vector<int64_t>(x_tz.size(), 1);
auto beta_md = dnnl::memory::desc(beta_tz, platform::MKLDNNGetDataType<T>(), auto beta_md =
x->format()); dnnl::memory::desc(beta_tz, platform::MKLDNNGetDataType<T>(),
platform::GetPlainMKLDNNFormat(x_tz.size()));
dnnl::post_ops post_ops; dnnl::post_ops post_ops;
post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_soft_relu, 0.0f, post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_soft_relu, 0.0f,
...@@ -50,7 +49,8 @@ class SoftplusMKLDNNHandler ...@@ -50,7 +49,8 @@ class SoftplusMKLDNNHandler
attrs.set_post_ops(post_ops); attrs.set_post_ops(post_ops);
this->AcquireForwardPrimitiveDescriptor(attrs, dnnl::algorithm::binary_mul, this->AcquireForwardPrimitiveDescriptor(attrs, dnnl::algorithm::binary_mul,
x_md, beta_md, x_md); x->mem_desc(), beta_md,
x->mem_desc());
} }
std::shared_ptr<dnnl::memory> AcquireBetaMemory(const float* beta) { std::shared_ptr<dnnl::memory> AcquireBetaMemory(const float* beta) {
...@@ -129,8 +129,7 @@ void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) { ...@@ -129,8 +129,7 @@ void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) {
binary_p->execute(astream, args); binary_p->execute(astream, args);
astream.wait(); astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); out->set_mem_desc(dst_memory_p->get_desc());
out->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
} }
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -28,18 +28,22 @@ class TestExpandV2OneDNNOp(OpTest): ...@@ -28,18 +28,22 @@ class TestExpandV2OneDNNOp(OpTest):
self.op_type = "expand_v2" self.op_type = "expand_v2"
self.init_data() self.init_data()
self.x = np.random.random(self.ori_shape).astype("float32") self.x = np.random.random(self.ori_shape).astype("float32")
self.set_inputs()
self.attrs = {'shape': self.shape, 'use_mkldnn': True} self.attrs = {'shape': self.shape, 'use_mkldnn': True}
self.set_inputs()
self.set_additional_inputs()
output = np.tile(self.x, self.expand_times) output = np.tile(self.x, self.expand_times)
self.outputs = {'Out': output} self.outputs = {'Out': output}
def set_inputs(self): def set_inputs(self):
self.inputs = {'X': self.x} self.inputs = {'X': self.x}
def set_additional_inputs(self):
pass
def init_data(self): def init_data(self):
self.ori_shape = [1, 140] self.ori_shape = [1, 1, 1, 140]
self.shape = [12, 140] self.shape = [2, 3, 4, 140]
self.expand_times = [12, 1] self.expand_times = [2, 3, 4, 1]
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(core.CPUPlace()) self.check_output_with_place(core.CPUPlace())
...@@ -74,7 +78,7 @@ class TestExpandV2ExpandShapesTensor1OneDNNOp(TestExpandV2OneDNNOp): ...@@ -74,7 +78,7 @@ class TestExpandV2ExpandShapesTensor1OneDNNOp(TestExpandV2OneDNNOp):
self.ori_shape = [100, 1] self.ori_shape = [100, 1]
self.expand_times = [1, 2] self.expand_times = [1, 2]
self.expand_shape = [100, 2] self.expand_shape = [100, 2]
self.shape = [-1, -1] self.shape = [100, 2]
def calc_expand_shapes_tensor(self): def calc_expand_shapes_tensor(self):
self.expand_shapes_tensor = [] self.expand_shapes_tensor = []
...@@ -82,12 +86,9 @@ class TestExpandV2ExpandShapesTensor1OneDNNOp(TestExpandV2OneDNNOp): ...@@ -82,12 +86,9 @@ class TestExpandV2ExpandShapesTensor1OneDNNOp(TestExpandV2OneDNNOp):
self.expand_shapes_tensor.append(("x" + str(index), np.ones( self.expand_shapes_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele)) (1)).astype('int32') * ele))
def set_inputs(self): def set_additional_inputs(self):
self.calc_expand_shapes_tensor() self.calc_expand_shapes_tensor()
self.inputs = { self.inputs['expand_shapes_tensor'] = self.expand_shapes_tensor
'X': self.x,
'expand_shapes_tensor': self.expand_shapes_tensor
}
class TestExpandV2ExpandShapesTensor2OneDNNOp( class TestExpandV2ExpandShapesTensor2OneDNNOp(
...@@ -104,13 +105,10 @@ class TestExpandV2ShapesTensorOneDNNOp(TestExpandV2OneDNNOp): ...@@ -104,13 +105,10 @@ class TestExpandV2ShapesTensorOneDNNOp(TestExpandV2OneDNNOp):
self.ori_shape = [100] self.ori_shape = [100]
self.expand_times = [2, 1] self.expand_times = [2, 1]
self.expand_shape = [2, 100] self.expand_shape = [2, 100]
self.shape = [-1, -1] self.shape = [2, 100]
def set_inputs(self): def set_additional_inputs(self):
self.inputs = { self.inputs['Shape'] = np.array(self.expand_shape).astype("int32")
'X': self.x,
'Shape': np.array(self.expand_shape).astype("int32")
}
# BF16 TESTS # BF16 TESTS
...@@ -118,6 +116,7 @@ def create_expand_v2_bf16_test_class(parent): ...@@ -118,6 +116,7 @@ def create_expand_v2_bf16_test_class(parent):
@OpTestTool.skip_if_not_cpu_bf16() @OpTestTool.skip_if_not_cpu_bf16()
class TestExpandV2BF16OneDNNOp(parent): class TestExpandV2BF16OneDNNOp(parent):
def set_inputs(self): def set_inputs(self):
self.attrs['mkldnn_data_type'] = 'bfloat16'
self.inputs = {"X": convert_float_to_uint16(self.x)} self.inputs = {"X": convert_float_to_uint16(self.x)}
def calculate_grads(self): def calculate_grads(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册