提交 5a44a789 编写于 作者: S Steffy-zxf 提交者: wuzewu

Update tasks whose output type is list (#55)

上级 a271703a
......@@ -163,7 +163,7 @@ class BasicTask(object):
with fluid.program_guard(self.env.main_program,
self._base_startup_program):
with fluid.unique_name.guard(self.env.UNG):
self.env.output = self._build_net()
self.env.outputs = self._build_net()
if self.is_train_phase or self.is_test_phase:
self.env.label = self._add_label()
self.env.loss = self._add_loss()
......@@ -200,8 +200,11 @@ class BasicTask(object):
self.env.main_program = t_program
self.env.loss = self.env.main_program.global_block().vars[
self.env.loss.name]
self.env.output = self.env.main_program.global_block().vars[
self.env.output.name]
outputs_name = [var.name for var in self.env.outputs]
self.env.outputs = [
self.env.main_program.global_block().vars[name]
for name in outputs_name
]
metrics_name = [var.name for var in self.env.metrics]
self.env.metrics = [
self.env.main_program.global_block().vars[name]
......@@ -344,10 +347,10 @@ class BasicTask(object):
return self.env.label
@property
def output(self):
def outputs(self):
if not self.env.is_inititalized:
self._build_env()
return self.env.output
return self.env.outputs
@property
def metrics(self):
......@@ -378,7 +381,7 @@ class BasicTask(object):
def fetch_list(self):
if self.is_train_phase or self.is_test_phase:
return [metric.name for metric in self.metrics] + [self.loss.name]
return [self.output.name]
return [output.name for output in self.outputs]
def _build_env_start_event(self):
pass
......@@ -611,18 +614,18 @@ class ClassifierTask(BasicTask):
name="cls_out_b", initializer=fluid.initializer.Constant(0.)),
act="softmax")
return logits
return [logits]
def _add_label(self):
return fluid.layers.data(name="label", dtype="int64", shape=[1])
def _add_loss(self):
ce_loss = fluid.layers.cross_entropy(
input=self.output, label=self.label)
input=self.outputs[0], label=self.label)
return fluid.layers.mean(x=ce_loss)
def _add_metrics(self):
return [fluid.layers.accuracy(input=self.output, label=self.label)]
return [fluid.layers.accuracy(input=self.outputs[0], label=self.label)]
def _build_env_end_event(self):
with self.log_writer.mode(self.phase) as logw:
......@@ -720,7 +723,7 @@ class TextClassifierTask(ClassifierTask):
name="cls_out_b", initializer=fluid.initializer.Constant(0.)),
act="softmax")
return logits
return [logits]
class SequenceLabelTask(BasicTask):
......@@ -773,7 +776,7 @@ class SequenceLabelTask(BasicTask):
logits = fluid.layers.flatten(logits, axis=2)
logits = fluid.layers.softmax(logits)
self.num_labels = logits.shape[1]
return logits
return [logits]
def _add_label(self):
label = fluid.layers.data(
......@@ -782,7 +785,8 @@ class SequenceLabelTask(BasicTask):
def _add_loss(self):
labels = fluid.layers.flatten(self.label, axis=2)
ce_loss = fluid.layers.cross_entropy(input=self.output, label=labels)
ce_loss = fluid.layers.cross_entropy(
input=self.outputs[0], label=labels)
loss = fluid.layers.mean(x=ce_loss)
return loss
......@@ -865,7 +869,7 @@ class SequenceLabelTask(BasicTask):
return [metric.name for metric in self.metrics] + [self.loss.name]
elif self.is_predict_phase:
return [self.ret_infers.name] + [self.seq_len.name]
return [self.output.name]
return [output.name for output in self.outputs]
class MultiLabelClassifierTask(ClassifierTask):
......@@ -928,7 +932,7 @@ class MultiLabelClassifierTask(ClassifierTask):
label_split = fluid.layers.split(self.label, self.num_classes, dim=-1)
total_loss = fluid.layers.fill_constant(
shape=[1], value=0.0, dtype='float64')
for index, probs in enumerate(self.output):
for index, probs in enumerate(self.outputs):
ce_loss = fluid.layers.cross_entropy(
input=probs, label=label_split[index])
total_loss += fluid.layers.reduce_sum(ce_loss)
......@@ -939,7 +943,7 @@ class MultiLabelClassifierTask(ClassifierTask):
label_split = fluid.layers.split(self.label, self.num_classes, dim=-1)
# metrics change to auc of every class
eval_list = []
for index, probs in enumerate(self.output):
for index, probs in enumerate(self.outputs):
current_auc, _, _ = fluid.layers.auc(
input=probs, label=label_split[index])
eval_list.append(current_auc)
......@@ -1015,4 +1019,4 @@ class MultiLabelClassifierTask(ClassifierTask):
def fetch_list(self):
if self.is_train_phase or self.is_test_phase:
return [metric.name for metric in self.metrics] + [self.loss.name]
return self.output
return self.outputs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册