未验证 提交 dafb0e3b 编写于 作者: C Chen Weihang 提交者: GitHub

Polish framework error message part 6 (#27257)

* polish framework error msg part 6

* polish lossed item

* fix failed unittest

* polish by reviewer comments
上级 e6e2e537
...@@ -69,7 +69,8 @@ class OpInfo { ...@@ -69,7 +69,8 @@ class OpInfo {
const OpCreator& Creator() const { const OpCreator& Creator() const {
PADDLE_ENFORCE_NOT_NULL(creator_, PADDLE_ENFORCE_NOT_NULL(creator_,
"Operator's Creator has not been registered"); platform::errors::NotFound(
"Operator's Creator has not been registered."));
return creator_; return creator_;
} }
...@@ -79,11 +80,12 @@ class OpInfo { ...@@ -79,11 +80,12 @@ class OpInfo {
std::string type = proto_ ? proto_->type() : "unknown"; std::string type = proto_ ? proto_->type() : "unknown";
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
grad_op_maker_, grad_op_maker_,
"Operator %s's GradOpMaker has not been " platform::errors::NotFound(
"registered.\nPlease check whether %s_op has " "Operator %s's GradOpMaker has not been "
"grad_op.\nIf not, please set stop_gradient to True " "registered.\nPlease check whether (%s) operator has "
"for its input and output variables using var.stop_gradient=True.", "gradient operator.\nIf not, please set stop_gradient to be True "
type.c_str(), type.c_str()); "for its input and output variables using var.stop_gradient=True.",
type.c_str(), type.c_str()));
return grad_op_maker_; return grad_op_maker_;
} }
...@@ -100,11 +102,12 @@ class OpInfo { ...@@ -100,11 +102,12 @@ class OpInfo {
std::string type = proto_ ? proto_->type() : "unknown"; std::string type = proto_ ? proto_->type() : "unknown";
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
dygraph_grad_op_maker_, dygraph_grad_op_maker_,
"Operator %s's DygraphGradOpMaker has not been " platform::errors::NotFound(
"registered.\nPlease check whether %s_op has " "Operator %s's DygraphGradOpMaker has not been "
"grad_op.\nIf not, please set stop_gradient to True " "registered.\nPlease check whether (%s) operator has "
"for its input and output variables using var.stop_gradient=True.", "gradient operator.\nIf not, please set stop_gradient to be True "
type.c_str(), type.c_str()); "for its input and output variables using var.stop_gradient=True.",
type.c_str(), type.c_str()));
return dygraph_grad_op_maker_; return dygraph_grad_op_maker_;
} }
...@@ -130,14 +133,17 @@ class OpInfoMap { ...@@ -130,14 +133,17 @@ class OpInfoMap {
} }
void Insert(const std::string& type, const OpInfo& info) { void Insert(const std::string& type, const OpInfo& info) {
PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type); PADDLE_ENFORCE_NE(Has(type), true,
platform::errors::AlreadyExists(
"Operator (%s) has been registered.", type));
map_.insert({type, info}); map_.insert({type, info});
} }
const OpInfo& Get(const std::string& type) const { const OpInfo& Get(const std::string& type) const {
auto op_info_ptr = GetNullable(type); auto op_info_ptr = GetNullable(type);
PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not been registered", PADDLE_ENFORCE_NOT_NULL(
type); op_info_ptr,
platform::errors::NotFound("Operator (%s) is not registered.", type));
return *op_info_ptr; return *op_info_ptr;
} }
......
...@@ -33,10 +33,18 @@ size_t OpKernelType::Hash::operator()(const OpKernelType& key) const { ...@@ -33,10 +33,18 @@ size_t OpKernelType::Hash::operator()(const OpKernelType& key) const {
cur_loc += OpKernelType::kLibBits; cur_loc += OpKernelType::kLibBits;
int customized_value = key.customized_type_value_; int customized_value = key.customized_type_value_;
PADDLE_ENFORCE(customized_value < (1 << OpKernelType::kCustomizeBits)); PADDLE_ENFORCE_LT(customized_value, (1 << OpKernelType::kCustomizeBits),
platform::errors::Unavailable(
"Too many custom OpKernel attribute values, expected "
"maximum value is %d, received value is %d.",
(1 << OpKernelType::kCustomizeBits), customized_value));
customized_value = customized_value << cur_loc; customized_value = customized_value << cur_loc;
cur_loc += OpKernelType::kCustomizeBits; cur_loc += OpKernelType::kCustomizeBits;
PADDLE_ENFORCE(cur_loc < 64); PADDLE_ENFORCE_LT(cur_loc, 64,
platform::errors::Unavailable(
"Too many OpKernel attribute values, expected maximum "
"value is 64, received value is %d.",
cur_loc));
std::hash<int> hasher; std::hash<int> hasher;
return hasher(place + data_type + data_layout + library_type + return hasher(place + data_type + data_layout + library_type +
......
...@@ -43,7 +43,9 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput( ...@@ -43,7 +43,9 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) { auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name); PADDLE_ENFORCE_EQ(
names.count(name), 0,
platform::errors::AlreadyExists("Attribute [%s] is duplicated.", name));
names.insert(name); names.insert(name);
}; };
for (auto& attr : proto_->attrs()) { for (auto& attr : proto_->attrs()) {
......
...@@ -54,9 +54,10 @@ class Registrar { ...@@ -54,9 +54,10 @@ class Registrar {
template <typename... ARGS> template <typename... ARGS>
struct OperatorRegistrar : public Registrar { struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const char* op_type) { explicit OperatorRegistrar(const char* op_type) {
if (OpInfoMap::Instance().Has(op_type)) { PADDLE_ENFORCE_EQ(
PADDLE_THROW("'%s' is registered more than once.", op_type); OpInfoMap::Instance().Has(op_type), false,
} platform::errors::AlreadyExists(
"Operator '%s' is registered more than once.", op_type));
static_assert(sizeof...(ARGS) != 0, static_assert(sizeof...(ARGS) != 0,
"OperatorRegistrar should be invoked at least by OpClass"); "OperatorRegistrar should be invoked at least by OpClass");
OpInfo info; OpInfo info;
......
...@@ -58,7 +58,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -58,7 +58,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddInput("input", "input of cosine op").AsDuplicable(); AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").AsIntermediate(); AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) { auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); PADDLE_ENFORCE_EQ(i % 2, 0, platform::errors::InvalidArgument(
"'test_attr' must be even!"));
}; };
AddAttr<int>("test_attr", "a simple test attribute") AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker); .AddCustomChecker(my_checker);
......
...@@ -152,10 +152,10 @@ class OpVersionRegistrar { ...@@ -152,10 +152,10 @@ class OpVersionRegistrar {
return instance; return instance;
} }
OpVersion& Register(const std::string& op_type) { OpVersion& Register(const std::string& op_type) {
if (op_version_map_.find(op_type) != op_version_map_.end()) { PADDLE_ENFORCE_EQ(
PADDLE_THROW("'%s' is registered in operator version more than once.", op_version_map_.find(op_type), op_version_map_.end(),
op_type); platform::errors::AlreadyExists(
} "'%s' is registered in operator version more than once.", op_type));
op_version_map_.insert({op_type, OpVersion()}); op_version_map_.insert({op_type, OpVersion()});
return op_version_map_[op_type]; return op_version_map_[op_type];
} }
......
...@@ -164,15 +164,20 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { ...@@ -164,15 +164,20 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(4) << place << " " << DebugStringEx(&scope); VLOG(4) << place << " " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place); PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CUDA support.",
place));
#else #else
auto dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device; auto dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
platform::SetDeviceId(dev_id); platform::SetDeviceId(dev_id);
#endif #endif
} else if (platform::is_xpu_place(place)) { } else if (platform::is_xpu_place(place)) {
#ifndef PADDLE_WITH_XPU #ifndef PADDLE_WITH_XPU
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s", place)); "Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with XPU support.",
place));
#else #else
auto dev_id = BOOST_GET_CONST(platform::XPUPlace, place).device; auto dev_id = BOOST_GET_CONST(platform::XPUPlace, place).device;
platform::SetXPUDeviceId(dev_id); platform::SetXPUDeviceId(dev_id);
...@@ -214,7 +219,7 @@ std::string OperatorBase::Input(const std::string& name) const { ...@@ -214,7 +219,7 @@ std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
ins.size(), 1UL, ins.size(), 1UL,
platform::errors::AlreadyExists( platform::errors::InvalidArgument(
"Operator %s's input %s should contain only one variable.", type_, "Operator %s's input %s should contain only one variable.", type_,
name)); name));
return ins.empty() ? kEmptyVarName : ins[0]; return ins.empty() ? kEmptyVarName : ins[0];
...@@ -223,8 +228,10 @@ std::string OperatorBase::Input(const std::string& name) const { ...@@ -223,8 +228,10 @@ std::string OperatorBase::Input(const std::string& name) const {
const std::vector<std::string>& OperatorBase::Inputs( const std::vector<std::string>& OperatorBase::Inputs(
const std::string& name) const { const std::string& name) const {
auto it = inputs_.find(name); auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Operator %s does not have the input %s.", PADDLE_ENFORCE_NE(
type_, name); it, inputs_.end(),
platform::errors::NotFound("Operator %s does not have the input %s.",
type_, name));
return it->second; return it->second;
} }
...@@ -238,17 +245,21 @@ bool OperatorBase::HasOutputs(const std::string& name) const { ...@@ -238,17 +245,21 @@ bool OperatorBase::HasOutputs(const std::string& name) const {
std::string OperatorBase::Output(const std::string& name) const { std::string OperatorBase::Output(const std::string& name) const {
auto& outs = Outputs(name); auto& outs = Outputs(name);
PADDLE_ENFORCE_LE(outs.size(), 1UL, PADDLE_ENFORCE_LE(
"Operator %s's output %s should contain only one variable.", outs.size(), 1UL,
type_, name); platform::errors::InvalidArgument(
"Operator %s's output %s should contain only one variable.", type_,
name));
return outs.empty() ? kEmptyVarName : outs[0]; return outs.empty() ? kEmptyVarName : outs[0];
} }
const std::vector<std::string>& OperatorBase::Outputs( const std::vector<std::string>& OperatorBase::Outputs(
const std::string& name) const { const std::string& name) const {
auto it = outputs_.find(name); auto it = outputs_.find(name);
PADDLE_ENFORCE(it != outputs_.end(), PADDLE_ENFORCE_NE(
"Operator %s does not have an output called %s.", type_, name); it, outputs_.end(),
platform::errors::NotFound(
"Operator %s does not have an output called %s.", type_, name));
return it->second; return it->second;
} }
...@@ -391,16 +402,19 @@ void OperatorBase::CheckAllInputOutputSet() const { ...@@ -391,16 +402,19 @@ void OperatorBase::CheckAllInputOutputSet() const {
for (auto& in : info_->Proto().inputs()) { for (auto& in : info_->Proto().inputs()) {
if (!in.dispensable()) { if (!in.dispensable()) {
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(), PADDLE_ENFORCE_NE(
"Operator %s's input, %s, is not set", Type(), in.name()); inputs_.find(in.name()), inputs_.end(),
platform::errors::NotFound("Operator %s's input (%s) is not set.",
Type(), in.name()));
} }
} }
for (auto& out : info_->Proto().outputs()) { for (auto& out : info_->Proto().outputs()) {
if (!out.dispensable()) { if (!out.dispensable()) {
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(), PADDLE_ENFORCE_NE(
"Operator %s's output, %s, is not set", Type(), outputs_.find(out.name()), outputs_.end(),
out.name()); platform::errors::NotFound("Operator %s's output (%s) is not set.",
Type(), out.name()));
} }
} }
} }
...@@ -428,8 +442,9 @@ const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) { ...@@ -428,8 +442,9 @@ const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
} else if (var.IsType<SelectedRows>()) { } else if (var.IsType<SelectedRows>()) {
return &(var.Get<SelectedRows>().value()); return &(var.Get<SelectedRows>().value());
} else { } else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", PADDLE_THROW(platform::errors::InvalidArgument(
ToTypeName(var.Type())); "Variable type is %s, expect LoDTensor or SelectedRows.",
ToTypeName(var.Type())));
} }
} }
...@@ -439,8 +454,9 @@ Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) { ...@@ -439,8 +454,9 @@ Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
return var->GetMutable<SelectedRows>()->mutable_value(); return var->GetMutable<SelectedRows>()->mutable_value();
} else { } else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", PADDLE_THROW(platform::errors::InvalidArgument(
ToTypeName(var->Type())); "Variable type is %s, expect LoDTensor or SelectedRows.",
ToTypeName(var->Type())));
} }
} }
...@@ -462,7 +478,7 @@ const Variable* ExecutionContext::InputVar(const std::string& name) const { ...@@ -462,7 +478,7 @@ const Variable* ExecutionContext::InputVar(const std::string& name) const {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
it->second.size(), 1UL, it->second.size(), 1UL,
platform::errors::AlreadyExists( platform::errors::InvalidArgument(
"Operator %s's input %s should contain only one variable.", "Operator %s's input %s should contain only one variable.",
op_.Type(), name)); op_.Type(), name));
return it->second.empty() ? nullptr : it->second[0]; return it->second.empty() ? nullptr : it->second[0];
...@@ -472,9 +488,11 @@ Variable* ExecutionContext::OutputVar(const std::string& name) const { ...@@ -472,9 +488,11 @@ Variable* ExecutionContext::OutputVar(const std::string& name) const {
auto it = ctx_.outputs.find(name); auto it = ctx_.outputs.find(name);
if (it == ctx_.outputs.end()) return nullptr; if (it == ctx_.outputs.end()) return nullptr;
PADDLE_ENFORCE_LE(it->second.size(), 1UL, PADDLE_ENFORCE_LE(
"Operator %s's output %s should contain only one variable.", it->second.size(), 1UL,
op_.Type(), name); platform::errors::InvalidArgument(
"Operator %s's output %s should contain only one variable.",
op_.Type(), name));
return it->second.empty() ? nullptr : it->second[0]; return it->second.empty() ? nullptr : it->second[0];
} }
...@@ -497,10 +515,11 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( ...@@ -497,10 +515,11 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
std::transform(vars.begin(), vars.end(), std::back_inserter(res), std::transform(vars.begin(), vars.end(), std::back_inserter(res),
[&](const Variable* var) -> const Tensor* { [&](const Variable* var) -> const Tensor* {
if (var == nullptr) return nullptr; if (var == nullptr) return nullptr;
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(var->IsType<LoDTensor>(), true,
var->IsType<LoDTensor>(), platform::errors::InvalidArgument(
"should be LoDTensor, but the received type is %s", "Input variable should be LoDTensor, "
ToTypeName(var->Type())); "but the received type is %s.",
ToTypeName(var->Type())));
return &(var->Get<LoDTensor>()); return &(var->Get<LoDTensor>());
}); });
return res; return res;
...@@ -558,8 +577,10 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -558,8 +577,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
const auto& in = it->second; const auto& in = it->second;
if (in.size() == 0) return false; if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(in.size(), 1UL, PADDLE_ENFORCE_EQ(
"Input %s should not have more than one inputs", name); in.size(), 1UL,
platform::errors::InvalidArgument(
"Input %s should not contain more than one inputs.", name));
return in[0] != nullptr; return in[0] != nullptr;
} }
...@@ -574,8 +595,10 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -574,8 +595,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
if (out.size() == 0) { if (out.size() == 0) {
return false; return false;
} }
PADDLE_ENFORCE_EQ(out.size(), 1UL, PADDLE_ENFORCE_EQ(
"Output %s should not have more than one outputs", name); out.size(), 1UL,
platform::errors::InvalidArgument(
"Output %s should not contain more than one outputs.", name));
return out[0] != nullptr; return out[0] != nullptr;
} }
...@@ -644,16 +667,31 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -644,16 +667,31 @@ class RuntimeInferShapeContext : public InferShapeContext {
size_t j = 0) override { size_t j = 0) override {
auto in_it = ctx_.inputs.find(in); auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out); auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i, PADDLE_ENFORCE_NE(
"Inputs %s should have %llu argument", in, i); in_it, ctx_.inputs.end(),
PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j, platform::errors::NotFound("Input %s does not exist.", in));
"Outputs %s should have %llu argument", out, j); PADDLE_ENFORCE_NE(
out_it, ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i, in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(), i));
PADDLE_ENFORCE_LT(j, out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(), j));
Variable* in_var = in_it->second[i]; Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j]; Variable* out_var = out_it->second[j];
PADDLE_ENFORCE(in_var->Type() == out_var->Type(), PADDLE_ENFORCE_EQ(
"The type of %s and %s is not the same.", in, out); in_var->Type(), out_var->Type(),
platform::errors::InvalidArgument(
"The type of input (%s) and output (%s) are inconsistent.", in,
out));
if (in_var->IsType<framework::SelectedRows>()) { if (in_var->IsType<framework::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>(); auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
...@@ -666,9 +704,9 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -666,9 +704,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>(); auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims()); out_lod_tensor->Resize(in_lod_tensor.dims());
} else { } else {
PADDLE_THROW( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, the input type of ShareDim only can be LoDTensor " "Currently, the input type of ShareDim only can be LoDTensor "
"or SelectedRows."); "or SelectedRows."));
} }
} }
...@@ -721,16 +759,30 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -721,16 +759,30 @@ class RuntimeInferShapeContext : public InferShapeContext {
size_t j = 0) const override { size_t j = 0) const override {
auto in_it = ctx_.inputs.find(in); auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out); auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i, PADDLE_ENFORCE_NE(
"Inputs %s should have %llu argument", in, i); in_it, ctx_.inputs.end(),
PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j, platform::errors::NotFound("Input %s does not exist.", in));
"Outputs %s should have %llu argument", out, j); PADDLE_ENFORCE_NE(
out_it, ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i, in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(), i));
PADDLE_ENFORCE_LT(j, out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(), j));
Variable* in_var = in_it->second.at(i); Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<LoDTensor>()) return; if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = out_it->second.at(j); Variable* out_var = out_it->second.at(j);
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(), PADDLE_ENFORCE_EQ(
"The %d-th output of Output(%s) must be LoDTensor.", j, out); out_var->IsType<LoDTensor>(), true,
platform::errors::InvalidArgument(
"The %zu-th output of Output(%s) must be LoDTensor.", j, out));
auto& in_tensor = in_var->Get<LoDTensor>(); auto& in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>(); auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod()); out_tensor->set_lod(in_tensor.lod());
...@@ -757,18 +809,18 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -757,18 +809,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override { int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
PADDLE_THROW( PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetLoDLevel is only used in compile time. The calculation of " "GetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be " "output's actual lod is different among operators so that should be "
"set in the runtime kernel."); "set in the runtime kernel."));
} }
void SetLoDLevel(const std::string& out, int32_t lod_level, void SetLoDLevel(const std::string& out, int32_t lod_level,
size_t j = 0) const override { size_t j = 0) const override {
PADDLE_THROW( PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetLoDLevel is only used in compile time. The calculation of " "SetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be " "output's actual lod is different among operators so that should be "
"set in the runtime kernel."); "set in the runtime kernel."));
} }
bool IsRuntime() const override { return true; } bool IsRuntime() const override { return true; }
...@@ -794,9 +846,11 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -794,9 +846,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
DDim GetInputDim(const std::string& name) const override { DDim GetInputDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name); const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ(vars.size(), 1UL, PADDLE_ENFORCE_EQ(
"Input(%s) should hold one element, but now it holds %d", vars.size(), 1UL,
name, vars.size()); platform::errors::InvalidArgument(
"Input(%s) should hold one element, but now it holds %zu elements.",
name, vars.size()));
return this->GetDim(vars[0]); return this->GetDim(vars[0]);
} }
...@@ -817,9 +871,11 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -817,9 +871,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
void SetOutputDim(const std::string& name, const DDim& dim) override { void SetOutputDim(const std::string& name, const DDim& dim) override {
auto& vars = OutputVars(name); auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ(vars.size(), 1UL, PADDLE_ENFORCE_EQ(
"Output(%s) should hold one element, but now it holds %d", vars.size(), 1UL,
name, vars.size()); platform::errors::InvalidArgument("Output(%s) should hold one element, "
"but now it holds %zu elements.",
name, vars.size()));
SetDim(vars[0], dim); SetDim(vars[0], dim);
} }
...@@ -831,16 +887,17 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -831,16 +887,17 @@ class RuntimeInferShapeContext : public InferShapeContext {
protected: protected:
DDim GetDim(Variable* var) const { DDim GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input variable is nullptr."));
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims(); return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims(); return var->Get<SelectedRows>().GetCompleteDims();
} else { } else {
PADDLE_THROW( PADDLE_THROW(platform::errors::InvalidArgument(
"Only LoDTensor/SelectedRows support 'GetDim', but Variables " "Only LoDTensor or SelectedRows support 'GetDim', but input "
"type_id is %s.", "Variable's type is %s.",
ToTypeName(var->Type())); ToTypeName(var->Type())));
} }
} }
...@@ -853,7 +910,8 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -853,7 +910,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
std::vector<DDim> GetRepeatedDims(const std::string& name) const override { std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW("Only compile time support this method"); PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetRepeatedDims method only ban be used in compile time."));
} }
void SetDim(Variable* var, const DDim& dim) { void SetDim(Variable* var, const DDim& dim) {
...@@ -862,15 +920,22 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -862,15 +920,22 @@ class RuntimeInferShapeContext : public InferShapeContext {
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]); var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else { } else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", PADDLE_THROW(platform::errors::Unimplemented(
ToTypeName(var->Type())); "Variable type error, expect LoDTensor or SelectedRows, but received "
"(%s).",
ToTypeName(var->Type())));
} }
} }
void SetDims(const std::vector<Variable*>& vars, void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims) { const std::vector<DDim>& dims) {
size_t length = vars.size(); size_t length = vars.size();
PADDLE_ENFORCE_EQ(length, dims.size()); PADDLE_ENFORCE_EQ(length, dims.size(),
platform::errors::InvalidArgument(
"The number of input variables do not match the "
"number of input dimensions, the number of variables "
"is %zu, the number of dimensions is %zu.",
length, dims.size()));
for (size_t i = 0; i < length; ++i) { for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) { if (vars[i] == nullptr) {
continue; continue;
...@@ -881,7 +946,8 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -881,7 +946,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
void SetRepeatedDims(const std::string& name, void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override { const std::vector<DDim>& dims) override {
PADDLE_THROW("Only compile time support this method"); PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetRepeatedDims method only can be used in compile time."));
} }
std::vector<proto::VarType::Type> GetVarTypes( std::vector<proto::VarType::Type> GetVarTypes(
...@@ -901,16 +967,19 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -901,16 +967,19 @@ class RuntimeInferShapeContext : public InferShapeContext {
private: private:
const std::vector<Variable*>& InputVars(const std::string& name) const { const std::vector<Variable*>& InputVars(const std::string& name) const {
auto it = ctx_.inputs.find(name); auto it = ctx_.inputs.find(name);
PADDLE_ENFORCE(it != ctx_.inputs.end(), PADDLE_ENFORCE_NE(
"Operator %s does not have the input %s.", op_.Type(), name); it, ctx_.inputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the input (%s).", op_.Type(), name));
return it->second; return it->second;
} }
const std::vector<Variable*>& OutputVars(const std::string& name) const { const std::vector<Variable*>& OutputVars(const std::string& name) const {
auto it = ctx_.outputs.find(name); auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE(it != ctx_.outputs.end(), PADDLE_ENFORCE_NE(
"Operator %s does not have the outputs %s.", op_.Type(), it, ctx_.outputs.end(),
name); platform::errors::NotFound(
"Operator (%s) does not have the outputs (%s).", op_.Type(), name));
return it->second; return it->second;
} }
...@@ -928,10 +997,14 @@ static void CheckTensorNANOrInf(const std::string& op_type, ...@@ -928,10 +997,14 @@ static void CheckTensorNANOrInf(const std::string& op_type,
tensor.type() != proto::VarType::FP64) { tensor.type() != proto::VarType::FP64) {
return; return;
} }
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), PADDLE_ENFORCE_NE(
"Operator %s output Tensor %s contains Inf", op_type, name); framework::TensorContainsInf(tensor), true,
PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor), platform::errors::Fatal("Operator %s output Tensor %s contains Inf.",
"Operator %s output Tensor %s contains NAN", op_type, name); op_type, name));
PADDLE_ENFORCE_NE(
framework::TensorContainsNAN(tensor), true,
platform::errors::Fatal("Operator %s output Tensor %s contains NAN.",
op_type, name));
} }
void OperatorWithKernel::RuntimeInferShape(const Scope& scope, void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
...@@ -1074,10 +1147,11 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, ...@@ -1074,10 +1147,11 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
// check if op[type] has kernel registered. // check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels(); auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_); auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) { PADDLE_ENFORCE_NE(
PADDLE_THROW( kernels_iter, all_op_kernels.end(),
"There are no kernels which are registered in the %s operator.", type_); platform::errors::Unavailable(
} "There are no kernels which are registered in the %s operator.",
type_));
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
...@@ -1131,10 +1205,10 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, ...@@ -1131,10 +1205,10 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
} }
#endif #endif
if (kernel_iter == kernels.end()) { PADDLE_ENFORCE_NE(kernel_iter, kernels.end(),
PADDLE_THROW("op %s does not have kernel for %s", type_, platform::errors::NotFound(
KernelTypeToString(expected_kernel_key)); "Operator (%s) does not have kernel for %s.", type_,
} KernelTypeToString(expected_kernel_key)));
std::lock_guard<std::mutex> lock(cache_update_mutex_); std::lock_guard<std::mutex> lock(cache_update_mutex_);
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) { if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
...@@ -1149,13 +1223,14 @@ void OperatorWithKernel::TransferInplaceVarsBack( ...@@ -1149,13 +1223,14 @@ void OperatorWithKernel::TransferInplaceVarsBack(
for (auto& var_name : inplace_vars) { for (auto& var_name : inplace_vars) {
VLOG(3) << "share inplace var " + var_name + " back to it's original scope"; VLOG(3) << "share inplace var " + var_name + " back to it's original scope";
auto* origin_var = scope.FindVar(var_name); auto* origin_var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(origin_var, "The var[%s] should not be nullptr.", PADDLE_ENFORCE_NOT_NULL(origin_var,
var_name); platform::errors::InvalidArgument(
"The variable[%s] is nullptr.", var_name));
auto* original_tensor = auto* original_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(origin_var); GetMutableLoDTensorOrSelectedRowsValueFromVar(origin_var);
auto* var = transfer_scope.FindVar(var_name); auto* var = transfer_scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var, "The var[%s] should not be nullptr.", PADDLE_ENFORCE_NOT_NULL(var, platform::errors::InvalidArgument(
var_name); "The variable[%s] is nullptr.", var_name));
auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto original_dims = original_tensor->dims(); auto original_dims = original_tensor->dims();
original_tensor->ShareDataWith(*transformed_tensor); original_tensor->ShareDataWith(*transformed_tensor);
...@@ -1380,9 +1455,11 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType( ...@@ -1380,9 +1455,11 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
ParseInputDataType(ctx, name, &data_type); ParseInputDataType(ctx, name, &data_type);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
data_type, dafault_data_type, data_type, dafault_data_type,
"The Input Variable(%s) of %s Op used to determine kernel data type " platform::errors::InvalidArgument(
"is empty or not LoDTensor or SelectedRows or LoDTensorArray.", "The Input Variable(%s) of (%s) Operator used to determine kernel "
name, Type()); "data type is empty or not LoDTensor or SelectedRows or "
"LoDTensorArray.",
name, Type()));
return data_type; return data_type;
} }
......
...@@ -495,9 +495,9 @@ TEST(IndicateVarDataTypeTest, other) { ...@@ -495,9 +495,9 @@ TEST(IndicateVarDataTypeTest, other) {
EXPECT_TRUE( EXPECT_TRUE(
ex_msg.find( ex_msg.find(
"The Input Variable(Other) of " "The Input Variable(Other) of "
"indicate_other_data_type_test Op used to " "(indicate_other_data_type_test) Operator used to "
"determine kernel data type " "determine kernel data type "
"is empty or not LoDTensor or SelectedRows or LoDTensorArray") != "is empty or not LoDTensor or SelectedRows or LoDTensorArray.") !=
std::string::npos); std::string::npos);
} }
ASSERT_TRUE(caught); ASSERT_TRUE(caught);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册