提交 ab89c546 编写于 作者: X xiaolil1

enable both fp32 and int8 init

上级 ce7add88
...@@ -369,10 +369,68 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -369,10 +369,68 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
ctx.op().Output("Output")); ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
bool is_INT8 = ctx.HasInput("Scale_in")? true : false;
bool need_s8_to_u8 = false;
if (fuse_residual_conn && is_INT8 && fuse_relu) {
need_s8_to_u8 = true;
}
std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p;
std::shared_ptr<mkldnn::memory> dst_memory_p;
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
bool is_INT8 = ctx.HasInput("Scale_in")? true : false; auto prim_key = key + "@conv_p";
auto dst_key = key + "@dst_mem_p";
auto src_key = key + "@src_mem_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(dev_ctx.GetBlob(prim_key));
src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key));
dst_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
if (src_memory_p) {
src_memory_p->set_data_handle(to_void_cast<T>(input_data));
}
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
conv_pd = std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(dev_ctx.GetBlob(key_conv_pd));
std::shared_ptr<ConvMKLDNNHandler> handler;
if(conv_pd){
handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key));
}
if (!is_INT8 && dst_memory_p){
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>();
if (residual_param->format() != handler->GetDstFormat()) {
auto output_data =
output->mutable_data<T>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims());
auto residual_data_type =
paddle::framework::ToMKLDNNDataType(residual_param->type());
auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler->AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
} else {
output->ShareDataWith(*residual_param);
auto output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory_p->set_data_handle(to_void_cast<T>(output_data));
}
} else {
auto output_data =
output->mutable_data<T>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<T>(output_data));
}
}
if(!is_INT8){ if(!is_INT8){
if(conv_p == nullptr){
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format()); {src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
auto user_weights_md = platform::MKLDNNMemDesc( auto user_weights_md = platform::MKLDNNMemDesc(
...@@ -398,8 +456,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -398,8 +456,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
...@@ -415,22 +471,20 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -415,22 +471,20 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key); handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key));
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = auto user_src_memory_p =
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory( auto user_weights_memory_p = handler->AcquireWeightsMemory(
user_weights_md, to_void_cast<float>(filter_data)); user_weights_md, to_void_cast<float>(filter_data));
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto src_memory_p = src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler->AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( auto weights_memory_p = handler->AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test); user_weights_memory_p, pipeline, is_test);
std::shared_ptr<mkldnn::memory> dst_memory_p;
if (fuse_residual_conn) { if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>(); auto residual_param_data = residual_param->data<T>();
...@@ -442,9 +496,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -442,9 +496,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Output and elementwise parameter need to have the " "Output and elementwise parameter need to have the "
"same dimension sizes"); "same dimension sizes");
if (residual_param->format() != handler.GetDstFormat()) { if (residual_param->format() != handler->GetDstFormat()) {
auto output_data = auto output_data =
output->mutable_data<T>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler.GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
auto residual_data_tz = auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims()); paddle::framework::vectorize2int(residual_param->dims());
auto residual_data_type = auto residual_data_type =
...@@ -452,70 +506,54 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -452,70 +506,54 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_residual_md = platform::MKLDNNMemDesc( auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format()); residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler.AcquireResidualDataMemory( auto user_residual_memory_p = handler->AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data)); user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory( dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline); user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
} else { } else {
output->ShareDataWith(*residual_param); output->ShareDataWith(*residual_param);
auto output_data = output->mutable_data<T>(ctx.GetPlace()); auto output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); handler->AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
} }
} else { } else {
auto output_data = auto output_data =
output->mutable_data<T>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler.GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); handler->AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
} }
// create convolution op primitive // create convolution op primitive
std::shared_ptr<mkldnn::convolution_forward> conv_p;
if (bias) { if (bias) {
const T* bias_data = bias->data<T>(); const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x); {bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x);
auto user_bias_memory_p = auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data)); handler->AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
auto bias_memory_p = auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); handler->AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_test);
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p); bias_memory_p, dst_memory_p);
} else { } else {
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p,
dst_memory_p); dst_memory_p);
} }
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else {
pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait();
} else{ output->set_layout(DataLayout::kMKLDNN);
bool need_s8_to_u8 = false; output->set_format(GetMKLDNNFormat(*dst_memory_p));
if (fuse_residual_conn && fuse_relu) {
need_s8_to_u8 = true;
}
std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p;
std::shared_ptr<mkldnn::memory> dst_memory_p;
std::vector<primitive> pipeline;
auto prim_key = key + "@conv_p";
auto dst_key = key + "@dst_mem_p";
auto src_key = key + "@src_mem_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(dev_ctx.GetBlob(prim_key));
src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key));
dst_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
if (src_memory_p) {
src_memory_p->set_data_handle(to_void_cast<T>(input_data));
} }
} else{
if(conv_p == nullptr){ if(conv_p == nullptr){
auto* scale_in = ctx.HasInput("Scale_in") ? ctx.Input<Tensor>("Scale_in") : nullptr; auto* scale_in = ctx.HasInput("Scale_in") ? ctx.Input<Tensor>("Scale_in") : nullptr;
auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input<Tensor>("Scale_in_eltwise") : nullptr; auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input<Tensor>("Scale_in_eltwise") : nullptr;
...@@ -621,8 +659,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -621,8 +659,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); auto dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
if (bias) { if (bias) {
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, memory::data_type::s32, memory::format::x); bias_tz, memory::data_type::s32, memory::format::x);
...@@ -639,21 +675,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -639,21 +675,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key); handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key));
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = auto user_src_memory_p =
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory( auto user_weights_memory_p = handler->AcquireWeightsMemory(
user_weights_md, to_void_cast<float>(filter_data)); user_weights_md, to_void_cast<float>(filter_data));
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
src_memory_p = src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler->AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
std::shared_ptr<mkldnn::memory> weights_memory_p; std::shared_ptr<mkldnn::memory> weights_memory_p;
int mask_reorder = is_multi_channel? ((g!= 1) ? (1<<1)+(1<<0) : 1<<0) : 0; int mask_reorder = is_multi_channel? ((g!= 1) ? (1<<1)+(1<<0) : 1<<0) : 0;
weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( weights_memory_p = handler->AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder); user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder);
if(fuse_residual_conn) { if(fuse_residual_conn) {
...@@ -662,27 +698,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -662,27 +698,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Output and elementwise parameter need to have the " "Output and elementwise parameter need to have the "
"same dimension sizes"); "same dimension sizes");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type()); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type());
PADDLE_ENFORCE_EQ(residual_param->format(), handler.GetDstFormat(), PADDLE_ENFORCE_EQ(residual_param->format(), handler->GetDstFormat(),
"Conv input dimension and filter dimension should be the same."); "Conv input dimension and filter dimension should be the same.");
output->ShareDataWith(*residual_param); output->ShareDataWith(*residual_param);
if(residual_dt == mkldnn::memory::data_type::u8){ if(residual_dt == mkldnn::memory::data_type::u8){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace()); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data)); handler->AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
} else{ } else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace()); int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data)); handler->AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
} }
} else { } else {
if(fuse_relu){ if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler.GetDstMemorySize()); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data)); handler->AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
} else{ } else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler.GetDstMemorySize()); int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data)); handler->AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
} }
} }
...@@ -694,7 +730,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -694,7 +730,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x); {bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x);
auto user_bias_memory_p = auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<float>(bias_data)); handler->AcquireBiasMemory(user_bias_md, to_void_cast<float>(bias_data));
std::shared_ptr<mkldnn::memory> bias_memory_p; std::shared_ptr<mkldnn::memory> bias_memory_p;
int mask_reorder = is_multi_channel? 1<<0 : 1; int mask_reorder = is_multi_channel? 1<<0 : 1;
if(!scale_reuse){ if(!scale_reuse){
...@@ -709,11 +745,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -709,11 +745,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
scale_bias_data = GetScaleMap(scale_map, scale_bias_key); scale_bias_data = GetScaleMap(scale_map, scale_bias_key);
} }
bias_memory_p = bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_test, is_INT8, scale_bias_data, mask_reorder); handler->AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_test, is_INT8, scale_bias_data, mask_reorder);
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p); bias_memory_p, dst_memory_p);
} else { } else {
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p,
dst_memory_p); dst_memory_p);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册