提交 f1cebfef 编写于 作者: J Jacek Czaja

- more fixes

上级 d7652d5f
...@@ -1205,8 +1205,8 @@ std::ostream& operator<<(std::ostream& os, const phi::DenseTensor& t) { ...@@ -1205,8 +1205,8 @@ std::ostream& operator<<(std::ostream& os, const phi::DenseTensor& t) {
<< "\n"; << "\n";
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
os << " - format: " os << " - memory desc: "
<< dnnl_fmt_tag2str(static_cast<dnnl_format_tag_t>(t.format())) << "\n"; << (t.mem_desc()) << "\n";
#endif #endif
DenseTensor tensor; DenseTensor tensor;
......
...@@ -203,15 +203,15 @@ class MulPrimitiveFactory { ...@@ -203,15 +203,15 @@ class MulPrimitiveFactory {
const ExecutionContext &ctx) { const ExecutionContext &ctx) {
Tensor x_tmp; Tensor x_tmp;
Tensor data_matrix; Tensor data_matrix;
MKLDNNMemoryFormat src_fmt = data->format();
MKLDNNMemoryFormat dst_fmt; MKLDNNMemoryFormat dst_fmt;
auto src_mdesc = CreateMemDescriptor<T>(data, src_fmt); // This code is enforcing plain (non-blocked) memory arrangement
// in order to flatten (reduce dimensionality) of Tensor later
if ((data->dims().size() == 4 && auto src_mdesc = data->mem_desc();
src_fmt != (dst_fmt = MKLDNNMemoryFormat::nchw)) || auto dst_mdesc = data->dims().size() >= 4 ? (data->dims().size() == 5 ?
(data->dims().size() == 5 && CreateMemDescriptor(data, MKLDNNMemoryFormat::ncdhw) :
src_fmt != (dst_fmt = MKLDNNMemoryFormat::ncdhw))) { CreateMemDescriptor(data, MKLDNNMemoryFormat::nchw) ) : src_mdesc;
auto dst_mdesc = CreateMemDescriptor<T>(data, dst_fmt);
if (src_mdesc != dst_mdesc) {
x_tmp.mutable_data<T>(ctx.GetPlace(), data->memory_size()); x_tmp.mutable_data<T>(ctx.GetPlace(), data->memory_size());
Reorder(src_mdesc, Reorder(src_mdesc,
......
...@@ -85,15 +85,13 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -85,15 +85,13 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
if (reorder_p == nullptr) { if (reorder_p == nullptr) {
auto dst_tz = phi::vectorize(output->dims());
auto src_dt = framework::ToMKLDNNDataType( auto src_dt = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(input->dtype())); framework::TransToProtoVarType(input->dtype()));
auto dst_dt = with_shift ? framework::MKLDNNDataType::u8 : src_dt; auto dst_dt = with_shift ? framework::MKLDNNDataType::u8 : src_dt;
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, input->format());
src_memory = std::make_shared<dnnl::memory>( src_memory = std::make_shared<dnnl::memory>(
src_md, engine, to_void_cast<T>(input_data)); input->mem_desc(), engine, to_void_cast<T>(input_data));
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, input->format()); auto dst_md = platform::MKLDNNMemDesc({src_tz}, dst_dt, input->mem_desc().strides);
dnnl::primitive_attr attri; dnnl::primitive_attr attri;
int mask = 0; int mask = 0;
......
...@@ -83,7 +83,7 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> { ...@@ -83,7 +83,7 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
onednn_engine); onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->format(), platform::to_void_cast(x->data<T>())); x->mem_desc(), platform::to_void_cast(x->data<T>()));
out->Resize(x_dims); // to match x numel, format is changed later out->Resize(x_dims); // to match x numel, format is changed later
// reorder is done into a plain tag to allow usage with blocked formats // reorder is done into a plain tag to allow usage with blocked formats
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
...@@ -356,7 +356,7 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> { ...@@ -356,7 +356,7 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
onednn_engine); onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>())); dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx, this->getPlainFormatTag(dout), ctx.GetPlace()); dx, this->getPlainFormatTag(dout), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
......
...@@ -161,11 +161,7 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -161,11 +161,7 @@ class SliceOp : public framework::OperatorWithKernel {
// reorders, because if blocked dimension is not divisible by 8 or // reorders, because if blocked dimension is not divisible by 8 or
// 16(depending on which blocking format is used) submemory cannot be // 16(depending on which blocking format is used) submemory cannot be
// created, so in that scenario a fallback is needed // created, so in that scenario a fallback is needed
auto tmp_md = dnnl::memory::desc( if (ctx.Input<phi::DenseTensor>("Input")->mem_desc().data.format_desc.blocking.inner_nblks == 0)
phi::vectorize(ctx.Input<phi::DenseTensor>("Input")->dims()),
dnnl::memory::data_type::f32,
ctx.Input<phi::DenseTensor>("Input")->format());
if (tmp_md.data.format_desc.blocking.inner_nblks == 0)
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type,
ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
...@@ -336,13 +332,7 @@ class SliceOpGrad : public framework::OperatorWithKernel { ...@@ -336,13 +332,7 @@ class SliceOpGrad : public framework::OperatorWithKernel {
// reorders, because if blocked dimension is not divisible by 8 or // reorders, because if blocked dimension is not divisible by 8 or
// 16(depending on which blocking format is used) submemory cannot be // 16(depending on which blocking format is used) submemory cannot be
// created, so in that scenario a fallback is needed // created, so in that scenario a fallback is needed
auto tmp_md = dnnl::memory::desc( if (ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"))->mem_desc().data.format_desc.blocking.inner_nblks == 0)
phi::vectorize(
ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"))
->dims()),
dnnl::memory::data_type::f32,
ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"))->format());
if (tmp_md.data.format_desc.blocking.inner_nblks == 0)
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type,
ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
......
...@@ -56,7 +56,6 @@ DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) { ...@@ -56,7 +56,6 @@ DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) {
inplace_version_counter_ = other.inplace_version_counter_; inplace_version_counter_ = other.inplace_version_counter_;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
mem_desc_ = other.mem_desc_; mem_desc_ = other.mem_desc_;
#endif #endif
} }
...@@ -66,7 +65,6 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) { ...@@ -66,7 +65,6 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
holder_ = other.holder_; holder_ = other.holder_;
inplace_version_counter_ = other.inplace_version_counter_; inplace_version_counter_ = other.inplace_version_counter_;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
mem_desc_ = other.mem_desc_; mem_desc_ = other.mem_desc_;
#endif #endif
return *this; return *this;
...@@ -77,7 +75,6 @@ DenseTensor& DenseTensor::operator=(DenseTensor&& other) { ...@@ -77,7 +75,6 @@ DenseTensor& DenseTensor::operator=(DenseTensor&& other) {
std::swap(holder_, other.holder_); std::swap(holder_, other.holder_);
std::swap(inplace_version_counter_, other.inplace_version_counter_); std::swap(inplace_version_counter_, other.inplace_version_counter_);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
mem_desc_ = other.mem_desc_; mem_desc_ = other.mem_desc_;
#endif #endif
return *this; return *this;
......
...@@ -227,16 +227,6 @@ In the final state, we should come up with a MKLDNN_Tensor and move the ...@@ -227,16 +227,6 @@ In the final state, we should come up with a MKLDNN_Tensor and move the
following codes there. following codes there.
*/ */
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
/**
* @brief the detail format of memory block which have layout as kMKLDNN
*
* @note MKLDNN lib support various memory format like nchw, nhwc, nChw8C,
* nChw16c, etc. For a MKLDNN memory block, layout will be set as
* DataLayout::kMKLDNN meanwhile detail memory format will be kept in
* this field.
*/
dnnl::memory::format_tag format_ = dnnl::memory::format_tag::undef;
/// \brief memory descriptor of tensor which have layout set as kMKLDNN /// \brief memory descriptor of tensor which have layout set as kMKLDNN
dnnl::memory::desc mem_desc_; dnnl::memory::desc mem_desc_;
#endif #endif
......
...@@ -365,7 +365,6 @@ DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) { ...@@ -365,7 +365,6 @@ DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) {
meta_.layout = src.meta_.layout; meta_.layout = src.meta_.layout;
meta_.offset = src.meta_.offset; meta_.offset = src.meta_.offset;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
format_ = src.format_;
mem_desc_ = src.mem_desc_; mem_desc_ = src.mem_desc_;
#endif #endif
return *this; return *this;
......
...@@ -49,7 +49,7 @@ void SplitKernel(const Context& dev_ctx, ...@@ -49,7 +49,7 @@ void SplitKernel(const Context& dev_ctx,
out_vec_dims, offset, reorder_src_memory_p); out_vec_dims, offset, reorder_src_memory_p);
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out[i], out_vec_dims, x.format(), dev_ctx.GetPlace()); out[i], slice_mem_p->get_desc(), dev_ctx.GetPlace());
auto reorder_p = auto reorder_p =
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册