diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 5fff2fa84b81ead59e98687ddb2681f87db39011..be5b1e2b32774ab9ad6f6a498a1e028fe8fab33d 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -334,7 +334,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector dilations = ctx.Attr>("dilations"); bool fuse_relu = ctx.Attr("fuse_relu"); bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); + bool force_fp32_output = ctx.Attr("force_fp32_output"); int groups = ctx.Attr("groups"); +//std::cout<<"force_fp32_output = "< { src_tz, weights_tz, strides, paddings, dilations, groups, ctx.op().Output("Output")); const std::string key_conv_pd = key + "@conv_pd"; - +//std::cout< { int8_t* output_data = output->mutable_data(ctx.GetPlace()); dst_memory_p->set_data_handle(to_void_cast(output_data)); } - } else { + } else if(!force_fp32_output){ if(fuse_relu){ uint8_t* output_data = output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); dst_memory_p->set_data_handle(to_void_cast(output_data)); @@ -452,6 +458,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { int8_t* output_data = output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); dst_memory_p->set_data_handle(to_void_cast(output_data)); } + } else { + float* output_data = output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); + dst_memory_p->set_data_handle(to_void_cast(output_data)); } } @@ -600,7 +609,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto sum_scale_key = key + "@sum_scale"; auto scale_in_eltwise_key = key + "@scale_in_eltwise"; std::vector scale_in_data; - std::vector scale_out_data; + std::vector scale_out_data = {1.0f}; std::vector scale_weights_data; std::vector scale_in_eltwise_data; std::vector output_shift_scale; @@ -619,7 +628,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { for(int i=0; idata() + i); } - scale_out_data = {*(scale_out->data())}; + if(!force_fp32_output) + scale_out_data = {*(scale_out->data())}; output_shift_scale.resize(count); #pragma omp parallel for if (count > 1) for(int i=0; i { paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char))); + if(force_fp32_output){ + dst_dt = paddle::framework::ToMKLDNNDataType(std::type_index(typeid(float))); + } + if(fuse_residual_conn){ auto residual = ctx.Input("ResidualData"); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type()); @@ -738,7 +752,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { dst_memory_p = handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } - } else { + } else if(!force_fp32_output){ if(fuse_relu){ uint8_t* output_data = output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); dst_memory_p = @@ -748,6 +762,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { dst_memory_p = handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } + } else { + float* output_data = output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); + dst_memory_p = + handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } // create convolution op primitive @@ -793,6 +811,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { output->set_layout(DataLayout::kMKLDNN); output->set_format(GetMKLDNNFormat(*dst_memory_p)); } else { +//std::cout<<"this is int8 init"<