未验证 提交 532079a2 编写于 作者: C Chen Weihang 提交者: GitHub

API (CompiledProgram) error message enhancement (#23559)

* api compild program error polish, test=develop

* fix coverage problem, test=develop

* fix details & add unittests, test=develop

* add test for coverage, test=develop
上级 73f421f7
...@@ -69,10 +69,11 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( ...@@ -69,10 +69,11 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
std::unordered_set<std::string> out_arg_set; std::unordered_set<std::string> out_arg_set;
for (auto &each_var_name : op->OutputArgumentNames()) { for (auto &each_var_name : op->OutputArgumentNames()) {
if (each_var_name != kEmptyVarName) { if (each_var_name != kEmptyVarName) {
PADDLE_ENFORCE(out_arg_set.count(each_var_name) == 0, PADDLE_ENFORCE_EQ(out_arg_set.count(each_var_name), 0,
"Program is wrong. %s occurs in output of %s several " platform::errors::InvalidArgument(
"times.", "The input Program is invalid. Variable %s occurs"
each_var_name, op->Type()); " in output of %s multiple times.",
each_var_name, op->Type()));
out_arg_set.insert(each_var_name); out_arg_set.insert(each_var_name);
} }
...@@ -121,10 +122,10 @@ void Graph::ResolveHazard( ...@@ -121,10 +122,10 @@ void Graph::ResolveHazard(
(*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0]; (*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0];
const auto &read_ops = (*it_old)->outputs; const auto &read_ops = (*it_old)->outputs;
PADDLE_ENFORCE( PADDLE_ENFORCE_NOT_NULL(
write_op, write_op, platform::errors::NotFound(
string::Sprintf("The write_op of var %s should not be empty.", "The generate operator of variable %s is null.",
(*it_new)->Name())); (*it_new)->Name()));
// Add write after write dependence // Add write after write dependence
ir::Node *upstream_op = ir::Node *upstream_op =
...@@ -174,6 +175,8 @@ std::shared_ptr<Graph> Graph::Clone() { ...@@ -174,6 +175,8 @@ std::shared_ptr<Graph> Graph::Clone() {
cloned_graph->num_node_created_ = 0; cloned_graph->num_node_created_ = 0;
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned; std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
for (auto *n : this->node_set_) { for (auto *n : this->node_set_) {
PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument(
"The node to be cloned is nullptr."));
ir::Node *cloned_node = nullptr; ir::Node *cloned_node = nullptr;
if (n->IsCtrlVar()) { if (n->IsCtrlVar()) {
cloned_node = cloned_graph->CreateControlDepVar(); cloned_node = cloned_graph->CreateControlDepVar();
...@@ -184,11 +187,11 @@ std::shared_ptr<Graph> Graph::Clone() { ...@@ -184,11 +187,11 @@ std::shared_ptr<Graph> Graph::Clone() {
} else if (n->IsOp()) { } else if (n->IsOp()) {
cloned_node = cloned_graph->CreateOpNode(n->Op()); cloned_node = cloned_graph->CreateOpNode(n->Op());
} }
if (cloned_node) { PADDLE_ENFORCE_NOT_NULL(
origin_to_cloned[n] = cloned_node; cloned_node,
} else { platform::errors::InvalidArgument(
PADDLE_THROW("The cloned node's type is not supported!"); "Failed to clone new node from original node in graph."));
} origin_to_cloned[n] = cloned_node;
} }
for (auto *n : this->node_set_) { for (auto *n : this->node_set_) {
for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) { for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) {
......
...@@ -95,15 +95,17 @@ class Graph { ...@@ -95,15 +95,17 @@ class Graph {
template <typename AttrType> template <typename AttrType>
AttrType &Get(const std::string &attr_name) const { AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE_EQ(Has(attr_name), true, "%s attr not registered for graph.", PADDLE_ENFORCE_EQ(
attr_name); Has(attr_name), true,
platform::errors::PreconditionNotMet(
"%s attribute not registered for current graph.", attr_name));
try { try {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name)); return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} catch (boost::bad_any_cast &) { } catch (boost::bad_any_cast &) {
PADDLE_THROW( PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid attribute type of %s error, expected: %s, actual: %s", "Invalid attribute type of %s, expected: %s, received: %s.",
attr_name, typeid(AttrType *).name(), attr_name, platform::demangle(typeid(AttrType *).name()), // NOLINT
attrs_.at(attr_name).type().name()); platform::demangle(attrs_.at(attr_name).type().name())));
} }
} }
...@@ -112,7 +114,8 @@ class Graph { ...@@ -112,7 +114,8 @@ class Graph {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
attrs_.count(attr_name), 0, attrs_.count(attr_name), 0,
platform::errors::AlreadyExists( platform::errors::AlreadyExists(
"The attribute %s has been set in the graph.", attr_name)); "The attribute %s to be set already exists in the graph.",
attr_name));
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() { attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name; VLOG(3) << "deleting " << attr_name;
...@@ -124,8 +127,9 @@ class Graph { ...@@ -124,8 +127,9 @@ class Graph {
void SetNotOwned(const std::string &attr_name, AttrType *attr) { void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
attrs_.count(attr_name), 0, attrs_.count(attr_name), 0,
platform::errors::AlreadyExists( platform::errors::AlreadyExists("The attribute %s to be set(not owned) "
"The attribute %s has been set in the graph.", attr_name)); "already exists in the graph.",
attr_name));
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
attr_dels_[attr_name] = []() {}; attr_dels_[attr_name] = []() {};
} }
...@@ -134,7 +138,8 @@ class Graph { ...@@ -134,7 +138,8 @@ class Graph {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
attrs_.count(attr_name), 0, attrs_.count(attr_name), 0,
platform::errors::NotFound( platform::errors::NotFound(
"The attribute %s has not been set in the graph.", attr_name)); "The attribute %s to be erased does not exist in the graph.",
attr_name));
attr_dels_[attr_name](); attr_dels_[attr_name]();
attrs_.erase(attr_name); attrs_.erase(attr_name);
attr_dels_.erase(attr_name); attr_dels_.erase(attr_name);
...@@ -144,7 +149,9 @@ class Graph { ...@@ -144,7 +149,9 @@ class Graph {
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc) {
PADDLE_ENFORCE_NOT_NULL(var_desc); PADDLE_ENFORCE_NOT_NULL(
var_desc, platform::errors::InvalidArgument(
"The VarDesc used to create variable node is null."));
auto *x = AddNode(new ir::Node(var_desc)); auto *x = AddNode(new ir::Node(var_desc));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
return x; return x;
...@@ -152,7 +159,9 @@ class Graph { ...@@ -152,7 +159,9 @@ class Graph {
// Create a normal runnable operator with OpDesc. // Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) { ir::Node *CreateOpNode(OpDesc *op_desc) {
PADDLE_ENFORCE_NOT_NULL(op_desc); PADDLE_ENFORCE_NOT_NULL(
op_desc, platform::errors::InvalidArgument(
"The OpDesc used to create operator node is null."));
auto *x = AddNode(new ir::Node(op_desc)); auto *x = AddNode(new ir::Node(op_desc));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
return x; return x;
...@@ -192,7 +201,9 @@ class Graph { ...@@ -192,7 +201,9 @@ class Graph {
} }
std::unique_ptr<ir::Node> RemoveNode(ir::Node *node) { std::unique_ptr<ir::Node> RemoveNode(ir::Node *node) {
PADDLE_ENFORCE_EQ(node_set_.find(node) != node_set_.end(), true); PADDLE_ENFORCE_EQ(node_set_.find(node) != node_set_.end(), true,
platform::errors::PreconditionNotMet(
"The node to be removed does not exist."));
std::unique_ptr<ir::Node> ret; std::unique_ptr<ir::Node> ret;
ret.reset(nodes_.at(node).release()); ret.reset(nodes_.at(node).release());
nodes_.erase(node); nodes_.erase(node);
...@@ -218,7 +229,9 @@ class Graph { ...@@ -218,7 +229,9 @@ class Graph {
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE_EQ(node_set_.find(node) == node_set_.end(), true); PADDLE_ENFORCE_EQ(node_set_.find(node) == node_set_.end(), true,
platform::errors::PreconditionNotMet(
"The node to be added already exists."));
nodes_[node].reset(node); nodes_[node].reset(node);
node_set_.insert(node); node_set_.insert(node);
return node; return node;
......
...@@ -139,8 +139,9 @@ class CompiledProgram(object): ...@@ -139,8 +139,9 @@ class CompiledProgram(object):
self._graph = core.Graph(program_or_graph.desc) self._graph = core.Graph(program_or_graph.desc)
self._program = program_or_graph self._program = program_or_graph
else: else:
raise ValueError("Wrong program_to_graph type: %s" % raise TypeError(
type(program_or_graph)) "The type of program_to_graph parameter is wrong, expected Graph or Program, but received %s"
% type(program_or_graph))
self._scope = None self._scope = None
self._place = None self._place = None
...@@ -258,8 +259,8 @@ class CompiledProgram(object): ...@@ -258,8 +259,8 @@ class CompiledProgram(object):
feed={"X": test_data}, feed={"X": test_data},
fetch_list=[loss.name]) fetch_list=[loss.name])
""" """
assert not self._is_data_parallel, "Already compiled with parallel." assert not self._is_data_parallel, "Already compiled with parallel, cannot be recompiled."
assert not self._is_inference, "Cannot compile both data parallel and inference" assert not self._is_inference, "Cannot compile with both data parallel and inference."
self._is_data_parallel = True self._is_data_parallel = True
# FIXME(zcd): Currently, the build_strategy can be set during creating # FIXME(zcd): Currently, the build_strategy can be set during creating
# CompiledProgram or calling with_data_parallel, and it may be confusing, # CompiledProgram or calling with_data_parallel, and it may be confusing,
...@@ -272,7 +273,7 @@ class CompiledProgram(object): ...@@ -272,7 +273,7 @@ class CompiledProgram(object):
self._places = places self._places = places
if _has_backward_op(self._graph): if _has_backward_op(self._graph):
assert self._loss_name is not None, "The loss_name should be set here." assert self._loss_name is not None, "The loss name of CompiledProgram is None. The loss name should be set if CompiledProgram contains backward part."
if self._places is not None: if self._places is not None:
if not isinstance(self._places, (list, tuple)): if not isinstance(self._places, (list, tuple)):
...@@ -288,8 +289,8 @@ class CompiledProgram(object): ...@@ -288,8 +289,8 @@ class CompiledProgram(object):
Returns: Returns:
self self
""" """
assert not self._is_data_parallel, "Cannot compile both data parallel and inference" assert not self._is_data_parallel, "Cannot compile with both data parallel and inference"
assert not self._is_inference, "Already compiled with inference" assert not self._is_inference, "Already compiled with inference, cannot be recompiled."
assert any([ assert any([
isinstance(config, InferNativeConfig), isinstance(config, InferNativeConfig),
...@@ -300,30 +301,29 @@ class CompiledProgram(object): ...@@ -300,30 +301,29 @@ class CompiledProgram(object):
return self return self
def _with_distributed(self): def _with_distributed(self):
raise NotImplementedError() raise NotImplementedError(
"Subclass of CompiledProgram should implement _with_distributed method."
)
def _compile_data_parallel(self, places, use_cuda=False, scope=None): def _compile_data_parallel(self, places, use_cuda=False, scope=None):
if self._share_vars_from: if self._share_vars_from:
if scope: if scope:
sys.stderr.write("share_vars_from is set, scope is ignored.\n") sys.stderr.write("share_vars_from is set, scope is ignored.\n")
if not self._is_data_parallel:
raise ValueError(
"Currently, only data parallel mode need share_vars_from.")
if not self._share_vars_from._is_data_parallel: if not self._share_vars_from._is_data_parallel:
raise ValueError("share_vars_from is not data parallel. Cannot " raise ValueError(
"share vars from it.") "The shared Program is not data parallel, cannot "
"share variables from it.")
if self._share_vars_from._executor is None: if self._share_vars_from._executor is None:
raise ValueError( raise ValueError(
"share_vars_from is not compiled and run, so there is no " "The shared Program is not compiled and executed, so there is no "
"var to share.") "variables to share.")
self._local_scopes = self._share_vars_from._executor.local_scopes() self._local_scopes = self._share_vars_from._executor.local_scopes()
else: else:
assert scope is not None, "" assert scope is not None, ""
self._local_scopes = [] self._local_scopes = []
assert isinstance(places, tuple) or isinstance(places, list), \ assert isinstance(places, tuple) or isinstance(places, list), \
"Currently , The places type only should be list or tuple, \n" \ "Currently , The places type can only be list or tuple, but the input type is {}.".format(type(places))
"but the input type is {}.".format(type(places))
if self._build_strategy is None: if self._build_strategy is None:
self._build_strategy = BuildStrategy() self._build_strategy = BuildStrategy()
...@@ -354,7 +354,7 @@ class CompiledProgram(object): ...@@ -354,7 +354,7 @@ class CompiledProgram(object):
tps = self._program._trainers_endpoints tps = self._program._trainers_endpoints
assert self._build_strategy.num_trainers == len( assert self._build_strategy.num_trainers == len(
tps), "num_trainers == len(end_points)" tps), "The trainer numbers is not equal to endpoint numbers."
self._build_strategy.trainers_endpoints = tps self._build_strategy.trainers_endpoints = tps
if self._program: if self._program:
...@@ -366,11 +366,11 @@ class CompiledProgram(object): ...@@ -366,11 +366,11 @@ class CompiledProgram(object):
self._build_strategy.enable_sequential_execution = True self._build_strategy.enable_sequential_execution = True
if self._program is not None and self._program._enable_dgc: if self._program is not None and self._program._enable_dgc:
assert use_cuda, "DGC only used under cuda" assert use_cuda, "DGC only used under CUDA environment."
assert self._build_strategy.num_trainers * len( assert self._build_strategy.num_trainers * len(
places) > 1, "DGC is not useful for single card training" places) > 1, "DGC is not avaliable for single card training."
assert self._build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce, "DGC \ assert self._build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce, "DGC \
only used for AllReduce BuildStrategy" only can be used for AllReduce BuildStrategy."
# DGC doesn't support fuse for now, close fuse. # DGC doesn't support fuse for now, close fuse.
self._build_strategy.fuse_all_reduce_ops = False self._build_strategy.fuse_all_reduce_ops = False
...@@ -411,9 +411,9 @@ class CompiledProgram(object): ...@@ -411,9 +411,9 @@ class CompiledProgram(object):
""" """
if self._compiled: if self._compiled:
if scope and self._scope != scope: if scope and self._scope != scope:
raise ValueError("Cannot compile with different scope") raise ValueError("Cannot compile program with different scope.")
if place and not self._place._equals(place): if place and not self._place._equals(place):
raise ValueError("Cannot compile with different place") raise ValueError("Cannot compile program with different place.")
return self return self
self._compiled = True self._compiled = True
...@@ -448,9 +448,9 @@ class CompiledProgram(object): ...@@ -448,9 +448,9 @@ class CompiledProgram(object):
if has_set_place: if has_set_place:
for p in place_list: for p in place_list:
assert p._type() == place._type(), \ assert p._type() == place._type(), \
"Place type not match. You may set the wrong type of places" "Place type not match. You may set wrong type of places."
else: else:
place_list = cuda_places() if isinstance( place_list = cuda_places() if isinstance(
place, core.CUDAPlace) else cpu_places() place, core.CUDAPlace) else cpu_places()
assert place_list, "no place for execution" assert place_list, "No places for execution."
return place_list return place_list
# copyright (c) 2020 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import core
from test_imperative_base import new_program_scope
from simple_nets import simple_fc_net
class TestCompiledProgram(unittest.TestCase):
def setUp(self):
self.seed = 100
self.img = np.random.random(size=(16, 784)).astype('float32')
self.label = np.random.randint(
low=0, high=10, size=[16, 1], dtype=np.int64)
with new_program_scope():
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
loss = simple_fc_net()
exe.run(fluid.default_startup_program())
loss_data, = exe.run(fluid.default_main_program(),
feed={"image": self.img,
"label": self.label},
fetch_list=[loss.name])
self.loss = loss_data[0]
def test_compiled_program_base(self):
with new_program_scope():
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
loss = simple_fc_net()
exe.run(fluid.default_startup_program())
compiled_prog = fluid.CompiledProgram(fluid.default_main_program())
loss_data, = exe.run(compiled_prog,
feed={"image": self.img,
"label": self.label},
fetch_list=[loss.name])
self.assertTrue(np.array_equal(loss_data[0], self.loss))
def test_compiled_program_with_data_parallel(self):
with new_program_scope():
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
loss = simple_fc_net()
exe.run(fluid.default_startup_program())
compiled_prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
loss_name=loss.name, places=[place])
loss_data, = exe.run(compiled_prog,
feed={"image": self.img,
"label": self.label},
fetch_list=[loss.name])
self.assertTrue(np.array_equal(loss_data[0], self.loss))
class TestCompiledProgramError(unittest.TestCase):
def test_program_or_graph_error(self):
self.assertRaises(TypeError, fluid.CompiledProgram, "program")
def build_simple_model(self):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
prediction = fluid.layers.fc(input=img, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
def compile_program_not_compiled(self):
with fluid.program_guard(fluid.Program()):
# build model
self.build_simple_model()
# compile program
program = fluid.default_main_program()
compiled_program = fluid.CompiledProgram(
program).with_data_parallel()
return compiled_program
def compile_program(self):
with fluid.program_guard(fluid.Program()):
# build model
self.build_simple_model()
# compile program
program = fluid.default_main_program()
compiled_program = fluid.CompiledProgram(program)
scope = fluid.global_scope()
place = fluid.CPUPlace()
compiled_program._compile(scope, place)
return compiled_program, scope, place
def test_compile_scope_error(self):
compiled_program, _, place = self.compile_program()
new_scope = core.Scope()
with self.assertRaises(ValueError):
compiled_program._compile(new_scope, place)
def test_compile_place_error(self):
# need create different place
if core.is_compiled_with_cuda():
compiled_program, scope, _ = self.compile_program()
new_place = fluid.CUDAPlace(0)
with self.assertRaises(ValueError):
compiled_program._compile(scope, new_place)
def test_share_vars_from_error_no_parallel(self):
with fluid.program_guard(fluid.Program()):
source_program, _, _ = self.compile_program()
self.build_simple_model()
# compile program
program = fluid.default_main_program()
compiled_program = fluid.CompiledProgram(
program).with_data_parallel(share_vars_from=source_program)
scope = fluid.global_scope()
place = fluid.CPUPlace()
with self.assertRaises(ValueError):
compiled_program._compile(scope, place)
def test_share_vars_from_error_no_executor(self):
with fluid.program_guard(fluid.Program()):
source_program = self.compile_program_not_compiled()
self.build_simple_model()
# compile program
program = fluid.default_main_program()
compiled_program = fluid.CompiledProgram(
program).with_data_parallel(share_vars_from=source_program)
scope = fluid.global_scope()
place = fluid.CPUPlace()
with self.assertRaises(ValueError):
compiled_program._compile(scope, place)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册