提交 751a826c 编写于 作者: X xiaolil1

fix conv int8 bugs with debug log

上级 fcbe4898
......@@ -173,6 +173,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
dev_ctx_.SetBlob(prim_key, conv_p);
} else {
std::cout<<"4 is reuse = "<<is_reusing_;
is_reusing_ = true;
}
return conv_p;
......@@ -186,6 +187,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
auto prim_key = key_ + "@conv_p";
auto conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx_.GetBlob(prim_key));
//is_reusing_ = false;
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (conv_p == nullptr) {
......@@ -195,6 +197,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
dev_ctx_.SetBlob(prim_key, conv_p);
} else {
std::cout<<"5 is reuse = "<<is_reusing_;
is_reusing_ = true;
}
return conv_p;
......@@ -376,40 +379,64 @@ std::cout<<"this is conv int8 op .............."<<std::endl;
ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline;
std::cout<<key_conv_pd<<std::endl;
std::vector<primitive> pipeline;
std::cout<<"log1....."<<std::endl;
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
{src_tz}, platform::MKLDNNGetDataType<float>(), mkldnn::memory::format::nChw16c);
std::cout<<"log2....."<<std::endl;
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<float>(),
(g == 1) ? filter->format() : mkldnn::memory::format::goihw);
std::cout<<"log3....."<<std::endl;
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
//std::shared_ptr<mkldnn::memory::desc> src_md;
//std::shared_ptr<mkldnn::memory::desc> weights_md;
//std::shared_ptr<mkldnn::memory::desc> dst_md;
std::vector<int> bias_tz;
//if(is_INT8){
// src_md.reset(new platform::MKLDNNMemDesc(
// src_tz, memory::data_type::u8, chosen_memory_format));
// weights_md.reset(new platform::MKLDNNMemDesc(
// weights_tz, memory::data_type::s8,
// (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw));
// dst_md.reset(new platform::MKLDNNMemDesc(
// dst_tz,
// fuse_relu?memory::data_type::u8:memory::data_type::s8,
// chosen_memory_format));
//} else{
// src_md.reset(new platform::MKLDNNMemDesc(
// src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format));
// weights_md.reset(new platform::MKLDNNMemDesc(
// weights_tz, platform::MKLDNNGetDataType<T>(),
// (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw));
// dst_md.reset(new platform::MKLDNNMemDesc(
// dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format));
//}
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
src_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(),
weights_tz, platform::MKLDNNGetDataType<float>(),
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
// Currently used whenever bias is != nullptr.
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format);
if(is_INT8){
src_md = platform::MKLDNNMemDesc(
src_tz, memory::data_type::u8, chosen_memory_format);
weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8,
weights_tz, memory::data_type::s8,
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
dst_md = platform::MKLDNNMemDesc(
dst_tz,
dst_tz,
fuse_relu?memory::data_type::u8:memory::data_type::s8,
chosen_memory_format);
}
......@@ -467,7 +494,7 @@ std::cout<<"this is conv int8 op .............."<<std::endl;
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
std::vector<float> scale_weights_data(count);
for(int i=0; i<count; i++){
scale_weights_data[i] = *(scale_weights->data<T>() + i);
scale_weights_data[i] = *(scale_weights->data<float>() + i);
}
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder);
......@@ -526,8 +553,8 @@ std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<
{bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x);
auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<float>(bias_data));
auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
std::shared_ptr<mkldnn::memory> bias_memory_p;// =
//handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
if(is_INT8){
int mask_reorder = is_multi_channel? 0 : 1<<0;
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
......@@ -535,9 +562,12 @@ std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<
for(int i=0; i<count; i++){
scale_bias_data[i] = (*scale_in->data<float>()) * (*(scale_weights->data<float>() + i));
}
auto bias_memory_p =
bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_INT8, scale_bias_data, mask_reorder);
}
} else{
bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
}
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p);
} else {
......
......@@ -70,6 +70,7 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims,
mkldnn::memory::data_type data_type,
mkldnn::memory::format format) {
mkldnn::memory::dims tz = dims;
std::cout<<"this is MKLDNNMemDesc"<<" data_type"<<data_type<<" format"<<format<<std::endl;
return mkldnn::memory::desc({tz}, data_type, format);
}
......@@ -163,6 +164,7 @@ std::cout<<"mem_p == null"<<std::endl;
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
std::cout<<"1 is reuse = "<<is_reusing_;
is_reusing_ = true;
}
std::cout<<"mdp fmt = "<<mdp.desc().data.format<<" mem_p fmt = "<<mem_p->get_primitive_desc().desc().data.format<<std::endl;
......@@ -188,6 +190,7 @@ std::cout<<"mem_p == null"<<std::endl;
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
std::cout<<"2 is reuse = "<<is_reusing_;
is_reusing_ = true;
}
std::cout<<"md fmt = "<<md.data.format<<" mem_p fmt = "<<mem_p->get_primitive_desc().desc().data.format<<std::endl;
......@@ -239,6 +242,7 @@ std::cout<<"md fmt = "<<md.data.format<<" mem_p fmt = "<<mem_p->get_primitive_
if (reorder_p != nullptr) {
pipeline.push_back(*reorder_p);
}
std::cout<<"3 is reuse = "<<is_reusing_;
is_reusing_ = true;
}
return target_memory_p;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册