提交 751a826c 编写于 作者: X xiaolil1

fix conv int8 bugs with debug log

上级 fcbe4898
...@@ -173,6 +173,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -173,6 +173,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
dev_ctx_.SetBlob(prim_key, conv_p); dev_ctx_.SetBlob(prim_key, conv_p);
} else { } else {
std::cout<<"4 is reuse = "<<is_reusing_;
is_reusing_ = true; is_reusing_ = true;
} }
return conv_p; return conv_p;
...@@ -186,6 +187,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -186,6 +187,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
auto prim_key = key_ + "@conv_p"; auto prim_key = key_ + "@conv_p";
auto conv_p = std::static_pointer_cast<mkldnn::convolution_forward>( auto conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx_.GetBlob(prim_key)); dev_ctx_.GetBlob(prim_key));
//is_reusing_ = false;
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false), PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context"); "Fail to find convolution primitive in device context");
if (conv_p == nullptr) { if (conv_p == nullptr) {
...@@ -195,6 +197,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -195,6 +197,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
dev_ctx_.SetBlob(prim_key, conv_p); dev_ctx_.SetBlob(prim_key, conv_p);
} else { } else {
std::cout<<"5 is reuse = "<<is_reusing_;
is_reusing_ = true; is_reusing_ = true;
} }
return conv_p; return conv_p;
...@@ -376,40 +379,64 @@ std::cout<<"this is conv int8 op .............."<<std::endl; ...@@ -376,40 +379,64 @@ std::cout<<"this is conv int8 op .............."<<std::endl;
ctx.op().Output("Output")); ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline; std::cout<<key_conv_pd<<std::endl;
std::vector<primitive> pipeline;
std::cout<<"log1....."<<std::endl;
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format()); {src_tz}, platform::MKLDNNGetDataType<float>(), mkldnn::memory::format::nChw16c);
std::cout<<"log2....."<<std::endl;
auto user_weights_md = platform::MKLDNNMemDesc( auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<float>(), {weights_tz}, platform::MKLDNNGetDataType<float>(),
(g == 1) ? filter->format() : mkldnn::memory::format::goihw); (g == 1) ? filter->format() : mkldnn::memory::format::goihw);
std::cout<<"log3....."<<std::endl;
/* create memory descriptor for convolution without specified format /* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose * ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance * the memory format preferred for best performance
*/ */
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
//std::shared_ptr<mkldnn::memory::desc> src_md;
//std::shared_ptr<mkldnn::memory::desc> weights_md;
//std::shared_ptr<mkldnn::memory::desc> dst_md;
std::vector<int> bias_tz;
//if(is_INT8){
// src_md.reset(new platform::MKLDNNMemDesc(
// src_tz, memory::data_type::u8, chosen_memory_format));
// weights_md.reset(new platform::MKLDNNMemDesc(
// weights_tz, memory::data_type::s8,
// (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw));
// dst_md.reset(new platform::MKLDNNMemDesc(
// dst_tz,
// fuse_relu?memory::data_type::u8:memory::data_type::s8,
// chosen_memory_format));
//} else{
// src_md.reset(new platform::MKLDNNMemDesc(
// src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format));
// weights_md.reset(new platform::MKLDNNMemDesc(
// weights_tz, platform::MKLDNNGetDataType<T>(),
// (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw));
// dst_md.reset(new platform::MKLDNNMemDesc(
// dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format));
//}
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc( auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), weights_tz, platform::MKLDNNGetDataType<float>(),
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
// Currently used whenever bias is != nullptr.
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format);
if(is_INT8){ if(is_INT8){
src_md = platform::MKLDNNMemDesc( src_md = platform::MKLDNNMemDesc(
src_tz, memory::data_type::u8, chosen_memory_format); src_tz, memory::data_type::u8, chosen_memory_format);
weights_md = platform::MKLDNNMemDesc( weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8, weights_tz, memory::data_type::s8,
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
dst_md = platform::MKLDNNMemDesc( dst_md = platform::MKLDNNMemDesc(
dst_tz, dst_tz,
fuse_relu?memory::data_type::u8:memory::data_type::s8, fuse_relu?memory::data_type::u8:memory::data_type::s8,
chosen_memory_format); chosen_memory_format);
} }
...@@ -467,7 +494,7 @@ std::cout<<"this is conv int8 op .............."<<std::endl; ...@@ -467,7 +494,7 @@ std::cout<<"this is conv int8 op .............."<<std::endl;
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
std::vector<float> scale_weights_data(count); std::vector<float> scale_weights_data(count);
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
scale_weights_data[i] = *(scale_weights->data<T>() + i); scale_weights_data[i] = *(scale_weights->data<float>() + i);
} }
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder); user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder);
...@@ -526,8 +553,8 @@ std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()< ...@@ -526,8 +553,8 @@ std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<
{bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x); {bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x);
auto user_bias_memory_p = auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<float>(bias_data)); handler.AcquireBiasMemory(user_bias_md, to_void_cast<float>(bias_data));
auto bias_memory_p = std::shared_ptr<mkldnn::memory> bias_memory_p;// =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); //handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
if(is_INT8){ if(is_INT8){
int mask_reorder = is_multi_channel? 0 : 1<<0; int mask_reorder = is_multi_channel? 0 : 1<<0;
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
...@@ -535,9 +562,12 @@ std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()< ...@@ -535,9 +562,12 @@ std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
scale_bias_data[i] = (*scale_in->data<float>()) * (*(scale_weights->data<float>() + i)); scale_bias_data[i] = (*scale_in->data<float>()) * (*(scale_weights->data<float>() + i));
} }
auto bias_memory_p = bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_INT8, scale_bias_data, mask_reorder); handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_INT8, scale_bias_data, mask_reorder);
} } else{
bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
}
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p); bias_memory_p, dst_memory_p);
} else { } else {
......
...@@ -70,6 +70,7 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims, ...@@ -70,6 +70,7 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims,
mkldnn::memory::data_type data_type, mkldnn::memory::data_type data_type,
mkldnn::memory::format format) { mkldnn::memory::format format) {
mkldnn::memory::dims tz = dims; mkldnn::memory::dims tz = dims;
std::cout<<"this is MKLDNNMemDesc"<<" data_type"<<data_type<<" format"<<format<<std::endl;
return mkldnn::memory::desc({tz}, data_type, format); return mkldnn::memory::desc({tz}, data_type, format);
} }
...@@ -163,6 +164,7 @@ std::cout<<"mem_p == null"<<std::endl; ...@@ -163,6 +164,7 @@ std::cout<<"mem_p == null"<<std::endl;
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance // Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency // should be reused or none of them. So we check consistency
std::cout<<"1 is reuse = "<<is_reusing_;
is_reusing_ = true; is_reusing_ = true;
} }
std::cout<<"mdp fmt = "<<mdp.desc().data.format<<" mem_p fmt = "<<mem_p->get_primitive_desc().desc().data.format<<std::endl; std::cout<<"mdp fmt = "<<mdp.desc().data.format<<" mem_p fmt = "<<mem_p->get_primitive_desc().desc().data.format<<std::endl;
...@@ -188,6 +190,7 @@ std::cout<<"mem_p == null"<<std::endl; ...@@ -188,6 +190,7 @@ std::cout<<"mem_p == null"<<std::endl;
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance // Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency // should be reused or none of them. So we check consistency
std::cout<<"2 is reuse = "<<is_reusing_;
is_reusing_ = true; is_reusing_ = true;
} }
std::cout<<"md fmt = "<<md.data.format<<" mem_p fmt = "<<mem_p->get_primitive_desc().desc().data.format<<std::endl; std::cout<<"md fmt = "<<md.data.format<<" mem_p fmt = "<<mem_p->get_primitive_desc().desc().data.format<<std::endl;
...@@ -239,6 +242,7 @@ std::cout<<"md fmt = "<<md.data.format<<" mem_p fmt = "<<mem_p->get_primitive_ ...@@ -239,6 +242,7 @@ std::cout<<"md fmt = "<<md.data.format<<" mem_p fmt = "<<mem_p->get_primitive_
if (reorder_p != nullptr) { if (reorder_p != nullptr) {
pipeline.push_back(*reorder_p); pipeline.push_back(*reorder_p);
} }
std::cout<<"3 is reuse = "<<is_reusing_;
is_reusing_ = true; is_reusing_ = true;
} }
return target_memory_p; return target_memory_p;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册