未验证 提交 f71593d2 编写于 作者: S Steffy-zxf 提交者: GitHub

fix bug that was resulted by add object detection task (#577)

上级 ae9edc1c
......@@ -344,10 +344,6 @@ class BaseTask(object):
# set default phase
self.enter_phase("train")
@property
def base_main_program(self):
return self._base_main_program
@contextlib.contextmanager
def phase_guard(self, phase):
self.enter_phase(phase)
......@@ -397,7 +393,7 @@ class BaseTask(object):
self._build_env_start_event()
self.env.is_inititalized = True
self.env.main_program = clone_program(
self.base_main_program, for_test=False)
self._base_main_program, for_test=False)
self.env.startup_program = fluid.Program()
with fluid.program_guard(self.env.main_program,
......@@ -410,7 +406,6 @@ class BaseTask(object):
self.env.metrics = self._add_metrics()
if self.is_predict_phase or self.is_test_phase:
# Todo: paddle.fluid.core_avx.EnforceNotMet: Getting 'tensor_desc' is not supported by the type of var kCUDNNFwdAlgoCache. at
self.env.main_program = clone_program(
self.env.main_program, for_test=True)
hub.common.paddle_helper.set_op_attr(
......@@ -1063,10 +1058,8 @@ class BaseTask(object):
capacity=64,
use_double_buffer=True,
iterable=True)
data_reader = data_loader.set_sample_list_generator(
self.reader, self.places)
# data_reader = data_loader.set_batch_generator(
# self.reader, places=self.places)
data_reader = data_loader.set_batch_generator(
self.reader, places=self.places)
else:
data_feeder = fluid.DataFeeder(
feed_list=self.feed_list, place=self.place)
......@@ -1083,28 +1076,12 @@ class BaseTask(object):
step_run_state.run_step = 1
num_batch_examples = len(batch)
if self.return_numpy == 2:
fetch_result = self.exe.run(
self.main_program_to_be_run,
feed=batch,
fetch_list=self.fetch_list,
return_numpy=False)
# fetch_result = [x if isinstance(x,fluid.LoDTensor) else np.array(x) for x in fetch_result]
fetch_result = [
x if hasattr(x, 'recursive_sequence_lengths') else
np.array(x) for x in fetch_result
]
elif self.return_numpy:
fetch_result = self.exe.run(
self.main_program_to_be_run,
feed=batch,
fetch_list=self.fetch_list)
else:
fetch_result = self.exe.run(
self.main_program_to_be_run,
feed=batch,
fetch_list=self.fetch_list,
return_numpy=False)
fetch_result = self.exe.run(
self.main_program_to_be_run,
feed=batch,
fetch_list=self.fetch_list,
return_numpy=self.return_numpy)
if not self.return_numpy:
fetch_result = [np.array(x) for x in fetch_result]
for index, result in enumerate(fetch_result):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册