diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py index 0ed0599faaf7574bbb6ba0786d5572639feca23b..76803818453c929d1dbf503159c59e1325c8337e 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py @@ -41,7 +41,7 @@ class OffloadHelper(object): idx, type='cast', inputs={'X': src_var}, - outputs={'Y': dst_var}, + outputs={'Out': dst_var}, attrs={ 'in_dtype': src_var.dtype, 'out_dtype': dst_var.dtype, @@ -166,7 +166,7 @@ class OffloadHelper(object): assert param in param_to_fp16 fp16_param_name = param_to_fp16[param] - fp16_param_var = block.var[fp16_param_name] + fp16_param_var = block.var(fp16_param_name) fp16_param_var.persistable = True self._insert_cast_op(block, idx + 1, param, param_to_fp16[param]) @@ -177,7 +177,7 @@ class OffloadHelper(object): # step3.4: remove cast op if op.type == 'cast': - input_name = op.desc.input_arg_names[0] + input_name = op.desc.input_arg_names()[0] if input_name in param_to_idx: block._remove_op(idx, sync=False) continue