未验证 提交 6dd70b9b 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Cherry-pick] Add Static Variable Clone (#30208) #30270

Cherry-pick of PR #30208 , this PR added clone method for static Variable so that this interface will be same as dygraph. It fixed some bugs in dy2stat where users called clone of dygraph Tensor.
上级 fb66355e
......@@ -80,9 +80,12 @@ class TraceBackFrame(OriginInfo):
self.source_code = source_code
def formated_message(self):
# self.source_code may be empty in some functions.
# For example, decorator generated function
return ' File "{}", line {}, in {}\n\t{}'.format(
self.location.filepath, self.location.lineno, self.function_name,
self.source_code.lstrip())
self.source_code.lstrip()
if isinstance(self.source_code, str) else self.source_code)
class ErrorData(object):
......
......@@ -246,8 +246,10 @@ def _static_only_(func):
def _fake_interface_only_(func):
def __impl__(*args, **kwargs):
raise AssertionError(
"'%s' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
% func.__name__)
"'%s' should be called by imperative Varible in imperative mode, please run it in dygraph "
"mode. You can turn off paddle.enable_static() if you are in static mode, or turn off "
"ProgramTranslator if you are using @paddle.jit.to_static" %
func.__name__)
return __impl__
......@@ -1625,6 +1627,40 @@ class Variable(object):
"""
return self.desc.type()
def clone(self):
"""
Returns a new static Variable, which is the clone of the original static
Variable. It remains in the current graph, that is, the cloned Variable
provides gradient propagation. Calling ``out = tensor.clone()`` is same
as ``out = assign(tensor)`` .
Returns:
Variable: The cloned Variable.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
# create a static Variable
x = paddle.static.data(name='x', shape=[3, 2, 1])
# create a cloned Variable
y = x.clone()
"""
output = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name + "_clone"),
dtype=self.dtype,
type=self.type,
persistable=self.persistable,
stop_gradient=self.stop_gradient)
self.block.append_op(
type='assign', inputs={'X': [self]}, outputs={'Out': [output]})
return output
def _set_error_clip(self, error_clip):
"""
Set the error_clip.
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy
import paddle
import unittest
@paddle.jit.to_static
def tensor_clone(x):
x = paddle.to_tensor(x)
y = x.clone()
return y
class TestTensorClone(unittest.TestCase):
def _run(self, to_static):
prog_trans = paddle.jit.ProgramTranslator()
prog_trans.enable(to_static)
x = paddle.ones([1, 2, 3])
return tensor_clone(x).numpy()
def test_tensor_clone(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
self.assertTrue(
numpy.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
@paddle.jit.to_static
def tensor_numpy(x):
x = paddle.to_tensor(x)
x.clear_gradient()
return x
class TestTensorDygraphOnlyMethodError(unittest.TestCase):
def _run(self, to_static):
prog_trans = paddle.jit.ProgramTranslator()
prog_trans.enable(to_static)
x = paddle.zeros([2, 2])
y = tensor_numpy(x)
return y.numpy()
def test_to_static_numpy_report_error(self):
dygraph_res = self._run(to_static=False)
with self.assertRaises(AssertionError):
static_res = self._run(to_static=True)
if __name__ == '__main__':
unittest.main()
......@@ -157,9 +157,12 @@ class Test_Detach(unittest.TestCase):
except Exception as e:
# Here is to check
assert type(e) == AssertionError
assert str(
e
) == "'detach' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
assert str(e) == (
"'detach' should be called by imperative Varible "
"in imperative mode, please run it in dygraph mode. You can "
"turn off paddle.enable_static() if you are in static mode, "
"or turn off ProgramTranslator if you are using "
"@paddle.jit.to_static")
class TestInplace(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册