提交 fe581b0e 编写于 作者: A Adam 提交者: Tao Luo

Minor GetMKLDNNFormat changes (#20055)

test=develop
上级 54e07994
......@@ -167,9 +167,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
stream(stream::kind::eager).submit(pipeline).wait();
z->set_layout(DataLayout::kMKLDNN);
z->set_format((MKLDNNMemoryFormat)dst_memory->get_primitive_desc()
.desc()
.data.format);
z->set_format(platform::GetMKLDNNFormat(*dst_memory));
}
}
};
......
......@@ -47,11 +47,6 @@ static memory::primitive_desc CreateMemPrimDesc(const Tensor& input,
return mem_prim_desc;
}
static MKLDNNMemoryFormat GetDstMemFormat(
const concat::primitive_desc& concat_pd) {
return (MKLDNNMemoryFormat)concat_pd.dst_primitive_desc().desc().data.format;
}
static platform::CPUPlace GetCpuPlace(
const paddle::framework::ExecutionContext& ctx) {
auto place = ctx.GetPlace();
......@@ -198,7 +193,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
stream(stream::kind::eager).submit({*concat_p}).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetDstMemFormat(*concat_pd));
output->set_format(platform::GetMKLDNNFormat(*dst_mem));
}
};
} // namespace operators
......
......@@ -71,7 +71,7 @@ class FCPrimitiveFactory {
input_->set_data_handle(const_cast<T*>(in->data<T>()));
output_->set_data_handle(out->mutable_data<T>(ctx.GetPlace()));
if (out->format() == MKLDNNMemoryFormat::format_undef) {
auto output_format = output_->get_primitive_desc().desc().data.format;
auto output_format = platform::GetMKLDNNFormat(*output_);
out->set_format((MKLDNNMemoryFormat)output_format);
}
}
......@@ -199,8 +199,9 @@ class FCPrimitiveFactory {
auto dst_prim_desc = fc_prim_desc.dst_primitive_desc();
auto buffer_size = dst_prim_desc.get_size();
T* output_data = output->mutable_data<T>(ctx.GetPlace(), buffer_size);
output->set_format((MKLDNNMemoryFormat)dst_prim_desc.desc().data.format);
return memory(dst_prim_desc, to_void_cast<T>(output_data));
memory dst_mem(dst_prim_desc, to_void_cast<T>(output_data));
output->set_format(platform::GetMKLDNNFormat(dst_mem));
return dst_mem;
}
void RecomputeOutputDims(const ExecutionContext& ctx, const LoDTensor* input,
......
......@@ -77,13 +77,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<mkldnn::primitive> pipeline = {*lrn_p};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
auto output_format =
(mkldnn::memory::format)dst_memory->get_primitive_desc()
.desc()
.data.format;
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(output_format);
out->set_format(platform::GetMKLDNNFormat(*dst_memory));
}
};
......@@ -129,13 +124,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<mkldnn::primitive> pipeline = {*lrn_bwd};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
auto output_format =
(mkldnn::memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format;
x_grad->set_layout(framework::DataLayout::kMKLDNN);
x_grad->set_format(output_format);
x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
}
};
} // namespace operators
......
......@@ -107,7 +107,7 @@ class MulPrimitiveFactory {
output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
if (out->format() == MKLDNNMemoryFormat::format_undef) {
auto output_format = output_->get_primitive_desc().desc().data.format;
auto output_format = platform::GetMKLDNNFormat(*output_);
out->set_format((MKLDNNMemoryFormat)output_format);
}
}
......@@ -139,8 +139,9 @@ class MulPrimitiveFactory {
auto buffer_size = dst_prim_desc.get_size();
OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size);
output->set_format((MKLDNNMemoryFormat)dst_prim_desc.desc().data.format);
return memory(dst_prim_desc, to_void_cast<OT>(output_data));
memory dst_mem(dst_prim_desc, to_void_cast<OT>(output_data));
output->set_format(platform::GetMKLDNNFormat(dst_mem));
return dst_mem;
}
memory Reorder(const memory::desc &src_desc, const memory::desc &dst_desc,
......
......@@ -95,11 +95,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<mkldnn::primitive> pipeline{*pool_p};
stream(stream::kind::eager).submit(pipeline).wait();
auto output_format =
(MKLDNNMemoryFormat)dst_memory->get_primitive_desc().desc().data.format;
output->set_layout(DataLayout::kMKLDNN);
output->set_format(output_format);
output->set_format(platform::GetMKLDNNFormat(*dst_memory));
}
};
......@@ -179,12 +176,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*pool_bwd_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
auto in_x_grad_format =
(MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc()
.desc()
.data.format;
in_x_grad->set_layout(DataLayout::kMKLDNN);
in_x_grad->set_format(in_x_grad_format);
in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
} // Compute()
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册