From c372a763036db7d52a7283423f32e7a037a3b773 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Mon, 11 Jan 2021 10:55:27 +0800 Subject: [PATCH] Add Static Variable Clone (#30208) Add clone method for static Variable so that this interface will be same as dygraph. It fixed some bugs in dy2stat --- .../fluid/dygraph/dygraph_to_static/error.py | 5 +- python/paddle/fluid/framework.py | 40 ++++++++++- .../dygraph_to_static/test_tensor_methods.py | 67 +++++++++++++++++++ .../fluid/tests/unittests/test_detach.py | 9 ++- 4 files changed, 115 insertions(+), 6 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_methods.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/error.py b/python/paddle/fluid/dygraph/dygraph_to_static/error.py index 350e0ad5d7..a994fbb107 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/error.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/error.py @@ -80,9 +80,12 @@ class TraceBackFrame(OriginInfo): self.source_code = source_code def formated_message(self): + # self.source_code may be empty in some functions. + # For example, decorator generated function return ' File "{}", line {}, in {}\n\t{}'.format( self.location.filepath, self.location.lineno, self.function_name, - self.source_code.lstrip()) + self.source_code.lstrip() + if isinstance(self.source_code, str) else self.source_code) class ErrorData(object): diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 39005d9a98..143b4a8f71 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -246,8 +246,10 @@ def _static_only_(func): def _fake_interface_only_(func): def __impl__(*args, **kwargs): raise AssertionError( - "'%s' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode" - % func.__name__) + "'%s' should be called by imperative Varible in imperative mode, please run it in dygraph " + "mode. You can turn off paddle.enable_static() if you are in static mode, or turn off " + "ProgramTranslator if you are using @paddle.jit.to_static" % + func.__name__) return __impl__ @@ -1629,6 +1631,40 @@ class Variable(object): """ return self.desc.type() + def clone(self): + """ + Returns a new static Variable, which is the clone of the original static + Variable. It remains in the current graph, that is, the cloned Variable + provides gradient propagation. Calling ``out = tensor.clone()`` is same + as ``out = assign(tensor)`` . + + Returns: + Variable: The cloned Variable. + + Examples: + .. code-block:: python + + import paddle + + paddle.enable_static() + + # create a static Variable + x = paddle.static.data(name='x', shape=[3, 2, 1]) + # create a cloned Variable + y = x.clone() + + """ + output = self.block.create_var( + name=unique_name.generate_with_ignorable_key(self.name + "_clone"), + dtype=self.dtype, + type=self.type, + persistable=self.persistable, + stop_gradient=self.stop_gradient) + + self.block.append_op( + type='assign', inputs={'X': [self]}, outputs={'Out': [output]}) + return output + def _set_error_clip(self, error_clip): """ Set the error_clip. diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_methods.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_methods.py new file mode 100644 index 0000000000..f06d48c963 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_methods.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy +import paddle +import unittest + + +@paddle.jit.to_static +def tensor_clone(x): + x = paddle.to_tensor(x) + y = x.clone() + return y + + +class TestTensorClone(unittest.TestCase): + def _run(self, to_static): + prog_trans = paddle.jit.ProgramTranslator() + prog_trans.enable(to_static) + x = paddle.ones([1, 2, 3]) + return tensor_clone(x).numpy() + + def test_tensor_clone(self): + dygraph_res = self._run(to_static=False) + static_res = self._run(to_static=True) + self.assertTrue( + numpy.allclose(dygraph_res, static_res), + msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res, + static_res)) + + +@paddle.jit.to_static +def tensor_numpy(x): + x = paddle.to_tensor(x) + x.clear_gradient() + return x + + +class TestTensorDygraphOnlyMethodError(unittest.TestCase): + def _run(self, to_static): + prog_trans = paddle.jit.ProgramTranslator() + prog_trans.enable(to_static) + x = paddle.zeros([2, 2]) + y = tensor_numpy(x) + return y.numpy() + + def test_to_static_numpy_report_error(self): + dygraph_res = self._run(to_static=False) + with self.assertRaises(AssertionError): + static_res = self._run(to_static=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_detach.py b/python/paddle/fluid/tests/unittests/test_detach.py index 431c987a51..9a535f9e00 100644 --- a/python/paddle/fluid/tests/unittests/test_detach.py +++ b/python/paddle/fluid/tests/unittests/test_detach.py @@ -157,9 +157,12 @@ class Test_Detach(unittest.TestCase): except Exception as e: # Here is to check assert type(e) == AssertionError - assert str( - e - ) == "'detach' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode" + assert str(e) == ( + "'detach' should be called by imperative Varible " + "in imperative mode, please run it in dygraph mode. You can " + "turn off paddle.enable_static() if you are in static mode, " + "or turn off ProgramTranslator if you are using " + "@paddle.jit.to_static") class TestInplace(unittest.TestCase): -- GitLab