提交 dba69287 编写于 作者: T tensor-tang

fix lod tensor

test=develop
上级 6094a723
...@@ -74,7 +74,7 @@ def create_lod_tensor(data, recursive_seq_lens, place): ...@@ -74,7 +74,7 @@ def create_lod_tensor(data, recursive_seq_lens, place):
assert [ assert [
new_recursive_seq_lens new_recursive_seq_lens
] == recursive_seq_lens, "data and recursive_seq_lens do not match" ] == recursive_seq_lens, "data and recursive_seq_lens do not match"
flattened_data = np.concatenate(data, axis=0).astype("int64") flattened_data = np.concatenate(data, axis=0)
flattened_data = flattened_data.reshape([len(flattened_data), 1]) flattened_data = flattened_data.reshape([len(flattened_data), 1])
return create_lod_tensor(flattened_data, recursive_seq_lens, place) return create_lod_tensor(flattened_data, recursive_seq_lens, place)
elif isinstance(data, np.ndarray): elif isinstance(data, np.ndarray):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册