diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index 1d229c753a6815031e980a6025a03c671e88dd8c..53c57a09f2e7957f0e57b205fd594f6937a208e9 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -453,7 +453,7 @@ class Controller(object): # compute loss task_id_var = net_inputs['__task_id'] - task_id_vec = layers.one_hot(task_id_var, num_instances) + task_id_vec = fluid.one_hot(task_id_var, num_instances) losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0) loss = layers.reduce_sum(task_id_vec * losses)