未验证 提交 1917b380 编写于 作者: W wanghuancoder 提交者: GitHub

fix some errmsg report,in framework/ir/, about 21 files (#25525)

* fix error msg report in ir/, about 19 files, test=develop

* modified some unclear descriptions, test=develop

* modified some unclear descriptions, test=develop

* modify unit test pass_test.cc, because the error report in pass.cc is used by pass_test.cc, test=develop
上级 4ec1251a
...@@ -320,7 +320,7 @@ std::vector<Node *> FuseBatchNormActPass::ReplaceNode( ...@@ -320,7 +320,7 @@ std::vector<Node *> FuseBatchNormActPass::ReplaceNode(
return node; return node;
}); });
PADDLE_ENFORCE_EQ(has_replaced, true, PADDLE_ENFORCE_EQ(has_replaced, true,
platform::errors::NotFound("Not find %s in the node list.", platform::errors::NotFound("Not found %s in the node list.",
cur_node->Name())); cur_node->Name()));
return new_list; return new_list;
} }
......
...@@ -42,7 +42,8 @@ void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const { ...@@ -42,7 +42,8 @@ void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
// ele_add(x, act(y)) // ele_add(x, act(y))
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct( ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const { ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("elewise_add_act", graph); FusePassBase::Init("elewise_add_act", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -93,7 +94,8 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct( ...@@ -93,7 +94,8 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
// act(ele_add(x,y)) // act(ele_add(x,y))
ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd( ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const { ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("act_elewise_add", graph); FusePassBase::Init("act_elewise_add", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -145,7 +147,8 @@ ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd( ...@@ -145,7 +147,8 @@ ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"] // ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const { ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("elewise_add_act_grad", graph); FusePassBase::Init("elewise_add_act_grad", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -252,10 +255,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { ...@@ -252,10 +255,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
bool save_intermediate_out = BOOST_GET_CONST( bool save_intermediate_out = BOOST_GET_CONST(
bool, cur_node->Op()->GetAttr("save_intermediate_out")); bool, cur_node->Op()->GetAttr("save_intermediate_out"));
auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut"); auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut");
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
save_intermediate_out && !intermediate_out_args.empty(), (save_intermediate_out && !intermediate_out_args.empty()), true,
"The %s should save the intermediate_out in the fusing stage.", platform::errors::InvalidArgument(
cur_node->Name()); "The %s should save the intermediate out in the fusing stage.",
cur_node->Name()));
// If the intermediate_out's output is empty, it should be removed. // If the intermediate_out's output is empty, it should be removed.
auto cur_node_outputs = cur_node->outputs; auto cur_node_outputs = cur_node->outputs;
...@@ -271,10 +275,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { ...@@ -271,10 +275,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
} else if (cur_node->Name() == "fused_elemwise_activation_grad") { } else if (cur_node->Name() == "fused_elemwise_activation_grad") {
auto intermediate_out_grad_args = auto intermediate_out_grad_args =
cur_node->Op()->Output(GradVarName("IntermediateOut")); cur_node->Op()->Output(GradVarName("IntermediateOut"));
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
!intermediate_out_grad_args.empty(), intermediate_out_grad_args.empty(), false,
"The %s should save the intermediate_out in the fusing stage.", platform::errors::InvalidArgument(
cur_node->Name()); "The %s should save the intermediate out in the fusing stage.",
cur_node->Name()));
auto cur_node_outputs = cur_node->outputs; auto cur_node_outputs = cur_node->outputs;
// If the intermediate_out_g's output is empty, it should be removed. // If the intermediate_out_g's output is empty, it should be removed.
for (auto &out : cur_node_outputs) { for (auto &out : cur_node_outputs) {
...@@ -312,7 +317,11 @@ void FuseElewiseAddActPass::ReLinkNodes(Graph *graph, ...@@ -312,7 +317,11 @@ void FuseElewiseAddActPass::ReLinkNodes(Graph *graph,
nodes2delete.emplace(out); nodes2delete.emplace(out);
} }
} else { } else {
PADDLE_ENFORCE(out == intermediate_out); PADDLE_ENFORCE_EQ(
out, intermediate_out,
platform::errors::InvalidArgument(
"Output of op(%s) must be %s, but not %s.", op_1->Name(),
intermediate_out->Name(), out->Name()));
IR_OP_VAR_LINK(fused_op, out); IR_OP_VAR_LINK(fused_op, out);
} }
} }
...@@ -347,8 +356,9 @@ std::vector<Node *> FuseElewiseAddActPass::ReplaceNode( ...@@ -347,8 +356,9 @@ std::vector<Node *> FuseElewiseAddActPass::ReplaceNode(
} }
return node; return node;
}); });
PADDLE_ENFORCE(has_replaced, "Not find %s in the node list.", PADDLE_ENFORCE_EQ(has_replaced, true,
cur_node->Name()); platform::errors::NotFound("Not found %s in the node list.",
cur_node->Name()));
return new_list; return new_list;
} }
......
...@@ -25,14 +25,19 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const { ...@@ -25,14 +25,19 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const {
} }
Scope* FusePassBase::param_scope() const { Scope* FusePassBase::param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); PADDLE_ENFORCE_EQ(graph_->Has(kParamScopeAttr), true,
platform::errors::InvalidArgument(
"Graph must have kParamScopeAttr attribute."));
auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr); auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr);
return &scope; return &scope;
} }
void FusePassBase::AddStatis(int count_of_fused) const { void FusePassBase::AddStatis(int count_of_fused) const {
PADDLE_ENFORCE(graph_); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(!repr_.empty()); graph_, platform::errors::InvalidArgument("Graph cannot be nullptr."));
PADDLE_ENFORCE_EQ(repr_.empty(), false,
platform::errors::InvalidArgument(
"Fuse pass must be initialized with a name."));
if (!graph_->Has(kFuseStatisAttr)) { if (!graph_->Has(kFuseStatisAttr)) {
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>); graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
} }
......
...@@ -31,7 +31,8 @@ void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const { ...@@ -31,7 +31,8 @@ void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
ir::Graph *graph, bool only_forward) const { ir::Graph *graph, bool only_forward) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
if (only_forward) if (only_forward)
FusePassBase::Init("relu_depthwise_conv_only_forward", graph); FusePassBase::Init("relu_depthwise_conv_only_forward", graph);
else else
...@@ -110,23 +111,45 @@ ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( ...@@ -110,23 +111,45 @@ ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
xg_var = subgraph.at(xg)->Var(); xg_var = subgraph.at(xg)->Var();
} }
PADDLE_ENFORCE_EQ(layer_op->Input("Input").size(), 1UL); PADDLE_ENFORCE_EQ(layer_op->Input("Input").size(), 1UL,
PADDLE_ENFORCE_EQ(layer_op->Input("Input")[0], y_var->Name()); platform::errors::InvalidArgument(
"Op(%s)'s input size(%d) must be 1.",
layer_op->Type(), layer_op->Input("Input").size()));
PADDLE_ENFORCE_EQ(
layer_op->Input("Input")[0], y_var->Name(),
platform::errors::InvalidArgument(
"Op(%s)'s input name(%s) must be %s.", layer_op->Type(),
layer_op->Input("Input")[0], y_var->Name()));
layer_op->SetInput("Input", {x_var->Name()}); layer_op->SetInput("Input", {x_var->Name()});
subgraph.at(layer)->inputs.push_back(subgraph.at(x)); subgraph.at(layer)->inputs.push_back(subgraph.at(x));
subgraph.at(x)->outputs.push_back(subgraph.at(layer)); subgraph.at(x)->outputs.push_back(subgraph.at(layer));
VLOG(4) << "replace " << y_var->Name() << " -> " << x_var->Name(); VLOG(4) << "replace " << y_var->Name() << " -> " << x_var->Name();
if (!only_forward) { if (!only_forward) {
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input").size(), 1UL); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input")[0], y_var->Name()); layer_g_op->Input("Input").size(), 1UL,
platform::errors::InvalidArgument(
"Op(%s)'s input size(%d) must be 1.", layer_g_op->Type(),
layer_g_op->Input("Input").size()));
PADDLE_ENFORCE_EQ(
layer_g_op->Input("Input")[0], y_var->Name(),
platform::errors::InvalidArgument(
"Op(%s)'s input name(%s) must be %s.", layer_g_op->Type(),
layer_g_op->Input("Input")[0], y_var->Name()));
layer_g_op->SetInput("Input", {x_var->Name()}); layer_g_op->SetInput("Input", {x_var->Name()});
subgraph.at(layer_g)->inputs.push_back(subgraph.at(x)); subgraph.at(layer_g)->inputs.push_back(subgraph.at(x));
subgraph.at(x)->outputs.push_back(subgraph.at(layer_g)); subgraph.at(x)->outputs.push_back(subgraph.at(layer_g));
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input")).size(), 1UL); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input"))[0], layer_g_op->Output(GradVarName("Input")).size(), 1UL,
yg_var->Name()); platform::errors::InvalidArgument(
"Op(%s)'s input size(%d) must be 1.", layer_g_op->Type(),
layer_g_op->Output(GradVarName("Input")).size()));
PADDLE_ENFORCE_EQ(
layer_g_op->Output(GradVarName("Input"))[0], yg_var->Name(),
platform::errors::InvalidArgument(
"Op(%s)'s input name(%s) must be %s.", layer_g_op->Type(),
layer_g_op->Output(GradVarName("Input"))[0], yg_var->Name()));
layer_g_op->SetOutput(GradVarName("Input"), {xg_var->Name()}); layer_g_op->SetOutput(GradVarName("Input"), {xg_var->Name()});
subgraph.at(layer_g)->outputs.push_back(subgraph.at(xg)); subgraph.at(layer_g)->outputs.push_back(subgraph.at(xg));
subgraph.at(xg)->inputs.push_back(subgraph.at(layer_g)); subgraph.at(xg)->inputs.push_back(subgraph.at(layer_g));
......
...@@ -136,7 +136,9 @@ bool FindCircleSubGraph(const Graph &graph, ...@@ -136,7 +136,9 @@ bool FindCircleSubGraph(const Graph &graph,
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) { std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp> std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
adj_list = BuildOperationAdjList(graph); adj_list = BuildOperationAdjList(graph);
PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr)); PADDLE_ENFORCE_EQ(HasCircleInternal(adj_list, nullptr), false,
platform::errors::InvalidArgument(
"Generated graph shouldn't contain cycle."));
std::unordered_set<ir::Node *> visited; std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret; std::vector<ir::Node *> ret;
for (auto adj : adj_list) { for (auto adj : adj_list) {
...@@ -161,7 +163,11 @@ BuildOperationAdjList(const Graph &graph) { ...@@ -161,7 +163,11 @@ BuildOperationAdjList(const Graph &graph) {
} }
for (auto &var : n->inputs) { for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) { for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); PADDLE_ENFORCE_EQ(
adj_n->NodeType(), ir::Node::Type::kOperation,
platform::errors::InvalidArgument(
"Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(),
static_cast<int>(adj_n->NodeType())));
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n) VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n) << " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var); << " via " << var->Name() << reinterpret_cast<void *>(var);
...@@ -184,7 +190,11 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList( ...@@ -184,7 +190,11 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList(
} }
for (auto &var : n->outputs) { for (auto &var : n->outputs) {
for (auto &adj_n : var->outputs) { for (auto &adj_n : var->outputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); PADDLE_ENFORCE_EQ(
adj_n->NodeType(), ir::Node::Type::kOperation,
platform::errors::InvalidArgument(
"Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(),
static_cast<int>(adj_n->NodeType())));
VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n) VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n) << " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var); << " via " << var->Name() << reinterpret_cast<void *>(var);
...@@ -359,7 +369,10 @@ size_t GraphNum(const Graph &graph) { ...@@ -359,7 +369,10 @@ size_t GraphNum(const Graph &graph) {
} }
std::unique_ptr<std::ostream> fout( std::unique_ptr<std::ostream> fout(
new std::ofstream(FLAGS_print_sub_graph_dir)); new std::ofstream(FLAGS_print_sub_graph_dir));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE_EQ(fout->good(), true,
platform::errors::Unavailable(
"Can not open file %s for printing the graph.",
FLAGS_print_sub_graph_dir));
*fout << out.str(); *fout << out.str();
} }
} }
......
...@@ -33,7 +33,8 @@ const char kSumGradOpName[] = "sum"; ...@@ -33,7 +33,8 @@ const char kSumGradOpName[] = "sum";
const char kOptimizerType[] = "sgd"; const char kOptimizerType[] = "sgd";
void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
// We could collect all weights' name from SGD, where // We could collect all weights' name from SGD, where
// W1 <- SGD(W0, Grad0) // W1 <- SGD(W0, Grad0)
...@@ -41,7 +42,10 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -41,7 +42,10 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (IsOpNamed(node, kOptimizerType)) { if (IsOpNamed(node, kOptimizerType)) {
auto& param_out_vars = node->Op()->Output("ParamOut"); auto& param_out_vars = node->Op()->Output("ParamOut");
PADDLE_ENFORCE(param_out_vars.size() == 1u); PADDLE_ENFORCE_EQ(
param_out_vars.size(), 1u,
platform::errors::InvalidArgument(
"In op(%s), find output(ParamOut) failed.", node->Name()));
weight_var_set.insert(param_out_vars[0]); weight_var_set.insert(param_out_vars[0]);
} }
} }
...@@ -95,12 +99,19 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -95,12 +99,19 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Found forward_op " << forward_op->Name(); VLOG(3) << "Found forward_op " << forward_op->Name();
PADDLE_ENFORCE(forward_op); PADDLE_ENFORCE_NOT_NULL(
forward_op, platform::errors::NotFound(
"Can not find forward op for backword op(%s).",
backward_op->Name()));
Node* new_optimizer_node = CreateNewSGDNode( Node* new_optimizer_node = CreateNewSGDNode(
graph, forward_op, backward_op, node, opt_node); graph, forward_op, backward_op, node, opt_node);
PADDLE_ENFORCE(new_optimizer_node); PADDLE_ENFORCE_NOT_NULL(
new_optimizer_node,
platform::errors::InvalidArgument(
"Create new SGD node failed, backward op is %s.",
backward_op->Name()));
} }
} }
} }
...@@ -144,11 +155,21 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -144,11 +155,21 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
ir::Node* LockFreeOptimizePass::CreateNewSGDNode( ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
ir::Graph* graph, ir::Node* forward_node, ir::Node* backward_node, ir::Graph* graph, ir::Node* forward_node, ir::Node* backward_node,
ir::Node* grad_sum_node, ir::Node* optimize_node) const { ir::Node* grad_sum_node, ir::Node* optimize_node) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(graph,
PADDLE_ENFORCE(forward_node); platform::errors::InvalidArgument(
PADDLE_ENFORCE(backward_node); "Input argument graph cannot be nullptr."));
PADDLE_ENFORCE(grad_sum_node); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(optimize_node); forward_node, platform::errors::InvalidArgument(
"Input argument forward_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
backward_node, platform::errors::InvalidArgument(
"Input argument backward_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
grad_sum_node, platform::errors::InvalidArgument(
"Input argument grad_sum_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
optimize_node, platform::errors::InvalidArgument(
"Input argument optimize_node cannot be nullptr."));
// find the grad var node between the grad sum node and backward_node // find the grad var node between the grad sum node and backward_node
std::vector<ir::Node*> grad_vars = std::vector<ir::Node*> grad_vars =
...@@ -159,7 +180,8 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode( ...@@ -159,7 +180,8 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
grad_node = node; grad_node = node;
} }
} }
PADDLE_ENFORCE(grad_node); PADDLE_ENFORCE_NOT_NULL(grad_node, platform::errors::NotFound(
"Can not find control dep variable."));
// create a new SGD node // create a new SGD node
OpDesc* old_desc = optimize_node->Op(); OpDesc* old_desc = optimize_node->Op();
...@@ -212,8 +234,14 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode( ...@@ -212,8 +234,14 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
} }
// SGD must have only one param and LR in // SGD must have only one param and LR in
PADDLE_ENFORCE(old_desc->Input("LearningRate").size() == 1u); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE(old_desc->Input("Param").size() == 1u); old_desc->Input("LearningRate").size(), 1u,
platform::errors::InvalidArgument(
"In op(%s), find input(LearningRate) failed.", old_desc->Type()));
PADDLE_ENFORCE_EQ(
old_desc->Input("Param").size(), 1u,
platform::errors::InvalidArgument("In op(%s), find input(Param) failed.",
old_desc->Type()));
// LR and weight nodes should be copied // LR and weight nodes should be copied
for (Node* upstream_node : optimize_node->inputs) { for (Node* upstream_node : optimize_node->inputs) {
...@@ -245,9 +273,17 @@ std::vector<ir::Node*> LockFreeOptimizePass::FindConnectedNode( ...@@ -245,9 +273,17 @@ std::vector<ir::Node*> LockFreeOptimizePass::FindConnectedNode(
void LockFreeOptimizePass::ReplaceUpstreamNode( void LockFreeOptimizePass::ReplaceUpstreamNode(
ir::Node* upstream_node, ir::Node* old_optimizer_node, ir::Node* upstream_node, ir::Node* old_optimizer_node,
ir::Node* new_optimizer_node) const { ir::Node* new_optimizer_node) const {
PADDLE_ENFORCE(upstream_node); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(old_optimizer_node); upstream_node, platform::errors::InvalidArgument(
PADDLE_ENFORCE(new_optimizer_node); "Input argument upstream_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
old_optimizer_node,
platform::errors::InvalidArgument(
"Input argument old_optimizer_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
new_optimizer_node,
platform::errors::InvalidArgument(
"Input argument new_optimizer_node cannot be nullptr."));
// Remove the old_optimizer_node from upstream_node's outputs vector // Remove the old_optimizer_node from upstream_node's outputs vector
auto& output_node_vec = upstream_node->outputs; auto& output_node_vec = upstream_node->outputs;
...@@ -268,8 +304,14 @@ void LockFreeOptimizePass::ReplaceUpstreamNode( ...@@ -268,8 +304,14 @@ void LockFreeOptimizePass::ReplaceUpstreamNode(
void LockFreeOptimizePass::ReplaceAllDownstreamNode( void LockFreeOptimizePass::ReplaceAllDownstreamNode(
ir::Node* old_optimizer_node, ir::Node* new_optimizer_node) const { ir::Node* old_optimizer_node, ir::Node* new_optimizer_node) const {
PADDLE_ENFORCE(old_optimizer_node); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(new_optimizer_node); old_optimizer_node,
platform::errors::InvalidArgument(
"Input argument old_optimizer_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
new_optimizer_node,
platform::errors::InvalidArgument(
"Input argument new_optimizer_node cannot be nullptr."));
for (ir::Node* downstream_node : old_optimizer_node->outputs) { for (ir::Node* downstream_node : old_optimizer_node->outputs) {
// Remove the old_optimizer_node from downstream_node's inputs vector // Remove the old_optimizer_node from downstream_node's inputs vector
...@@ -292,8 +334,12 @@ void LockFreeOptimizePass::ReplaceAllDownstreamNode( ...@@ -292,8 +334,12 @@ void LockFreeOptimizePass::ReplaceAllDownstreamNode(
ir::Node* LockFreeOptimizePass::FindForwardOpViaBackwardOp( ir::Node* LockFreeOptimizePass::FindForwardOpViaBackwardOp(
ir::Graph* graph, ir::Node* backward_node) const { ir::Graph* graph, ir::Node* backward_node) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(graph,
PADDLE_ENFORCE(backward_node); platform::errors::InvalidArgument(
"Input argument graph cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
backward_node, platform::errors::InvalidArgument(
"Input argument backward_node cannot be nullptr."));
// strip the suffix _grad of backward_node's name // strip the suffix _grad of backward_node's name
std::string forward_op_name = backward_node->Name(); std::string forward_op_name = backward_node->Name();
......
...@@ -87,34 +87,46 @@ class LockFreeOptimizePass : public Pass { ...@@ -87,34 +87,46 @@ class LockFreeOptimizePass : public Pass {
ir::Node* downstream_node) const; ir::Node* downstream_node) const;
inline bool IsOpNamed(ir::Node* node, const std::string& name) const { inline bool IsOpNamed(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node); PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kOperation && node->Name() == name; return node->NodeType() == Node::Type::kOperation && node->Name() == name;
} }
inline bool IsVarNamed(ir::Node* node, const std::string& name) const { inline bool IsVarNamed(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node); PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kVariable && node->Name() == name; return node->NodeType() == Node::Type::kVariable && node->Name() == name;
} }
inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const { inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node); PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kVariable && return node->NodeType() == Node::Type::kVariable &&
boost::algorithm::ends_with(node->Name(), name); boost::algorithm::ends_with(node->Name(), name);
} }
inline bool IsVarNameContains(ir::Node* node, const std::string& name) const { inline bool IsVarNameContains(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node); PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kVariable && return node->NodeType() == Node::Type::kVariable &&
node->Name().find(name) != std::string::npos; node->Name().find(name) != std::string::npos;
} }
inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const { inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const {
PADDLE_ENFORCE(ctrl_dep_node); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(node); ctrl_dep_node, platform::errors::InvalidArgument(
"Input argument ctrl_dep_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return IsControlDepVar(*ctrl_dep_node) && return IsControlDepVar(*ctrl_dep_node) &&
ctrl_dep_node->inputs.size() >= 1u && ctrl_dep_node->inputs.size() >= 1u &&
......
...@@ -66,12 +66,18 @@ class Node { ...@@ -66,12 +66,18 @@ class Node {
std::string Name() const { return name_; } std::string Name() const { return name_; }
VarDesc* Var() const { VarDesc* Var() const {
PADDLE_ENFORCE_EQ(IsVar(), true); PADDLE_ENFORCE_EQ(IsVar(), true,
platform::errors::InvalidArgument(
"Node(%s) must be kVariable type, but not %d.", name_,
static_cast<int>(type_)));
return var_desc_.get(); return var_desc_.get();
} }
OpDesc* Op() const { OpDesc* Op() const {
PADDLE_ENFORCE_EQ(IsOp(), true); PADDLE_ENFORCE_EQ(IsOp(), true,
platform::errors::InvalidArgument(
"Node(%s) must be kOperation type, but not %d.",
name_, static_cast<int>(type_)));
return op_desc_.get(); return op_desc_.get();
} }
...@@ -92,8 +98,9 @@ class Node { ...@@ -92,8 +98,9 @@ class Node {
try { try {
return *boost::any_cast<T*>(wrapper_); return *boost::any_cast<T*>(wrapper_);
} catch (boost::bad_any_cast&) { } catch (boost::bad_any_cast&) {
PADDLE_THROW("Invalid wrapper type error, expected %s, actual %s", PADDLE_THROW(platform::errors::InvalidArgument(
typeid(T).name(), wrapper_type_.name()); "Invalid wrapper type error, expected %s, actual %s.",
typeid(T).name(), wrapper_type_.name()));
} }
} }
...@@ -114,8 +121,9 @@ class Node { ...@@ -114,8 +121,9 @@ class Node {
} }
void RenameVar(const std::string& new_name) { void RenameVar(const std::string& new_name) {
PADDLE_ENFORCE(type_ == Type::kVariable && var_desc_, PADDLE_ENFORCE_EQ(
"Must be type of variable"); type_ == Type::kVariable && var_desc_, true,
platform::errors::InvalidArgument("Node must be type of variable."));
name_ = new_name; name_ = new_name;
var_desc_->SetName(new_name); var_desc_->SetName(new_name);
} }
......
...@@ -26,7 +26,8 @@ namespace ir { ...@@ -26,7 +26,8 @@ namespace ir {
Graph* Pass::Apply(Graph* graph) const { Graph* Pass::Apply(Graph* graph) const {
CheckPrevPass(); CheckPrevPass();
PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty."); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
for (const std::string& attr : required_pass_attrs_) { for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
attrs_.find(attr), attrs_.end(), attrs_.find(attr), attrs_.end(),
...@@ -40,11 +41,14 @@ Graph* Pass::Apply(Graph* graph) const { ...@@ -40,11 +41,14 @@ Graph* Pass::Apply(Graph* graph) const {
} }
ApplyImpl(graph); ApplyImpl(graph);
// TODO(panyx0718): Add more verifications. // TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*graph), PADDLE_ENFORCE_EQ(
"Illegal Pass %s. Generated graph shouldn't have cycle.", HasCircle(*graph), false,
Type()); platform::errors::InvalidArgument(
PADDLE_ENFORCE(VarDescIsConsistency(*graph), "Illegal pass %s. Generated graph shouldn't contain cycle.", Type()));
"The VarDescs of persistable variable are not consistency."); PADDLE_ENFORCE_EQ(
VarDescIsConsistency(*graph), true,
platform::errors::InvalidArgument(
"The VarDescs of persistable variable are not consistency."));
applied_ = true; applied_ = true;
if (!graph->Has(kPassRecorder)) { if (!graph->Has(kPassRecorder)) {
graph->Set<PassRecorder>(kPassRecorder, new PassRecorder); graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
......
...@@ -55,8 +55,9 @@ class Pass { ...@@ -55,8 +55,9 @@ class Pass {
// Get a reference to the attributed previously set. // Get a reference to the attributed previously set.
template <typename AttrType> template <typename AttrType>
AttrType &Get(const std::string &attr_name) const { AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(), PADDLE_ENFORCE_NE(attrs_.find(attr_name), attrs_.end(),
"%s attr not registered for pass.", attr_name); platform::errors::InvalidArgument(
"Attribute %s not registered for pass.", attr_name));
try { try {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name)); return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} catch (boost::bad_any_cast &) { } catch (boost::bad_any_cast &) {
...@@ -76,7 +77,7 @@ class Pass { ...@@ -76,7 +77,7 @@ class Pass {
}; };
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid type for attritube %s, expected: %s, actual: %s", attr_name, "Invalid type for attritube %s, expected: %s, actual: %s.", attr_name,
TypeToString(typeid(AttrType *)), TypeToString(typeid(AttrType *)),
TypeToString(attrs_.at(attr_name).type()))); TypeToString(attrs_.at(attr_name).type())));
} }
...@@ -101,9 +102,10 @@ class Pass { ...@@ -101,9 +102,10 @@ class Pass {
template <typename AttrType> template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) { void Set(const std::string &attr_name, AttrType *attr) {
if (default_pass_attrs_.count(attr_name) == 0) { if (default_pass_attrs_.count(attr_name) == 0) {
PADDLE_ENFORCE_EQ(attrs_.count(attr_name), 0, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( attrs_.count(attr_name), 0,
"Attribute %s already set in the pass", attr_name)); platform::errors::AlreadyExists(
"Attribute %s already set in the pass.", attr_name));
} else { } else {
VLOG(3) << "Setting the attribute " << attr_name << " for the pass " VLOG(3) << "Setting the attribute " << attr_name << " for the pass "
<< type_; << type_;
...@@ -119,15 +121,16 @@ class Pass { ...@@ -119,15 +121,16 @@ class Pass {
// should delete the attribute. // should delete the attribute.
template <typename AttrType> template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) { void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass", PADDLE_ENFORCE_EQ(attrs_.count(attr_name), 0,
attr_name); platform::errors::AlreadyExists(
"Attribute %s already set in the pass.", attr_name));
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
} }
protected: protected:
virtual void ApplyImpl(Graph *graph) const { virtual void ApplyImpl(Graph *graph) const {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"The virtual Pass called is not implemented.")); "The virtual pass called is not implemented."));
} }
// Some Pass must be placed before this Pass, and some // Some Pass must be placed before this Pass, and some
...@@ -198,8 +201,9 @@ class PassRegistry { ...@@ -198,8 +201,9 @@ class PassRegistry {
} }
std::unique_ptr<Pass> Get(const std::string &pass_type) const { std::unique_ptr<Pass> Get(const std::string &pass_type) const {
PADDLE_ENFORCE(Has(pass_type), "Pass %s has not been registered", PADDLE_ENFORCE_EQ(Has(pass_type), true,
pass_type); platform::errors::InvalidArgument(
"Pass %s has not been registered.", pass_type));
return map_.at(pass_type)(); return map_.at(pass_type)();
} }
...@@ -213,8 +217,10 @@ class PassRegistry { ...@@ -213,8 +217,10 @@ class PassRegistry {
template <typename PassType> template <typename PassType>
struct PassRegistrar : public Registrar { struct PassRegistrar : public Registrar {
explicit PassRegistrar(const char *pass_type) { explicit PassRegistrar(const char *pass_type) {
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type), PADDLE_ENFORCE_EQ(
"'%s' is registered more than once.", pass_type); PassRegistry::Instance().Has(pass_type), false,
platform::errors::AlreadyExists(
"Pass '%s' is registered more than once.", pass_type));
PassRegistry::Instance().Insert( PassRegistry::Instance().Insert(
pass_type, [this, pass_type]() -> std::unique_ptr<Pass> { pass_type, [this, pass_type]() -> std::unique_ptr<Pass> {
std::unique_ptr<Pass> pass(new PassType()); std::unique_ptr<Pass> pass(new PassType());
......
...@@ -28,13 +28,19 @@ std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) { ...@@ -28,13 +28,19 @@ std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) {
} }
void PassBuilder::RemovePass(size_t idx) { void PassBuilder::RemovePass(size_t idx) {
PADDLE_ENFORCE(passes_.size() > idx); PADDLE_ENFORCE_GT(
passes_.size(), idx,
platform::errors::InvalidArgument(
"Passes size is %d, %d is not a valid index.", passes_.size(), idx));
passes_.erase(passes_.begin() + idx); passes_.erase(passes_.begin() + idx);
} }
std::shared_ptr<Pass> PassBuilder::InsertPass(size_t idx, std::shared_ptr<Pass> PassBuilder::InsertPass(size_t idx,
const std::string& pass_type) { const std::string& pass_type) {
PADDLE_ENFORCE(passes_.size() >= idx); PADDLE_ENFORCE_GE(
passes_.size(), idx,
platform::errors::InvalidArgument(
"Passes size is %d, %d is not a valid index.", passes_.size(), idx));
std::shared_ptr<Pass> pass( std::shared_ptr<Pass> pass(
ir::PassRegistry::Instance().Get(pass_type).release()); ir::PassRegistry::Instance().Get(pass_type).release());
passes_.insert(passes_.begin() + idx, std::move(pass)); passes_.insert(passes_.begin() + idx, std::move(pass));
......
...@@ -119,7 +119,7 @@ TEST(PassTest, TestPassAttrCheck) { ...@@ -119,7 +119,7 @@ TEST(PassTest, TestPassAttrCheck) {
} catch (paddle::platform::EnforceNotMet& e) { } catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what()); exception = std::string(e.what());
} }
ASSERT_TRUE(exception.find("shouldn't have cycle") != exception.npos); ASSERT_TRUE(exception.find("shouldn't contain cycle") != exception.npos);
pass = PassRegistry::Instance().Get("test_pass"); pass = PassRegistry::Instance().Get("test_pass");
pass->Set<int>("test_pass_attr", new int); pass->Set<int>("test_pass_attr", new int);
......
...@@ -43,9 +43,11 @@ void DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -43,9 +43,11 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
// ops linked from it // ops linked from it
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
PADDLE_ENFORCE_EQ(subgraph.count(input_act_node), true, PADDLE_ENFORCE_EQ(
subgraph.count(input_act_node), true,
platform::errors::NotFound( platform::errors::NotFound(
"Input act node not found in Delete Quant fusion.")); "Input act node(%s) not found in QuantDequantFuse pass.",
input_act_node->name()));
Node* input_act = subgraph.at(input_act_node); Node* input_act = subgraph.at(input_act_node);
Node* input_scale = subgraph.at(pattern.GetPDNode("input_scale_node")); Node* input_scale = subgraph.at(pattern.GetPDNode("input_scale_node"));
Node* quant = subgraph.at(pattern.GetPDNode("quant_node")); Node* quant = subgraph.at(pattern.GetPDNode("quant_node"));
...@@ -58,7 +60,7 @@ void DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -58,7 +60,7 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
std::string input_scale_var_name = quant->Op()->Input("InScale").front(); std::string input_scale_var_name = quant->Op()->Input("InScale").front();
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument( scope, platform::errors::InvalidArgument(
"scope in DeleteQuantOpFuse pass should not be null.")); "Scope in QuantDequantFuse pass should not be null."));
const LoDTensor& input_scale_tensor = const LoDTensor& input_scale_tensor =
scope->FindVar(input_scale_var_name)->Get<LoDTensor>(); scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -84,8 +86,8 @@ void DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -84,8 +86,8 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
} else if (quantized_op_type == "mul") { } else if (quantized_op_type == "mul") {
op_desc->SetAttr("X_scale", scale_value); op_desc->SetAttr("X_scale", scale_value);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported quantized op type %s", quantized_op_type)); "Unsupported quantized op type %s.", quantized_op_type));
} }
op_desc->SetAttr("bit_length", bit_length); op_desc->SetAttr("bit_length", bit_length);
op_desc->RenameInput(output_act_name, input_act_name); op_desc->RenameInput(output_act_name, input_act_name);
...@@ -119,9 +121,9 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -119,9 +121,9 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
weight_name = "W"; weight_name = "W";
input_name = "Input"; input_name = "Input";
} else { } else {
PADDLE_ENFORCE( PADDLE_THROW(platform::errors::Unimplemented(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for " "QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for "
"now."); "now."));
} }
const std::string pattern_name = "dequant_fuse"; const std::string pattern_name = "dequant_fuse";
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -141,8 +143,9 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -141,8 +143,9 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
Graph* g) { Graph* g) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
subgraph.count(quantized_op_input), true, subgraph.count(quantized_op_input), true,
platform::errors::NotFound( platform::errors::NotFound("Quantized op input node(%s) did not find "
"Quantized op input node not found in Delete Quant fusion.")); "in QuantDequantFuse pass.",
quantized_op_input->name()));
Node* quantized_op_input_node = subgraph.at(quantized_op_input); Node* quantized_op_input_node = subgraph.at(quantized_op_input);
Node* quantized_op_weight_node = Node* quantized_op_weight_node =
subgraph.at(pattern.GetPDNode("quantized_op_weight")); subgraph.at(pattern.GetPDNode("quantized_op_weight"));
...@@ -165,7 +168,7 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -165,7 +168,7 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scales_name.size(), 2, scales_name.size(), 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scales size in channel-wise dequantize op should be 2, got %d", "Scales size in channel-wise dequantize op should be 2, got %d.",
scales_name.size())); scales_name.size()));
const LoDTensor& channel_scale_tensor = const LoDTensor& channel_scale_tensor =
scope->FindVar(scales_name[0])->Get<LoDTensor>(); scope->FindVar(scales_name[0])->Get<LoDTensor>();
...@@ -193,9 +196,10 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -193,9 +196,10 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
bool valid_scale_size = bool valid_scale_size =
(weight_scale.size() == 1 || (weight_scale.size() == 1 ||
weight_scale.size() == static_cast<size_t>(w_dims[0])); weight_scale.size() == static_cast<size_t>(w_dims[0]));
PADDLE_ENFORCE_EQ(valid_scale_size, true, PADDLE_ENFORCE_EQ(
valid_scale_size, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"TRT int8 quant: invalid scale size")); "TRT int8 quant: invalid scale size(%d).", weight_scale.size()));
float* quantized_weight_data = float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace()); weight_tensor->mutable_data<float>(platform::CPUPlace());
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
......
...@@ -278,11 +278,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -278,11 +278,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto retrieve_node = [](const std::string& name, auto retrieve_node = [](const std::string& name,
const GraphPatternDetector::subgraph_t& subgraph, const GraphPatternDetector::subgraph_t& subgraph,
const PDPattern& pat) -> Node* { const PDPattern& pat) -> Node* {
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)), PADDLE_ENFORCE_GT(subgraph.count(pat.RetrieveNode(name)), 0,
"pattern has no Node called %s", name.c_str()); platform::errors::NotFound(
"Pattern has no node called %s.", name.c_str()));
Node* p = subgraph.at(pat.RetrieveNode(name)); Node* p = subgraph.at(pat.RetrieveNode(name));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(p, platform::errors::NotFound(
p, platform::errors::NotFound("subgraph has no node %s", name.c_str())); "Subgraph has no node %s.", name.c_str()));
return p; return p;
}; };
...@@ -365,7 +366,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -365,7 +366,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
} }
void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const { void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
int fusion_count = 0; int fusion_count = 0;
......
...@@ -55,9 +55,15 @@ void TestMain(int num_fc) { ...@@ -55,9 +55,15 @@ void TestMain(int num_fc) {
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
// Delete (num_fc_nodes_before - 1) fc ops // Delete (num_fc_nodes_before - 1) fc ops
PADDLE_ENFORCE_EQ(num_nodes_before - (num_fc_nodes_before - 1) + 1, PADDLE_ENFORCE_EQ(
num_nodes_after); num_nodes_before - (num_fc_nodes_before - 1) + 1, num_nodes_after,
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1); platform::errors::InvalidArgument(
"num_nodes_before = %d, num_fc_nodes_before = %d, num_nodes_after = "
"%d.",
num_nodes_before, num_fc_nodes_before, num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1,
platform::errors::InvalidArgument(
"num_fused_nodes_after = %d.", num_fused_nodes_after));
} }
TEST(RepeatedFCReluFusePass, basic_3) { TestMain(3); } TEST(RepeatedFCReluFusePass, basic_3) { TestMain(3); }
......
...@@ -186,10 +186,12 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -186,10 +186,12 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const {
BuildFCPattern(pattern, concat_out); BuildFCPattern(pattern, concat_out);
#define GET_NODE(id, pattern) \ #define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \ PADDLE_ENFORCE_GT( \
"pattern has no Node called %s", #id); \ subgraph.count(pattern.RetrieveNode(#id)), 0, \
platform::errors::NotFound("Pattern has no node called %s.", #id)); \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); PADDLE_ENFORCE_NOT_NULL( \
id, platform::errors::NotFound("Subgraph has no node %s.", #id));
int fuse_count{0}; int fuse_count{0};
......
...@@ -139,11 +139,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -139,11 +139,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto retrieve_node = [](const std::string& name, auto retrieve_node = [](const std::string& name,
const GraphPatternDetector::subgraph_t& subgraph, const GraphPatternDetector::subgraph_t& subgraph,
const PDPattern& pat) -> Node* { const PDPattern& pat) -> Node* {
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)), PADDLE_ENFORCE_GT(subgraph.count(pat.RetrieveNode(name)), 0,
"pattern has no Node called %s", name.c_str()); platform::errors::NotFound(
"Pattern has no node called %s.", name.c_str()));
Node* p = subgraph.at(pat.RetrieveNode(name)); Node* p = subgraph.at(pat.RetrieveNode(name));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(p, platform::errors::NotFound(
p, platform::errors::NotFound("subgraph has no node %s", name.c_str())); "Subgraph has no node %s.", name.c_str()));
return p; return p;
}; };
......
...@@ -47,7 +47,9 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { ...@@ -47,7 +47,9 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
Graph* g) { Graph* g) {
GET_NODES; GET_NODES;
PADDLE_ENFORCE(subgraph.count(x)); PADDLE_ENFORCE_GT(
subgraph.count(x), 0,
platform::errors::NotFound("Detector did not find input X."));
auto* input_node = subgraph.at(x); auto* input_node = subgraph.at(x);
auto reshape1_desc = reshape1_op->Op(); auto reshape1_desc = reshape1_op->Op();
auto reshape2_desc = reshape2_op->Op(); auto reshape2_desc = reshape2_op->Op();
......
...@@ -59,12 +59,25 @@ TEST(SimplifyWithBasicOpsPass, dropout) { ...@@ -59,12 +59,25 @@ TEST(SimplifyWithBasicOpsPass, dropout) {
int num_scale_nodes_after = GetNumOpNodes(graph, "scale"); int num_scale_nodes_after = GetNumOpNodes(graph, "scale");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_dropout_nodes_after, 0); PADDLE_ENFORCE_EQ(
num_dropout_nodes_after, 0,
platform::errors::InvalidArgument("num_dropout_nodes_after = %d.",
num_dropout_nodes_after));
if (dropout_implementation == "downgrade_in_infer") { if (dropout_implementation == "downgrade_in_infer") {
PADDLE_ENFORCE_EQ(num_dropout_nodes_before, PADDLE_ENFORCE_EQ(
num_scale_nodes_after - num_scale_nodes_before); num_dropout_nodes_before,
num_scale_nodes_after - num_scale_nodes_before,
platform::errors::InvalidArgument(
"num_dropout_nodes_before = %d, num_scale_nodes_after = %d, "
"num_scale_nodes_before = %d.",
num_dropout_nodes_before, num_scale_nodes_after,
num_scale_nodes_before));
} else { } else {
PADDLE_ENFORCE_EQ(num_scale_nodes_after - num_scale_nodes_before, 0); PADDLE_ENFORCE_EQ(
num_scale_nodes_after - num_scale_nodes_before, 0,
platform::errors::InvalidArgument(
"num_scale_nodes_after = %d, num_scale_nodes_before = %d.",
num_scale_nodes_after, num_scale_nodes_before));
} }
} }
} }
......
...@@ -300,10 +300,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -300,10 +300,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
auto retrieve_node = [](const std::string& name, auto retrieve_node = [](const std::string& name,
const GraphPatternDetector::subgraph_t& subgraph, const GraphPatternDetector::subgraph_t& subgraph,
const PDPattern& pat) -> Node* { const PDPattern& pat) -> Node* {
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)), PADDLE_ENFORCE_GT(subgraph.count(pat.RetrieveNode(name)), 0,
"pattern has no Node called %s", name.c_str()); platform::errors::NotFound(
"Pattern has no node called %s.", name.c_str()));
Node* p = subgraph.at(pat.RetrieveNode(name)); Node* p = subgraph.at(pat.RetrieveNode(name));
PADDLE_ENFORCE_NOT_NULL(p, "subgraph has no node %s", name.c_str()); PADDLE_ENFORCE_NOT_NULL(p, platform::errors::NotFound(
"Subgraph has no node %s.", name.c_str()));
return p; return p;
}; };
......
...@@ -51,15 +51,25 @@ void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) { ...@@ -51,15 +51,25 @@ void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) {
std::vector<Node *> nodes; std::vector<Node *> nodes;
for (int i = 0; i < times; i++) { for (int i = 0; i < times; i++) {
PADDLE_ENFORCE( PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i)))); subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i))),
PADDLE_ENFORCE( platform::errors::NotFound("Can not find transpose%d in subgraph.",
subgraph.at(pattern.GetPDNode("transpose_out" + std::to_string(i)))); i));
PADDLE_ENFORCE( PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("flatten" + std::to_string(i)))); subgraph.at(pattern.GetPDNode("transpose_out" + std::to_string(i))),
PADDLE_ENFORCE( platform::errors::NotFound(
subgraph.at(pattern.GetPDNode("flatten_out" + std::to_string(i)))); "Can not find transpose_out%d in subgraph.", i));
PADDLE_ENFORCE(subgraph.at(input_nodes[i])); PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("flatten" + std::to_string(i))),
platform::errors::NotFound("Can not find flatten%d in subgraph.", i));
PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("flatten_out" + std::to_string(i))),
platform::errors::NotFound("Can not find flatten_out%d in subgraph.",
i));
PADDLE_ENFORCE_NOT_NULL(
subgraph.at(input_nodes[i]),
platform::errors::NotFound("Can not find %s in subgraph.",
input_nodes[i]->name()));
nodes.push_back(subgraph.at(input_nodes[i])); nodes.push_back(subgraph.at(input_nodes[i]));
nodes.push_back( nodes.push_back(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册