未验证 提交 a82ce2b1 编写于 作者: H Huihuang Zheng 提交者: GitHub

API/OP (ConditionalBlock) error message enhancement (#23480)

API/OP (ConditionalBlock) error message enhancement (#23480)
上级 4489f0d3
......@@ -51,7 +51,9 @@ class ConditionalBlockInferOp : public ConditionalOp {
if (need_run) {
auto *scope_var = scope.FindVar(Output("Scope"));
PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
PADDLE_ENFORCE_NOT_NULL(
scope_var, platform::errors::PreconditionNotMet(
"Scope must be set in ConditionalBlockInferOp."));
auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>();
scopes->resize(1);
scopes->front() = &scope.NewScope();
......
......@@ -56,7 +56,9 @@ class ConditionalBlockOp : public ConditionalOp {
if (need_run) {
auto *scope_var = scope.FindVar(Output(ConditionalOp::kScope));
PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
PADDLE_ENFORCE_NOT_NULL(
scope_var, platform::errors::PreconditionNotMet(
"Scope must be set in conditional_block_op."));
auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>();
scopes->resize(1);
scopes->front() = &scope.NewScope();
......@@ -79,7 +81,7 @@ class ConditionalBlockInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInputs(ConditionalOp::kCondition), true,
platform::errors::InvalidArgument(
"conditional_block_op must have condition input"));
"conditional_block_op must have condition input."));
}
};
......@@ -116,13 +118,13 @@ class ConditionalBlockGradOp : public ConditionalOp {
}
auto *scope_var = scope.FindVar(Input(ConditionalOp::kScope));
PADDLE_ENFORCE_NE(scope_var, nullptr,
platform::errors::InvalidArgument(
"Scope must be set in conditional block op"));
PADDLE_ENFORCE_NOT_NULL(
scope_var, platform::errors::PreconditionNotMet(
"Scope must be set in conditional block op."));
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
PADDLE_ENFORCE_GT(scopes.size(), 0,
platform::errors::InvalidArgument(
"Scope must be set in conditional block op"));
"Scope must be set in conditional block op."));
framework::Scope &cur_scope = *scopes[0];
framework::Executor exec(dev_place);
......@@ -192,7 +194,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
PADDLE_ENFORCE_EQ(outside_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"Type of outside_var %s is NOT LoDTensor, which "
"doesn't match input_var %s",
"doesn't match input_var %s.",
outside_grad_name, input_name));
AssignZeroToOutsideTensor(
place, scope, input_var->Get<framework::LoDTensor>(),
......@@ -202,7 +204,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
true,
platform::errors::InvalidArgument(
"Type of outside_var %s is NOT LoDTensorArray, "
"which doesn't match input_var %s",
"which doesn't match input_var %s.",
outside_grad_name, input_name));
const auto &input_tensors = input_var->Get<framework::LoDTensorArray>();
auto *outside_tensors =
......@@ -210,7 +212,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
PADDLE_ENFORCE_EQ(input_tensors.size(), outside_tensors->size(),
platform::errors::InvalidArgument(
"LoDTensorArray outside_var %s doen't have same "
"size as input_var %s",
"size as input_var %s.",
outside_grad_name, input_name));
for (size_t j = 0; j < input_tensors.size(); ++j) {
AssignZeroToOutsideTensor(place, scope, input_tensors[j],
......@@ -220,7 +222,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
// TODO(huihuangzheng): add support for SelectedRows
PADDLE_THROW(platform::errors::InvalidArgument(
"Conditional block grad op doesn't support non-LoDTensor output "
"now"));
"now."));
}
}
}
......@@ -245,7 +247,10 @@ class ConditionalBlockGradOp : public ConditionalOp {
class ConditionalBlockGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInputs(ConditionalOp::kCondition));
PADDLE_ENFORCE_EQ(
context->HasInputs(ConditionalOp::kCondition), true,
platform::errors::InvalidArgument(
"Condition must be set in conditional_block_grad_op."));
if (context->HasInputs(ConditionalOp::kInputs) &&
context->HasOutputs(framework::GradVarName(ConditionalOp::kInputs))) {
context->SetOutputsDim(framework::GradVarName(ConditionalOp::kInputs),
......
......@@ -49,7 +49,9 @@ class ConditionalOp : public framework::OperatorBase {
xs.begin(), xs.end(), retv.begin(),
[&scope](const std::string &var_name) -> const framework::LoDTensor * {
auto *var = scope.FindVar(var_name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Cannot find variable %s",
var_name));
return &var->Get<framework::LoDTensor>();
});
return retv;
......@@ -57,15 +59,17 @@ class ConditionalOp : public framework::OperatorBase {
bool ScalarCondition(
const std::vector<const framework::LoDTensor *> &ips) const {
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
PADDLE_THROW("should have one initialized input as condition");
}
PADDLE_ENFORCE_EQ(
ips.size() == 1UL && ips[0]->IsInitialized(), true,
platform::errors::InvalidArgument(
"condition should have one initialized input as condition"));
PADDLE_ENFORCE(ips[0]->type() == framework::proto::VarType::BOOL &&
PADDLE_ENFORCE_EQ(ips[0]->type() == framework::proto::VarType::BOOL &&
ips[0]->numel() == 1,
true, platform::errors::InvalidArgument(
"condition input's data type should be bool, "
"numel should be 1, actual numel is %d",
ips[0]->numel());
ips[0]->numel()));
bool res = false;
if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA
......
......@@ -31,7 +31,12 @@ static bool IsMatchedConditionalBlockOpAndConditionalBlockGradOp(
static void FindAllConditionalBlockAndConditionalBlockGradOp(
const framework::ProgramDesc &program, std::vector<OpVariant> *fwd_ops,
std::vector<OpVariant> *bwd_ops) {
PADDLE_ENFORCE_GE(fwd_ops->size(), bwd_ops->size());
PADDLE_ENFORCE_GE(
fwd_ops->size(), bwd_ops->size(),
platform::errors::InvalidArgument(
"Size of forward ops must be greater or equal to backward ops. The "
"number of forward ops is %d and the number of backward ops is %d",
fwd_ops->size(), bwd_ops->size()));
for (size_t i = 1; i < program.Size(); ++i) {
auto &block = program.Block(i);
......@@ -47,7 +52,11 @@ static void FindAllConditionalBlockAndConditionalBlockGradOp(
PADDLE_ENFORCE_GE(
fwd_ops->size(), bwd_ops->size(),
"There are extra conditional_block_grad ops in the graph or program");
platform::errors::InvalidArgument(
"There are more conditional_block_grad ops than "
"conditional_block ops in the graph or program. The number of "
"forward ops is %d and the number of backward ops is %d",
fwd_ops->size(), bwd_ops->size()));
}
static void SetSkipVarsForConditionalBlockOp(OpVariant *fwd_op,
......@@ -102,14 +111,17 @@ static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
for (auto &fwd_op : ifelse_op_set) {
if (IsMatchedConditionalBlockOpAndConditionalBlockGradOp(fwd_op,
bwd_op)) {
PADDLE_ENFORCE(matched_fwd_op == nullptr,
"Found multiple matched conditional_block ops");
PADDLE_ENFORCE_EQ(matched_fwd_op, nullptr,
platform::errors::PreconditionNotMet(
"Found multiple matched conditional_block ops."));
matched_fwd_op = &fwd_op;
}
}
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
"Cannot find matched forward conditional_block op");
PADDLE_ENFORCE_NOT_NULL(
matched_fwd_op,
platform::errors::PreconditionNotMet(
"Cannot find matched forward conditional_block op."));
SetSkipVarsForConditionalBlockOp(const_cast<OpVariant *>(matched_fwd_op),
&bwd_op);
......
......@@ -1812,8 +1812,7 @@ class ConditionalBlockGuard(BlockGuard):
"""
def __init__(self, block):
if not isinstance(block, ConditionalBlock):
raise TypeError("block should be conditional block")
check_type(block, "block", ConditionalBlock, "ConditionalBlockGuard")
super(ConditionalBlockGuard, self).__init__(block.helper.main_program)
self.block = block
......@@ -1855,8 +1854,7 @@ class ConditionalBlock(object):
def __init__(self, inputs, is_scalar_condition=False, name=None):
for each_input in inputs:
if not isinstance(each_input, Variable):
raise TypeError("Each input should be variable")
check_type(each_input, "input", Variable, "ConditionalBlock")
self.inputs = inputs
self.is_scalar_condition = is_scalar_condition
self.helper = LayerHelper('conditional_block', name=name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册