From 22708640191e0c1984fbfac87a6785d69f9958a7 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Sun, 26 Apr 2020 15:57:15 +0800 Subject: [PATCH] Fusion group optimize for cuda codegen(#23940) --- .../ir/fusion_group/code_generator.cc | 88 +++++++++++-------- .../ir/fusion_group/code_generator.h | 4 + .../ir/fusion_group/code_generator_helper.cc | 2 - .../ir/fusion_group/code_generator_helper.h | 10 ++- .../ir/fusion_group/cuda_resources.h | 16 ++++ .../ir/fusion_group/fusion_group_pass.cc | 24 ++++- .../framework/ir/fusion_group/operation.cc | 1 + .../framework/ir/fusion_group/subgraph.h | 55 ++++++++---- 8 files changed, 136 insertions(+), 64 deletions(-) diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.cc b/paddle/fluid/framework/ir/fusion_group/code_generator.cc index f88e872dbad..438a06f0f9a 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.cc @@ -71,6 +71,8 @@ static bool HasInput(Node* n, std::string name) { std::vector CodeGenerator::ConvertToExpressions( SubGraph* subgraph) { std::unordered_map var_ids = EncodeVarNodes(subgraph); + std::vector intermediate_out_nodes = + subgraph->GetIntermediateOutVarNodes(); std::vector expressions; for (auto* node : subgraph->SortedNodes()) { if (node && node->IsOp() && node->Op()) { @@ -81,7 +83,8 @@ std::vector CodeGenerator::ConvertToExpressions( // - X, Y in forward operations // - X, Y, Out, out@GRAD in backward operations std::vector input_ids; - auto operation = OperationMap::Instance().Get(op->Type()); + std::string op_name = op->Type(); + auto operation = OperationMap::Instance().Get(op_name); std::vector input_names = operation.input_names; for (auto& name : input_names) { @@ -105,6 +108,7 @@ std::vector CodeGenerator::ConvertToExpressions( std::vector output_ids; std::vector output_names = OperationMap::Instance().Get(op->Type()).output_names; + std::unordered_map intermediate_state; for (auto& name : output_names) { PADDLE_ENFORCE_NE( @@ -112,12 +116,21 @@ std::vector CodeGenerator::ConvertToExpressions( platform::errors::InvalidArgument( "Output(%s) of operation %s is not set.", name, op->Type())); output_ids.push_back(var_ids[op->Output(name)[0]]); + bool enable_intermediate = false; + for (auto* n : intermediate_out_nodes) { + if (n->Name() == op->Output(name)[0]) { + enable_intermediate = true; + break; + } + } + intermediate_state[var_ids[op->Output(name)[0]]] = enable_intermediate; } std::string lhs_type = ExtractDataType(node->outputs); std::string rhs_type = ExtractDataType(node->inputs); - auto expression = OperationExpression(node->Name(), input_ids, output_ids, - rhs_type, lhs_type); + auto expression = + OperationExpression(node->Name(), input_ids, output_ids, rhs_type, + lhs_type, intermediate_state); expression.SetAttr(attr); expressions.push_back(expression); } @@ -133,13 +146,17 @@ std::string CodeGenerator::Generate( // TODO(liuyiqun): Check whether all expressions are elementwise operations. std::set input_ids = std::move(DistilInputIds(expressions)); std::set output_ids = std::move(DistilOutputIds(expressions)); + std::set intermediate_ids = + std::move(DistilIntermediateIds(expressions)); std::unordered_map dtypes = std::move(DistilDtypes(expressions)); TemplateVariable template_var; template_var.Add("func_name", func_name); - template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtypes)); + template_var.Add("parameters", EmitParameters(input_ids, output_ids, + intermediate_ids, dtypes)); template_var.Add("compute_body", - EmitComputeBody(expressions, input_ids, output_ids, dtypes)); + EmitComputeBody(expressions, input_ids, output_ids, + intermediate_ids, dtypes)); std::set all_dtype; for (const auto& type : dtypes) { @@ -185,6 +202,19 @@ std::set CodeGenerator::DistilOutputIds( return output_ids; } +std::set CodeGenerator::DistilIntermediateIds( + const std::vector& expressions) { + std::set intermediate_ids; + // Use std::set to remove the reptead id and get a ordered list. + for (size_t i = 0; i < expressions.size(); i++) { + for (auto id : expressions[i].GetOutputIds()) { + auto intermediate_state = expressions[i].GetIntermediateState(); + if (intermediate_state[id]) intermediate_ids.insert(id); + } + } + return intermediate_ids; +} + std::unordered_map CodeGenerator::DistilDtypes( const std::vector& expressions) { std::unordered_map dtypes; @@ -218,6 +248,7 @@ std::unordered_map CodeGenerator::DistilDtypes( // we get the parameter list code for the expression information std::string CodeGenerator::EmitParameters( const std::set& input_ids, const std::set& output_ids, + const std::set& intermediate_ids, const std::unordered_map& dtypes) const { std::stringstream ret; ret << "int N, "; @@ -226,25 +257,28 @@ std::string CodeGenerator::EmitParameters( // from the input list. for (auto id : input_ids) { if (output_ids.find(id) == output_ids.end()) { - ret << dtypes.at(id) << "* " << ArgName(id) << ", "; + ret << "const " << dtypes.at(id) << "* __restrict__ " << ArgName(id) + << ", "; } } size_t index = 0; for (auto id : output_ids) { - ret << dtypes.at(id) << "* " << ArgName(id); - if (index != output_ids.size() - 1) { - ret << ", "; + if (intermediate_ids.find(id) == intermediate_ids.end()) { + ret << dtypes.at(id) << "* " << ArgName(id); + if (index != output_ids.size() - 1) { + ret << ", "; + } } index++; } - return ret.str(); } std::string CodeGenerator::EmitComputeBody( const std::vector& expressions, const std::set& input_ids, const std::set& output_ids, + const std::set& intermediate_ids, const std::unordered_map& dtypes) const { std::ostringstream compute; std::unordered_set used; @@ -258,14 +292,17 @@ std::string CodeGenerator::EmitComputeBody( for (auto id : input_ids) { if (output_ids.find(id) == output_ids.end() && used.find(id) != used.end()) { - load << dtypes.at(id) << " " << TmpName(id) << " = " << VarName(id) + load << dtypes.at(id) << " " << TmpName(id) << " = " + << "__ldg(&" << VarName(id) << ")" << ";"; } } // Store temporal variables to memory. std::ostringstream store; for (auto id : output_ids) { - store << VarName(id) << " = " << TmpName(id) << ";"; + if (intermediate_ids.find(id) == intermediate_ids.end()) { + store << VarName(id) << " = " << TmpName(id) << ";"; + } } return load.str() + compute.str() + store.str(); @@ -285,32 +322,7 @@ std::unordered_map CodeGenerator::EncodeVarNodes( var_ids[in->Name()] = id++; } } - // Numbering internal vars. - for (auto* node : subgraph->SortedNodes()) { - if (node && node->IsVar() && node->Var()) { - bool is_found = false; - for (auto* in : input_var_nodes) { - if (node == in) { - is_found = true; - break; - } - } - if (is_found) { - continue; - } - for (auto* out : output_var_nodes) { - if (node == out) { - is_found = true; - break; - } - } - PADDLE_ENFORCE_EQ( - is_found, true, - platform::errors::Unimplemented( - "Subgraph with internal var nodes (%s) is not supported yet.", - node->Name())); - } - } + // Encoding output vars. for (auto* out : output_var_nodes) { VLOG(3) << "Ecoding output names:" << out->Name() << ", id:" << id; diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.h b/paddle/fluid/framework/ir/fusion_group/code_generator.h index 670fc47a6af..2b18657bbcf 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.h +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.h @@ -43,17 +43,21 @@ class CodeGenerator { const std::vector& expressions); std::set DistilOutputIds( const std::vector& expressions); + std::set DistilIntermediateIds( + const std::vector& expressions); std::unordered_map DistilDtypes( const std::vector& expressions); // we get the parameter list code for the expression information std::string EmitParameters( const std::set& input_ids, const std::set& output_ids, + const std::set& intermediate_ids, const std::unordered_map& dtypes) const; std::string EmitComputeBody( const std::vector& expressions, const std::set& input_ids, const std::set& output_ids, + const std::set& intermediate_ids, const std::unordered_map& dtypes) const; // Encode all var nodes in the subgraph with an unique number. diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc index ed6007d36e0..69a78e1c753 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc @@ -149,8 +149,6 @@ std::string OperationExpression::GetRHS(std::unordered_set* used, "Expected %d-th input id > 0 for operation < %s " ">. Received %d.", index, op_type_, input_ids_[index])); - // TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we - // need to add general fp16 compute later. var_name = TmpName(input_ids_[index]); rhs.replace(pos, length + 3, var_name); used->insert(input_ids_[index]); diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h index 2f4a5f5901f..63197c00ff1 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h @@ -46,14 +46,19 @@ class OperationExpression { public: explicit OperationExpression(std::string op_type, std::vector input_ids, std::vector output_ids, - std::string rhs_type, std::string lhs_type) + std::string rhs_type, std::string lhs_type, + std::unordered_map intermediate_state) : op_type_(op_type), input_ids_(input_ids), output_ids_(output_ids), rhs_type_(rhs_type), - lhs_type_(lhs_type) {} + lhs_type_(lhs_type), + intermediate_state_(intermediate_state) {} std::string GetOpType() const { return op_type_; } + std::unordered_map GetIntermediateState() const { + return intermediate_state_; + } std::vector GetInputIds() const { return input_ids_; } std::vector GetOutputIds() const { return output_ids_; } std::string GetRHSType() const { return rhs_type_; } @@ -78,6 +83,7 @@ class OperationExpression { AttributeMap attr_; std::string rhs_type_; std::string lhs_type_; + std::unordered_map intermediate_state_; }; class TemplateVariable { diff --git a/paddle/fluid/framework/ir/fusion_group/cuda_resources.h b/paddle/fluid/framework/ir/fusion_group/cuda_resources.h index 945b7929db9..6514b87b06e 100644 --- a/paddle/fluid/framework/ir/fusion_group/cuda_resources.h +++ b/paddle/fluid/framework/ir/fusion_group/cuda_resources.h @@ -269,6 +269,22 @@ __CUDA_FP16_DECL__ __half hsqrt(const __half a) { __APPROX_FCAST(sqrt); } +#if defined(__cplusplus) && (__CUDA_ARCH__ >= 320 || !defined(__CUDA_ARCH__)) +#if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__) +#define __LDG_PTR "l" +#else +#define __LDG_PTR "r" +#endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/ +__CUDA_FP16_DECL__ __half __ldg(const __half *ptr) +{ + __half ret; + asm ("ld.global.nc.b16 %0, [%1];" : "=h"(__HALF_TO_US(ret)) : __LDG_PTR(ptr)); + return ret; +} + +#undef __LDG_PTR +#endif /*defined(__cplusplus) && (__CUDA_ARCH__ >= 320 || !defined(__CUDA_ARCH__))*/ + __device__ inline __half Exp(const __half x) { return hexp(x); } __device__ inline __half Log(const __half x) { return hlog(x); } __device__ inline __half Sqrt(const __half x) { return hsqrt(x); } diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc index a34b27b6418..7ac14a698aa 100644 --- a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc @@ -48,13 +48,18 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const { int num_subgraphs = 0; size_t min_subgraph_size = 2; - bool save_intermediate_out = true; + bool save_intermediate_out = false; for (auto& vec : subgraphs) { fusion_group::SubGraph subgraph( type, "", save_intermediate_out, std::unordered_set(vec.begin(), vec.end())); VLOG(3) << "subgraph: {\n" << DebugString(subgraph.SortedNodes()) << "}\n"; + // In elementwise fused kernel, memory is the bound of execution, + // here we remove the output id to use less memory and less time. + if (subgraph.RemoveIntermediateOut()) { + subgraph.DetectIntermediateOutWithGraph(graph); + } if (subgraph.IsValid(min_subgraph_size)) { subgraph.SetFuncName("FusedElementwise" + std::to_string(index++)); if (GenerateCode(&subgraph)) { @@ -106,6 +111,8 @@ void FusionGroupPass::InsertFusionGroupOp( subgraph->GetInputVarNodes(); const std::vector& output_vars_of_subgraph = subgraph->GetOutputVarNodes(); + const std::vector intermediate_vars_of_subgraph = + subgraph->GetIntermediateOutVarNodes(); std::unordered_set external_nodes; OpDesc op_desc; @@ -122,9 +129,18 @@ void FusionGroupPass::InsertFusionGroupOp( std::vector output_names; std::vector outs_data_types; + std::vector output_var_without_intermediate; for (auto* n : output_vars_of_subgraph) { - output_names.push_back(n->Name()); - outs_data_types.push_back(DataTypeToString(n->Var()->GetDataType())); + auto it_input = + find(input_vars_of_subgraph.begin(), input_vars_of_subgraph.end(), n); + auto it_intermediate = find(intermediate_vars_of_subgraph.begin(), + intermediate_vars_of_subgraph.end(), n); + if (it_intermediate == intermediate_vars_of_subgraph.end() && + it_input == input_vars_of_subgraph.end()) { + output_names.push_back(n->Name()); + outs_data_types.push_back(DataTypeToString(n->Var()->GetDataType())); + output_var_without_intermediate.push_back(n); + } external_nodes.insert(n); } @@ -141,7 +157,7 @@ void FusionGroupPass::InsertFusionGroupOp( IR_NODE_LINK_TO(in, fusion_group_node); } - for (auto* out : output_vars_of_subgraph) { + for (auto* out : output_var_without_intermediate) { IR_NODE_LINK_TO(fusion_group_node, out); } diff --git a/paddle/fluid/framework/ir/fusion_group/operation.cc b/paddle/fluid/framework/ir/fusion_group/operation.cc index ef715384573..b127d132baf 100644 --- a/paddle/fluid/framework/ir/fusion_group/operation.cc +++ b/paddle/fluid/framework/ir/fusion_group/operation.cc @@ -54,6 +54,7 @@ void OperationMap::Insert(int type, int num_operands, std::string op_type, std::string grad_op_type = op_type + "_grad"; // grad_inputs = inputs + outputs + grad of outputs std::vector grad_input_names = input_names; + for (auto name : output_names) { grad_input_names.push_back(name); } diff --git a/paddle/fluid/framework/ir/fusion_group/subgraph.h b/paddle/fluid/framework/ir/fusion_group/subgraph.h index 029166cbe17..40707b7c8bb 100644 --- a/paddle/fluid/framework/ir/fusion_group/subgraph.h +++ b/paddle/fluid/framework/ir/fusion_group/subgraph.h @@ -19,6 +19,8 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/ir/fusion_group/operation.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/subgraph_detector.h" @@ -64,6 +66,7 @@ class SubGraph { } int GetType() const { return type_; } + bool RemoveIntermediateOut() { return !save_intermediate_out_; } void SetFuncName(std::string func_name) { func_name_ = func_name; } std::string GetFuncName() const { return func_name_; } @@ -133,30 +136,45 @@ class SubGraph { } } } + return output_vars_all; + } - if (save_intermediate_out_) { - return output_vars_all; - } + std::vector GetIntermediateOutVarNodes() { + return intermediate_out_nodes_; + } - std::vector output_vars_outside; - for (auto* n : output_vars_all) { - // If one of the var_node's outputs is the input of some operator - // outside the subgraph, it is considered the output var node of the - // subgraph. - bool is_found = true; - if (n->outputs.size() == 0U) { - is_found = false; - } - for (auto* out : n->outputs) { - if (!Has(out)) { - is_found = false; + void DetectIntermediateOutWithGraph(Graph* graph) { + auto graph_nodes = graph->Nodes(); + + for (auto* n : SortedNodes()) { + bool enable_remove = true; + + if (n && n->IsVar() && n->Var()) { + bool leaf_graph = true; + for (auto* node : graph_nodes) { + if (node->IsOp()) { + auto inputs = node->inputs; + for (auto* in : inputs) { + if (in == n) { + if (!Has(node)) enable_remove = false; + leaf_graph = false; + } + } + } + if (!enable_remove) { + break; + } } + if (leaf_graph) enable_remove = false; + + } else { + enable_remove = false; } - if (!is_found) { - output_vars_outside.push_back(n); + + if (enable_remove) { + intermediate_out_nodes_.push_back(n); } } - return output_vars_outside; } private: @@ -218,6 +236,7 @@ class SubGraph { bool save_intermediate_out_{true}; std::unordered_set nodes_set_; + std::vector intermediate_out_nodes_{}; bool is_sorted_{false}; std::vector sorted_nodes_; }; -- GitLab