diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 7404972ea7cca0177a157c127055abaaf7e91046..1d6dcad6e40e8ffeedd86d9168401483ac02ddaf 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -16,10 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" - -namespace phi { -class DenseTensor; -} // namespace phi +#include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { @@ -34,388 +31,127 @@ using framework::DDim; using framework::ExecutionContext; using framework::LoDTensor; using framework::Tensor; +using phi::vectorize; using platform::GetMKLDNNFormat; using platform::MKLDNNDeviceContext; +using platform::MKLDNNGetDataType; using platform::to_void_cast; +template +constexpr bool IsInt8() { + return std::is_same::value || std::is_same::value; +} + template -class FCPrimitiveFactory { +class FCMKLDNNHandler + : public platform::MKLDNNHandlerNoCachingT { public: - explicit FCPrimitiveFactory(const dnnl::engine& engine) : engine_(engine) {} - - void ExecuteFcPrimitive(const LoDTensor* input, - const Tensor* weights, - const Tensor* bias, - LoDTensor* output, - const MKLDNNDeviceContext& dev_ctx, - const ExecutionContext& ctx) { - RecomputeOutputDims(ctx, input, weights, output); - // If primitive has already been created and cached, don't create new one, - // but update input and output data pointers and return it. - if (fc_) { - UpdateDataPointers(ctx, output, input); - this->Execute(); - return; - } // Otherwise, create a new one. - - auto in_col_dims = ctx.Attr("in_num_col_dims"); - PADDLE_ENFORCE_LE( - in_col_dims, - 2, - platform::errors::Unimplemented( - "DNNL FC doesn't support in_num_col_dims parameter to " - "be higher than " - "2.")); - if (in_col_dims == 2) { - PADDLE_ENFORCE_EQ( - input->dims().size(), - 3, - platform::errors::Unimplemented( - "DNNL FC only supports in_num_col_dims equal to 2 when " - "3 dim input is provided.")); - PADDLE_ENFORCE_EQ( - input->format(), - MKLDNNMemoryFormat::ncw, - platform::errors::Unimplemented( - "DNNL FC only supports in_num_col_dims equal to 2 when " - "input format is equal to ncw.")); + FCMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, + const platform::MKLDNNDeviceContext& dev_ctx, + const Tensor* x, + const Tensor* weights, + const Tensor* bias, + Tensor* out, + const int in_num_col_dims, + dnnl::engine mkldnn_engine, + platform::Place cpu_place) + : platform::MKLDNNHandlerNoCachingT( + mkldnn_engine, cpu_place), + dev_ctx_(dev_ctx) { + this->memory_key_ = ctx.InputName("W"); + + auto x_vec_dims = phi::vectorize(x->dims()); + auto weights_vec_dims = phi::vectorize(weights->dims()); + + int MB = 1; + for (int i = 0; i < in_num_col_dims; ++i) { + MB *= x_vec_dims[i]; } - weights_ = CreateWeightsMemory(weights); - - // Since MKL-DNN has a lot of limitations on what the input/weights/output - // dimensions should be, to simplify the code, the creation of primitive - // descriptor has been divided into separate cases, based on the number - // of input dimensions. - size_t input_dim_num = input->dims().size(); - paddle::optional fc_prim_desc; - memory::desc usr_weights_desc = {}; - switch (input_dim_num) { - case 2: - fc_prim_desc = - Create2DFcPrimDescriptor(input, weights, bias, output, ctx); - usr_weights_desc = Create2DUserWeightsDesc(); - break; - case 3: - fc_prim_desc = - Create3DFcPrimDescriptor(input, weights, bias, output, ctx); - usr_weights_desc = Create3DUserWeightsDesc(weights); - break; - case 4: - fc_prim_desc = - Create4DFcPrimDescriptor(input, weights, bias, output, ctx); - usr_weights_desc = Create4DUserWeightsDesc(input, weights); - break; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "DNNL FC doesn't support input dims different than 2, 3, 4.")); - break; + int IC = 1; + for (size_t i = in_num_col_dims; i < x_vec_dims.size(); ++i) { + IC *= x_vec_dims[i]; } - input_ = CreateMemory(fc_prim_desc->src_desc(), input); - // Update weights format inside of its memory - weights_ = Reorder( - usr_weights_desc, usr_weights_desc, weights_->get_data_handle()); - - // Quantize weights and reorder to format chosen by FC primitive descriptor. - QuantizeWeights(ctx, fc_prim_desc->weights_desc()); - bias_ = CreateMemoryToBeCached(fc_prim_desc->bias_desc(), bias); - // If int8 is desired, quantize bias into 32-bit signed int - QuantizeBias(*fc_prim_desc, ctx); + int OC = weights_vec_dims[1]; - // Store weights and bias in the mkldnn cache - CacheWeightsAndBias(dev_ctx, ctx); + dnnl::memory::desc bias_md; - // Based on format determined by inner_product, create output in desired - // memory format - output_ = CreateDstMemory(*fc_prim_desc, ctx, output); + auto src_md = dnnl::memory::desc( + {MB, IC}, MKLDNNGetDataType(), dnnl::memory::format_tag::any); + auto weights_md = dnnl::memory::desc( + {OC, IC}, MKLDNNGetDataType(), dnnl::memory::format_tag::any); + auto dst_md = dnnl::memory::desc( + {MB, OC}, MKLDNNGetDataType(), dnnl::memory::format_tag::any); + if (bias) { + bias_md = dnnl::memory::desc({bias->numel()}, + MKLDNNGetDataType(), + dnnl::memory::format_tag::a); + } - // Return MKL-DNN primitive ready to be fed into pipeline and executed - fc_ = inner_product_forward(*fc_prim_desc); - this->Execute(); - } + dnnl::primitive_attr attrs; + HandlePostOps(ctx, &attrs); - void Execute() { - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - if (bias_) { - fc_->execute(astream, - {{DNNL_ARG_SRC, *input_}, - {DNNL_ARG_WEIGHTS, *weights_}, - {DNNL_ARG_BIAS, *bias_}, - {DNNL_ARG_DST, *output_}}); - } else { - fc_->execute(astream, - {{DNNL_ARG_SRC, *input_}, - {DNNL_ARG_WEIGHTS, *weights_}, - {DNNL_ARG_DST, *output_}}); - } - astream.wait(); + this->AcquireForwardPrimitiveDescriptor(attrs, + prop_kind::forward_inference, + src_md, + weights_md, + bias_md, + dst_md); } private: - // DNNL always returns 2-dimensional data block as a result of computing - // inner product. Hence the format 'nc' is always set for its output - // primitive. Therefore, function SetOutputFormat is needed to choose - // an appropriate format based on the number of input dimensions and - // format of an input tensor. - void SetOutputFormat(MKLDNNMemoryFormat in_format, Tensor* out) { - int dim_num = out->dims().size(); - // In case of 2 dims, we set the only possible format, nc - if (dim_num == 2) { - out->set_format(MKLDNNMemoryFormat::nc); - out->set_mem_desc({phi::vectorize(out->dims()), - platform::MKLDNNGetDataType(), - out->format()}); - // In case of 3 dims, we generate a format that is based on number - // of output dims and the layout of input format (nchw or nhwc). - } else if (dim_num == 3) { - if (in_format == MKLDNNMemoryFormat::nwc || - in_format == MKLDNNMemoryFormat::nhwc) { - out->set_format( - platform::MKLDNNFormatForSize(dim_num, MKLDNNMemoryFormat::nhwc)); - } else { - out->set_format( - platform::MKLDNNFormatForSize(dim_num, MKLDNNMemoryFormat::nchw)); - } - // In any other case we overwrite the output format with the input one. - } else { - out->set_format(in_format); - } - } + void HandlePostOps(const paddle::framework::ExecutionContext& ctx, + dnnl::primitive_attr* attrs) { + static std::unordered_map algo_map = { + {"relu", dnnl::algorithm::eltwise_relu}, + {"gelu", dnnl::algorithm::eltwise_gelu}, + {"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh}, + {"gelu_erf", dnnl::algorithm::eltwise_gelu_erf}, + {"tanh", dnnl::algorithm::eltwise_tanh}, + {"sigmoid", dnnl::algorithm::eltwise_logistic}, + {"hard_swish", dnnl::algorithm::eltwise_hardswish}, + {"mish", dnnl::algorithm::eltwise_mish}}; - void UpdateDataPointers(const ExecutionContext& ctx, - Tensor* out, - const Tensor* in) { - input_->set_data_handle(to_void_cast(in->data())); - output_->set_data_handle(out->mutable_data(ctx.GetPlace())); - // If the primitive exists, but the output tensor has changed its - // variable, update its format to what has been determined in first - // call to CreateFcPrimitive method. - if (out->format() == MKLDNNMemoryFormat::undef) { - SetOutputFormat(in->format(), out); + std::vector output_shift_scale; + float scale = 1.0f; + if (IsInt8()) { + std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx); + int mask = CreateMask(1, output_shift_scale.size() > 1); + attrs->set_output_scales(mask, output_shift_scale); } - } - - dnnl::inner_product_forward::primitive_desc Create2DFcPrimDescriptor( - const LoDTensor* input, - const Tensor* weights, - const Tensor* bias, - LoDTensor* output, - const ExecutionContext& ctx) { - auto src_desc = CreateMemDescriptor(input, MKLDNNMemoryFormat::any); - auto weight_dims = Get2DWeightDimsForDNNL(weights); - auto weights_desc = - CreateMemDescriptor(weight_dims, MKLDNNMemoryFormat::any); - auto bias_desc = CreateMemDescriptor(bias, MKLDNNMemoryFormat::x); - auto dst_desc = CreateMemDescriptor(output, MKLDNNMemoryFormat::any); - const auto attrs = CreateFCAttrs(ctx); - return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); - } - std::vector Get2DWeightDimsForDNNL(const Tensor* weights) { - auto dims = phi::vectorize(weights->dims()); - std::swap(dims[0], dims[1]); // swap input dim with output dim - return dims; - } - - memory::desc Create2DUserWeightsDesc() { return weights_->get_desc(); } - - dnnl::inner_product_forward::primitive_desc Create3DFcPrimDescriptor( - const LoDTensor* input, - const Tensor* weights, - const Tensor* bias, - LoDTensor* output, - const ExecutionContext& ctx) { - auto input_dims = phi::vectorize(input->dims()); - std::vector new_input_dims = { - input_dims[0] * input_dims[1], input_dims[2], 1}; - auto src_desc = - CreateMemDescriptor(new_input_dims, MKLDNNMemoryFormat::any); - - auto weight_dims = Get3DWeightDimsForDNNL(weights); - auto weights_desc = - CreateMemDescriptor(weight_dims, MKLDNNMemoryFormat::any); - - auto bias_desc = CreateMemDescriptor(bias, MKLDNNMemoryFormat::x); - - auto dst_dims = {input_dims[0] * input_dims[1], weight_dims[0]}; - auto dst_desc = - CreateMemDescriptor(dst_dims, MKLDNNMemoryFormat::any); - const auto attrs = CreateFCAttrs(ctx); - return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); - } - - std::vector Get3DWeightDimsForDNNL(const Tensor* weights) { - auto paddle_w_dims = phi::vectorize(weights->dims()); - return {paddle_w_dims[1], paddle_w_dims[0], 1}; - } - - memory::desc Create3DUserWeightsDesc(const Tensor* weights) { - auto dims = Get3DWeightDimsForDNNL(weights); - return CreateMemDescriptor(dims, MKLDNNMemoryFormat::oiw); - } - - dnnl::inner_product_forward::primitive_desc Create4DFcPrimDescriptor( - const LoDTensor* input, - const Tensor* weights, - const Tensor* bias, - LoDTensor* output, - const ExecutionContext& ctx) { - auto src_desc = CreateMemDescriptor(input, MKLDNNMemoryFormat::any); - // Since MKL-DNN doesn't support 4D column-major data formats in - // inner_product primitive, transpose the weights to be in - // row-major format - auto dims = Get4DWeightDimsForDNNL(input, weights); - auto weights_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::any); - auto bias_desc = CreateMemDescriptor(bias, MKLDNNMemoryFormat::x); - auto dst_desc = CreateMemDescriptor(output, MKLDNNMemoryFormat::any); - const auto attrs = CreateFCAttrs(ctx); - return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); - } + dnnl::post_ops post_ops; - std::vector Get4DWeightDimsForDNNL(const LoDTensor* input, - const Tensor* weights) { - auto old_w_dims = phi::vectorize(weights->dims()); - auto old_in_dims = phi::vectorize(input->dims()); - auto dims = {old_w_dims[1], old_in_dims[1], old_in_dims[2], old_in_dims[3]}; - return dims; - } - - memory::desc Create4DUserWeightsDesc(const LoDTensor* input, - const Tensor* weights) { - auto dims = Get4DWeightDimsForDNNL(input, weights); - return CreateMemDescriptor(dims, MKLDNNMemoryFormat::oihw); - } - - // Convert data from one data format to another - std::shared_ptr Reorder(const memory::desc& src_desc, - const memory::desc& dst_desc, - void* src_data) { - auto src_mem = memory(src_desc, engine_, src_data); - auto dst_mem = std::make_shared(dst_desc, engine_); - - auto reorder = dnnl::reorder(src_mem, *dst_mem); - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - - { - platform::RecordEvent record_reorder( - "int_reorder", - platform::TracerEventType::UserDefined, - 2, - platform::EventRole::kUniqueOp); - reorder.execute(astream, src_mem, *dst_mem); - astream.wait(); + constexpr float sum_scale = 1.0f; + if (ctx.HasAttr("fuse_residual_connection") && + ctx.Attr("fuse_residual_connection")) { + post_ops.append_sum(sum_scale); } - return dst_mem; - } + std::string activation_type = ctx.Attr("activation_type"); - // Convert data from one data format to another and rescale it. - // If the desired data type is (un)signed int8, quantization occurs here. - std::shared_ptr ReorderWithScale( - const std::shared_ptr src_mem, - const memory::desc& dst_md, - const std::vector& scale_data) { - auto dst_mem = std::make_shared(dst_md, engine_); - dnnl::primitive_attr attributes; - // According to MKL-DNN's documentation mask determines along which - // dimensions should the scale be applied. - // 0 - Single scale applied to whole tensor - // 1 - Apply Scale along a slice of each dimension which index is 1. - // In case of weights quantization, that dimension is output, - // becuase we perform per-output-channel quantization - int mask = CreateMask(0, scale_data.size() > 1); - attributes.set_output_scales(mask, scale_data); - auto reorder = dnnl::reorder(*src_mem, *dst_mem, attributes); + if (activation_type.empty() == false) { + constexpr float alpha = 0.0f; + constexpr float beta = 0.0f; - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - { - platform::RecordEvent record_reorder( - "int_reorder", - platform::TracerEventType::UserDefined, - 2, - platform::EventRole::kUniqueOp); - reorder.execute(astream, - {{DNNL_ARG_FROM, *src_mem}, {DNNL_ARG_TO, *dst_mem}}); - astream.wait(); + post_ops.append_eltwise(scale, algo_map[activation_type], alpha, beta); } - - return dst_mem; - } - - template - static dnnl::memory::desc CreateMemDescriptor( - const std::vector& dims, MKLDNNMemoryFormat format) { - return platform::MKLDNNMemDesc( - dims, platform::MKLDNNGetDataType(), format); - } - - template - static dnnl::memory::desc CreateMemDescriptor(const Tensor* tensor, - MKLDNNMemoryFormat format) { - auto dims = phi::vectorize(tensor->dims()); - return CreateMemDescriptor(dims, format); - } - - template - dnnl::memory CreateMemory(const dnnl::memory::desc& desc, - const Tensor* tensor) { - return CreateMemory(desc, platform::to_void_cast(tensor->data())); - } - - dnnl::memory CreateMemory(const dnnl::memory::desc& desc, void* data) { - return memory(desc, engine_, data); - } - - template - std::shared_ptr CreateMemoryToBeCached( - const dnnl::memory::desc& desc, const Tensor* tensor) { - return CreateMemoryToBeCached(desc, - platform::to_void_cast(tensor->data())); - } - - std::shared_ptr CreateMemoryToBeCached( - const dnnl::memory::desc& desc, void* data) { - return std::make_shared(desc, engine_, data); - } - - // Create weights memory and transform to default MKL-DNN format - std::shared_ptr CreateWeightsMemory(const Tensor* weights) { - auto dims = phi::vectorize(weights->dims()); - std::swap(dims[0], dims[1]); // Correct output dimensions - auto src_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::io); - auto dst_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::oi); - // Transpose weights through MKL-DNN's reorder from io to oi format. - return Reorder(src_desc, - dst_desc, - platform::to_void_cast(weights->data())); - } - - void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx, - const ExecutionContext& ctx) { - std::string key = platform::CreateKey(dev_ctx); - key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); - - const std::string weights_key = key + ctx.InputName("W"); - const std::string bias_key = key + ctx.InputName("Bias"); - dev_ctx.SetBlob(weights_key, weights_); - dev_ctx.SetBlob(bias_key, bias_); + attrs->set_post_ops(post_ops); } // Compute the bias scales so that its values correspond to the // scale of data being an output of weights and input multiplication - std::vector ComputeBiasScales(const ExecutionContext& ctx) { - auto scale_in_data = ctx.Attr("Scale_in"); - auto scale_weights_data = ctx.Attr>("Scale_weights"); - const size_t weight_scales_num = scale_weights_data.size(); - std::vector bias_scales(weight_scales_num); + std::vector ComputeBiasScales( + const float scale_in, const std::vector& scale_weights) { + std::vector bias_scales(scale_weights.size()); -#pragma omp parallel for - for (size_t i = 0; i < weight_scales_num; i++) { - if (scale_weights_data[i] == 0.0) + for (size_t i = 0; i < bias_scales.size(); ++i) { + if (scale_weights[i] == 0.0) bias_scales[i] = 1.0f; else - bias_scales[i] = scale_in_data * scale_weights_data[i]; + bias_scales[i] = scale_in * scale_weights[i]; } return bias_scales; @@ -444,7 +180,6 @@ class FCPrimitiveFactory { const size_t weight_scales_num = scale_weights_data.size(); std::vector output_shift_scale(weight_scales_num); -#pragma omp parallel for for (size_t i = 0; i < weight_scales_num; i++) { if (scale_weights_data[i] == 0.0) output_shift_scale[i] = inner_scale; @@ -464,131 +199,218 @@ class FCPrimitiveFactory { return is_multi_channel_quantizied ? 1 << slice_dimension : 0; } - void QuantizeWeights(const ExecutionContext& ctx, memory::desc dst) { - weights_ = ReorderWithScale( - weights_, dst, ctx.Attr>("Scale_weights")); - } + std::shared_ptr AcquireMemoryWithReorderAndAttrs( + const dnnl::memory::desc& user_md, + const dnnl::memory::desc& target_md, + void* ptr, + const dnnl::primitive_attr& attrs) { + std::shared_ptr target_memory_p; - void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc, - const ExecutionContext& ctx) { - auto bias_scales = ComputeBiasScales(ctx); - bias_ = ReorderWithScale(bias_, fc_prim_desc.bias_desc(), bias_scales); - } + auto user_memory_p = + std::make_shared(user_md, this->engine_, ptr); + target_memory_p = std::make_shared(target_md, this->engine_); + auto reorder_p = std::make_shared( + *user_memory_p, *target_memory_p, attrs); - dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) { - dnnl::primitive_attr attributes; - dnnl::post_ops post_operations; + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + reorder_p->execute( + astream, + {{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}}); + astream.wait(); - std::vector output_shift_scale; - float scale; - std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx); - int mask = CreateMask(1, output_shift_scale.size() > 1); - attributes.set_output_scales(mask, output_shift_scale); + return target_memory_p; + } - float sum_scale = 1.0f; - if (ctx.HasAttr("fuse_residual_connection") && - ctx.Attr("fuse_residual_connection")) { - post_operations.append_sum(sum_scale); - } + std::string memory_key_; + const platform::MKLDNNDeviceContext& dev_ctx_; - if (ctx.Attr("activation_type") == "relu") { - constexpr float negative_slope = 0.0f; - constexpr float placeholder = 1.0f; // beta - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_relu, negative_slope, placeholder); - } else if (ctx.Attr("activation_type") == "gelu") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_gelu, alpha, beta); - } else if (ctx.Attr("activation_type") == "gelu_tanh") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_gelu_tanh, alpha, beta); - } else if (ctx.Attr("activation_type") == "gelu_erf") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_gelu_erf, alpha, beta); - } else if (ctx.Attr("activation_type") == "tanh") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_tanh, alpha, beta); - } else if (ctx.Attr("activation_type") == "sigmoid") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_logistic, alpha, beta); - } else if (ctx.Attr("activation_type") == "mish") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_mish, alpha, beta); - } else if (ctx.Attr("activation_type") == "hard_swish") { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_hardswish, alpha, beta); + public: + std::shared_ptr AcquireSrcMemoryWithReorder(const Tensor* x) { + const T_in* x_data = x->data(); + + auto user_md = x->mem_desc(); + if (x->dims().size() != 2) { + // reshape restrictions are always satisfied because in case of 3 or 4 dim + // input, plain layout is enforced + user_md = user_md.reshape(this->fwd_pd_->src_desc().dims()); } - attributes.set_post_ops(post_operations); - return attributes; + return this->AcquireMemoryWithReorder( + user_md, this->fwd_pd_->src_desc(), to_void_cast(x_data)); } - dnnl::inner_product_forward::primitive_desc CreateFcPrimDesc( - const dnnl::memory::desc& input_desc, - const dnnl::memory::desc& weights_desc, - const dnnl::memory::desc& bias_desc, - const dnnl::memory::desc& dst_desc, - const dnnl::primitive_attr& attrs) { - auto fc_desc = inner_product_forward::desc(prop_kind::forward_scoring, - input_desc, - weights_desc, - bias_desc, - dst_desc); + std::shared_ptr AcquireBiasMemoryWithReorder( + const Tensor* bias, + const float scale_in, + const std::vector& scale_weights) { + const float* bias_data = bias->data(); + + if (IsInt8() == false) { + // for BF16/FP32 bias is 1D and has no scales, so reorder is not needed + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->bias_desc(), + to_void_cast(bias_data)); + } else { + const std::string bias_key = this->memory_key_ + "@bias"; + auto memory_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(bias_key)); + + if (!memory_p) { + const auto& scale_data = ComputeBiasScales(scale_in, scale_weights); + dnnl::primitive_attr attrs; + + int mask = CreateMask(0, scale_data.size() > 1); + attrs.set_output_scales(mask, scale_data); + + auto user_md = dnnl::memory::desc({bias->dims()[0]}, + MKLDNNGetDataType(), + dnnl::memory::format_tag::a); + + memory_p = this->AcquireMemoryWithReorderAndAttrs( + user_md, + this->fwd_pd_->bias_desc(), + to_void_cast(bias_data), + attrs); + } + return memory_p; + } + } + + std::shared_ptr AcquireWeightsMemoryWithReorder( + const Tensor* weights, const std::vector& scale_data) { + const std::string weights_key = this->memory_key_ + "@weights"; + auto memory_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(weights_key)); - return inner_product_forward::primitive_desc(fc_desc, attrs, engine_); + if (!memory_p) { + const float* weights_data = weights->data(); + auto weights_dims = this->fwd_pd_->weights_desc().dims(); + + auto user_md = dnnl::memory::desc(weights_dims, + MKLDNNGetDataType(), + dnnl::memory::format_tag::io); + + if (IsInt8()) { + dnnl::primitive_attr attrs; + int mask = CreateMask(0, scale_data.size() > 1); + attrs.set_output_scales(mask, scale_data); + + memory_p = this->AcquireMemoryWithReorderAndAttrs( + user_md, + this->fwd_pd_->weights_desc(), + to_void_cast(weights_data), + attrs); + } else { + memory_p = + this->AcquireMemoryWithReorder(user_md, + this->fwd_pd_->weights_desc(), + to_void_cast(weights_data)); + } + + this->dev_ctx_.SetBlob(weights_key, memory_p); + } + return memory_p; } - // Create output memory based on output tensor and inner_product - // primitive descriptor format chosen for output - dnnl::memory CreateDstMemory( - const dnnl::inner_product_forward::primitive_desc& fc_prim_desc, - const ExecutionContext& ctx, - Tensor* output) { + std::shared_ptr AcquireCustomDstMemory( + const ExecutionContext& ctx, Tensor* out) { if (ctx.HasAttr("fuse_residual_connection") && ctx.Attr("fuse_residual_connection")) { auto* residual_param = ctx.Output("ResidualData"); PADDLE_ENFORCE_EQ( - output->dims(), + out->dims(), residual_param->dims(), platform::errors::InvalidArgument( "Output and elementwise parameter need to have the " "same dimension sizes, but got output's dimension = %d" " and residual param's dimension =%d .", - output->dims().size(), + out->dims().size(), residual_param->dims().size())); - output->ShareDataWith(*residual_param); + out->ShareDataWith(*residual_param); } + return this->template AcquireDstMemory(out); + } +}; - auto dst_desc = fc_prim_desc.dst_desc(); - auto buffer_size = dst_desc.get_size(); - T_out* output_data = - output->mutable_data(ctx.GetPlace(), buffer_size); - memory dst_mem(dst_desc, engine_, to_void_cast(output_data)); - SetOutputFormat(ctx.Input("Input")->format(), output); +template +class FCMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + bool force_fp32_output = ctx.Attr("force_fp32_output"); + bool fuse_relu = ctx.Attr("activation_type") == "relu"; - return dst_mem; + if (force_fp32_output) { + this->RunKernel(ctx); + } else if (IsInt8()) { + if (fuse_relu) { + this->RunKernel(ctx); + } else { + this->RunKernel(ctx); + } + } else { + this->RunKernel(ctx); + } + } + + template + void RunKernel(const framework::ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + const auto* x = ctx.Input("Input"); + const auto* weights = ctx.Input("W"); + const auto* bias = ctx.Input("Bias"); + auto out = ctx.Output("Out"); + + auto in_col_dims = ctx.Attr("in_num_col_dims"); + + const float scale_in = ctx.Attr("Scale_in"); + const auto& scale_weights = ctx.Attr>("Scale_weights"); + + RecomputeOutputDims(ctx, x, weights, out); + + FCMKLDNNHandler handler(ctx, + dev_ctx, + x, + weights, + bias, + out, + in_col_dims, + mkldnn_engine, + ctx.GetPlace()); + + auto src_memory_p = handler.AcquireSrcMemoryWithReorder(x); + auto weights_memory_p = + handler.AcquireWeightsMemoryWithReorder(weights, scale_weights); + auto dst_memory_p = handler.AcquireCustomDstMemory(ctx, out); + + auto fc_p = handler.AcquireForwardPrimitive(); + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + + std::unordered_map fc_args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + if (bias) { + auto bias_memory_p = + handler.AcquireBiasMemoryWithReorder(bias, scale_in, scale_weights); + fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p}); + } + + fc_p->execute(astream, fc_args); + astream.wait(); + + out->set_mem_desc( + dst_memory_p->get_desc().reshape(phi::vectorize(out->dims()))); } void RecomputeOutputDims(const ExecutionContext& ctx, - const LoDTensor* input, - const Tensor* w, - LoDTensor* output) { + const LoDTensor* x, + const Tensor* weights, + LoDTensor* out) const { int in_num_col_dims = ctx.Attr("in_num_col_dims"); bool padding_weights = ctx.Attr("padding_weights"); PADDLE_ENFORCE_EQ(padding_weights, @@ -596,102 +418,16 @@ class FCPrimitiveFactory { platform::errors::PermissionDenied( "Weight padding in fc can not be used in MKLDNN.")); std::vector output_dims; - FCOutputSize(input->dims(), - w->dims(), + FCOutputSize(x->dims(), + weights->dims(), output_dims, in_num_col_dims, padding_weights); - output->Resize(phi::make_ddim(output_dims)); - output->set_lod(input->lod()); + out->Resize(phi::make_ddim(output_dims)); + out->set_lod(x->lod()); } - - private: - const dnnl::engine& engine_; - paddle::optional input_; - paddle::optional output_; - std::shared_ptr bias_; - std::shared_ptr weights_; - paddle::optional fc_; }; -// Attempt to fetch cached primitive factory based on provided parameters -// of input format, weight dimensions and output name. -// If not cached, create a new one. -template -static std::shared_ptr> -GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx, - const std::string& key) { - auto prim_creator = - std::static_pointer_cast>( - dev_ctx.GetBlob(key)); - if (prim_creator == nullptr) { - prim_creator = std::make_shared>( - dev_ctx.GetEngine()); - dev_ctx.SetBlob(key, prim_creator); - } - - return prim_creator; -} - -// Choose appropriate primitive factory implementation based on inferred -// output type (uint8, int8 or float). -template -static void ExecuteFc(const ExecutionContext& ctx, - const LoDTensor* input, - const Tensor* w, - const Tensor* bias, - LoDTensor* output, - bool fuse_relu, - bool force_fp32_output) { - auto& dev_ctx = ctx.template device_context(); - std::string prim_key = platform::CreateKey(dev_ctx, - input->format(), - input->dims()[0], - phi::vectorize(w->dims()), - ctx.OutputName("Out")); - prim_key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, prim_key); - - constexpr bool is_int8 = - std::is_same::value || std::is_same::value; - bool is_bfloat16 = std::is_same::value; - if ((!is_int8 && !is_bfloat16) || force_fp32_output) { - GetPrimitiveFactory(dev_ctx, prim_key) - ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); - } else if (is_bfloat16) { - GetPrimitiveFactory(dev_ctx, prim_key) - ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); - } else if (fuse_relu) { - GetPrimitiveFactory(dev_ctx, prim_key) - ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); - } else { - GetPrimitiveFactory(dev_ctx, prim_key) - ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); - } -} - -template -class FCMKLDNNOpKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), - true, - platform::errors::PreconditionNotMet("FC MKL-DNN must use CPUPlace.")); - platform::MKLDNNDeviceContext::tls().log_lib_version(); - auto input = ctx.Input("Input"); - auto w = ctx.Input("W"); - auto bias = ctx.Input("Bias"); - auto output = ctx.Output("Out"); - - bool fuse_relu = ctx.Attr("activation_type") == "relu"; - bool force_fp32_output = ctx.Attr("force_fp32_output"); - - ExecuteFc( - ctx, input, w, bias, output, fuse_relu, force_fp32_output); - - output->set_layout(DataLayout::kMKLDNN); - } -}; } // namespace operators } // namespace paddle @@ -704,7 +440,7 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, ::paddle::platform::CPUPlace, FP32, ops::kFCMKLDNNFP32, - ops::FCMKLDNNOpKernel); + ops::FCMKLDNNKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( fc, @@ -712,19 +448,19 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( ::paddle::platform::CPUPlace, BF16, ops::kFCMKLDNNFP32, - ops::FCMKLDNNOpKernel); + ops::FCMKLDNNKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, ::paddle::platform::CPUPlace, U8, ops::kFCMKLDNNINT8, - ops::FCMKLDNNOpKernel); + ops::FCMKLDNNKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, ::paddle::platform::CPUPlace, S8, ops::kFCMKLDNNINT8, - ops::FCMKLDNNOpKernel); + ops::FCMKLDNNKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fc_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fc_int8_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ebbbab7f3b33c580fa8d3646fcbdb6dcdb25d4ae --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fc_int8_mkldnn_op.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool + + +@OpTestTool.skip_if_not_cpu() +class TestFCINT8OneDNNOp(OpTest): + + def setUp(self): + self.op_type = "fc" + self._cpu_only = True + self.configure() + self.generate_data() + self.set_inputs() + + self.attrs = { + 'use_mkldnn': True, + 'Scale_in': self.x_scale, + 'Scale_weights': [self.y_scale], + 'Scale_out': self.out_scale, + 'force_fp32_output': self.force_fp32_output + } + + if self.force_fp32_output: + out = self.out_float + else: + out = self.out + + self.outputs = {'Out': out} + + def configure(self): + self.use_bias = True + self.force_fp32_output = False + + def set_inputs(self): + self.inputs = {'Input': self.x, 'W': self.y_float, 'Bias': self.bias} + + def quantize(self, tensor): + scale = 63. / np.abs(np.amax(tensor)) + quantized = np.round(scale * tensor).astype("int8") + return scale, quantized + + def generate_data(self): + self.x_float = np.random.random((10, 5)).astype("float32") * 10 + self.x_scale, self.x = self.quantize(self.x_float) + + self.y_float = np.random.random((5, 10)).astype("float32") * 10 + self.y_scale, self.y = self.quantize(self.y_float) + + self.out_float = np.dot(self.x_float, self.y_float) + if self.use_bias: + self.bias = np.random.random((10)).astype("float32") * 10 + self.out_float += self.bias + + self.out_scale, self.out = self.quantize(self.out_float) + + def test_check_output(self): + int_atol = 2 + self.check_output(int_atol) + + +class TestFCINT8NoBiasOneDNNOp(TestFCINT8OneDNNOp): + + def configure(self): + self.use_bias = False + self.force_fp32_output = False + + def set_inputs(self): + self.inputs = { + 'Input': self.x, + 'W': self.y_float, + } + + +class TestFCINT8ForceFP32OutputOneDNNOp(TestFCINT8NoBiasOneDNNOp): + + def configure(self): + self.use_bias = False + self.force_fp32_output = True + + +if __name__ == "__main__": + import paddle + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py index dfae1c514bafda2dee5071152c32f856fb321432..0c1d9bef032bc3d5c17e1696f7ecf5dbb9899cc9 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py @@ -73,4 +73,6 @@ class TestFCMKLDNNOp1(TestFCMKLDNNOp): if __name__ == "__main__": + import paddle + paddle.enable_static() unittest.main()