提交 87a8caa8 编写于 作者: T tangwei12 提交者: Dong Daxiang

Refactor fetch handler (#21264) (#21537)

* fix fetch handler problem and refactor
when a user define FetchHandler class, he or she should initialize a handler
with variable dict. the key of a variable dict is a user defined name,
the value of a variable dict is a Varaible generated from python API.

For each fetching, a user should implement handler function in which
fetched_result_dict will be available and the user can access the fetched value
with user defined keys.
上级 20a09375
...@@ -186,6 +186,9 @@ void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) { ...@@ -186,6 +186,9 @@ void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
// training and finalize training // training and finalize training
VLOG(3) << "Trainer starts to run"; VLOG(3) << "Trainer starts to run";
trainer->Run(); trainer->Run();
}
void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
VLOG(3) << "Trainer going to finalize"; VLOG(3) << "Trainer going to finalize";
trainer->Finalize(); trainer->Finalize();
} }
......
...@@ -126,6 +126,8 @@ class Executor { ...@@ -126,6 +126,8 @@ class Executor {
Scope* scope, Dataset* dataset); Scope* scope, Dataset* dataset);
void RunFromDataset(std::shared_ptr<TrainerBase> trainer); void RunFromDataset(std::shared_ptr<TrainerBase> trainer);
void ReleaseTrainer(std::shared_ptr<TrainerBase> trainer);
const platform::Place GetPlace() const { return place_; } const platform::Place GetPlace() const { return place_; }
private: private:
......
...@@ -77,14 +77,12 @@ void MultiTrainer::Run() { ...@@ -77,14 +77,12 @@ void MultiTrainer::Run() {
workers_[thidx].get())); workers_[thidx].get()));
} }
} }
}
void MultiTrainer::Finalize() {
for (auto& th : threads_) { for (auto& th : threads_) {
th.join(); th.join();
} }
root_scope_->DropKids();
} }
void MultiTrainer::Finalize() { root_scope_->DropKids(); }
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -1364,10 +1364,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1364,10 +1364,13 @@ All parameter, weight, gradient are variables in Paddle.
.def("close", &Executor::Close) .def("close", &Executor::Close)
.def("run_from_dataset", &Executor::RunFromDataset, .def("run_from_dataset", &Executor::RunFromDataset,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("release_trainer", &Executor::ReleaseTrainer,
py::call_guard<py::gil_scoped_release>())
.def("init_for_dataset", .def("init_for_dataset",
[](Executor &self, const ProgramDesc &prog, [](Executor &self, const ProgramDesc &prog,
const std::string &trainer_desc, Scope *scope, const std::string &trainer_desc, Scope *scope,
Dataset *dataset) -> std::shared_ptr<TrainerBase> { Dataset *dataset) -> std::shared_ptr<TrainerBase> {
pybind11::gil_scoped_release release;
return self.InitForDataset(prog, trainer_desc, scope, dataset); return self.InitForDataset(prog, trainer_desc, scope, dataset);
}) })
.def("run_from_dataset", .def("run_from_dataset",
......
...@@ -395,23 +395,28 @@ def _as_lodtensor(data, place): ...@@ -395,23 +395,28 @@ def _as_lodtensor(data, place):
class FetchHandler(object): class FetchHandler(object):
def __init__(self, fetch_target_names, period_secs=60, return_np=True): def __init__(self, var_dict=None, period_secs=60):
self.fetch_target_names = fetch_target_names assert var_dict != None
self.var_dict = var_dict
self.period_secs = period_secs self.period_secs = period_secs
self.return_np = return_np
def handler(self, fetch_target_vars): def handler(self, res_dict):
return for key in res_dict:
if type(res_dict[key]) is np.ndarray:
sys.stdout.write("{}[0]: {} ".format(key, res_dict[key][0]))
sys.stdout.write("\n")
@staticmethod @staticmethod
def help(): def help():
print(""" print("""
class FetchHandlerExamlpe(FetchHandler): class FetchHandlerExample(FetchHandler):
def handler(self, fetch_target_vars): def handler(self, res_dict):
b_auc = fetch_target_vars[0] print(res_dict["auc"])
g_auc = fetch_target_vars[1] print("auc: {}, {}".format(res_dict["auc"], time.ctime()))
print("b_auc: {}, g_auc: {} at time: {}".format(b_auc, g_auc, time.ctime())) auc = Variable()
var_dict = {"auc": auc}
handler = FetchHandlerExample(var_dict=var_dict)
""") """)
...@@ -1019,13 +1024,13 @@ class Executor(object): ...@@ -1019,13 +1024,13 @@ class Executor(object):
scope0 = trainer_instance.get_worker_scope(0) scope0 = trainer_instance.get_worker_scope(0)
fetch_monitor = FetchHandlerMonitor(scope0, fetch_handler) fetch_monitor = FetchHandlerMonitor(scope0, fetch_handler)
fetch_monitor.start() fetch_monitor.start()
self._default_executor.run_from_dataset(trainer_instance) self._default_executor.run_from_dataset(trainer_instance)
fetch_monitor.stop() fetch_monitor.stop()
self._default_executor.release_trainer(trainer_instance)
else: else:
self._default_executor.run_from_dataset(trainer_instance) self._default_executor.run_from_dataset(trainer_instance)
self._default_executor.release_trainer(trainer_instance)
dataset._dynamic_adjust_after_train() dataset._dynamic_adjust_after_train()
dataset._finish_to_run() dataset._finish_to_run()
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import time import time
import unittest import unittest
import numpy as np import numpy as np
from paddle.fluid.framework import Program
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -29,20 +30,35 @@ class TestFetchHandler(unittest.TestCase): ...@@ -29,20 +30,35 @@ class TestFetchHandler(unittest.TestCase):
table = np.random.random((3, 10)).astype("float32") table = np.random.random((3, 10)).astype("float32")
prog = Program()
block = prog.current_block()
var_emb = block.create_var(name='emb', type=core.VarDesc.VarType.FP32)
var_emb3 = block.create_var(name='emb3', type=core.VarDesc.VarType.FP32)
class FH(fluid.executor.FetchHandler): class FH(fluid.executor.FetchHandler):
def handler(self, fetch_target_vars): def handler(self, fetch_dict):
assert len(fetch_target_vars) == 1 assert len(fetch_dict) == 1
table_var = scope.var('emb').get_tensor() table_var = scope.var('emb').get_tensor()
table_var.set(table, place) table_var.set(table, place)
fh = FH(var_dict={'emb': var_emb}, period_secs=2)
fh = FH(['emb'], period_secs=2, return_np=True)
fm = fluid.trainer_factory.FetchHandlerMonitor(scope, fh) fm = fluid.trainer_factory.FetchHandlerMonitor(scope, fh)
fm.start() fm.start()
time.sleep(10) time.sleep(3)
fm.stop() fm.stop()
default_fh = fluid.executor.FetchHandler(
var_dict={'emb': var_emb,
'emb2': None,
'emb3': var_emb3},
period_secs=1)
default_fm = fluid.trainer_factory.FetchHandlerMonitor(scope,
default_fh)
default_fm.start()
time.sleep(5)
default_fm.stop()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -15,11 +15,15 @@ ...@@ -15,11 +15,15 @@
import threading import threading
import time import time
import logging
import numpy as np import numpy as np
logging.basicConfig()
from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer
from .device_worker import Hogwild, DownpourSGD, Section from .device_worker import Hogwild, DownpourSGD, Section
from .framework import Variable
from multiprocessing import Process, Manager
__all__ = ["TrainerFactory", "FetchHandler", "FetchHandlerMonitor"] __all__ = ["TrainerFactory", "FetchHandler", "FetchHandlerMonitor"]
...@@ -93,68 +97,74 @@ class FetchHandlerMonitor(object): ...@@ -93,68 +97,74 @@ class FetchHandlerMonitor(object):
def __init__(self, scope, handler): def __init__(self, scope, handler):
self.fetch_instance = handler self.fetch_instance = handler
self.fetch_thread = threading.Thread( self.fetch_thread = threading.Thread(
target=self.handler_decorator, target=self.handler_launch_func, args=(scope, self.fetch_instance))
args=(scope, self.fetch_instance.handler)) self.running_lock = threading.Lock()
self.running = False self.running = False
def handler_launch_func(self, scope, handler):
fetch_instance = handler
period_secs = fetch_instance.period_secs
var_name_to_key = {}
for key in fetch_instance.var_dict:
if isinstance(fetch_instance.var_dict[key], Variable):
var_name_to_key[fetch_instance.var_dict[key].name] = key
else:
logging.warning("the value of {} is not a Variable".format(key))
var_name_to_key["None.var"] = key
elapsed_secs = 0
while True:
self.running_lock.acquire()
if self.running == False:
break
if elapsed_secs < period_secs:
# TODO(guru4elephant): needs customized condition
time.sleep(1)
elapsed_secs += 1
else:
elapsed_secs = 0
fetch_dict = {}
for key in var_name_to_key:
var = scope.find_var(key)
fetch_dict[key] = var
if var == None:
logging.warning("{} value currently not available".
format(var_name_to_key[key]))
res_dict = {}
for key in fetch_dict:
user_name = var_name_to_key[key]
if fetch_dict[key] == None:
res_dict[user_name] = None
continue
else:
res_dict[user_name] = fetch_dict[key].get_tensor()
lod = res_dict[user_name].lod()
if len(lod) > 0:
raise RuntimeError("Some of your fetched tensors \
hold LoD information. \
They can not be completely cast \
to Python ndarray. We can \
not return LoDTensor itself directly, \
please choose another targets")
if res_dict[user_name]._is_initialized():
res_dict[user_name] = np.array(res_dict[user_name])
else:
res_dict[user_name] = None
fetch_instance.handler(res_dict)
self.running_lock.release()
def start(self): def start(self):
""" """
start monitor, start monitor,
it will start a monitor thread. it will start a monitor thread.
""" """
self.running_lock.acquire()
self.running = True self.running = True
self.running_lock.release()
self.fetch_thread.setDaemon(True) self.fetch_thread.setDaemon(True)
self.fetch_thread.start() self.fetch_thread.start()
def handler_decorator(self, fetch_scope, fetch_handler):
"""
decorator of handler,
Args:
fetch_scope(Scope): fetch scope
fetch_handler(Handler): fetch handler
"""
fetch_target_names = self.fetch_instance.fetch_target_names
period_secs = self.fetch_instance.period_secs
elapsed_secs = 0
while True:
while self.running and elapsed_secs >= period_secs:
elapsed_secs = 0
fetch_vars = [
fetch_scope.find_var(varname)
for varname in fetch_target_names
]
if None in fetch_vars:
continue
fetch_tensors = [var.get_tensor() for var in fetch_vars]
if self.fetch_instance.return_np:
fetch_nps = []
for tensor in fetch_tensors:
lod = tensor.lod()
if len(lod) > 0:
raise RuntimeError(
"Some of your fetched tensors hold LoD information. \
They can not be completely cast to Python ndarray. We can not \
return LoDTensor itself directly, please choose another targets"
)
if tensor._is_initialized():
fetch_nps.append(np.array(tensor))
else:
fetch_nps.append(None)
fetch_handler(fetch_nps)
else:
fetch_handler(fetch_tensors)
else:
time.sleep(1)
elapsed_secs += 1
def stop(self): def stop(self):
self.running_lock.acquire()
self.running = False self.running = False
self.running_lock.release()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册