diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 8aecb78a73c7c5ca7d5ee771aee3e92b1bdbfc60..bcc891ec76883c7a4e4451f48389d1f70dc6a422 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -33,6 +33,35 @@ using platform::GetMKLDNNFormat; template class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { public: +struct ConvInfo{ + const paddle::framework::Tensor* input; + const paddle::framework::Tensor* bias; + const paddle::framework::Tensor* output; + const paddle::framework::Tensor* weight; + std::vector* strides; + std::vector* paddings; + std::vector* dilations; + std::vector* src_tz; + std::vector* weights_tz; + std::vector* dst_tz; + int g; +}; +struct MkldnnInfo{ + bool fuse_relu; + bool fuse_residual_conn; + bool force_fp32_output; + bool is_test; + const mkldnn::engine* mkldnn_engine; + std::vector* pipeline; + const std::string* key_conv_pd; + std::string* key; + std::shared_ptr handler; + std::shared_ptr src_memory_p; + std::shared_ptr user_src_memory_p; + std::shared_ptr dst_memory_p; + std::shared_ptr conv_p; + std::shared_ptr conv_pd; +}; void Compute(const paddle::framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), @@ -85,7 +114,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { "dilation in convolution is not implemented yet"); const T* input_data = input->data(); - const float* filter_data = filter->data(); std::vector src_tz = paddle::framework::vectorize2int(input->dims()); std::vector weights_tz = @@ -127,13 +155,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { need_s8_to_u8 = true; } - std::shared_ptr conv_p; - std::shared_ptr src_memory_p; - std::shared_ptr user_src_memory_p; - std::shared_ptr dst_memory_p; + std::shared_ptr conv_p = nullptr; + std::shared_ptr src_memory_p = nullptr; + std::shared_ptr user_src_memory_p = nullptr; + std::shared_ptr dst_memory_p = nullptr; std::vector pipeline; - std::shared_ptr conv_pd; - std::shared_ptr handler; + std::shared_ptr conv_pd = nullptr; + std::shared_ptr handler = nullptr; auto prim_key = key + "@conv_p"; auto dst_key = key + "@dst_mem_p"; @@ -142,42 +170,38 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto src_reorder_key = key + "@src_mem_preorder_p"; conv_p = std::static_pointer_cast(dev_ctx.GetBlob(prim_key)); if(conv_p == nullptr){ + struct ConvInfo convinfo; + struct MkldnnInfo mkldnninfo; + convinfo.strides = &strides; + convinfo.paddings = &paddings; + convinfo.dilations = &dilations; + convinfo.src_tz = &src_tz; + convinfo.weights_tz = &weights_tz; + convinfo.dst_tz = &dst_tz; + convinfo.g = g; + mkldnninfo.fuse_relu = fuse_relu; + mkldnninfo.fuse_residual_conn = fuse_residual_conn; + mkldnninfo.force_fp32_output = force_fp32_output; + mkldnninfo.is_test = is_test; + mkldnninfo.mkldnn_engine = &mkldnn_engine; + mkldnninfo.handler = handler; + mkldnninfo.pipeline = &pipeline; + mkldnninfo.key_conv_pd = &key_conv_pd; + mkldnninfo.key = &key; + mkldnninfo.src_memory_p = src_memory_p; + mkldnninfo.user_src_memory_p = user_src_memory_p; + mkldnninfo.dst_memory_p = dst_memory_p; + mkldnninfo.conv_p = conv_p; + mkldnninfo.conv_pd = conv_pd; if(is_INT8){ - CreateINT8Primitive(ctx, is_test, dev_ctx, mkldnn_engine, input, //filter, - bias, output, - strides, paddings, - dilations, fuse_relu, - fuse_residual_conn, input_data, - filter_data, src_tz, - weights_tz, g, - dst_tz, key, - dst_memory_p, - pipeline, - key_conv_pd, - src_memory_p, - user_src_memory_p, - conv_p, - conv_pd, - handler, - force_fp32_output); + CreateINT8Primitive(ctx, &dev_ctx, input, filter, bias, output, &convinfo, &mkldnninfo); }else{ - CreateFP32Primitive(ctx, is_test, dev_ctx, mkldnn_engine, input, //filter, - bias, output, - strides, paddings, - dilations, fuse_relu, - fuse_residual_conn, input_data, - filter_data, src_tz, - weights_tz, g, - dst_tz, key, - dst_memory_p, - pipeline, - key_conv_pd, - src_memory_p, - user_src_memory_p, - conv_p, - conv_pd, - handler); + CreateFP32Primitive(ctx, &dev_ctx, input, filter, bias, output, &convinfo, &mkldnninfo); } + //src_memory_p = mkldnninfo.src_memory_p; + //user_src_memory_p = mkldnninfo.user_src_memory_p; + dst_memory_p = mkldnninfo.dst_memory_p; + //conv_p = mkldnninfo.conv_p; } else { auto src_memory_reorder_p = std::static_pointer_cast(dev_ctx.GetBlob(src_reorder_key)); src_memory_p = std::static_pointer_cast(dev_ctx.GetBlob(src_key)); @@ -267,33 +291,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { private: void CreateFP32Primitive( - paddle::framework::ExecutionContext ctx, bool is_test, - const paddle::platform::MKLDNNDeviceContext& dev_ctx, - const mkldnn::engine& mkldnn_engine, - const paddle::framework::Tensor* input,// const paddle::framework::Tensor* filter, + const paddle::framework::ExecutionContext& ctx, + const paddle::platform::MKLDNNDeviceContext* dev_ctx, + const paddle::framework::Tensor* input, const paddle::framework::Tensor* filter, const paddle::framework::Tensor* bias, paddle::framework::Tensor* output, - std::vector strides, std::vector paddings, - std::vector dilations, bool fuse_relu, - bool fuse_residual_conn, const T* input_data, - const float* filter_data, std::vector src_tz, - std::vector weights_tz, int g, - std::vector dst_tz, const std::string key, - std::shared_ptr &dst_memory_p, - std::vector& pipeline, - const std::string &key_conv_pd, - std::shared_ptr src_memory_p, - std::shared_ptr user_src_memory_p, - std::shared_ptr conv_p, - std::shared_ptr conv_pd, - std::shared_ptr handler) const{ - - //const T* input_data = input->data(); - + ConvInfo* convinfo, MkldnnInfo* mkldnninfo) const{ + const T* input_data = input->data(); + const float* filter_data = filter->data(); auto user_src_md = platform::MKLDNNMemDesc( - {src_tz}, platform::MKLDNNGetDataType(), input->format()); + {*(convinfo->src_tz)}, platform::MKLDNNGetDataType(), input->format()); auto user_weights_md = platform::MKLDNNMemDesc( - {weights_tz}, platform::MKLDNNGetDataType(), - (g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw); + {*(convinfo->weights_tz)}, platform::MKLDNNGetDataType(), + (convinfo->g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw); /* create memory descriptor for convolution without specified format * ('any') which lets a primitive (convolution in this case) choose @@ -304,46 +313,51 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { platform::data_format_to_memory_format(data_format); auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + *(convinfo->src_tz), platform::MKLDNNGetDataType(), chosen_memory_format); auto weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + *(convinfo->weights_tz), platform::MKLDNNGetDataType(), chosen_memory_format); std::vector bias_tz; // TODO(mgallus): avoid empty vector creation. // Currently used whenever bias is != nullptr. auto dst_md = platform::MKLDNNMemDesc( - dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + *(convinfo->dst_tz), platform::MKLDNNGetDataType(), chosen_memory_format); // create a conv primitive descriptor and save it for usage in backward if (bias) { bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_md = platform::MKLDNNMemDesc( bias_tz, platform::MKLDNNGetDataType(), memory::format::x); - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, - strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn, is_test); + mkldnninfo->conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, + *convinfo->strides, *convinfo->paddings, + *mkldnninfo->mkldnn_engine, + mkldnninfo->fuse_relu, mkldnninfo->fuse_residual_conn, + mkldnninfo->is_test); } else { - conv_pd = - ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn, is_test); + mkldnninfo->conv_pd = + ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, + *convinfo->strides, *convinfo->paddings, + *mkldnninfo->mkldnn_engine, + mkldnninfo->fuse_relu, mkldnninfo->fuse_residual_conn, + mkldnninfo->is_test); } // Save conv_pd/src_memory/weights_memory for backward pass - dev_ctx.SetBlob(key_conv_pd, conv_pd); + dev_ctx->SetBlob(*mkldnninfo->key_conv_pd, mkldnninfo->conv_pd); - handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key)); + mkldnninfo->handler.reset(new platform::ConvMKLDNNHandler(mkldnninfo->conv_pd, *dev_ctx, *mkldnninfo->mkldnn_engine, *mkldnninfo->key)); // create mkldnn memory from input tensors (data/weights) - user_src_memory_p = - handler->AcquireSrcMemory(user_src_md, to_void_cast(input_data)); - auto user_weights_memory_p = handler->AcquireWeightsMemory( + mkldnninfo->user_src_memory_p = + mkldnninfo->handler->AcquireSrcMemory(user_src_md, to_void_cast(input_data)); + auto user_weights_memory_p = mkldnninfo->handler->AcquireWeightsMemory( user_weights_md, to_void_cast(filter_data)); // create reorder primitive if the input format is not the preferred one - src_memory_p = - handler->AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); - auto weights_memory_p = handler->AcquireWeightsMemoryFromPrimitive( - user_weights_memory_p, pipeline, is_test); + mkldnninfo->src_memory_p = + mkldnninfo->handler->AcquireSrcMemoryFromPrimitive(mkldnninfo->user_src_memory_p, *mkldnninfo->pipeline); + auto weights_memory_p = mkldnninfo->handler->AcquireWeightsMemoryFromPrimitive( + user_weights_memory_p, *mkldnninfo->pipeline, mkldnninfo->is_test); - if (fuse_residual_conn) { + if (mkldnninfo->fuse_residual_conn) { auto residual_param = ctx.Input("ResidualData"); auto residual_param_data = residual_param->data(); @@ -354,9 +368,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { "Output and elementwise parameter need to have the " "same dimension sizes"); - if (residual_param->format() != handler->GetDstFormat()) { + if (residual_param->format() != mkldnninfo->handler->GetDstFormat()) { auto output_data = - output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); + output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, mkldnninfo->handler->GetDstMemorySize()); auto residual_data_tz = paddle::framework::vectorize2int(residual_param->dims()); auto residual_data_type = @@ -364,21 +378,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto user_residual_md = platform::MKLDNNMemDesc( residual_data_tz, residual_data_type, residual_param->format()); - auto user_residual_memory_p = handler->AcquireResidualDataMemory( + auto user_residual_memory_p = mkldnninfo->handler->AcquireResidualDataMemory( user_residual_md, to_void_cast(residual_param_data)); - dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory( - user_residual_memory_p, to_void_cast(output_data), pipeline); + mkldnninfo->dst_memory_p = mkldnninfo->handler->AcquireDstMemoryFromResidualDataMemory( + user_residual_memory_p, to_void_cast(output_data), *mkldnninfo->pipeline); } else { output->ShareDataWith(*residual_param); auto output_data = output->mutable_data(ctx.GetPlace()); - dst_memory_p = - handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + mkldnninfo->dst_memory_p = + mkldnninfo->handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } } else { auto output_data = - output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); - dst_memory_p = - handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, mkldnninfo->handler->GetDstMemorySize()); + mkldnninfo->dst_memory_p = + mkldnninfo->handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } // create convolution op primitive @@ -387,72 +401,41 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto user_bias_md = platform::MKLDNNMemDesc( {bias_tz}, platform::MKLDNNGetDataType(), memory::format::x); auto user_bias_memory_p = - handler->AcquireBiasMemory(user_bias_md, to_void_cast(bias_data)); + mkldnninfo->handler->AcquireBiasMemory(user_bias_md, to_void_cast(bias_data)); auto bias_memory_p = - handler->AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_test); - conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p, - bias_memory_p, dst_memory_p); + mkldnninfo->handler->AcquireBiasMemoryFromPrimitive(user_bias_memory_p, *mkldnninfo->pipeline, mkldnninfo->is_test); + mkldnninfo->conv_p = mkldnninfo->handler->AcquireConvolution( + mkldnninfo->src_memory_p, weights_memory_p, + bias_memory_p, mkldnninfo->dst_memory_p); } else { - conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p, - dst_memory_p); + mkldnninfo->conv_p = mkldnninfo->handler->AcquireConvolution( + mkldnninfo->src_memory_p, weights_memory_p, + mkldnninfo->dst_memory_p); } // push primitive to stream and wait until it's executed - pipeline.push_back(*conv_p); + mkldnninfo->pipeline->push_back(*mkldnninfo->conv_p); }; void CreateINT8Primitive( - const paddle::framework::ExecutionContext& ctx, bool is_test, - const paddle::platform::MKLDNNDeviceContext & dev_ctx, - const mkldnn::engine & mkldnn_engine, - const paddle::framework::Tensor* input, //const paddle::framework::Tensor* filter, + const paddle::framework::ExecutionContext& ctx, + const paddle::platform::MKLDNNDeviceContext* dev_ctx, + const paddle::framework::Tensor* input, const paddle::framework::Tensor* filter, const paddle::framework::Tensor* bias, paddle::framework::Tensor* output, - std::vector strides, std::vector paddings, - std::vector dilations, bool fuse_relu, - bool fuse_residual_conn, const T* input_data, - const float* filter_data, std::vector src_tz, - std::vector weights_tz, int g, - std::vector dst_tz, const std::string key, - std::shared_ptr& dst_memory_p, - std::vector& pipeline, - const std::string &key_conv_pd, - std::shared_ptr src_memory_p, - std::shared_ptr user_src_memory_p, - std::shared_ptr conv_p, - std::shared_ptr conv_pd, - std::shared_ptr handler, - bool force_fp32_output) const { - //const T* input_data = input->data(); + ConvInfo* convinfo, MkldnnInfo* mkldnninfo) const { + const T* input_data = input->data(); + const float* filter_data = filter->data(); bool is_INT8 = true; auto scale_in_data = ctx.Attr("Scale_in"); auto scale_in_eltwise_data = ctx.Attr("Scale_in_eltwise"); auto scale_weights_data = ctx.Attr>("Scale_weights"); - auto scale_out_data = force_fp32_output? 1.0f : ctx.Attr("Scale_out"); + auto scale_out_data = mkldnninfo->force_fp32_output? 1.0f : ctx.Attr("Scale_out"); bool is_multi_channel = scale_weights_data.size() > 1 ? true : false; - auto scale_in_key = key + "@scale_in"; - auto scale_weights_key = key + "@scale_weights"; - auto scale_out_key = key + "@scale_out"; - auto output_shift_scale_key = key + "@output_shift_scale"; - auto sum_scale_key = key + "@sum_scale"; - auto scale_in_eltwise_key = key + "@scale_in_eltwise"; - //std::vector scale_in_data; - //std::vector scale_out_data = {1.0f}; - //std::vector scale_weights_data; - //std::vector scale_in_eltwise_data; std::vector output_shift_scale; float sum_scale = 1.0f; - - int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; - //scale_in_data = {scale_in}; - //scale_weights_data.resize(count); - //#pragma omp parallel for if (count > 1) - //for(int i=0; idata() + i); - //} - //if(!force_fp32_output) - //scale_out_data = {*(scale_out->data())}; + int count = is_multi_channel? (convinfo->g>1? (*convinfo->weights_tz)[1] * (*convinfo->weights_tz)[0] : (*convinfo->weights_tz)[0]) : 1; output_shift_scale.resize(count); #pragma omp parallel for if (count > 1) for(int i=0; i { else output_shift_scale[i] = scale_out_data / (scale_in_data * scale_weights_data[i]); } - if(fuse_residual_conn){ - //scale_in_eltwise_data = {*(scale_in_eltwise->data())}; + if(mkldnninfo->fuse_residual_conn){ sum_scale = scale_out_data / scale_in_eltwise_data; } auto user_src_md = platform::MKLDNNMemDesc( - {src_tz}, paddle::framework::ToMKLDNNDataType(input->type()), input->format()); + {*convinfo->src_tz}, paddle::framework::ToMKLDNNDataType(input->type()), input->format()); auto user_weights_md = platform::MKLDNNMemDesc( - {weights_tz}, platform::MKLDNNGetDataType(), - (g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw); + {*convinfo->weights_tz}, platform::MKLDNNGetDataType(), + ((convinfo->g) == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw); /* create memory descriptor for convolution without specified format * ('any') which lets a primitive (convolution in this case) choose @@ -483,123 +465,123 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto bias_tz = paddle::framework::vectorize2int(bias->dims()); auto src_md = platform::MKLDNNMemDesc( - src_tz, memory::data_type::u8, chosen_memory_format); + *convinfo->src_tz, memory::data_type::u8, chosen_memory_format); auto weights_md = platform::MKLDNNMemDesc( - weights_tz, memory::data_type::s8, chosen_memory_format); + *convinfo->weights_tz, memory::data_type::s8, chosen_memory_format); - auto dst_dt = fuse_relu? + auto dst_dt = mkldnninfo->fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char))); - if(force_fp32_output){ + if(mkldnninfo->force_fp32_output){ dst_dt = paddle::framework::ToMKLDNNDataType(std::type_index(typeid(float))); } - if(fuse_residual_conn){ + if(mkldnninfo->fuse_residual_conn){ auto residual = ctx.Input("ResidualData"); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type()); if(dst_dt != residual_dt) dst_dt = residual_dt; } - auto dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); + auto dst_md = platform::MKLDNNMemDesc(*convinfo->dst_tz, dst_dt, chosen_memory_format); // create a conv primitive descriptor and save it for usage in backward if (bias) { auto bias_md = platform::MKLDNNMemDesc( bias_tz, memory::data_type::s32, memory::format::x); - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, - strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn, - output_shift_scale, sum_scale, is_test); + mkldnninfo->conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, + *convinfo->strides, *convinfo->paddings, *mkldnninfo->mkldnn_engine, + mkldnninfo->fuse_relu, mkldnninfo->fuse_residual_conn, + output_shift_scale, sum_scale, mkldnninfo->is_test); } else { - conv_pd = - ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn, - output_shift_scale, sum_scale, is_test); + mkldnninfo->conv_pd = + ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, *convinfo->strides, *convinfo->paddings, + *mkldnninfo->mkldnn_engine, mkldnninfo->fuse_relu, mkldnninfo->fuse_residual_conn, + output_shift_scale, sum_scale, mkldnninfo->is_test); } // Save conv_pd/src_memory/weights_memory for backward pass - dev_ctx.SetBlob(key_conv_pd, conv_pd); + dev_ctx->SetBlob(*mkldnninfo->key_conv_pd, mkldnninfo->conv_pd); - handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key)); + mkldnninfo->handler.reset(new platform::ConvMKLDNNHandler(mkldnninfo->conv_pd, *dev_ctx, *mkldnninfo->mkldnn_engine, *mkldnninfo->key)); // create mkldnn memory from input tensors (data/weights) - user_src_memory_p = - handler->AcquireSrcMemory(user_src_md, to_void_cast(input_data)); - auto user_weights_memory_p = handler->AcquireWeightsMemory( + mkldnninfo->user_src_memory_p = + mkldnninfo->handler->AcquireSrcMemory(user_src_md, to_void_cast(input_data)); + auto user_weights_memory_p = mkldnninfo->handler->AcquireWeightsMemory( user_weights_md, to_void_cast(filter_data)); // create reorder primitive if the input format is not the preferred one - src_memory_p = - handler->AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); + mkldnninfo->src_memory_p = + mkldnninfo->handler->AcquireSrcMemoryFromPrimitive(mkldnninfo->user_src_memory_p, *mkldnninfo->pipeline); std::shared_ptr weights_memory_p; - int mask_reorder = is_multi_channel? ((g!= 1) ? (1<<1)+(1<<0) : 1<<0) : 0; - weights_memory_p = handler->AcquireWeightsMemoryFromPrimitive( - user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder); + int mask_reorder = is_multi_channel? ((convinfo->g!= 1) ? (1<<1)+(1<<0) : 1<<0) : 0; + weights_memory_p = mkldnninfo->handler->AcquireWeightsMemoryFromPrimitive( + user_weights_memory_p, *mkldnninfo->pipeline, mkldnninfo->is_test, is_INT8, scale_weights_data, mask_reorder); - if(fuse_residual_conn) { + if(mkldnninfo->fuse_residual_conn) { auto residual_param = ctx.Input("ResidualData"); PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), "Output and elementwise parameter need to have the " "same dimension sizes"); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type()); - PADDLE_ENFORCE_EQ(residual_param->format(), handler->GetDstFormat(), + PADDLE_ENFORCE_EQ(residual_param->format(), mkldnninfo->handler->GetDstFormat(), "Conv input dimension and filter dimension should be the same."); output->ShareDataWith(*residual_param); if(residual_dt == mkldnn::memory::data_type::u8){ uint8_t* output_data = output->mutable_data(ctx.GetPlace()); - dst_memory_p = - handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + mkldnninfo->dst_memory_p = + mkldnninfo->handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } else{ int8_t* output_data = output->mutable_data(ctx.GetPlace()); - dst_memory_p = - handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + mkldnninfo->dst_memory_p = + mkldnninfo->handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } - } else if(!force_fp32_output){ - if(fuse_relu){ - uint8_t* output_data = output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); - dst_memory_p = - handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + } else if(!mkldnninfo->force_fp32_output){ + if(mkldnninfo->fuse_relu){ + uint8_t* output_data = output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, mkldnninfo->handler->GetDstMemorySize()); + mkldnninfo->dst_memory_p = + mkldnninfo->handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } else{ - int8_t* output_data = output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); - dst_memory_p = - handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + int8_t* output_data = output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, mkldnninfo->handler->GetDstMemorySize()); + mkldnninfo->dst_memory_p = + mkldnninfo->handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } } else { - float* output_data = output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); - dst_memory_p = - handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + float* output_data = output->mutable_data(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, mkldnninfo->handler->GetDstMemorySize()); + mkldnninfo->dst_memory_p = + mkldnninfo->handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } // create convolution op primitive std::vector scale_bias_data; - auto scale_bias_key = key + "@scale_bias"; + auto scale_bias_key = *mkldnninfo->key + "@scale_bias"; if (bias) { const float* bias_data = bias->data(); auto user_bias_md = platform::MKLDNNMemDesc( {bias_tz}, platform::MKLDNNGetDataType(), memory::format::x); auto user_bias_memory_p = - handler->AcquireBiasMemory(user_bias_md, to_void_cast(bias_data)); + mkldnninfo->handler->AcquireBiasMemory(user_bias_md, to_void_cast(bias_data)); std::shared_ptr bias_memory_p; int mask_reorder = is_multi_channel? 1<<0 : 1; - int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; + int count = is_multi_channel? (convinfo->g>1? (*convinfo->weights_tz)[1] * (*convinfo->weights_tz)[0] : (*convinfo->weights_tz)[0]) : 1; scale_bias_data.resize(count); #pragma omp parallel for if (count > 1) for(int i=0; iAcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_test, is_INT8, scale_bias_data, mask_reorder); - conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p, - bias_memory_p, dst_memory_p); + mkldnninfo->handler->AcquireBiasMemoryFromPrimitive(user_bias_memory_p, *mkldnninfo->pipeline, mkldnninfo->is_test, is_INT8, scale_bias_data, mask_reorder); + mkldnninfo->conv_p = mkldnninfo->handler->AcquireConvolution(mkldnninfo->src_memory_p, weights_memory_p, + bias_memory_p, mkldnninfo->dst_memory_p); } else { - conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p, - dst_memory_p); + mkldnninfo->conv_p = mkldnninfo->handler->AcquireConvolution(mkldnninfo->src_memory_p, weights_memory_p, + mkldnninfo->dst_memory_p); } // push primitive to stream and wait until it's executed - pipeline.push_back(*conv_p); + mkldnninfo->pipeline->push_back(*mkldnninfo->conv_p); }; void AppendKey(std::string& key, mkldnn::memory::dims& input_dims, // NOLINT