未验证 提交 f0d193a2 编写于 作者: W wangchaochaohu 提交者: GitHub

Cast fusion for fusion group (#22876)

* add support for expression type convert and add cast Op support in fusion group
上级 29a7a52d
...@@ -24,6 +24,21 @@ namespace framework { ...@@ -24,6 +24,21 @@ namespace framework {
namespace ir { namespace ir {
namespace fusion_group { namespace fusion_group {
std::string ExtractDataType(const std::vector<Node*> nodes) {
std::string dtype_str = "float";
auto data_type = nodes.back()->Var()->GetDataType();
if (data_type == proto::VarType::FP32) {
dtype_str = "float";
} else if (data_type == proto::VarType::FP64) {
dtype_str = "double";
} else if (data_type == proto::VarType::FP16) {
dtype_str = "float16";
}
return dtype_str;
}
CodeGenerator::CodeGenerator() { CodeGenerator::CodeGenerator() {
// Only support elementwise operations now. // Only support elementwise operations now.
code_templates_.resize(1); code_templates_.resize(1);
...@@ -34,8 +49,7 @@ CodeGenerator::CodeGenerator() { ...@@ -34,8 +49,7 @@ CodeGenerator::CodeGenerator() {
std::string CodeGenerator::Generate(SubGraph* subgraph) { std::string CodeGenerator::Generate(SubGraph* subgraph) {
std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph); std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph);
return Generate(subgraph->GetFuncName(), subgraph->GetDataType(), return Generate(subgraph->GetFuncName(), expressions);
expressions);
} }
static bool HasInput(Node* n, std::string name) { static bool HasInput(Node* n, std::string name) {
...@@ -95,8 +109,11 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( ...@@ -95,8 +109,11 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
"Output(%s) of operation %s is not set.", name, op->Type())); "Output(%s) of operation %s is not set.", name, op->Type()));
output_ids.push_back(var_ids[op->Output(name)[0]]); output_ids.push_back(var_ids[op->Output(name)[0]]);
} }
expressions.push_back(
OperationExpression(node->Name(), input_ids, output_ids)); std::string lhs_type = ExtractDataType(node->outputs);
std::string rhs_type = ExtractDataType(node->inputs);
expressions.emplace_back(OperationExpression(
node->Name(), input_ids, output_ids, rhs_type, lhs_type));
} }
} }
return expressions; return expressions;
...@@ -105,25 +122,32 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( ...@@ -105,25 +122,32 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// In order to get the right result of expression, we need to calculate and // In order to get the right result of expression, we need to calculate and
// store the expression as suffix Expressions using vector. // store the expression as suffix Expressions using vector.
std::string CodeGenerator::Generate( std::string CodeGenerator::Generate(
std::string func_name, std::string dtype, std::string func_name,
const std::vector<OperationExpression>& expressions) { const std::vector<OperationExpression>& expressions) {
// TODO(liuyiqun): Check whether all expressions are elementwise operations. // TODO(liuyiqun): Check whether all expressions are elementwise operations.
std::set<int> input_ids = DistilInputIds(expressions); std::set<int> input_ids = DistilInputIds(expressions);
std::set<int> output_ids = DistilOutputIds(expressions); std::set<int> output_ids = DistilOutputIds(expressions);
std::unordered_map<int, std::string> dtypes = DistilDtypes(expressions);
TemplateVariable template_var; TemplateVariable template_var;
template_var.Add("func_name", func_name); template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtype)); template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtypes));
template_var.Add("compute_body", template_var.Add("compute_body",
EmitComputeBody(expressions, input_ids, output_ids, dtype)); EmitComputeBody(expressions, input_ids, output_ids, dtypes));
std::string predefined_cuda_functions; std::set<std::string> all_dtype;
if (dtype == "float") { for (const auto& type : dtypes) {
predefined_cuda_functions = predefined_cuda_functions_fp32; all_dtype.insert(type.second);
} else if (dtype == "double") { }
predefined_cuda_functions = predefined_cuda_functions_fp64; std::string predefined_cuda_functions = "";
} else if (dtype == "float16") { if (all_dtype.find("float") != all_dtype.end() &&
predefined_cuda_functions = predefined_cuda_functions_fp16; all_dtype.find("float16") == all_dtype.end()) {
predefined_cuda_functions += predefined_cuda_functions_fp32;
}
if (all_dtype.find("double") != all_dtype.end()) {
predefined_cuda_functions += predefined_cuda_functions_fp64;
}
if (all_dtype.find("float16") != all_dtype.end()) {
predefined_cuda_functions += predefined_cuda_functions_fp16;
} }
return predefined_cuda_functions + code_templates_[0].Format(template_var); return predefined_cuda_functions + code_templates_[0].Format(template_var);
} }
...@@ -154,10 +178,40 @@ std::set<int> CodeGenerator::DistilOutputIds( ...@@ -154,10 +178,40 @@ std::set<int> CodeGenerator::DistilOutputIds(
return output_ids; return output_ids;
} }
std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
const std::vector<OperationExpression>& expressions) {
std::unordered_map<int, std::string> dtypes;
for (const auto& expression : expressions) {
for (auto id : expression.GetInputIds()) {
auto dtype = expression.GetRHSType();
if (dtypes.find(id) == dtypes.end()) {
dtypes[id] = dtype;
} else {
PADDLE_ENFORCE_EQ(
dtypes[id], dtype,
platform::errors::PreconditionNotMet(
"In fusion group, Same Node id must have same date type"));
}
}
for (auto id : expression.GetOutputIds()) {
auto dtype = expression.GetLHSType();
if (dtypes.find(id) == dtypes.end()) {
dtypes[id] = dtype;
} else {
PADDLE_ENFORCE_EQ(
dtypes[id], dtype,
platform::errors::PreconditionNotMet(
"In fusion group, Same Node id must have same date type"));
}
}
}
return dtypes;
}
// we get the parameter list code for the expression information // we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids, std::string CodeGenerator::EmitParameters(
const std::set<int>& output_ids, const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype) { std::unordered_map<int, std::string> dtypes) {
std::stringstream ret; std::stringstream ret;
ret << "int N, "; ret << "int N, ";
...@@ -165,13 +219,13 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids, ...@@ -165,13 +219,13 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
// from the input list. // from the input list.
for (auto id : input_ids) { for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end()) { if (output_ids.find(id) == output_ids.end()) {
ret << dtype << "* " << ArgName(id) << ", "; ret << dtypes[id] << "* " << ArgName(id) << ", ";
} }
} }
size_t index = 0; size_t index = 0;
for (auto id : output_ids) { for (auto id : output_ids) {
ret << dtype << "* " << ArgName(id); ret << dtypes[id] << "* " << ArgName(id);
if (index != output_ids.size() - 1) { if (index != output_ids.size() - 1) {
ret << ", "; ret << ", ";
} }
...@@ -184,13 +238,12 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids, ...@@ -184,13 +238,12 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
std::string CodeGenerator::EmitComputeBody( std::string CodeGenerator::EmitComputeBody(
const std::vector<OperationExpression>& expressions, const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids, const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype) { std::unordered_map<int, std::string> dtypes) {
std::ostringstream compute; std::ostringstream compute;
std::unordered_set<int> used; std::unordered_set<int> used;
std::string compute_dtype = (dtype == "float16") ? "float" : dtype;
for (size_t i = 0; i < expressions.size(); i++) { for (size_t i = 0; i < expressions.size(); i++) {
VLOG(3) << DebugString(expressions[i]); VLOG(3) << DebugString(expressions[i]);
compute << expressions[i].GetExpression(compute_dtype, &used); compute << expressions[i].GetExpression(&used);
} }
// Load input to temporal variables. // Load input to temporal variables.
...@@ -198,23 +251,13 @@ std::string CodeGenerator::EmitComputeBody( ...@@ -198,23 +251,13 @@ std::string CodeGenerator::EmitComputeBody(
for (auto id : input_ids) { for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end() && if (output_ids.find(id) == output_ids.end() &&
used.find(id) != used.end()) { used.find(id) != used.end()) {
if (dtype == "float16") { load << dtypes[id] << " " << TmpName(id) << " = " << VarName(id) << ";";
load << "float " << TmpName(id) << " = __half2float(" << ArgName(id)
<< "[idx]);";
} else {
load << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];";
}
} }
} }
// Store temporal variables to memory. // Store temporal variables to memory.
std::ostringstream store; std::ostringstream store;
for (auto id : output_ids) { for (auto id : output_ids) {
if (dtype == "float16") { store << VarName(id) << " = " << TmpName(id) << ";";
store << ArgName(id) << "[idx] = __float2half(" << TmpName(id) << ");";
} else {
store << ArgName(id) << "[idx] = " << TmpName(id) << ";";
}
} }
return load.str() + compute.str() + store.str(); return load.str() + compute.str() + store.str();
......
...@@ -30,7 +30,7 @@ class CodeGenerator { ...@@ -30,7 +30,7 @@ class CodeGenerator {
public: public:
CodeGenerator(); CodeGenerator();
std::string Generate(std::string func_name, std::string dtype, std::string Generate(std::string func_name,
const std::vector<OperationExpression>& expressions); const std::vector<OperationExpression>& expressions);
std::string Generate(SubGraph* subgraph); std::string Generate(SubGraph* subgraph);
...@@ -42,16 +42,18 @@ class CodeGenerator { ...@@ -42,16 +42,18 @@ class CodeGenerator {
const std::vector<OperationExpression>& expressions); const std::vector<OperationExpression>& expressions);
std::set<int> DistilOutputIds( std::set<int> DistilOutputIds(
const std::vector<OperationExpression>& expressions); const std::vector<OperationExpression>& expressions);
std::unordered_map<int, std::string> DistilDtypes(
const std::vector<OperationExpression>& expressions);
// we get the parameter list code for the expression information // we get the parameter list code for the expression information
std::string EmitParameters(const std::set<int>& input_ids, std::string EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids, const std::set<int>& output_ids,
std::string dtype); std::unordered_map<int, std::string> dtypes);
std::string EmitComputeBody( std::string EmitComputeBody(
const std::vector<OperationExpression>& expressions, const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids, const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype); std::unordered_map<int, std::string> dtypes);
// Encode all var nodes in the subgraph with an unique number. // Encode all var nodes in the subgraph with an unique number.
std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph); std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph);
......
...@@ -50,10 +50,26 @@ static std::string ExpandMultivariateTemplate(const std::string rhs, ...@@ -50,10 +50,26 @@ static std::string ExpandMultivariateTemplate(const std::string rhs,
return sum_rhs; return sum_rhs;
} }
// In order to avoid multiple __half2float function calls, we do this
// optimization
static std::string OptimzeFP16RHS(std::unordered_set<int>* used,
const int index,
const std::vector<int>& input_ids) {
std::stringstream ret;
if (used->find(input_ids[index]) == used->end()) {
ret << "float half2fp32_" + TmpName(input_ids[index]) + " = __half2float(" +
TmpName(input_ids[index]) + ");";
}
return ret.str();
}
std::string OperationExpression::GetRHS(std::unordered_set<int>* used, std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
std::string* half2fp32_statement,
size_t exprs_index) const { size_t exprs_index) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[exprs_index]; auto rhs = OperationMap::Instance().Get(op_type_).exprs[exprs_index];
auto num_operands = OperationMap::Instance().Get(op_type_).num_operands; auto num_operands = OperationMap::Instance().Get(op_type_).num_operands;
if (num_operands == -1) { if (num_operands == -1) {
size_t input_size = input_ids_.size(); size_t input_size = input_ids_.size();
rhs = ExpandMultivariateTemplate(rhs, input_size); rhs = ExpandMultivariateTemplate(rhs, input_size);
...@@ -78,7 +94,16 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used, ...@@ -78,7 +94,16 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Expected %d-th input id > 0 for operation < %s >. Received %d.", "Expected %d-th input id > 0 for operation < %s >. Received %d.",
index, op_type_, input_ids_[index])); index, op_type_, input_ids_[index]));
rhs.replace(pos, length + 3, TmpName(input_ids_[index])); // TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we need
// to add general fp16 compute later.
std::string var_name;
if (rhs_type_ == "float16") {
half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_));
var_name = "half2fp32_" + TmpName(input_ids_[index]);
} else {
var_name = TmpName(input_ids_[index]);
}
rhs.replace(pos, length + 3, var_name);
used->insert(input_ids_[index]); used->insert(input_ids_[index]);
} }
} }
...@@ -87,7 +112,7 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used, ...@@ -87,7 +112,7 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
std::string OperationExpression::GetLHS(size_t i) const { std::string OperationExpression::GetLHS(size_t i) const {
std::stringstream ret; std::stringstream ret;
ret << TmpName(output_ids_[i]); ret << lhs_type_ << " " << TmpName(output_ids_[i]);
return ret.str(); return ret.str();
} }
...@@ -98,15 +123,29 @@ bool OperationExpression::IsSupport() const { ...@@ -98,15 +123,29 @@ bool OperationExpression::IsSupport() const {
// we Traverse the graph and get the group , all input id and output id is // we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group // unique for the node which belong the group
std::string OperationExpression::GetExpression( std::string OperationExpression::GetExpression(
std::string dtype, std::unordered_set<int>* used) const { std::unordered_set<int>* used) const {
std::string half2fp32_statement;
std::stringstream ret; std::stringstream ret;
if (IsSupport()) { if (IsSupport()) {
for (size_t i = 0; i < output_ids_.size(); ++i) { for (size_t i = 0; i < output_ids_.size(); ++i) {
ret << dtype << " " << GetLHS(i) << " = " << GetRHS(used, i) << ";"; std::string cast_str = "";
if ((lhs_type_ == rhs_type_ && rhs_type_ != "float16") ||
(lhs_type_ != rhs_type_ && rhs_type_ == "float16")) {
ret << GetLHS(i) << " = " << GetRHS(used, &half2fp32_statement, i)
<< ";";
} else {
if ((lhs_type_ == rhs_type_ && rhs_type_ == "float16") ||
lhs_type_ == "float16") {
cast_str = "__float2half";
} else {
cast_str = "static_cast<" + lhs_type_ + ">";
}
ret << GetLHS(i) << " = " << cast_str << "("
<< GetRHS(used, &half2fp32_statement, i) << ");";
}
} }
} }
return half2fp32_statement + ret.str();
return ret.str();
} }
} // namespace fusion_group } // namespace fusion_group
......
...@@ -30,29 +30,41 @@ namespace fusion_group { ...@@ -30,29 +30,41 @@ namespace fusion_group {
static inline std::string ArgName(int index) { static inline std::string ArgName(int index) {
return "arg" + std::to_string(index); return "arg" + std::to_string(index);
} }
static inline std::string TmpName(int index) { static inline std::string TmpName(int index) {
return "tmp" + std::to_string(index); return "tmp" + std::to_string(index);
} }
static inline std::string VarName(int index) {
return "arg" + std::to_string(index) + "[idx]";
}
class OperationExpression { class OperationExpression {
public: public:
explicit OperationExpression(std::string op_type, std::vector<int> input_ids, explicit OperationExpression(std::string op_type, std::vector<int> input_ids,
std::vector<int> output_ids) std::vector<int> output_ids,
: op_type_(op_type), input_ids_(input_ids), output_ids_(output_ids) {} std::string rhs_type, std::string lhs_type)
: op_type_(op_type),
input_ids_(input_ids),
output_ids_(output_ids),
rhs_type_(rhs_type),
lhs_type_(lhs_type) {}
std::string GetOpType() const { return op_type_; } std::string GetOpType() const { return op_type_; }
std::vector<int> GetInputIds() const { return input_ids_; } std::vector<int> GetInputIds() const { return input_ids_; }
std::vector<int> GetOutputIds() const { return output_ids_; } std::vector<int> GetOutputIds() const { return output_ids_; }
std::string GetRHSType() const { return rhs_type_; }
std::string GetLHSType() const { return lhs_type_; }
// Check whether this operation type is supported in OperationMap. // Check whether this operation type is supported in OperationMap.
bool IsSupport() const; bool IsSupport() const;
std::string GetExpression(std::string dtype, std::string GetExpression(std::unordered_set<int>* used) const;
std::unordered_set<int>* used) const;
private: private:
// TODO(wangchao): make offset more flexible we add stride and basic offset // TODO(wangchao): make offset more flexible we add stride and basic offset
std::string GetRHS(std::unordered_set<int>* used, std::string GetRHS(std::unordered_set<int>* used,
std::string* half2fp32_statement,
size_t exprs_index = 0) const; size_t exprs_index = 0) const;
std::string GetLHS(size_t i = 0) const; std::string GetLHS(size_t i = 0) const;
...@@ -60,6 +72,8 @@ class OperationExpression { ...@@ -60,6 +72,8 @@ class OperationExpression {
std::string op_type_; std::string op_type_;
std::vector<int> input_ids_; std::vector<int> input_ids_;
std::vector<int> output_ids_; std::vector<int> output_ids_;
std::string rhs_type_;
std::string lhs_type_;
}; };
class TemplateVariable { class TemplateVariable {
......
...@@ -288,7 +288,7 @@ void TestMain(std::string func_name, ...@@ -288,7 +288,7 @@ void TestMain(std::string func_name,
std::string dtype) { std::string dtype) {
fusion_group::OperationMap::Init(); fusion_group::OperationMap::Init();
fusion_group::CodeGenerator code_generator; fusion_group::CodeGenerator code_generator;
std::string code_str = code_generator.Generate(func_name, dtype, expressions); std::string code_str = code_generator.Generate(func_name, expressions);
VLOG(3) << code_str; VLOG(3) << code_str;
LOG(INFO) << "dtype: " << dtype; LOG(INFO) << "dtype: " << dtype;
...@@ -297,7 +297,7 @@ void TestMain(std::string func_name, ...@@ -297,7 +297,7 @@ void TestMain(std::string func_name,
} }
void TestMain(fusion_group::SubGraph* subgraph, std::vector<int> input_ids, void TestMain(fusion_group::SubGraph* subgraph, std::vector<int> input_ids,
std::vector<int> output_ids) { std::vector<int> output_ids, std::string dtype) {
fusion_group::OperationMap::Init(); fusion_group::OperationMap::Init();
fusion_group::CodeGenerator code_generator; fusion_group::CodeGenerator code_generator;
std::string code_str = code_generator.Generate(subgraph); std::string code_str = code_generator.Generate(subgraph);
...@@ -307,26 +307,28 @@ void TestMain(fusion_group::SubGraph* subgraph, std::vector<int> input_ids, ...@@ -307,26 +307,28 @@ void TestMain(fusion_group::SubGraph* subgraph, std::vector<int> input_ids,
std::vector<fusion_group::OperationExpression> expressions = std::vector<fusion_group::OperationExpression> expressions =
code_generator.ConvertToExpressions(subgraph); code_generator.ConvertToExpressions(subgraph);
LOG(INFO) << "dtype: " << subgraph->GetDataType();
TestElementwiseMain(subgraph->GetFuncName(), code_str, expressions, input_ids, TestElementwiseMain(subgraph->GetFuncName(), code_str, expressions, input_ids,
output_ids, subgraph->GetDataType()); output_ids, dtype);
} }
TEST(code_generator, elementwise) { TEST(code_generator, elementwise) {
// t2 = t0 * t1
// t4 = t2 + t3
// t6 = t4 - t5
// t7 = relu(t6)
// t8 = sigmoid(t7)
fusion_group::OperationExpression exp1("elementwise_mul", {0, 1}, {2});
fusion_group::OperationExpression exp2("elementwise_add", {2, 3}, {4});
fusion_group::OperationExpression exp3("elementwise_sub", {4, 5}, {6});
fusion_group::OperationExpression exp4("relu", {6}, {7});
fusion_group::OperationExpression exp5("sigmoid", {7}, {8});
std::vector<fusion_group::OperationExpression> expressions = {
exp1, exp2, exp3, exp4, exp5};
for (std::string dtype : {"float", "float16"}) { for (std::string dtype : {"float", "float16"}) {
// t2 = t0 * t1
// t4 = t2 + t3
// t6 = t4 - t5
// t7 = relu(t6)
// t8 = sigmoid(t7)
fusion_group::OperationExpression exp1("elementwise_mul", {0, 1}, {2},
dtype, dtype);
fusion_group::OperationExpression exp2("elementwise_add", {2, 3}, {4},
dtype, dtype);
fusion_group::OperationExpression exp3("elementwise_sub", {4, 5}, {6},
dtype, dtype);
fusion_group::OperationExpression exp4("relu", {6}, {7}, dtype, dtype);
fusion_group::OperationExpression exp5("sigmoid", {7}, {8}, dtype, dtype);
std::vector<fusion_group::OperationExpression> expressions = {
exp1, exp2, exp3, exp4, exp5};
// Expressions: // Expressions:
// Op(elementwise_mul), inputs:{0,1}, outputs:{2} // Op(elementwise_mul), inputs:{0,1}, outputs:{2}
// Op(elementwise_add), inputs:{2,3}, outputs:{4} // Op(elementwise_add), inputs:{2,3}, outputs:{4}
...@@ -340,17 +342,18 @@ TEST(code_generator, elementwise) { ...@@ -340,17 +342,18 @@ TEST(code_generator, elementwise) {
} }
TEST(code_generator, elementwise_grad) { TEST(code_generator, elementwise_grad) {
// The var order: t0, t1, t2, t3, t0', t1', t2', t3'
// t2 = t0 * t1
// t3 = relu(t2)
// t2' = relu_grad(t2, t3, t3')
// t0', t1' = elementwise_mul_grad(t0, t1, t2, t2')
fusion_group::OperationExpression exp1("relu_grad", {-1, 3, 7}, {6});
fusion_group::OperationExpression exp2("elementwise_mul_grad", {0, 1, 2, 6},
{4, 5});
std::vector<fusion_group::OperationExpression> expressions = {exp1, exp2};
for (std::string dtype : {"float", "float16"}) { for (std::string dtype : {"float", "float16"}) {
// The var order: t0, t1, t2, t3, t0', t1', t2', t3'
// t2 = t0 * t1
// t3 = relu(t2)
// t2' = relu_grad(t2, t3, t3')
// t0', t1' = elementwise_mul_grad(t0, t1, t2, t2')
fusion_group::OperationExpression exp1("relu_grad", {-1, 3, 7}, {6}, dtype,
dtype);
fusion_group::OperationExpression exp2("elementwise_mul_grad", {0, 1, 2, 6},
{4, 5}, dtype, dtype);
std::vector<fusion_group::OperationExpression> expressions = {exp1, exp2};
// Expressions: // Expressions:
// Op(relu_grad), inputs:{2,3,7}, outputs:{6} // Op(relu_grad), inputs:{2,3,7}, outputs:{6}
// Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5} // Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5}
...@@ -474,7 +477,7 @@ TEST(code_generator, subgraph) { ...@@ -474,7 +477,7 @@ TEST(code_generator, subgraph) {
// Op(elementwise_add), inputs:{7,6}, outputs:{8} // Op(elementwise_add), inputs:{7,6}, outputs:{8}
std::vector<int> input_ids = {0, 1, 2, 3}; std::vector<int> input_ids = {0, 1, 2, 3};
std::vector<int> output_ids = {4, 5, 6, 7, 8}; std::vector<int> output_ids = {4, 5, 6, 7, 8};
TestMain(&subgraph, input_ids, output_ids); TestMain(&subgraph, input_ids, output_ids, dtype);
} }
} }
...@@ -493,7 +496,7 @@ TEST(code_generator, subgraph_grad) { ...@@ -493,7 +496,7 @@ TEST(code_generator, subgraph_grad) {
// Op(tanh_grad), inputs:{9,4,13}, outputs:{14} // Op(tanh_grad), inputs:{9,4,13}, outputs:{14}
std::vector<int> input_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; std::vector<int> input_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
std::vector<int> output_ids = {10, 11, 12, 13, 14, 15, 16, 17}; std::vector<int> output_ids = {10, 11, 12, 13, 14, 15, 16, 17};
TestMain(&subgraph, input_ids, output_ids); TestMain(&subgraph, input_ids, output_ids, dtype);
} }
} }
#endif #endif
...@@ -60,6 +60,50 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l, ...@@ -60,6 +60,50 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
return l.size() != 0U && r.size() != 0U && l == r; return l.size() != 0U && r.size() != 0U && l == r;
} }
bool GroupDetector::IsFusionGroupOp(const Node* n) {
if (!(n && n->IsOp() && n->Op())) return false;
bool is_first = true;
proto::VarType::Type i_data_type = proto::VarType::FP32;
proto::VarType::Type o_data_type = proto::VarType::FP32;
for (auto* i_node : n->inputs) {
if (!i_node->Var()) return false;
if (i_node->Var()->GetType() != proto::VarType::LOD_TENSOR) {
return false;
}
if (is_first) {
i_data_type = i_node->Var()->GetDataType();
is_first = false;
} else {
if (i_data_type != i_node->Var()->GetDataType()) return false;
}
}
is_first = true;
for (auto* o_node : n->outputs) {
if (!o_node->Var()) return false;
if (o_node->Var()->GetType() != proto::VarType::LOD_TENSOR) {
return false;
}
if (is_first) {
o_data_type = o_node->Var()->GetDataType();
is_first = false;
} else {
if (o_data_type != o_node->Var()->GetDataType()) return false;
}
}
if (!(i_data_type == proto::VarType::FP32 ||
i_data_type == proto::VarType::FP64 ||
i_data_type == proto::VarType::FP16) ||
!(o_data_type == proto::VarType::FP32 ||
o_data_type == proto::VarType::FP64 ||
o_data_type == proto::VarType::FP16))
return false;
return true;
}
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) { bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) { if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) {
std::vector<int64_t> shape_0; std::vector<int64_t> shape_0;
...@@ -85,7 +129,9 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) { ...@@ -85,7 +129,9 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()( std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()(
Graph* graph) { Graph* graph) {
auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); }; auto teller = [&](const Node* n) -> bool {
return IsFusionGroupOp(n) && IsElementwiseOp(n);
};
return SubgraphDetector(graph, teller)(); return SubgraphDetector(graph, teller)();
} }
......
...@@ -23,7 +23,12 @@ namespace framework { ...@@ -23,7 +23,12 @@ namespace framework {
namespace ir { namespace ir {
namespace fusion_group { namespace fusion_group {
class ElementwiseGroupDetector { class GroupDetector {
protected:
bool IsFusionGroupOp(const Node* n);
};
class ElementwiseGroupDetector : GroupDetector {
public: public:
std::vector<std::vector<Node*>> operator()(Graph* graph); std::vector<std::vector<Node*>> operator()(Graph* graph);
......
...@@ -110,18 +110,25 @@ void FusionGroupPass::InsertFusionGroupOp( ...@@ -110,18 +110,25 @@ void FusionGroupPass::InsertFusionGroupOp(
op_desc.SetType("fusion_group"); op_desc.SetType("fusion_group");
std::vector<std::string> input_names; std::vector<std::string> input_names;
std::vector<std::string> inputs_data_types;
for (auto* n : input_vars_of_subgraph) { for (auto* n : input_vars_of_subgraph) {
input_names.push_back(n->Name()); input_names.push_back(n->Name());
inputs_data_types.push_back(DataTypeToString(n->Var()->GetDataType()));
external_nodes.insert(n); external_nodes.insert(n);
} }
op_desc.SetInput("Inputs", input_names); op_desc.SetInput("Inputs", input_names);
std::vector<std::string> output_names; std::vector<std::string> output_names;
std::vector<std::string> outs_data_types;
for (auto* n : output_vars_of_subgraph) { for (auto* n : output_vars_of_subgraph) {
output_names.push_back(n->Name()); output_names.push_back(n->Name());
outs_data_types.push_back(DataTypeToString(n->Var()->GetDataType()));
external_nodes.insert(n); external_nodes.insert(n);
} }
op_desc.SetOutput("Outs", output_names); op_desc.SetOutput("Outs", output_names);
op_desc.SetAttr("inputs_data_type", inputs_data_types);
op_desc.SetAttr("outs_data_type", outs_data_types);
op_desc.SetAttr("type", subgraph->GetType()); op_desc.SetAttr("type", subgraph->GetType());
op_desc.SetAttr("func_name", subgraph->GetFuncName()); op_desc.SetAttr("func_name", subgraph->GetFuncName());
op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
...@@ -131,6 +138,7 @@ void FusionGroupPass::InsertFusionGroupOp( ...@@ -131,6 +138,7 @@ void FusionGroupPass::InsertFusionGroupOp(
for (auto* in : input_vars_of_subgraph) { for (auto* in : input_vars_of_subgraph) {
IR_NODE_LINK_TO(in, fusion_group_node); IR_NODE_LINK_TO(in, fusion_group_node);
} }
for (auto* out : output_vars_of_subgraph) { for (auto* out : output_vars_of_subgraph) {
IR_NODE_LINK_TO(fusion_group_node, out); IR_NODE_LINK_TO(fusion_group_node, out);
} }
......
...@@ -102,6 +102,13 @@ void OperationMap::InsertUnaryElementwiseOperations() { ...@@ -102,6 +102,13 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// dx = dout * (1 - out * out) // dx = dout * (1 - out * out)
insert_handler("tanh", "2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0", insert_handler("tanh", "2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0",
{"${2} * (1.0 - ${1} * ${1})"}); {"${2} * (1.0 - ${1} * ${1})"});
// cast
// out = static_cast<T>(d)
// dx = static_cast<T>(d_out)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler("cast", "${0}", {"${0}"});
} }
void OperationMap::InsertBinaryElementwiseOperations() { void OperationMap::InsertBinaryElementwiseOperations() {
...@@ -158,10 +165,12 @@ void OperationMap::InsertMultivariateElementwiseOperations() { ...@@ -158,10 +165,12 @@ void OperationMap::InsertMultivariateElementwiseOperations() {
std::vector<std::string> grad_exprs) { std::vector<std::string> grad_exprs) {
int type = 0; int type = 0;
int num_oprands = -1; int num_oprands = -1;
// here ... represent the number of input is changed
Insert(type, num_oprands, op_type, expr, grad_exprs, {"X"}, {"Out"}); Insert(type, num_oprands, op_type, expr, grad_exprs, {"X"}, {"Out"});
}; };
// here [] represent the number of input is positive(>=0).
// if input list size of Sum Op is 3, It will expand as
// ${0} + ${1} + ${2}
insert_handler("sum", "${0}[ + ${?}]", {}); insert_handler("sum", "${0}[ + ${?}]", {});
} }
......
...@@ -49,7 +49,6 @@ class SubGraph { ...@@ -49,7 +49,6 @@ class SubGraph {
} }
} }
} }
ExtractDataType();
} }
bool IsValid(int min_subgraph_size) { bool IsValid(int min_subgraph_size) {
...@@ -61,11 +60,10 @@ class SubGraph { ...@@ -61,11 +60,10 @@ class SubGraph {
return false; return false;
} }
return ExtractDataType(); return true;
} }
int GetType() const { return type_; } int GetType() const { return type_; }
std::string GetDataType() const { return data_type_; }
void SetFuncName(std::string func_name) { func_name_ = func_name; } void SetFuncName(std::string func_name) { func_name_ = func_name; }
std::string GetFuncName() const { return func_name_; } std::string GetFuncName() const { return func_name_; }
...@@ -162,37 +160,6 @@ class SubGraph { ...@@ -162,37 +160,6 @@ class SubGraph {
} }
private: private:
bool ExtractDataType() {
bool is_first = true;
proto::VarType::Type data_type = proto::VarType::FP32;
for (auto* n : nodes_set_) {
if (n && n->IsVar() && n->Var()) {
if (n->Var()->GetType() != proto::VarType::LOD_TENSOR) {
// All var node in a subgraph should hold a LoDTensor.
return false;
}
if (is_first) {
data_type = n->Var()->GetDataType();
is_first = false;
} else if (n->Var()->GetDataType() != data_type) {
// DataType of VarDesc in a subgraph is not the same.
return false;
}
}
}
if (data_type == proto::VarType::FP32) {
data_type_ = "float";
} else if (data_type == proto::VarType::FP64) {
data_type_ = "double";
} else if (data_type == proto::VarType::FP16) {
data_type_ = "float16";
} else {
VLOG(2) << "Only support fp32, fp64 and fp16 in fusion_group.";
return false;
}
return true;
}
void TopologicalSort() { void TopologicalSort() {
if (!is_sorted_) { if (!is_sorted_) {
std::unordered_map<Node*, std::vector<Node*>> inputs_map; std::unordered_map<Node*, std::vector<Node*>> inputs_map;
......
...@@ -21,7 +21,7 @@ class FusionGroupOp : public framework::OperatorWithKernel { ...@@ -21,7 +21,7 @@ class FusionGroupOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
const size_t num_ins = ctx->Inputs("Inputs").size(); const size_t num_ins = ctx->Inputs("Inputs").size();
const size_t num_outs = ctx->Outputs("Outs").size(); const size_t num_outs = ctx->Outputs("Outs").size();
...@@ -58,6 +58,13 @@ class FusionGroupOp : public framework::OperatorWithKernel { ...@@ -58,6 +58,13 @@ class FusionGroupOp : public framework::OperatorWithKernel {
ctx->ShareLoD("Inputs", /*->*/ "Outs", 0, j); ctx->ShareLoD("Inputs", /*->*/ "Outs", 0, j);
} }
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CUDAPlace(0));
};
}; };
class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker { class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -69,6 +76,12 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -69,6 +76,12 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Outs", AddOutput("Outs",
"(std::vector<LoDTensor>) The outputs of fusion_group op.") "(std::vector<LoDTensor>) The outputs of fusion_group op.")
.AsDuplicable(); .AsDuplicable();
AddAttr<std::vector<std::string>>(
"outs_data_type", "The data type of Outputs in fusion_group op.")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"inputs_data_type", "The data type of Inputs in fusion_group op.")
.SetDefault({});
AddAttr<int>("type", "Fusion type.").SetDefault(0); AddAttr<int>("type", "Fusion type.").SetDefault(0);
AddAttr<std::string>("func_name", "Name of the generated functions.") AddAttr<std::string>("func_name", "Name of the generated functions.")
.SetDefault(""); .SetDefault("");
......
...@@ -22,6 +22,20 @@ limitations under the License. */ ...@@ -22,6 +22,20 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static void MutableMultiTypeData(
std::vector<paddle::framework::LoDTensor*>* var,
const std::vector<std::string>& data_type, const platform::Place& place) {
for (size_t i = 0; i < (*var).size(); i++) {
if (data_type[i] == "float") {
(*var)[i]->mutable_data<float>(place);
} else if (data_type[i] == "double") {
(*var)[i]->mutable_data<double>(place);
} else if (data_type[i] == "::paddle::platform::float16") {
(*var)[i]->mutable_data<paddle::platform::float16>(place);
}
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FusionGroupKernel : public framework::OpKernel<T> { class FusionGroupKernel : public framework::OpKernel<T> {
public: public:
...@@ -29,14 +43,15 @@ class FusionGroupKernel : public framework::OpKernel<T> { ...@@ -29,14 +43,15 @@ class FusionGroupKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<framework::LoDTensor>("Inputs"); auto ins = ctx.MultiInput<framework::LoDTensor>("Inputs");
auto outs = ctx.MultiOutput<framework::LoDTensor>("Outs"); auto outs = ctx.MultiOutput<framework::LoDTensor>("Outs");
int type = ctx.Attr<int>("type"); int type = ctx.Attr<int>("type");
auto outs_type = ctx.Attr<std::vector<std::string>>("outs_data_type");
auto inputs_type = ctx.Attr<std::vector<std::string>>("inputs_data_type");
size_t num_ins = ins.size(); size_t num_ins = ins.size();
size_t num_outs = outs.size(); size_t num_outs = outs.size();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
for (size_t i = 0; i < num_outs; ++i) {
outs[i]->mutable_data<T>(place); MutableMultiTypeData(&outs, outs_type, place);
}
std::string func_name = ctx.Attr<std::string>("func_name"); std::string func_name = ctx.Attr<std::string>("func_name");
platform::DeviceCode* dev_code = platform::DeviceCode* dev_code =
...@@ -47,13 +62,25 @@ class FusionGroupKernel : public framework::OpKernel<T> { ...@@ -47,13 +62,25 @@ class FusionGroupKernel : public framework::OpKernel<T> {
size_t n = ins[0]->numel(); size_t n = ins[0]->numel();
std::vector<void*> args; std::vector<void*> args;
args.push_back(&n); args.push_back(&n);
std::vector<const T*> ptrs(num_ins + num_outs); std::vector<const void*> ptrs(num_ins + num_outs);
for (size_t i = 0; i < num_ins; ++i) { for (size_t i = 0; i < num_ins; ++i) {
ptrs[i] = ins[i]->data<T>(); if (inputs_type[i] == "::paddle::platform::float16") {
ptrs[i] = ins[i]->data<paddle::platform::float16>();
} else if (inputs_type[i] == "double") {
ptrs[i] = ins[i]->data<double>();
} else if (inputs_type[i] == "float") {
ptrs[i] = ins[i]->data<float>();
}
args.push_back(&ptrs[i]); args.push_back(&ptrs[i]);
} }
for (size_t j = 0; j < num_outs; ++j) { for (size_t j = 0; j < num_outs; ++j) {
ptrs[num_ins + j] = outs[j]->data<T>(); if (outs_type[j] == "::paddle::platform::float16") {
ptrs[num_ins + j] = outs[j]->data<paddle::platform::float16>();
} else if (outs_type[j] == "double") {
ptrs[num_ins + j] = outs[j]->data<double>();
} else if (outs_type[j] == "float") {
ptrs[num_ins + j] = outs[j]->data<float>();
}
args.push_back(&ptrs[num_ins + j]); args.push_back(&ptrs[num_ins + j]);
} }
dev_code->Launch(n, &args); dev_code->Launch(n, &args);
......
...@@ -57,7 +57,8 @@ framework::OpDesc* CreateFusionGroupOp( ...@@ -57,7 +57,8 @@ framework::OpDesc* CreateFusionGroupOp(
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::vector<int64_t>>& input_shapes, const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::string>& output_names, int type, const std::vector<std::string>& output_names, int type,
std::string func_name) { const std::vector<std::string>& inputs_data_type,
const std::vector<std::string>& outs_data_type, std::string func_name) {
EXPECT_EQ(input_names.size(), input_shapes.size()); EXPECT_EQ(input_names.size(), input_shapes.size());
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
...@@ -76,6 +77,8 @@ framework::OpDesc* CreateFusionGroupOp( ...@@ -76,6 +77,8 @@ framework::OpDesc* CreateFusionGroupOp(
op->SetType("fusion_group"); op->SetType("fusion_group");
op->SetInput("Inputs", input_names); op->SetInput("Inputs", input_names);
op->SetOutput("Outs", output_names); op->SetOutput("Outs", output_names);
op->SetAttr("inputs_data_type", inputs_data_type);
op->SetAttr("outs_data_type", outs_data_type);
op->SetAttr("type", type); op->SetAttr("type", type);
op->SetAttr("func_name", func_name); op->SetAttr("func_name", func_name);
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(), op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
...@@ -130,6 +133,8 @@ void CheckOutputs(framework::Scope* scope, ...@@ -130,6 +133,8 @@ void CheckOutputs(framework::Scope* scope,
void TestMain(const std::vector<std::string>& input_names, void TestMain(const std::vector<std::string>& input_names,
const std::vector<std::vector<int64_t>>& input_shapes, const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::string>& output_names, int type, const std::vector<std::string>& output_names, int type,
const std::vector<std::string>& inputs_data_type,
const std::vector<std::string>& outs_data_type,
std::string func_name, std::string cuda_kernel_str, std::string func_name, std::string cuda_kernel_str,
CPUKernelFunc cpu_kernel_func) { CPUKernelFunc cpu_kernel_func) {
// Compile the device code // Compile the device code
...@@ -139,8 +144,9 @@ void TestMain(const std::vector<std::string>& input_names, ...@@ -139,8 +144,9 @@ void TestMain(const std::vector<std::string>& input_names,
// Create a ProgramDesc that has a fusion_group_op. // Create a ProgramDesc that has a fusion_group_op.
framework::ProgramDesc program; framework::ProgramDesc program;
framework::OpDesc* op_desc = CreateFusionGroupOp( framework::OpDesc* op_desc =
&program, input_names, input_shapes, output_names, type, func_name); CreateFusionGroupOp(&program, input_names, input_shapes, output_names,
type, inputs_data_type, outs_data_type, func_name);
auto fusion_group_op = framework::OpRegistry::CreateOp(*op_desc); auto fusion_group_op = framework::OpRegistry::CreateOp(*op_desc);
framework::Scope scope; framework::Scope scope;
...@@ -210,8 +216,11 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) { ...@@ -210,8 +216,11 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) {
} }
}; };
TestMain(input_names, input_shapes, output_names, 0, std::vector<std::string> inputs_data_type(input_names.size(), "float");
"elementwise_cuda_kernel_0", kernel, elementwise_cpu_kernel_0); std::vector<std::string> outs_data_type(output_names.size(), "float");
TestMain(input_names, input_shapes, output_names, 0, inputs_data_type,
outs_data_type, "elementwise_cuda_kernel_0", kernel,
elementwise_cpu_kernel_0);
} }
} // namespace operators } // namespace operators
......
...@@ -142,8 +142,8 @@ class PassTest(unittest.TestCase): ...@@ -142,8 +142,8 @@ class PassTest(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
outs_opt[i], outs[i], atol=atol), outs_opt[i], outs[i], atol=atol),
"Output < {} > has diff at {}".format(self.fetch_list[i].name, "Output < {} > has diff at {}, expected {} but got {}".format(
str(place))) self.fetch_list[i].name, str(place), outs_opt[i], outs[i]))
def _check_fused_ops(self, program): def _check_fused_ops(self, program):
''' '''
......
...@@ -125,17 +125,15 @@ class FusionGroupPassTestFP16(FusionGroupPassTest): ...@@ -125,17 +125,15 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
fluid.data( fluid.data(
name="data2", shape=[128, 128], dtype=dtype)) name="data2", shape=[128, 128], dtype=dtype))
# subgraph with only 1 op node
tmp_0 = self.feed_vars[0] * self.feed_vars[1] tmp_0 = self.feed_vars[0] * self.feed_vars[1]
tmp_1 = layers.mul(tmp_0, self.feed_vars[2]) tmp_1 = layers.mul(tmp_0, self.feed_vars[2])
tmp_2 = layers.cast(tmp_0, dtype="float16")
tmp_3 = layers.cast(tmp_1, dtype="float16") tmp_3 = layers.cast(tmp_1, dtype="float16")
# subgraph with 2 op nodes tmp_2 = layers.cast(tmp_0, dtype="float16")
tmp_4 = layers.relu(tmp_2 + tmp_3) tmp_4 = layers.relu(tmp_2 + tmp_3)
tmp_5 = layers.cast(tmp_4, dtype=dtype) tmp_5 = layers.cast(tmp_4, dtype=dtype)
self.fetch_list = [tmp_5] self.fetch_list = [tmp_0, tmp_1, tmp_2, tmp_3, tmp_4, tmp_5]
self.num_fused_ops = 1 self.num_fused_ops = 2
class FusionGroupPassSumTest(FusionGroupPassTest): class FusionGroupPassSumTest(FusionGroupPassTest):
...@@ -147,9 +145,28 @@ class FusionGroupPassSumTest(FusionGroupPassTest): ...@@ -147,9 +145,28 @@ class FusionGroupPassSumTest(FusionGroupPassTest):
tmp_1 = layers.sum([tmp_0, self.feed_vars[2], self.feed_vars[3]]) tmp_1 = layers.sum([tmp_0, self.feed_vars[2], self.feed_vars[3]])
tmp_2 = layers.sum([tmp_1, self.feed_vars[4]]) tmp_2 = layers.sum([tmp_1, self.feed_vars[4]])
self.fetch_list = [tmp_0, tmp_1] self.fetch_list = [tmp_0, tmp_1, tmp_2]
self.num_fused_ops = 1
class FusionGroupPassCastTest(FusionGroupPassTest):
def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([2, 2], dtype, 2)
tmp_0 = layers.elementwise_add(self.feed_vars[0], self.feed_vars[1])
tmp_1 = layers.cast(tmp_0, dtype="double")
tmp_2 = layers.cast(tmp_1, dtype="float32")
self.fetch_list = [tmp_0, tmp_1, tmp_2]
self.num_fused_ops = 1 self.num_fused_ops = 1
def setUp(self):
self.build_program("float64")
self.feeds = self._feed_random_data(self.feed_vars)
self.pass_names = "fusion_group_pass"
self.fused_op_type = "fusion_group"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册