未验证 提交 22708640 编写于 作者: W wangchaochaohu 提交者: GitHub

Fusion group optimize for cuda codegen(#23940)

上级 94dfb7d7
......@@ -71,6 +71,8 @@ static bool HasInput(Node* n, std::string name) {
std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
SubGraph* subgraph) {
std::unordered_map<std::string, int> var_ids = EncodeVarNodes(subgraph);
std::vector<Node*> intermediate_out_nodes =
subgraph->GetIntermediateOutVarNodes();
std::vector<OperationExpression> expressions;
for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsOp() && node->Op()) {
......@@ -81,7 +83,8 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// - X, Y in forward operations
// - X, Y, Out, out@GRAD in backward operations
std::vector<int> input_ids;
auto operation = OperationMap::Instance().Get(op->Type());
std::string op_name = op->Type();
auto operation = OperationMap::Instance().Get(op_name);
std::vector<std::string> input_names = operation.input_names;
for (auto& name : input_names) {
......@@ -105,6 +108,7 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
std::vector<int> output_ids;
std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names;
std::unordered_map<int, bool> intermediate_state;
for (auto& name : output_names) {
PADDLE_ENFORCE_NE(
......@@ -112,12 +116,21 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
platform::errors::InvalidArgument(
"Output(%s) of operation %s is not set.", name, op->Type()));
output_ids.push_back(var_ids[op->Output(name)[0]]);
bool enable_intermediate = false;
for (auto* n : intermediate_out_nodes) {
if (n->Name() == op->Output(name)[0]) {
enable_intermediate = true;
break;
}
}
intermediate_state[var_ids[op->Output(name)[0]]] = enable_intermediate;
}
std::string lhs_type = ExtractDataType(node->outputs);
std::string rhs_type = ExtractDataType(node->inputs);
auto expression = OperationExpression(node->Name(), input_ids, output_ids,
rhs_type, lhs_type);
auto expression =
OperationExpression(node->Name(), input_ids, output_ids, rhs_type,
lhs_type, intermediate_state);
expression.SetAttr(attr);
expressions.push_back(expression);
}
......@@ -133,13 +146,17 @@ std::string CodeGenerator::Generate(
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
std::set<int> input_ids = std::move(DistilInputIds(expressions));
std::set<int> output_ids = std::move(DistilOutputIds(expressions));
std::set<int> intermediate_ids =
std::move(DistilIntermediateIds(expressions));
std::unordered_map<int, std::string> dtypes =
std::move(DistilDtypes(expressions));
TemplateVariable template_var;
template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtypes));
template_var.Add("parameters", EmitParameters(input_ids, output_ids,
intermediate_ids, dtypes));
template_var.Add("compute_body",
EmitComputeBody(expressions, input_ids, output_ids, dtypes));
EmitComputeBody(expressions, input_ids, output_ids,
intermediate_ids, dtypes));
std::set<std::string> all_dtype;
for (const auto& type : dtypes) {
......@@ -185,6 +202,19 @@ std::set<int> CodeGenerator::DistilOutputIds(
return output_ids;
}
std::set<int> CodeGenerator::DistilIntermediateIds(
const std::vector<OperationExpression>& expressions) {
std::set<int> intermediate_ids;
// Use std::set to remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetOutputIds()) {
auto intermediate_state = expressions[i].GetIntermediateState();
if (intermediate_state[id]) intermediate_ids.insert(id);
}
}
return intermediate_ids;
}
std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
const std::vector<OperationExpression>& expressions) {
std::unordered_map<int, std::string> dtypes;
......@@ -218,6 +248,7 @@ std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
// we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(
const std::set<int>& input_ids, const std::set<int>& output_ids,
const std::set<int>& intermediate_ids,
const std::unordered_map<int, std::string>& dtypes) const {
std::stringstream ret;
ret << "int N, ";
......@@ -226,25 +257,28 @@ std::string CodeGenerator::EmitParameters(
// from the input list.
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end()) {
ret << dtypes.at(id) << "* " << ArgName(id) << ", ";
ret << "const " << dtypes.at(id) << "* __restrict__ " << ArgName(id)
<< ", ";
}
}
size_t index = 0;
for (auto id : output_ids) {
ret << dtypes.at(id) << "* " << ArgName(id);
if (index != output_ids.size() - 1) {
ret << ", ";
if (intermediate_ids.find(id) == intermediate_ids.end()) {
ret << dtypes.at(id) << "* " << ArgName(id);
if (index != output_ids.size() - 1) {
ret << ", ";
}
}
index++;
}
return ret.str();
}
std::string CodeGenerator::EmitComputeBody(
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
const std::set<int>& intermediate_ids,
const std::unordered_map<int, std::string>& dtypes) const {
std::ostringstream compute;
std::unordered_set<int> used;
......@@ -258,14 +292,17 @@ std::string CodeGenerator::EmitComputeBody(
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end() &&
used.find(id) != used.end()) {
load << dtypes.at(id) << " " << TmpName(id) << " = " << VarName(id)
load << dtypes.at(id) << " " << TmpName(id) << " = "
<< "__ldg(&" << VarName(id) << ")"
<< ";";
}
}
// Store temporal variables to memory.
std::ostringstream store;
for (auto id : output_ids) {
store << VarName(id) << " = " << TmpName(id) << ";";
if (intermediate_ids.find(id) == intermediate_ids.end()) {
store << VarName(id) << " = " << TmpName(id) << ";";
}
}
return load.str() + compute.str() + store.str();
......@@ -285,32 +322,7 @@ std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes(
var_ids[in->Name()] = id++;
}
}
// Numbering internal vars.
for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsVar() && node->Var()) {
bool is_found = false;
for (auto* in : input_var_nodes) {
if (node == in) {
is_found = true;
break;
}
}
if (is_found) {
continue;
}
for (auto* out : output_var_nodes) {
if (node == out) {
is_found = true;
break;
}
}
PADDLE_ENFORCE_EQ(
is_found, true,
platform::errors::Unimplemented(
"Subgraph with internal var nodes (%s) is not supported yet.",
node->Name()));
}
}
// Encoding output vars.
for (auto* out : output_var_nodes) {
VLOG(3) << "Ecoding output names:" << out->Name() << ", id:" << id;
......
......@@ -43,17 +43,21 @@ class CodeGenerator {
const std::vector<OperationExpression>& expressions);
std::set<int> DistilOutputIds(
const std::vector<OperationExpression>& expressions);
std::set<int> DistilIntermediateIds(
const std::vector<OperationExpression>& expressions);
std::unordered_map<int, std::string> DistilDtypes(
const std::vector<OperationExpression>& expressions);
// we get the parameter list code for the expression information
std::string EmitParameters(
const std::set<int>& input_ids, const std::set<int>& output_ids,
const std::set<int>& intermediate_ids,
const std::unordered_map<int, std::string>& dtypes) const;
std::string EmitComputeBody(
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
const std::set<int>& intermediate_ids,
const std::unordered_map<int, std::string>& dtypes) const;
// Encode all var nodes in the subgraph with an unique number.
......
......@@ -149,8 +149,6 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
"Expected %d-th input id > 0 for operation < %s "
">. Received %d.",
index, op_type_, input_ids_[index]));
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we
// need to add general fp16 compute later.
var_name = TmpName(input_ids_[index]);
rhs.replace(pos, length + 3, var_name);
used->insert(input_ids_[index]);
......
......@@ -46,14 +46,19 @@ class OperationExpression {
public:
explicit OperationExpression(std::string op_type, std::vector<int> input_ids,
std::vector<int> output_ids,
std::string rhs_type, std::string lhs_type)
std::string rhs_type, std::string lhs_type,
std::unordered_map<int, bool> intermediate_state)
: op_type_(op_type),
input_ids_(input_ids),
output_ids_(output_ids),
rhs_type_(rhs_type),
lhs_type_(lhs_type) {}
lhs_type_(lhs_type),
intermediate_state_(intermediate_state) {}
std::string GetOpType() const { return op_type_; }
std::unordered_map<int, bool> GetIntermediateState() const {
return intermediate_state_;
}
std::vector<int> GetInputIds() const { return input_ids_; }
std::vector<int> GetOutputIds() const { return output_ids_; }
std::string GetRHSType() const { return rhs_type_; }
......@@ -78,6 +83,7 @@ class OperationExpression {
AttributeMap attr_;
std::string rhs_type_;
std::string lhs_type_;
std::unordered_map<int, bool> intermediate_state_;
};
class TemplateVariable {
......
......@@ -269,6 +269,22 @@ __CUDA_FP16_DECL__ __half hsqrt(const __half a) {
__APPROX_FCAST(sqrt);
}
#if defined(__cplusplus) && (__CUDA_ARCH__ >= 320 || !defined(__CUDA_ARCH__))
#if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
#define __LDG_PTR "l"
#else
#define __LDG_PTR "r"
#endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
__CUDA_FP16_DECL__ __half __ldg(const __half *ptr)
{
__half ret;
asm ("ld.global.nc.b16 %0, [%1];" : "=h"(__HALF_TO_US(ret)) : __LDG_PTR(ptr));
return ret;
}
#undef __LDG_PTR
#endif /*defined(__cplusplus) && (__CUDA_ARCH__ >= 320 || !defined(__CUDA_ARCH__))*/
__device__ inline __half Exp(const __half x) { return hexp(x); }
__device__ inline __half Log(const __half x) { return hlog(x); }
__device__ inline __half Sqrt(const __half x) { return hsqrt(x); }
......
......@@ -48,13 +48,18 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
int num_subgraphs = 0;
size_t min_subgraph_size = 2;
bool save_intermediate_out = true;
bool save_intermediate_out = false;
for (auto& vec : subgraphs) {
fusion_group::SubGraph subgraph(
type, "", save_intermediate_out,
std::unordered_set<Node*>(vec.begin(), vec.end()));
VLOG(3) << "subgraph: {\n" << DebugString(subgraph.SortedNodes()) << "}\n";
// In elementwise fused kernel, memory is the bound of execution,
// here we remove the output id to use less memory and less time.
if (subgraph.RemoveIntermediateOut()) {
subgraph.DetectIntermediateOutWithGraph(graph);
}
if (subgraph.IsValid(min_subgraph_size)) {
subgraph.SetFuncName("FusedElementwise" + std::to_string(index++));
if (GenerateCode(&subgraph)) {
......@@ -106,6 +111,8 @@ void FusionGroupPass::InsertFusionGroupOp(
subgraph->GetInputVarNodes();
const std::vector<Node*>& output_vars_of_subgraph =
subgraph->GetOutputVarNodes();
const std::vector<Node*> intermediate_vars_of_subgraph =
subgraph->GetIntermediateOutVarNodes();
std::unordered_set<Node*> external_nodes;
OpDesc op_desc;
......@@ -122,9 +129,18 @@ void FusionGroupPass::InsertFusionGroupOp(
std::vector<std::string> output_names;
std::vector<std::string> outs_data_types;
std::vector<Node*> output_var_without_intermediate;
for (auto* n : output_vars_of_subgraph) {
output_names.push_back(n->Name());
outs_data_types.push_back(DataTypeToString(n->Var()->GetDataType()));
auto it_input =
find(input_vars_of_subgraph.begin(), input_vars_of_subgraph.end(), n);
auto it_intermediate = find(intermediate_vars_of_subgraph.begin(),
intermediate_vars_of_subgraph.end(), n);
if (it_intermediate == intermediate_vars_of_subgraph.end() &&
it_input == input_vars_of_subgraph.end()) {
output_names.push_back(n->Name());
outs_data_types.push_back(DataTypeToString(n->Var()->GetDataType()));
output_var_without_intermediate.push_back(n);
}
external_nodes.insert(n);
}
......@@ -141,7 +157,7 @@ void FusionGroupPass::InsertFusionGroupOp(
IR_NODE_LINK_TO(in, fusion_group_node);
}
for (auto* out : output_vars_of_subgraph) {
for (auto* out : output_var_without_intermediate) {
IR_NODE_LINK_TO(fusion_group_node, out);
}
......
......@@ -54,6 +54,7 @@ void OperationMap::Insert(int type, int num_operands, std::string op_type,
std::string grad_op_type = op_type + "_grad";
// grad_inputs = inputs + outputs + grad of outputs
std::vector<std::string> grad_input_names = input_names;
for (auto name : output_names) {
grad_input_names.push_back(name);
}
......
......@@ -19,6 +19,8 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
......@@ -64,6 +66,7 @@ class SubGraph {
}
int GetType() const { return type_; }
bool RemoveIntermediateOut() { return !save_intermediate_out_; }
void SetFuncName(std::string func_name) { func_name_ = func_name; }
std::string GetFuncName() const { return func_name_; }
......@@ -133,30 +136,45 @@ class SubGraph {
}
}
}
return output_vars_all;
}
if (save_intermediate_out_) {
return output_vars_all;
}
std::vector<Node*> GetIntermediateOutVarNodes() {
return intermediate_out_nodes_;
}
std::vector<Node*> output_vars_outside;
for (auto* n : output_vars_all) {
// If one of the var_node's outputs is the input of some operator
// outside the subgraph, it is considered the output var node of the
// subgraph.
bool is_found = true;
if (n->outputs.size() == 0U) {
is_found = false;
}
for (auto* out : n->outputs) {
if (!Has(out)) {
is_found = false;
void DetectIntermediateOutWithGraph(Graph* graph) {
auto graph_nodes = graph->Nodes();
for (auto* n : SortedNodes()) {
bool enable_remove = true;
if (n && n->IsVar() && n->Var()) {
bool leaf_graph = true;
for (auto* node : graph_nodes) {
if (node->IsOp()) {
auto inputs = node->inputs;
for (auto* in : inputs) {
if (in == n) {
if (!Has(node)) enable_remove = false;
leaf_graph = false;
}
}
}
if (!enable_remove) {
break;
}
}
if (leaf_graph) enable_remove = false;
} else {
enable_remove = false;
}
if (!is_found) {
output_vars_outside.push_back(n);
if (enable_remove) {
intermediate_out_nodes_.push_back(n);
}
}
return output_vars_outside;
}
private:
......@@ -218,6 +236,7 @@ class SubGraph {
bool save_intermediate_out_{true};
std::unordered_set<Node*> nodes_set_;
std::vector<Node*> intermediate_out_nodes_{};
bool is_sorted_{false};
std::vector<Node*> sorted_nodes_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册