提交 77599415 编写于 作者: T tensor-tang

enable read dataset

上级 c00843f4
...@@ -14,7 +14,12 @@ limitations under the License. */ ...@@ -14,7 +14,12 @@ limitations under the License. */
#include <sys/time.h> #include <sys/time.h>
#include <time.h> #include <time.h>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/inference/tests/test_helper.h"
...@@ -31,16 +36,37 @@ inline double get_current_ms() { ...@@ -31,16 +36,37 @@ inline double get_current_ms() {
return 1e+3 * time.tv_sec + 1e-3 * time.tv_usec; return 1e+3 * time.tv_sec + 1e-3 * time.tv_usec;
} }
void read_data(
std::vector<std::vector<int64_t>>* out,
const std::string& filename = "/home/tangjian/paddle-tj/out.ids.txt") {
using namespace std; // NOLINT
fstream fin(filename);
string line;
out->clear();
while (getline(fin, line)) {
istringstream iss(line);
vector<int64_t> ids;
string field;
while (getline(iss, field, ' ')) {
ids.push_back(stoi(field));
}
out->push_back(ids);
}
}
TEST(inference, understand_sentiment) { TEST(inference, understand_sentiment) {
if (FLAGS_dirname.empty()) { if (FLAGS_dirname.empty()) {
LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
} }
std::vector<std::vector<int64_t>> inputdatas;
read_data(&inputdatas);
LOG(INFO) << "---------- dataset size: " << inputdatas.size();
LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::string dirname = FLAGS_dirname; std::string dirname = FLAGS_dirname;
const bool model_combined = false; const bool model_combined = false;
int total_work = 100; int total_work = 10;
int num_threads = 10; int num_threads = 2;
int work_per_thread = total_work / num_threads; int work_per_thread = total_work / num_threads;
std::vector<std::unique_ptr<std::thread>> infer_threads; std::vector<std::unique_ptr<std::thread>> infer_threads;
for (int i = 0; i < num_threads; ++i) { for (int i = 0; i < num_threads; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册