未验证 提交 691ced87 编写于 作者: D Dong Daxiang 提交者: GitHub

Refactor fetch handler (#21264)

* fix fetch handler problem and refactor
when a user define FetchHandler class, he or she should initialize a handler
with variable dict. the key of a variable dict is a user defined name,
the value of a variable dict is a Varaible generated from python API.

For each fetching, a user should implement handler function in which
fetched_result_dict will be available and the user can access the fetched value
with user defined keys.
上级 f1b09ba3
......@@ -185,6 +185,9 @@ void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
// training and finalize training
VLOG(3) << "Trainer starts to run";
trainer->Run();
}
void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
VLOG(3) << "Trainer going to finalize";
trainer->Finalize();
}
......
......@@ -120,6 +120,8 @@ class Executor {
Scope* scope, Dataset* dataset);
void RunFromDataset(std::shared_ptr<TrainerBase> trainer);
void ReleaseTrainer(std::shared_ptr<TrainerBase> trainer);
const platform::Place GetPlace() const { return place_; }
private:
......
......@@ -77,14 +77,12 @@ void MultiTrainer::Run() {
workers_[thidx].get()));
}
}
}
void MultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
root_scope_->DropKids();
}
void MultiTrainer::Finalize() { root_scope_->DropKids(); }
} // end namespace framework
} // end namespace paddle
......@@ -1324,10 +1324,13 @@ All parameter, weight, gradient are variables in Paddle.
.def("close", &Executor::Close)
.def("run_from_dataset", &Executor::RunFromDataset,
py::call_guard<py::gil_scoped_release>())
.def("release_trainer", &Executor::ReleaseTrainer,
py::call_guard<py::gil_scoped_release>())
.def("init_for_dataset",
[](Executor &self, const ProgramDesc &prog,
const std::string &trainer_desc, Scope *scope,
Dataset *dataset) -> std::shared_ptr<TrainerBase> {
pybind11::gil_scoped_release release;
return self.InitForDataset(prog, trainer_desc, scope, dataset);
})
.def("run_from_dataset",
......
......@@ -395,23 +395,28 @@ def _as_lodtensor(data, place):
class FetchHandler(object):
def __init__(self, fetch_target_names, period_secs=60, return_np=True):
self.fetch_target_names = fetch_target_names
def __init__(self, var_dict=None, period_secs=60):
assert var_dict != None
self.var_dict = var_dict
self.period_secs = period_secs
self.return_np = return_np
def handler(self, fetch_target_vars):
return
def handler(self, res_dict):
for key in res_dict:
if type(res_dict[key]) is np.ndarray:
sys.stdout.write("{}[0]: {} ".format(key, res_dict[key][0]))
sys.stdout.write("\n")
@staticmethod
def help():
print("""
class FetchHandlerExamlpe(FetchHandler):
def handler(self, fetch_target_vars):
b_auc = fetch_target_vars[0]
g_auc = fetch_target_vars[1]
print("b_auc: {}, g_auc: {} at time: {}".format(b_auc, g_auc, time.ctime()))
class FetchHandlerExample(FetchHandler):
def handler(self, res_dict):
print(res_dict["auc"])
print("auc: {}, {}".format(res_dict["auc"], time.ctime()))
auc = Variable()
var_dict = {"auc": auc}
handler = FetchHandlerExample(var_dict=var_dict)
""")
......@@ -1010,13 +1015,13 @@ class Executor(object):
scope0 = trainer_instance.get_worker_scope(0)
fetch_monitor = FetchHandlerMonitor(scope0, fetch_handler)
fetch_monitor.start()
self._default_executor.run_from_dataset(trainer_instance)
fetch_monitor.stop()
self._default_executor.release_trainer(trainer_instance)
else:
self._default_executor.run_from_dataset(trainer_instance)
self._default_executor.release_trainer(trainer_instance)
dataset._dynamic_adjust_after_train()
dataset._finish_to_run()
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import time
import unittest
import numpy as np
from paddle.fluid.framework import Program
import paddle.fluid.core as core
import paddle.fluid as fluid
......@@ -29,20 +30,35 @@ class TestFetchHandler(unittest.TestCase):
table = np.random.random((3, 10)).astype("float32")
prog = Program()
block = prog.current_block()
var_emb = block.create_var(name='emb', type=core.VarDesc.VarType.FP32)
var_emb3 = block.create_var(name='emb3', type=core.VarDesc.VarType.FP32)
class FH(fluid.executor.FetchHandler):
def handler(self, fetch_target_vars):
assert len(fetch_target_vars) == 1
def handler(self, fetch_dict):
assert len(fetch_dict) == 1
table_var = scope.var('emb').get_tensor()
table_var.set(table, place)
fh = FH(['emb'], period_secs=2, return_np=True)
fh = FH(var_dict={'emb': var_emb}, period_secs=2)
fm = fluid.trainer_factory.FetchHandlerMonitor(scope, fh)
fm.start()
time.sleep(10)
time.sleep(3)
fm.stop()
default_fh = fluid.executor.FetchHandler(
var_dict={'emb': var_emb,
'emb2': None,
'emb3': var_emb3},
period_secs=1)
default_fm = fluid.trainer_factory.FetchHandlerMonitor(scope,
default_fh)
default_fm.start()
time.sleep(5)
default_fm.stop()
if __name__ == "__main__":
unittest.main()
......@@ -15,11 +15,15 @@
import threading
import time
import logging
import numpy as np
logging.basicConfig()
from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer
from .device_worker import Hogwild, DownpourSGD, Section
from .framework import Variable
from multiprocessing import Process, Manager
__all__ = ["TrainerFactory", "FetchHandler", "FetchHandlerMonitor"]
......@@ -93,68 +97,74 @@ class FetchHandlerMonitor(object):
def __init__(self, scope, handler):
self.fetch_instance = handler
self.fetch_thread = threading.Thread(
target=self.handler_decorator,
args=(scope, self.fetch_instance.handler))
target=self.handler_launch_func, args=(scope, self.fetch_instance))
self.running_lock = threading.Lock()
self.running = False
def handler_launch_func(self, scope, handler):
fetch_instance = handler
period_secs = fetch_instance.period_secs
var_name_to_key = {}
for key in fetch_instance.var_dict:
if isinstance(fetch_instance.var_dict[key], Variable):
var_name_to_key[fetch_instance.var_dict[key].name] = key
else:
logging.warning("the value of {} is not a Variable".format(key))
var_name_to_key["None.var"] = key
elapsed_secs = 0
while True:
self.running_lock.acquire()
if self.running == False:
break
if elapsed_secs < period_secs:
# TODO(guru4elephant): needs customized condition
time.sleep(1)
elapsed_secs += 1
else:
elapsed_secs = 0
fetch_dict = {}
for key in var_name_to_key:
var = scope.find_var(key)
fetch_dict[key] = var
if var == None:
logging.warning("{} value currently not available".
format(var_name_to_key[key]))
res_dict = {}
for key in fetch_dict:
user_name = var_name_to_key[key]
if fetch_dict[key] == None:
res_dict[user_name] = None
continue
else:
res_dict[user_name] = fetch_dict[key].get_tensor()
lod = res_dict[user_name].lod()
if len(lod) > 0:
raise RuntimeError("Some of your fetched tensors \
hold LoD information. \
They can not be completely cast \
to Python ndarray. We can \
not return LoDTensor itself directly, \
please choose another targets")
if res_dict[user_name]._is_initialized():
res_dict[user_name] = np.array(res_dict[user_name])
else:
res_dict[user_name] = None
fetch_instance.handler(res_dict)
self.running_lock.release()
def start(self):
"""
start monitor,
it will start a monitor thread.
"""
self.running_lock.acquire()
self.running = True
self.running_lock.release()
self.fetch_thread.setDaemon(True)
self.fetch_thread.start()
def handler_decorator(self, fetch_scope, fetch_handler):
"""
decorator of handler,
Args:
fetch_scope(Scope): fetch scope
fetch_handler(Handler): fetch handler
"""
fetch_target_names = self.fetch_instance.fetch_target_names
period_secs = self.fetch_instance.period_secs
elapsed_secs = 0
while True:
while self.running and elapsed_secs >= period_secs:
elapsed_secs = 0
fetch_vars = [
fetch_scope.find_var(varname)
for varname in fetch_target_names
]
if None in fetch_vars:
continue
fetch_tensors = [var.get_tensor() for var in fetch_vars]
if self.fetch_instance.return_np:
fetch_nps = []
for tensor in fetch_tensors:
lod = tensor.lod()
if len(lod) > 0:
raise RuntimeError(
"Some of your fetched tensors hold LoD information. \
They can not be completely cast to Python ndarray. We can not \
return LoDTensor itself directly, please choose another targets"
)
if tensor._is_initialized():
fetch_nps.append(np.array(tensor))
else:
fetch_nps.append(None)
fetch_handler(fetch_nps)
else:
fetch_handler(fetch_tensors)
else:
time.sleep(1)
elapsed_secs += 1
def stop(self):
self.running_lock.acquire()
self.running = False
self.running_lock.release()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册