From 61921084082e55041a3d1e6a34a735fba32e558b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Gallus?= Date: Fri, 3 Jan 2020 07:24:34 +0100 Subject: [PATCH] [DNNL] 3D Fully-Connected (#21746) --- paddle/fluid/framework/ir/fc_fuse_pass.cc | 7 +- .../inference/analysis/ir_pass_manager.cc | 38 ++- .../inference/analysis/ir_pass_manager.h | 6 +- paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc | 263 ++++++++++-------- .../unittests/mkldnn/test_fc_mkldnn_op.py | 23 +- 5 files changed, 208 insertions(+), 129 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 8eccad1ee0e..ed575272f17 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -92,14 +92,15 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { // This is to add padding for dimension 128 on concern of MKL performance auto* scope = param_scope(); auto* weight = scope->FindVar(w->Name())->GetMutable(); - auto place = weight->place(); - bool use_gpu = Get("use_gpu"); auto* weight_data = weight->data(); auto weight_dims = weight->dims(); int weight_num = product(weight_dims); int w_h = weight_dims[0]; int w_w = weight_dims[1]; - if (!use_gpu) { + bool use_gpu = Has("use_gpu") ? Get("use_gpu") : false; + bool use_fc_padding = + Has("use_fc_padding") ? Get("use_fc_padding") : true; + if (!use_gpu && use_fc_padding) { if (w_h % 128 == 0 && w_w % 128 == 0) { auto* weight_data_tmp = new float[weight_num]; for (int i = 0; i < w_h; i++) { diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 174d6e3fc1c..c8486f5151c 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -158,11 +158,47 @@ void IRPassManager::CreatePasses(Argument *argument, } } +bool IRPassManager::HasPass(const std::string &pass_type) { + if (passes_.empty()) return false; + auto it = std::find_if( + passes_.begin(), passes_.end(), + [&](std::unique_ptr &pass) { return pass->Type() == pass_type; }); + return it != passes_.end(); +} + +std::unique_ptr &IRPassManager::GetPass(const std::string &pass_type) { + PADDLE_ENFORCE_EQ(passes_.empty(), false, + platform::errors::PreconditionNotMet( + "The list of passes cannot be empty.")); + auto it = std::find_if(passes_.begin(), passes_.end(), + [&](const std::unique_ptr &pass) { + return pass->Type() == pass_type; + }); + PADDLE_ENFORCE_NE(it, passes_.end(), + platform::errors::PermissionDenied( + "You cannot get pass which was not added earlier.")); + return *it; +} + +// Some passes depend on each other. This method serves for exchanging +// information between them. +void IRPassManager::UpdatePasses() { + // Update padding settings for fc_fuse_pass. Skipp adding padding for + // MKL-DNN-based FC + bool use_fc_padding = !HasPass("fc_mkldnn_pass"); + if (HasPass("fc_fuse_pass")) { + auto &fc_fuse_pass = GetPass("fc_fuse_pass"); + fc_fuse_pass->Set("use_fc_padding", new bool(use_fc_padding)); + } +} + std::unique_ptr IRPassManager::Apply(std::unique_ptr graph) { if (passes_.empty()) { return graph; } - PADDLE_ENFORCE(graph.get()); + PADDLE_ENFORCE_NOT_NULL(graph.get(), platform::errors::PreconditionNotMet( + "Graph cannot be NULL.")); + UpdatePasses(); // Apply all the passes for (const auto &pass : passes_) { if (pass->Type() != "graph_viz_pass" && !disable_logs_) { diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.h b/paddle/fluid/inference/analysis/ir_pass_manager.h index f96b4a0f135..2366a0eaf09 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.h +++ b/paddle/fluid/inference/analysis/ir_pass_manager.h @@ -39,6 +39,7 @@ namespace inference { namespace analysis { using framework::ProgramDesc; using framework::ir::Graph; +using framework::ir::Pass; class IRPassManager final { public: @@ -53,9 +54,12 @@ class IRPassManager final { private: void CreatePasses(Argument *argument, const std::vector &passes); + bool HasPass(const std::string &pass_type); + std::unique_ptr &GetPass(const std::string &pass_type); + void UpdatePasses(); std::unique_ptr graph_; - std::vector> passes_; + std::vector> passes_; bool disable_logs_{false}; }; diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index dfe9639b6cc..edc14add803 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -52,26 +52,56 @@ class FCPrimitiveFactory { UpdateDataPointers(ctx, output, input); this->Execute(); return; - } - auto src_desc = CreateMemDescriptor(input, input->format()); - input_ = CreateMemory(src_desc, input); + } // Otherwise, create a new one. - // Since MKL-DNN doesn't support 4D column-major data formats in - // inner_product - // primitive, transpose the weights to be in row-major format + // Transform weights to default MKL-DNN format weights_ = TransposeWeights(weights); - if (src_desc.data.ndims == 4) { - weights_ = CreateFourDimWeightsMemory(input, weights); + // Since MKL-DNN has a lot of limitations on what the input/weights/output + // dimensions should be, to simplify the code, the creation of primitive + // descriptor has been divided into separate cases, based on the number + // of input dimensions. + size_t input_dim_num = input->dims().size(); + boost::optional fc_prim_desc; + memory::desc usr_weights_desc = {}; + switch (input_dim_num) { + case 2: + fc_prim_desc = + Create2DFcPrimDescriptor(input, weights, bias, output, ctx); + usr_weights_desc = Create2DUserWeightsDesc(); + break; + case 3: + fc_prim_desc = + Create3DFcPrimDescriptor(input, weights, bias, output, ctx); + usr_weights_desc = Create3DUserWeightsDesc(weights); + break; + case 4: + fc_prim_desc = + Create4DFcPrimDescriptor(input, weights, bias, output, ctx); + usr_weights_desc = Create4DUserWeightsDesc(input, weights); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "DNNL FC doesn't support input dims different than 2, 3, 4.")); + break; } - // If int8 data type is desired, weights are quantized to signed int8 - QuantizeWeights(ctx); + input_ = CreateMemory(fc_prim_desc->src_desc(), input); + // Update weights format inside of its memory + weights_ = Reorder(usr_weights_desc, usr_weights_desc, + weights_->get_data_handle()); - // Choose MKLDNNMemoryFormat::any so that MKL-DNN can determine itself what - // is the best format for output during the creation of inner product - // primitive descriptor - auto dst_desc = CreateMemDescriptor(output, MKLDNNMemoryFormat::any); + // Quantize weights and reorder to format chosen by FC primitive descriptor. + QuantizeWeights(ctx, fc_prim_desc->weights_desc()); + + bias_ = CreateMemory(fc_prim_desc->bias_desc(), bias); + // If int8 is desired, quantize bias into 32-bit signed int + QuantizeBias(*fc_prim_desc, ctx); - fc_ = CreateFcPrimitive(*input_, *weights_, dst_desc, bias, output, ctx); + // Based on format determined by inner_product, create output in desired + // memory format + output_ = CreateDstMemory(*fc_prim_desc, ctx, output); + + // Return MKL-DNN primitive ready to be fed into pipeline and executed + fc_ = inner_product_forward(*fc_prim_desc); this->Execute(); } @@ -99,26 +129,99 @@ class FCPrimitiveFactory { // variable, update its format to what has been determined in first // call to CreateFcPrimitive method. if (out->format() == MKLDNNMemoryFormat::undef) { - auto output_format = platform::GetMKLDNNFormat(*output_); - out->set_format((MKLDNNMemoryFormat)output_format); + MKLDNNMemoryFormat format; + auto data_type = input_->get_desc().data.data_type; + if (data_type == mkldnn_f32) + format = MKLDNNMemoryFormat::nchw; + else + format = MKLDNNMemoryFormat::nhwc; + + MKLDNNMemoryFormat selected = platform::MKLDNNFormatForSize( + framework::vectorize(out->dims()).size(), format); + + out->set_format(selected); } } - // Choose weight memory format based on input memory format - MKLDNNMemoryFormat MatchWeightFormat(MKLDNNMemoryFormat fmt) { - using format = MKLDNNMemoryFormat; - switch (fmt) { - case format::nChw16c: - return format::aBcd16b; - case format::nChw8c: - return format::aBcd8b; - case format::nchw: - return format::oihw; - case format::nhwc: - return format::hwio; - default: - return format::undef; - } + mkldnn::inner_product_forward::primitive_desc Create2DFcPrimDescriptor( + const LoDTensor* input, const Tensor* weights, const Tensor* bias, + LoDTensor* output, const ExecutionContext& ctx) { + auto src_desc = CreateMemDescriptor(input, input->format()); + auto weight_dims = Get2DWeightDimsForDNNL(weights); + auto weights_desc = + CreateMemDescriptor(weight_dims, MKLDNNMemoryFormat::any); + auto bias_desc = CreateMemDescriptor(bias, MKLDNNMemoryFormat::x); + auto dst_desc = CreateMemDescriptor(output, MKLDNNMemoryFormat::any); + const auto attrs = CreatePostOps(ctx); + return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); + } + + std::vector Get2DWeightDimsForDNNL(const Tensor* weights) { + auto dims = framework::vectorize(weights->dims()); + std::swap(dims[0], dims[1]); // swap input dim with output dim + return dims; + } + + memory::desc Create2DUserWeightsDesc() { return weights_->get_desc(); } + + mkldnn::inner_product_forward::primitive_desc Create3DFcPrimDescriptor( + const LoDTensor* input, const Tensor* weights, const Tensor* bias, + LoDTensor* output, const ExecutionContext& ctx) { + auto input_dims = framework::vectorize(input->dims()); + std::vector new_input_dims = {input_dims[0] * input_dims[1], 1, + input_dims[2]}; + auto src_desc = CreateMemDescriptor(new_input_dims, input->format()); + + auto weight_dims = Get3DWeightDimsForDNNL(weights); + auto weights_desc = + CreateMemDescriptor(weight_dims, MKLDNNMemoryFormat::any); + + auto bias_desc = CreateMemDescriptor(bias, MKLDNNMemoryFormat::x); + + auto dst_dims = {input_dims[0] * input_dims[1], weight_dims[0]}; + auto dst_desc = + CreateMemDescriptor(dst_dims, MKLDNNMemoryFormat::any); + const auto attrs = CreatePostOps(ctx); + return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); + } + + std::vector Get3DWeightDimsForDNNL(const Tensor* weights) { + auto paddle_w_dims = framework::vectorize(weights->dims()); + return {paddle_w_dims[1], 1, paddle_w_dims[0]}; + } + + memory::desc Create3DUserWeightsDesc(const Tensor* weights) { + auto dims = Get3DWeightDimsForDNNL(weights); + return CreateMemDescriptor(dims, MKLDNNMemoryFormat::oiw); + } + + mkldnn::inner_product_forward::primitive_desc Create4DFcPrimDescriptor( + const LoDTensor* input, const Tensor* weights, const Tensor* bias, + LoDTensor* output, const ExecutionContext& ctx) { + auto src_desc = CreateMemDescriptor(input, input->format()); + // Since MKL-DNN doesn't support 4D column-major data formats in + // inner_product primitive, transpose the weights to be in + // row-major format + auto dims = Get4DWeightDimsForDNNL(input, weights); + auto weights_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::any); + auto bias_desc = CreateMemDescriptor(bias, MKLDNNMemoryFormat::x); + auto dst_desc = CreateMemDescriptor(output, MKLDNNMemoryFormat::any); + const auto attrs = CreatePostOps(ctx); + return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); + } + + std::vector Get4DWeightDimsForDNNL(const LoDTensor* input, + const Tensor* weights) { + auto old_w_dims = framework::vectorize(weights->dims()); + auto old_in_dims = framework::vectorize(input->dims()); + auto dims = {old_w_dims[1], old_in_dims[1], old_in_dims[2], old_in_dims[3]}; + return dims; + } + + memory::desc Create4DUserWeightsDesc(const LoDTensor* input, + const Tensor* weights) { + auto dims = Get4DWeightDimsForDNNL(input, weights); + return CreateMemDescriptor(dims, MKLDNNMemoryFormat::oihw); } // Convert data from one data format to another @@ -247,12 +350,9 @@ class FCPrimitiveFactory { return is_multi_channel_quantizied ? 1 << slice_dimension : 0; } - void QuantizeWeights(const ExecutionContext& ctx) { - auto quantized_desc = weights_->get_desc(); - quantized_desc.data.data_type = - (mkldnn_data_type_t)platform::MKLDNNGetDataType(); - weights_ = Reorder(*weights_, quantized_desc, - ctx.Attr>("Scale_weights")); + void QuantizeWeights(const ExecutionContext& ctx, memory::desc dst) { + weights_ = + Reorder(*weights_, dst, ctx.Attr>("Scale_weights")); } void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc, @@ -282,43 +382,6 @@ class FCPrimitiveFactory { return attributes; } - inner_product_forward CreateFcPrimitive(const memory& src_memory, - const memory& weights_memory, - const memory::desc& dst_desc, - const Tensor* bias, Tensor* output, - const ExecutionContext& ctx) { - // Acquire descriptors needed for creation of inner_product primitive - // descriptor - const auto weights_desc = weights_memory.get_desc(); - const auto src_desc = src_memory.get_desc(); - // Based on provided attributes, create attributes used by MKL-DNN to - // enable fused post-op activations such as 'relu' - const auto attrs = CreatePostOps(ctx); - // If bias exists, create inner_product primitive with or without bias - if (bias) { - auto bias_desc = CreateMemDescriptor(bias, bias->format()); - bias_ = CreateMemory(bias_desc, bias); - // Create inner_product descriptor. At this point the format of output - // is determined. - auto fc_prim_desc = - CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); - // If int8 is desired, quantize bias into 32-bit signed int - QuantizeBias(fc_prim_desc, ctx); - - // Based on format determined by inner_product, create output in desired - // memory format - output_ = CreateDstMemory(fc_prim_desc, ctx, output); - - // Return MKL-DNN primitive ready to be fed into pipeline and executed - return inner_product_forward(fc_prim_desc); - } else { - auto fc_prim_desc = - CreateFcPrimDesc(src_desc, weights_desc, dst_desc, attrs); - output_ = CreateDstMemory(fc_prim_desc, ctx, output); - return inner_product_forward(fc_prim_desc); - } - } - mkldnn::inner_product_forward::primitive_desc CreateFcPrimDesc( const mkldnn::memory::desc& input_desc, const mkldnn::memory::desc& weights_desc, @@ -332,43 +395,6 @@ class FCPrimitiveFactory { return inner_product_forward::primitive_desc(fc_desc, attrs, engine_); } - mkldnn::inner_product_forward::primitive_desc CreateFcPrimDesc( - const mkldnn::memory::desc& input_desc, - const mkldnn::memory::desc& weights_desc, - const mkldnn::memory::desc& dst_desc, - const mkldnn::primitive_attr& attrs) { - auto fc_desc = inner_product_forward::desc(prop_kind::forward, input_desc, - weights_desc, dst_desc); - - return inner_product_forward::primitive_desc(fc_desc, attrs, engine_); - } - - // Since MKL-DNN requires the number of input dimensions to be - // equal to the number of weight dimensions, we have to convert - // weights to 4D memory if input is 4D. It also requires that - // all dimensions of weights and inputs agree, with an exception - // for the batch size and number of output channels (the first dim). - // In order to perform that we have to prepare the memory descriptor - // by hand, as MKL-DNN's reorder does not support conversion - // from one dimensionality to another. Hence, we set - // the first dimension of weights to resemble number of outputs - // and then we use the sizes of number of input channels as well - // as image width and height for latter dimensions. Then we create - // memories, find a format corresponding with input format and - // perform a converion. - mkldnn::memory CreateFourDimWeightsMemory(const Tensor* input, - const Tensor* weights) { - auto input_dims = framework::vectorize(input->dims()); - auto weight_dims = framework::vectorize(weights->dims()); - auto dims = {weight_dims[1], input_dims[1], input_dims[2], input_dims[3]}; - - auto dst_format = MatchWeightFormat(input->format()); - auto src_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::oihw); - auto dst_desc = CreateMemDescriptor(dims, dst_format); - - return Reorder(src_desc, dst_desc, weights_->get_data_handle()); - } - // Create output memory based on output tensor and inner_product // primitive descriptor format chosen for output mkldnn::memory CreateDstMemory( @@ -379,7 +405,18 @@ class FCPrimitiveFactory { T_out* output_data = output->mutable_data(ctx.GetPlace(), buffer_size); memory dst_mem(dst_desc, engine_, to_void_cast(output_data)); - output->set_format(platform::GetMKLDNNFormat(dst_mem)); + + MKLDNNMemoryFormat format; + auto data_type = input_->get_desc().data.data_type; + if (data_type == mkldnn_f32) + format = MKLDNNMemoryFormat::nchw; + else + format = MKLDNNMemoryFormat::nhwc; + + MKLDNNMemoryFormat selected = platform::MKLDNNFormatForSize( + framework::vectorize(output->dims()).size(), format); + + output->set_format(selected); return dst_mem; } diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py index fb54bf55543..e96b8cf8191 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py @@ -19,14 +19,8 @@ import numpy as np from paddle.fluid.tests.unittests.op_test import OpTest -def fully_connected_naive(input, weights, bias_data=None): - result = None - - if not bias_data: - result = np.dot(input, weights) - else: - result = np.dot(input, weights) + bias_data - +def fully_connected_naive(input, weights, bias_data): + result = np.dot(input, weights) + bias_data return result @@ -39,18 +33,24 @@ class MatrixGenerate: class TestFCMKLDNNOp(OpTest): def create_data(self): self.matrix = MatrixGenerate(1, 10, 15, 3, 3) + self.bias = np.random.random(15).astype("float32") def setUp(self): self.op_type = "fc" self._cpu_only = True self.use_mkldnn = True self.create_data() - self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights} + self.inputs = { + 'Input': self.matrix.input, + 'W': self.matrix.weights, + 'Bias': self.bias + } - self.attrs = {'use_mkldnn': self.use_mkldnn, } + self.attrs = {'use_mkldnn': self.use_mkldnn} self.outputs = { - 'Out': fully_connected_naive(self.matrix.input, self.matrix.weights) + 'Out': fully_connected_naive(self.matrix.input, self.matrix.weights, + self.bias) } def test_check_output(self): @@ -67,6 +67,7 @@ class TestFCMKLDNNOp(OpTest): class TestFCMKLDNNOp1(TestFCMKLDNNOp): def create_data(self): self.matrix = MatrixGenerate(2, 15, 48, 2, 2) + self.bias = np.random.random(48).astype("float32") if __name__ == "__main__": -- GitLab