提交 709d9e3c 编写于 作者: J Jacek Czaja

- Added reusing MKL-DNN primitives for Transpose MKL-DNN op

test=develop
上级 b37fb7a6
...@@ -32,7 +32,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -32,7 +32,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE( PADDLE_ENFORCE(
is_test == true, is_test == true,
"ConvTransposeMKLDNN works only for inference!. Set is_test = True"); "TransposeMKLDNN works only for inference!. Set is_test = True");
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
...@@ -47,69 +47,24 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -47,69 +47,24 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
return; return;
} }
std::vector<int> nchw_axis(ndims, 0);
for (size_t i = 0; i < nchw_axis.size(); ++i) {
nchw_axis[i] = i;
}
std::vector<int> nchw_tz = paddle::framework::vectorize2int(input->dims()); std::vector<int> nchw_tz = paddle::framework::vectorize2int(input->dims());
std::string data_format = ctx.Attr<std::string>("data_format");
auto src_md =
input->format() != mkldnn::memory::format::nchw
? platform::MKLDNNMemDesc(nchw_tz, platform::MKLDNNGetDataType<T>(),
input->format())
: Axis2MemoryDesc(nchw_tz, nchw_axis);
this->TransposeKernel(ctx.GetPlace(), Axis2MemoryDesc(nchw_tz, axis),
src_md, output, input_data, nchw_tz, mkldnn_engine);
}
protected:
mkldnn::memory::desc Axis2MemoryDesc(std::vector<int>& nchw_tz,
std::vector<int>& axis) const {
mkldnn_memory_desc_t mem_fmt;
mem_fmt.primitive_kind = mkldnn_memory;
mem_fmt.ndims = axis.size();
for (unsigned int i = 0; i < nchw_tz.size(); ++i) {
mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format,
// regardless physical layout)
}
mem_fmt.data_type = mkldnn_f32;
mem_fmt.format = mkldnn_blocked;
unsigned int total_stride = 1;
for (int i = nchw_tz.size() - 1; i >= 0; --i) {
mem_fmt.layout_desc.blocking.padding_dims[i] =
nchw_tz[i]; // logical dimensions (nchw format, regardless physical
// layout)
mem_fmt.layout_desc.blocking.block_dims[i] = 1;
mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset
mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride;
mem_fmt.layout_desc.blocking.strides[1][axis[i]] = 1;
total_stride *= nchw_tz[axis[i]];
}
mem_fmt.layout_desc.blocking.offset_padding = 0; // no initial offset
return mem_fmt;
}
void TransposeKernel(platform::Place place, mkldnn::memory::desc md_o, const std::string key = platform::TransposeMKLDNNHandler::GetHash(
mkldnn::memory::desc md_i, Tensor* output, nchw_tz, axis, ctx.op().Output("Out"));
const T* data_i, std::vector<int>& nchw_dims,
const mkldnn::engine& eng) const {
// Make Memory primitive descriptors
auto mpd_o = mkldnn::memory::primitive_desc(md_o, eng);
auto mpd_i = mkldnn::memory::primitive_desc(md_i, eng);
auto data_o = output->mutable_data<T>( platform::TransposeMKLDNNHandler handler(nchw_tz, axis, dev_ctx,
place, paddle::memory::Allocator::kDefault, mpd_o.get_size()); mkldnn_engine, key);
auto src = mkldnn::memory(mpd_i, (T*)(data_i)); auto transpose_src_memory_p = handler.AcquireSrcMemory(
auto dst = mkldnn::memory(mpd_o, data_o); input->format(), platform::to_void_cast<T>(input_data));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(output, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
auto r = mkldnn::reorder(src, dst); std::vector<mkldnn::primitive> pipeline;
mkldnn::stream(mkldnn::stream::kind::eager).submit({r}).wait(); pipeline.push_back(*transpose_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
} }
}; };
......
...@@ -197,6 +197,130 @@ class MKLDNNHandler { ...@@ -197,6 +197,130 @@ class MKLDNNHandler {
bool is_reusing_; bool is_reusing_;
}; };
class TransposeMKLDNNHandler : public MKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int>& dims, std::vector<int>& axis,
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::format& fmt, void* ptr) {
auto local_key = key_ + "@user_src_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < logical_axis_.size(); ++i) {
logical_axis_[i] = i;
}
auto src_md = fmt != mkldnn::memory::format::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<float>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{src_md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output,
platform::Place place) {
auto local_key = key_ + "@user_dst_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
auto dst_mdp = mkldnn::memory::primitive_desc{
Axis2MemoryDesc(dims_, axis_), engine_};
auto dst_data = output->mutable_data<float>(
place, paddle::memory::Allocator::kDefault, dst_mdp.get_size());
mem_p = std::make_shared<mkldnn::memory>(dst_mdp, dst_data);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
auto dst_data = output->mutable_data<float>(place);
mem_p->set_data_handle(dst_data);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::reorder> AcquireTranspose(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
auto prim_key = key_ + "@transpose_p";
auto transpose_p =
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((transpose_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (transpose_p == nullptr) {
transpose_p =
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, transpose_p);
} else {
is_reusing_ = true;
}
return transpose_p;
}
static std::string GetHash(std::vector<int>& shape, // NOLINT
std::vector<int>& axis, // NOLINT
const std::string& suffix) {
return dims2str(shape) + dims2str(axis) + suffix;
}
protected:
mkldnn_memory_desc_t Axis2MemoryDesc(std::vector<int>& nchw_tz,
std::vector<int>& axis) {
mkldnn_memory_desc_t mem_fmt;
mem_fmt.primitive_kind = mkldnn_memory;
mem_fmt.ndims = axis.size();
for (unsigned int i = 0; i < nchw_tz.size(); ++i) {
mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format,
// regardless physical layout)
}
mem_fmt.data_type = mkldnn_f32;
mem_fmt.format = mkldnn_blocked;
unsigned int total_stride = 1;
for (int i = nchw_tz.size() - 1; i >= 0; --i) {
mem_fmt.layout_desc.blocking.padding_dims[i] =
nchw_tz[i]; // logical dimensions (nchw format, regardless physical
// layout)
mem_fmt.layout_desc.blocking.block_dims[i] = 1;
mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset
mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride;
mem_fmt.layout_desc.blocking.strides[1][axis[i]] = 1;
total_stride *= nchw_tz[axis[i]];
}
mem_fmt.layout_desc.blocking.offset_padding = 0; // no initial offset
return mem_fmt;
}
private:
std::vector<int> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
};
template <class forward_t, class backward_data_t, class backward_weights_t> template <class forward_t, class backward_data_t, class backward_weights_t>
class ConvMKLDNNTemplateHandler : public MKLDNNHandler { class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册