未验证 提交 c9f4fcf3 编写于 作者: J jakpiase 提交者: GitHub

OneDNN md-in-tensor refactoring part 1: Added main changes for md-in-tensor (#41303)

* changes for md in tensor

* ci fix

* Temporarily limited dims for test

* ci fix

* removed unnecessary includes

* added reviewers suggestions

* checkouted two files to avoid changing more than 19 in single PR

* minor fix

* reverted one file to reduce files changed to 19
上级 3e8b6bbc
......@@ -134,15 +134,6 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out,
platform::Place place, bool always_copy) {
PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Input tensor format is invalid. Input tensor should "
"have specified memory format."));
PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::any,
platform::errors::InvalidArgument(
"Input tensor format is invalid. Input tensor should "
"have specified memory format."));
// Set default as NCHW in case not specified
out_layout =
out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout;
......@@ -162,22 +153,24 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
"Input tensor type (%s) is not supported.",
DataTypeToString(framework::TransToProtoVarType(in.dtype()))));
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format =
platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
dnnl::memory::desc out_mem_desc(out_tz, in_type, out_format);
// output tensor has the same dims as input. Reorder don't change dims
out->set_mem_desc(out_mem_desc);
out->Resize(in.dims());
if ((in_format != out_format) || always_copy) {
if ((in.mem_desc() != out->mem_desc()) || always_copy) {
void* in_data = GetDataFromTensor(in, in_type);
platform::ReorderMKLDNNHandler handler(
in_tz, framework::TransToProtoVarType(in.dtype()), in_type, cpu_engine);
auto reorder_src_memory_p = handler.AcquireSrcMemory(in_format, in_data);
auto reorder_src_memory_p =
handler.AcquireSrcMemory(in.mem_desc(), in_data);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(out, out_format, place);
handler.AcquireDstMemory(out, out->mem_desc(), place);
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
......
......@@ -70,8 +70,10 @@ void TransformData(const OpKernelType &expected_kernel_type,
paddle::platform::MKLDNNDeviceContext::tls()
.set_cur_paddle_data_layout(lin);
}
out.set_layout(DataLayout::kMKLDNN);
out.set_format(out_format);
dnnl::memory::desc out_mem_desc(
vectorize(out.dims()),
ToMKLDNNDataType(TransToProtoVarType(in.type())), out_format);
out.set_mem_desc(out_mem_desc);
} else {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
// Do transform via MKLDNN lib
......@@ -121,8 +123,9 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor,
tran_lod_tensor->set_lod(in_lod_tensor.lod());
tran_lod_tensor->set_layout(in_lod_tensor.layout());
#ifdef PADDLE_WITH_MKLDNN
tran_lod_tensor->set_format(in_lod_tensor.format());
tran_lod_tensor->set_mem_desc(in_lod_tensor.mem_desc());
#endif
tran_lod_tensor->set_layout(in_lod_tensor.layout());
tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<phi::SelectedRows>()) {
auto &in_selected_rows = in_var.Get<phi::SelectedRows>();
......
......@@ -51,7 +51,7 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place,
auto src_place = src.place();
auto src_ptr = src.data();
#ifdef PADDLE_WITH_MKLDNN
dst->set_format(src.format());
dst->set_mem_desc(src.mem_desc());
// oneDNN tensors due to padding may be of bigger size
// than numel()*size(type())
auto dst_ptr =
......@@ -61,6 +61,7 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place,
#else
auto dst_ptr = dst->mutable_data(dst_place, src.dtype());
#endif
dst->set_layout(src.layout());
if (src_ptr == dst_ptr && src_place == dst_place) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
<< dst_place;
......
......@@ -47,9 +47,9 @@ class CastMKLDNNKernel : public framework::OpKernel<T> {
dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->format(), platform::to_void_cast(x->data<T>()));
auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(out, x->format(), dev_ctx.GetPlace());
x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, x->mem_desc(), dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
......@@ -58,7 +58,7 @@ class CastMKLDNNKernel : public framework::OpKernel<T> {
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
out->set_mem_desc(reorder_dst_memory_p->get_desc());
}
};
} // namespace operators
......
......@@ -45,20 +45,19 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> {
out_new_dims[i] = out_new_dims[i] > 0 ? out_new_dims[i] : x_vec_dims[i];
}
dnnl::memory::format_tag x_format_tag = x->format();
dnnl::memory::desc x_mem_desc = x->mem_desc();
if (x_vec_dims.size() != out_new_dims.size()) {
x_format_tag =
GetExtendedFormatTag(x_vec_dims, out_new_dims.size(), x_format_tag);
x_mem_desc = GetExtendedMemoryDescriptor(x_mem_desc, x_vec_dims,
out_new_dims.size());
}
out->Resize(phi::make_ddim(out_new_dims));
out->set_format(x_format_tag);
paddle::platform::BroadcastDataMKLDNNHandler<T> handler(
dnnl::algorithm::binary_add, onednn_engine, ctx.GetPlace(), out, x,
0.0f, 1.0f, x_vec_dims);
0.0f, 1.0f, x_mem_desc);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out);
auto dst_memory_p = handler.AcquireDstMemory(out); // acquires zeroed mem
auto binary_p = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
......@@ -70,22 +69,18 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> {
binary_p->execute(astream, args);
astream.wait();
out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format(paddle::platform::GetMKLDNNFormat(*dst_memory_p));
out->set_mem_desc(dst_memory_p->get_desc());
}
private:
dnnl::memory::format_tag GetExtendedFormatTag(
std::vector<int64_t>& dims, int new_size, // NOLINT
dnnl::memory::format_tag format_tag) const {
dnnl::memory::desc md(dims, paddle::platform::MKLDNNGetDataType<T>(),
format_tag);
dnnl::memory::desc GetExtendedMemoryDescriptor(
const dnnl::memory::desc& x_mem_desc,
const std::vector<int64_t>& x_vec_dims, int new_size) const {
std::vector<int64_t> new_dims(new_size, 1);
std::copy(dims.begin(), dims.end(),
new_dims.begin() + new_size - dims.size());
std::copy(x_vec_dims.begin(), x_vec_dims.end(),
new_dims.begin() + new_size - x_vec_dims.size());
dims = std::move(new_dims);
return paddle::platform::GetMKLDNNFormat(md.reshape(dims));
return x_mem_desc.reshape(new_dims);
}
};
......
......@@ -15,26 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
static dnnl::memory::format_tag get_plain_format_tag(
const paddle::framework::Tensor* tensor) {
auto tensor_dims_size = tensor->dims().size();
switch (tensor_dims_size) {
case 1:
return dnnl::memory::format_tag::a;
case 2:
return dnnl::memory::format_tag::ab;
case 3:
return dnnl::memory::format_tag::abc;
case 4:
return dnnl::memory::format_tag::abcd;
case 5:
return dnnl::memory::format_tag::abcde;
}
return dnnl::memory::format_tag::abcdef;
}
namespace paddle {
namespace operators {
......@@ -105,11 +85,12 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> {
onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->format(), platform::to_void_cast(x->data<T>()));
x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets,
reorder_src_memory_p);
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, slice_dims, get_plain_format_tag(x), ctx.GetPlace());
out, slice_dims, platform::GetPlainMKLDNNFormat(x_vec_dims.size()),
ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);
......@@ -133,9 +114,7 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> {
astream.wait();
out->Resize(phi::make_ddim(new_out_dims));
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(
reorder_dst_memory_p->get_desc().reshape(new_out_dims)));
out->set_mem_desc(reorder_dst_memory_p->get_desc().reshape(new_out_dims));
}
};
template <typename T>
......
......@@ -101,7 +101,7 @@ class SplitMKLDNNKernel : public framework::OpKernel<T> {
x_vec_dims, framework::TransToProtoVarType(x->dtype()), x_type,
onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->format(), platform::to_void_cast(x->data<T>()));
x->mem_desc(), platform::to_void_cast(x->data<T>()));
for (size_t i = 0; i < outs_number; ++i) {
auto out_vec_dims = phi::vectorize(outs[i]->dims());
......@@ -117,8 +117,7 @@ class SplitMKLDNNKernel : public framework::OpKernel<T> {
offset[axis] += num > 0 ? x->dims()[axis] / num : sections[i];
outs[i]->set_layout(framework::DataLayout::kMKLDNN);
outs[i]->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
outs[i]->set_mem_desc(reorder_dst_memory_p->get_desc());
}
astream.wait();
}
......
......@@ -77,10 +77,10 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
input_type, onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
input->format(), platform::to_void_cast(input->data<T>()));
input->mem_desc(), platform::to_void_cast(input->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
output, input->format(), ctx.GetPlace());
output, input->mem_desc(), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
......@@ -88,10 +88,8 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
output->set_layout(framework::DataLayout::kMKLDNN);
output->set_format(
platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(output->dims()))));
output->set_mem_desc(reorder_dst_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(output->dims())));
} else {
platform::ReductionMKLDNNHandler<T> handler(reduction_type, 0.0f, 0.0f,
onednn_engine, ctx.GetPlace(),
......@@ -107,10 +105,8 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
reduction_p->execute(astream, reduction_args);
astream.wait();
output->set_layout(framework::DataLayout::kMKLDNN);
output->set_format(
platform::GetMKLDNNFormat(dst_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(output->dims()))));
output->set_mem_desc(dst_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(output->dims())));
}
}
};
......@@ -128,37 +124,25 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel<T> {
bool keep_dim = ctx.Attr<bool>("keep_dim");
bool reduce_all = ctx.Attr<bool>("reduce_all");
auto dims = ctx.Attr<std::vector<int>>("dim");
auto* input_dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* output_dx = ctx.Output<Tensor>(framework::GradVarName("X"));
const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dnnl::memory::format_tag x_format_tag;
auto input_dims =
CalculateReducedDims(output_dx, input_dy, dims, reduce_all, keep_dim);
auto output_dims = phi::vectorize(output_dx->dims());
const auto input_dims =
CalculateReducedDims(dx, dout, dims, reduce_all, keep_dim);
const auto output_dims = phi::vectorize(dx->dims());
if (input_dims != output_dims) {
auto input_dy_md = dnnl::memory::desc(phi::vectorize(input_dy->dims()),
platform::MKLDNNGetDataType<T>(),
input_dy->format());
auto input_dy_ex_md = input_dy_md.reshape(input_dims);
// TODO(jczaja): once MD is stored in Tensor we no longer need to guess
// formats
x_format_tag = platform::GetMKLDNNFormat(input_dy_ex_md);
auto dout_mem_desc = dout->mem_desc();
} else {
// There was no broadcasting then just simple copy is done
// same format used for input and output
x_format_tag = getPlainFormatTag(output_dx);
if (input_dims != output_dims) {
dout_mem_desc = dout_mem_desc.reshape(input_dims);
}
output_dx->set_format(x_format_tag);
platform::BroadcastDataMKLDNNHandler<T> handler(
binary_type, onednn_engine, ctx.GetPlace(), output_dx, input_dy,
scale_x, scale_y, input_dims);
binary_type, onednn_engine, ctx.GetPlace(), dx, dout, scale_x, scale_y,
dout_mem_desc);
const auto src_memory_p = handler.AcquireSrcMemory(input_dy);
const auto dst_memory_p = handler.AcquireDstMemory(output_dx);
const auto src_memory_p = handler.AcquireSrcMemory(dout);
const auto dst_memory_p = handler.AcquireDstMemory(dx);
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
......@@ -170,29 +154,7 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel<T> {
binary_prim->execute(astream, args);
astream.wait();
output_dx->set_layout(framework::DataLayout::kMKLDNN);
}
protected:
dnnl::memory::format_tag getPlainFormatTag(const Tensor* tensor) const {
auto tensor_dims_size = tensor->dims().size();
PADDLE_ENFORCE_EQ(
tensor_dims_size <= 5 && tensor_dims_size >= 1, true,
platform::errors::InvalidArgument(
"Dims for reduction_grad oneDNN op must be in range <1, 5>"));
switch (tensor_dims_size) {
case 1:
return dnnl::memory::format_tag::a;
case 2:
return dnnl::memory::format_tag::ab;
case 3:
return dnnl::memory::format_tag::abc;
case 4:
return dnnl::memory::format_tag::abcd;
}
return dnnl::memory::format_tag::abcde;
dx->set_mem_desc(dst_memory_p->get_desc());
}
};
......
......@@ -84,10 +84,8 @@ class SplitOp : public framework::OperatorWithKernel {
// reorders, because if blocked dimension is not divisible by 8 or
// 16(depending on which blocking format is used) submemory cannot be
// created, so in that scenario a fallback is needed
auto tmp_md = dnnl::memory::desc(
phi::vectorize(ctx.Input<Tensor>("X")->dims()),
dnnl::memory::data_type::f32, ctx.Input<Tensor>("X")->format());
if (tmp_md.data.format_desc.blocking.inner_nblks == 0)
const auto x_md = ctx.Input<Tensor>("X")->mem_desc();
if (x_md.data.format_desc.blocking.inner_nblks == 0)
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......
......@@ -625,20 +625,12 @@ class BinaryMKLDNNHandler
platform::errors::InvalidArgument(
"Wrong layout set for X tensor. Expected: %d (kMKLDNN), Actual: %d",
DataLayout::kMKLDNN, x->layout()));
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for X tensor : %d (undef)",
static_cast<unsigned int>(x->format())));
PADDLE_ENFORCE_EQ(
y->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Y tensor. Expected: %d (kMKLDNN), Actual: %d",
DataLayout::kMKLDNN, y->layout()));
PADDLE_ENFORCE_NE(y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Y tensor : %d (undef)",
static_cast<unsigned int>(y->format())));
const auto src_x_tz = phi::vectorize(x->dims());
const auto src_y_tz = phi::vectorize(y->dims());
......@@ -648,10 +640,8 @@ class BinaryMKLDNNHandler
const auto dst_tz = (z == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
: phi::vectorize(z->dims());
auto src0_md = dnnl::memory::desc(
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
auto src1_md = dnnl::memory::desc(
src_y_tz, platform::MKLDNNGetDataType<T>(), y->format());
auto src0_md = x->mem_desc();
auto src1_md = y->mem_desc();
if (rankdiff > 0) { // Second input is of smaller rank than first
std::vector<int64_t> dims1_ex(rankdiff, 1);
dims1_ex.insert(next(dims1_ex.begin(), (axis == -1 ? rankdiff : axis)),
......@@ -730,21 +720,19 @@ class BroadcastDataMKLDNNHandler
const dnnl::engine engine,
platform::Place cpu_place, const Tensor* out,
const Tensor* x, float scale_x, float scale_y,
const std::vector<int64_t>& input_dims)
const dnnl::memory::desc& x_mem_desc)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor."));
const auto src0_tz = phi::vectorize(out->dims());
const auto src0_md = dnnl::memory::desc(
src0_tz, platform::MKLDNNGetDataType<T>(), out->format());
const auto src1_md = dnnl::memory::desc(
input_dims, platform::MKLDNNGetDataType<T>(), out->format());
const auto src0_md =
dnnl::memory::desc(src0_tz, platform::MKLDNNGetDataType<T>(),
platform::GetPlainMKLDNNFormat(src0_tz.size()));
const auto src1_md = x_mem_desc;
dnnl::primitive_attr attributes;
attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x});
......@@ -777,21 +765,16 @@ class ReductionMKLDNNHandler
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor."));
const auto x_tz = phi::vectorize(x->dims());
const auto x_md =
dnnl::memory::desc(x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto y_md =
memory::desc(y_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto y_md = memory::desc(y_tz, platform::MKLDNNGetDataType<T>(),
dnnl::memory::format_tag::any);
if (attr)
this->AcquireForwardPrimitiveDescriptor(attr, algo, x_md, y_md, p, eps);
this->AcquireForwardPrimitiveDescriptor(attr, algo, x->mem_desc(), y_md,
p, eps);
else
this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps);
this->AcquireForwardPrimitiveDescriptor(algo, x->mem_desc(), y_md, p,
eps);
}
};
......@@ -911,7 +894,7 @@ class ActivationMKLDNNHandler
ActivationMKLDNNHandler(dnnl::algorithm algorithm,
const framework::ExecutionContext& ctx,
const dnnl::engine engine, Place cpu_place,
const framework::Tensor* in_x)
const framework::Tensor* x)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::eltwise_forward,
dnnl::eltwise_backward>(engine,
cpu_place) {
......@@ -946,25 +929,15 @@ class ActivationMKLDNNHandler
}
}
PADDLE_ENFORCE(in_x->dims().size() >= 1 || in_x->dims().size() <= 6,
platform::errors::Unimplemented(
"Input dimension size can be 1, 2, 3, 4, "
"5, or 6, but now the dimension size is",
in_x->dims().size()));
auto src_tz = phi::vectorize<int64_t>(in_x->dims());
auto src_fmt = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format();
auto md =
dnnl::memory::desc(src_tz, platform::MKLDNNGetDataType<T>(), src_fmt);
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
algorithm, md, alpha, beta);
algorithm, x->mem_desc(), alpha,
beta);
}
ActivationMKLDNNHandler(dnnl::algorithm algorithm,
const framework::ExecutionContext& ctx,
const dnnl::engine engine, Place cpu_place,
const framework::Tensor* in_x, const Tensor* out_grad)
const framework::Tensor* x, const Tensor* dout)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::eltwise_forward,
dnnl::eltwise_backward>(engine,
cpu_place) {
......@@ -985,23 +958,11 @@ class ActivationMKLDNNHandler
: ctx.Attr<float>("max");
}
auto diff_dst_tz = phi::vectorize<int64_t>(out_grad->dims());
auto src_fmt =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format();
auto diff_fmt =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : out_grad->format();
auto dims = phi::vectorize(in_x->dims());
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), src_fmt);
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
algorithm, src_md, alpha, beta);
this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md,
alpha, beta);
algorithm, x->mem_desc(), alpha,
beta);
this->AcquireBackwardPrimitiveDescriptor(algorithm, dout->mem_desc(),
x->mem_desc(), alpha, beta);
}
std::shared_ptr<dnnl::memory> AcquireBackwardSrcMemory(
......@@ -1036,6 +997,11 @@ class ReorderMKLDNNHandler {
dtype_dst_(dtype_dst),
engine_(engine) {}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const dnnl::memory::desc& md,
void* ptr) {
return std::make_shared<dnnl::memory>(md, engine_, ptr);
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const MKLDNNMemoryFormat& fmt,
void* ptr) {
auto md = dnnl::memory::desc(dims_, dtype_, fmt);
......@@ -1060,6 +1026,22 @@ class ReorderMKLDNNHandler {
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(
framework::Tensor* output, const dnnl::memory::desc& src_md,
platform::Place place) {
if (vtype_dst_ == vtype_) {
auto dst_data = output->mutable_data(
place, framework::TransToPhiDataType(vtype_dst_), src_md.get_size());
return std::make_shared<dnnl::memory>(src_md, engine_, dst_data);
} else {
auto dst_md = src_md;
dst_md.data.data_type = static_cast<dnnl_data_type_t>(dtype_dst_);
auto dst_data = output->mutable_data(
place, framework::TransToPhiDataType(vtype_dst_), dst_md.get_size());
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(
framework::Tensor* output, const std::vector<int64_t>& dims,
const MKLDNNMemoryFormat& fmt, platform::Place place) {
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// NOTE:
// GetMKLDNNFormat function is here temporarily. It is
// needed because without them forward declaration was causing an error when
// building with "-DWITH_TESTING=ON". This file will be deleted after completing
// md-related refactoring
namespace paddle {
namespace platform {
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
auto ndims = mem_desc.data.ndims;
auto strides = mem_desc.data.format_desc.blocking.strides;
auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;
if (ndims == 1) {
return dnnl::memory::format_tag::x;
} else if (ndims == 2) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1]) {
return dnnl::memory::format_tag::nc;
} else {
return dnnl::memory::format_tag::cn;
}
}
} else if (ndims == 3) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
return dnnl::memory::format_tag::ncw;
} else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
return dnnl::memory::format_tag::ntc;
} else {
return dnnl::memory::format_tag::nwc;
}
}
} else if (ndims == 4) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3]) {
return dnnl::memory::format_tag::nchw;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba;
} else if (strides[3] >= strides[2] && strides[2] >= strides[0] &&
strides[0] >= strides[1]) {
return dnnl::memory::format_tag::dcab;
} else {
return dnnl::memory::format_tag::nhwc;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw16c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw8c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::Acdb8a;
}
} else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw4c;
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::Acdb16a;
}
}
} else if (inner_nblks == 2) {
if (inner_blks[0] == 16 && inner_blks[1] == 16) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return dnnl::memory::format_tag::OIhw16i16o;
}
} else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return dnnl::memory::format_tag::OIhw8i8o;
}
}
}
} else if (ndims == 5) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::abcde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::acbde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::acdeb;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::Acdeb8a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::Abcde8a;
}
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde8b;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::Acdeb16a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::Abcde16a;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde16b;
}
}
}
} else if (ndims == 6) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::abcdef;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::acbdef;
}
}
}
// DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
// std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
// std::cout<<"NDIMS: "<<ndims<<std::endl;
// std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
// for (int i=0;i<ndims;++i) {
// std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
// }
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
// }
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
// }
return dnnl::memory::format_tag::undef;
}
} // namespace platform
} // namespace paddle
......@@ -16,7 +16,7 @@ cc_library(tensor_base SRCS tensor_base.cc allocator.cc DEPS phi_enforce)
cc_library(tensor_meta SRCS tensor_meta.cc DEPS phi_enforce)
cc_library(lod_utils SRCS lod_utils.cc DEPS phi_enforce)
cc_library(dense_tensor SRCS dense_tensor.cc dense_tensor_impl.cc DEPS fluid_convert_utils tensor_meta tensor_base)
cc_library(dense_tensor SRCS dense_tensor.cc dense_tensor_impl.cc DEPS convert_utils fluid_convert_utils tensor_meta tensor_base)
cc_library(sparse_coo_tensor SRCS sparse_coo_tensor.cc DEPS tensor_meta tensor_base)
cc_library(sparse_csr_tensor SRCS sparse_csr_tensor.cc DEPS dense_tensor tensor_base)
cc_library(string_tensor SRCS string_tensor.cc DEPS convert_utils tensor_meta tensor_base)
......
......@@ -121,4 +121,24 @@ const std::string& TransToFluidOpName(const std::string& phi_kernel_name) {
return phi_kernel_name;
}
#ifdef PADDLE_WITH_MKLDNN
dnnl::memory::data_type TransToMKLDNNDataType(
const paddle::experimental::DataType& dtype) {
switch (dtype) {
case DataType::FLOAT32:
return dnnl::memory::data_type::f32;
case DataType::BFLOAT16:
return dnnl::memory::data_type::bf16;
case DataType::INT8:
return dnnl::memory::data_type::s8;
case DataType::UINT8:
return dnnl::memory::data_type::u8;
case DataType::INT32:
return dnnl::memory::data_type::s32;
default:
return dnnl::memory::data_type::undef;
}
}
#endif
} // namespace phi
......@@ -20,6 +20,10 @@ limitations under the License. */
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
#ifdef PADDLE_WITH_MKLDNN
#include "dnnl.hpp"
#endif
namespace phi {
const std::string& TransToPhiKernelName(const std::string& fluid_op_name);
......@@ -28,4 +32,9 @@ const std::string& TransToFluidOpName(const std::string& phi_kernel_name);
Backend TransToPhiBackend(const phi::Place& place);
phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id = true);
#ifdef PADDLE_WITH_MKLDNN
dnnl::memory::data_type TransToMKLDNNDataType(
const paddle::experimental::DataType& dtype);
#endif
} // namespace phi
......@@ -57,6 +57,7 @@ DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) {
#ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
mem_desc_ = other.mem_desc_;
#endif
}
......@@ -66,6 +67,7 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
inplace_version_counter_ = other.inplace_version_counter_;
#ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
mem_desc_ = other.mem_desc_;
#endif
return *this;
}
......@@ -74,6 +76,10 @@ DenseTensor& DenseTensor::operator=(DenseTensor&& other) {
meta_ = std::move(other.meta_);
std::swap(holder_, other.holder_);
std::swap(inplace_version_counter_, other.inplace_version_counter_);
#ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
mem_desc_ = other.mem_desc_;
#endif
return *this;
}
......
......@@ -207,6 +207,9 @@ following codes there.
* this field.
*/
dnnl::memory::format_tag format_ = dnnl::memory::format_tag::undef;
/// \brief memory descriptor of tensor which have layout set as kMKLDNN
dnnl::memory::desc mem_desc_;
#endif
#ifndef PADDLE_WITH_CUSTOM_KERNEL
......
......@@ -20,6 +20,7 @@ limitations under the License. */
Will be adjusted/removed/moved in the near future
*/
public:
/* @jim19930609: Remove dependency on protobuf after Tensor Unification.
*/
......@@ -127,7 +128,14 @@ following codes there.
#ifdef PADDLE_WITH_MKLDNN
public:
inline dnnl::memory::format_tag format() const { return format_; }
dnnl::memory::desc mem_desc() const;
inline void set_mem_desc(const dnnl::memory::desc& mem_desc) {
mem_desc_ = mem_desc;
meta_.layout = DataLayout::kMKLDNN;
}
dnnl::memory::format_tag format() const;
inline void set_format(const dnnl::memory::format_tag format) {
format_ = format;
......
......@@ -21,6 +21,10 @@ limitations under the License. */
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/compat/convert_utils.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_utils.h"
#endif
namespace phi {
/* --------------------------- */
/* From framework::Tensor */
......@@ -354,6 +358,19 @@ std::vector<DenseTensor> DenseTensor::Chunk(int64_t chunks,
return Split(split_size, axis);
}
#ifdef PADDLE_WITH_MKLDNN
dnnl::memory::desc DenseTensor::mem_desc() const {
return mem_desc_ ? mem_desc_
: dnnl::memory::desc(phi::vectorize(meta_.dims),
phi::TransToMKLDNNDataType(meta_.dtype),
format_);
}
dnnl::memory::format_tag DenseTensor::format() const {
return mem_desc_ ? paddle::platform::GetMKLDNNFormat(mem_desc_) : format_;
}
#endif
DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) {
src.check_memory_size();
// Preserve LoD
......
......@@ -53,7 +53,7 @@ class TestMkldnnShapeOp(MkldnnAutoScanTest):
@given(
in_shape=st.lists(
st.integers(
min_value=1, max_value=3), min_size=1, max_size=9),
min_value=1, max_value=3), min_size=1, max_size=6),
in_dtype=st.sampled_from([np.float32, np.uint16, np.int8, np.uint8]))
def test(self, *args, **kwargs):
self.run_test(quant=False, *args, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册