未验证 提交 e0841332 编写于 作者: X xiongkun 提交者: GitHub

fix error (#56572)

上级 99795a13
......@@ -589,7 +589,7 @@ def _setitem_impl_(var, item, value):
ProgramTranslator,
)
ProgramTranslator.get_instance()._params_map.add(
ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, var.desc.id(), output
)
......@@ -935,7 +935,7 @@ def _setitem_static(x, indices, values):
if not paddle.in_dynamic_mode():
# map var to the new output
paddle.jit.api.ProgramTranslator.get_instance()._params_map.add(
paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, x.desc.id(), output
)
return output
......@@ -1008,7 +1008,7 @@ def _setitem_static(x, indices, values):
)
if not paddle.in_dynamic_mode():
# map var to the new output
paddle.jit.api.ProgramTranslator.get_instance()._params_map.add(
paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, x.desc.id(), output
)
return output
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import astor
from paddle.utils import gast
......
......@@ -53,7 +53,7 @@ def convert_load(x):
from paddle.jit.dy2static.program_translator import ProgramTranslator
new_var = ProgramTranslator.get_instance()._params_map.get(
new_var = ProgramTranslator.get_instance()._inplace_map.get(
cur_block.program, x.desc.id()
)
if new_var is not None:
......@@ -381,9 +381,13 @@ def _run_paddle_cond(
_convert_tensor_arrray_if_necessary(helper, push_pop_names)
pred = cast_bool_if_necessary(pred)
init_args = helper.get(return_name_ids)
from paddle.jit.dy2static.program_translator import ProgramTranslator
inplace_map = ProgramTranslator.get_instance()._inplace_map
def new_true_fn():
# init args may contain mutable python container like [var, 2], we copy then like in while_loop
inplace_map_checkpoint = inplace_map.save_checkpoint()
helper.set(
return_name_ids,
paddle.utils.copy_mutable_vars(init_args),
......@@ -392,21 +396,22 @@ def _run_paddle_cond(
# IfExpr will return a non-None return value, so we just return ret.
# We assume normal return has no return value.
if ret is None:
return helper.get(return_name_ids)
else:
return ret
ret = helper.get(return_name_ids)
inplace_map.restore_checkpoint(inplace_map_checkpoint)
return ret
def new_false_fn():
# init args may contain mutable python container like [var, 2], we copy then like in while_loop
inplace_map_checkpoint = inplace_map.save_checkpoint()
helper.set(
return_name_ids,
paddle.utils.copy_mutable_vars(init_args),
)
ret = false_fn()
if ret is None:
return helper.get(return_name_ids)
else:
return ret
ret = helper.get(return_name_ids)
inplace_map.restore_checkpoint(inplace_map_checkpoint)
return ret
try:
cond_outs = paddle.static.nn.cond(
......
......@@ -1256,6 +1256,14 @@ class ConcreteProgram:
)
def _program_hash(program):
"""
because program is not deleted while calling from_func_spec.
so it's ok to use id(program)
"""
return id(program)
class ParametersRecorder:
def __init__(self):
self.params_dict = {}
......@@ -1263,35 +1271,28 @@ class ParametersRecorder:
@synchronized
def add(self, program, param):
"""use the default_program as key, append param the parameter list."""
key = self._program_hash(program)
key = _program_hash(program)
if key not in self.params_dict:
self.params_dict[key] = set()
params = self.params_dict[key]
params.add(param)
def pop(self, program):
params = self.params_dict.get(self._program_hash(program))
params = self.params_dict.get(_program_hash(program))
if params is None:
return []
del self.params_dict[self._program_hash(program)]
del self.params_dict[_program_hash(program)]
return list(params)
def _program_hash(self, program):
"""
because program is not deleted while calling from_func_spec.
so it's ok to use id(program)
"""
return id(program)
class ParametersMap:
class InplaceMap:
def __init__(self):
self.params_dict = {}
@synchronized
def add(self, program, id, param):
"""use the default_program as key, append param the parameter list."""
key = self._program_hash(program)
key = _program_hash(program)
if key not in self.params_dict:
self.params_dict[key] = {}
......@@ -1299,7 +1300,7 @@ class ParametersMap:
params[id] = param
def get(self, program, id):
params = self.params_dict.get(self._program_hash(program))
params = self.params_dict.get(_program_hash(program))
if params is None:
return None
if id not in params:
......@@ -1313,12 +1314,19 @@ class ParametersMap:
params[var.desc.id()] = root_var
return root_var
def _program_hash(self, program):
"""
because program is not deleted while calling from_func_spec.
so it's ok to use id(program)
"""
return id(program)
def restore_checkpoint(self, checkpoint):
# InplaceMap is a nested effect.
# when enter a block, we should save a checkpoint
# when exit a block, we should restore a checkpoint
# for example:
# if cond > 0:
# x [:] = 0
# return x
# x[:] only effect current cond block, we should restore in false block.
self.params_dict = checkpoint
def save_checkpoint(self):
return dict(self.params_dict.items())
class FallbackProgramLayer:
......@@ -1582,7 +1590,7 @@ class ProgramTranslator:
self._initialized = True
self._program_cache = ProgramCache()
self._params_recorder = ParametersRecorder()
self._params_map = ParametersMap()
self._inplace_map = InplaceMap()
self.enable_to_static = True
def enable(self, enable_to_static):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册