未验证 提交 f0d193a2 编写于 作者: W wangchaochaohu 提交者: GitHub

Cast fusion for fusion group (#22876)

* add support for expression type convert and add cast Op support in fusion group
上级 29a7a52d
......@@ -24,6 +24,21 @@ namespace framework {
namespace ir {
namespace fusion_group {
std::string ExtractDataType(const std::vector<Node*> nodes) {
std::string dtype_str = "float";
auto data_type = nodes.back()->Var()->GetDataType();
if (data_type == proto::VarType::FP32) {
dtype_str = "float";
} else if (data_type == proto::VarType::FP64) {
dtype_str = "double";
} else if (data_type == proto::VarType::FP16) {
dtype_str = "float16";
}
return dtype_str;
}
CodeGenerator::CodeGenerator() {
// Only support elementwise operations now.
code_templates_.resize(1);
......@@ -34,8 +49,7 @@ CodeGenerator::CodeGenerator() {
std::string CodeGenerator::Generate(SubGraph* subgraph) {
std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph);
return Generate(subgraph->GetFuncName(), subgraph->GetDataType(),
expressions);
return Generate(subgraph->GetFuncName(), expressions);
}
static bool HasInput(Node* n, std::string name) {
......@@ -95,8 +109,11 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
"Output(%s) of operation %s is not set.", name, op->Type()));
output_ids.push_back(var_ids[op->Output(name)[0]]);
}
expressions.push_back(
OperationExpression(node->Name(), input_ids, output_ids));
std::string lhs_type = ExtractDataType(node->outputs);
std::string rhs_type = ExtractDataType(node->inputs);
expressions.emplace_back(OperationExpression(
node->Name(), input_ids, output_ids, rhs_type, lhs_type));
}
}
return expressions;
......@@ -105,25 +122,32 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// In order to get the right result of expression, we need to calculate and
// store the expression as suffix Expressions using vector.
std::string CodeGenerator::Generate(
std::string func_name, std::string dtype,
std::string func_name,
const std::vector<OperationExpression>& expressions) {
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
std::set<int> input_ids = DistilInputIds(expressions);
std::set<int> output_ids = DistilOutputIds(expressions);
std::unordered_map<int, std::string> dtypes = DistilDtypes(expressions);
TemplateVariable template_var;
template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtype));
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtypes));
template_var.Add("compute_body",
EmitComputeBody(expressions, input_ids, output_ids, dtype));
std::string predefined_cuda_functions;
if (dtype == "float") {
predefined_cuda_functions = predefined_cuda_functions_fp32;
} else if (dtype == "double") {
predefined_cuda_functions = predefined_cuda_functions_fp64;
} else if (dtype == "float16") {
predefined_cuda_functions = predefined_cuda_functions_fp16;
EmitComputeBody(expressions, input_ids, output_ids, dtypes));
std::set<std::string> all_dtype;
for (const auto& type : dtypes) {
all_dtype.insert(type.second);
}
std::string predefined_cuda_functions = "";
if (all_dtype.find("float") != all_dtype.end() &&
all_dtype.find("float16") == all_dtype.end()) {
predefined_cuda_functions += predefined_cuda_functions_fp32;
}
if (all_dtype.find("double") != all_dtype.end()) {
predefined_cuda_functions += predefined_cuda_functions_fp64;
}
if (all_dtype.find("float16") != all_dtype.end()) {
predefined_cuda_functions += predefined_cuda_functions_fp16;
}
return predefined_cuda_functions + code_templates_[0].Format(template_var);
}
......@@ -154,10 +178,40 @@ std::set<int> CodeGenerator::DistilOutputIds(
return output_ids;
}
std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
const std::vector<OperationExpression>& expressions) {
std::unordered_map<int, std::string> dtypes;
for (const auto& expression : expressions) {
for (auto id : expression.GetInputIds()) {
auto dtype = expression.GetRHSType();
if (dtypes.find(id) == dtypes.end()) {
dtypes[id] = dtype;
} else {
PADDLE_ENFORCE_EQ(
dtypes[id], dtype,
platform::errors::PreconditionNotMet(
"In fusion group, Same Node id must have same date type"));
}
}
for (auto id : expression.GetOutputIds()) {
auto dtype = expression.GetLHSType();
if (dtypes.find(id) == dtypes.end()) {
dtypes[id] = dtype;
} else {
PADDLE_ENFORCE_EQ(
dtypes[id], dtype,
platform::errors::PreconditionNotMet(
"In fusion group, Same Node id must have same date type"));
}
}
}
return dtypes;
}
// we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::string dtype) {
std::string CodeGenerator::EmitParameters(
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::unordered_map<int, std::string> dtypes) {
std::stringstream ret;
ret << "int N, ";
......@@ -165,13 +219,13 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
// from the input list.
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end()) {
ret << dtype << "* " << ArgName(id) << ", ";
ret << dtypes[id] << "* " << ArgName(id) << ", ";
}
}
size_t index = 0;
for (auto id : output_ids) {
ret << dtype << "* " << ArgName(id);
ret << dtypes[id] << "* " << ArgName(id);
if (index != output_ids.size() - 1) {
ret << ", ";
}
......@@ -184,13 +238,12 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
std::string CodeGenerator::EmitComputeBody(
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype) {
std::unordered_map<int, std::string> dtypes) {
std::ostringstream compute;
std::unordered_set<int> used;
std::string compute_dtype = (dtype == "float16") ? "float" : dtype;
for (size_t i = 0; i < expressions.size(); i++) {
VLOG(3) << DebugString(expressions[i]);
compute << expressions[i].GetExpression(compute_dtype, &used);
compute << expressions[i].GetExpression(&used);
}
// Load input to temporal variables.
......@@ -198,23 +251,13 @@ std::string CodeGenerator::EmitComputeBody(
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end() &&
used.find(id) != used.end()) {
if (dtype == "float16") {
load << "float " << TmpName(id) << " = __half2float(" << ArgName(id)
<< "[idx]);";
} else {
load << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];";
}
load << dtypes[id] << " " << TmpName(id) << " = " << VarName(id) << ";";
}
}
// Store temporal variables to memory.
std::ostringstream store;
for (auto id : output_ids) {
if (dtype == "float16") {
store << ArgName(id) << "[idx] = __float2half(" << TmpName(id) << ");";
} else {
store << ArgName(id) << "[idx] = " << TmpName(id) << ";";
}
store << VarName(id) << " = " << TmpName(id) << ";";
}
return load.str() + compute.str() + store.str();
......
......@@ -30,7 +30,7 @@ class CodeGenerator {
public:
CodeGenerator();
std::string Generate(std::string func_name, std::string dtype,
std::string Generate(std::string func_name,
const std::vector<OperationExpression>& expressions);
std::string Generate(SubGraph* subgraph);
......@@ -42,16 +42,18 @@ class CodeGenerator {
const std::vector<OperationExpression>& expressions);
std::set<int> DistilOutputIds(
const std::vector<OperationExpression>& expressions);
std::unordered_map<int, std::string> DistilDtypes(
const std::vector<OperationExpression>& expressions);
// we get the parameter list code for the expression information
std::string EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::string dtype);
std::unordered_map<int, std::string> dtypes);
std::string EmitComputeBody(
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype);
std::unordered_map<int, std::string> dtypes);
// Encode all var nodes in the subgraph with an unique number.
std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph);
......
......@@ -50,10 +50,26 @@ static std::string ExpandMultivariateTemplate(const std::string rhs,
return sum_rhs;
}
// In order to avoid multiple __half2float function calls, we do this
// optimization
static std::string OptimzeFP16RHS(std::unordered_set<int>* used,
const int index,
const std::vector<int>& input_ids) {
std::stringstream ret;
if (used->find(input_ids[index]) == used->end()) {
ret << "float half2fp32_" + TmpName(input_ids[index]) + " = __half2float(" +
TmpName(input_ids[index]) + ");";
}
return ret.str();
}
std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
std::string* half2fp32_statement,
size_t exprs_index) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[exprs_index];
auto num_operands = OperationMap::Instance().Get(op_type_).num_operands;
if (num_operands == -1) {
size_t input_size = input_ids_.size();
rhs = ExpandMultivariateTemplate(rhs, input_size);
......@@ -78,7 +94,16 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
platform::errors::InvalidArgument(
"Expected %d-th input id > 0 for operation < %s >. Received %d.",
index, op_type_, input_ids_[index]));
rhs.replace(pos, length + 3, TmpName(input_ids_[index]));
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we need
// to add general fp16 compute later.
std::string var_name;
if (rhs_type_ == "float16") {
half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_));
var_name = "half2fp32_" + TmpName(input_ids_[index]);
} else {
var_name = TmpName(input_ids_[index]);
}
rhs.replace(pos, length + 3, var_name);
used->insert(input_ids_[index]);
}
}
......@@ -87,7 +112,7 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
std::string OperationExpression::GetLHS(size_t i) const {
std::stringstream ret;
ret << TmpName(output_ids_[i]);
ret << lhs_type_ << " " << TmpName(output_ids_[i]);
return ret.str();
}
......@@ -98,15 +123,29 @@ bool OperationExpression::IsSupport() const {
// we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group
std::string OperationExpression::GetExpression(
std::string dtype, std::unordered_set<int>* used) const {
std::unordered_set<int>* used) const {
std::string half2fp32_statement;
std::stringstream ret;
if (IsSupport()) {
for (size_t i = 0; i < output_ids_.size(); ++i) {
ret << dtype << " " << GetLHS(i) << " = " << GetRHS(used, i) << ";";
std::string cast_str = "";
if ((lhs_type_ == rhs_type_ && rhs_type_ != "float16") ||
(lhs_type_ != rhs_type_ && rhs_type_ == "float16")) {
ret << GetLHS(i) << " = " << GetRHS(used, &half2fp32_statement, i)
<< ";";
} else {
if ((lhs_type_ == rhs_type_ && rhs_type_ == "float16") ||
lhs_type_ == "float16") {
cast_str = "__float2half";
} else {
cast_str = "static_cast<" + lhs_type_ + ">";
}
ret << GetLHS(i) << " = " << cast_str << "("
<< GetRHS(used, &half2fp32_statement, i) << ");";
}
}
}
return ret.str();
return half2fp32_statement + ret.str();
}
} // namespace fusion_group
......
......@@ -30,29 +30,41 @@ namespace fusion_group {
static inline std::string ArgName(int index) {
return "arg" + std::to_string(index);
}
static inline std::string TmpName(int index) {
return "tmp" + std::to_string(index);
}
static inline std::string VarName(int index) {
return "arg" + std::to_string(index) + "[idx]";
}
class OperationExpression {
public:
explicit OperationExpression(std::string op_type, std::vector<int> input_ids,
std::vector<int> output_ids)
: op_type_(op_type), input_ids_(input_ids), output_ids_(output_ids) {}
std::vector<int> output_ids,
std::string rhs_type, std::string lhs_type)
: op_type_(op_type),
input_ids_(input_ids),
output_ids_(output_ids),
rhs_type_(rhs_type),
lhs_type_(lhs_type) {}
std::string GetOpType() const { return op_type_; }
std::vector<int> GetInputIds() const { return input_ids_; }
std::vector<int> GetOutputIds() const { return output_ids_; }
std::string GetRHSType() const { return rhs_type_; }
std::string GetLHSType() const { return lhs_type_; }
// Check whether this operation type is supported in OperationMap.
bool IsSupport() const;
std::string GetExpression(std::string dtype,
std::unordered_set<int>* used) const;
std::string GetExpression(std::unordered_set<int>* used) const;
private:
// TODO(wangchao): make offset more flexible we add stride and basic offset
std::string GetRHS(std::unordered_set<int>* used,
std::string* half2fp32_statement,
size_t exprs_index = 0) const;
std::string GetLHS(size_t i = 0) const;
......@@ -60,6 +72,8 @@ class OperationExpression {
std::string op_type_;
std::vector<int> input_ids_;
std::vector<int> output_ids_;
std::string rhs_type_;
std::string lhs_type_;
};
class TemplateVariable {
......
......@@ -288,7 +288,7 @@ void TestMain(std::string func_name,
std::string dtype) {
fusion_group::OperationMap::Init();
fusion_group::CodeGenerator code_generator;
std::string code_str = code_generator.Generate(func_name, dtype, expressions);
std::string code_str = code_generator.Generate(func_name, expressions);
VLOG(3) << code_str;
LOG(INFO) << "dtype: " << dtype;
......@@ -297,7 +297,7 @@ void TestMain(std::string func_name,
}
void TestMain(fusion_group::SubGraph* subgraph, std::vector<int> input_ids,
std::vector<int> output_ids) {
std::vector<int> output_ids, std::string dtype) {
fusion_group::OperationMap::Init();
fusion_group::CodeGenerator code_generator;
std::string code_str = code_generator.Generate(subgraph);
......@@ -307,26 +307,28 @@ void TestMain(fusion_group::SubGraph* subgraph, std::vector<int> input_ids,
std::vector<fusion_group::OperationExpression> expressions =
code_generator.ConvertToExpressions(subgraph);
LOG(INFO) << "dtype: " << subgraph->GetDataType();
TestElementwiseMain(subgraph->GetFuncName(), code_str, expressions, input_ids,
output_ids, subgraph->GetDataType());
output_ids, dtype);
}
TEST(code_generator, elementwise) {
// t2 = t0 * t1
// t4 = t2 + t3
// t6 = t4 - t5
// t7 = relu(t6)
// t8 = sigmoid(t7)
fusion_group::OperationExpression exp1("elementwise_mul", {0, 1}, {2});
fusion_group::OperationExpression exp2("elementwise_add", {2, 3}, {4});
fusion_group::OperationExpression exp3("elementwise_sub", {4, 5}, {6});
fusion_group::OperationExpression exp4("relu", {6}, {7});
fusion_group::OperationExpression exp5("sigmoid", {7}, {8});
std::vector<fusion_group::OperationExpression> expressions = {
exp1, exp2, exp3, exp4, exp5};
for (std::string dtype : {"float", "float16"}) {
// t2 = t0 * t1
// t4 = t2 + t3
// t6 = t4 - t5
// t7 = relu(t6)
// t8 = sigmoid(t7)
fusion_group::OperationExpression exp1("elementwise_mul", {0, 1}, {2},
dtype, dtype);
fusion_group::OperationExpression exp2("elementwise_add", {2, 3}, {4},
dtype, dtype);
fusion_group::OperationExpression exp3("elementwise_sub", {4, 5}, {6},
dtype, dtype);
fusion_group::OperationExpression exp4("relu", {6}, {7}, dtype, dtype);
fusion_group::OperationExpression exp5("sigmoid", {7}, {8}, dtype, dtype);
std::vector<fusion_group::OperationExpression> expressions = {
exp1, exp2, exp3, exp4, exp5};
// Expressions:
// Op(elementwise_mul), inputs:{0,1}, outputs:{2}
// Op(elementwise_add), inputs:{2,3}, outputs:{4}
......@@ -340,17 +342,18 @@ TEST(code_generator, elementwise) {
}
TEST(code_generator, elementwise_grad) {
// The var order: t0, t1, t2, t3, t0', t1', t2', t3'
// t2 = t0 * t1
// t3 = relu(t2)
// t2' = relu_grad(t2, t3, t3')
// t0', t1' = elementwise_mul_grad(t0, t1, t2, t2')
fusion_group::OperationExpression exp1("relu_grad", {-1, 3, 7}, {6});
fusion_group::OperationExpression exp2("elementwise_mul_grad", {0, 1, 2, 6},
{4, 5});
std::vector<fusion_group::OperationExpression> expressions = {exp1, exp2};
for (std::string dtype : {"float", "float16"}) {
// The var order: t0, t1, t2, t3, t0', t1', t2', t3'
// t2 = t0 * t1
// t3 = relu(t2)
// t2' = relu_grad(t2, t3, t3')
// t0', t1' = elementwise_mul_grad(t0, t1, t2, t2')
fusion_group::OperationExpression exp1("relu_grad", {-1, 3, 7}, {6}, dtype,
dtype);
fusion_group::OperationExpression exp2("elementwise_mul_grad", {0, 1, 2, 6},
{4, 5}, dtype, dtype);
std::vector<fusion_group::OperationExpression> expressions = {exp1, exp2};
// Expressions:
// Op(relu_grad), inputs:{2,3,7}, outputs:{6}
// Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5}
......@@ -474,7 +477,7 @@ TEST(code_generator, subgraph) {
// Op(elementwise_add), inputs:{7,6}, outputs:{8}
std::vector<int> input_ids = {0, 1, 2, 3};
std::vector<int> output_ids = {4, 5, 6, 7, 8};
TestMain(&subgraph, input_ids, output_ids);
TestMain(&subgraph, input_ids, output_ids, dtype);
}
}
......@@ -493,7 +496,7 @@ TEST(code_generator, subgraph_grad) {
// Op(tanh_grad), inputs:{9,4,13}, outputs:{14}
std::vector<int> input_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
std::vector<int> output_ids = {10, 11, 12, 13, 14, 15, 16, 17};
TestMain(&subgraph, input_ids, output_ids);
TestMain(&subgraph, input_ids, output_ids, dtype);
}
}
#endif
......@@ -60,6 +60,50 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
return l.size() != 0U && r.size() != 0U && l == r;
}
bool GroupDetector::IsFusionGroupOp(const Node* n) {
if (!(n && n->IsOp() && n->Op())) return false;
bool is_first = true;
proto::VarType::Type i_data_type = proto::VarType::FP32;
proto::VarType::Type o_data_type = proto::VarType::FP32;
for (auto* i_node : n->inputs) {
if (!i_node->Var()) return false;
if (i_node->Var()->GetType() != proto::VarType::LOD_TENSOR) {
return false;
}
if (is_first) {
i_data_type = i_node->Var()->GetDataType();
is_first = false;
} else {
if (i_data_type != i_node->Var()->GetDataType()) return false;
}
}
is_first = true;
for (auto* o_node : n->outputs) {
if (!o_node->Var()) return false;
if (o_node->Var()->GetType() != proto::VarType::LOD_TENSOR) {
return false;
}
if (is_first) {
o_data_type = o_node->Var()->GetDataType();
is_first = false;
} else {
if (o_data_type != o_node->Var()->GetDataType()) return false;
}
}
if (!(i_data_type == proto::VarType::FP32 ||
i_data_type == proto::VarType::FP64 ||
i_data_type == proto::VarType::FP16) ||
!(o_data_type == proto::VarType::FP32 ||
o_data_type == proto::VarType::FP64 ||
o_data_type == proto::VarType::FP16))
return false;
return true;
}
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) {
std::vector<int64_t> shape_0;
......@@ -85,7 +129,9 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()(
Graph* graph) {
auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); };
auto teller = [&](const Node* n) -> bool {
return IsFusionGroupOp(n) && IsElementwiseOp(n);
};
return SubgraphDetector(graph, teller)();
}
......
......@@ -23,7 +23,12 @@ namespace framework {
namespace ir {
namespace fusion_group {
class ElementwiseGroupDetector {
class GroupDetector {
protected:
bool IsFusionGroupOp(const Node* n);
};
class ElementwiseGroupDetector : GroupDetector {
public:
std::vector<std::vector<Node*>> operator()(Graph* graph);
......
......@@ -110,18 +110,25 @@ void FusionGroupPass::InsertFusionGroupOp(
op_desc.SetType("fusion_group");
std::vector<std::string> input_names;
std::vector<std::string> inputs_data_types;
for (auto* n : input_vars_of_subgraph) {
input_names.push_back(n->Name());
inputs_data_types.push_back(DataTypeToString(n->Var()->GetDataType()));
external_nodes.insert(n);
}
op_desc.SetInput("Inputs", input_names);
std::vector<std::string> output_names;
std::vector<std::string> outs_data_types;
for (auto* n : output_vars_of_subgraph) {
output_names.push_back(n->Name());
outs_data_types.push_back(DataTypeToString(n->Var()->GetDataType()));
external_nodes.insert(n);
}
op_desc.SetOutput("Outs", output_names);
op_desc.SetAttr("inputs_data_type", inputs_data_types);
op_desc.SetAttr("outs_data_type", outs_data_types);
op_desc.SetAttr("type", subgraph->GetType());
op_desc.SetAttr("func_name", subgraph->GetFuncName());
op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
......@@ -131,6 +138,7 @@ void FusionGroupPass::InsertFusionGroupOp(
for (auto* in : input_vars_of_subgraph) {
IR_NODE_LINK_TO(in, fusion_group_node);
}
for (auto* out : output_vars_of_subgraph) {
IR_NODE_LINK_TO(fusion_group_node, out);
}
......
......@@ -102,6 +102,13 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// dx = dout * (1 - out * out)
insert_handler("tanh", "2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0",
{"${2} * (1.0 - ${1} * ${1})"});
// cast
// out = static_cast<T>(d)
// dx = static_cast<T>(d_out)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler("cast", "${0}", {"${0}"});
}
void OperationMap::InsertBinaryElementwiseOperations() {
......@@ -158,10 +165,12 @@ void OperationMap::InsertMultivariateElementwiseOperations() {
std::vector<std::string> grad_exprs) {
int type = 0;
int num_oprands = -1;
// here ... represent the number of input is changed
Insert(type, num_oprands, op_type, expr, grad_exprs, {"X"}, {"Out"});
};
// here [] represent the number of input is positive(>=0).
// if input list size of Sum Op is 3, It will expand as
// ${0} + ${1} + ${2}
insert_handler("sum", "${0}[ + ${?}]", {});
}
......
......@@ -49,7 +49,6 @@ class SubGraph {
}
}
}
ExtractDataType();
}
bool IsValid(int min_subgraph_size) {
......@@ -61,11 +60,10 @@ class SubGraph {
return false;
}
return ExtractDataType();
return true;
}
int GetType() const { return type_; }
std::string GetDataType() const { return data_type_; }
void SetFuncName(std::string func_name) { func_name_ = func_name; }
std::string GetFuncName() const { return func_name_; }
......@@ -162,37 +160,6 @@ class SubGraph {
}
private:
bool ExtractDataType() {
bool is_first = true;
proto::VarType::Type data_type = proto::VarType::FP32;
for (auto* n : nodes_set_) {
if (n && n->IsVar() && n->Var()) {
if (n->Var()->GetType() != proto::VarType::LOD_TENSOR) {
// All var node in a subgraph should hold a LoDTensor.
return false;
}
if (is_first) {
data_type = n->Var()->GetDataType();
is_first = false;
} else if (n->Var()->GetDataType() != data_type) {
// DataType of VarDesc in a subgraph is not the same.
return false;
}
}
}
if (data_type == proto::VarType::FP32) {
data_type_ = "float";
} else if (data_type == proto::VarType::FP64) {
data_type_ = "double";
} else if (data_type == proto::VarType::FP16) {
data_type_ = "float16";
} else {
VLOG(2) << "Only support fp32, fp64 and fp16 in fusion_group.";
return false;
}
return true;
}
void TopologicalSort() {
if (!is_sorted_) {
std::unordered_map<Node*, std::vector<Node*>> inputs_map;
......
......@@ -21,7 +21,7 @@ class FusionGroupOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
const size_t num_ins = ctx->Inputs("Inputs").size();
const size_t num_outs = ctx->Outputs("Outs").size();
......@@ -58,6 +58,13 @@ class FusionGroupOp : public framework::OperatorWithKernel {
ctx->ShareLoD("Inputs", /*->*/ "Outs", 0, j);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CUDAPlace(0));
};
};
class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -69,6 +76,12 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Outs",
"(std::vector<LoDTensor>) The outputs of fusion_group op.")
.AsDuplicable();
AddAttr<std::vector<std::string>>(
"outs_data_type", "The data type of Outputs in fusion_group op.")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"inputs_data_type", "The data type of Inputs in fusion_group op.")
.SetDefault({});
AddAttr<int>("type", "Fusion type.").SetDefault(0);
AddAttr<std::string>("func_name", "Name of the generated functions.")
.SetDefault("");
......
......@@ -22,6 +22,20 @@ limitations under the License. */
namespace paddle {
namespace operators {
static void MutableMultiTypeData(
std::vector<paddle::framework::LoDTensor*>* var,
const std::vector<std::string>& data_type, const platform::Place& place) {
for (size_t i = 0; i < (*var).size(); i++) {
if (data_type[i] == "float") {
(*var)[i]->mutable_data<float>(place);
} else if (data_type[i] == "double") {
(*var)[i]->mutable_data<double>(place);
} else if (data_type[i] == "::paddle::platform::float16") {
(*var)[i]->mutable_data<paddle::platform::float16>(place);
}
}
}
template <typename DeviceContext, typename T>
class FusionGroupKernel : public framework::OpKernel<T> {
public:
......@@ -29,14 +43,15 @@ class FusionGroupKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<framework::LoDTensor>("Inputs");
auto outs = ctx.MultiOutput<framework::LoDTensor>("Outs");
int type = ctx.Attr<int>("type");
auto outs_type = ctx.Attr<std::vector<std::string>>("outs_data_type");
auto inputs_type = ctx.Attr<std::vector<std::string>>("inputs_data_type");
size_t num_ins = ins.size();
size_t num_outs = outs.size();
auto place = ctx.GetPlace();
for (size_t i = 0; i < num_outs; ++i) {
outs[i]->mutable_data<T>(place);
}
MutableMultiTypeData(&outs, outs_type, place);
std::string func_name = ctx.Attr<std::string>("func_name");
platform::DeviceCode* dev_code =
......@@ -47,13 +62,25 @@ class FusionGroupKernel : public framework::OpKernel<T> {
size_t n = ins[0]->numel();
std::vector<void*> args;
args.push_back(&n);
std::vector<const T*> ptrs(num_ins + num_outs);
std::vector<const void*> ptrs(num_ins + num_outs);
for (size_t i = 0; i < num_ins; ++i) {
ptrs[i] = ins[i]->data<T>();
if (inputs_type[i] == "::paddle::platform::float16") {
ptrs[i] = ins[i]->data<paddle::platform::float16>();
} else if (inputs_type[i] == "double") {
ptrs[i] = ins[i]->data<double>();
} else if (inputs_type[i] == "float") {
ptrs[i] = ins[i]->data<float>();
}
args.push_back(&ptrs[i]);
}
for (size_t j = 0; j < num_outs; ++j) {
ptrs[num_ins + j] = outs[j]->data<T>();
if (outs_type[j] == "::paddle::platform::float16") {
ptrs[num_ins + j] = outs[j]->data<paddle::platform::float16>();
} else if (outs_type[j] == "double") {
ptrs[num_ins + j] = outs[j]->data<double>();
} else if (outs_type[j] == "float") {
ptrs[num_ins + j] = outs[j]->data<float>();
}
args.push_back(&ptrs[num_ins + j]);
}
dev_code->Launch(n, &args);
......
......@@ -57,7 +57,8 @@ framework::OpDesc* CreateFusionGroupOp(
const std::vector<std::string>& input_names,
const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::string>& output_names, int type,
std::string func_name) {
const std::vector<std::string>& inputs_data_type,
const std::vector<std::string>& outs_data_type, std::string func_name) {
EXPECT_EQ(input_names.size(), input_shapes.size());
for (size_t i = 0; i < input_names.size(); ++i) {
......@@ -76,6 +77,8 @@ framework::OpDesc* CreateFusionGroupOp(
op->SetType("fusion_group");
op->SetInput("Inputs", input_names);
op->SetOutput("Outs", output_names);
op->SetAttr("inputs_data_type", inputs_data_type);
op->SetAttr("outs_data_type", outs_data_type);
op->SetAttr("type", type);
op->SetAttr("func_name", func_name);
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
......@@ -130,6 +133,8 @@ void CheckOutputs(framework::Scope* scope,
void TestMain(const std::vector<std::string>& input_names,
const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::string>& output_names, int type,
const std::vector<std::string>& inputs_data_type,
const std::vector<std::string>& outs_data_type,
std::string func_name, std::string cuda_kernel_str,
CPUKernelFunc cpu_kernel_func) {
// Compile the device code
......@@ -139,8 +144,9 @@ void TestMain(const std::vector<std::string>& input_names,
// Create a ProgramDesc that has a fusion_group_op.
framework::ProgramDesc program;
framework::OpDesc* op_desc = CreateFusionGroupOp(
&program, input_names, input_shapes, output_names, type, func_name);
framework::OpDesc* op_desc =
CreateFusionGroupOp(&program, input_names, input_shapes, output_names,
type, inputs_data_type, outs_data_type, func_name);
auto fusion_group_op = framework::OpRegistry::CreateOp(*op_desc);
framework::Scope scope;
......@@ -210,8 +216,11 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) {
}
};
TestMain(input_names, input_shapes, output_names, 0,
"elementwise_cuda_kernel_0", kernel, elementwise_cpu_kernel_0);
std::vector<std::string> inputs_data_type(input_names.size(), "float");
std::vector<std::string> outs_data_type(output_names.size(), "float");
TestMain(input_names, input_shapes, output_names, 0, inputs_data_type,
outs_data_type, "elementwise_cuda_kernel_0", kernel,
elementwise_cpu_kernel_0);
}
} // namespace operators
......
......@@ -142,8 +142,8 @@ class PassTest(unittest.TestCase):
self.assertTrue(
np.allclose(
outs_opt[i], outs[i], atol=atol),
"Output < {} > has diff at {}".format(self.fetch_list[i].name,
str(place)))
"Output < {} > has diff at {}, expected {} but got {}".format(
self.fetch_list[i].name, str(place), outs_opt[i], outs[i]))
def _check_fused_ops(self, program):
'''
......
......@@ -125,17 +125,15 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
fluid.data(
name="data2", shape=[128, 128], dtype=dtype))
# subgraph with only 1 op node
tmp_0 = self.feed_vars[0] * self.feed_vars[1]
tmp_1 = layers.mul(tmp_0, self.feed_vars[2])
tmp_2 = layers.cast(tmp_0, dtype="float16")
tmp_3 = layers.cast(tmp_1, dtype="float16")
# subgraph with 2 op nodes
tmp_2 = layers.cast(tmp_0, dtype="float16")
tmp_4 = layers.relu(tmp_2 + tmp_3)
tmp_5 = layers.cast(tmp_4, dtype=dtype)
self.fetch_list = [tmp_5]
self.num_fused_ops = 1
self.fetch_list = [tmp_0, tmp_1, tmp_2, tmp_3, tmp_4, tmp_5]
self.num_fused_ops = 2
class FusionGroupPassSumTest(FusionGroupPassTest):
......@@ -147,9 +145,28 @@ class FusionGroupPassSumTest(FusionGroupPassTest):
tmp_1 = layers.sum([tmp_0, self.feed_vars[2], self.feed_vars[3]])
tmp_2 = layers.sum([tmp_1, self.feed_vars[4]])
self.fetch_list = [tmp_0, tmp_1]
self.fetch_list = [tmp_0, tmp_1, tmp_2]
self.num_fused_ops = 1
class FusionGroupPassCastTest(FusionGroupPassTest):
def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([2, 2], dtype, 2)
tmp_0 = layers.elementwise_add(self.feed_vars[0], self.feed_vars[1])
tmp_1 = layers.cast(tmp_0, dtype="double")
tmp_2 = layers.cast(tmp_1, dtype="float32")
self.fetch_list = [tmp_0, tmp_1, tmp_2]
self.num_fused_ops = 1
def setUp(self):
self.build_program("float64")
self.feeds = self._feed_random_data(self.feed_vars)
self.pass_names = "fusion_group_pass"
self.fused_op_type = "fusion_group"
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册