未验证 提交 c372a763 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add Static Variable Clone (#30208)

Add clone method for static Variable so that this interface will be same as dygraph. It fixed some bugs in dy2stat
上级 fee42441
......@@ -80,9 +80,12 @@ class TraceBackFrame(OriginInfo):
self.source_code = source_code
def formated_message(self):
# self.source_code may be empty in some functions.
# For example, decorator generated function
return ' File "{}", line {}, in {}\n\t{}'.format(
self.location.filepath, self.location.lineno, self.function_name,
self.source_code.lstrip())
self.source_code.lstrip()
if isinstance(self.source_code, str) else self.source_code)
class ErrorData(object):
......
......@@ -246,8 +246,10 @@ def _static_only_(func):
def _fake_interface_only_(func):
def __impl__(*args, **kwargs):
raise AssertionError(
"'%s' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
% func.__name__)
"'%s' should be called by imperative Varible in imperative mode, please run it in dygraph "
"mode. You can turn off paddle.enable_static() if you are in static mode, or turn off "
"ProgramTranslator if you are using @paddle.jit.to_static" %
func.__name__)
return __impl__
......@@ -1629,6 +1631,40 @@ class Variable(object):
"""
return self.desc.type()
def clone(self):
"""
Returns a new static Variable, which is the clone of the original static
Variable. It remains in the current graph, that is, the cloned Variable
provides gradient propagation. Calling ``out = tensor.clone()`` is same
as ``out = assign(tensor)`` .
Returns:
Variable: The cloned Variable.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
# create a static Variable
x = paddle.static.data(name='x', shape=[3, 2, 1])
# create a cloned Variable
y = x.clone()
"""
output = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name + "_clone"),
dtype=self.dtype,
type=self.type,
persistable=self.persistable,
stop_gradient=self.stop_gradient)
self.block.append_op(
type='assign', inputs={'X': [self]}, outputs={'Out': [output]})
return output
def _set_error_clip(self, error_clip):
"""
Set the error_clip.
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy
import paddle
import unittest
@paddle.jit.to_static
def tensor_clone(x):
x = paddle.to_tensor(x)
y = x.clone()
return y
class TestTensorClone(unittest.TestCase):
def _run(self, to_static):
prog_trans = paddle.jit.ProgramTranslator()
prog_trans.enable(to_static)
x = paddle.ones([1, 2, 3])
return tensor_clone(x).numpy()
def test_tensor_clone(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
self.assertTrue(
numpy.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
@paddle.jit.to_static
def tensor_numpy(x):
x = paddle.to_tensor(x)
x.clear_gradient()
return x
class TestTensorDygraphOnlyMethodError(unittest.TestCase):
def _run(self, to_static):
prog_trans = paddle.jit.ProgramTranslator()
prog_trans.enable(to_static)
x = paddle.zeros([2, 2])
y = tensor_numpy(x)
return y.numpy()
def test_to_static_numpy_report_error(self):
dygraph_res = self._run(to_static=False)
with self.assertRaises(AssertionError):
static_res = self._run(to_static=True)
if __name__ == '__main__':
unittest.main()
......@@ -157,9 +157,12 @@ class Test_Detach(unittest.TestCase):
except Exception as e:
# Here is to check
assert type(e) == AssertionError
assert str(
e
) == "'detach' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
assert str(e) == (
"'detach' should be called by imperative Varible "
"in imperative mode, please run it in dygraph mode. You can "
"turn off paddle.enable_static() if you are in static mode, "
"or turn off ProgramTranslator if you are using "
"@paddle.jit.to_static")
class TestInplace(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册