未验证 提交 85b6bb58 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #10747 from jczaja/prv-mkldnn-pooling-reuse

Reuse of pooling mkldnn primitives
......@@ -18,6 +18,26 @@ limitations under the License. */
namespace paddle {
namespace operators {
using mkldnn::memory; // Note: paddle has also "memory" namespace
using mkldnn::pooling_forward;
using mkldnn::pooling_backward;
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std::string gethash(memory::dims& input_dims, std::string& pooling_type,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, std::string suffix) {
auto dims2str = [](memory::dims& operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
};
return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) +
dims2str(paddings) + pooling_type + suffix;
}
template <typename T>
class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
......@@ -34,10 +54,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when saving info into device context
const std::string key = ctx.op().Output("Out");
const std::string key_pool_pd = key + "@pool_pd";
const std::string key_pool_workspace_memory =
key + "@pool_workspace_memory";
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
......@@ -63,13 +79,28 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
const std::string key = gethash(src_tz, pooling_type, ksize, strides,
paddings, ctx.op().Output("Out"));
const std::string key_pool_p = key + "@pool_p";
const std::string key_pool_pd = key + "@pool_pd";
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p";
const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p";
const std::string key_pool_workspace_memory =
key + "@pool_workspace_memory";
auto pool_p =
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p));
if (pool_p == nullptr) {
// TODO(pzelazko-intel): support more formats
auto src_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
auto src_md =
platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(),
mkldnn::memory::format::nchw);
auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32,
auto dst_md =
platform::MKLDNNMemDesc(dst_tz, platform::MKLDNNGetDataType<T>(),
mkldnn::memory::format::nchw);
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
std::shared_ptr<pooling_forward::primitive_desc> pool_pd =
CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize,
pooling_type, mkldnn_engine);
......@@ -82,18 +113,37 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// save pool_workspace_memory to be referred in backward path
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory);
auto src_memory =
mkldnn::memory({src_md, mkldnn_engine},
auto pool_src_memory_p = std::make_shared<memory>(
memory::primitive_desc{src_md, mkldnn_engine},
static_cast<void*>(const_cast<T*>(input_data)));
auto dst_memory =
mkldnn::memory({dst_md, mkldnn_engine},
static_cast<void*>(const_cast<T*>(output_data)));
dev_ctx.SetBlob(key_pool_src_mem_p, pool_src_memory_p);
auto pool_prim = mkldnn::pooling_forward(*pool_pd, src_memory, dst_memory,
auto pool_dst_memory_p = std::make_shared<memory>(
memory::primitive_desc{dst_md, mkldnn_engine},
static_cast<void*>(output_data));
dev_ctx.SetBlob(key_pool_dst_mem_p, pool_dst_memory_p);
pool_p = std::make_shared<pooling_forward>(
*pool_pd, *(pool_src_memory_p.get()), *(pool_dst_memory_p.get()),
*workspace_memory);
dev_ctx.SetBlob(key_pool_p, pool_p);
} else {
// Primitives already exist
auto pool_src_memory_p =
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_src_mem_p));
PADDLE_ENFORCE(pool_src_memory_p != nullptr,
"Fail to find pooling src mem_p in device context");
auto pool_dst_memory_p =
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p));
PADDLE_ENFORCE(pool_dst_memory_p != nullptr,
"Fail to find pooling dst mem_p in device context");
pool_src_memory_p->set_data_handle(
reinterpret_cast<void*>(const_cast<T*>(input_data)));
pool_dst_memory_p->set_data_handle(output_data);
}
// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline{pool_prim};
std::vector<mkldnn::primitive> pipeline{*(pool_p.get())};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
......@@ -120,8 +170,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::primitive_desc workspace_md =
pooling_type == "max"
? pool_pd->workspace_primitive_desc()
: mkldnn::memory::primitive_desc(
{{}, mkldnn::memory::f32, mkldnn::memory::format::nchw},
: mkldnn::memory::primitive_desc({{},
platform::MKLDNNGetDataType<T>(),
mkldnn::memory::format::nchw},
engine);
auto p_workspace_memory = new mkldnn::memory(workspace_md);
......@@ -140,13 +191,6 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
const std::string key = ctx.op().Input("Out");
const std::string key_pool_pd = key + "@pool_pd";
const std::string key_pool_workspace_memory =
key + "@pool_workspace_memory";
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
......@@ -171,11 +215,26 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> diff_dst_tz =
paddle::framework::vectorize2int(out_grad->dims());
auto diff_src_md = platform::MKLDNNMemDesc(diff_src_tz, mkldnn::memory::f32,
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
const std::string key = gethash(diff_src_tz, pooling_type, ksize, strides,
paddings, ctx.op().Input("Out"));
const std::string key_pool_bwd_p = key + "@pool_bwd_p";
const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p";
const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";
const std::string key_pool_pd = key + "@pool_pd";
const std::string key_pool_workspace_memory =
key + "@pool_workspace_memory";
auto pool_bwd_p = std::static_pointer_cast<pooling_backward>(
dev_ctx.GetBlob(key_pool_bwd_p));
if (pool_bwd_p == nullptr) {
auto diff_src_md =
platform::MKLDNNMemDesc(diff_src_tz, platform::MKLDNNGetDataType<T>(),
mkldnn::memory::format::nchw);
auto diff_dst_md = platform::MKLDNNMemDesc(diff_dst_tz, mkldnn::memory::f32,
auto diff_dst_md =
platform::MKLDNNMemDesc(diff_dst_tz, platform::MKLDNNGetDataType<T>(),
mkldnn::memory::format::nchw);
// Retrieve pool_pd/pool_workspace_memory from device context
auto pool_pd =
std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>(
......@@ -188,6 +247,15 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(workspace_memory != nullptr,
"Fail to find workspace_memory in device context");
auto pool_diff_src_memory_p = std::make_shared<memory>(memory(
{diff_src_md, mkldnn_engine}, static_cast<void*>(in_x_grad_data)));
dev_ctx.SetBlob(key_pool_diff_src_mem_p, pool_diff_src_memory_p);
auto pool_diff_dst_memory_p = std::make_shared<memory>(
memory({diff_dst_md, mkldnn_engine},
static_cast<void*>(const_cast<T*>(out_grad_data))));
dev_ctx.SetBlob(key_pool_diff_dst_mem_p, pool_diff_dst_memory_p);
auto pool_bwd_desc = mkldnn::pooling_backward::desc(
pooling_type == "max" ? mkldnn::algorithm::pooling_max
: mkldnn::algorithm::pooling_avg,
......@@ -196,18 +264,27 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc(
pool_bwd_desc, mkldnn_engine, *pool_pd);
auto diff_src_memory =
mkldnn::memory({diff_src_md, mkldnn_engine},
static_cast<void*>(const_cast<T*>(in_x_grad_data)));
auto diff_dst_memory =
mkldnn::memory({diff_dst_md, mkldnn_engine},
static_cast<void*>(const_cast<T*>(out_grad_data)));
auto bwd_prim = mkldnn::pooling_backward(
pool_bwd_pd, diff_dst_memory, *workspace_memory, diff_src_memory);
pool_bwd_p = std::make_shared<pooling_backward>(
pool_bwd_pd, *(pool_diff_dst_memory_p.get()), *workspace_memory,
*(pool_diff_src_memory_p));
dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p);
} else {
// Primitives already exist
auto pool_diff_src_memory_p = std::static_pointer_cast<memory>(
dev_ctx.GetBlob(key_pool_diff_src_mem_p));
PADDLE_ENFORCE(pool_diff_src_memory_p != nullptr,
"Fail to find pooling src mem_p in device context");
auto pool_diff_dst_memory_p = std::static_pointer_cast<memory>(
dev_ctx.GetBlob(key_pool_diff_dst_mem_p));
PADDLE_ENFORCE(pool_diff_dst_memory_p != nullptr,
"Fail to find pooling dst mem_p in device context");
pool_diff_src_memory_p->set_data_handle(
reinterpret_cast<void*>(in_x_grad_data));
pool_diff_dst_memory_p->set_data_handle(const_cast<T*>(out_grad_data));
}
// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline{bwd_prim};
std::vector<mkldnn::primitive> pipeline{*(pool_bwd_p.get())};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
} // Compute()
};
......
......@@ -71,5 +71,15 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
return use_mkldnn && platform::is_cpu_place(ctx.GetPlace());
}
template <typename Type>
mkldnn::memory::data_type MKLDNNGetDataType() {
return mkldnn::memory::data_undef;
}
template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<float>() {
return mkldnn::memory::f32;
}
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册