提交 3050a73f 编写于 作者: X xiaolil1

modify md reuse to decrease map find

上级 dd630819
...@@ -452,13 +452,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -452,13 +452,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
static std::unordered_map<std::string, std::shared_ptr<mkldnn::memory::desc>> md_map; static std::unordered_map<std::string, std::vector<std::shared_ptr<mkldnn::memory::desc>>> md_map;
bool md_reuse = true; bool md_reuse = true;
auto user_src_md_key = key + "@user_src_md"; std::vector<std::shared_ptr<mkldnn::memory::desc>> mds(8, nullptr);
if (GetMdMap(md_map, user_src_md_key) == nullptr){ std::vector<std::shared_ptr<mkldnn::memory::desc>> none_mds = {};
//auto user_src_md_key = key + "@user_src_md";
if (GetMdMap(md_map, key) == none_mds){
md_reuse = false; //we suppose all mds are reused if the first md is in the map. md_reuse = false; //we suppose all mds are reused if the first md is in the map.
} else{
mds = GetMdMap(md_map, key);
} }
auto user_weights_md_key = key + "@user_weights_md"; //auto user_weights_md_key = key + "@user_weights_md";
std::shared_ptr<mkldnn::memory::desc> user_src_md; std::shared_ptr<mkldnn::memory::desc> user_src_md;
std::shared_ptr<mkldnn::memory::desc> user_weights_md; std::shared_ptr<mkldnn::memory::desc> user_weights_md;
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
...@@ -470,12 +474,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -470,12 +474,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
user_weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( user_weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<float>(), {weights_tz}, platform::MKLDNNGetDataType<float>(),
(g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw))); (g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw)));
SetMdMap(md_map, user_src_md_key, user_src_md); mds[0] = user_src_md;
SetMdMap(md_map, user_weights_md_key, user_weights_md); mds[1] = user_weights_md;
//SetMdMap(md_map, user_src_md_key, user_src_md);
//SetMdMap(md_map, user_weights_md_key, user_weights_md);
} else{ } else{
user_src_md = GetMdMap(md_map, user_src_md_key); user_src_md = mds[0];
user_weights_md = GetMdMap(md_map, user_weights_md_key); user_weights_md = mds[1];
//user_src_md = GetMdMap(md_map, user_src_md_key);
//user_weights_md = GetMdMap(md_map, user_weights_md_key);
} }
/* create memory descriptor for convolution without specified format /* create memory descriptor for convolution without specified format
...@@ -489,10 +497,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -489,10 +497,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd; std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
auto bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_tz = paddle::framework::vectorize2int(bias->dims());
auto src_md_key = key + "@src_md"; //auto src_md_key = key + "@src_md";
auto weights_md_key = key + "@weights_md_key"; //auto weights_md_key = key + "@weights_md_key";
auto dst_md_key = key + "@dst_md_key"; //auto dst_md_key = key + "@dst_md_key";
auto bias_md_key = key + "@bias_md_key"; //auto bias_md_key = key + "@bias_md_key";
std::shared_ptr<mkldnn::memory::desc> src_md; std::shared_ptr<mkldnn::memory::desc> src_md;
std::shared_ptr<mkldnn::memory::desc> weights_md; std::shared_ptr<mkldnn::memory::desc> weights_md;
std::shared_ptr<mkldnn::memory::desc> dst_md; std::shared_ptr<mkldnn::memory::desc> dst_md;
...@@ -512,13 +520,19 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -512,13 +520,19 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_dt = residual_dt; dst_dt = residual_dt;
} }
dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format))); dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format)));
SetMdMap(md_map, src_md_key, src_md); mds[2] = src_md;
SetMdMap(md_map, weights_md_key, weights_md); mds[3] = weights_md;
SetMdMap(md_map, dst_md_key, dst_md); mds[4] = dst_md;
//SetMdMap(md_map, src_md_key, src_md);
//SetMdMap(md_map, weights_md_key, weights_md);
//SetMdMap(md_map, dst_md_key, dst_md);
} else{ } else{
src_md = GetMdMap(md_map, src_md_key); src_md = mds[2];
weights_md = GetMdMap(md_map, weights_md_key); weights_md = mds[3];
dst_md = GetMdMap(md_map, dst_md_key); dst_md = mds[4];
//src_md = GetMdMap(md_map, src_md_key);
//weights_md = GetMdMap(md_map, weights_md_key);
//dst_md = GetMdMap(md_map, dst_md_key);
} }
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
...@@ -527,9 +541,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -527,9 +541,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if(!md_reuse){ if(!md_reuse){
bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
bias_tz, memory::data_type::s32, memory::format::x))); bias_tz, memory::data_type::s32, memory::format::x)));
SetMdMap(md_map, bias_md_key, bias_md); mds[5] = bias_md;
//SetMdMap(md_map, bias_md_key, bias_md);
} else{ } else{
bias_md = GetMdMap(md_map, bias_md_key); bias_md = mds[5];
//bias_md = GetMdMap(md_map, bias_md_key);
} }
conv_pd = ConvFwdPrimitiveDesc(*src_md, *weights_md, *bias_md, *dst_md, conv_pd = ConvFwdPrimitiveDesc(*src_md, *weights_md, *bias_md, *dst_md,
...@@ -551,13 +567,19 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -551,13 +567,19 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw))); (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw)));
dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format))); dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format)));
SetMdMap(md_map, src_md_key, src_md); mds[2] = src_md;
SetMdMap(md_map, weights_md_key, weights_md); mds[3] = weights_md;
SetMdMap(md_map, dst_md_key, dst_md); mds[4] = dst_md;
//SetMdMap(md_map, src_md_key, src_md);
//SetMdMap(md_map, weights_md_key, weights_md);
//SetMdMap(md_map, dst_md_key, dst_md);
} else{ } else{
src_md = GetMdMap(md_map, src_md_key); src_md = mds[2];
weights_md = GetMdMap(md_map, weights_md_key); weights_md = mds[3];
dst_md = GetMdMap(md_map, dst_md_key); dst_md = mds[4];
//src_md = GetMdMap(md_map, src_md_key);
//weights_md = GetMdMap(md_map, weights_md_key);
//dst_md = GetMdMap(md_map, dst_md_key);
} }
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
if (bias) { if (bias) {
...@@ -565,9 +587,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -565,9 +587,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if(!md_reuse){ if(!md_reuse){
bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<float>(), memory::format::x))); bias_tz, platform::MKLDNNGetDataType<float>(), memory::format::x)));
SetMdMap(md_map, bias_md_key, bias_md); mds[5] = bias_md;
//SetMdMap(md_map, bias_md_key, bias_md);
} else{ } else{
bias_md = GetMdMap(md_map, bias_md_key); bias_md = mds[5];
//bias_md = GetMdMap(md_map, bias_md_key);
} }
conv_pd = ConvFwdPrimitiveDesc(*src_md, *weights_md, *bias_md, *dst_md, conv_pd = ConvFwdPrimitiveDesc(*src_md, *weights_md, *bias_md, *dst_md,
strides, paddings, mkldnn_engine, strides, paddings, mkldnn_engine,
...@@ -605,7 +629,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -605,7 +629,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::memory> dst_memory_p; std::shared_ptr<mkldnn::memory> dst_memory_p;
bool need_s8_to_u8 = false; bool need_s8_to_u8 = false;
auto user_residual_md_key = key + "@user_residual_md"; //auto user_residual_md_key = key + "@user_residual_md";
if(fuse_residual_conn) { if(fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
...@@ -621,9 +645,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -621,9 +645,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(residual_param->type()); paddle::framework::ToMKLDNNDataType(residual_param->type());
user_residual_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( user_residual_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format()))); residual_data_tz, residual_data_type, residual_param->format())));
SetMdMap(md_map, user_residual_md_key, user_residual_md); mds[6] = user_residual_md;
//SetMdMap(md_map, user_residual_md_key, user_residual_md);
} else{ } else{
user_residual_md = GetMdMap(md_map, user_residual_md_key); user_residual_md = mds[6];
//user_residual_md = GetMdMap(md_map, user_residual_md_key);
} }
if(is_INT8){ if(is_INT8){
if(residual_dt == mkldnn::memory::data_type::u8){ if(residual_dt == mkldnn::memory::data_type::u8){
...@@ -706,16 +732,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -706,16 +732,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// create convolution op primitive // create convolution op primitive
std::shared_ptr<mkldnn::convolution_forward> conv_p; std::shared_ptr<mkldnn::convolution_forward> conv_p;
//auto scale_bias_key = key + "@scale_bias"; //auto scale_bias_key = key + "@scale_bias";
auto user_bias_md_key = key + "@user_bias_md"; //auto user_bias_md_key = key + "@user_bias_md";
if (bias) { if (bias) {
const float* bias_data = bias->data<float>(); const float* bias_data = bias->data<float>();
std::shared_ptr<mkldnn::memory::desc> user_bias_md; std::shared_ptr<mkldnn::memory::desc> user_bias_md;
if(!md_reuse){ if(!md_reuse){
user_bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( user_bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x))); {bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x)));
SetMdMap(md_map, user_bias_md_key, user_bias_md); mds[7] = user_bias_md;
//SetMdMap(md_map, user_bias_md_key, user_bias_md);
} else{ } else{
user_bias_md = GetMdMap(md_map, user_bias_md_key); user_bias_md = mds[7];
//user_bias_md = GetMdMap(md_map, user_bias_md_key);
} }
auto user_bias_memory_p = auto user_bias_memory_p =
handler.AcquireBiasMemory(*user_bias_md, to_void_cast<float>(bias_data)); handler.AcquireBiasMemory(*user_bias_md, to_void_cast<float>(bias_data));
...@@ -748,6 +776,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -748,6 +776,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
SetScaleMap(scale_map, key, scale_datas); SetScaleMap(scale_map, key, scale_datas);
SetMdMap(md_map, key, mds);
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
...@@ -783,8 +812,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -783,8 +812,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
return {{0.0f}}; return {{0.0f}};
} }
void SetMdMap(std::unordered_map<std::string, std::shared_ptr<mkldnn::memory::desc>> &md_map, void SetMdMap(std::unordered_map<std::string, std::vector<std::shared_ptr<mkldnn::memory::desc>>> &md_map,
const std::string& name, std::shared_ptr<mkldnn::memory::desc> mds) const { const std::string& name, std::vector<std::shared_ptr<mkldnn::memory::desc>> mds) const {
auto it = md_map.find(name); auto it = md_map.find(name);
if (it == md_map.end()) { if (it == md_map.end()) {
md_map[name] = mds; // create new blob md_map[name] = mds; // create new blob
...@@ -794,13 +823,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -794,13 +823,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
return; return;
} }
std::shared_ptr<mkldnn::memory::desc> GetMdMap(std::unordered_map<std::string, std::shared_ptr<mkldnn::memory::desc>> md_map, std::vector<std::shared_ptr<mkldnn::memory::desc>> GetMdMap(std::unordered_map<std::string, std::vector<std::shared_ptr<mkldnn::memory::desc>>> md_map,
const std::string& name) const { const std::string& name) const {
auto it = md_map.find(name); auto it = md_map.find(name);
if (it != md_map.end()) { if (it != md_map.end()) {
return (*it).second; return (*it).second;
} }
return nullptr; return {};
} }
mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn, mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册