提交 60b7bf6f 编写于 作者: D dongdaxiang

add infer_from_dataset for inference

上级 030c7e7e
......@@ -147,6 +147,8 @@ class HogwildWorker : public CPUWorkerBase {
std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_;
Scope* thread_scope_;
HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_;
};
class DownpourWorker : public HogwildWorker {
......
......@@ -22,6 +22,12 @@ namespace framework {
void HogwildWorker::Initialize(const TrainerDesc& desc) {
fetch_config_ = desc.fetch_config();
param_ = desc.hogwild_param();
skip_ops_.resize(param_.skip_ops_size());
LOG(WARNING) << "skip op size: " << skip_ops_.size();
for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i);
}
}
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
......@@ -92,9 +98,18 @@ void HogwildWorker::TrainFilesWithProfiler() {
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
for (size_t i = 0; i < ops_.size(); ++i) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
timeline.Start();
VLOG(3) << "Going to run op " << op_name[i];
ops_[i]->Run(*thread_scope_, place_);
if (!need_skip) {
ops_[i]->Run(*thread_scope_, place_);
}
VLOG(3) << "Op " << op_name[i] << " Finished";
timeline.Pause();
op_total_time[i] += timeline.ElapsedSec();
......@@ -127,7 +142,16 @@ void HogwildWorker::TrainFiles() {
int cur_batch;
while ((cur_batch = device_reader_->Next()) > 0) {
for (auto& op : ops_) {
op->Run(*thread_scope_, place_);
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op->Run(*thread_scope_, place_);
}
}
PrintFetchVars();
......
......@@ -89,6 +89,9 @@ class Hogwild(DeviceWorker):
trainer_desc(TrainerDesc): a TrainerDesc object
"""
trainer_desc.device_worker_name = "HogwildWorker"
if self.infer_:
# just ignore feed op for inference model
trainer_desc.hogwild_param.skip_ops.extend(["feed"])
class DownpourSGD(DeviceWorker):
......
......@@ -659,10 +659,12 @@ class Executor(object):
def infer_from_dataset(self,
program=None,
dataset=None,
fetch_list=None,
scope=None,
thread=0,
opt_info=None):
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100):
"""
The document of infer_from_dataset is almost the same as
train_from_dataset, except that in distributed training,
......@@ -711,8 +713,8 @@ class Executor(object):
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
trainer._gen_trainer_desc()
trainer._set_infer(True)
trainer._gen_trainer_desc()
dataset._prepare_to_run()
if debug:
self._dump_debug_info(program=program, trainer=trainer)
......
......@@ -98,7 +98,7 @@ class DistMultiTrainer(TrainerDesc):
super(DistMultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer"
if self.program_ == None:
print("None program")
raise RuntimeError("None Program")
self.device_worker_._set_infer(self.infer_)
self.device_worker_._set_program(self.program_)
self.device_worker_._gen_worker_desc(self.proto_desc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册