提交 cb0b81f9 编写于 作者: Y Yang Yang

add << lodtensor

上级 f879ef23
...@@ -43,6 +43,22 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) { ...@@ -43,6 +43,22 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) {
return os; return os;
} }
std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
PADDLE_ENFORCE(platform::is_cpu_place(t.place()));
PADDLE_ENFORCE(t.type().hash_code() == typeid(float).hash_code());
os << "dim: " << t.dims() << "\n";
os << "lod: " << t.lod() << "\n";
// only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) {
os << t.data<float>()[i] << " ";
}
return os;
}
LoD SliceLevels(const LoD &in, size_t level_begin, size_t level_end) { LoD SliceLevels(const LoD &in, size_t level_begin, size_t level_end) {
LoD new_lod; LoD new_lod;
new_lod.reserve(level_end - level_begin); new_lod.reserve(level_end - level_begin);
......
...@@ -58,6 +58,7 @@ using Vector = thrust::host_vector< ...@@ -58,6 +58,7 @@ using Vector = thrust::host_vector<
using LoD = std::vector<Vector<size_t>>; using LoD = std::vector<Vector<size_t>>;
std::ostream& operator<<(std::ostream& os, const LoD& lod); std::ostream& operator<<(std::ostream& os, const LoD& lod);
std::ostream& operator<<(std::ostream& os, const LoDTensor& t);
/* /*
* Slice levels from a LoD. * Slice levels from a LoD.
......
...@@ -156,11 +156,12 @@ class ParallelDoGradOp : public OperatorBase { ...@@ -156,11 +156,12 @@ class ParallelDoGradOp : public OperatorBase {
for (auto &s : Inputs(framework::GradVarName(kOutputs))) { for (auto &s : Inputs(framework::GradVarName(kOutputs))) {
LOG(INFO) << s; LOG(INFO) << s;
LOG(INFO) << scope.FindVar(s)->Get<LoDTensor>().dims(); LOG(INFO) << scope.FindVar(s)->Get<LoDTensor>();
for (auto *sub_scope : sub_scopes) { for (auto *sub_scope : sub_scopes) {
LOG(INFO) << sub_scope->FindVar(s)->Get<LoDTensor>().dims(); LOG(INFO) << sub_scope->FindVar(s)->Get<LoDTensor>();
} }
} }
// exe run // exe run
for (int place_idx = 0; place_idx < places.size(); ++place_idx) { for (int place_idx = 0; place_idx < places.size(); ++place_idx) {
VLOG(3) << "Run " << place_idx; VLOG(3) << "Run " << place_idx;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册