提交 60b7bf6f 编写于 作者: D dongdaxiang

add infer_from_dataset for inference

上级 030c7e7e
...@@ -147,6 +147,8 @@ class HogwildWorker : public CPUWorkerBase { ...@@ -147,6 +147,8 @@ class HogwildWorker : public CPUWorkerBase {
std::vector<std::string> op_names_; std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_; std::vector<OperatorBase*> ops_;
Scope* thread_scope_; Scope* thread_scope_;
HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_;
}; };
class DownpourWorker : public HogwildWorker { class DownpourWorker : public HogwildWorker {
......
...@@ -22,6 +22,12 @@ namespace framework { ...@@ -22,6 +22,12 @@ namespace framework {
void HogwildWorker::Initialize(const TrainerDesc& desc) { void HogwildWorker::Initialize(const TrainerDesc& desc) {
fetch_config_ = desc.fetch_config(); fetch_config_ = desc.fetch_config();
param_ = desc.hogwild_param();
skip_ops_.resize(param_.skip_ops_size());
LOG(WARNING) << "skip op size: " << skip_ops_.size();
for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i);
}
} }
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) { void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
...@@ -92,9 +98,18 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -92,9 +98,18 @@ void HogwildWorker::TrainFilesWithProfiler() {
read_time += timeline.ElapsedSec(); read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
for (size_t i = 0; i < ops_.size(); ++i) { for (size_t i = 0; i < ops_.size(); ++i) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
timeline.Start(); timeline.Start();
VLOG(3) << "Going to run op " << op_name[i]; VLOG(3) << "Going to run op " << op_name[i];
ops_[i]->Run(*thread_scope_, place_); if (!need_skip) {
ops_[i]->Run(*thread_scope_, place_);
}
VLOG(3) << "Op " << op_name[i] << " Finished"; VLOG(3) << "Op " << op_name[i] << " Finished";
timeline.Pause(); timeline.Pause();
op_total_time[i] += timeline.ElapsedSec(); op_total_time[i] += timeline.ElapsedSec();
...@@ -127,7 +142,16 @@ void HogwildWorker::TrainFiles() { ...@@ -127,7 +142,16 @@ void HogwildWorker::TrainFiles() {
int cur_batch; int cur_batch;
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(*thread_scope_, place_); bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op->Run(*thread_scope_, place_);
}
} }
PrintFetchVars(); PrintFetchVars();
......
...@@ -89,6 +89,9 @@ class Hogwild(DeviceWorker): ...@@ -89,6 +89,9 @@ class Hogwild(DeviceWorker):
trainer_desc(TrainerDesc): a TrainerDesc object trainer_desc(TrainerDesc): a TrainerDesc object
""" """
trainer_desc.device_worker_name = "HogwildWorker" trainer_desc.device_worker_name = "HogwildWorker"
if self.infer_:
# just ignore feed op for inference model
trainer_desc.hogwild_param.skip_ops.extend(["feed"])
class DownpourSGD(DeviceWorker): class DownpourSGD(DeviceWorker):
......
...@@ -659,10 +659,12 @@ class Executor(object): ...@@ -659,10 +659,12 @@ class Executor(object):
def infer_from_dataset(self, def infer_from_dataset(self,
program=None, program=None,
dataset=None, dataset=None,
fetch_list=None,
scope=None, scope=None,
thread=0, thread=0,
opt_info=None): debug=False,
fetch_list=None,
fetch_info=None,
print_period=100):
""" """
The document of infer_from_dataset is almost the same as The document of infer_from_dataset is almost the same as
train_from_dataset, except that in distributed training, train_from_dataset, except that in distributed training,
...@@ -711,8 +713,8 @@ class Executor(object): ...@@ -711,8 +713,8 @@ class Executor(object):
fetch_list=fetch_list, fetch_list=fetch_list,
fetch_info=fetch_info, fetch_info=fetch_info,
print_period=print_period) print_period=print_period)
trainer._gen_trainer_desc()
trainer._set_infer(True) trainer._set_infer(True)
trainer._gen_trainer_desc()
dataset._prepare_to_run() dataset._prepare_to_run()
if debug: if debug:
self._dump_debug_info(program=program, trainer=trainer) self._dump_debug_info(program=program, trainer=trainer)
......
...@@ -98,7 +98,7 @@ class DistMultiTrainer(TrainerDesc): ...@@ -98,7 +98,7 @@ class DistMultiTrainer(TrainerDesc):
super(DistMultiTrainer, self)._gen_trainer_desc() super(DistMultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer" self.proto_desc.class_name = "DistMultiTrainer"
if self.program_ == None: if self.program_ == None:
print("None program") raise RuntimeError("None Program")
self.device_worker_._set_infer(self.infer_) self.device_worker_._set_infer(self.infer_)
self.device_worker_._set_program(self.program_) self.device_worker_._set_program(self.program_)
self.device_worker_._gen_worker_desc(self.proto_desc) self.device_worker_._gen_worker_desc(self.proto_desc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册