未验证 提交 ba90e052 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #15917 from jczaja/prv-tensor-mkldnn-ops

[MKL-DNN] Adjusting ops to Tensor modifications
......@@ -77,8 +77,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
} else {
functor.RunMidWise(n, pre, post);
}
z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format());
z->set_mkldnn_prim_desc(x->get_mkldnn_prim_desc());
} else {
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
x->format() != memory::format::format_undef,
......@@ -116,7 +115,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd);
// create mkldnn memory for dst
memory dst_memory = memory(sum_pd.dst_primitive_desc(), z_data);
auto dst_mem_pd = sum_pd.dst_primitive_desc();
memory dst_memory = memory(dst_mem_pd, z_data);
std::vector<primitive::at> inputs;
inputs.push_back(srcs[0]);
......@@ -129,9 +129,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
pipeline.push_back(sum_prim);
stream(stream::kind::eager).submit(pipeline).wait();
z->set_layout(DataLayout::kMKLDNN);
z->set_format(
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
z->set_mkldnn_prim_desc(dst_mem_pd);
}
}
};
......@@ -152,24 +150,19 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
auto* out = dout;
auto *x = dout, *y = dout;
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
in->set_layout(DataLayout::kMKLDNN);
in->set_format(out->format());
};
if (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()) {
if (dx->dims() == dy->dims()) {
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dx->mutable_data<T>(ctx.GetPlace()));
set_mkldnn_format(dx, dout);
dx->set_mkldnn_prim_desc(dout->get_mkldnn_prim_desc());
}
if (dy) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dy->mutable_data<T>(ctx.GetPlace()));
set_mkldnn_format(dy, dout);
dy->set_mkldnn_prim_desc(dout->get_mkldnn_prim_desc());
}
}
} else {
......
......@@ -96,8 +96,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
std::vector<int> src_tz = framework::vectorize2int(x->dims());
auto src_format =
src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
auto src_format = x->format();
const std::string key = gethash(src_tz, algorithm);
const std::string key_src_data =
......@@ -127,10 +126,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
if (p_fwd == nullptr) {
// create mkldnn memory for input X
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), src_format);
auto src_memory = std::shared_ptr<memory>(
new memory({src_md, mkldnn_engine}, to_void_cast(x_data)));
new memory(x->get_mkldnn_prim_desc(), to_void_cast(x_data)));
// save src_memory to be referred in backward path
dev_ctx.SetBlob(key_src_mem, src_memory);
......@@ -177,8 +174,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
pipeline.push_back(*p_fwd);
stream(stream::kind::eager).submit(pipeline).wait();
y->set_layout(DataLayout::kMKLDNN);
y->set_format(GetMKLDNNFormat(*dst_memory));
y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc());
}
template <typename T>
......@@ -196,9 +192,6 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
auto diff_y_format =
diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();
const std::string key = gethash(diff_dst_tz, algorithm);
const std::string key_src_data =
key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
......@@ -210,8 +203,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem";
const std::string key_fwd_pd =
key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd";
const std::string key_with_layouts =
key + std::to_string(*p_src_layout) + "-" + std::to_string(diff_y_format);
const std::string key_with_layouts = key + std::to_string(*p_src_layout) +
"-" + std::to_string(diff_y->format());
const std::string key_diff_src_mem =
key_with_layouts + "@eltwise_diff_src_mem";
const std::string key_diff_dst_mem =
......@@ -234,10 +227,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
if (p_grad == nullptr) {
// create mkldnn memory for input diff_y
auto diff_dst_md = platform::MKLDNNMemDesc(
diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
auto diff_dst_memory = std::shared_ptr<memory>(
new memory({diff_dst_md, mkldnn_engine}, to_void_cast(diff_y_data)));
new memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data)));
dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory);
// retrieve eltwise primitive desc from device context
......@@ -281,8 +272,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
pipeline.push_back(*p_grad);
stream(stream::kind::eager).submit(pipeline).wait();
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory));
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
}
template <typename T, mkldnn::algorithm algorithm>
......
......@@ -206,17 +206,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor
mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
// keys for backward pass
const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, global_stats, input_format,
src_tz, epsilon, flags, global_stats, x->format(),
ctx.op().Output("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input_format);
auto user_src_md = x->get_mkldnn_prim_desc().desc();
// create primitive descriptor for batch norm forward
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
......@@ -230,8 +227,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine,
key);
auto src_memory =
handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data));
auto src_memory = handler.AcquireSrcMemory(x->get_mkldnn_prim_desc(),
to_void_cast(x_data));
// crate mkldnn memory for weights(scale/shift)
auto scaleshift_memory =
......@@ -265,8 +262,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
variance_memory, false);
}
y->set_layout(DataLayout::kMKLDNN);
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc());
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*batch_norm_p);
......@@ -336,9 +332,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
mkldnn::memory::format dst_format =
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
......@@ -346,14 +339,14 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// keys from forward pass
const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, false, input_format,
src_tz, epsilon, flags, false, x->format(),
ctx.op().Input("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
// keys for primitives reuse
const std::string key_with_hash =
key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false,
input_format);
x->format());
const std::string key_batch_norm_bwd_p =
key_with_hash + "@batch_norm_bwd_p";
const std::string key_batch_norm_src_mem_p =
......@@ -373,9 +366,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
primitive reorder_diff_dst;
bool is_diff_dst_reordered = false;
auto user_diff_dst_memory = memory(
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
to_void_cast(diff_y_data));
auto user_diff_dst_memory =
memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic;
......@@ -459,10 +451,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory);
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
} else {
// primitives already exist
UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data));
......@@ -487,10 +476,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
}
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
}
// execute optional reorder and batch_norm backward primitive
......
......@@ -47,11 +47,6 @@ static memory::primitive_desc CreateMemPrimDesc(const Tensor& input,
return mem_prim_desc;
}
static mkldnn::memory::format GetDstMemFormat(
const concat::primitive_desc& concat_pd) {
return (memory::format)concat_pd.dst_primitive_desc().desc().data.format;
}
static platform::CPUPlace GetCpuPlace(
const paddle::framework::ExecutionContext& ctx) {
auto place = ctx.GetPlace();
......@@ -139,8 +134,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto concat = prim_creator.CreateConcatPrimitive(concat_pd, output, place);
stream(stream::kind::eager).submit({concat}).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetDstMemFormat(concat_pd));
output->set_mkldnn_prim_desc(concat_pd.dst_primitive_desc());
}
};
} // namespace operators
......
......@@ -282,8 +282,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait();
auto dst_mpd = dst_memory_p->get_primitive_desc();
output->set_mkldnn_prim_desc(dst_mpd);
output->set_mkldnn_prim_desc(dst_memory_p->get_primitive_desc());
}
void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test");
......@@ -972,8 +971,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_bwd_data_p);
input_grad->set_layout(DataLayout::kMKLDNN);
input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p));
input_grad->set_mkldnn_prim_desc(diff_src_memory_p->get_primitive_desc());
}
stream(stream::kind::eager).submit(pipeline).wait();
}
......
......@@ -221,8 +221,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
output->set_mkldnn_prim_desc(dst_memory_p->get_primitive_desc());
}
private:
......
......@@ -81,10 +81,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);
auto dims = paddle::framework::vectorize2int(x->dims());
auto src_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, x->format());
auto src_md = x->get_mkldnn_prim_desc().desc();
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
mkldnn::lrn_across_channels,
......@@ -94,7 +91,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
beta,
k};
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
auto src_memory_pd = x->get_mkldnn_prim_desc();
if (!is_test) {
const std::string key = ctx.op().Output("Out");
......@@ -111,16 +108,15 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory->set_data_handle(
static_cast<void*>(const_cast<T*>(input_data)));
auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(),
static_cast<void*>(output_data));
auto dst_memory_pd = forward_pd->dst_primitive_desc();
auto dst_memory =
mkldnn::memory(dst_memory_pd, static_cast<void*>(output_data));
auto workspace_memory = insert_to_context<mkldnn::memory>(
key_workspace_memory, dev_ctx,
forward_pd->workspace_primitive_desc());
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
out->set_mkldnn_prim_desc(dst_memory_pd);
} else {
auto forward_pd =
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
......@@ -128,13 +124,12 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
auto workspace_memory =
mkldnn::memory{forward_pd.workspace_primitive_desc()};
auto dst_memory_pd = forward_pd.dst_primitive_desc();
auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(),
static_cast<void*>(output_data));
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
out->set_mkldnn_prim_desc(dst_memory_pd);
}
}
};
......
......@@ -158,6 +158,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto softmax_p =
handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);
// We cannot use softmax_dst_memory_p to get prim desc as
// it contains flattened dims (2D) while output tensor can
// have 2,3,4+ dims
auto output_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(output->dims()),
mkldnn::memory::format::blocked);
output->set_mkldnn_prim_desc(output_mem_pd);
std::vector<primitive> pipeline{
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
stream(stream::kind::eager).submit(pipeline).wait();
......
......@@ -106,12 +106,12 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory::desc(dst_tz, memory::data_type::f32, memory::format::any);
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
auto dst_mem_pd = sum_pd.dst_primitive_desc();
std::shared_ptr<memory> dst_mem;
if (in_place) {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc()));
dst_mem.reset(new memory(dst_mem_pd));
} else {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data));
dst_mem.reset(new memory(dst_mem_pd, output_data));
}
std::vector<mkldnn::primitive::at> inputs;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
......@@ -136,8 +136,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (in_place) pipeline.push_back(reorder_prim);
stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(output_format);
output->set_mkldnn_prim_desc(dst_mem_pd);
} else { // Fallback to naive version
// TODO(@mozga-intel) Add MKLDNN SelectedRows & LoDTensorArray support
SumKernel<CPUDeviceContext, T> reference_kernel;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册