未验证 提交 3be7f495 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ] transfer list into tensor array at runtime. (#45594)

* 1. make list transformer into jit form.
2. fix some bugs in tensor_array, such as append.
3. enhance the function analysis visitor to recognize push/pop.
4. add setter/getter helper to deal with 2+ name sets.

* fix ci errors:
1. add to_tensor_array logic in convert_cond
2. fix IfExpr error.
3. fix erros while return_names or push_pop_names is None
4. fix slice error in a[i]=1 where a is tensor_array
5. add pop interface in Variable
上级 31b92305
...@@ -75,7 +75,7 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var, ...@@ -75,7 +75,7 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var,
PADDLE_ENFORCE_EQ(dst_tensor.defined(), PADDLE_ENFORCE_EQ(dst_tensor.defined(),
true, true,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"dst_tensor shall be defined.")); "dst_tensor `%s` shall be defined.", name));
if (dst_tensor.is_dense_tensor()) { if (dst_tensor.is_dense_tensor()) {
auto &src_tensor = src_var.Get<phi::DenseTensor>(); auto &src_tensor = src_var.Get<phi::DenseTensor>();
......
...@@ -93,7 +93,7 @@ class DygraphToStaticAst(BaseTransformer): ...@@ -93,7 +93,7 @@ class DygraphToStaticAst(BaseTransformer):
EarlyReturnTransformer, EarlyReturnTransformer,
BasicApiTransformer, # Basic Api BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor) TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor)
ListTransformer, # List used in control flow #ListTransformer, # List used in control flow
BreakContinueTransformer, # break/continue in loops BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not LogicalTransformer, # logical and/or/not
......
...@@ -25,6 +25,7 @@ from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, lo ...@@ -25,6 +25,7 @@ from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, lo
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException
from paddle.fluid.dygraph.dygraph_to_static.utils import GetterSetterHelper
from paddle.fluid.layers.utils import copy_mutable_vars from paddle.fluid.layers.utils import copy_mutable_vars
...@@ -74,7 +75,12 @@ def _unpack_by_structure_paddle(target, structure): ...@@ -74,7 +75,12 @@ def _unpack_by_structure_paddle(target, structure):
return ret return ret
def convert_while_loop(cond, body, getter, setter): def convert_while_loop(cond,
body,
getter,
setter,
return_name_ids=None,
push_pop_names=None):
""" """
A function representation of a Python ``while`` statement. A function representation of a Python ``while`` statement.
...@@ -91,21 +97,41 @@ def convert_while_loop(cond, body, getter, setter): ...@@ -91,21 +97,41 @@ def convert_while_loop(cond, body, getter, setter):
# If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars. # If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars.
pred = cond() pred = cond()
if isinstance(pred, Variable): if isinstance(pred, Variable):
_run_paddle_while(cond, body, getter, setter) _run_paddle_while(cond, body, getter, setter, return_name_ids,
push_pop_names)
else: else:
_run_py_while(cond, body, getter, setter) _run_py_while(cond, body, getter, setter)
def _run_paddle_while(cond, body, getter, setter): def _convert_tensor_arrray_if_necessary(setterhelper, push_pop_names):
push_pop_vars = setterhelper.get(push_pop_names)
if push_pop_vars is None:
return
def maybe_to_tensor_array(v):
if isinstance(v, list):
return create_array("float32", initialized_list=v)
else:
return v
setterhelper.set(push_pop_names,
[maybe_to_tensor_array(v) for v in push_pop_vars])
def _run_paddle_while(cond, body, getter, setter, return_name_ids,
push_pop_names):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors. # NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
helper = GetterSetterHelper(getter, setter, return_name_ids, push_pop_names)
_convert_tensor_arrray_if_necessary(helper, push_pop_names)
def new_body_fn(*args): def new_body_fn(*args):
""" wrap the body() and add return value for `while_loop` """ wrap the body() and add return value for `while_loop`
the args may be differ from getter(). the args may be differ from getter().
""" """
mutable_loop_vars = args mutable_loop_vars = args
setter(mutable_loop_vars) helper.set(return_name_ids, mutable_loop_vars)
body() body()
return getter() return helper.get(return_name_ids)
def new_cond_fn(*args): def new_cond_fn(*args):
""" cond is a zero-args function, which is not """ cond is a zero-args function, which is not
...@@ -116,12 +142,13 @@ def _run_paddle_while(cond, body, getter, setter): ...@@ -116,12 +142,13 @@ def _run_paddle_while(cond, body, getter, setter):
# UndefinedVar will become data layer not check variable with value=NO_VALUE_MAGIC. # UndefinedVar will become data layer not check variable with value=NO_VALUE_MAGIC.
loop_vars = [ loop_vars = [
to_static_variable(var) if not isinstance(var, UndefinedVar) else var to_static_variable(var) if not isinstance(var, UndefinedVar) else var
for var in getter() for var in helper.get(return_name_ids)
] ]
setter(loop_vars) # change the non-local var to variable helper.set(return_name_ids,
loop_vars) # change the non-local var to variable
# variable maybe modified to inner var. change it into # variable maybe modified to inner var. change it into
loop_vars = control_flow.while_loop(new_cond_fn, new_body_fn, loop_vars) loop_vars = control_flow.while_loop(new_cond_fn, new_body_fn, loop_vars)
setter(loop_vars) # change back to loop_vars helper.set(return_name_ids, loop_vars)
return loop_vars return loop_vars
...@@ -263,8 +290,13 @@ def _run_py_logical_not(x): ...@@ -263,8 +290,13 @@ def _run_py_logical_not(x):
return not x return not x
def convert_ifelse(pred, true_fn, false_fn, get_args, set_args, def convert_ifelse(pred,
return_name_ids): true_fn,
false_fn,
get_args,
set_args,
return_name_ids,
push_pop_names=None):
""" """
A function representation of a Python ``if/else`` statement. A function representation of a Python ``if/else`` statement.
...@@ -274,6 +306,7 @@ def convert_ifelse(pred, true_fn, false_fn, get_args, set_args, ...@@ -274,6 +306,7 @@ def convert_ifelse(pred, true_fn, false_fn, get_args, set_args,
false_fn(callable): A callable to be performed if ``pred`` is false. false_fn(callable): A callable to be performed if ``pred`` is false.
get_args(callable): Get all arguments that needed in true_fn and false_fn. get_args(callable): Get all arguments that needed in true_fn and false_fn.
set_args(callable): Update arguments that modified in trure_fn and false_fn. set_args(callable): Update arguments that modified in trure_fn and false_fn.
return_name_ids(list[string]): the returned names.
Returns: Returns:
``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` . ``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` .
...@@ -281,7 +314,7 @@ def convert_ifelse(pred, true_fn, false_fn, get_args, set_args, ...@@ -281,7 +314,7 @@ def convert_ifelse(pred, true_fn, false_fn, get_args, set_args,
""" """
if isinstance(pred, Variable): if isinstance(pred, Variable):
out = _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args, out = _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids) return_name_ids, push_pop_names)
else: else:
out = _run_py_ifelse(pred, true_fn, false_fn, get_args, set_args, out = _run_py_ifelse(pred, true_fn, false_fn, get_args, set_args,
return_name_ids) return_name_ids)
...@@ -290,27 +323,30 @@ def convert_ifelse(pred, true_fn, false_fn, get_args, set_args, ...@@ -290,27 +323,30 @@ def convert_ifelse(pred, true_fn, false_fn, get_args, set_args,
def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args, def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids): return_name_ids, push_pop_names):
""" """
Paddle cond API will evaluate both ture_fn and false_fn codes. Paddle cond API will evaluate both ture_fn and false_fn codes.
""" """
helper = GetterSetterHelper(get_args, set_args, return_name_ids,
push_pop_names)
_convert_tensor_arrray_if_necessary(helper, push_pop_names)
pred = cast_bool_if_necessary(pred) pred = cast_bool_if_necessary(pred)
init_args = get_args() init_args = helper.get(return_name_ids)
def new_true_fn(): def new_true_fn():
#init args may contain mutable python container like [var, 2], we copy then like in while_loop #init args may contain mutable python container like [var, 2], we copy then like in while_loop
set_args(copy_mutable_vars(init_args)) helper.set(return_name_ids, copy_mutable_vars(init_args))
ret = true_fn() ret = true_fn()
# IfExpr will return a non-None return value, so we just return ret. # IfExpr will return a non-None return value, so we just return ret.
# We assume normal return has no return value. # We assume normal return has no return value.
if ret is None: return get_args() if ret is None: return helper.get(return_name_ids)
else: return ret else: return ret
def new_false_fn(): def new_false_fn():
#init args may contain mutable python container like [var, 2], we copy then like in while_loop #init args may contain mutable python container like [var, 2], we copy then like in while_loop
set_args(copy_mutable_vars(init_args)) helper.set(return_name_ids, copy_mutable_vars(init_args))
ret = false_fn() ret = false_fn()
if ret is None: return get_args() if ret is None: return helper.get(return_name_ids)
else: return ret else: return ret
try: try:
...@@ -327,6 +363,8 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args, ...@@ -327,6 +363,8 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
"Your if/else have different number of return value. TODO: add link to modifty. {}" "Your if/else have different number of return value. TODO: add link to modifty. {}"
.format(str(e))) .format(str(e)))
raise e raise e
get_args = lambda: helper.get(return_name_ids)
set_args = lambda vs: helper.set(return_name_ids, vs)
return _recover_args_state(cond_outs, get_args, set_args, return_name_ids) return _recover_args_state(cond_outs, get_args, set_args, return_name_ids)
......
...@@ -35,6 +35,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_no ...@@ -35,6 +35,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_no
from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_INDEX_PREFIX, FOR_ITER_TUPLE_PREFIX, FOR_ITER_TUPLE_INDEX_PREFIX, FOR_ITER_VAR_LEN_PREFIX, FOR_ITER_VAR_NAME_PREFIX, FOR_ITER_ZIP_TO_LIST_PREFIX, FOR_ITER_TARGET_PREFIX, FOR_ITER_ITERATOR_PREFIX from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_INDEX_PREFIX, FOR_ITER_TUPLE_PREFIX, FOR_ITER_TUPLE_INDEX_PREFIX, FOR_ITER_VAR_LEN_PREFIX, FOR_ITER_VAR_NAME_PREFIX, FOR_ITER_ZIP_TO_LIST_PREFIX, FOR_ITER_TARGET_PREFIX, FOR_ITER_ITERATOR_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import GetterSetterHelper, create_name_str
TRUE_FUNC_PREFIX = 'true_fn' TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn' FALSE_FUNC_PREFIX = 'false_fn'
...@@ -65,16 +66,16 @@ class IfElseTransformer(BaseTransformer): ...@@ -65,16 +66,16 @@ class IfElseTransformer(BaseTransformer):
def visit_If(self, node): def visit_If(self, node):
self.generic_visit(node) self.generic_visit(node)
new_vars_stmts, true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids = transform_if_else( true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids, push_pop_ids = transform_if_else(
node, self.root) node, self.root)
new_node = create_convert_ifelse_node(return_name_ids, node.test, new_node = create_convert_ifelse_node(return_name_ids, push_pop_ids,
true_func_node, false_func_node, node.test, true_func_node,
get_args_node, set_args_node) false_func_node, get_args_node,
set_args_node)
return new_vars_stmts + [ return [get_args_node, set_args_node, true_func_node, false_func_node
get_args_node, set_args_node, true_func_node, false_func_node ] + [new_node]
] + [new_node]
def visit_Call(self, node): def visit_Call(self, node):
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]` # Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
...@@ -91,7 +92,7 @@ class IfElseTransformer(BaseTransformer): ...@@ -91,7 +92,7 @@ class IfElseTransformer(BaseTransformer):
""" """
self.generic_visit(node) self.generic_visit(node)
new_node = create_convert_ifelse_node(None, node.test, node.body, new_node = create_convert_ifelse_node(None, None, node.test, node.body,
node.orelse, None, None, True) node.orelse, None, None, True)
# Note: A blank line will be added separately if transform gast.Expr # Note: A blank line will be added separately if transform gast.Expr
# into source code. Using gast.Expr.value instead to avoid syntax error # into source code. Using gast.Expr.value instead to avoid syntax error
...@@ -306,16 +307,7 @@ def transform_if_else(node, root): ...@@ -306,16 +307,7 @@ def transform_if_else(node, root):
# TODO(liym27): Consider variable like `self.a` modified in if/else node. # TODO(liym27): Consider variable like `self.a` modified in if/else node.
return_name_ids = sorted(list(node.pd_scope.modified_vars())) return_name_ids = sorted(list(node.pd_scope.modified_vars()))
# NOTE: Python can create variable only in if body or only in else body, and use it out of if/else. push_pop_ids = sorted(list(node.pd_scope.variadic_length_vars()))
# E.g.
#
# if x > 5:
# a = 10
# print(a)
#
# Create static variable for those variables
create_new_vars_in_parent_stmts = []
nonlocal_names = list(return_name_ids) nonlocal_names = list(return_name_ids)
nonlocal_names.sort() nonlocal_names.sort()
# NOTE: All var in return_name_ids should be in nonlocal_names. # NOTE: All var in return_name_ids should be in nonlocal_names.
...@@ -359,13 +351,15 @@ def transform_if_else(node, root): ...@@ -359,13 +351,15 @@ def transform_if_else(node, root):
input_args=empty_arg_node, input_args=empty_arg_node,
return_name_ids=[]) return_name_ids=[])
get_args_node = create_get_args_node(nonlocal_names) helper = GetterSetterHelper(None, None, nonlocal_names, push_pop_ids)
set_args_node = create_set_args_node(nonlocal_names) get_args_node = create_get_args_node(helper.union())
set_args_node = create_set_args_node(helper.union())
return create_new_vars_in_parent_stmts, true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids return true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids, push_pop_ids
def create_convert_ifelse_node(return_name_ids, def create_convert_ifelse_node(return_name_ids,
push_pop_ids,
pred, pred,
true_func, true_func,
false_func, false_func,
...@@ -377,17 +371,6 @@ def create_convert_ifelse_node(return_name_ids, ...@@ -377,17 +371,6 @@ def create_convert_ifelse_node(return_name_ids,
pred, true_fn, false_fn, get_args, set_args, return_name_ids)` pred, true_fn, false_fn, get_args, set_args, return_name_ids)`
to replace original `python if/else` statement. to replace original `python if/else` statement.
""" """
def create_name_str(name_ids):
"""
Return "('x', 'y')" for [x, y]
"""
if not name_ids:
return 'None'
names_str = ["'%s'" % name for name in name_ids]
return "(%s, )" % ','.join(names_str)
if is_if_expr: if is_if_expr:
true_func_source = "lambda : {}".format(ast_to_source_code(true_func)) true_func_source = "lambda : {}".format(ast_to_source_code(true_func))
false_func_source = "lambda : {}".format(ast_to_source_code(false_func)) false_func_source = "lambda : {}".format(ast_to_source_code(false_func))
...@@ -397,7 +380,7 @@ def create_convert_ifelse_node(return_name_ids, ...@@ -397,7 +380,7 @@ def create_convert_ifelse_node(return_name_ids,
convert_ifelse_layer = gast.parse( convert_ifelse_layer = gast.parse(
'_jst.IfElse(' '_jst.IfElse('
'{pred}, {true_fn}, {false_fn}, {get_args}, {set_args}, {return_name_ids})' '{pred}, {true_fn}, {false_fn}, {get_args}, {set_args}, {return_name_ids}, push_pop_names={push_pop_ids})'
.format( .format(
pred=ast_to_source_code(pred), pred=ast_to_source_code(pred),
true_fn=true_func_source, true_fn=true_func_source,
...@@ -406,6 +389,7 @@ def create_convert_ifelse_node(return_name_ids, ...@@ -406,6 +389,7 @@ def create_convert_ifelse_node(return_name_ids,
'lambda: None', #TODO: better way to deal with this 'lambda: None', #TODO: better way to deal with this
set_args=set_args_func.name set_args=set_args_func.name
if not is_if_expr else 'lambda args: None', if not is_if_expr else 'lambda args: None',
return_name_ids=create_name_str(return_name_ids))).body[0] return_name_ids=create_name_str(return_name_ids),
push_pop_ids=create_name_str(push_pop_ids))).body[0]
return convert_ifelse_layer return convert_ifelse_layer
...@@ -33,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransfor ...@@ -33,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransfor
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForLoopTuplePreTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForLoopTuplePreTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import GetterSetterHelper, create_name_str
__all__ = ['LoopTransformer', 'NameVisitor'] __all__ = ['LoopTransformer', 'NameVisitor']
...@@ -43,8 +44,8 @@ FOR_CONDITION_PREFIX = 'for_loop_condition' ...@@ -43,8 +44,8 @@ FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body' FOR_BODY_PREFIX = 'for_loop_body'
def create_while_nodes(condition_name, body_name, loop_var_names, getter_name, def create_while_nodes(condition_name, body_name, loop_var_names,
setter_name): push_pop_names, getter_name, setter_name):
""" """
Returns a list of gast.Node which represents the calling of Paddle Returns a list of gast.Node which represents the calling of Paddle
controlflow while_loop. controlflow while_loop.
...@@ -84,9 +85,9 @@ def create_while_nodes(condition_name, body_name, loop_var_names, getter_name, ...@@ -84,9 +85,9 @@ def create_while_nodes(condition_name, body_name, loop_var_names, getter_name,
assign_loop_var_names.append(name) assign_loop_var_names.append(name)
while_func_name = "_jst.While" while_func_name = "_jst.While"
while_node_str = "{}({}, {}, {}, {})".format(while_func_name, while_node_str = "{}({}, {}, {}, {}, return_name_ids={}, push_pop_names={})".format(
condition_name, body_name, while_func_name, condition_name, body_name, getter_name, setter_name,
getter_name, setter_name) create_name_str(loop_var_names), create_name_str(push_pop_names))
while_node = gast.parse(while_node_str).body[0] while_node = gast.parse(while_node_str).body[0]
ret = [while_node] ret = [while_node]
...@@ -539,6 +540,7 @@ class LoopTransformer(BaseTransformer): ...@@ -539,6 +540,7 @@ class LoopTransformer(BaseTransformer):
# 2. get original loop vars # 2. get original loop vars
loop_var_names, create_var_names = node.pd_scope.modified_vars( loop_var_names, create_var_names = node.pd_scope.modified_vars(
), node.pd_scope.created_vars() ), node.pd_scope.created_vars()
push_pop_names = list(node.pd_scope.variadic_length_vars())
# TODO: Remove the bunch of code? We have the unique format `for A in B:` # TODO: Remove the bunch of code? We have the unique format `for A in B:`
# NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases, # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
# we need append new loop var & remove useless loop var # we need append new loop var & remove useless loop var
...@@ -607,12 +609,13 @@ class LoopTransformer(BaseTransformer): ...@@ -607,12 +609,13 @@ class LoopTransformer(BaseTransformer):
type_comment=None) type_comment=None)
new_stmts.append(body_func_node) new_stmts.append(body_func_node)
get_args_node = create_get_args_node(nonlocal_names) helper = GetterSetterHelper(None, None, nonlocal_names, push_pop_names)
set_args_node = create_set_args_node(nonlocal_names) get_args_node = create_get_args_node(helper.union())
set_args_node = create_set_args_node(helper.union())
# 7. create & append while loop node # 7. create & append while loop node
while_loop_nodes = create_while_nodes(condition_func_node.name, while_loop_nodes = create_while_nodes(condition_func_node.name,
body_func_node.name, body_func_node.name,
nonlocal_names, nonlocal_names, push_pop_names,
get_args_node.name, get_args_node.name,
set_args_node.name) set_args_node.name)
new_stmts.extend([get_args_node, set_args_node]) new_stmts.extend([get_args_node, set_args_node])
...@@ -623,6 +626,7 @@ class LoopTransformer(BaseTransformer): ...@@ -623,6 +626,7 @@ class LoopTransformer(BaseTransformer):
def get_while_stmt_nodes(self, node): def get_while_stmt_nodes(self, node):
loop_var_names, create_var_names = node.pd_scope.modified_vars( loop_var_names, create_var_names = node.pd_scope.modified_vars(
), node.pd_scope.created_vars() ), node.pd_scope.created_vars()
push_pop_names = list(node.pd_scope.variadic_length_vars())
new_stmts = [] new_stmts = []
# create non-local statement for body and cond. # create non-local statement for body and cond.
...@@ -675,12 +679,14 @@ class LoopTransformer(BaseTransformer): ...@@ -675,12 +679,14 @@ class LoopTransformer(BaseTransformer):
returns=None, returns=None,
type_comment=None) type_comment=None)
new_stmts.append(body_func_node) new_stmts.append(body_func_node)
get_args_node = create_get_args_node(nonlocal_names)
set_args_node = create_set_args_node(nonlocal_names) helper = GetterSetterHelper(None, None, nonlocal_names, push_pop_names)
get_args_node = create_get_args_node(helper.union())
set_args_node = create_set_args_node(helper.union())
while_loop_nodes = create_while_nodes(condition_func_node.name, while_loop_nodes = create_while_nodes(condition_func_node.name,
body_func_node.name, body_func_node.name,
nonlocal_names, nonlocal_names, push_pop_names,
get_args_node.name, get_args_node.name,
set_args_node.name) set_args_node.name)
new_stmts.extend([get_args_node, set_args_node]) new_stmts.extend([get_args_node, set_args_node])
......
...@@ -33,6 +33,8 @@ from paddle.fluid.data_feeder import convert_dtype ...@@ -33,6 +33,8 @@ from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers import assign from paddle.fluid.layers import assign
import collections
from functools import reduce
# Note(Aurelius): Do not forget the dot `.` to distinguish other # Note(Aurelius): Do not forget the dot `.` to distinguish other
# module such as paddlenlp. # module such as paddlenlp.
...@@ -1020,8 +1022,9 @@ class NameScope: ...@@ -1020,8 +1022,9 @@ class NameScope:
self.args = set() self.args = set()
self.father = None # point to the nearest function name scope. self.father = None # point to the nearest function name scope.
self.w_vars = set() # all qualified + normal names been stored self.w_vars = set() # all qualified + normal names been stored
self.created = set( self.created = set() # useful for control flow compatibility
) # useful for control flow compatibility. may be remove later # may be remove later.
self.push_pop_vars = set() # we call push and pop in the vars
def set_father(self, father): def set_father(self, father):
self.father = father self.father = father
...@@ -1040,6 +1043,9 @@ class NameScope: ...@@ -1040,6 +1043,9 @@ class NameScope:
# may be globals / non-locals / args / qualified names and created_vars # may be globals / non-locals / args / qualified names and created_vars
return self.w_vars return self.w_vars
def variadic_length_vars(self):
return self.push_pop_vars
def control_flow_vars(self): def control_flow_vars(self):
valid_names = self.w_vars valid_names = self.w_vars
tmp = self.father.global_vars & valid_names, tmp = self.father.global_vars & valid_names,
...@@ -1053,17 +1059,25 @@ class NameScope: ...@@ -1053,17 +1059,25 @@ class NameScope:
self.nonlocals |= name_scope.nonlocals self.nonlocals |= name_scope.nonlocals
self.args |= name_scope.args self.args |= name_scope.args
self.w_vars |= name_scope.w_vars self.w_vars |= name_scope.w_vars
self.push_pop_vars |= name_scope.push_pop_vars
class FunctionNameLivenessAnalysis(gast.NodeVisitor): class FunctionNameLivenessAnalysis(gast.NodeVisitor):
""" analyze the liveness of a function. """ analyze the liveness of a function.
every variables stored in this scope will be collected, every variables stored in this scope will be collected,
in addition with global/nonlocal information. in addition with global/nonlocal information and
push_pop information.
1. global variable is stored in node.var_globals. 1. global variable is stored in node.var_globals.
2. nonlocal variable is stored in node.var_nonlocals. 2. nonlocal variable is stored in node.var_nonlocals.
3. arguments is stored in node.var_args. 3. arguments is stored in node.var_args.
4. if a variable's push and pop attribute is called,
it will be collected in push_pop_vars. They are
used for transformation to tensor_array.
NOTE: push_pop_vars **may not** in w_vars.
a.push(0) don't modify the variable a, but the content
of a.
For example: For example:
...@@ -1073,8 +1087,12 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1073,8 +1087,12 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
nonlocal x,y nonlocal x,y
print(a) print(a)
i = k i = k
b = []
c = [1,2,3]
for m in range(10): for m in range(10):
q = 12 q = 12
b.push(1)
c.pop()
After this visitor we have: After this visitor we have:
# node is the FunctionDef node with name: "func" # node is the FunctionDef node with name: "func"
...@@ -1082,7 +1100,8 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1082,7 +1100,8 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
globals = ['i', 'j'], globals = ['i', 'j'],
nonlocals = ['x', 'y'], nonlocals = ['x', 'y'],
args = ['args', 'kargs'], args = ['args', 'kargs'],
wr_vars = ['a', 'i', 'q', 'm'] wr_vars = ['a', 'i', 'q', 'm', 'c', 'b']
push_pop_vars = ['b', 'c']
) )
""" """
...@@ -1137,7 +1156,7 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1137,7 +1156,7 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
self._get_argument_names(node)) self._get_argument_names(node))
def post_func(): def post_func():
""" NOTE: why we need merge w_vars here ? """ NOTE: why we need merge w_vars and push_pop_vars here ?
because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope. because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope.
""" """
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import WHILE_CONDITION_PREFIX, WHILE_BODY_PREFIX, FOR_CONDITION_PREFIX, FOR_BODY_PREFIX from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import WHILE_CONDITION_PREFIX, WHILE_BODY_PREFIX, FOR_CONDITION_PREFIX, FOR_BODY_PREFIX
...@@ -1155,6 +1174,8 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1155,6 +1174,8 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
if self._father_name_scope() and is_control_flow_def_node(): if self._father_name_scope() and is_control_flow_def_node():
self._father_name_scope().w_vars |= self._current_name_scope( self._father_name_scope().w_vars |= self._current_name_scope(
).w_vars ).w_vars
self._father_name_scope(
).push_pop_vars |= self._current_name_scope().push_pop_vars
self._visit_scope_node(node, pre_func, post_func) self._visit_scope_node(node, pre_func, post_func)
...@@ -1210,6 +1231,17 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1210,6 +1231,17 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
name = ast_to_source_code(node).strip() name = ast_to_source_code(node).strip()
self._current_name_scope().w_vars.add(name) self._current_name_scope().w_vars.add(name)
def visit_Call(self, node):
self.generic_visit(node)
if not isinstance(node.func, gast.Attribute):
return
variadic_length_method = ['append', 'pop']
if node.func.attr not in variadic_length_method:
return
# we don't treat push and pop as a write operator. such as a[i]=10 is not modify a.
name = ast_to_source_code(node.func.value).strip()
self._current_name_scope().push_pop_vars.add(name)
def _get_argument_names(self, node): def _get_argument_names(self, node):
""" get all arguments name in the functiondef node. """ get all arguments name in the functiondef node.
this node is local to the function and shouldn't this node is local to the function and shouldn't
...@@ -1315,3 +1347,57 @@ def create_nonlocal_stmt_nodes(names): ...@@ -1315,3 +1347,57 @@ def create_nonlocal_stmt_nodes(names):
return [] return []
func_code = "nonlocal {}".format(','.join(names)) func_code = "nonlocal {}".format(','.join(names))
return [gast.parse(func_code).body[0]] return [gast.parse(func_code).body[0]]
class GetterSetterHelper:
""" we have two classes of names in setter and getter function:
w_vars(loop_vars) + push_pop_vars
To simplify the setter logic in convert_while and convert_cond,
we extract the helper class here.
"""
def __init__(self, getter_func, setter_func, *name_lists):
name_lists = map(lambda x: [] if x is None else x, name_lists)
name_sets = map(lambda x: set(x), name_lists)
self._union = list(reduce(lambda x, y: x | y, name_sets, set()))
self._union.sort()
self.getter = getter_func
self.setter = setter_func
self.name2id = {name: idx for idx, name in enumerate(self._union)}
def union(self):
return self._union
def get(self, names):
if names is None: names = []
vars = self.getter()
if vars is None: return tuple()
for n in names:
assert n in self.name2id, "the name `{}` not in name union set`{}`.".format(
n, self.name2id.keys())
return tuple(map(lambda n: vars[self.name2id[n]], names))
def set(self, names, values):
if names is None: names = []
if values is None: values = []
vars = self.getter()
if vars is None: return
for n in names:
assert n in self.name2id, "the name `{}` not in name union set`{}`.".format(
n, self.name2id.keys())
vars = list(vars)
indices = list(map(lambda n: self.name2id[n], names))
for i, v in zip(indices, values):
vars[i] = v
self.setter(vars)
def create_name_str(name_ids):
"""
Return "('x', 'y')" for [x, y]
"""
if not name_ids:
return 'None'
names_str = ["'%s'" % name for name in name_ids]
return "(%s, )" % ','.join(names_str)
...@@ -21,6 +21,7 @@ from .. import core ...@@ -21,6 +21,7 @@ from .. import core
from ..framework import Variable, unique_name, static_only from ..framework import Variable, unique_name, static_only
from .layer_function_generator import OpProtoHolder from .layer_function_generator import OpProtoHolder
from .control_flow import array_write, array_length from .control_flow import array_write, array_length
from paddle.fluid.dygraph.base import in_declarative_mode
_supported_int_dtype_ = [ _supported_int_dtype_ = [
core.VarDesc.VarType.BOOL, core.VarDesc.VarType.BOOL,
...@@ -211,16 +212,35 @@ def monkey_patch_variable(): ...@@ -211,16 +212,35 @@ def monkey_patch_variable():
""" """
if not isinstance(var, Variable): if not isinstance(var, Variable):
raise TypeError( if in_declarative_mode():
"Required input var should be Variable, but received {}".format( """ in dy2static mode, x may be tensorable values such as int, float, np.array
type(var))) """
from paddle.tensor.creation import to_tensor
var = to_tensor(var)
else:
raise TypeError(
"Required input var should be Variable, but received {}".
format(type(var)))
if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY: if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
raise TypeError( raise TypeError(
"Only Variable with VarType.LOD_TENSOR_ARRAY support `append` method, but received type: {}" "Only Variable with VarType.LOD_TENSOR_ARRAY support `append` method, but received type: {}"
.format(self.type)) .format(self.type))
array_write(x=var, i=array_length(self), array=self) array_write(x=var, i=array_length(self), array=self)
@static_only
def pop(self, *args):
"""
**Notes**:
**The type variable must be LoD Tensor Array.
"""
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import _run_paddle_pop
if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
raise TypeError(
"Only Variable with VarType.LOD_TENSOR_ARRAY support `append` method, but received type: {}"
.format(self.type))
return _run_paddle_pop(self, *args)
def _scalar_op_(var, scale, bias): def _scalar_op_(var, scale, bias):
block = current_block(var) block = current_block(var)
out = create_new_tmp_var(block, var.dtype) out = create_new_tmp_var(block, var.dtype)
...@@ -389,6 +409,7 @@ def monkey_patch_variable(): ...@@ -389,6 +409,7 @@ def monkey_patch_variable():
('cpu', cpu), ('cpu', cpu),
('cuda', cuda), ('cuda', cuda),
('append', append), ('append', append),
('pop', pop),
('dim', lambda x: len(x.shape)), ('dim', lambda x: len(x.shape)),
('ndimension', lambda x: len(x.shape)), ('ndimension', lambda x: len(x.shape)),
('ndim', _ndim_), ('ndim', _ndim_),
......
...@@ -39,6 +39,19 @@ class JudgeVisitor(gast.NodeVisitor): ...@@ -39,6 +39,19 @@ class JudgeVisitor(gast.NodeVisitor):
self.generic_visit(node) self.generic_visit(node)
class JudgePushPopVisitor(gast.NodeVisitor):
def __init__(self, push_pop_vars):
self.pp_var = push_pop_vars
def visit_FunctionDef(self, node):
scope = node.pd_scope
expected = self.pp_var.get(node.name, set())
assert scope.push_pop_vars == expected, "Not Equals in function:{} . expect {} , but get {}".format(
node.name, expected, scope.push_pop_vars)
self.generic_visit(node)
def test_normal_0(x): def test_normal_0(x):
def func(): def func():
...@@ -88,9 +101,67 @@ def test_nonlocal(x, *args, **kargs): ...@@ -88,9 +101,67 @@ def test_nonlocal(x, *args, **kargs):
return x return x
def test_push_pop_1(x, *args, **kargs):
""" push_pop_vars in main_function is : `l`, `k`
"""
l = []
k = []
for i in range(10):
l.append(i)
k.pop(i)
return l
def test_push_pop_2(x, *args, **kargs):
""" push_pop_vars in main_function is : `k`
"""
l = []
k = []
def func():
l.append(0)
for i in range(10):
k.append(i)
return l, k
def test_push_pop_3(x, *args, **kargs):
""" push_pop_vars in main_function is : `k`
NOTE: One may expect `k` and `l` because l
is nonlocal. Name bind analysis is
not implemented yet.
"""
l = []
k = []
def func():
nonlocal l
l.append(0)
for i in range(10):
k.append(i)
return l, k
def test_push_pop_4(x, *args, **kargs):
""" push_pop_vars in main_function is : `k`
"""
l = []
k = []
for i in range(10):
for j in range(10):
if True:
l.append(j)
else:
k.pop()
return l, k
class TestClosureAnalysis(unittest.TestCase): class TestClosureAnalysis(unittest.TestCase):
def setUp(self): def setUp(self):
self.judge_type = "var and w_vars"
self.init_dygraph_func() self.init_dygraph_func()
def init_dygraph_func(self): def init_dygraph_func(self):
...@@ -132,12 +203,20 @@ class TestClosureAnalysis(unittest.TestCase): ...@@ -132,12 +203,20 @@ class TestClosureAnalysis(unittest.TestCase):
] ]
def test_main(self): def test_main(self):
for mod, ans, func in zip(self.modified_var, self.answer, if self.judge_type == 'push_pop_vars':
self.all_dygraph_funcs): for push_pop_vars, func in zip(self.push_pop_vars,
test_func = inspect.getsource(func) self.all_dygraph_funcs):
gast_root = gast.parse(test_func) test_func = inspect.getsource(func)
name_visitor = FunctionNameLivenessAnalysis(gast_root) gast_root = gast.parse(test_func)
JudgeVisitor(ans, mod).visit(gast_root) name_visitor = FunctionNameLivenessAnalysis(gast_root)
JudgePushPopVisitor(push_pop_vars).visit(gast_root)
else:
for mod, ans, func in zip(self.modified_var, self.answer,
self.all_dygraph_funcs):
test_func = inspect.getsource(func)
gast_root = gast.parse(test_func)
name_visitor = FunctionNameLivenessAnalysis(gast_root)
JudgeVisitor(ans, mod).visit(gast_root)
def TestClosureAnalysis_Attribute_func(): def TestClosureAnalysis_Attribute_func():
...@@ -158,5 +237,25 @@ class TestClosureAnalysis_Attribute(TestClosureAnalysis): ...@@ -158,5 +237,25 @@ class TestClosureAnalysis_Attribute(TestClosureAnalysis):
}] }]
class TestClosureAnalysis_PushPop(TestClosureAnalysis):
def init_dygraph_func(self):
self.judge_type = "push_pop_vars"
self.all_dygraph_funcs = [
test_push_pop_1, test_push_pop_2, test_push_pop_3, test_push_pop_4
]
self.push_pop_vars = [{
"test_push_pop_1": set({'l', 'k'}),
}, {
"test_push_pop_2": set({'k'}),
"func": set("l"),
}, {
"test_push_pop_3": set({'k'}),
"func": set("l"),
}, {
"test_push_pop_4": set({'k', 'l'}),
}]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -254,13 +254,13 @@ class TestListWithoutControlFlow(unittest.TestCase): ...@@ -254,13 +254,13 @@ class TestListWithoutControlFlow(unittest.TestCase):
dy_res, dy_res,
rtol=1e-05, rtol=1e-05,
err_msg='dygraph_res is {}\nstatic_res is {}'.format( err_msg='dygraph_res is {}\nstatic_res is {}'.format(
stat_res, dy_res)) dy_res, stat_res))
class TestListInIf(TestListWithoutControlFlow): class TestListInIf(TestListWithoutControlFlow):
def init_dygraph_func(self): def init_dygraph_func(self):
self.all_dygraph_funcs = [test_list_append_in_if, test_list_pop_in_if] self.all_dygraph_funcs = [test_list_append_in_if]
class TestListInWhileLoop(TestListWithoutControlFlow): class TestListInWhileLoop(TestListWithoutControlFlow):
......
...@@ -89,9 +89,12 @@ class StaticCode1(): ...@@ -89,9 +89,12 @@ class StaticCode1():
x_v = x_v + 1 x_v = x_v + 1
return return
_jst.IfElse( _jst.IfElse(paddle.mean(x_v)[0] > 5,
paddle.mean(x_v)[0] > 5, true_fn_0, false_fn_0, get_args_0, true_fn_0,
set_args_0, ('x_v', )) false_fn_0,
get_args_0,
set_args_0, ('x_v', ),
push_pop_names=None)
def get_args_1(): def get_args_1():
nonlocal __return_0, __return_1, __return_value_0, loss nonlocal __return_0, __return_1, __return_value_0, loss
...@@ -114,9 +117,13 @@ class StaticCode1(): ...@@ -114,9 +117,13 @@ class StaticCode1():
__return_value_0 = x_v __return_value_0 = x_v
return return
_jst.IfElse(label is not None, true_fn_1, false_fn_1, get_args_1, _jst.IfElse(label is not None,
true_fn_1,
false_fn_1,
get_args_1,
set_args_1, set_args_1,
('__return_0', '__return_1', '__return_value_0', 'loss')) ('__return_0', '__return_1', '__return_value_0', 'loss'),
push_pop_names=None)
return __return_value_0 return __return_value_0
...@@ -146,9 +153,12 @@ class StaticCode2(): ...@@ -146,9 +153,12 @@ class StaticCode2():
x_v = x_v + 1 x_v = x_v + 1
return return
_jst.IfElse( _jst.IfElse(paddle.mean(x_v)[0] > 5,
paddle.mean(x_v)[0] > 5, true_fn_2, false_fn_2, get_args_2, true_fn_2,
set_args_2, ('x_v', )) false_fn_2,
get_args_2,
set_args_2, ('x_v', ),
push_pop_names=None)
def get_args_3(): def get_args_3():
nonlocal __return_2, __return_3, __return_value_1, loss nonlocal __return_2, __return_3, __return_value_1, loss
...@@ -171,9 +181,13 @@ class StaticCode2(): ...@@ -171,9 +181,13 @@ class StaticCode2():
__return_value_1 = x_v __return_value_1 = x_v
return return
_jst.IfElse(label is not None, true_fn_3, false_fn_3, get_args_3, _jst.IfElse(label is not None,
true_fn_3,
false_fn_3,
get_args_3,
set_args_3, set_args_3,
('__return_2', '__return_3', '__return_value_1', 'loss')) ('__return_2', '__return_3', '__return_value_1', 'loss'),
push_pop_names=None)
return __return_value_1 return __return_value_1
...@@ -195,7 +209,7 @@ class TestDygraphToStaticCode(unittest.TestCase): ...@@ -195,7 +209,7 @@ class TestDygraphToStaticCode(unittest.TestCase):
def test_decorator(self): def test_decorator(self):
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else) code = program_translator.get_code(dyfunc_with_if_else)
#print(code) print(code)
answer = get_source_code(StaticCode1.dyfunc_with_if_else) answer = get_source_code(StaticCode1.dyfunc_with_if_else)
self.assertEqual( self.assertEqual(
answer.replace('\n', '').replace(' ', ''), answer.replace('\n', '').replace(' ', ''),
...@@ -205,6 +219,7 @@ class TestDygraphToStaticCode(unittest.TestCase): ...@@ -205,6 +219,7 @@ class TestDygraphToStaticCode(unittest.TestCase):
answer = get_source_code(StaticCode2.dyfunc_with_if_else) answer = get_source_code(StaticCode2.dyfunc_with_if_else)
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else) code = program_translator.get_code(dyfunc_with_if_else)
print(code)
self.assertEqual( self.assertEqual(
answer.replace('\n', '').replace(' ', ''), answer.replace('\n', '').replace(' ', ''),
code.replace('\n', '').replace(' ', '')) code.replace('\n', '').replace(' ', ''))
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
from paddle.fluid.dygraph.dygraph_to_static.utils import GetterSetterHelper
vars = [1, 2, 3, 4, 5]
def getter():
return vars
def setter(values):
global vars
vars = values
class TestGetterSetterHelper(unittest.TestCase):
def test_1(self):
helper = GetterSetterHelper(getter, setter, ['a', 'b', 'e'],
['d', 'f', 'e'])
print(helper.union())
expect_union = ['a', 'b', 'd', 'e', 'f']
assert helper.union() == expect_union
assert helper.get(expect_union) == (1, 2, 3, 4, 5)
helper.set(['a', 'b'], [1, 1])
assert vars == [1, 1, 3, 4, 5]
helper.set(['f', 'e'], [12, 10])
assert vars == [1, 1, 3, 10, 12]
helper.set(None, None)
assert vars == [1, 1, 3, 10, 12]
assert helper.get(None) == tuple()
assert helper.get([]) == tuple()
if __name__ == '__main__':
unittest.main()
...@@ -551,8 +551,37 @@ def _getitem_impl_(var, item): ...@@ -551,8 +551,37 @@ def _getitem_impl_(var, item):
return out return out
def _setitem_for_tensor_array(var, item, value):
""" branches for tensor array setitem operation.
A item can be a:
(1) int/Variable, which is a simple number/variable such as [1], [-2]
(2) Slice, which is represented by bounds such as [2:-1]
(3) Tuple, which includes the above two cases such as [2:-1, 1]
If item is case (1), we perform paddle.tensor.array_write,
in other cases, we raise a NotImplementedError.
"""
from ..framework import LayerHelper, core, _non_static_mode
from .framework import Variable
assert not _non_static_mode(
), "setitem for tensor_array must be called in static graph mode."
if isinstance(item, (Variable, int)):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle import cast
from paddle.tensor import array_write
item = paddle.cast(to_static_variable(item), dtype='int64')
value = to_static_variable(value)
array_write(x=value, i=item, array=var)
else:
raise NotImplementedError(
"Only support __setitem__ by Int/Variable in tensor_array, but gets {}"
.format(type(item)))
def _setitem_impl_(var, item, value): def _setitem_impl_(var, item, value):
from .framework import default_main_program, Variable from .framework import default_main_program, Variable
from paddle.fluid import core
if var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
return _setitem_for_tensor_array(var, item, value)
inputs = {'Input': var} inputs = {'Input': var}
if isinstance(item, list): if isinstance(item, list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册