提交 82741fb6 编写于 作者: X xixiaoyao

fix bugs

上级 fa625b30
......@@ -19,7 +19,7 @@ import json
from paddle import fluid
import paddlepalm.utils.basic_helper as helper
from paddlepalm.utils import reader_helper, saver
from paddlepalm.distribute import gpu_dev_count
from paddlepalm.distribute import gpu_dev_count, data_feeder
# from paddlepalm.default_settings import *
DEBUG=False
......@@ -143,6 +143,7 @@ class Trainer(object):
self._train_init_prog = train_init_prog
with fluid.program_guard(train_prog, train_init_prog):
net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
self._net_inputs = net_inputs
# build backbone and task layers
# bb_output_vars = self._backbone.build(net_inputs, scope_name='__paddlepalm_')
......@@ -189,6 +190,7 @@ class Trainer(object):
bb_fetches = {k: v.name for k,v in bb_output_vars.items()}
task_fetches = {k: v.name for k,v in task_output_vars.items()}
self._fetches = task_fetches
# fetches = task_fetches
# fetches['__task_id'] = net_inputs['__task_id'].name
......@@ -269,6 +271,17 @@ class Trainer(object):
model_path,
main_program=self._train_init_prog)
def train(self, iterator, print_steps=5):
feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(self._net_inputs)
distribute_feeder = data_feeder(iterator, feed_batch_process_fn)
fetch_names, fetch_list = zip(*self._fetches.items())
for feed, mask in distribute_feeder:
rt_outputs = self.exe.run(self._train_prog, feed=feed, fetch_list=fetch_list)
while mask.pop() == False:
rt_outputs.pop()
def _build_head(self, net_inputs, phase, scope=""):
if phase == 'train':
output_vars = self._task_head.build(net_inputs, scope_name=scope)
......
......@@ -22,6 +22,19 @@ from paddle import fluid
from paddle.fluid import layers
def create_feed_batch_process_fn(net_inputs):
def feed_batch_process_fn(data):
temp = {}
for q, var in net_inputs.items():
if isinstance(var, str) or isinstance(var, unicode):
temp[var] = data[q]
else:
temp[var.name] = data[q]
return temp
return feed_batch_process_fn
def _check_and_adapt_shape_dtype(rt_val, attr, message=""):
if not isinstance(rt_val, np.ndarray):
rt_val = np.array(rt_val)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册