未验证 提交 1917b380 编写于 作者: W wanghuancoder 提交者: GitHub

fix some errmsg report,in framework/ir/, about 21 files (#25525)

* fix error msg report in ir/, about 19 files, test=develop

* modified some unclear descriptions, test=develop

* modified some unclear descriptions, test=develop

* modify unit test pass_test.cc, because the error report in pass.cc is used by pass_test.cc, test=develop
上级 4ec1251a
......@@ -320,7 +320,7 @@ std::vector<Node *> FuseBatchNormActPass::ReplaceNode(
return node;
});
PADDLE_ENFORCE_EQ(has_replaced, true,
platform::errors::NotFound("Not find %s in the node list.",
platform::errors::NotFound("Not found %s in the node list.",
cur_node->Name()));
return new_list;
}
......
......@@ -42,7 +42,8 @@ void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
// ele_add(x, act(y))
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("elewise_add_act", graph);
GraphPatternDetector gpd;
......@@ -93,7 +94,8 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
// act(ele_add(x,y))
ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("act_elewise_add", graph);
GraphPatternDetector gpd;
......@@ -145,7 +147,8 @@ ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("elewise_add_act_grad", graph);
GraphPatternDetector gpd;
......@@ -252,10 +255,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
bool save_intermediate_out = BOOST_GET_CONST(
bool, cur_node->Op()->GetAttr("save_intermediate_out"));
auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut");
PADDLE_ENFORCE(
save_intermediate_out && !intermediate_out_args.empty(),
"The %s should save the intermediate_out in the fusing stage.",
cur_node->Name());
PADDLE_ENFORCE_EQ(
(save_intermediate_out && !intermediate_out_args.empty()), true,
platform::errors::InvalidArgument(
"The %s should save the intermediate out in the fusing stage.",
cur_node->Name()));
// If the intermediate_out's output is empty, it should be removed.
auto cur_node_outputs = cur_node->outputs;
......@@ -271,10 +275,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
} else if (cur_node->Name() == "fused_elemwise_activation_grad") {
auto intermediate_out_grad_args =
cur_node->Op()->Output(GradVarName("IntermediateOut"));
PADDLE_ENFORCE(
!intermediate_out_grad_args.empty(),
"The %s should save the intermediate_out in the fusing stage.",
cur_node->Name());
PADDLE_ENFORCE_EQ(
intermediate_out_grad_args.empty(), false,
platform::errors::InvalidArgument(
"The %s should save the intermediate out in the fusing stage.",
cur_node->Name()));
auto cur_node_outputs = cur_node->outputs;
// If the intermediate_out_g's output is empty, it should be removed.
for (auto &out : cur_node_outputs) {
......@@ -312,7 +317,11 @@ void FuseElewiseAddActPass::ReLinkNodes(Graph *graph,
nodes2delete.emplace(out);
}
} else {
PADDLE_ENFORCE(out == intermediate_out);
PADDLE_ENFORCE_EQ(
out, intermediate_out,
platform::errors::InvalidArgument(
"Output of op(%s) must be %s, but not %s.", op_1->Name(),
intermediate_out->Name(), out->Name()));
IR_OP_VAR_LINK(fused_op, out);
}
}
......@@ -347,8 +356,9 @@ std::vector<Node *> FuseElewiseAddActPass::ReplaceNode(
}
return node;
});
PADDLE_ENFORCE(has_replaced, "Not find %s in the node list.",
cur_node->Name());
PADDLE_ENFORCE_EQ(has_replaced, true,
platform::errors::NotFound("Not found %s in the node list.",
cur_node->Name()));
return new_list;
}
......
......@@ -25,14 +25,19 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const {
}
Scope* FusePassBase::param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
PADDLE_ENFORCE_EQ(graph_->Has(kParamScopeAttr), true,
platform::errors::InvalidArgument(
"Graph must have kParamScopeAttr attribute."));
auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr);
return &scope;
}
void FusePassBase::AddStatis(int count_of_fused) const {
PADDLE_ENFORCE(graph_);
PADDLE_ENFORCE(!repr_.empty());
PADDLE_ENFORCE_NOT_NULL(
graph_, platform::errors::InvalidArgument("Graph cannot be nullptr."));
PADDLE_ENFORCE_EQ(repr_.empty(), false,
platform::errors::InvalidArgument(
"Fuse pass must be initialized with a name."));
if (!graph_->Has(kFuseStatisAttr)) {
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
}
......
......@@ -31,7 +31,8 @@ void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
ir::Graph *graph, bool only_forward) const {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
if (only_forward)
FusePassBase::Init("relu_depthwise_conv_only_forward", graph);
else
......@@ -110,23 +111,45 @@ ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
xg_var = subgraph.at(xg)->Var();
}
PADDLE_ENFORCE_EQ(layer_op->Input("Input").size(), 1UL);
PADDLE_ENFORCE_EQ(layer_op->Input("Input")[0], y_var->Name());
PADDLE_ENFORCE_EQ(layer_op->Input("Input").size(), 1UL,
platform::errors::InvalidArgument(
"Op(%s)'s input size(%d) must be 1.",
layer_op->Type(), layer_op->Input("Input").size()));
PADDLE_ENFORCE_EQ(
layer_op->Input("Input")[0], y_var->Name(),
platform::errors::InvalidArgument(
"Op(%s)'s input name(%s) must be %s.", layer_op->Type(),
layer_op->Input("Input")[0], y_var->Name()));
layer_op->SetInput("Input", {x_var->Name()});
subgraph.at(layer)->inputs.push_back(subgraph.at(x));
subgraph.at(x)->outputs.push_back(subgraph.at(layer));
VLOG(4) << "replace " << y_var->Name() << " -> " << x_var->Name();
if (!only_forward) {
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input").size(), 1UL);
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input")[0], y_var->Name());
PADDLE_ENFORCE_EQ(
layer_g_op->Input("Input").size(), 1UL,
platform::errors::InvalidArgument(
"Op(%s)'s input size(%d) must be 1.", layer_g_op->Type(),
layer_g_op->Input("Input").size()));
PADDLE_ENFORCE_EQ(
layer_g_op->Input("Input")[0], y_var->Name(),
platform::errors::InvalidArgument(
"Op(%s)'s input name(%s) must be %s.", layer_g_op->Type(),
layer_g_op->Input("Input")[0], y_var->Name()));
layer_g_op->SetInput("Input", {x_var->Name()});
subgraph.at(layer_g)->inputs.push_back(subgraph.at(x));
subgraph.at(x)->outputs.push_back(subgraph.at(layer_g));
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input")).size(), 1UL);
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input"))[0],
yg_var->Name());
PADDLE_ENFORCE_EQ(
layer_g_op->Output(GradVarName("Input")).size(), 1UL,
platform::errors::InvalidArgument(
"Op(%s)'s input size(%d) must be 1.", layer_g_op->Type(),
layer_g_op->Output(GradVarName("Input")).size()));
PADDLE_ENFORCE_EQ(
layer_g_op->Output(GradVarName("Input"))[0], yg_var->Name(),
platform::errors::InvalidArgument(
"Op(%s)'s input name(%s) must be %s.", layer_g_op->Type(),
layer_g_op->Output(GradVarName("Input"))[0], yg_var->Name()));
layer_g_op->SetOutput(GradVarName("Input"), {xg_var->Name()});
subgraph.at(layer_g)->outputs.push_back(subgraph.at(xg));
subgraph.at(xg)->inputs.push_back(subgraph.at(layer_g));
......
......@@ -136,7 +136,9 @@ bool FindCircleSubGraph(const Graph &graph,
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
adj_list = BuildOperationAdjList(graph);
PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr));
PADDLE_ENFORCE_EQ(HasCircleInternal(adj_list, nullptr), false,
platform::errors::InvalidArgument(
"Generated graph shouldn't contain cycle."));
std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret;
for (auto adj : adj_list) {
......@@ -161,7 +163,11 @@ BuildOperationAdjList(const Graph &graph) {
}
for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
PADDLE_ENFORCE_EQ(
adj_n->NodeType(), ir::Node::Type::kOperation,
platform::errors::InvalidArgument(
"Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(),
static_cast<int>(adj_n->NodeType())));
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var);
......@@ -184,7 +190,11 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList(
}
for (auto &var : n->outputs) {
for (auto &adj_n : var->outputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
PADDLE_ENFORCE_EQ(
adj_n->NodeType(), ir::Node::Type::kOperation,
platform::errors::InvalidArgument(
"Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(),
static_cast<int>(adj_n->NodeType())));
VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var);
......@@ -359,7 +369,10 @@ size_t GraphNum(const Graph &graph) {
}
std::unique_ptr<std::ostream> fout(
new std::ofstream(FLAGS_print_sub_graph_dir));
PADDLE_ENFORCE(fout->good());
PADDLE_ENFORCE_EQ(fout->good(), true,
platform::errors::Unavailable(
"Can not open file %s for printing the graph.",
FLAGS_print_sub_graph_dir));
*fout << out.str();
}
}
......
......@@ -33,7 +33,8 @@ const char kSumGradOpName[] = "sum";
const char kOptimizerType[] = "sgd";
void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
// We could collect all weights' name from SGD, where
// W1 <- SGD(W0, Grad0)
......@@ -41,7 +42,10 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
for (auto* node : graph->Nodes()) {
if (IsOpNamed(node, kOptimizerType)) {
auto& param_out_vars = node->Op()->Output("ParamOut");
PADDLE_ENFORCE(param_out_vars.size() == 1u);
PADDLE_ENFORCE_EQ(
param_out_vars.size(), 1u,
platform::errors::InvalidArgument(
"In op(%s), find output(ParamOut) failed.", node->Name()));
weight_var_set.insert(param_out_vars[0]);
}
}
......@@ -95,12 +99,19 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Found forward_op " << forward_op->Name();
PADDLE_ENFORCE(forward_op);
PADDLE_ENFORCE_NOT_NULL(
forward_op, platform::errors::NotFound(
"Can not find forward op for backword op(%s).",
backward_op->Name()));
Node* new_optimizer_node = CreateNewSGDNode(
graph, forward_op, backward_op, node, opt_node);
PADDLE_ENFORCE(new_optimizer_node);
PADDLE_ENFORCE_NOT_NULL(
new_optimizer_node,
platform::errors::InvalidArgument(
"Create new SGD node failed, backward op is %s.",
backward_op->Name()));
}
}
}
......@@ -144,11 +155,21 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
ir::Graph* graph, ir::Node* forward_node, ir::Node* backward_node,
ir::Node* grad_sum_node, ir::Node* optimize_node) const {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(forward_node);
PADDLE_ENFORCE(backward_node);
PADDLE_ENFORCE(grad_sum_node);
PADDLE_ENFORCE(optimize_node);
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Input argument graph cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
forward_node, platform::errors::InvalidArgument(
"Input argument forward_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
backward_node, platform::errors::InvalidArgument(
"Input argument backward_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
grad_sum_node, platform::errors::InvalidArgument(
"Input argument grad_sum_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
optimize_node, platform::errors::InvalidArgument(
"Input argument optimize_node cannot be nullptr."));
// find the grad var node between the grad sum node and backward_node
std::vector<ir::Node*> grad_vars =
......@@ -159,7 +180,8 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
grad_node = node;
}
}
PADDLE_ENFORCE(grad_node);
PADDLE_ENFORCE_NOT_NULL(grad_node, platform::errors::NotFound(
"Can not find control dep variable."));
// create a new SGD node
OpDesc* old_desc = optimize_node->Op();
......@@ -212,8 +234,14 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
}
// SGD must have only one param and LR in
PADDLE_ENFORCE(old_desc->Input("LearningRate").size() == 1u);
PADDLE_ENFORCE(old_desc->Input("Param").size() == 1u);
PADDLE_ENFORCE_EQ(
old_desc->Input("LearningRate").size(), 1u,
platform::errors::InvalidArgument(
"In op(%s), find input(LearningRate) failed.", old_desc->Type()));
PADDLE_ENFORCE_EQ(
old_desc->Input("Param").size(), 1u,
platform::errors::InvalidArgument("In op(%s), find input(Param) failed.",
old_desc->Type()));
// LR and weight nodes should be copied
for (Node* upstream_node : optimize_node->inputs) {
......@@ -245,9 +273,17 @@ std::vector<ir::Node*> LockFreeOptimizePass::FindConnectedNode(
void LockFreeOptimizePass::ReplaceUpstreamNode(
ir::Node* upstream_node, ir::Node* old_optimizer_node,
ir::Node* new_optimizer_node) const {
PADDLE_ENFORCE(upstream_node);
PADDLE_ENFORCE(old_optimizer_node);
PADDLE_ENFORCE(new_optimizer_node);
PADDLE_ENFORCE_NOT_NULL(
upstream_node, platform::errors::InvalidArgument(
"Input argument upstream_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
old_optimizer_node,
platform::errors::InvalidArgument(
"Input argument old_optimizer_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
new_optimizer_node,
platform::errors::InvalidArgument(
"Input argument new_optimizer_node cannot be nullptr."));
// Remove the old_optimizer_node from upstream_node's outputs vector
auto& output_node_vec = upstream_node->outputs;
......@@ -268,8 +304,14 @@ void LockFreeOptimizePass::ReplaceUpstreamNode(
void LockFreeOptimizePass::ReplaceAllDownstreamNode(
ir::Node* old_optimizer_node, ir::Node* new_optimizer_node) const {
PADDLE_ENFORCE(old_optimizer_node);
PADDLE_ENFORCE(new_optimizer_node);
PADDLE_ENFORCE_NOT_NULL(
old_optimizer_node,
platform::errors::InvalidArgument(
"Input argument old_optimizer_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
new_optimizer_node,
platform::errors::InvalidArgument(
"Input argument new_optimizer_node cannot be nullptr."));
for (ir::Node* downstream_node : old_optimizer_node->outputs) {
// Remove the old_optimizer_node from downstream_node's inputs vector
......@@ -292,8 +334,12 @@ void LockFreeOptimizePass::ReplaceAllDownstreamNode(
ir::Node* LockFreeOptimizePass::FindForwardOpViaBackwardOp(
ir::Graph* graph, ir::Node* backward_node) const {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(backward_node);
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Input argument graph cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
backward_node, platform::errors::InvalidArgument(
"Input argument backward_node cannot be nullptr."));
// strip the suffix _grad of backward_node's name
std::string forward_op_name = backward_node->Name();
......
......@@ -87,34 +87,46 @@ class LockFreeOptimizePass : public Pass {
ir::Node* downstream_node) const;
inline bool IsOpNamed(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node);
PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kOperation && node->Name() == name;
}
inline bool IsVarNamed(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node);
PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kVariable && node->Name() == name;
}
inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node);
PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kVariable &&
boost::algorithm::ends_with(node->Name(), name);
}
inline bool IsVarNameContains(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node);
PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kVariable &&
node->Name().find(name) != std::string::npos;
}
inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const {
PADDLE_ENFORCE(ctrl_dep_node);
PADDLE_ENFORCE(node);
PADDLE_ENFORCE_NOT_NULL(
ctrl_dep_node, platform::errors::InvalidArgument(
"Input argument ctrl_dep_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return IsControlDepVar(*ctrl_dep_node) &&
ctrl_dep_node->inputs.size() >= 1u &&
......
......@@ -66,12 +66,18 @@ class Node {
std::string Name() const { return name_; }
VarDesc* Var() const {
PADDLE_ENFORCE_EQ(IsVar(), true);
PADDLE_ENFORCE_EQ(IsVar(), true,
platform::errors::InvalidArgument(
"Node(%s) must be kVariable type, but not %d.", name_,
static_cast<int>(type_)));
return var_desc_.get();
}
OpDesc* Op() const {
PADDLE_ENFORCE_EQ(IsOp(), true);
PADDLE_ENFORCE_EQ(IsOp(), true,
platform::errors::InvalidArgument(
"Node(%s) must be kOperation type, but not %d.",
name_, static_cast<int>(type_)));
return op_desc_.get();
}
......@@ -92,8 +98,9 @@ class Node {
try {
return *boost::any_cast<T*>(wrapper_);
} catch (boost::bad_any_cast&) {
PADDLE_THROW("Invalid wrapper type error, expected %s, actual %s",
typeid(T).name(), wrapper_type_.name());
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid wrapper type error, expected %s, actual %s.",
typeid(T).name(), wrapper_type_.name()));
}
}
......@@ -114,8 +121,9 @@ class Node {
}
void RenameVar(const std::string& new_name) {
PADDLE_ENFORCE(type_ == Type::kVariable && var_desc_,
"Must be type of variable");
PADDLE_ENFORCE_EQ(
type_ == Type::kVariable && var_desc_, true,
platform::errors::InvalidArgument("Node must be type of variable."));
name_ = new_name;
var_desc_->SetName(new_name);
}
......
......@@ -26,7 +26,8 @@ namespace ir {
Graph* Pass::Apply(Graph* graph) const {
CheckPrevPass();
PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty.");
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE_NE(
attrs_.find(attr), attrs_.end(),
......@@ -40,11 +41,14 @@ Graph* Pass::Apply(Graph* graph) const {
}
ApplyImpl(graph);
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*graph),
"Illegal Pass %s. Generated graph shouldn't have cycle.",
Type());
PADDLE_ENFORCE(VarDescIsConsistency(*graph),
"The VarDescs of persistable variable are not consistency.");
PADDLE_ENFORCE_EQ(
HasCircle(*graph), false,
platform::errors::InvalidArgument(
"Illegal pass %s. Generated graph shouldn't contain cycle.", Type()));
PADDLE_ENFORCE_EQ(
VarDescIsConsistency(*graph), true,
platform::errors::InvalidArgument(
"The VarDescs of persistable variable are not consistency."));
applied_ = true;
if (!graph->Has(kPassRecorder)) {
graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
......
......@@ -55,8 +55,9 @@ class Pass {
// Get a reference to the attributed previously set.
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"%s attr not registered for pass.", attr_name);
PADDLE_ENFORCE_NE(attrs_.find(attr_name), attrs_.end(),
platform::errors::InvalidArgument(
"Attribute %s not registered for pass.", attr_name));
try {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} catch (boost::bad_any_cast &) {
......@@ -76,7 +77,7 @@ class Pass {
};
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid type for attritube %s, expected: %s, actual: %s", attr_name,
"Invalid type for attritube %s, expected: %s, actual: %s.", attr_name,
TypeToString(typeid(AttrType *)),
TypeToString(attrs_.at(attr_name).type())));
}
......@@ -101,9 +102,10 @@ class Pass {
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) {
if (default_pass_attrs_.count(attr_name) == 0) {
PADDLE_ENFORCE_EQ(attrs_.count(attr_name), 0,
platform::errors::InvalidArgument(
"Attribute %s already set in the pass", attr_name));
PADDLE_ENFORCE_EQ(
attrs_.count(attr_name), 0,
platform::errors::AlreadyExists(
"Attribute %s already set in the pass.", attr_name));
} else {
VLOG(3) << "Setting the attribute " << attr_name << " for the pass "
<< type_;
......@@ -119,15 +121,16 @@ class Pass {
// should delete the attribute.
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass",
attr_name);
PADDLE_ENFORCE_EQ(attrs_.count(attr_name), 0,
platform::errors::AlreadyExists(
"Attribute %s already set in the pass.", attr_name));
attrs_[attr_name] = attr;
}
protected:
virtual void ApplyImpl(Graph *graph) const {
PADDLE_THROW(platform::errors::Unimplemented(
"The virtual Pass called is not implemented."));
"The virtual pass called is not implemented."));
}
// Some Pass must be placed before this Pass, and some
......@@ -198,8 +201,9 @@ class PassRegistry {
}
std::unique_ptr<Pass> Get(const std::string &pass_type) const {
PADDLE_ENFORCE(Has(pass_type), "Pass %s has not been registered",
pass_type);
PADDLE_ENFORCE_EQ(Has(pass_type), true,
platform::errors::InvalidArgument(
"Pass %s has not been registered.", pass_type));
return map_.at(pass_type)();
}
......@@ -213,8 +217,10 @@ class PassRegistry {
template <typename PassType>
struct PassRegistrar : public Registrar {
explicit PassRegistrar(const char *pass_type) {
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type),
"'%s' is registered more than once.", pass_type);
PADDLE_ENFORCE_EQ(
PassRegistry::Instance().Has(pass_type), false,
platform::errors::AlreadyExists(
"Pass '%s' is registered more than once.", pass_type));
PassRegistry::Instance().Insert(
pass_type, [this, pass_type]() -> std::unique_ptr<Pass> {
std::unique_ptr<Pass> pass(new PassType());
......
......@@ -28,13 +28,19 @@ std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) {
}
void PassBuilder::RemovePass(size_t idx) {
PADDLE_ENFORCE(passes_.size() > idx);
PADDLE_ENFORCE_GT(
passes_.size(), idx,
platform::errors::InvalidArgument(
"Passes size is %d, %d is not a valid index.", passes_.size(), idx));
passes_.erase(passes_.begin() + idx);
}
std::shared_ptr<Pass> PassBuilder::InsertPass(size_t idx,
const std::string& pass_type) {
PADDLE_ENFORCE(passes_.size() >= idx);
PADDLE_ENFORCE_GE(
passes_.size(), idx,
platform::errors::InvalidArgument(
"Passes size is %d, %d is not a valid index.", passes_.size(), idx));
std::shared_ptr<Pass> pass(
ir::PassRegistry::Instance().Get(pass_type).release());
passes_.insert(passes_.begin() + idx, std::move(pass));
......
......@@ -119,7 +119,7 @@ TEST(PassTest, TestPassAttrCheck) {
} catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what());
}
ASSERT_TRUE(exception.find("shouldn't have cycle") != exception.npos);
ASSERT_TRUE(exception.find("shouldn't contain cycle") != exception.npos);
pass = PassRegistry::Instance().Get("test_pass");
pass->Set<int>("test_pass_attr", new int);
......
......@@ -43,9 +43,11 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
// ops linked from it
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
PADDLE_ENFORCE_EQ(subgraph.count(input_act_node), true,
PADDLE_ENFORCE_EQ(
subgraph.count(input_act_node), true,
platform::errors::NotFound(
"Input act node not found in Delete Quant fusion."));
"Input act node(%s) not found in QuantDequantFuse pass.",
input_act_node->name()));
Node* input_act = subgraph.at(input_act_node);
Node* input_scale = subgraph.at(pattern.GetPDNode("input_scale_node"));
Node* quant = subgraph.at(pattern.GetPDNode("quant_node"));
......@@ -58,7 +60,7 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
std::string input_scale_var_name = quant->Op()->Input("InScale").front();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument(
"scope in DeleteQuantOpFuse pass should not be null."));
"Scope in QuantDequantFuse pass should not be null."));
const LoDTensor& input_scale_tensor =
scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
......@@ -84,8 +86,8 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
} else if (quantized_op_type == "mul") {
op_desc->SetAttr("X_scale", scale_value);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type %s", quantized_op_type));
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported quantized op type %s.", quantized_op_type));
}
op_desc->SetAttr("bit_length", bit_length);
op_desc->RenameInput(output_act_name, input_act_name);
......@@ -119,9 +121,9 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
weight_name = "W";
input_name = "Input";
} else {
PADDLE_ENFORCE(
PADDLE_THROW(platform::errors::Unimplemented(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for "
"now.");
"now."));
}
const std::string pattern_name = "dequant_fuse";
GraphPatternDetector gpd;
......@@ -141,8 +143,9 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
Graph* g) {
PADDLE_ENFORCE_EQ(
subgraph.count(quantized_op_input), true,
platform::errors::NotFound(
"Quantized op input node not found in Delete Quant fusion."));
platform::errors::NotFound("Quantized op input node(%s) did not find "
"in QuantDequantFuse pass.",
quantized_op_input->name()));
Node* quantized_op_input_node = subgraph.at(quantized_op_input);
Node* quantized_op_weight_node =
subgraph.at(pattern.GetPDNode("quantized_op_weight"));
......@@ -165,7 +168,7 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
PADDLE_ENFORCE_EQ(
scales_name.size(), 2,
platform::errors::InvalidArgument(
"Scales size in channel-wise dequantize op should be 2, got %d",
"Scales size in channel-wise dequantize op should be 2, got %d.",
scales_name.size()));
const LoDTensor& channel_scale_tensor =
scope->FindVar(scales_name[0])->Get<LoDTensor>();
......@@ -193,9 +196,10 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
bool valid_scale_size =
(weight_scale.size() == 1 ||
weight_scale.size() == static_cast<size_t>(w_dims[0]));
PADDLE_ENFORCE_EQ(valid_scale_size, true,
PADDLE_ENFORCE_EQ(
valid_scale_size, true,
platform::errors::InvalidArgument(
"TRT int8 quant: invalid scale size"));
"TRT int8 quant: invalid scale size(%d).", weight_scale.size()));
float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
for (int j = 0; j < weight_tensor->numel(); j++) {
......
......@@ -278,11 +278,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto retrieve_node = [](const std::string& name,
const GraphPatternDetector::subgraph_t& subgraph,
const PDPattern& pat) -> Node* {
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)),
"pattern has no Node called %s", name.c_str());
PADDLE_ENFORCE_GT(subgraph.count(pat.RetrieveNode(name)), 0,
platform::errors::NotFound(
"Pattern has no node called %s.", name.c_str()));
Node* p = subgraph.at(pat.RetrieveNode(name));
PADDLE_ENFORCE_NOT_NULL(
p, platform::errors::NotFound("subgraph has no node %s", name.c_str()));
PADDLE_ENFORCE_NOT_NULL(p, platform::errors::NotFound(
"Subgraph has no node %s.", name.c_str()));
return p;
};
......@@ -365,7 +366,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
}
void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph);
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);
int fusion_count = 0;
......
......@@ -55,9 +55,15 @@ void TestMain(int num_fc) {
VLOG(3) << DebugString(graph);
// Delete (num_fc_nodes_before - 1) fc ops
PADDLE_ENFORCE_EQ(num_nodes_before - (num_fc_nodes_before - 1) + 1,
num_nodes_after);
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1);
PADDLE_ENFORCE_EQ(
num_nodes_before - (num_fc_nodes_before - 1) + 1, num_nodes_after,
platform::errors::InvalidArgument(
"num_nodes_before = %d, num_fc_nodes_before = %d, num_nodes_after = "
"%d.",
num_nodes_before, num_fc_nodes_before, num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1,
platform::errors::InvalidArgument(
"num_fused_nodes_after = %d.", num_fused_nodes_after));
}
TEST(RepeatedFCReluFusePass, basic_3) { TestMain(3); }
......
......@@ -186,10 +186,12 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const {
BuildFCPattern(pattern, concat_out);
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \
"pattern has no Node called %s", #id); \
PADDLE_ENFORCE_GT( \
subgraph.count(pattern.RetrieveNode(#id)), 0, \
platform::errors::NotFound("Pattern has no node called %s.", #id)); \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
PADDLE_ENFORCE_NOT_NULL( \
id, platform::errors::NotFound("Subgraph has no node %s.", #id));
int fuse_count{0};
......
......@@ -139,11 +139,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto retrieve_node = [](const std::string& name,
const GraphPatternDetector::subgraph_t& subgraph,
const PDPattern& pat) -> Node* {
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)),
"pattern has no Node called %s", name.c_str());
PADDLE_ENFORCE_GT(subgraph.count(pat.RetrieveNode(name)), 0,
platform::errors::NotFound(
"Pattern has no node called %s.", name.c_str()));
Node* p = subgraph.at(pat.RetrieveNode(name));
PADDLE_ENFORCE_NOT_NULL(
p, platform::errors::NotFound("subgraph has no node %s", name.c_str()));
PADDLE_ENFORCE_NOT_NULL(p, platform::errors::NotFound(
"Subgraph has no node %s.", name.c_str()));
return p;
};
......
......@@ -47,7 +47,9 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
Graph* g) {
GET_NODES;
PADDLE_ENFORCE(subgraph.count(x));
PADDLE_ENFORCE_GT(
subgraph.count(x), 0,
platform::errors::NotFound("Detector did not find input X."));
auto* input_node = subgraph.at(x);
auto reshape1_desc = reshape1_op->Op();
auto reshape2_desc = reshape2_op->Op();
......
......@@ -59,12 +59,25 @@ TEST(SimplifyWithBasicOpsPass, dropout) {
int num_scale_nodes_after = GetNumOpNodes(graph, "scale");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_dropout_nodes_after, 0);
PADDLE_ENFORCE_EQ(
num_dropout_nodes_after, 0,
platform::errors::InvalidArgument("num_dropout_nodes_after = %d.",
num_dropout_nodes_after));
if (dropout_implementation == "downgrade_in_infer") {
PADDLE_ENFORCE_EQ(num_dropout_nodes_before,
num_scale_nodes_after - num_scale_nodes_before);
PADDLE_ENFORCE_EQ(
num_dropout_nodes_before,
num_scale_nodes_after - num_scale_nodes_before,
platform::errors::InvalidArgument(
"num_dropout_nodes_before = %d, num_scale_nodes_after = %d, "
"num_scale_nodes_before = %d.",
num_dropout_nodes_before, num_scale_nodes_after,
num_scale_nodes_before));
} else {
PADDLE_ENFORCE_EQ(num_scale_nodes_after - num_scale_nodes_before, 0);
PADDLE_ENFORCE_EQ(
num_scale_nodes_after - num_scale_nodes_before, 0,
platform::errors::InvalidArgument(
"num_scale_nodes_after = %d, num_scale_nodes_before = %d.",
num_scale_nodes_after, num_scale_nodes_before));
}
}
}
......
......@@ -300,10 +300,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
auto retrieve_node = [](const std::string& name,
const GraphPatternDetector::subgraph_t& subgraph,
const PDPattern& pat) -> Node* {
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)),
"pattern has no Node called %s", name.c_str());
PADDLE_ENFORCE_GT(subgraph.count(pat.RetrieveNode(name)), 0,
platform::errors::NotFound(
"Pattern has no node called %s.", name.c_str()));
Node* p = subgraph.at(pat.RetrieveNode(name));
PADDLE_ENFORCE_NOT_NULL(p, "subgraph has no node %s", name.c_str());
PADDLE_ENFORCE_NOT_NULL(p, platform::errors::NotFound(
"Subgraph has no node %s.", name.c_str()));
return p;
};
......
......@@ -51,15 +51,25 @@ void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) {
std::vector<Node *> nodes;
for (int i = 0; i < times; i++) {
PADDLE_ENFORCE(
subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i))));
PADDLE_ENFORCE(
subgraph.at(pattern.GetPDNode("transpose_out" + std::to_string(i))));
PADDLE_ENFORCE(
subgraph.at(pattern.GetPDNode("flatten" + std::to_string(i))));
PADDLE_ENFORCE(
subgraph.at(pattern.GetPDNode("flatten_out" + std::to_string(i))));
PADDLE_ENFORCE(subgraph.at(input_nodes[i]));
PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i))),
platform::errors::NotFound("Can not find transpose%d in subgraph.",
i));
PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("transpose_out" + std::to_string(i))),
platform::errors::NotFound(
"Can not find transpose_out%d in subgraph.", i));
PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("flatten" + std::to_string(i))),
platform::errors::NotFound("Can not find flatten%d in subgraph.", i));
PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("flatten_out" + std::to_string(i))),
platform::errors::NotFound("Can not find flatten_out%d in subgraph.",
i));
PADDLE_ENFORCE_NOT_NULL(
subgraph.at(input_nodes[i]),
platform::errors::NotFound("Can not find %s in subgraph.",
input_nodes[i]->name()));
nodes.push_back(subgraph.at(input_nodes[i]));
nodes.push_back(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册