提交 2afe82fe 编写于 作者: Q Qiao Longfei

fix ctr reader read svm data

test=develop
上级 488719ba
...@@ -213,7 +213,7 @@ void ReadSvmData(const DataDesc& data_desc, std::shared_ptr<Reader> reader, ...@@ -213,7 +213,7 @@ void ReadSvmData(const DataDesc& data_desc, std::shared_ptr<Reader> reader,
framework::LoD lod{lod_data}; framework::LoD lod{lod_data};
lod_tensor.set_lod(lod); lod_tensor.set_lod(lod);
int64_t* tensor_data = lod_tensor.mutable_data<int64_t>( int64_t* tensor_data = lod_tensor.mutable_data<int64_t>(
framework::make_ddim({1, static_cast<int64_t>(batch_feasign.size())}), framework::make_ddim({static_cast<int64_t>(batch_feasign.size()), 1}),
platform::CPUPlace()); platform::CPUPlace());
memcpy(tensor_data, batch_feasign.data(), memcpy(tensor_data, batch_feasign.data(),
batch_feasign.size() * sizeof(int64_t)); batch_feasign.size() * sizeof(int64_t));
...@@ -223,7 +223,7 @@ void ReadSvmData(const DataDesc& data_desc, std::shared_ptr<Reader> reader, ...@@ -223,7 +223,7 @@ void ReadSvmData(const DataDesc& data_desc, std::shared_ptr<Reader> reader,
// insert label tensor // insert label tensor
framework::LoDTensor label_tensor; framework::LoDTensor label_tensor;
auto* label_tensor_data = label_tensor.mutable_data<int64_t>( auto* label_tensor_data = label_tensor.mutable_data<int64_t>(
framework::make_ddim({1, static_cast<int64_t>(batch_label.size())}), framework::make_ddim({static_cast<int64_t>(batch_label.size()), 1}),
platform::CPUPlace()); platform::CPUPlace());
memcpy(label_tensor_data, batch_label.data(), memcpy(label_tensor_data, batch_label.data(),
batch_label.size() * sizeof(int64_t)); batch_label.size() * sizeof(int64_t));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册