提交 16dc8056 编写于 作者: W WangXi 提交者: sandyhouse

fix

上级 7215ca10
...@@ -41,7 +41,7 @@ class OffloadHelper(object): ...@@ -41,7 +41,7 @@ class OffloadHelper(object):
idx, idx,
type='cast', type='cast',
inputs={'X': src_var}, inputs={'X': src_var},
outputs={'Y': dst_var}, outputs={'Out': dst_var},
attrs={ attrs={
'in_dtype': src_var.dtype, 'in_dtype': src_var.dtype,
'out_dtype': dst_var.dtype, 'out_dtype': dst_var.dtype,
...@@ -166,7 +166,7 @@ class OffloadHelper(object): ...@@ -166,7 +166,7 @@ class OffloadHelper(object):
assert param in param_to_fp16 assert param in param_to_fp16
fp16_param_name = param_to_fp16[param] fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var[fp16_param_name] fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True fp16_param_var.persistable = True
self._insert_cast_op(block, idx + 1, param, self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param]) param_to_fp16[param])
...@@ -177,7 +177,7 @@ class OffloadHelper(object): ...@@ -177,7 +177,7 @@ class OffloadHelper(object):
# step3.4: remove cast op # step3.4: remove cast op
if op.type == 'cast': if op.type == 'cast':
input_name = op.desc.input_arg_names[0] input_name = op.desc.input_arg_names()[0]
if input_name in param_to_idx: if input_name in param_to_idx:
block._remove_op(idx, sync=False) block._remove_op(idx, sync=False)
continue continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册