未验证 提交 073af623 编写于 作者: K Kexin Zhao 提交者: GitHub

add print lod_tensor int64 option (#11644)

上级 833a6f36
......@@ -51,8 +51,6 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) {
}
std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
PADDLE_ENFORCE(t.type().hash_code() == typeid(float).hash_code());
if (!platform::is_cpu_place(t.place())) {
LoDTensor tt;
framework::TensorCopy(t, platform::CPUPlace(), &tt);
......@@ -70,7 +68,13 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) {
os << t.data<float>()[i] << " ";
if (t.type().hash_code() == typeid(float).hash_code()) {
os << t.data<float>()[i] << " ";
} else if (t.type().hash_code() == typeid(int64_t).hash_code()) {
os << t.data<int64_t>()[i] << " ";
} else {
PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
}
}
return os;
......
......@@ -26,6 +26,20 @@
namespace paddle {
namespace framework {
TEST(LoD, PrintLoDTensor) {
LoDTensor tensor1;
tensor1.mutable_data<float>(platform::CPUPlace());
tensor1.data<float>()[0] = 0.2;
tensor1.data<float>()[1] = 0.5;
LOG(INFO) << tensor1;
LoDTensor tensor2;
tensor2.mutable_data<int64_t>(platform::CPUPlace());
tensor2.data<int64_t>()[0] = 1;
tensor2.data<int64_t>()[1] = 2;
LOG(INFO) << tensor2;
}
TEST(LoD, data) {
LoD lod{{0, 1, 2}};
lod.push_back({0, 2, 4, 5});
......@@ -37,7 +51,7 @@ TEST(LoD, data) {
}
}
TEST(LodExpand, test) {
TEST(LoD, ExpandLoD) {
LoD lod{{0, 2}};
LoDTensor tensor;
tensor.set_lod(lod);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册