未验证 提交 e0841332 编写于 作者: X xiongkun 提交者: GitHub

fix error (#56572)

上级 99795a13
...@@ -589,7 +589,7 @@ def _setitem_impl_(var, item, value): ...@@ -589,7 +589,7 @@ def _setitem_impl_(var, item, value):
ProgramTranslator, ProgramTranslator,
) )
ProgramTranslator.get_instance()._params_map.add( ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, var.desc.id(), output cur_block.program, var.desc.id(), output
) )
...@@ -935,7 +935,7 @@ def _setitem_static(x, indices, values): ...@@ -935,7 +935,7 @@ def _setitem_static(x, indices, values):
if not paddle.in_dynamic_mode(): if not paddle.in_dynamic_mode():
# map var to the new output # map var to the new output
paddle.jit.api.ProgramTranslator.get_instance()._params_map.add( paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, x.desc.id(), output cur_block.program, x.desc.id(), output
) )
return output return output
...@@ -1008,7 +1008,7 @@ def _setitem_static(x, indices, values): ...@@ -1008,7 +1008,7 @@ def _setitem_static(x, indices, values):
) )
if not paddle.in_dynamic_mode(): if not paddle.in_dynamic_mode():
# map var to the new output # map var to the new output
paddle.jit.api.ProgramTranslator.get_instance()._params_map.add( paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, x.desc.id(), output cur_block.program, x.desc.id(), output
) )
return output return output
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import astor import astor
from paddle.utils import gast from paddle.utils import gast
......
...@@ -53,7 +53,7 @@ def convert_load(x): ...@@ -53,7 +53,7 @@ def convert_load(x):
from paddle.jit.dy2static.program_translator import ProgramTranslator from paddle.jit.dy2static.program_translator import ProgramTranslator
new_var = ProgramTranslator.get_instance()._params_map.get( new_var = ProgramTranslator.get_instance()._inplace_map.get(
cur_block.program, x.desc.id() cur_block.program, x.desc.id()
) )
if new_var is not None: if new_var is not None:
...@@ -381,9 +381,13 @@ def _run_paddle_cond( ...@@ -381,9 +381,13 @@ def _run_paddle_cond(
_convert_tensor_arrray_if_necessary(helper, push_pop_names) _convert_tensor_arrray_if_necessary(helper, push_pop_names)
pred = cast_bool_if_necessary(pred) pred = cast_bool_if_necessary(pred)
init_args = helper.get(return_name_ids) init_args = helper.get(return_name_ids)
from paddle.jit.dy2static.program_translator import ProgramTranslator
inplace_map = ProgramTranslator.get_instance()._inplace_map
def new_true_fn(): def new_true_fn():
# init args may contain mutable python container like [var, 2], we copy then like in while_loop # init args may contain mutable python container like [var, 2], we copy then like in while_loop
inplace_map_checkpoint = inplace_map.save_checkpoint()
helper.set( helper.set(
return_name_ids, return_name_ids,
paddle.utils.copy_mutable_vars(init_args), paddle.utils.copy_mutable_vars(init_args),
...@@ -392,21 +396,22 @@ def _run_paddle_cond( ...@@ -392,21 +396,22 @@ def _run_paddle_cond(
# IfExpr will return a non-None return value, so we just return ret. # IfExpr will return a non-None return value, so we just return ret.
# We assume normal return has no return value. # We assume normal return has no return value.
if ret is None: if ret is None:
return helper.get(return_name_ids) ret = helper.get(return_name_ids)
else: inplace_map.restore_checkpoint(inplace_map_checkpoint)
return ret return ret
def new_false_fn(): def new_false_fn():
# init args may contain mutable python container like [var, 2], we copy then like in while_loop # init args may contain mutable python container like [var, 2], we copy then like in while_loop
inplace_map_checkpoint = inplace_map.save_checkpoint()
helper.set( helper.set(
return_name_ids, return_name_ids,
paddle.utils.copy_mutable_vars(init_args), paddle.utils.copy_mutable_vars(init_args),
) )
ret = false_fn() ret = false_fn()
if ret is None: if ret is None:
return helper.get(return_name_ids) ret = helper.get(return_name_ids)
else: inplace_map.restore_checkpoint(inplace_map_checkpoint)
return ret return ret
try: try:
cond_outs = paddle.static.nn.cond( cond_outs = paddle.static.nn.cond(
......
...@@ -1256,6 +1256,14 @@ class ConcreteProgram: ...@@ -1256,6 +1256,14 @@ class ConcreteProgram:
) )
def _program_hash(program):
"""
because program is not deleted while calling from_func_spec.
so it's ok to use id(program)
"""
return id(program)
class ParametersRecorder: class ParametersRecorder:
def __init__(self): def __init__(self):
self.params_dict = {} self.params_dict = {}
...@@ -1263,35 +1271,28 @@ class ParametersRecorder: ...@@ -1263,35 +1271,28 @@ class ParametersRecorder:
@synchronized @synchronized
def add(self, program, param): def add(self, program, param):
"""use the default_program as key, append param the parameter list.""" """use the default_program as key, append param the parameter list."""
key = self._program_hash(program) key = _program_hash(program)
if key not in self.params_dict: if key not in self.params_dict:
self.params_dict[key] = set() self.params_dict[key] = set()
params = self.params_dict[key] params = self.params_dict[key]
params.add(param) params.add(param)
def pop(self, program): def pop(self, program):
params = self.params_dict.get(self._program_hash(program)) params = self.params_dict.get(_program_hash(program))
if params is None: if params is None:
return [] return []
del self.params_dict[self._program_hash(program)] del self.params_dict[_program_hash(program)]
return list(params) return list(params)
def _program_hash(self, program):
"""
because program is not deleted while calling from_func_spec.
so it's ok to use id(program)
"""
return id(program)
class ParametersMap: class InplaceMap:
def __init__(self): def __init__(self):
self.params_dict = {} self.params_dict = {}
@synchronized @synchronized
def add(self, program, id, param): def add(self, program, id, param):
"""use the default_program as key, append param the parameter list.""" """use the default_program as key, append param the parameter list."""
key = self._program_hash(program) key = _program_hash(program)
if key not in self.params_dict: if key not in self.params_dict:
self.params_dict[key] = {} self.params_dict[key] = {}
...@@ -1299,7 +1300,7 @@ class ParametersMap: ...@@ -1299,7 +1300,7 @@ class ParametersMap:
params[id] = param params[id] = param
def get(self, program, id): def get(self, program, id):
params = self.params_dict.get(self._program_hash(program)) params = self.params_dict.get(_program_hash(program))
if params is None: if params is None:
return None return None
if id not in params: if id not in params:
...@@ -1313,12 +1314,19 @@ class ParametersMap: ...@@ -1313,12 +1314,19 @@ class ParametersMap:
params[var.desc.id()] = root_var params[var.desc.id()] = root_var
return root_var return root_var
def _program_hash(self, program): def restore_checkpoint(self, checkpoint):
""" # InplaceMap is a nested effect.
because program is not deleted while calling from_func_spec. # when enter a block, we should save a checkpoint
so it's ok to use id(program) # when exit a block, we should restore a checkpoint
""" # for example:
return id(program) # if cond > 0:
# x [:] = 0
# return x
# x[:] only effect current cond block, we should restore in false block.
self.params_dict = checkpoint
def save_checkpoint(self):
return dict(self.params_dict.items())
class FallbackProgramLayer: class FallbackProgramLayer:
...@@ -1582,7 +1590,7 @@ class ProgramTranslator: ...@@ -1582,7 +1590,7 @@ class ProgramTranslator:
self._initialized = True self._initialized = True
self._program_cache = ProgramCache() self._program_cache = ProgramCache()
self._params_recorder = ParametersRecorder() self._params_recorder = ParametersRecorder()
self._params_map = ParametersMap() self._inplace_map = InplaceMap()
self.enable_to_static = True self.enable_to_static = True
def enable(self, enable_to_static): def enable(self, enable_to_static):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册