未验证 提交 1be6bf45 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add assign to fusion_group and enhance inplace execution in fusion_group. (#26121)

上级 b2034c28
......@@ -68,11 +68,35 @@ static bool HasInput(Node* n, std::string name) {
return input_names_set.find(name) != input_names_set.end();
}
static Node* GetInputVar(Node* n, const std::string& name) {
PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
platform::errors::InvalidArgument(
"Expected node %p to be an operator node.", n));
for (auto* in : n->inputs) {
if (in->Name() == name) {
return in;
}
}
return nullptr;
}
static Node* GetOutputVar(Node* n, const std::string& name) {
PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
platform::errors::InvalidArgument(
"Expected node %p to be an operator node.", n));
for (auto* out : n->outputs) {
if (out->Name() == name) {
return out;
}
}
return nullptr;
}
std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
SubGraph* subgraph) {
std::unordered_map<std::string, int> var_ids = EncodeVarNodes(subgraph);
std::vector<Node*> intermediate_out_nodes =
subgraph->GetIntermediateOutVarNodes();
std::unordered_map<Node*, int> var_ids = EncodeVarNodes(subgraph);
std::unordered_set<Node*> intermediate_out_vars_set =
subgraph->GetIntermediateOutVarNodesSet();
std::vector<OperationExpression> expressions;
for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsOp() && node->Op()) {
......@@ -92,11 +116,12 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
if ((HasInput(node, name) && op->Input(name).size() >= 1U)) {
for (size_t i = 0; i < op->Input(name).size(); i++) {
Node* input_var = GetInputVar(node, op->Input(name)[i]);
PADDLE_ENFORCE_NE(
var_ids.find(op->Input(name)[i]), var_ids.end(),
var_ids.find(input_var), var_ids.end(),
platform::errors::InvalidArgument(
"Input(%s) of operation %s is not set.", name, op->Type()));
input_ids.push_back(var_ids[op->Input(name)[i]]);
input_ids.push_back(var_ids[input_var]);
}
} else {
input_ids.push_back(-1);
......@@ -106,31 +131,29 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// Output ids should be set in fixed order, like:
// - dx, dy in backward operations
std::vector<int> output_ids;
std::vector<int> intermediate_output_ids;
std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names;
std::unordered_map<int, bool> intermediate_state;
for (auto& name : output_names) {
Node* output_var = GetOutputVar(node, op->Output(name)[0]);
PADDLE_ENFORCE_NE(
var_ids.find(op->Output(name)[0]), var_ids.end(),
var_ids.find(output_var), var_ids.end(),
platform::errors::InvalidArgument(
"Output(%s) of operation %s is not set.", name, op->Type()));
output_ids.push_back(var_ids[op->Output(name)[0]]);
bool enable_intermediate = false;
for (auto* n : intermediate_out_nodes) {
if (n->Name() == op->Output(name)[0]) {
enable_intermediate = true;
break;
}
output_ids.push_back(var_ids[output_var]);
if (!subgraph->SaveIntermediateOut() &&
intermediate_out_vars_set.find(output_var) !=
intermediate_out_vars_set.end()) {
intermediate_output_ids.push_back(var_ids[output_var]);
}
intermediate_state[var_ids[op->Output(name)[0]]] = enable_intermediate;
}
std::string lhs_type = ExtractDataType(node->outputs);
std::string rhs_type = ExtractDataType(node->inputs);
auto expression =
OperationExpression(node->Name(), input_ids, output_ids, rhs_type,
lhs_type, intermediate_state);
lhs_type, intermediate_output_ids);
expression.SetAttr(attr);
expressions.push_back(expression);
}
......@@ -146,17 +169,18 @@ std::string CodeGenerator::Generate(
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
std::set<int> input_ids = std::move(DistilInputIds(expressions));
std::set<int> output_ids = std::move(DistilOutputIds(expressions));
std::set<int> intermediate_ids =
std::set<int> intermediate_output_ids =
std::move(DistilIntermediateIds(expressions));
std::unordered_map<int, std::string> dtypes =
std::move(DistilDtypes(expressions));
TemplateVariable template_var;
template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(input_ids, output_ids,
intermediate_ids, dtypes));
template_var.Add(
"parameters",
EmitParameters(input_ids, output_ids, intermediate_output_ids, dtypes));
template_var.Add("compute_body",
EmitComputeBody(expressions, input_ids, output_ids,
intermediate_ids, dtypes));
intermediate_output_ids, dtypes));
std::set<std::string> all_dtype;
for (const auto& type : dtypes) {
......@@ -204,18 +228,14 @@ std::set<int> CodeGenerator::DistilOutputIds(
std::set<int> CodeGenerator::DistilIntermediateIds(
const std::vector<OperationExpression>& expressions) {
std::set<int> intermediate_ids;
std::set<int> intermediate_output_ids;
// Use std::set to remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetOutputIds()) {
auto intermediate_state = expressions[i].GetIntermediateState();
if (intermediate_state.find(id) != intermediate_state.end() &&
intermediate_state[id]) {
intermediate_ids.insert(id);
}
for (auto id : expressions[i].GetIntermediateOutputIds()) {
intermediate_output_ids.insert(id);
}
}
return intermediate_ids;
return intermediate_output_ids;
}
std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
......@@ -316,26 +336,29 @@ std::string CodeGenerator::EmitComputeBody(
return load.str() + compute.str() + store.str();
}
std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes(
std::unordered_map<Node*, int> CodeGenerator::EncodeVarNodes(
SubGraph* subgraph) {
const auto& input_var_nodes = subgraph->GetInputVarNodes();
const auto& output_var_nodes = subgraph->GetOutputVarNodes();
// Encode all var nodes, including intermediate output var nodes.
const auto& output_var_nodes = subgraph->GetOutputVarNodes(true);
int id = 0;
std::unordered_map<std::string, int> var_ids;
std::unordered_map<Node*, int> var_ids;
// Numbering input vars.
for (auto* in : input_var_nodes) {
VLOG(3) << "Encoding input names:" << in->Name() << ", id:" << id;
if (var_ids.find(in->Name()) == var_ids.end()) {
var_ids[in->Name()] = id++;
VLOG(3) << "Encoding input names:" << in->Name() << "(" << in
<< "), id:" << id;
if (var_ids.find(in) == var_ids.end()) {
var_ids[in] = id++;
}
}
// Encoding output vars.
for (auto* out : output_var_nodes) {
VLOG(3) << "Ecoding output names:" << out->Name() << ", id:" << id;
if (var_ids.find(out->Name()) == var_ids.end()) {
var_ids[out->Name()] = id++;
VLOG(3) << "Ecoding output names:" << out->Name() << "(" << out
<< "), id:" << id;
if (var_ids.find(out) == var_ids.end()) {
var_ids[out] = id++;
}
}
return var_ids;
......
......@@ -61,7 +61,7 @@ class CodeGenerator {
const std::unordered_map<int, std::string>& dtypes) const;
// Encode all var nodes in the subgraph with an unique number.
std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph);
std::unordered_map<Node*, int> EncodeVarNodes(SubGraph* subgraph);
private:
std::vector<CodeTemplate> code_templates_;
......
......@@ -48,20 +48,20 @@ class OperationExpression {
std::string op_type, const std::vector<int>& input_ids,
const std::vector<int>& output_ids, std::string rhs_type,
std::string lhs_type,
const std::unordered_map<int, bool>& intermediate_state = {})
const std::vector<int>& intermediate_output_ids = {})
: op_type_(op_type),
input_ids_(input_ids),
output_ids_(output_ids),
rhs_type_(rhs_type),
lhs_type_(lhs_type),
intermediate_state_(intermediate_state) {}
intermediate_output_ids_(intermediate_output_ids) {}
std::string GetOpType() const { return op_type_; }
std::unordered_map<int, bool> GetIntermediateState() const {
return intermediate_state_;
}
std::vector<int> GetInputIds() const { return input_ids_; }
std::vector<int> GetOutputIds() const { return output_ids_; }
std::vector<int> GetIntermediateOutputIds() const {
return intermediate_output_ids_;
}
std::string GetRHSType() const { return rhs_type_; }
std::string GetLHSType() const { return lhs_type_; }
void SetAttr(AttributeMap attr) { attr_ = attr; }
......@@ -84,7 +84,7 @@ class OperationExpression {
AttributeMap attr_;
std::string rhs_type_;
std::string lhs_type_;
std::unordered_map<int, bool> intermediate_state_;
std::vector<int> intermediate_output_ids_;
};
class TemplateVariable {
......
......@@ -144,7 +144,6 @@ void CheckOutput(const std::vector<OperationExpression>& expressions,
LOG(INFO) << "Precision check failed from i = " << id
<< ", expect: " << expect << ", actual: " << actual;
EXPECT_LT(fabs(actual - expect), eps);
break;
}
}
}
......@@ -465,7 +464,7 @@ TEST(code_generator, subgraph) {
for (std::string dtype : {"float", "__half"}) {
std::unique_ptr<paddle::framework::ir::Graph> graph =
BuildGraph(false, dtype);
fusion_group::SubGraph subgraph(0, "elementwise_kernel_1", false,
fusion_group::SubGraph subgraph(0, "elementwise_kernel_1", true,
graph->Nodes());
// Expressions generated by code_generator (they may be different):
......@@ -484,7 +483,7 @@ TEST(code_generator, subgraph_grad) {
for (std::string dtype : {"float", "__half"}) {
std::unique_ptr<paddle::framework::ir::Graph> graph =
BuildGraph(true, dtype);
fusion_group::SubGraph subgraph(0, "elementwise_grad_kernel_1", false,
fusion_group::SubGraph subgraph(0, "elementwise_grad_kernel_1", true,
DistilGradNodes(graph));
// Expressions generated by code_generator (they may be different):
......
......@@ -63,7 +63,7 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
bool GroupDetector::CheckPrecondition(const Node* n) {
auto check_data_type = [&](const std::vector<Node*>& nodes) -> bool {
bool is_first = true;
proto::VarType::Type data_type_0;
proto::VarType::Type data_type_0 = proto::VarType::BOOL;
for (auto* n : nodes) {
if (n && n->IsVar() && n->Var()) {
if (n->Var()->GetType() != proto::VarType::LOD_TENSOR) {
......
......@@ -63,11 +63,6 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
std::unordered_set<Node*>(vec.begin(), vec.end()));
VLOG(3) << "subgraph: {\n" << DebugString(subgraph.SortedNodes()) << "}\n";
// In elementwise fused kernel, memory is the bound of execution,
// here we remove the output id to use less memory and less time.
if (subgraph.RemoveIntermediateOut()) {
subgraph.DetectIntermediateOutWithGraph(graph);
}
if (subgraph.IsValid(min_subgraph_size)) {
subgraph.SetFuncName("fused_elementwise_" + std::to_string(index++));
if (GenerateCode(&subgraph)) {
......@@ -115,57 +110,52 @@ static int ExtractOpRole(fusion_group::SubGraph* subgraph) {
void FusionGroupPass::InsertFusionGroupOp(
Graph* graph, fusion_group::SubGraph* subgraph) const {
const std::vector<Node*>& input_vars_of_subgraph =
subgraph->GetInputVarNodes();
const std::vector<Node*>& output_vars_of_subgraph =
subgraph->GetOutputVarNodes();
const std::vector<Node*> intermediate_vars_of_subgraph =
subgraph->GetIntermediateOutVarNodes();
const std::vector<Node*>& input_vars = subgraph->GetInputVarNodes();
const std::vector<Node*>& output_vars =
subgraph->GetOutputVarNodes(subgraph->SaveIntermediateOut());
std::unordered_set<Node*> external_nodes;
OpDesc op_desc;
op_desc.SetType("fusion_group");
// Prepare inputs.
std::vector<std::string> input_names;
std::vector<std::string> inputs_data_types;
for (auto* n : input_vars_of_subgraph) {
std::vector<int> input_dtypes;
std::unordered_set<Node*> output_vars_set(output_vars.begin(),
output_vars.end());
for (auto* n : input_vars) {
// It is not an output var node.
if (output_vars_set.find(n) == output_vars_set.end()) {
input_names.push_back(n->Name());
inputs_data_types.push_back(DataTypeToString(n->Var()->GetDataType()));
input_dtypes.push_back(n->Var()->GetDataType());
external_nodes.insert(n);
}
op_desc.SetInput("Inputs", input_names);
}
// Prepare outputs.
std::vector<std::string> output_names;
std::vector<std::string> outs_data_types;
std::vector<Node*> output_var_without_intermediate;
for (auto* n : output_vars_of_subgraph) {
auto it_input =
find(input_vars_of_subgraph.begin(), input_vars_of_subgraph.end(), n);
auto it_intermediate = find(intermediate_vars_of_subgraph.begin(),
intermediate_vars_of_subgraph.end(), n);
if (it_intermediate == intermediate_vars_of_subgraph.end() &&
it_input == input_vars_of_subgraph.end()) {
std::vector<int> output_dtypes;
for (auto* n : output_vars) {
output_names.push_back(n->Name());
outs_data_types.push_back(DataTypeToString(n->Var()->GetDataType()));
output_var_without_intermediate.push_back(n);
}
output_dtypes.push_back(n->Var()->GetDataType());
external_nodes.insert(n);
}
OpDesc op_desc;
op_desc.SetType("fusion_group");
op_desc.SetInput("Inputs", input_names);
op_desc.SetOutput("Outs", output_names);
op_desc.SetAttr("inputs_data_type", inputs_data_types);
op_desc.SetAttr("outs_data_type", outs_data_types);
op_desc.SetAttr("inputs_dtype", input_dtypes);
op_desc.SetAttr("outs_dtype", output_dtypes);
op_desc.SetAttr("type", subgraph->GetType());
op_desc.SetAttr("func_name", subgraph->GetFuncName());
op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(subgraph));
Node* fusion_group_node = graph->CreateOpNode(&op_desc);
for (auto* in : input_vars_of_subgraph) {
for (auto* in : input_vars) {
if (output_vars_set.find(in) == output_vars_set.end()) {
IR_NODE_LINK_TO(in, fusion_group_node);
}
for (auto* out : output_var_without_intermediate) {
}
for (auto* out : output_vars) {
IR_NODE_LINK_TO(fusion_group_node, out);
}
......
......@@ -105,12 +105,6 @@ void OperationMap::InsertUnaryElementwiseOperations() {
insert_handler("tanh", "%{2.0} / (%{1.0} + Exp(-%{2.0} * ${0})) - %{1.0}",
{"${2} * (%{1.0} - ${1} * ${1})"});
// cast:
// out = static_cast<T>(x)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler("cast", "${0}", {});
// sqrt:
// out = x^(1/2)
// dx = dout * 0.5 / out
......@@ -121,6 +115,16 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// dx = dout * 2.0 * x
insert_handler("square", "${0} * ${0}", {"${2} * %{2.0} * ${0}"});
// assign:
// out = x
insert_handler("assign", "${0}", {});
// cast:
// out = static_cast<T>(x)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler("cast", "${0}", {});
// scale
// out = (bias_after_scale) ? scale * X + bias : scale(X + bias)
// here we use '=' operator to seperate th default value
......
......@@ -66,11 +66,12 @@ class SubGraph {
}
int GetType() const { return type_; }
bool RemoveIntermediateOut() { return !save_intermediate_out_; }
void SetFuncName(std::string func_name) { func_name_ = func_name; }
std::string GetFuncName() const { return func_name_; }
bool SaveIntermediateOut() const { return save_intermediate_out_; }
const std::unordered_set<Node*>& Nodes() const { return nodes_set_; }
const std::vector<Node*>& SortedNodes() {
if (!is_sorted_) {
......@@ -118,66 +119,88 @@ class SubGraph {
return input_vars;
}
std::vector<Node*> GetOutputVarNodes() {
std::vector<Node*> GetOutputVarNodes(bool with_intermediate_out) {
// The order of output nodes should be consistant anywhere..
std::vector<Node*> output_vars_all;
std::vector<Node*> output_vars;
for (auto* n : SortedNodes()) {
if (n && n->IsVar() && n->Var()) {
if (IsOutputOfInternalOp(n)) {
// If the var_node is the output of some op_node in the subgraph, it
// is considered the output var node of the subgraph.
bool is_found = false;
for (auto* in : n->inputs) {
if (Has(in)) {
is_found = true;
}
if (with_intermediate_out) {
output_vars.push_back(n);
} else {
if (n->outputs.empty() || IsInputOfExternalOp(n)) {
output_vars.push_back(n);
}
if (is_found) {
output_vars_all.push_back(n);
}
}
}
return output_vars_all;
return output_vars;
}
std::vector<Node*> GetIntermediateOutVarNodes() {
return intermediate_out_nodes_;
// Intermediate output var nodes: the output of some op_node in the
// subgraph, but not referenced outside the subgraph.
std::vector<Node*> intermediate_out_vars;
for (auto* n : SortedNodes()) {
if (IsOutputOfInternalOp(n) && IsInputOfInternalOp(n) &&
!IsInputOfExternalOp(n)) {
// When the outputs size is 0, it is also considered a intermidiate
// output. It maybe an unused output or the fetching vars, so that we
// cannot eleiminate it directly here.
intermediate_out_vars.push_back(n);
}
}
return intermediate_out_vars;
}
void DetectIntermediateOutWithGraph(Graph* graph) {
auto graph_nodes = graph->Nodes();
for (auto* n : SortedNodes()) {
bool enable_remove = true;
std::unordered_set<Node*> GetIntermediateOutVarNodesSet() {
std::vector<Node*> intermediate_out_vars = GetIntermediateOutVarNodes();
return std::unordered_set<Node*>(intermediate_out_vars.begin(),
intermediate_out_vars.end());
}
if (n && n->IsVar() && n->Var()) {
bool leaf_graph = true;
for (auto* node : graph_nodes) {
if (node->IsOp()) {
auto inputs = node->inputs;
for (auto* in : inputs) {
if (in && in->Name() == n->Name()) {
if (!Has(node)) enable_remove = false;
leaf_graph = false;
private:
bool IsInputOfInternalOp(Node* n) {
bool is_input_of_internal_op = false;
if (Has(n) && n && n->IsVar() && n->Var()) {
for (auto* out : n->outputs) {
if (Has(out)) {
is_input_of_internal_op = true;
break;
}
}
}
return is_input_of_internal_op;
}
if (!enable_remove) {
bool IsInputOfExternalOp(Node* n) {
// If n is the input any one node outside the subgraph.
bool is_input_of_external_op = false;
if (Has(n) && n && n->IsVar() && n->Var()) {
for (auto* out : n->outputs) {
if (!Has(out)) {
is_input_of_external_op = true;
break;
}
}
if (leaf_graph) enable_remove = false;
} else {
enable_remove = false;
}
return is_input_of_external_op;
}
if (enable_remove) {
intermediate_out_nodes_.push_back(n);
bool IsOutputOfInternalOp(Node* n) {
bool is_output_of_internal_op = false;
if (Has(n) && n && n->IsVar() && n->Var()) {
for (auto* in : n->inputs) {
if (Has(in)) {
is_output_of_internal_op = true;
break;
}
}
}
return is_output_of_internal_op;
}
private:
void TopologicalSort() {
if (!is_sorted_) {
std::unordered_map<Node*, std::vector<Node*>> inputs_map;
......@@ -236,7 +259,6 @@ class SubGraph {
bool save_intermediate_out_{true};
std::unordered_set<Node*> nodes_set_;
std::vector<Node*> intermediate_out_nodes_{};
bool is_sorted_{false};
std::vector<Node*> sorted_nodes_;
};
......
......@@ -22,8 +22,14 @@ class FusionGroupOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
const size_t num_ins = ctx->Inputs("Inputs").size();
const size_t num_outs = ctx->Outputs("Outs").size();
OP_INOUT_CHECK(ctx->HasInputs("Inputs"), "Input", "Inputs", "FusionGroup");
OP_INOUT_CHECK(ctx->HasOutputs("Outs"), "Output", "Outs", "FusionGroup");
auto input_names = ctx->Inputs("Inputs");
auto output_names = ctx->Outputs("Outs");
const size_t num_ins = input_names.size();
const size_t num_outs = output_names.size();
PADDLE_ENFORCE_GE(
num_ins, 1UL,
......@@ -42,9 +48,12 @@ class FusionGroupOp : public framework::OperatorWithKernel {
std::vector<framework::DDim> x_dims = ctx->GetInputsDim("Inputs");
if (type == 0) {
for (size_t i = 1; i < num_ins; ++i) {
PADDLE_ENFORCE_EQ(x_dims[0], x_dims[i],
PADDLE_ENFORCE_EQ(
x_dims[0], x_dims[i],
platform::errors::InvalidArgument(
"All the inputs' dims should be the same."));
"All the inputs' dims is expected to be the same. "
"But recieved [%s] (name: %s) vs [%s] (name: %s).",
x_dims[0], input_names[0], x_dims[i], input_names[i]));
}
std::vector<framework::DDim> out_dims;
for (size_t j = 0; j < num_outs; ++j) {
......@@ -76,11 +85,11 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Outs",
"(std::vector<LoDTensor>) The outputs of fusion_group op.")
.AsDuplicable();
AddAttr<std::vector<std::string>>(
"outs_data_type", "The data type of Outputs in fusion_group op.")
AddAttr<std::vector<int>>("outs_dtype",
"The data type of Outputs in fusion_group op.")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"inputs_data_type", "The data type of Inputs in fusion_group op.")
AddAttr<std::vector<int>>("inputs_dtype",
"The data type of Inputs in fusion_group op.")
.SetDefault({});
AddAttr<int>("type", "Fusion type.").SetDefault(0);
AddAttr<std::string>("func_name", "Name of the generated functions.")
......
......@@ -24,14 +24,14 @@ namespace operators {
static void MutableMultiTypeData(
std::vector<paddle::framework::LoDTensor*>* var,
const std::vector<std::string>& data_type, const platform::Place& place) {
const std::vector<int>& data_type, const platform::Place& place) {
for (size_t i = 0; i < var->size(); i++) {
if (data_type[i] == "float") {
if (data_type[i] == framework::proto::VarType::FP32) {
(*var)[i]->mutable_data<float>(place);
} else if (data_type[i] == "double") {
(*var)[i]->mutable_data<double>(place);
} else if (data_type[i] == "::paddle::platform::float16") {
} else if (data_type[i] == framework::proto::VarType::FP16) {
(*var)[i]->mutable_data<paddle::platform::float16>(place);
} else if (data_type[i] == framework::proto::VarType::FP64) {
(*var)[i]->mutable_data<double>(place);
}
}
}
......@@ -43,15 +43,15 @@ class FusionGroupKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<framework::LoDTensor>("Inputs");
auto outs = ctx.MultiOutput<framework::LoDTensor>("Outs");
int type = ctx.Attr<int>("type");
auto outs_type = ctx.Attr<std::vector<std::string>>("outs_data_type");
auto inputs_type = ctx.Attr<std::vector<std::string>>("inputs_data_type");
const auto& outs_dtype = ctx.Attr<std::vector<int>>("outs_dtype");
const auto& inputs_dtype = ctx.Attr<std::vector<int>>("inputs_dtype");
size_t num_ins = ins.size();
size_t num_outs = outs.size();
auto place = ctx.GetPlace();
MutableMultiTypeData(&outs, outs_type, place);
MutableMultiTypeData(&outs, outs_dtype, place);
std::string func_name = ctx.Attr<std::string>("func_name");
platform::DeviceCode* dev_code =
......@@ -64,22 +64,22 @@ class FusionGroupKernel : public framework::OpKernel<T> {
args.push_back(&n);
std::vector<const void*> ptrs(num_ins + num_outs);
for (size_t i = 0; i < num_ins; ++i) {
if (inputs_type[i] == "::paddle::platform::float16") {
if (inputs_dtype[i] == framework::proto::VarType::FP16) {
ptrs[i] = ins[i]->data<paddle::platform::float16>();
} else if (inputs_type[i] == "double") {
ptrs[i] = ins[i]->data<double>();
} else if (inputs_type[i] == "float") {
} else if (inputs_dtype[i] == framework::proto::VarType::FP32) {
ptrs[i] = ins[i]->data<float>();
} else if (inputs_dtype[i] == framework::proto::VarType::FP64) {
ptrs[i] = ins[i]->data<double>();
}
args.push_back(&ptrs[i]);
}
for (size_t j = 0; j < num_outs; ++j) {
if (outs_type[j] == "::paddle::platform::float16") {
if (outs_dtype[j] == framework::proto::VarType::FP16) {
ptrs[num_ins + j] = outs[j]->data<paddle::platform::float16>();
} else if (outs_type[j] == "double") {
ptrs[num_ins + j] = outs[j]->data<double>();
} else if (outs_type[j] == "float") {
} else if (outs_dtype[j] == framework::proto::VarType::FP32) {
ptrs[num_ins + j] = outs[j]->data<float>();
} else if (outs_dtype[j] == framework::proto::VarType::FP64) {
ptrs[num_ins + j] = outs[j]->data<double>();
}
args.push_back(&ptrs[num_ins + j]);
}
......
......@@ -57,10 +57,14 @@ framework::OpDesc* CreateFusionGroupOp(
const std::vector<std::string>& input_names,
const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::string>& output_names, int type,
const std::vector<std::string>& inputs_data_type,
const std::vector<std::string>& outs_data_type, std::string func_name) {
std::string func_name) {
EXPECT_EQ(input_names.size(), input_shapes.size());
std::vector<int> input_dtypes(input_names.size(),
framework::proto::VarType::FP32);
std::vector<int> output_dtypes(output_names.size(),
framework::proto::VarType::FP32);
for (size_t i = 0; i < input_names.size(); ++i) {
auto* var = program->MutableBlock(0)->Var(input_names[i]);
var->SetType(framework::proto::VarType::LOD_TENSOR);
......@@ -77,8 +81,8 @@ framework::OpDesc* CreateFusionGroupOp(
op->SetType("fusion_group");
op->SetInput("Inputs", input_names);
op->SetOutput("Outs", output_names);
op->SetAttr("inputs_data_type", inputs_data_type);
op->SetAttr("outs_data_type", outs_data_type);
op->SetAttr("inputs_dtype", input_dtypes);
op->SetAttr("outs_dtype", output_dtypes);
op->SetAttr("type", type);
op->SetAttr("func_name", func_name);
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
......@@ -133,8 +137,6 @@ void CheckOutputs(framework::Scope* scope,
void TestMain(const std::vector<std::string>& input_names,
const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::string>& output_names, int type,
const std::vector<std::string>& inputs_data_type,
const std::vector<std::string>& outs_data_type,
std::string func_name, std::string cuda_kernel_str,
CPUKernelFunc cpu_kernel_func) {
// Compile the device code
......@@ -144,9 +146,8 @@ void TestMain(const std::vector<std::string>& input_names,
// Create a ProgramDesc that has a fusion_group_op.
framework::ProgramDesc program;
framework::OpDesc* op_desc =
CreateFusionGroupOp(&program, input_names, input_shapes, output_names,
type, inputs_data_type, outs_data_type, func_name);
framework::OpDesc* op_desc = CreateFusionGroupOp(
&program, input_names, input_shapes, output_names, type, func_name);
auto fusion_group_op = framework::OpRegistry::CreateOp(*op_desc);
framework::Scope scope;
......@@ -216,11 +217,8 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) {
}
};
std::vector<std::string> inputs_data_type(input_names.size(), "float");
std::vector<std::string> outs_data_type(output_names.size(), "float");
TestMain(input_names, input_shapes, output_names, 0, inputs_data_type,
outs_data_type, "elementwise_cuda_kernel_0", kernel,
elementwise_cpu_kernel_0);
TestMain(input_names, input_shapes, output_names, 0,
"elementwise_cuda_kernel_0", kernel, elementwise_cpu_kernel_0);
}
} // namespace operators
......
......@@ -77,12 +77,13 @@ class FusionGroupPassTest(PassTest):
self.check_output_with_place(fluid.CUDAPlace(0))
class FusionGroupPassTest1(FusionGroupPassTest):
class FusionGroupPassComplicatedTest(FusionGroupPassTest):
def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 5)
self.feed_vars = self._prepare_feed_vars([32, 64], dtype, 5)
tmp_0 = layers.assign(self.feed_vars[0])
one = layers.fill_constant(shape=[1], dtype=dtype, value=1.0)
tmp_0 = one * self.feed_vars[0]
# subgraph with 9 op nodes
tmp_1 = tmp_0 * layers.sigmoid(self.feed_vars[1]) + layers.sigmoid(
self.feed_vars[2]) * layers.tanh(self.feed_vars[3])
......@@ -94,7 +95,7 @@ class FusionGroupPassTest1(FusionGroupPassTest):
self.fetch_list = [tmp_2, self.grad(tmp_0)]
class FusionGroupPassTest2(FusionGroupPassTest):
class FusionGroupPassInplaceTest(FusionGroupPassTest):
def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 3)
......@@ -103,15 +104,13 @@ class FusionGroupPassTest2(FusionGroupPassTest):
name="data3", shape=[128, 32], dtype=dtype))
# subgraph with 3 op node
tmp_0 = self.feed_vars[0] + self.feed_vars[1]
tmp_1 = layers.relu(self.feed_vars[2] * tmp_0)
# subgraph with 2 op nodes
tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3]))
tmp_3 = layers.mul(tmp_1, tmp_2)
tmp_0 = self.feed_vars[0] - self.feed_vars[1]
tmp_1 = tmp_0 * self.feed_vars[2]
tmp_2 = layers.assign(tmp_1, output=tmp_0)
tmp_3 = layers.mul(tmp_2, self.feed_vars[3])
self.append_gradients(tmp_3)
self.num_fused_ops = 2
self.fetch_list = [tmp_3, self.grad(tmp_1)]
self.num_fused_ops = 1
self.fetch_list = [tmp_3]
class FusionGroupPassTestFP64(FusionGroupPassTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册