提交 50326563 编写于 作者: L Leo Zhao 提交者: Tao Luo

enable mkldnn primitive reuse for platform reorder (#17826)

test=develop
上级 7611208a
......@@ -13,11 +13,13 @@
// limitations under the License.
#include "paddle/fluid/framework/data_layout_transform.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#endif
namespace paddle {
......@@ -145,7 +147,6 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format =
......@@ -156,14 +157,21 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type);
auto out_data = out->mutable_data(expected_kernel_type.place_, in.type());
const std::string key = platform::ReorderMKLDNNHandler::GetHash(
in_tz, in_format, out_format, std::to_string(in_type));
auto in_memory =
memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data);
auto out_memory =
memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data);
platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx,
cpu_engine, key);
platform::Reorder(in_memory, out_memory);
auto reorder_src_memory_p = handler.AcquireSrcMemory(in_format, in_data);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(out, out_format, expected_kernel_type.place_);
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*reorder_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
} else {
out->ShareDataWith(in);
}
......
......@@ -400,6 +400,93 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
std::vector<int> logical_axis_;
};
class ReorderMKLDNNHandler : public MKLDNNHandler {
public:
ReorderMKLDNNHandler(std::vector<int>& dims, // NOLINT
framework::proto::VarType::Type vtype,
mkldnn::memory::data_type dtype,
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
vtype_(vtype),
dtype_(dtype) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::format& fmt, void* ptr) {
auto local_key = key_ + "@user_src_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
auto src_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{src_md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
framework::Tensor* output, const mkldnn::memory::format& fmt,
platform::Place place) {
auto local_key = key_ + "@user_dst_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
auto dst_mdp = mkldnn::memory::primitive_desc{dst_md, engine_};
auto dst_data = output->mutable_data(place, vtype_);
mem_p = std::make_shared<mkldnn::memory>(dst_mdp, dst_data);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
auto dst_data = output->mutable_data(place, vtype_);
mem_p->set_data_handle(dst_data);
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::reorder> AcquireReorder(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
auto prim_key = key_ + "@reorder_p";
auto reorder_p =
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((reorder_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (reorder_p == nullptr) {
reorder_p =
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, reorder_p);
} else {
is_reusing_ = true;
}
return reorder_p;
}
static std::string GetHash(std::vector<int>& shape, // NOLINT
mkldnn::memory::format in_fmt,
mkldnn::memory::format out_fmt,
const std::string& suffix) {
return dims2str(shape) + std::to_string(in_fmt) + "->" +
std::to_string(out_fmt) + "#" + suffix;
}
private:
std::vector<int> dims_;
framework::proto::VarType::Type vtype_;
mkldnn::memory::data_type dtype_;
};
template <typename T>
struct convolutional_algorithm;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册