未验证 提交 ba90e052 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #15917 from jczaja/prv-tensor-mkldnn-ops

[MKL-DNN] Adjusting ops to Tensor modifications
...@@ -77,8 +77,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -77,8 +77,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
} else { } else {
functor.RunMidWise(n, pre, post); functor.RunMidWise(n, pre, post);
} }
z->set_layout(DataLayout::kMKLDNN); z->set_mkldnn_prim_desc(x->get_mkldnn_prim_desc());
z->set_format(x->format());
} else { } else {
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
x->format() != memory::format::format_undef, x->format() != memory::format::format_undef,
...@@ -116,7 +115,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -116,7 +115,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd); auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd);
// create mkldnn memory for dst // create mkldnn memory for dst
memory dst_memory = memory(sum_pd.dst_primitive_desc(), z_data); auto dst_mem_pd = sum_pd.dst_primitive_desc();
memory dst_memory = memory(dst_mem_pd, z_data);
std::vector<primitive::at> inputs; std::vector<primitive::at> inputs;
inputs.push_back(srcs[0]); inputs.push_back(srcs[0]);
...@@ -129,9 +129,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -129,9 +129,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
pipeline.push_back(sum_prim); pipeline.push_back(sum_prim);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
z->set_layout(DataLayout::kMKLDNN); z->set_mkldnn_prim_desc(dst_mem_pd);
z->set_format(
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
} }
} }
}; };
...@@ -152,24 +150,19 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -152,24 +150,19 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
auto* out = dout; auto* out = dout;
auto *x = dout, *y = dout; auto *x = dout, *y = dout;
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
in->set_layout(DataLayout::kMKLDNN);
in->set_format(out->format());
};
if (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()) { if (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()) {
if (dx->dims() == dy->dims()) { if (dx->dims() == dy->dims()) {
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx); auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
if (dx) { if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(), blas.VCOPY(dout->numel(), dout->data<T>(),
dx->mutable_data<T>(ctx.GetPlace())); dx->mutable_data<T>(ctx.GetPlace()));
set_mkldnn_format(dx, dout); dx->set_mkldnn_prim_desc(dout->get_mkldnn_prim_desc());
} }
if (dy) { if (dy) {
blas.VCOPY(dout->numel(), dout->data<T>(), blas.VCOPY(dout->numel(), dout->data<T>(),
dy->mutable_data<T>(ctx.GetPlace())); dy->mutable_data<T>(ctx.GetPlace()));
set_mkldnn_format(dy, dout); dy->set_mkldnn_prim_desc(dout->get_mkldnn_prim_desc());
} }
} }
} else { } else {
......
...@@ -96,8 +96,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -96,8 +96,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
std::vector<int> src_tz = framework::vectorize2int(x->dims()); std::vector<int> src_tz = framework::vectorize2int(x->dims());
auto src_format = auto src_format = x->format();
src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
const std::string key = gethash(src_tz, algorithm); const std::string key = gethash(src_tz, algorithm);
const std::string key_src_data = const std::string key_src_data =
...@@ -127,10 +126,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -127,10 +126,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
if (p_fwd == nullptr) { if (p_fwd == nullptr) {
// create mkldnn memory for input X // create mkldnn memory for input X
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), src_format);
auto src_memory = std::shared_ptr<memory>( auto src_memory = std::shared_ptr<memory>(
new memory({src_md, mkldnn_engine}, to_void_cast(x_data))); new memory(x->get_mkldnn_prim_desc(), to_void_cast(x_data)));
// save src_memory to be referred in backward path // save src_memory to be referred in backward path
dev_ctx.SetBlob(key_src_mem, src_memory); dev_ctx.SetBlob(key_src_mem, src_memory);
...@@ -177,8 +174,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -177,8 +174,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
pipeline.push_back(*p_fwd); pipeline.push_back(*p_fwd);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
y->set_layout(DataLayout::kMKLDNN); y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc());
y->set_format(GetMKLDNNFormat(*dst_memory));
} }
template <typename T> template <typename T>
...@@ -196,9 +192,6 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -196,9 +192,6 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims()); std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
auto diff_y_format =
diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();
const std::string key = gethash(diff_dst_tz, algorithm); const std::string key = gethash(diff_dst_tz, algorithm);
const std::string key_src_data = const std::string key_src_data =
key + ctx.op().Input("Out") + "@eltwise_fwd_src_data"; key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
...@@ -210,8 +203,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -210,8 +203,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem"; key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem";
const std::string key_fwd_pd = const std::string key_fwd_pd =
key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd"; key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd";
const std::string key_with_layouts = const std::string key_with_layouts = key + std::to_string(*p_src_layout) +
key + std::to_string(*p_src_layout) + "-" + std::to_string(diff_y_format); "-" + std::to_string(diff_y->format());
const std::string key_diff_src_mem = const std::string key_diff_src_mem =
key_with_layouts + "@eltwise_diff_src_mem"; key_with_layouts + "@eltwise_diff_src_mem";
const std::string key_diff_dst_mem = const std::string key_diff_dst_mem =
...@@ -234,10 +227,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -234,10 +227,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
if (p_grad == nullptr) { if (p_grad == nullptr) {
// create mkldnn memory for input diff_y // create mkldnn memory for input diff_y
auto diff_dst_md = platform::MKLDNNMemDesc(
diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
auto diff_dst_memory = std::shared_ptr<memory>( auto diff_dst_memory = std::shared_ptr<memory>(
new memory({diff_dst_md, mkldnn_engine}, to_void_cast(diff_y_data))); new memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data)));
dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory); dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory);
// retrieve eltwise primitive desc from device context // retrieve eltwise primitive desc from device context
...@@ -281,8 +272,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -281,8 +272,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
pipeline.push_back(*p_grad); pipeline.push_back(*p_grad);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
diff_x->set_layout(DataLayout::kMKLDNN); diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory));
} }
template <typename T, mkldnn::algorithm algorithm> template <typename T, mkldnn::algorithm algorithm>
......
...@@ -206,17 +206,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -206,17 +206,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor // create mkldnn memory from input x tensor
mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
// keys for backward pass // keys for backward pass
const std::string key = BatchNormMKLDNNHandler::GetHash( const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, global_stats, input_format, src_tz, epsilon, flags, global_stats, x->format(),
ctx.op().Output("SavedMean")); ctx.op().Output("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = x->get_mkldnn_prim_desc().desc();
{src_tz}, platform::MKLDNNGetDataType<T>(), input_format);
// create primitive descriptor for batch norm forward // create primitive descriptor for batch norm forward
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>; using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
...@@ -230,8 +227,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -230,8 +227,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine, BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine,
key); key);
auto src_memory = auto src_memory = handler.AcquireSrcMemory(x->get_mkldnn_prim_desc(),
handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data)); to_void_cast(x_data));
// crate mkldnn memory for weights(scale/shift) // crate mkldnn memory for weights(scale/shift)
auto scaleshift_memory = auto scaleshift_memory =
...@@ -265,8 +262,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -265,8 +262,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
variance_memory, false); variance_memory, false);
} }
y->set_layout(DataLayout::kMKLDNN); y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc());
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
std::vector<mkldnn::primitive> pipeline; std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*batch_norm_p); pipeline.push_back(*batch_norm_p);
...@@ -336,9 +332,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -336,9 +332,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>; using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
mkldnn::memory::format dst_format =
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
mkldnn::memory::format input_format = mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format()); platform::MKLDNNFormatForSize(src_tz.size(), x->format());
...@@ -346,14 +339,14 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -346,14 +339,14 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// keys from forward pass // keys from forward pass
const std::string key = BatchNormMKLDNNHandler::GetHash( const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, false, input_format, src_tz, epsilon, flags, false, x->format(),
ctx.op().Input("SavedMean")); ctx.op().Input("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
// keys for primitives reuse // keys for primitives reuse
const std::string key_with_hash = const std::string key_with_hash =
key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false, key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false,
input_format); x->format());
const std::string key_batch_norm_bwd_p = const std::string key_batch_norm_bwd_p =
key_with_hash + "@batch_norm_bwd_p"; key_with_hash + "@batch_norm_bwd_p";
const std::string key_batch_norm_src_mem_p = const std::string key_batch_norm_src_mem_p =
...@@ -373,9 +366,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -373,9 +366,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
primitive reorder_diff_dst; primitive reorder_diff_dst;
bool is_diff_dst_reordered = false; bool is_diff_dst_reordered = false;
auto user_diff_dst_memory = memory( auto user_diff_dst_memory =
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine}, memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data));
to_void_cast(diff_y_data));
// MKLDNN requires a single piece of memory for scale and shift/bias data // MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic; const size_t scaleshift_size = 2 * ic;
...@@ -459,10 +451,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -459,10 +451,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory); dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory);
// set layout/format of output tensors // set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN); diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
} else { } else {
// primitives already exist // primitives already exist
UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data)); UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data));
...@@ -487,10 +476,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -487,10 +476,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
} }
// set layout/format of output tensors // set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN); diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
} }
// execute optional reorder and batch_norm backward primitive // execute optional reorder and batch_norm backward primitive
......
...@@ -47,11 +47,6 @@ static memory::primitive_desc CreateMemPrimDesc(const Tensor& input, ...@@ -47,11 +47,6 @@ static memory::primitive_desc CreateMemPrimDesc(const Tensor& input,
return mem_prim_desc; return mem_prim_desc;
} }
static mkldnn::memory::format GetDstMemFormat(
const concat::primitive_desc& concat_pd) {
return (memory::format)concat_pd.dst_primitive_desc().desc().data.format;
}
static platform::CPUPlace GetCpuPlace( static platform::CPUPlace GetCpuPlace(
const paddle::framework::ExecutionContext& ctx) { const paddle::framework::ExecutionContext& ctx) {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -139,8 +134,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -139,8 +134,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto concat = prim_creator.CreateConcatPrimitive(concat_pd, output, place); auto concat = prim_creator.CreateConcatPrimitive(concat_pd, output, place);
stream(stream::kind::eager).submit({concat}).wait(); stream(stream::kind::eager).submit({concat}).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_mkldnn_prim_desc(concat_pd.dst_primitive_desc());
output->set_format(GetDstMemFormat(concat_pd));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -282,8 +282,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -282,8 +282,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
auto dst_mpd = dst_memory_p->get_primitive_desc(); output->set_mkldnn_prim_desc(dst_memory_p->get_primitive_desc());
output->set_mkldnn_prim_desc(dst_mpd);
} }
void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const { void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
...@@ -972,8 +971,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -972,8 +971,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_bwd_data_p); pipeline.push_back(*conv_bwd_data_p);
input_grad->set_layout(DataLayout::kMKLDNN); input_grad->set_mkldnn_prim_desc(diff_src_memory_p->get_primitive_desc());
input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p));
} }
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
} }
......
...@@ -221,8 +221,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -221,8 +221,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_mkldnn_prim_desc(dst_memory_p->get_primitive_desc());
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
} }
private: private:
......
...@@ -81,10 +81,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -81,10 +81,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto e_mid = framework::EigenTensor<T, 4>::From(*mid); auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k); e_mid = e_mid.constant(k);
auto dims = paddle::framework::vectorize2int(x->dims()); auto src_md = x->get_mkldnn_prim_desc().desc();
auto src_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, x->format());
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward, auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
mkldnn::lrn_across_channels, mkldnn::lrn_across_channels,
...@@ -94,7 +91,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -94,7 +91,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
beta, beta,
k}; k};
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine}; auto src_memory_pd = x->get_mkldnn_prim_desc();
if (!is_test) { if (!is_test) {
const std::string key = ctx.op().Output("Out"); const std::string key = ctx.op().Output("Out");
...@@ -111,16 +108,15 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -111,16 +108,15 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory->set_data_handle( src_memory->set_data_handle(
static_cast<void*>(const_cast<T*>(input_data))); static_cast<void*>(const_cast<T*>(input_data)));
auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(), auto dst_memory_pd = forward_pd->dst_primitive_desc();
static_cast<void*>(output_data)); auto dst_memory =
mkldnn::memory(dst_memory_pd, static_cast<void*>(output_data));
auto workspace_memory = insert_to_context<mkldnn::memory>( auto workspace_memory = insert_to_context<mkldnn::memory>(
key_workspace_memory, dev_ctx, key_workspace_memory, dev_ctx,
forward_pd->workspace_primitive_desc()); forward_pd->workspace_primitive_desc());
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory); run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
out->set_mkldnn_prim_desc(dst_memory_pd);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
} else { } else {
auto forward_pd = auto forward_pd =
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine}; mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
...@@ -128,13 +124,12 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -128,13 +124,12 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))}; src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
auto workspace_memory = auto workspace_memory =
mkldnn::memory{forward_pd.workspace_primitive_desc()}; mkldnn::memory{forward_pd.workspace_primitive_desc()};
auto dst_memory_pd = forward_pd.dst_primitive_desc();
auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(), auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(),
static_cast<void*>(output_data)); static_cast<void*>(output_data));
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory); run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
out->set_mkldnn_prim_desc(dst_memory_pd);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
} }
} }
}; };
......
...@@ -158,6 +158,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -158,6 +158,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto softmax_p = auto softmax_p =
handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p); handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);
// We cannot use softmax_dst_memory_p to get prim desc as
// it contains flattened dims (2D) while output tensor can
// have 2,3,4+ dims
auto output_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(output->dims()),
mkldnn::memory::format::blocked);
output->set_mkldnn_prim_desc(output_mem_pd);
std::vector<primitive> pipeline{ std::vector<primitive> pipeline{
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))}; *(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
......
...@@ -106,12 +106,12 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -106,12 +106,12 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory::desc(dst_tz, memory::data_type::f32, memory::format::any); memory::desc(dst_tz, memory::data_type::f32, memory::format::any);
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd); auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
auto dst_mem_pd = sum_pd.dst_primitive_desc();
std::shared_ptr<memory> dst_mem; std::shared_ptr<memory> dst_mem;
if (in_place) { if (in_place) {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc())); dst_mem.reset(new memory(dst_mem_pd));
} else { } else {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data)); dst_mem.reset(new memory(dst_mem_pd, output_data));
} }
std::vector<mkldnn::primitive::at> inputs; std::vector<mkldnn::primitive::at> inputs;
for (size_t i = 0; i < srcs_mem.size(); ++i) { for (size_t i = 0; i < srcs_mem.size(); ++i) {
...@@ -136,8 +136,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -136,8 +136,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (in_place) pipeline.push_back(reorder_prim); if (in_place) pipeline.push_back(reorder_prim);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_mkldnn_prim_desc(dst_mem_pd);
output->set_format(output_format);
} else { // Fallback to naive version } else { // Fallback to naive version
// TODO(@mozga-intel) Add MKLDNN SelectedRows & LoDTensorArray support // TODO(@mozga-intel) Add MKLDNN SelectedRows & LoDTensorArray support
SumKernel<CPUDeviceContext, T> reference_kernel; SumKernel<CPUDeviceContext, T> reference_kernel;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册