diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 2a6c2fa5b386aff3c63fc9c23b40271587f6c946..a7a8663ec3b1c436104f53b6db833bd26f6722f0 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -147,6 +147,8 @@ class HogwildWorker : public CPUWorkerBase { std::vector op_names_; std::vector ops_; Scope* thread_scope_; + HogwildWorkerParameter param_; + std::vector skip_ops_; }; class DownpourWorker : public HogwildWorker { diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index 1f5389c9c5e5389b2a4a54f33799a19fcdda1901..d1a262f7d08caf7a57adb54ad118438bc7c9d60c 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -22,6 +22,12 @@ namespace framework { void HogwildWorker::Initialize(const TrainerDesc& desc) { fetch_config_ = desc.fetch_config(); + param_ = desc.hogwild_param(); + skip_ops_.resize(param_.skip_ops_size()); + LOG(WARNING) << "skip op size: " << skip_ops_.size(); + for (size_t i = 0; i < param_.skip_ops_size(); ++i) { + skip_ops_[i] = param_.skip_ops(i); + } } void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) { @@ -92,9 +98,18 @@ void HogwildWorker::TrainFilesWithProfiler() { read_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec(); for (size_t i = 0; i < ops_.size(); ++i) { + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } timeline.Start(); VLOG(3) << "Going to run op " << op_name[i]; - ops_[i]->Run(*thread_scope_, place_); + if (!need_skip) { + ops_[i]->Run(*thread_scope_, place_); + } VLOG(3) << "Op " << op_name[i] << " Finished"; timeline.Pause(); op_total_time[i] += timeline.ElapsedSec(); @@ -127,7 +142,16 @@ void HogwildWorker::TrainFiles() { int cur_batch; while ((cur_batch = device_reader_->Next()) > 0) { for (auto& op : ops_) { - op->Run(*thread_scope_, place_); + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + op->Run(*thread_scope_, place_); + } } PrintFetchVars(); diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index b636725eb359cf3fe0dd9f93146df297b552b958..21d50749f6cdcf9c3fa13e687d59ada67334355d 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -89,6 +89,9 @@ class Hogwild(DeviceWorker): trainer_desc(TrainerDesc): a TrainerDesc object """ trainer_desc.device_worker_name = "HogwildWorker" + if self.infer_: + # just ignore feed op for inference model + trainer_desc.hogwild_param.skip_ops.extend(["feed"]) class DownpourSGD(DeviceWorker): diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index d609b88fe5e755c8e2d6a065843324898978df38..c75a613d9a74886977c4be751f811280a81af772 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -659,10 +659,12 @@ class Executor(object): def infer_from_dataset(self, program=None, dataset=None, - fetch_list=None, scope=None, thread=0, - opt_info=None): + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100): """ The document of infer_from_dataset is almost the same as train_from_dataset, except that in distributed training, @@ -711,8 +713,8 @@ class Executor(object): fetch_list=fetch_list, fetch_info=fetch_info, print_period=print_period) - trainer._gen_trainer_desc() trainer._set_infer(True) + trainer._gen_trainer_desc() dataset._prepare_to_run() if debug: self._dump_debug_info(program=program, trainer=trainer) diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 5a19907ab699e7d9f58f8f589334fddec3c1abaf..9b6ec8fb2e10c2cfa02a1e5109505d3192a6e887 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -98,7 +98,7 @@ class DistMultiTrainer(TrainerDesc): super(DistMultiTrainer, self)._gen_trainer_desc() self.proto_desc.class_name = "DistMultiTrainer" if self.program_ == None: - print("None program") + raise RuntimeError("None Program") self.device_worker_._set_infer(self.infer_) self.device_worker_._set_program(self.program_) self.device_worker_._gen_worker_desc(self.proto_desc)