提交 b50c6a15 编写于 作者: X xiaolil1

enable md reuse for INT8 and FP32 forward

上级 53545a37
...@@ -376,7 +376,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -376,7 +376,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
static std::unordered_map<std::string, std::vector<float>> scale_map; static std::unordered_map<std::string, std::vector<float>> scale_map;
//scale_map.insert({key_conv_pd,{1.0f}}); //scale_map.insert({key_conv_pd,{1.0f}});
//scale_map[key_conv_pd]={0.1f}; //scale_map[key_conv_pd]={0.1f};
bool scale_reuse = false; bool scale_reuse = true;
auto scale_in_key = key + "@scale_in"; auto scale_in_key = key + "@scale_in";
auto scale_weights_key = key + "@scale_weights"; auto scale_weights_key = key + "@scale_weights";
auto scale_out_key = key + "@scale_out"; auto scale_out_key = key + "@scale_out";
...@@ -389,14 +389,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -389,14 +389,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<float> scale_in_eltwise_data; std::vector<float> scale_in_eltwise_data;
std::vector<float> output_shift_scale; std::vector<float> output_shift_scale;
std::vector<float> sum_scale = {1.0f}; std::vector<float> sum_scale = {1.0f};
std::vector<float> none_scale = {0}; std::vector<float> none_scale = {0.0f};
if (is_INT8 && GetScaleMap(scale_map, scale_in_key) == none_scale){ if (is_INT8 && GetScaleMap(scale_map, scale_in_key) == none_scale){
scale_reuse = true; scale_reuse = false;
} }
//std::cout<<"scale_reuse = "<<scale_reuse<<std::endl; //std::cout<<"scale_reuse = "<<scale_reuse<<std::endl;
if(is_INT8){ if(is_INT8){
if(scale_reuse){ if(!scale_reuse){
//std::cout<<"load scale!!!!!!!!"<<std::endl; //std::cout<<"load scale!!!!!!!!"<<std::endl;
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
scale_in_data = {*(scale_in->data<float>())}; scale_in_data = {*(scale_in->data<float>())};
...@@ -440,13 +440,31 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -440,13 +440,31 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
static std::unordered_map<std::string, std::shared_ptr<mkldnn::memory::desc>> md_map;
bool md_reuse = true;
auto user_src_md_key = key + "@user_src_md";
if (GetMdMap(md_map, user_src_md_key) == nullptr){
md_reuse = false; //we suppose all mds are reused if the first md is in the map.
}
auto user_weights_md_key = key + "@user_weights_md";
std::shared_ptr<mkldnn::memory::desc> user_src_md;
std::shared_ptr<mkldnn::memory::desc> user_weights_md;
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
auto user_src_md = platform::MKLDNNMemDesc( //std::cout<<"md_reuse = "<<md_reuse<<std::endl;
{src_tz}, paddle::framework::ToMKLDNNDataType(input->type()), input->format()); if(!md_reuse){
auto user_weights_md = platform::MKLDNNMemDesc( //std::cout<<"create md.......... "<<std::endl;
{weights_tz}, platform::MKLDNNGetDataType<float>(), user_src_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
(g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw); {src_tz}, paddle::framework::ToMKLDNNDataType(input->type()), input->format())));
user_weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<float>(),
(g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw)));
SetMdMap(md_map, user_src_md_key, user_src_md);
SetMdMap(md_map, user_weights_md_key, user_weights_md);
} else{
user_src_md = GetMdMap(md_map, user_src_md_key);
user_weights_md = GetMdMap(md_map, user_weights_md_key);
}
/* create memory descriptor for convolution without specified format /* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose * ('any') which lets a primitive (convolution in this case) choose
...@@ -458,53 +476,93 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -458,53 +476,93 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd; std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
auto bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_tz = paddle::framework::vectorize2int(bias->dims());
auto src_md_key = key + "@src_md";
auto weights_md_key = key + "@weights_md_key";
auto dst_md_key = key + "@dst_md_key";
auto bias_md_key = key + "@bias_md_key";
std::shared_ptr<mkldnn::memory::desc> src_md;
std::shared_ptr<mkldnn::memory::desc> weights_md;
std::shared_ptr<mkldnn::memory::desc> dst_md;
if(is_INT8){ if(is_INT8){
auto src_md = platform::MKLDNNMemDesc( if(!md_reuse){
src_tz, memory::data_type::u8, chosen_memory_format); src_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
auto weights_md = platform::MKLDNNMemDesc( src_tz, memory::data_type::u8, chosen_memory_format)));
weights_tz, memory::data_type::s8, weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); weights_tz, memory::data_type::s8,
auto dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char))); (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw)));
if(fuse_residual_conn){ auto dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char)));
auto residual = ctx.Input<Tensor>("ResidualData"); if(fuse_residual_conn){
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type()); auto residual = ctx.Input<Tensor>("ResidualData");
if(dst_dt != residual_dt) auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type());
dst_dt = residual_dt; if(dst_dt != residual_dt)
dst_dt = residual_dt;
}
dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format)));
SetMdMap(md_map, src_md_key, src_md);
SetMdMap(md_map, weights_md_key, weights_md);
SetMdMap(md_map, dst_md_key, dst_md);
} else{
src_md = GetMdMap(md_map, src_md_key);
weights_md = GetMdMap(md_map, weights_md_key);
dst_md = GetMdMap(md_map, dst_md_key);
} }
auto dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
if (bias) { if (bias) {
auto bias_md = platform::MKLDNNMemDesc( std::shared_ptr<mkldnn::memory::desc> bias_md;
bias_tz, memory::data_type::s32, memory::format::x); if(!md_reuse){
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
bias_tz, memory::data_type::s32, memory::format::x)));
SetMdMap(md_map, bias_md_key, bias_md);
} else{
bias_md = GetMdMap(md_map, bias_md_key);
}
conv_pd = ConvFwdPrimitiveDesc(*src_md, *weights_md, *bias_md, *dst_md,
strides, paddings, mkldnn_engine, strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn, fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale[0], is_test); output_shift_scale, sum_scale[0], is_test);
} else { } else {
conv_pd = conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, ConvFwdPrimitiveDesc(*src_md, *weights_md, *dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, mkldnn_engine, fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale[0], is_test); output_shift_scale, sum_scale[0], is_test);
} }
} else{ } else{
auto src_md = platform::MKLDNNMemDesc( if(!md_reuse){
src_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format); src_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
auto weights_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format)));
weights_tz, platform::MKLDNNGetDataType<float>(), weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); weights_tz, platform::MKLDNNGetDataType<float>(),
auto dst_md = platform::MKLDNNMemDesc( (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw)));
dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format); dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format)));
SetMdMap(md_map, src_md_key, src_md);
SetMdMap(md_map, weights_md_key, weights_md);
SetMdMap(md_map, dst_md_key, dst_md);
} else{
src_md = GetMdMap(md_map, src_md_key);
weights_md = GetMdMap(md_map, weights_md_key);
dst_md = GetMdMap(md_map, dst_md_key);
}
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
if (bias) { if (bias) {
auto bias_md = platform::MKLDNNMemDesc( std::shared_ptr<mkldnn::memory::desc> bias_md;
bias_tz, platform::MKLDNNGetDataType<float>(), memory::format::x); if(!md_reuse){
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
strides, paddings, mkldnn_engine, bias_tz, platform::MKLDNNGetDataType<float>(), memory::format::x)));
fuse_relu, fuse_residual_conn, is_test); SetMdMap(md_map, bias_md_key, bias_md);
} else{
bias_md = GetMdMap(md_map, bias_md_key);
}
conv_pd = ConvFwdPrimitiveDesc(*src_md, *weights_md, *bias_md, *dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn, is_test);
} else { } else {
conv_pd = conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, ConvFwdPrimitiveDesc(*src_md, *weights_md, *dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, is_test); mkldnn_engine, fuse_relu, fuse_residual_conn, is_test);
} }
} }
...@@ -515,9 +573,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -515,9 +573,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = auto user_src_memory_p =
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); handler.AcquireSrcMemory(*user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory( auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<float>(filter_data)); *user_weights_md, to_void_cast<float>(filter_data));
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto src_memory_p = auto src_memory_p =
...@@ -535,6 +593,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -535,6 +593,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::memory> dst_memory_p; std::shared_ptr<mkldnn::memory> dst_memory_p;
bool need_s8_to_u8 = false; bool need_s8_to_u8 = false;
auto user_residual_md_key = key + "@user_residual_md";
if(fuse_residual_conn) { if(fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
...@@ -542,42 +601,48 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -542,42 +601,48 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"same dimension sizes"); "same dimension sizes");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type()); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type());
if(residual_param->format() != handler.GetDstFormat()) { if(residual_param->format() != handler.GetDstFormat()) {
auto residual_data_tz = std::shared_ptr<mkldnn::memory::desc> user_residual_md;
paddle::framework::vectorize2int(residual_param->dims()); if(!md_reuse){
auto residual_data_type = auto residual_data_tz =
paddle::framework::ToMKLDNNDataType(residual_param->type()); paddle::framework::vectorize2int(residual_param->dims());
auto user_residual_md = platform::MKLDNNMemDesc( auto residual_data_type =
residual_data_tz, residual_data_type, residual_param->format()); paddle::framework::ToMKLDNNDataType(residual_param->type());
user_residual_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format())));
SetMdMap(md_map, user_residual_md_key, user_residual_md);
} else{
user_residual_md = GetMdMap(md_map, user_residual_md_key);
}
if(is_INT8){ if(is_INT8){
if(residual_dt == mkldnn::memory::data_type::u8){ if(residual_dt == mkldnn::memory::data_type::u8){
auto residual_param_data = residual_param->data<uint8_t>(); auto residual_param_data = residual_param->data<uint8_t>();
auto user_residual_memory_p = handler.AcquireResidualDataMemory( auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<uint8_t>(residual_param_data)); *user_residual_md, to_void_cast<uint8_t>(residual_param_data));
PADDLE_ENFORCE( PADDLE_ENFORCE(
residual_param_data != nullptr, residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion"); "Provide data if you want MKLDNN conv+elementwise_add fusion");
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace()); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromResidualDataMemory( handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<uint8_t>(output_data), pipeline); user_residual_memory_p, to_void_cast<uint8_t>(output_data), pipeline);
} else{ } else{
auto residual_param_data = residual_param->data<int8_t>(); auto residual_param_data = residual_param->data<int8_t>();
auto user_residual_memory_p = handler.AcquireResidualDataMemory( auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<int8_t>(residual_param_data)); *user_residual_md, to_void_cast<int8_t>(residual_param_data));
PADDLE_ENFORCE( PADDLE_ENFORCE(
residual_param_data != nullptr, residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion"); "Provide data if you want MKLDNN conv+elementwise_add fusion");
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace()); int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromResidualDataMemory( handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<int8_t>(output_data), pipeline); user_residual_memory_p, to_void_cast<int8_t>(output_data), pipeline);
if(fuse_relu) if(fuse_relu)
need_s8_to_u8 = true; need_s8_to_u8 = true;
} }
} else{ } else{
auto residual_param_data = residual_param->data<T>(); auto residual_param_data = residual_param->data<T>();
auto user_residual_memory_p = handler.AcquireResidualDataMemory( auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data)); *user_residual_md, to_void_cast<T>(residual_param_data));
PADDLE_ENFORCE( PADDLE_ENFORCE(
residual_param_data != nullptr, residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion"); "Provide data if you want MKLDNN conv+elementwise_add fusion");
...@@ -630,16 +695,23 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -630,16 +695,23 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::convolution_forward> conv_p; std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::vector<float> scale_bias_data; std::vector<float> scale_bias_data;
auto scale_bias_key = key + "@scale_bias"; auto scale_bias_key = key + "@scale_bias";
auto user_bias_md_key = key + "@user_bias_md";
if (bias) { if (bias) {
const float* bias_data = bias->data<float>(); const float* bias_data = bias->data<float>();
auto user_bias_md = platform::MKLDNNMemDesc( std::shared_ptr<mkldnn::memory::desc> user_bias_md;
{bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x); if(!md_reuse){
user_bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x)));
SetMdMap(md_map, user_bias_md_key, user_bias_md);
} else{
user_bias_md = GetMdMap(md_map, user_bias_md_key);
}
auto user_bias_memory_p = auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<float>(bias_data)); handler.AcquireBiasMemory(*user_bias_md, to_void_cast<float>(bias_data));
std::shared_ptr<mkldnn::memory> bias_memory_p; std::shared_ptr<mkldnn::memory> bias_memory_p;
if(is_INT8){ if(is_INT8){
int mask_reorder = is_multi_channel? 1<<0 : 1; int mask_reorder = is_multi_channel? 1<<0 : 1;
if(scale_reuse){ if(!scale_reuse){
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
scale_bias_data.resize(count); scale_bias_data.resize(count);
#pragma omp parallel for if (count > 1) #pragma omp parallel for if (count > 1)
...@@ -689,13 +761,33 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -689,13 +761,33 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
return; return;
} }
std::vector<float> GetScaleMap(std::unordered_map<std::string, std::vector<float>> &scale_map, std::vector<float> GetScaleMap(std::unordered_map<std::string, std::vector<float>> scale_map,
const std::string& name) const { const std::string& name) const {
auto it = scale_map.find(name); auto it = scale_map.find(name);
if (it != scale_map.end()) { if (it != scale_map.end()) {
return (*it).second; return (*it).second;
} }
return {0}; return {0.0f};
}
void SetMdMap(std::unordered_map<std::string, std::shared_ptr<mkldnn::memory::desc>> &md_map,
const std::string& name, std::shared_ptr<mkldnn::memory::desc> md) const {
auto it = md_map.find(name);
if (it == md_map.end()) {
md_map[name] = md; // create new blob
} else {
(*it).second = md; // set data to existing blob
}
return;
}
std::shared_ptr<mkldnn::memory::desc> GetMdMap(std::unordered_map<std::string, std::shared_ptr<mkldnn::memory::desc>> md_map,
const std::string& name) const {
auto it = md_map.find(name);
if (it != md_map.end()) {
return (*it).second;
}
return nullptr;
} }
mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn, mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册