未验证 提交 7849d58d 编写于 作者: J JYChen 提交者: GitHub

disable __setitem__ in static mode & add API paddle.static.setitem with dy2st strategy (#53682)

* add paddle.static.setitem

* add some help doc

* support setitem

* support machanism

* add more unittest

* remove usless code

* raise error in static setitem

* fix d2s UT

* remove static only for both-used code

* fix bool set_value in static, fix set_value_op UT

* fix unittests

* [May case some error]: remove inplace-version check

* add two test case for dy2st

* fix function in vision

* fix dy2st setitem support, refine UT case

* fix slice in static_mode

* add ParametersMap

* remove pop

* modify place

* [fix]: variable is also a tensor

* rewrite some ut & remove slicetransformer in dy2st

* solve error in static-mode

* fix ut

* return a result for set_array_write

* fix test_set_value_op_xpu

* code is different in dynamic / static mode

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
Co-authored-by: NNotHaozi <zhangmenghao@baidu.com>
上级 56d46ccc
......@@ -130,7 +130,12 @@ class Multinomial(distribution.Distribution):
logits, value = paddle.broadcast_tensors(
[paddle.log(self.probs), value]
)
logits[(value == 0) & (paddle.isinf(logits))] = 0
if paddle.in_dynamic_mode():
logits[(value == 0) & (paddle.isinf(logits))] = 0
else:
logits = paddle.static.setitem(
logits, (value == 0) & (paddle.isinf(logits)), 0
)
return (
paddle.lgamma(value.sum(-1) + 1)
......
......@@ -2295,7 +2295,14 @@ class Variable(metaclass=VariableMetaClass):
return _getitem_impl_(self, item)
def __setitem__(self, item, value):
return _setitem_impl_(self, item, value)
from .dygraph.base import in_declarative_mode
if in_declarative_mode():
return _setitem_impl_(self, item, value)
else:
raise RuntimeError(
"In static mode, the __setitem__ (looks like: x[indices] = values) should not be used. Please use x = paddle.static.setitem(x, indices, values)"
)
def get_value(self, scope=None):
"""
......
......@@ -215,7 +215,9 @@ class SliceInfo:
out = paddle.scatter(t_reshape, index_1d, value_1d)
if tensor_type is not None:
out = out.astype(tensor_type)
tensor_origin[:] = out.reshape(tensor_origin.shape)
tensor_origin = _setitem_impl_(
tensor_origin, ..., out.reshape(tensor_origin.shape)
)
return tensor_origin
......@@ -617,7 +619,7 @@ def _setitem_for_tensor_array(var, item, value):
If item is case (1), we perform paddle.tensor.array_write,
in other cases, we raise a NotImplementedError.
"""
from ..framework import LayerHelper, core
from .framework import Variable
assert (
......@@ -632,7 +634,7 @@ def _setitem_for_tensor_array(var, item, value):
item = paddle.cast(to_static_variable(item), dtype='int64')
value = to_static_variable(value)
array_write(x=value, i=item, array=var)
return array_write(x=value, i=item, array=var)
else:
raise NotImplementedError(
"Only support __setitem__ by Int/Variable in tensor_array, but gets {}".format(
......@@ -807,17 +809,31 @@ def _setitem_impl_(var, item, value):
if paddle.in_dynamic_mode():
var._bump_inplace_version()
output = var
else:
helper = paddle.fluid.layer_helper.LayerHelper('set_value', **locals())
output = helper.create_variable_for_type_inference(dtype=var.dtype)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': var},
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)
return var
if not paddle.in_dynamic_mode():
# map var to the new output
from paddle.jit.dy2static.program_translator import (
ProgramTranslator,
)
ProgramTranslator.get_instance()._params_map.add(
cur_block.program, var.desc.id(), output
)
return output
# the item is a tensor of bool
......@@ -848,11 +864,19 @@ def set_value_for_bool_tensor(var, item, value):
gather_val = gather_nd(var, idx)
gather_val_new = value - gather_val
out = scatter_nd_add(var, idx, gather_val_new)
var[:] = out
var = _setitem_impl_(var, ..., out)
return var
def idx_is_empty(var):
return var
from paddle.static.nn import cond
# If all the bool index is False, just do nothing
cond(item.any(), lambda: idx_not_empty(var, item, value))
var = cond(
item.any(),
lambda: idx_not_empty(var, item, value),
lambda: idx_is_empty(var),
)
return var
......@@ -178,16 +178,23 @@ def minimize_lbfgs(
shape=[], fill_value=(head - 1).mod(history_size), dtype='int64'
)
def cond(i, q):
def cond(i, q, ai_vec):
return i != tail
def body(i, q):
ai_vec[i] = rhok_vec[i] * paddle.dot(sk_vec[i], q)
def body(i, q, ai_vec):
if paddle.in_dynamic_mode():
ai_vec[i] = rhok_vec[i] * paddle.dot(sk_vec[i], q)
else:
ai_vec = paddle.static.setitem(
ai_vec, i, rhok_vec[i] * paddle.dot(sk_vec[i], q)
)
q = q - ai_vec[i] * yk_vec[i]
i = (i - 1).mod(history_size)
return i, q
return i, q, ai_vec
paddle.static.nn.while_loop(cond=cond, body=body, loop_vars=[i, q])
paddle.static.nn.while_loop(
cond=cond, body=body, loop_vars=[i, q, ai_vec]
)
r = paddle.matmul(H0, q)
......@@ -234,10 +241,14 @@ def minimize_lbfgs(
lambda: paddle.full(shape=[1], fill_value=1000.0, dtype=dtype),
lambda: 1.0 / rhok_inv,
)
sk_vec[head] = sk
yk_vec[head] = yk
rhok_vec[head] = rhok
if paddle.in_dynamic_mode():
sk_vec[head] = sk
yk_vec[head] = yk
rhok_vec[head] = rhok
else:
sk_vec = paddle.static.setitem(sk_vec, head, sk)
yk_vec = paddle.static.setitem(yk_vec, head, yk)
rhok_vec = paddle.static.setitem(rhok_vec, head, rhok)
head = (head + 1) % history_size
def true_fn(tail):
......
......@@ -20,7 +20,7 @@ from paddle.fluid.dygraph.base import (
_convert_into_variable,
in_declarative_mode,
)
from paddle.fluid.framework import Variable, core
from paddle.fluid.framework import Variable, core, default_main_program
from paddle.fluid.layers import control_flow
from paddle.fluid.layers.control_flow import while_loop
......@@ -48,6 +48,19 @@ def convert_load(x):
TODO:(@xiongkun) may run convert_load in dygraph mode, which should be fixed.
"""
return _convert_into_variable(x)
# get the new output of the var
if in_declarative_mode() and isinstance(x, Variable):
cur_block = default_main_program().current_block()
from paddle.jit.dy2static.program_translator import ProgramTranslator
new_var = ProgramTranslator.get_instance()._params_map.get(
cur_block.program, x.desc.id()
)
if new_var is not None:
return new_var
return x
......
......@@ -1125,6 +1125,36 @@ class ParametersRecorder:
return id(program)
class ParametersMap:
def __init__(self):
self.params_dict = {}
@synchronized
def add(self, program, id, param):
"""use the default_program as key, append param the parameter list."""
key = self._program_hash(program)
if key not in self.params_dict:
self.params_dict[key] = {}
params = self.params_dict[key]
params[id] = param
def get(self, program, id):
params = self.params_dict.get(self._program_hash(program))
if params is None:
return None
if id in params.keys():
return params[id]
return None
def _program_hash(self, program):
"""
because program is not deleted while calling from_func_spec.
so it's ok to use id(program)
"""
return id(program)
class FallbackProgramLayer:
__slots__ = [
'_instance',
......@@ -1386,6 +1416,7 @@ class ProgramTranslator:
self._initialized = True
self._program_cache = ProgramCache()
self._params_recorder = ParametersRecorder()
self._params_map = ParametersMap()
self.enable_to_static = True
def enable(self, enable_to_static):
......
......@@ -37,6 +37,7 @@ from .io import set_program_state # noqa: F401
from ..fluid import Scope # noqa: F401
from .input import data # noqa: F401
from .input import InputSpec # noqa: F401
from .input import setitem # noqa: F401
from ..tensor.creation import create_parameter # noqa: F401
from ..tensor.creation import create_global_var # noqa: F401
......
......@@ -18,6 +18,8 @@ from paddle.fluid.data_feeder import check_type
from paddle.fluid.framework import convert_np_dtype_to_dtype_, static_only
from paddle.fluid.layer_helper import LayerHelper
from ..fluid.variable_index import _setitem_impl_
__all__ = []
......@@ -342,3 +344,28 @@ class InputSpec:
def __ne__(self, other):
return not self == other
def setitem(x, index, value):
"""
x(Tensor): input Tensor.
index(Scalar|Tuple|List|Tensor): Where should be set value.
value(Scalar|Tensor): The value which is going to be set.
[How to write index?]
1. ':' -> slice(),
(1) a[:]=v -> setitem(a, slice(None,None,None), v)
(2) a[1::2] -> setitem(a, slice(1,None,2), v)
2. if there are multiple indexes for axes, use TUPLE (Not LIST) to pack them.
(1) a[1, 2]=v -> setitem(a, (1, 2), v)
(2) a[[1,2],[2,3]]=v -> setitem(a, ([1,2],[2,3]), v)
(3) a[1,:, 3] = v -> setitem(a, (1, slice(None,None,None),3), v)
(4) a[1, ..., 2]=v -> setitem(a, (1, ..., 2), v)
3. You can always use TUPLE as index input, even there is only one index.
(1) a[Tensor([10,10])]=v -> setitem(a, (Tensor([10,10]),), v)
(2) a[1] = v -> setitem(a, (1,), v)
"""
return _setitem_impl_(x, index, value)
......@@ -5788,11 +5788,26 @@ def vander(x, n=None, increasing=False, name=None):
res = paddle.empty([x.shape[0], n], dtype=x.dtype)
if n > 0:
res[:, 0] = paddle.to_tensor([1], dtype=x.dtype)
if n > 1:
res[:, 1:] = x[:, None]
res[:, 1:] = paddle.cumprod(res[:, 1:], dim=-1)
if paddle.in_dynamic_mode():
if n > 0:
res[:, 0] = paddle.to_tensor([1], dtype=x.dtype)
if n > 1:
res[:, 1:] = x[:, None]
res[:, 1:] = paddle.cumprod(res[:, 1:], dim=-1)
else:
if n > 0:
res = paddle.static.setitem(
res, (slice(None), 0), paddle.to_tensor([1], dtype=x.dtype)
)
if n > 1:
res = paddle.static.setitem(
res, (slice(None), slice(1, None)), x[:, None]
)
res = paddle.static.setitem(
res,
(slice(None), slice(1, None)),
paddle.cumprod(res[:, 1:], dim=-1),
)
res = res[:, ::-1] if not increasing else res
return res
......
......@@ -222,12 +222,12 @@ def _affine_grid(theta, w, h, ow, oh):
base_grid = paddle.ones((1, oh, ow, 3), dtype=theta.dtype)
x_grid = paddle.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, ow)
base_grid[..., 0] = x_grid
if paddle.in_dynamic_mode():
y_grid = paddle.linspace(
-oh * 0.5 + d, oh * 0.5 + d - 1, oh
).unsqueeze_(-1)
base_grid[..., 0] = x_grid
base_grid[..., 1] = y_grid
tmp = paddle.to_tensor([0.5 * w, 0.5 * h])
else:
......@@ -236,7 +236,8 @@ def _affine_grid(theta, w, h, ow, oh):
y_grid = paddle.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, oh).unsqueeze(
-1
)
base_grid[..., 1] = y_grid
base_grid = paddle.static.setitem(base_grid, (..., 0), x_grid)
base_grid = paddle.static.setitem(base_grid, (..., 1), y_grid)
tmp = paddle.assign(np.array([0.5 * w, 0.5 * h], dtype="float32"))
scaled_theta = theta.transpose((0, 2, 1)) / tmp
......@@ -397,6 +398,17 @@ def rotate(
0.0,
]
matrix = paddle.to_tensor(matrix, place=img.place)
matrix[2] += (
matrix[0] * (-rotn_center[0] - post_trans[0])
+ matrix[1] * (-rotn_center[1] - post_trans[1])
+ rotn_center[0]
)
matrix[5] += (
matrix[3] * (-rotn_center[0] - post_trans[0])
+ matrix[4] * (-rotn_center[1] - post_trans[1])
+ rotn_center[1]
)
else:
angle = angle / 180 * math.pi
matrix = paddle.concat(
......@@ -409,16 +421,22 @@ def rotate(
paddle.zeros([1]),
]
)
matrix[2] += matrix[0] * (-rotn_center[0] - post_trans[0]) + matrix[1] * (
-rotn_center[1] - post_trans[1]
)
matrix[5] += matrix[3] * (-rotn_center[0] - post_trans[0]) + matrix[4] * (
-rotn_center[1] - post_trans[1]
)
matrix[2] += rotn_center[0]
matrix[5] += rotn_center[1]
matrix = paddle.static.setitem(
matrix,
2,
matrix[2]
+ matrix[0] * (-rotn_center[0] - post_trans[0])
+ matrix[1] * (-rotn_center[1] - post_trans[1])
+ rotn_center[0],
)
matrix = paddle.static.setitem(
matrix,
5,
matrix[5]
+ matrix[3] * (-rotn_center[0] - post_trans[0])
+ matrix[4] * (-rotn_center[1] - post_trans[1])
+ rotn_center[1],
)
matrix = matrix.reshape((1, 2, 3))
......@@ -621,7 +639,12 @@ def erase(img, i, j, h, w, v, inplace=False):
if not inplace:
img = img.clone()
img[..., i : i + h, j : j + w] = v
if paddle.in_dynamic_mode():
img[..., i : i + h, j : j + w] = v
else:
img = paddle.static.setitem(
img, (..., slice(i, i + h), slice(j, j + w)), v
)
return img
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class TestSetItemBase(unittest.TestCase):
def setUp(self) -> None:
pass
def init_data(self):
paddle.seed(2023)
x = paddle.randn([4, 8, 16, 32])
x.stop_gradient = False
return x
def init_func(self):
def foo(x):
y = x + 1
y[:, 2] = x[:, 2] + 99
return y
return foo
def test_case(self):
func = self.init_func()
dy_res = self.run_dygrah(func)
st_res = self.run_to_static(func)
for dy_out, st_out in zip(dy_res, st_res):
np.testing.assert_allclose(dy_out.numpy(), st_out.numpy())
def run_dygrah(self, func):
x = self.init_data()
y = func(x)
x_grad = paddle.grad(y, x)[0]
return y, x_grad
def run_to_static(self, func):
func = paddle.jit.to_static(func)
return self.run_dygrah(func)
class TestCase1(TestSetItemBase):
def init_func(self):
def foo(x):
y = x + 1
y[2] = x[2] + 99 # (2, )
return y
return foo
class TestCase2(TestSetItemBase):
def init_func(self):
def foo(x):
y = x + 1
y[:] = x[:] + 99 # slice(None,None,None)
return y
return foo
class TestCase3(TestSetItemBase):
def init_func(self):
def foo(x):
y = x + 1
y[1::2] = x[1::2] + 99 # slice(1,None,2)
return y
return foo
class TestCase4(TestSetItemBase):
def init_func(self):
def foo(x):
y = x + 1
y[1, 2] = x[1, 2] + 99 # (1, 2)
return y
return foo
class TestCase5(TestSetItemBase):
def init_func(self):
def foo(x):
y = x + 1
y[[1, 2], [2, 3]] = x[[1, 2], [2, 3]] + 99 # ([1,2],[2,3])
return y
return foo
class TestCase6(TestSetItemBase):
def init_func(self):
def foo(x):
y = x + 1
y[1, :, 3] = x[1, :, 3] + 99 # slice(None,None,None),3)
return y
return foo
class TestCase7(TestSetItemBase):
def init_func(self):
def foo(x):
y = x + 1
y[1, ..., 2] = x[1, ..., 2] + 99 # (1, ..., 2)
return y
return foo
class TestCase8(TestSetItemBase):
def init_func(self):
def foo(x):
y = x + 1
index = paddle.to_tensor([1, 2], dtype="int64")
y[index] = x[index] + 99 # Tensor([1,2])
return y
return foo
class TestCase9(TestSetItemBase):
def init_func(self):
def foo(x):
y = x + 1
one = paddle.to_tensor(1, dtype="int64")
two = paddle.to_tensor(2, dtype="int64")
y[one, :, :, 2] = x[1, :, :, two] + 100 # Tensor(1), Tensor(2)
return y
return foo
class TestCase10(TestSetItemBase):
def init_func(self):
def foo(x):
y = x + 1
y[..., 4:6] = y[..., 4:6] * 10000
return y
return foo
class TestCase11(TestSetItemBase):
# Test gradient of value tensor
def init_func(self):
def foo(x, value):
y = x + 1
y[2, 4] = value
return y
return foo
def run_dygrah(self, func):
x = self.init_data()
value = paddle.ones((16, 32))
value.stop_gradient = False
y = func(x, value)
x_grad, value_grad = paddle.grad(y, [x, value])
return y, x_grad, value_grad
if __name__ == '__main__':
unittest.main()
......@@ -129,8 +129,12 @@ class TestSliceWithoutControlFlow(unittest.TestCase):
return self._run(to_static=False)
def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
res = self.dygraph_func(self.input)
func = (
paddle.jit.to_static(self.dygraph_func)
if to_static
else self.dygraph_func
)
res = func(self.input)
return res.numpy()
def run_static_mode(self):
......
......@@ -81,7 +81,8 @@ class TestSetValue(unittest.TestCase):
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.int32)
patch = np.array([41, 42]).astype(np.int32)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)
x_input = np.ones([3, 4], dtype=np.int32)
x_output = x_input.copy()
......@@ -110,10 +111,12 @@ class TestSetValue(unittest.TestCase):
patch = np.array(
[np.iinfo(np.int64).max, np.iinfo(np.int64).min]
).astype(np.int64)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)
x_input = np.ones([3, 4], dtype=np.int64)
x_output = x_input.copy()
x_output[:1, :2] = patch
self.fetch_list = [x.name]
......@@ -142,7 +145,8 @@ class TestSetValue(unittest.TestCase):
patch = np.array(
[np.finfo(np.float32).max, np.finfo(np.float32).min]
).astype(np.float32)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)
x_input = np.ones([3, 4], dtype=np.float32)
x_output = x_input.copy()
......@@ -171,7 +175,8 @@ class TestSetValue(unittest.TestCase):
patch = np.array(
[np.finfo(np.float64).max, np.finfo(np.float64).min]
).astype(np.float64)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)
x_input = np.ones([3, 4], dtype=np.float64)
x_output = x_input.copy()
......@@ -200,7 +205,8 @@ class TestSetValue(unittest.TestCase):
patch = np.array(
[np.finfo(np.float16).max, np.finfo(np.float16).min]
).astype(np.float16)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)
x_input = np.ones([3, 4], dtype=np.float16)
x_output = x_input.copy()
......@@ -227,7 +233,8 @@ class TestSetValue(unittest.TestCase):
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.bool)
patch = np.array([True, False])
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)
x_input = np.ones([3, 4], dtype=bool)
x_output = x_input.copy()
......@@ -257,7 +264,8 @@ class TestSetValue(unittest.TestCase):
paddle.ones([3, 4], dtype=paddle.float32),
)
patch = np.array([42.1 + 42.1j, 42.2 + 42.2j]).astype(np.complex64)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)
x_input = (np.ones([3, 4]) + 1j * np.ones([3, 4])).astype(np.complex64)
x_output = x_input.copy()
......@@ -282,7 +290,8 @@ class TestSetValue(unittest.TestCase):
np.finfo(np.float64).min + 1j * np.finfo(np.float64).max,
]
).astype(np.complex128)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)
x_input = (np.ones([3, 4]) + 1j * np.ones([3, 4])).astype(np.complex128)
x_output = x_input.copy()
......
......@@ -1365,12 +1365,10 @@ class TestVarBaseSetitemBoolIndex(unittest.TestCase):
def _test(self, value):
paddle.disable_static()
self.assertEqual(self.tensor_x.inplace_version, 0)
id_origin = id(self.tensor_x)
index_1 = paddle.to_tensor(np.array([True, False, False, False]))
self.tensor_x[index_1] = value
self.assertEqual(self.tensor_x.inplace_version, 1)
if isinstance(value, (int, float)):
result = np.zeros((2, 3)).astype(self.dtype) + value
......@@ -1383,13 +1381,11 @@ class TestVarBaseSetitemBoolIndex(unittest.TestCase):
index_2 = paddle.to_tensor(np.array([False, True, False, False]))
self.tensor_x[index_2] = value
self.assertEqual(self.tensor_x.inplace_version, 2)
np.testing.assert_array_equal(self.tensor_x[1].numpy(), result)
self.assertEqual(id_origin, id(self.tensor_x))
index_3 = paddle.to_tensor(np.array([True, True, True, True]))
self.tensor_x[index_3] = value
self.assertEqual(self.tensor_x.inplace_version, 3)
np.testing.assert_array_equal(self.tensor_x[3].numpy(), result)
self.assertEqual(id_origin, id(self.tensor_x))
......
......@@ -844,8 +844,7 @@ class TestListIndex(unittest.TestCase):
name='value', shape=value_np.shape, dtype='float32'
)
x[index] = value
y = x
y = paddle.static.setitem(x, index, value)
place = paddle.fluid.CPUPlace()
prog = paddle.static.default_main_program()
......@@ -1042,9 +1041,8 @@ class TestListIndex(unittest.TestCase):
name='index_2', shape=index2.shape, dtype='int32'
)
x1[index_1, index_2] = value
x2[index_1] = value
x1_out = paddle.static.setitem(x1, (index_1, index_2), value)
x2_out = paddle.static.setitem(x2, index_1, value)
place = (
paddle.fluid.CPUPlace()
if not paddle.fluid.core.is_compiled_with_cuda()
......@@ -1055,7 +1053,7 @@ class TestListIndex(unittest.TestCase):
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
fetch_list = [x1.name, x2.name]
fetch_list = [x1_out.name, x2_out.name]
setitem_pp = exe.run(
prog,
......@@ -1124,10 +1122,10 @@ class TestListIndex(unittest.TestCase):
name='x2', shape=array.shape, dtype='float32'
)
x1[index_mod1, index_mod2] = 1
x2[index_mod1] = 2.5
y1 = x1[index_mod2, index_mod1]
y2 = x2[index_mod2]
x1_out = paddle.static.setitem(x1, (index_mod1, index_mod2), 1)
x2_out = paddle.static.setitem(x2, index_mod1, 2.5)
y1 = x1_out[index_mod2, index_mod1]
y2 = x2_out[index_mod2]
place = (
paddle.fluid.CPUPlace()
if not paddle.fluid.core.is_compiled_with_cuda()
......@@ -1137,7 +1135,7 @@ class TestListIndex(unittest.TestCase):
prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
fetch_list = [x1.name, x2.name, y1.name, y2.name]
fetch_list = [x1_out.name, x2_out.name, y1.name, y2.name]
setitem_pp = exe.run(
prog,
......
......@@ -3141,7 +3141,7 @@ class TestSundryAPIStatic(unittest.TestCase):
x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5))
x.stop_gradient = False
out = x * 2
out[1, 2, 3, 4] = 10
out = paddle.static.setitem(out, (1, 2, 3, 4), 10)
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name])
......@@ -3162,7 +3162,7 @@ class TestSundryAPIStatic(unittest.TestCase):
x.stop_gradient = False
indice = paddle.full([], 1, dtype='int32')
out = x * 1
out[indice, indice] = 0.5
out = paddle.static.setitem(out, (indice, indice), 0.5)
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name])
......@@ -3181,7 +3181,7 @@ class TestSundryAPIStatic(unittest.TestCase):
v.stop_gradient = False
indice = paddle.full([], 1, dtype='int32')
out = x * 1
out[indice] = v
out = paddle.static.setitem(out, indice, v)
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name, v.grad_name])
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册