未验证 提交 c93331c5 编写于 作者: Z Zhen Wang 提交者: GitHub

Fix several bugs for enabling Paddle to train with CINN. (#36739)

* Update the content of `test_parallel_executor_run_cinn.py`.

* Fix some bugs in the topological sort and `CreateNewSubGraph`.

* Update the CINN commit id used by Paddle.

* Update the unit test to `add+relu`.

* Update according to reviewers' suggestion.
上级 c038cc7a
......@@ -27,7 +27,7 @@ add_definitions(-w)
include(ExternalProject)
set(CINN_SOURCE_DIR ${THIRD_PARTY_PATH}/CINN)
# TODO(zhhsplendid): Modify git tag after we have release tag
set(CINN_GIT_TAG e422c01b7875301996a2baf67a14ba61b0e6192a)
set(CINN_GIT_TAG cb030430d76f42f7310d09608f9b22959ecbcb51)
set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION} -DWITH_CUDA=${WITH_GPU} -DWITH_CUDNN=${WITH_GPU} -DPUBLISH_LIBS=ON -DWITH_TESTING=ON)
set(CINN_BUILD_COMMAND $(MAKE) cinnapi -j)
ExternalProject_Add(
......
......@@ -52,6 +52,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
ResolveOptionConfliction();
AppendPrintGraphPass("graph_viz_pass", "_original_graph");
#ifdef PADDLE_WITH_CINN
if (FLAGS_use_cinn) {
// Note: This pass is used to enable cinn.
AppendPass("build_cinn_pass");
AppendPrintGraphPass("graph_viz_pass", "_build_cinn_graph");
}
#endif
AppendPassWithCheck(strategy_.enable_sequential_execution_,
"sequential_execution_pass");
AppendPassWithCheck(strategy_.sync_batch_norm_, "sync_batch_norm_pass");
......@@ -74,13 +83,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Note: This pass is used to check whether the multi_device_graph is right.
AppendPass("multi_devices_check_pass");
#ifdef PADDLE_WITH_CINN
if (FLAGS_use_cinn) {
// Note: This pass is used to enable cinn.
AppendPass("build_cinn_pass");
}
#endif
SetCollectiveContext();
}
......
cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper lod_tensor proto_desc)
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector cinn_compiler)
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors enforce)
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph graph_helper transform_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn)
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn)
cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)
......
......@@ -26,9 +26,13 @@ limitations under the License. */
#include "cinn/frontend/op_mapper_registry.h"
#include "cinn/frontend/op_mappers/use_op_mappers.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace framework {
......@@ -40,11 +44,28 @@ using framework::ir::Node;
using GraphNodeVec = std::vector<Node*>;
using GraphNodeSet = std::unordered_set<Node*>;
namespace {
int ExtractOpRole(const GraphNodeSet& cluster) {
std::unordered_set<int> op_roles;
std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName();
for (auto* n : cluster) {
if (n->Op() && n->Op()->HasAttr(attr_name)) {
op_roles.insert(BOOST_GET_CONST(int, n->Op()->GetAttr(attr_name)));
}
}
if (op_roles.size() == 1U) {
return *(op_roles.begin());
} else {
return static_cast<int>(OpRole::kNotSpecified);
}
}
// Deal with subgraph's feed input var node:
// create a new input var node and it's feed op node
void AddFeedOpAndVar(const std::unordered_set<Node*>& feed_vars,
const GraphNodeSet& cluster,
const std::unordered_map<Node*, Node*>& old_op2new_op,
const std::unordered_map<Node*, Node*>& old_var2new_var,
Graph* graph) {
for (auto* old_var : feed_vars) {
// create feed op
......@@ -53,21 +74,19 @@ void AddFeedOpAndVar(const std::unordered_set<Node*>& feed_vars,
desc.SetOutput("Out", {old_var->Name()});
auto op = graph->CreateOpNode(&desc);
// create new feed var node (SSAGraph)
auto var = graph->CreateVarNode(old_var->Var());
// get new feed var node
auto* var = old_var2new_var.at(old_var);
// link feed op and feed var
op->outputs = {var};
var->inputs = {op};
IR_NODE_LINK_TO(op, var);
// link feed var to cluster op
for (auto* old_op : old_var->outputs) {
if (cluster.count(old_op)) {
var->outputs.emplace_back(old_op2new_op.at(old_op));
old_op2new_op.at(old_op)->inputs.emplace_back(var);
IR_NODE_LINK_TO(var, old_op2new_op.at(old_op));
}
// Do not need relink old op or old var here, they will be
// fixed in RemoveLinkFromCluster, here we just deal with
// fixed in RemoveSubGraphFromGraph, here we just deal with
// new subgraph's node.
}
}
......@@ -79,14 +98,14 @@ void AddFeedOpAndVar(const std::unordered_set<Node*>& feed_vars,
void AddParamVar(const std::unordered_set<Node*>& param_vars,
const GraphNodeSet& cluster,
const std::unordered_map<Node*, Node*>& old_op2new_op,
const std::unordered_map<Node*, Node*>& old_var2new_var,
Graph* graph) {
for (auto* old_var : param_vars) {
auto var = graph->CreateVarNode(old_var->Var());
auto* var = old_var2new_var.at(old_var);
for (auto* old_op : old_var->outputs) {
if (cluster.count(old_op)) {
var->outputs.emplace_back(old_op2new_op.at(old_op));
old_op2new_op.at(old_op)->inputs.emplace_back(var);
IR_NODE_LINK_TO(var, old_op2new_op.at(old_op));
}
}
}
......@@ -97,14 +116,14 @@ void AddParamVar(const std::unordered_set<Node*>& param_vars,
void AddOutputVar(const std::unordered_set<Node*>& output_vars,
const GraphNodeSet& cluster,
const std::unordered_map<Node*, Node*>& old_op2new_op,
const std::unordered_map<Node*, Node*>& old_var2new_var,
Graph* graph) {
for (auto* old_var : output_vars) {
auto var = graph->CreateVarNode(old_var->Var());
auto* var = old_var2new_var.at(old_var);
for (auto* old_op : old_var->inputs) {
if (cluster.count(old_op)) {
var->inputs.emplace_back(old_op2new_op.at(old_op));
old_op2new_op.at(old_op)->outputs.emplace_back(var);
IR_NODE_LINK_TO(old_op2new_op.at(old_op), var);
}
}
}
......@@ -128,14 +147,25 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
std::unordered_map<Node*, Node*> old_var2new_var;
for (auto* var : cluster_internals) {
Node* sub_node;
if (var->Var() == nullptr) {
sub_node = subgraph->CreateEmptyNode(var->Name(), var->NodeType());
} else {
sub_node = subgraph->CreateVarNode(var->Var());
PADDLE_ENFORCE_NOT_NULL(var->Var(),
platform::errors::PreconditionNotMet(
"The var desc of the node in cluster_internals "
"shouldn't be null."));
auto* sub_node = subgraph->CreateVarNode(var->Var());
old_var2new_var[var] = sub_node;
}
for (auto* var : cluster_inputs) {
if (var->Var()) {
auto* sub_node = subgraph->CreateVarNode(var->Var());
old_var2new_var[var] = sub_node;
}
}
for (auto* var : cluster_outputs) {
if (var->Var()) {
auto* sub_node = subgraph->CreateVarNode(var->Var());
old_var2new_var[var] = sub_node;
}
}
std::unordered_set<Node*> need_feed_vars;
std::unordered_set<Node *> param_vars, output_vars;
......@@ -144,8 +174,10 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
// out-graph.
for (auto* op : cluster) {
for (auto* var : op->inputs) {
if (cluster_internals.count(var)) {
old_op2new_op[op]->inputs.emplace_back(old_var2new_var[var]);
// one output var maybe an input of the cluster
if (cluster_internals.count(var) ||
(cluster_outputs.count(var) && old_var2new_var.count(var))) {
IR_NODE_LINK_TO(old_var2new_var.at(var), old_op2new_op.at(op));
} else if (cluster_inputs.count(var) && var->Var() != nullptr) {
if (var->Var()->IsParameter()) {
// Parameters have been preserved in scope, compared to feed var,
......@@ -162,7 +194,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
}
for (auto* var : op->outputs) {
if (cluster_internals.count(var)) {
old_op2new_op[op]->outputs.emplace_back(old_var2new_var[var]);
IR_NODE_LINK_TO(old_op2new_op.at(op), old_var2new_var.at(var));
} else if (cluster_outputs.count(var) && var->Var() != nullptr) {
// Create new output var node to guarantee the independency of
// subgraph. In other words, the subgraph has no connection with
......@@ -172,22 +204,12 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
}
}
AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, subgraph.get());
AddParamVar(param_vars, cluster, old_op2new_op, subgraph.get());
AddOutputVar(output_vars, cluster, old_op2new_op, subgraph.get());
for (auto* var : cluster_internals) {
for (auto* op : var->inputs) {
if (cluster.count(op)) {
old_var2new_var[var]->inputs.emplace_back(old_op2new_op[op]);
}
}
for (auto* op : var->outputs) {
if (cluster.count(op)) {
old_var2new_var[var]->outputs.emplace_back(old_op2new_op[op]);
}
}
}
AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, old_var2new_var,
subgraph.get());
AddParamVar(param_vars, cluster, old_op2new_op, old_var2new_var,
subgraph.get());
AddOutputVar(output_vars, cluster, old_op2new_op, old_var2new_var,
subgraph.get());
return subgraph;
}
......@@ -238,12 +260,26 @@ void AnalyseClusterVariables(const GraphNodeSet& cluster,
}
}
Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs,
void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs, Node* cinn_op_node) {
// add new link from cluster_inputs to cinn_op_node
for (auto* var_node : cluster_inputs) {
IR_NODE_LINK_TO(var_node, cinn_op_node);
}
// add new link from cinn_op_node to cluster_outputs
for (auto* var_node : cluster_outputs) {
IR_NODE_LINK_TO(cinn_op_node, var_node);
}
}
void AddCinnOpToGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const std::string& compilation_key, Graph* graph) {
// add special cinn op
framework::OpDesc special_op_desc;
special_op_desc.SetType(kCinnLaunchOp);
// Add the cinn launch op
framework::OpDesc cinn_op_desc;
cinn_op_desc.SetType(kCinnLaunchOp);
std::vector<std::string> input_names;
std::for_each(cluster_inputs.begin(), cluster_inputs.end(),
[&input_names](Node* n) {
......@@ -251,7 +287,7 @@ Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs,
input_names.emplace_back(n->Name());
}
});
special_op_desc.SetInput("X", input_names);
cinn_op_desc.SetInput("X", input_names);
std::vector<std::string> output_names;
std::for_each(cluster_outputs.begin(), cluster_outputs.end(),
[&output_names](Node* n) {
......@@ -259,96 +295,42 @@ Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs,
output_names.emplace_back(n->Name());
}
});
special_op_desc.SetOutput("Out", output_names);
special_op_desc.SetAttr(kCompilationKey, compilation_key);
special_op_desc.Flush();
auto* special_op_node = graph->CreateOpNode(&special_op_desc);
special_op_node->inputs.assign(cluster_inputs.begin(), cluster_inputs.end());
special_op_node->outputs.assign(cluster_outputs.begin(),
cluster_outputs.end());
return special_op_node;
}
void AddLinkToSpecialOp(const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
Node* special_op_node) {
// add new link from cluster_inputs to special_op_node
for (auto* var_node : cluster_inputs) {
var_node->outputs.push_back(special_op_node);
}
// add new link from special_op_node to cluster_outputs
for (auto* var_node : cluster_outputs) {
var_node->inputs.push_back(special_op_node);
}
}
void RemoveLinkFromCluster(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs) {
// remove all nodes in cluster
auto get_preserved_ops = [&cluster](const GraphNodeVec& ops) {
GraphNodeVec nodes;
for (auto* op_node : ops) {
if (cluster.find(op_node) == cluster.end()) {
nodes.emplace_back(op_node);
}
}
return nodes;
};
// removing useless link from cluster_inputs to cluster
for (auto* var_node : cluster_inputs) {
auto preserved_ops = get_preserved_ops(var_node->outputs);
var_node->outputs.assign(preserved_ops.begin(), preserved_ops.end());
// According to SSA form, a var node must not be any two op's output,
// and the cluster_inputs var nodes is defined as an out-graph op's
// output, so the cluster_inputs var nodes are not any subgraph op's
// output. Do not reassign input list here.
}
// removing useless link from cluster to cluster_outputs
for (auto* var_node : cluster_outputs) {
auto preserved_ops = get_preserved_ops(var_node->inputs);
var_node->inputs.assign(preserved_ops.begin(), preserved_ops.end());
// Note that cluster_outputs var node maybe some subgraph op's input,
// here we need remove them.
preserved_ops = get_preserved_ops(var_node->outputs);
var_node->outputs.assign(preserved_ops.begin(), preserved_ops.end());
}
cinn_op_desc.SetOutput("Out", output_names);
cinn_op_desc.SetAttr(kCompilationKey, compilation_key);
cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(cluster));
cinn_op_desc.Flush();
auto* cinn_op_node = graph->CreateOpNode(&cinn_op_desc);
// Add new links from or to the the cinn launch op node
AddLinkToCinnOp(cluster_inputs, cluster_outputs, cinn_op_node);
}
// Removing cluster node and internals node from Graph
void RemoveSubGraphFromGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_internals,
Graph* graph) {
for (auto* op_node : cluster) {
graph->RemoveNode(op_node);
}
for (auto* var_node : cluster_internals) {
graph->RemoveNode(var_node);
}
const std::unordered_set<const Node*> const_cluster{cluster.cbegin(),
cluster.cend()};
const std::unordered_set<const Node*> const_internals{
cluster_internals.cbegin(), cluster_internals.cend()};
ir::GraphSafeRemoveNodes(graph, const_cluster);
ir::GraphSafeRemoveNodes(graph, const_internals);
}
// Replacing Cinn subgraph to a special op node, whose op_type is
// Replacing Cinn subgraph to a cinn op node, whose op_type is
// kCinnLaunchOp, and inputs ares cluster_inputs and outputs are
// cluster_outputs.
// Meanwhile, move all links of cluster to the special op.
void ReplaceSubGraphWithSpecialOpNode(const GraphNodeSet& cluster,
// Meanwhile, move all links of cluster to the cinn op.
void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const GraphNodeSet& cluster_internals,
const std::string& compilation_key,
Graph* graph) {
// First, add the special op node whose name is "kCinnLaunchOp" into graph
auto special_op_node = AddSpecialOpToGraph(cluster_inputs, cluster_outputs,
compilation_key, graph);
// Second, remove all graph's links which are from or to cluster nodes
RemoveLinkFromCluster(cluster, cluster_inputs, cluster_outputs);
// Third, add new links from or to the the special op node
AddLinkToSpecialOp(cluster_inputs, cluster_outputs, special_op_node);
// Finally, remove the cinn sub graph from graph
// Add the cinn op node whose name is "kCinnLaunchOp" into graph
AddCinnOpToGraph(cluster, cluster_inputs, cluster_outputs, compilation_key,
graph);
// Remove the cinn subgraph from graph
RemoveSubGraphFromGraph(cluster, cluster_internals, graph);
}
......@@ -376,12 +358,12 @@ void SearchAllSubgraphs(Graph* graph) {
// save it in CinnCompiler
std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph(
cluster_set, cluster_internals, cluster_inputs, cluster_outputs));
// Replace the found cluster to a new special op node
ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs,
cluster_outputs, cluster_internals,
compilation_key, graph);
// Replace the found cluster to a new cinn op node
ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs,
cluster_internals, compilation_key, graph);
}
}
} // namespace
void BuildCinnPass::ApplyImpl(Graph* graph) const { SearchAllSubgraphs(graph); }
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace framework {
namespace paddle2cinn {
constexpr char kCinnLaunchOp[] = "CinnLaunchOp";
constexpr char kCinnLaunchOp[] = "cinn_launch";
constexpr char kCompilationKey[] = "compilation_key";
// A pass named BuildCinnPass, the function of this pass is:
......
......@@ -15,16 +15,18 @@ limitations under the License. */
#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"
#include <algorithm>
#include <iterator>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/paddle2cinn/transform_desc.h"
#include "paddle/fluid/framework/variable.h"
#include "cinn/frontend/op_mappers/use_op_mappers.h"
#include "cinn/frontend/var_type_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace framework {
......@@ -86,35 +88,93 @@ CinnGraphSymbolization::GetGraphInputParameterNames() const {
// Transform paddle scope to cinn, note that we only preserve the graph’s
// input parameter variable and ignore others.
std::shared_ptr<::cinn::hlir::framework::Scope>
CinnGraphSymbolization::CreateCinnScope(const FeedInfoMap& feed_map) const {
CinnGraphSymbolization::CreateCinnScope(const FeedInfoMap& feed_map) {
auto cinn_scope = ::cinn::hlir::framework::Scope::Create();
// get the graph's input parameter variable name list
auto parameter_names = GetGraphInputParameterNames();
for (const auto& param_name : parameter_names) {
VLOG(4) << "add param var [" << param_name << "] info scope";
// if cannot find var in graph input, skip.
// scope accepte the CINN format name, so here we need transform
// paddle format name to CINN format.
auto* cinn_var = cinn_scope->Var<CinnTensor>(
::cinn::utils::TransValidVarName(param_name));
auto valid_name = ::cinn::utils::TransValidVarName(param_name);
auto* cinn_var = cinn_scope->Var<CinnTensor>(valid_name);
auto& cinn_tensor = absl::get<CinnTensor>(*cinn_var);
// here we only need preserve dtype and shape, do not need preserve data
auto feed_info = feed_map.at(param_name);
cinn_tensor->set_type(feed_info.type);
cinn_tensor->Resize(::cinn::hlir::framework::Shape(feed_info.shape));
VLOG(4) << "add paddle param var [" << param_name
<< "] info cinn scope var[" << valid_name << "]";
var_model_to_program_map_[param_name] = valid_name;
}
return cinn_scope;
}
std::vector<Node*> CinnGraphSymbolization::TopologicalSort() const {
std::unordered_set<Node*> op_nodes;
std::for_each(graph_.Nodes().begin(), graph_.Nodes().end(),
[&op_nodes](Node* n) {
if (n->IsOp()) {
op_nodes.emplace(n);
}
});
std::unordered_map<Node*, std::unordered_map<Node*, size_t>> adj_list;
std::unordered_map<Node*, size_t> in_degrees;
for (auto* n : op_nodes) {
// the op's input is var
for (auto* in_var : n->inputs) {
// the var's input is op
for (auto* in_op : in_var->inputs) {
if (op_nodes.count(in_op)) {
++adj_list[in_op][n];
++in_degrees[n];
}
}
}
}
// find topology entries
std::queue<Node*> queue;
for (auto* n : op_nodes) {
if (!in_degrees[n]) {
queue.push(n);
}
}
// topological sorting
std::vector<Node*> sorted_ops;
while (!queue.empty()) {
auto* cur_op = queue.front();
queue.pop();
VLOG(4) << "topological sort insert: " << cur_op->Name() << " "
<< reinterpret_cast<void*>(cur_op) << " input "
<< cur_op->inputs.size();
sorted_ops.emplace_back(cur_op);
for (const auto& adj_pair : adj_list[cur_op]) {
in_degrees.at(adj_pair.first) -= adj_pair.second;
if (!in_degrees[adj_pair.first]) {
queue.push(adj_pair.first);
}
}
}
PADDLE_ENFORCE_EQ(sorted_ops.size(), op_nodes.size(),
platform::errors::PreconditionNotMet(
"The sorting graph contains cycles."));
return sorted_ops;
}
std::vector<std::unique_ptr<CinnOpDesc>>
CinnGraphSymbolization::TransformAllGraphOpToCinn() const {
std::vector<std::unique_ptr<CinnOpDesc>> cinn_op_descs;
const auto& sorted_ops = ir::TopologySortOperations(graph_);
auto sorted_ops = TopologicalSort();
for (auto* node : sorted_ops) {
cinn_op_descs.emplace_back(std::make_unique<CinnOpDesc>());
auto& cinn_desc = cinn_op_descs.back();
......
......@@ -102,6 +102,9 @@ class CinnGraphSymbolization {
// transform all paddle var desc in feed list into cinn_var_descs_
FeedInfoMap GetFeedInfoMapFromInput() const;
// get the topological sort of the graph_
std::vector<ir::Node*> TopologicalSort() const;
// transform all paddle op desc in graph into cinn op desc
std::vector<std::unique_ptr<CinnOpDesc>> TransformAllGraphOpToCinn() const;
......@@ -115,7 +118,7 @@ class CinnGraphSymbolization {
// create cinn scope and add parameter's feed info into scope
std::shared_ptr<::cinn::hlir::framework::Scope> CreateCinnScope(
const FeedInfoMap& feed_map) const;
const FeedInfoMap& feed_map);
// get the graph op's input persistable var name set
std::unordered_set<std::string> GetGraphInputParameterNames() const;
......
......@@ -268,7 +268,7 @@ TEST_F(CinnGraphSymbolizationTest, sortgraph) {
sort_names.emplace_back(desc->Type());
}
ASSERT_EQ(sort_names,
std::vector<std::string>({"feed", "mul", "feed", "add", "relu"}));
std::vector<std::string>({"feed", "feed", "mul", "add", "relu"}));
}
TEST_F(CinnGraphSymbolizationTest, runop) {
......
......@@ -16,14 +16,17 @@ from __future__ import print_function
import logging
import numpy as np
import os
import paddle
import shutil
import tempfile
import unittest
paddle.enable_static()
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
logger = logging.getLogger("paddle_with_cinn")
def set_cinn_flag(val):
......@@ -36,34 +39,79 @@ def set_cinn_flag(val):
return cinn_compiled
@unittest.skipIf(not set_cinn_flag(True), "Paddle is not compiled with CINN.")
class TestParallelExecutorRunCinn(unittest.TestCase):
def test_run_from_cinn(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
def reader(limit):
for i in range(limit):
yield np.ones([1, 28]).astype('float32') * (i * 3.14 / (i + 1)), \
np.array([i + 1]).astype('int64')
def rand_data(img, label, loop_num=10):
feed = []
data = reader(loop_num)
for _ in range(loop_num):
d, l = next(data)
feed.append({img: d, label: l})
return feed
def build_program(main_program, startup_program):
with paddle.static.program_guard(main_program, startup_program):
data = paddle.static.data(
name='X', shape=[None, 1], dtype='float32')
prediction = paddle.static.nn.fc(data, 2)
loss = paddle.mean(prediction)
adam = paddle.optimizer.Adam()
adam.minimize(loss)
img = paddle.static.data(name='img', shape=[1, 28], dtype='float32')
param = paddle.create_parameter(
name="bias",
shape=[1, 28],
dtype="float32",
attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(
np.ones([1, 28]).astype(np.float32))))
label = paddle.static.data(name="label", shape=[1], dtype='int64')
hidden = paddle.add(img, param)
prediction = paddle.nn.functional.relu(hidden)
loss = paddle.nn.functional.cross_entropy(input=prediction, label=label)
avg_loss = paddle.mean(loss)
adam = paddle.optimizer.Adam(learning_rate=0.001)
adam.minimize(avg_loss)
return img, label, avg_loss
def do_test(dot_save_dir):
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
img, label, loss = build_program(main_program, startup_program)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_program)
build_strategy = paddle.static.BuildStrategy()
build_strategy.debug_graphviz_path = os.path.join(dot_save_dir, "viz")
compiled_program = paddle.static.CompiledProgram(
main_program).with_data_parallel(loss_name=loss.name)
main_program, build_strategy).with_data_parallel(loss_name=loss.name)
batch_size = 16
x = np.random.random(size=(batch_size, 1)).astype('float32')
fetch = exe.run(compiled_program,
feed={'X': x},
fetch_list=[prediction.name],
iters = 1
feed = rand_data(img.name, label.name, iters)
for step in range(iters):
loss_v = exe.run(compiled_program,
feed=feed[step],
fetch_list=[loss],
return_merged=False)
logger.info("loss value = {}".format(loss_v))
@unittest.skipIf(not set_cinn_flag(True), "Paddle is not compiled with CINN.")
class TestParallelExecutorRunCinn(unittest.TestCase):
def setUp(self):
set_cinn_flag(True)
self.tmpdir = tempfile.mkdtemp(prefix="dots_")
def tearDown(self):
set_cinn_flag(False)
shutil.rmtree(self.tmpdir)
def test_run_with_cinn(self):
do_test(self.tmpdir)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册