提交 acfcb1c5 编写于 作者: W wangxiao

change to fluid.one_hot

上级 ec741693
...@@ -453,7 +453,7 @@ class Controller(object): ...@@ -453,7 +453,7 @@ class Controller(object):
# compute loss # compute loss
task_id_var = net_inputs['__task_id'] task_id_var = net_inputs['__task_id']
task_id_vec = layers.one_hot(task_id_var, num_instances) task_id_vec = fluid.one_hot(task_id_var, num_instances)
losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0) losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0)
loss = layers.reduce_sum(task_id_vec * losses) loss = layers.reduce_sum(task_id_vec * losses)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册