提交 26323274 编写于 作者: J Jacek Czaja 提交者: tensor-tang

[MKL-DNN] Tensor modifications revert (#16462)

* Revert "[MKL-DNN] Fix to crash of Transformer when mkldnn is to be used (#16233)"

This reverts commit 13816dd4.
Apart from enabling transformer for MKL-DNN

* Revert "- MKL-DNN pooling updated to set_prim_desc"

This reverts commit c63f6b20.

Conflicts:
	paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc

* Revert "[MKL-DNN] MKL-DNN specific Tensor modification (#15429)"

test=develop

This reverts commit dec9cf53.

* - concat compilation fix

- lint

test=develop

- Lint fixes

test=develop

- Lint fixes

test=develop

- Fix Transpose MKLDNN op

test=develop
上级 4143a1c2
...@@ -134,6 +134,11 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -134,6 +134,11 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
out_layout = out_layout =
out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout; out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout;
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(
pool.Get(expected_kernel_type.place_));
auto& cpu_engine = dev_ctx->GetEngine();
std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims()); std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims());
std::vector<int> out_tz = in_tz; std::vector<int> out_tz = in_tz;
...@@ -142,25 +147,29 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -142,25 +147,29 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"Input tensor type is not supported: %s", in.type()); "Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type; memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format =
platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
// output tensor has the same dims as input. Reorder don't change dims // output tensor has the same dims as input. Reorder don't change dims
out->Resize(in.dims()); out->Resize(in.dims());
// tempory mem pd fr out , to make reorder if (in_format != out_format) {
auto out_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(out->dims()),
mkldnn::memory::format::blocked, out_type);
if (in.get_mkldnn_prim_desc() != out_mem_pd) {
void* in_data = GetDataFromTensor(in, in_type); void* in_data = GetDataFromTensor(in, in_type);
auto out_data = out->mutable_data(expected_kernel_type.place_, in.type()); auto out_data = out->mutable_data(expected_kernel_type.place_, in.type());
auto in_memory = memory(in.get_mkldnn_prim_desc(), in_data); auto in_memory =
auto out_memory = memory(out_mem_pd, out_data); memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data);
auto out_memory =
memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data);
platform::Reorder(in_memory, out_memory); platform::Reorder(in_memory, out_memory);
} else { } else {
out->ShareDataWith(in); out->ShareDataWith(in);
} }
out->set_layout(out_layout); out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(memory::format::format_undef);
#endif #endif
} }
......
...@@ -51,31 +51,13 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -51,31 +51,13 @@ void TransformData(const OpKernelType &expected_kernel_type,
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel // Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
// Just set layout/format. No real transform occur // Just set layout/format. No real transform occur
auto out_format = platform::MKLDNNFormatForSize(in.dims().size(),
ToMKLDNNFormat(lin));
out.ShareDataWith(input_tensor); out.ShareDataWith(input_tensor);
// TODO(jczaja): Remove that once all mkldnn ops out.set_layout(DataLayout::kMKLDNN);
// are modified to work with mkldnn_blocked out.set_format(out_format);
auto mkldnn_fmt = [&](int rank) {
switch (rank) {
case 5:
return mkldnn::memory::format::ncdhw;
case 4:
return mkldnn::memory::format::nchw;
case 3:
return mkldnn::memory::format::ncw;
case 2:
return mkldnn::memory::format::nc;
case 1:
return mkldnn::memory::format::x;
default:
return mkldnn::memory::format::blocked;
}
};
auto out_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(out.dims()),
mkldnn_fmt(out.dims().size()));
out.set_mkldnn_prim_desc(out_mem_pd);
#endif #endif
} else { } else {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel // Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <typeindex> #include <typeindex>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
...@@ -27,10 +28,6 @@ limitations under the License. */ ...@@ -27,10 +28,6 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_utils.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -41,34 +38,10 @@ class Tensor { ...@@ -41,34 +38,10 @@ class Tensor {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
public: public:
// TODO(jczaja): This is depracted and will be removed inline mkldnn::memory::format format() const { return format_; }
inline mkldnn::memory::format format() const {
if (layout_ == DataLayout::kMKLDNN) {
return static_cast<mkldnn::memory::format>(mem_pd_.desc().data.format);
} else {
return mkldnn::memory::format::format_undef;
}
}
// TODO(jczaja): This is depracted and will be removed inline void set_format(const mkldnn::memory::format format) {
inline void set_format( format_ = format;
const mkldnn::memory::format fmt,
mkldnn::memory::data_type data_type = mkldnn::memory::f32) {
mem_pd_ = paddle::platform::create_prim_desc_from_format(
paddle::framework::vectorize2int(dims()), fmt, data_type);
layout_ = DataLayout::kMKLDNN;
}
inline mkldnn::memory::primitive_desc get_mkldnn_prim_desc() const {
return mem_pd_;
}
inline void set_mkldnn_prim_desc(
const mkldnn::memory::primitive_desc& mem_pd) {
// Internally MKL-DNN is just copying (increasing reference counter)
// to shared_ptr. So asignment should be quite cheap
mem_pd_ = mem_pd;
layout_ = DataLayout::kMKLDNN;
} }
protected: protected:
...@@ -76,9 +49,12 @@ class Tensor { ...@@ -76,9 +49,12 @@ class Tensor {
* @brief the detail format of memory block which have layout as kMKLDNN * @brief the detail format of memory block which have layout as kMKLDNN
* *
* @note MKLDNN lib support various memory format like nchw, nhwc, nChw8C, * @note MKLDNN lib support various memory format like nchw, nhwc, nChw8C,
* nChw16c, etc. For a MKLDNN memory block, we store memory descriptor * nChw16c, etc. For a MKLDNN memory block, layout will be set as
* DataLayout::kMKLDNN meanwhile detail memory format will be kept in
* this field.
*/ */
mutable mkldnn::memory::primitive_desc mem_pd_;
mkldnn::memory::format format_ = mkldnn::memory::format::format_undef;
#endif #endif
public: public:
......
...@@ -44,11 +44,6 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, ...@@ -44,11 +44,6 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
<< dst_place; << dst_place;
return; return;
} }
#ifdef PADDLE_WITH_MKLDNN
if (src.layout() == DataLayout::kMKLDNN) {
dst->set_mkldnn_prim_desc(src.get_mkldnn_prim_desc());
}
#endif
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size); boost::get<platform::CPUPlace>(src_place), src_ptr, size);
} }
......
...@@ -77,7 +77,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -77,7 +77,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
} else { } else {
functor.RunMidWise(n, pre, post); functor.RunMidWise(n, pre, post);
} }
z->set_mkldnn_prim_desc(x->get_mkldnn_prim_desc()); z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format());
} else { } else {
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
x->format() != memory::format::format_undef, x->format() != memory::format::format_undef,
...@@ -115,8 +116,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -115,8 +116,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd); auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd);
// create mkldnn memory for dst // create mkldnn memory for dst
auto dst_mem_pd = sum_pd.dst_primitive_desc(); memory dst_memory = memory(sum_pd.dst_primitive_desc(), z_data);
memory dst_memory = memory(dst_mem_pd, z_data);
std::vector<primitive::at> inputs; std::vector<primitive::at> inputs;
inputs.push_back(srcs[0]); inputs.push_back(srcs[0]);
...@@ -129,7 +129,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -129,7 +129,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
pipeline.push_back(sum_prim); pipeline.push_back(sum_prim);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
z->set_mkldnn_prim_desc(dst_mem_pd); z->set_layout(DataLayout::kMKLDNN);
z->set_format(
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
} }
} }
}; };
...@@ -150,19 +152,24 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -150,19 +152,24 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
auto* out = dout; auto* out = dout;
auto *x = dout, *y = dout; auto *x = dout, *y = dout;
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
in->set_layout(DataLayout::kMKLDNN);
in->set_format(out->format());
};
if (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()) { if (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()) {
if (dx->dims() == dy->dims()) { if (dx->dims() == dy->dims()) {
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx); auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
if (dx) { if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(), blas.VCOPY(dout->numel(), dout->data<T>(),
dx->mutable_data<T>(ctx.GetPlace())); dx->mutable_data<T>(ctx.GetPlace()));
dx->set_mkldnn_prim_desc(dout->get_mkldnn_prim_desc()); set_mkldnn_format(dx, dout);
} }
if (dy) { if (dy) {
blas.VCOPY(dout->numel(), dout->data<T>(), blas.VCOPY(dout->numel(), dout->data<T>(),
dy->mutable_data<T>(ctx.GetPlace())); dy->mutable_data<T>(ctx.GetPlace()));
dy->set_mkldnn_prim_desc(dout->get_mkldnn_prim_desc()); set_mkldnn_format(dy, dout);
} }
} }
} else { } else {
......
...@@ -96,7 +96,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -96,7 +96,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
std::vector<int> src_tz = framework::vectorize2int(x->dims()); std::vector<int> src_tz = framework::vectorize2int(x->dims());
auto src_format = x->format(); auto src_format =
src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
const std::string key = gethash(src_tz, algorithm); const std::string key = gethash(src_tz, algorithm);
const std::string key_src_data = const std::string key_src_data =
...@@ -126,8 +127,10 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -126,8 +127,10 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
if (p_fwd == nullptr) { if (p_fwd == nullptr) {
// create mkldnn memory for input X // create mkldnn memory for input X
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), src_format);
auto src_memory = std::shared_ptr<memory>( auto src_memory = std::shared_ptr<memory>(
new memory(x->get_mkldnn_prim_desc(), to_void_cast(x_data))); new memory({src_md, mkldnn_engine}, to_void_cast(x_data)));
// save src_memory to be referred in backward path // save src_memory to be referred in backward path
dev_ctx.SetBlob(key_src_mem, src_memory); dev_ctx.SetBlob(key_src_mem, src_memory);
...@@ -174,7 +177,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -174,7 +177,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
pipeline.push_back(*p_fwd); pipeline.push_back(*p_fwd);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc()); y->set_layout(DataLayout::kMKLDNN);
y->set_format(GetMKLDNNFormat(*dst_memory));
} }
template <typename T> template <typename T>
...@@ -192,6 +196,9 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -192,6 +196,9 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims()); std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
auto diff_y_format =
diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();
const std::string key = gethash(diff_dst_tz, algorithm); const std::string key = gethash(diff_dst_tz, algorithm);
const std::string key_src_data = const std::string key_src_data =
key + ctx.op().Input("Out") + "@eltwise_fwd_src_data"; key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
...@@ -203,8 +210,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -203,8 +210,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem"; key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem";
const std::string key_fwd_pd = const std::string key_fwd_pd =
key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd"; key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd";
const std::string key_with_layouts = key + std::to_string(*p_src_layout) + const std::string key_with_layouts =
"-" + std::to_string(diff_y->format()); key + std::to_string(*p_src_layout) + "-" + std::to_string(diff_y_format);
const std::string key_diff_src_mem = const std::string key_diff_src_mem =
key_with_layouts + "@eltwise_diff_src_mem"; key_with_layouts + "@eltwise_diff_src_mem";
const std::string key_diff_dst_mem = const std::string key_diff_dst_mem =
...@@ -227,8 +234,10 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -227,8 +234,10 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
if (p_grad == nullptr) { if (p_grad == nullptr) {
// create mkldnn memory for input diff_y // create mkldnn memory for input diff_y
auto diff_dst_md = platform::MKLDNNMemDesc(
diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
auto diff_dst_memory = std::shared_ptr<memory>( auto diff_dst_memory = std::shared_ptr<memory>(
new memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data))); new memory({diff_dst_md, mkldnn_engine}, to_void_cast(diff_y_data)));
dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory); dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory);
// retrieve eltwise primitive desc from device context // retrieve eltwise primitive desc from device context
...@@ -272,7 +281,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -272,7 +281,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
pipeline.push_back(*p_grad); pipeline.push_back(*p_grad);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc()); diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory));
} }
template <typename T, mkldnn::algorithm algorithm> template <typename T, mkldnn::algorithm algorithm>
......
...@@ -206,14 +206,17 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -206,14 +206,17 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor // create mkldnn memory from input x tensor
mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
// keys for backward pass // keys for backward pass
const std::string key = BatchNormMKLDNNHandler::GetHash( const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, global_stats, x->format(), src_tz, epsilon, flags, global_stats, input_format,
ctx.op().Output("SavedMean")); ctx.op().Output("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
auto user_src_md = x->get_mkldnn_prim_desc().desc(); auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input_format);
// create primitive descriptor for batch norm forward // create primitive descriptor for batch norm forward
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>; using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
...@@ -227,8 +230,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -227,8 +230,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine, BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine,
key); key);
auto src_memory = handler.AcquireSrcMemory(x->get_mkldnn_prim_desc(), auto src_memory =
to_void_cast(x_data)); handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data));
// crate mkldnn memory for weights(scale/shift) // crate mkldnn memory for weights(scale/shift)
auto scaleshift_memory = auto scaleshift_memory =
...@@ -262,7 +265,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -262,7 +265,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
variance_memory, false); variance_memory, false);
} }
y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc()); y->set_layout(DataLayout::kMKLDNN);
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
std::vector<mkldnn::primitive> pipeline; std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*batch_norm_p); pipeline.push_back(*batch_norm_p);
...@@ -332,6 +336,9 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -332,6 +336,9 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>; using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
mkldnn::memory::format dst_format =
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
mkldnn::memory::format input_format = mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format()); platform::MKLDNNFormatForSize(src_tz.size(), x->format());
...@@ -339,14 +346,14 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -339,14 +346,14 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// keys from forward pass // keys from forward pass
const std::string key = BatchNormMKLDNNHandler::GetHash( const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, false, x->format(), src_tz, epsilon, flags, false, input_format,
ctx.op().Input("SavedMean")); ctx.op().Input("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
// keys for primitives reuse // keys for primitives reuse
const std::string key_with_hash = const std::string key_with_hash =
key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false, key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false,
x->format()); input_format);
const std::string key_batch_norm_bwd_p = const std::string key_batch_norm_bwd_p =
key_with_hash + "@batch_norm_bwd_p"; key_with_hash + "@batch_norm_bwd_p";
const std::string key_batch_norm_src_mem_p = const std::string key_batch_norm_src_mem_p =
...@@ -366,8 +373,9 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -366,8 +373,9 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
primitive reorder_diff_dst; primitive reorder_diff_dst;
bool is_diff_dst_reordered = false; bool is_diff_dst_reordered = false;
auto user_diff_dst_memory = auto user_diff_dst_memory = memory(
memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data)); {{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
to_void_cast(diff_y_data));
// MKLDNN requires a single piece of memory for scale and shift/bias data // MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic; const size_t scaleshift_size = 2 * ic;
...@@ -451,7 +459,10 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -451,7 +459,10 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory); dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory);
// set layout/format of output tensors // set layout/format of output tensors
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc()); diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
} else { } else {
// primitives already exist // primitives already exist
UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data)); UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data));
...@@ -476,7 +487,10 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -476,7 +487,10 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
} }
// set layout/format of output tensors // set layout/format of output tensors
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc()); diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
} }
// execute optional reorder and batch_norm backward primitive // execute optional reorder and batch_norm backward primitive
......
...@@ -210,7 +210,8 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -210,7 +210,8 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
stream(stream::kind::eager).submit({*concat_p}).wait(); stream(stream::kind::eager).submit({*concat_p}).wait();
output->set_mkldnn_prim_desc(concat_pd->dst_primitive_desc()); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetDstMemFormat(*concat_pd));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -96,8 +96,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -96,8 +96,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr; auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN); PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN); input->format() != memory::format::format_undef,
"Wrong layout/format set for Input tensor");
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5, PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5,
"Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW"); "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5, PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5,
...@@ -144,19 +148,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -144,19 +148,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
// For convolution with groups we need to recreate primitive descriptor auto src_format = input->format();
// as Paddle tensor is not having group dims while mkldnn treats mkldnn::memory::format weights_format =
// group as another dimensions GetWeightsFormat(filter->format(), g, is_conv3d);
mkldnn::memory::primitive_desc user_weights_mpd =
filter->get_mkldnn_prim_desc(); auto user_src_md = platform::MKLDNNMemDesc(
if (g > 1) { {src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
mkldnn::memory::format weights_format = auto user_weights_md = platform::MKLDNNMemDesc(
GetWeightsFormat(filter->format(), g, is_conv3d); {weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
user_weights_mpd =
mkldnn::memory::primitive_desc(user_weights_md, mkldnn_engine);
}
/* create memory descriptor for convolution without specified format /* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose * ('any') which lets a primitive (convolution in this case) choose
...@@ -166,7 +165,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -166,7 +165,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
mkldnn::memory::format weights_format = mkldnn::memory::format::any; weights_format = mkldnn::memory::format::any;
// Check the format for user's special output // Check the format for user's special output
if (chosen_memory_format != mkldnn::memory::format::any) { if (chosen_memory_format != mkldnn::memory::format::any) {
if (is_conv3d) { if (is_conv3d) {
...@@ -206,10 +205,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -206,10 +205,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key); platform::ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key);
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = handler.AcquireSrcMemory( auto user_src_memory_p =
input->get_mkldnn_prim_desc(), to_void_cast<T>(input_data)); handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory( auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_mpd, to_void_cast<T>(filter_data)); user_weights_md, to_void_cast<T>(filter_data));
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto src_memory_p = auto src_memory_p =
...@@ -282,7 +281,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -282,7 +281,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_mkldnn_prim_desc(dst_memory_p->get_primitive_desc()); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const { void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
...@@ -948,8 +948,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -948,8 +948,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
pipeline.push_back(*conv_bwd_weights_p); pipeline.push_back(*conv_bwd_weights_p);
auto filter_grad_mpd = diff_weights_memory_p->get_primitive_desc(); filter_grad->set_layout(DataLayout::kMKLDNN);
filter_grad->set_mkldnn_prim_desc(filter_grad_mpd); filter_grad->set_format(GetMKLDNNFormat(*diff_weights_memory_p));
} }
if (input_grad) { if (input_grad) {
...@@ -972,7 +972,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -972,7 +972,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_bwd_data_p); pipeline.push_back(*conv_bwd_data_p);
input_grad->set_mkldnn_prim_desc(diff_src_memory_p->get_primitive_desc()); input_grad->set_layout(DataLayout::kMKLDNN);
input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p));
} }
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
} }
......
...@@ -221,7 +221,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -221,7 +221,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
output->set_mkldnn_prim_desc(dst_memory_p->get_primitive_desc()); output->set_layout(DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
} }
private: private:
......
...@@ -42,12 +42,8 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -42,12 +42,8 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
// The format of output is set as the mkldnn's format // The format of output is set as the mkldnn's format
// TODO(@mozga-intel) The format of matrix sets inside the another layers. // TODO(@mozga-intel) The format of matrix sets inside the another layers.
// TODO(jczaja): Remove this hack after checking performance on block layout tensor->set_layout(DataLayout::kMKLDNN);
tensor->set_format(mkldnn::memory::format::oihw);
auto tensor_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(tensor->dims()),
mkldnn::memory::format::oihw);
tensor->set_mkldnn_prim_desc(tensor_mem_pd);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -81,7 +81,10 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -81,7 +81,10 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto e_mid = framework::EigenTensor<T, 4>::From(*mid); auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k); e_mid = e_mid.constant(k);
auto src_md = x->get_mkldnn_prim_desc().desc(); auto dims = paddle::framework::vectorize2int(x->dims());
auto src_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, x->format());
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward, auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
mkldnn::lrn_across_channels, mkldnn::lrn_across_channels,
...@@ -91,7 +94,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -91,7 +94,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
beta, beta,
k}; k};
auto src_memory_pd = x->get_mkldnn_prim_desc(); auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
if (!is_test) { if (!is_test) {
const std::string key = ctx.op().Output("Out"); const std::string key = ctx.op().Output("Out");
...@@ -108,15 +111,16 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -108,15 +111,16 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory->set_data_handle( src_memory->set_data_handle(
static_cast<void*>(const_cast<T*>(input_data))); static_cast<void*>(const_cast<T*>(input_data)));
auto dst_memory_pd = forward_pd->dst_primitive_desc(); auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(),
auto dst_memory = static_cast<void*>(output_data));
mkldnn::memory(dst_memory_pd, static_cast<void*>(output_data));
auto workspace_memory = insert_to_context<mkldnn::memory>( auto workspace_memory = insert_to_context<mkldnn::memory>(
key_workspace_memory, dev_ctx, key_workspace_memory, dev_ctx,
forward_pd->workspace_primitive_desc()); forward_pd->workspace_primitive_desc());
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory); run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
out->set_mkldnn_prim_desc(dst_memory_pd);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
} else { } else {
auto forward_pd = auto forward_pd =
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine}; mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
...@@ -124,12 +128,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -124,12 +128,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))}; src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
auto workspace_memory = auto workspace_memory =
mkldnn::memory{forward_pd.workspace_primitive_desc()}; mkldnn::memory{forward_pd.workspace_primitive_desc()};
auto dst_memory_pd = forward_pd.dst_primitive_desc();
auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(), auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(),
static_cast<void*>(output_data)); static_cast<void*>(output_data));
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory); run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
out->set_mkldnn_prim_desc(dst_memory_pd);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
} }
} }
}; };
......
...@@ -158,14 +158,6 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -158,14 +158,6 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto softmax_p = auto softmax_p =
handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p); handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);
// We cannot use softmax_dst_memory_p to get prim desc as
// it contains flattened dims (2D) while output tensor can
// have 2,3,4+ dims
auto output_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(output->dims()),
mkldnn::memory::format::blocked);
output->set_mkldnn_prim_desc(output_mem_pd);
std::vector<primitive> pipeline{ std::vector<primitive> pipeline{
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))}; *(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
......
...@@ -106,12 +106,12 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -106,12 +106,12 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory::desc(dst_tz, memory::data_type::f32, memory::format::any); memory::desc(dst_tz, memory::data_type::f32, memory::format::any);
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd); auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
auto dst_mem_pd = sum_pd.dst_primitive_desc();
std::shared_ptr<memory> dst_mem; std::shared_ptr<memory> dst_mem;
if (in_place) { if (in_place) {
dst_mem.reset(new memory(dst_mem_pd)); dst_mem.reset(new memory(sum_pd.dst_primitive_desc()));
} else { } else {
dst_mem.reset(new memory(dst_mem_pd, output_data)); dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data));
} }
std::vector<mkldnn::primitive::at> inputs; std::vector<mkldnn::primitive::at> inputs;
for (size_t i = 0; i < srcs_mem.size(); ++i) { for (size_t i = 0; i < srcs_mem.size(); ++i) {
...@@ -136,7 +136,8 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -136,7 +136,8 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (in_place) pipeline.push_back(reorder_prim); if (in_place) pipeline.push_back(reorder_prim);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_mkldnn_prim_desc(dst_mem_pd); output->set_layout(DataLayout::kMKLDNN);
output->set_format(output_format);
} else { // Fallback to naive version } else { // Fallback to naive version
// TODO(@mozga-intel) Add MKLDNN SelectedRows & LoDTensorArray support // TODO(@mozga-intel) Add MKLDNN SelectedRows & LoDTensorArray support
SumKernel<CPUDeviceContext, T> reference_kernel; SumKernel<CPUDeviceContext, T> reference_kernel;
......
...@@ -52,7 +52,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -52,7 +52,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn_engine, key); mkldnn_engine, key);
auto transpose_src_memory_p = handler.AcquireSrcMemory( auto transpose_src_memory_p = handler.AcquireSrcMemory(
input->get_mkldnn_prim_desc(), platform::to_void_cast<T>(input_data)); input->format(), platform::to_void_cast<T>(input_data));
auto transpose_dst_memory_p = auto transpose_dst_memory_p =
handler.AcquireDstMemory(output, ctx.GetPlace()); handler.AcquireDstMemory(output, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p, auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
...@@ -62,14 +62,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -62,14 +62,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*transpose_p); pipeline.push_back(*transpose_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
// Transpose did change logical dimensions of Tensor, but reorder does not. output->set_layout(DataLayout::kNCHW);
// Reorder does change only physical layout eg. format , strides output->set_format(mkldnn::memory::format::format_undef);
// so we need to create new primitive descriptor with changed logical layout
// so it match output shape
auto output_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(output->dims()),
mkldnn::memory::format::blocked);
output->set_mkldnn_prim_desc(output_mem_pd);
} }
}; };
...@@ -134,9 +128,8 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -134,9 +128,8 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, dev_ctx, platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, dev_ctx,
mkldnn_engine, key); mkldnn_engine, key);
auto transpose_src_memory_p = auto transpose_src_memory_p = handler.AcquireSrcMemory(
handler.AcquireSrcMemory(out_grad->get_mkldnn_prim_desc(), out_grad->format(), platform::to_void_cast<T>(out_grad_data));
platform::to_void_cast<T>(out_grad_data));
auto transpose_dst_memory_p = auto transpose_dst_memory_p =
handler.AcquireDstMemory(x_grad, ctx.GetPlace()); handler.AcquireDstMemory(x_grad, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p, auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
...@@ -145,15 +138,6 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -145,15 +138,6 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<mkldnn::primitive> pipeline; std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*transpose_p); pipeline.push_back(*transpose_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
// Transpose did change logical dimensions of Tensor, but reorder does not.
// Reorder does change only physical layout eg. format , strides
// so we need to create new primitive descriptor with changed logical layout
// so it match output shape
auto x_grad_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(x_grad->dims()),
mkldnn::memory::format::blocked);
x_grad->set_mkldnn_prim_desc(x_grad_mem_pd);
} }
}; };
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
...@@ -39,45 +40,6 @@ class MKLDNNHandler { ...@@ -39,45 +40,6 @@ class MKLDNNHandler {
return this->AcquireMemory(md, ptr, "@user_src_mem_p"); return this->AcquireMemory(md, ptr, "@user_src_mem_p");
} }
// TODO(jczaja): extract common part and make AcquireMemory
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::primitive_desc& mpd, void* ptr) {
auto local_key = key_ + "@user_src_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mpd, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemory(
const mkldnn::memory::primitive_desc& mpd, void* ptr) {
auto local_key = key_ + "@user_weights_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mpd, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemory( std::shared_ptr<mkldnn::memory> AcquireWeightsMemory(
const mkldnn::memory::desc& md, void* ptr, const mkldnn::memory::desc& md, void* ptr,
user_function custom_func = {}) { user_function custom_func = {}) {
...@@ -315,7 +277,37 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -315,7 +277,37 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
mkldnn::engine engine, const std::string& base_key) mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key), : platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims), dims_(dims),
axis_(axis) {} axis_(axis),
logical_axis_(dims.size(), 0) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::format& fmt, void* ptr) {
auto local_key = key_ + "@user_src_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < logical_axis_.size(); ++i) {
logical_axis_[i] = i;
}
auto src_md = fmt != mkldnn::memory::format::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<float>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{src_md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output, std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output,
platform::Place place) { platform::Place place) {
...@@ -400,6 +392,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -400,6 +392,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
private: private:
std::vector<int> dims_; std::vector<int> dims_;
std::vector<int> axis_; std::vector<int> axis_;
std::vector<int> logical_axis_;
}; };
template <class forward_t, class backward_data_t, class backward_weights_t> template <class forward_t, class backward_data_t, class backward_weights_t>
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mkldnn.h>
#include <string>
namespace paddle {
namespace platform {
inline mkldnn::memory::primitive_desc create_prim_desc_from_dims(
const std::vector<int>& ltz, mkldnn::memory::format fmt,
mkldnn::memory::data_type data_type = mkldnn::memory::data_type::f32) {
mkldnn_memory_desc_t mem_fmt;
mem_fmt.primitive_kind = mkldnn_memory;
mem_fmt.ndims = ltz.size();
for (unsigned int i = 0; i < ltz.size(); ++i) {
mem_fmt.dims[i] = ltz[i]; // logical dimensions (nchw format,
// regardless physical layout)
}
mem_fmt.data_type = static_cast<mkldnn_data_type_t>(data_type);
mem_fmt.format = static_cast<mkldnn_memory_format_t>(fmt);
unsigned int total_stride = 1;
for (int i = ltz.size() - 1; i >= 0; --i) {
mem_fmt.layout_desc.blocking.padding_dims[i] =
ltz[i]; // logical dimensions (nchw format, regardless physical
// layout)
mem_fmt.layout_desc.blocking.block_dims[i] = 1;
mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset
mem_fmt.layout_desc.blocking.strides[0][i] = total_stride;
mem_fmt.layout_desc.blocking.strides[1][i] = 1;
total_stride *= ltz[i];
}
mem_fmt.layout_desc.blocking.offset_padding = 0; // no initial offset
auto& pool = platform::DeviceContextPool::Instance();
auto place = paddle::platform::CPUPlace();
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place));
auto& cpu_engine = dev_ctx->GetEngine();
return mkldnn::memory::primitive_desc(mem_fmt, cpu_engine);
}
inline mkldnn::memory::primitive_desc create_prim_desc_from_format(
const std::vector<int>& ltz, const mkldnn::memory::format format,
const mkldnn::memory::data_type data_type) {
auto md = mkldnn::memory::desc({ltz}, data_type, format);
auto& pool = platform::DeviceContextPool::Instance();
auto place = paddle::platform::CPUPlace();
auto dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place));
PADDLE_ENFORCE_NOT_NULL(dev_ctx, "Could not get valid device");
auto& cpu_engine = dev_ctx->GetEngine();
return mkldnn::memory::primitive_desc(md, cpu_engine);
}
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册