未验证 提交 e48dd92f 编写于 作者: F flame 提交者: GitHub

bug fix (#17392)

fix secure bug
上级 34369944
...@@ -36,6 +36,8 @@ struct DataReader { ...@@ -36,6 +36,8 @@ struct DataReader {
tensor.lod.front().push_back(data.size()); tensor.lod.front().push_back(data.size());
tensor.data.Resize(data.size() * sizeof(int64_t)); tensor.data.Resize(data.size() * sizeof(int64_t));
CHECK(tensor.data.data() != nullptr);
CHECK(data.data() != nullptr);
memcpy(tensor.data.data(), data.data(), data.size() * sizeof(int64_t)); memcpy(tensor.data.data(), data.data(), data.size() * sizeof(int64_t));
tensor.shape.push_back(data.size()); tensor.shape.push_back(data.size());
tensor.shape.push_back(1); tensor.shape.push_back(1);
...@@ -87,7 +89,12 @@ TEST(Analyzer_Text_Classification, profile) { ...@@ -87,7 +89,12 @@ TEST(Analyzer_Text_Classification, profile) {
CHECK_EQ(output.lod.size(), 0UL); CHECK_EQ(output.lod.size(), 0UL);
LOG(INFO) << "output.dtype: " << output.dtype; LOG(INFO) << "output.dtype: " << output.dtype;
std::stringstream ss; std::stringstream ss;
for (int i = 0; i < 5; i++) { int num_data = 1;
for (auto i : output.shape) {
num_data *= i;
}
for (int i = 0; i < num_data; i++) {
ss << static_cast<float *>(output.data.data())[i] << " "; ss << static_cast<float *>(output.data.data())[i] << " ";
} }
LOG(INFO) << "output.data summary: " << ss.str(); LOG(INFO) << "output.data summary: " << ss.str();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册