未验证 提交 23df6c44 编写于 作者: Q Qiao Longfei 提交者: GitHub

Add get lod for debug (#7375)

* add GetLoD for debug

* add LoDToString

* optimize if

* typo

* add lod_tensor to operator's dependency
上级 3423022e
...@@ -47,7 +47,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) ...@@ -47,7 +47,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform) shape_inference data_transform lod_tensor)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
......
...@@ -69,6 +69,12 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { ...@@ -69,6 +69,12 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
return os; return os;
} }
std::string LoDToString(const LoD &lod) {
std::ostringstream stream;
stream << lod;
return stream.str();
}
LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin, LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin,
size_t elem_end) { size_t elem_end) {
PADDLE_ENFORCE_LT(level, in.size()); PADDLE_ENFORCE_LT(level, in.size());
......
...@@ -60,6 +60,8 @@ using LoD = std::vector<Vector<size_t>>; ...@@ -60,6 +60,8 @@ using LoD = std::vector<Vector<size_t>>;
std::ostream& operator<<(std::ostream& os, const LoD& lod); std::ostream& operator<<(std::ostream& os, const LoD& lod);
std::ostream& operator<<(std::ostream& os, const LoDTensor& t); std::ostream& operator<<(std::ostream& os, const LoDTensor& t);
std::string LoDToString(const LoD& lod);
LoD SliceInLevel(const LoD& in, size_t level, size_t elem_begin, LoD SliceInLevel(const LoD& in, size_t level, size_t elem_begin,
size_t elem_end); size_t elem_end);
/* /*
......
...@@ -80,7 +80,9 @@ static DDim GetDims(const Scope& scope, const std::string& name) { ...@@ -80,7 +80,9 @@ static DDim GetDims(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
if (var == nullptr) { if (var == nullptr) {
return DDim({-1}); return DDim({-1});
} else if (var->IsType<LoDTensor>()) { }
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims(); return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims(); return var->Get<SelectedRows>().GetCompleteDims();
...@@ -89,6 +91,21 @@ static DDim GetDims(const Scope& scope, const std::string& name) { ...@@ -89,6 +91,21 @@ static DDim GetDims(const Scope& scope, const std::string& name) {
} }
} }
static LoD GetLoD(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
auto default_lod = LoD({{}});
if (var == nullptr) {
return default_lod;
}
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().lod();
} else {
return default_lod;
}
}
std::string OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL, PADDLE_ENFORCE_LE(ins.size(), 1UL,
...@@ -130,7 +147,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { ...@@ -130,7 +147,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < input.second.size(); ++i) { for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i]; ss << input.second[i];
if (scope) { if (scope) {
ss << "(" << GetDims(*scope, input.second[i]) << ")"; ss << "[" << GetDims(*scope, input.second[i]) << "]";
ss << "(" << GetLoD(*scope, input.second[i]) << ")";
} }
if (i != input.second.size() - 1) { if (i != input.second.size() - 1) {
ss << ", "; ss << ", ";
...@@ -149,7 +167,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { ...@@ -149,7 +167,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < output.second.size(); ++i) { for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i]; ss << output.second[i];
if (scope) { if (scope) {
ss << "(" << GetDims(*scope, output.second[i]) << ")"; ss << "[" << GetDims(*scope, output.second[i]) << "]";
ss << "(" << GetLoD(*scope, output.second[i]) << ")";
} }
if (i != output.second.size() - 1) { if (i != output.second.size() - 1) {
ss << ", "; ss << ", ";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册