diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index a50cc22e5bb0def54b057dcc23d2f6751eecc478..40737f4cd029b47dbd03069a2e4d29ad33121eb9 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -44,6 +44,7 @@ class FCPrimitiveFactory { void ExecuteFcPrimitive(const LoDTensor* input, const Tensor* weights, const Tensor* bias, LoDTensor* output, + const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx) { RecomputeOutputDims(ctx, input, weights, output); // If primitive has already been created and cached, don't create new one, @@ -74,8 +75,8 @@ class FCPrimitiveFactory { "input format is equal to ncw.")); } - // Transform weights to default MKL-DNN format - weights_ = TransposeWeights(weights); + weights_ = CreateWeightsMemory(weights); + // Since MKL-DNN has a lot of limitations on what the input/weights/output // dimensions should be, to simplify the code, the creation of primitive // descriptor has been divided into separate cases, based on the number @@ -112,10 +113,13 @@ class FCPrimitiveFactory { // Quantize weights and reorder to format chosen by FC primitive descriptor. QuantizeWeights(ctx, fc_prim_desc->weights_desc()); - bias_ = CreateMemory(fc_prim_desc->bias_desc(), bias); + bias_ = CreateMemoryToBeCached(fc_prim_desc->bias_desc(), bias); // If int8 is desired, quantize bias into 32-bit signed int QuantizeBias(*fc_prim_desc, ctx); + // Store weights and bias in the mkldnn cache + CacheWeightsAndBias(dev_ctx, ctx); + // Based on format determined by inner_product, create output in desired // memory format output_ = CreateDstMemory(*fc_prim_desc, ctx, output); @@ -262,14 +266,15 @@ class FCPrimitiveFactory { } // Convert data from one data format to another - mkldnn::memory Reorder(const memory::desc& src_desc, - const memory::desc& dst_desc, void* src_data) { + std::shared_ptr Reorder(const memory::desc& src_desc, + const memory::desc& dst_desc, + void* src_data) { auto src_mem = memory(src_desc, engine_, src_data); - auto dst_mem = memory(dst_desc, engine_); + auto dst_mem = std::make_shared(dst_desc, engine_); - auto reorder = mkldnn::reorder(src_mem, dst_mem); + auto reorder = mkldnn::reorder(src_mem, *dst_mem); mkldnn::stream astream(engine_); - reorder.execute(astream, src_mem, dst_mem); + reorder.execute(astream, src_mem, *dst_mem); astream.wait(); return dst_mem; @@ -277,9 +282,10 @@ class FCPrimitiveFactory { // Convert data from one data format to another and rescale it. // If the desired data type is (un)signed int8, quantization occurs here. - mkldnn::memory Reorder(const memory& src_mem, const memory::desc& dst_md, - const std::vector& scale_data) { - mkldnn::memory dst_mem = mkldnn::memory(dst_md, engine_); + std::shared_ptr ReorderWithScale( + const std::shared_ptr src_mem, const memory::desc& dst_md, + const std::vector& scale_data) { + auto dst_mem = std::make_shared(dst_md, engine_); mkldnn::primitive_attr attributes; // According to MKL-DNN's documentation mask determines along which // dimensions should the scale be applied. @@ -289,11 +295,11 @@ class FCPrimitiveFactory { // becuase we perform per-output-channel quantization int mask = CreateMask(0, scale_data.size() > 1); attributes.set_output_scales(mask, scale_data); - auto reorder = mkldnn::reorder(src_mem, dst_mem, attributes); + auto reorder = mkldnn::reorder(*src_mem, *dst_mem, attributes); mkldnn::stream astream(engine_); reorder.execute(astream, - {{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}}); + {{MKLDNN_ARG_FROM, *src_mem}, {MKLDNN_ARG_TO, *dst_mem}}); astream.wait(); return dst_mem; @@ -323,16 +329,38 @@ class FCPrimitiveFactory { return memory(desc, engine_, data); } - // Transpose weights through MKL-DNN's reorder from io to oi format. - mkldnn::memory TransposeWeights(const Tensor* weights) { + template + std::shared_ptr CreateMemoryToBeCached( + const mkldnn::memory::desc& desc, const Tensor* tensor) { + return CreateMemoryToBeCached(desc, + platform::to_void_cast(tensor->data())); + } + + std::shared_ptr CreateMemoryToBeCached( + const mkldnn::memory::desc& desc, void* data) { + return std::make_shared(desc, engine_, data); + } + + // Create weights memory and transform to default MKL-DNN format + std::shared_ptr CreateWeightsMemory(const Tensor* weights) { auto dims = framework::vectorize(weights->dims()); std::swap(dims[0], dims[1]); // Correct output dimensions auto src_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::io); auto dst_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::oi); + // Transpose weights through MKL-DNN's reorder from io to oi format. return Reorder(src_desc, dst_desc, platform::to_void_cast(weights->data())); } + void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx, + const ExecutionContext& ctx) { + const std::string key = platform::CreateKey(platform::ThreadIDasStr()); + const std::string weights_key = key + ctx.InputName("W"); + const std::string bias_key = key + ctx.InputName("Bias"); + dev_ctx.SetBlob(weights_key, weights_); + dev_ctx.SetBlob(bias_key, bias_); + } + // Compute the bias scales so that its values correspond to the // scale of data being an output of weights and input multiplication std::vector ComputeBiasScales(const ExecutionContext& ctx) { @@ -388,14 +416,14 @@ class FCPrimitiveFactory { } void QuantizeWeights(const ExecutionContext& ctx, memory::desc dst) { - weights_ = - Reorder(*weights_, dst, ctx.Attr>("Scale_weights")); + weights_ = ReorderWithScale(weights_, dst, + ctx.Attr>("Scale_weights")); } void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc, const ExecutionContext& ctx) { auto bias_scales = ComputeBiasScales(ctx); - bias_ = Reorder(*bias_, fc_prim_desc.bias_desc(), bias_scales); + bias_ = ReorderWithScale(bias_, fc_prim_desc.bias_desc(), bias_scales); } // Fuse relu into FC with activation type attribute has been set to 'relu' @@ -463,10 +491,10 @@ class FCPrimitiveFactory { private: const mkldnn::engine& engine_; - boost::optional bias_; boost::optional input_; boost::optional output_; - boost::optional weights_; + std::shared_ptr bias_; + std::shared_ptr weights_; boost::optional fc_; }; @@ -476,19 +504,13 @@ class FCPrimitiveFactory { template static std::shared_ptr> GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx, - const ExecutionContext& ctx, const Tensor* input, - const Tensor* weights, - const mkldnn::engine& mkldnn_engine) { - const std::string key = platform::CreateKey( - platform::ThreadIDasStr(), input->format(), input->dims()[0], - framework::vectorize(weights->dims()), ctx.OutputName("Out")); - + const std::string& key) { auto prim_creator = std::static_pointer_cast>( dev_ctx.GetBlob(key)); if (prim_creator == nullptr) { - prim_creator = - std::make_shared>(mkldnn_engine); + prim_creator = std::make_shared>( + dev_ctx.GetEngine()); dev_ctx.SetBlob(key, prim_creator); } @@ -498,24 +520,24 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx, // Choose appropriate primitive factory implementation based on inferred // output type (uint8, int8 or float). template -static void ExecuteFc(const MKLDNNDeviceContext& dev_ctx, - const ExecutionContext& ctx, const LoDTensor* input, +static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input, const Tensor* w, const Tensor* bias, LoDTensor* output, - const mkldnn::engine& mkldnn_engine, bool fuse_relu, - bool force_fp32_output) { + bool fuse_relu, bool force_fp32_output) { + auto& dev_ctx = ctx.template device_context(); + const std::string prim_key = platform::CreateKey( + platform::ThreadIDasStr(), input->format(), input->dims()[0], + framework::vectorize(w->dims()), ctx.OutputName("Out")); constexpr bool is_int8 = std::is_same::value || std::is_same::value; if (!is_int8 || force_fp32_output) { - GetPrimitiveFactory(dev_ctx, ctx, input, w, mkldnn_engine) - ->ExecuteFcPrimitive(input, w, bias, output, ctx); + GetPrimitiveFactory(dev_ctx, prim_key) + ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); } else if (fuse_relu) { - GetPrimitiveFactory(dev_ctx, ctx, input, w, - mkldnn_engine) - ->ExecuteFcPrimitive(input, w, bias, output, ctx); + GetPrimitiveFactory(dev_ctx, prim_key) + ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); } else { - GetPrimitiveFactory(dev_ctx, ctx, input, w, - mkldnn_engine) - ->ExecuteFcPrimitive(input, w, bias, output, ctx); + GetPrimitiveFactory(dev_ctx, prim_key) + ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); } } @@ -526,9 +548,6 @@ class FCMKLDNNOpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ( platform::is_cpu_place(ctx.GetPlace()), true, platform::errors::PreconditionNotMet("FC MKL-DNN must use CPUPlace.")); - auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - auto input = ctx.Input("Input"); auto w = ctx.Input("W"); auto bias = ctx.Input("Bias"); @@ -537,8 +556,8 @@ class FCMKLDNNOpKernel : public framework::OpKernel { bool fuse_relu = ctx.Attr("activation_type") == "relu"; bool force_fp32_output = ctx.Attr("force_fp32_output"); - ExecuteFc(dev_ctx, ctx, input, w, bias, output, mkldnn_engine, - fuse_relu, force_fp32_output); + ExecuteFc(ctx, input, w, bias, output, fuse_relu, + force_fp32_output); output->set_layout(DataLayout::kMKLDNN); }