未验证 提交 0eb03ed7 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

add new API: paddle.clone;Tensor.element_size;nn.utils.parameters_to_vector (#38020)

* add new API: paddle.clone;Tensor.element_size;nn.utils.parameters_to_vector

* fix comment
上级 aa059885
...@@ -126,6 +126,8 @@ if(WIN32) ...@@ -126,6 +126,8 @@ if(WIN32)
endforeach(flag_var) endforeach(flag_var)
endif() endif()
# NOTE(zhouwei): msvc max/min macro conflict with std::min/max, define NOMINMAX globally
add_definitions("-DNOMINMAX")
# windows build turn off warnings, use parallel compiling. # windows build turn off warnings, use parallel compiling.
foreach(flag_var foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -116,6 +117,10 @@ proto::VarType::Type VarDesc::GetDataType() const { ...@@ -116,6 +117,10 @@ proto::VarType::Type VarDesc::GetDataType() const {
return tensor_desc().data_type(); return tensor_desc().data_type();
} }
size_t VarDesc::ElementSize() const {
return framework::SizeOfType(GetDataType());
}
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const { std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
std::vector<proto::VarType::TensorDesc> descs = tensor_descs(); std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<proto::VarType::Type> res; std::vector<proto::VarType::Type> res;
......
...@@ -96,6 +96,8 @@ class VarDesc { ...@@ -96,6 +96,8 @@ class VarDesc {
proto::VarType::Type GetDataType() const; proto::VarType::Type GetDataType() const;
size_t ElementSize() const;
std::vector<proto::VarType::Type> GetDataTypes() const; std::vector<proto::VarType::Type> GetDataTypes() const;
void SetLoDLevel(int32_t lod_level); void SetLoDLevel(int32_t lod_level);
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
...@@ -37,7 +38,6 @@ ...@@ -37,7 +38,6 @@
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Variable; class Variable;
...@@ -212,6 +212,8 @@ class VarBase { ...@@ -212,6 +212,8 @@ class VarBase {
framework::proto::VarType::Type DataType() const { return var_->DataType(); } framework::proto::VarType::Type DataType() const { return var_->DataType(); }
size_t ElementSize() const { return framework::SizeOfType(var_->DataType()); }
void SetForwardDataType(framework::proto::VarType::Type data_type) { void SetForwardDataType(framework::proto::VarType::Type data_type) {
var_->SetForwardDataType(data_type); var_->SetForwardDataType(data_type);
} }
......
...@@ -2013,6 +2013,29 @@ void BindImperative(py::module *m_ptr) { ...@@ -2013,6 +2013,29 @@ void BindImperative(py::module *m_ptr) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>(); auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
return t->numel(); return t->numel();
}) })
.def("element_size", &imperative::VarBase::ElementSize, R"DOC(
Returns the size in bytes of an element in the Tensor.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1, dtype='bool')
x.element_size() # 1
x = paddle.to_tensor(1, dtype='float16')
x.element_size() # 2
x = paddle.to_tensor(1, dtype='float32')
x.element_size() # 4
x = paddle.to_tensor(1, dtype='float64')
x.element_size() # 8
x = paddle.to_tensor(1, dtype='complex128')
x.element_size() # 16
)DOC")
.def_property("name", &imperative::VarBase::Name, .def_property("name", &imperative::VarBase::Name,
&imperative::VarBase::SetName) &imperative::VarBase::SetName)
.def_property("stop_gradient", .def_property("stop_gradient",
...@@ -2020,28 +2043,40 @@ void BindImperative(py::module *m_ptr) { ...@@ -2020,28 +2043,40 @@ void BindImperative(py::module *m_ptr) {
&imperative::VarBase::SetOverridedStopGradient) &imperative::VarBase::SetOverridedStopGradient)
.def_property("persistable", &imperative::VarBase::Persistable, .def_property("persistable", &imperative::VarBase::Persistable,
&imperative::VarBase::SetPersistable) &imperative::VarBase::SetPersistable)
.def_property_readonly( .def_property_readonly("shape",
"shape", [](imperative::VarBase &self) {
[](imperative::VarBase &self) { if (self.Var().IsType<framework::LoDTensor>()) {
if (self.Var().IsType<framework::LoDTensor>()) { return framework::vectorize<int>(
return framework::vectorize<int>( self.Var()
self.Var().Get<framework::LoDTensor>().dims()); .Get<framework::LoDTensor>()
} else if (self.Var().IsType<framework::SelectedRows>()) { .dims());
return framework::vectorize<int>( } else if (self.Var()
self.Var().Get<framework::SelectedRows>().value().dims()); .IsType<
} else if (self.Var().IsType<framework::Strings>()) { framework::SelectedRows>()) {
return std::vector<int>{static_cast<int>( return framework::vectorize<int>(
self.Var().Get<framework::Strings>().size())}; self.Var()
} else if (self.Var().IsType<framework::Vocab>()) { .Get<framework::SelectedRows>()
return std::vector<int>{ .value()
static_cast<int>(self.Var().Get<framework::Vocab>().size())}; .dims());
} else { } else if (self.Var()
VLOG(2) << "It is meaningless to get shape of " .IsType<framework::Strings>()) {
"variable type " return std::vector<int>{static_cast<int>(
<< GetTypeName(self); self.Var()
return std::vector<int>(); .Get<framework::Strings>()
} .size())};
}) } else if (self.Var()
.IsType<framework::Vocab>()) {
return std::vector<int>{static_cast<int>(
self.Var()
.Get<framework::Vocab>()
.size())};
} else {
VLOG(2) << "It is meaningless to get shape of "
"variable type "
<< GetTypeName(self);
return std::vector<int>();
}
})
.def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf, .def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf,
R"DOC( R"DOC(
Whether a Tensor is leaf Tensor. Whether a Tensor is leaf Tensor.
......
...@@ -179,6 +179,8 @@ void BindVarDsec(pybind11::module *m) { ...@@ -179,6 +179,8 @@ void BindVarDsec(pybind11::module *m) {
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("dtype", &pd::VarDesc::GetDataType, .def("dtype", &pd::VarDesc::GetDataType,
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("element_size", &pd::VarDesc::ElementSize,
pybind11::return_value_policy::reference)
.def("dtypes", &pd::VarDesc::GetDataTypes, .def("dtypes", &pd::VarDesc::GetDataTypes,
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("lod_level", &pd::VarDesc::GetLoDLevel) .def("lod_level", &pd::VarDesc::GetLoDLevel)
......
...@@ -91,6 +91,7 @@ from .tensor.creation import empty # noqa: F401 ...@@ -91,6 +91,7 @@ from .tensor.creation import empty # noqa: F401
from .tensor.creation import empty_like # noqa: F401 from .tensor.creation import empty_like # noqa: F401
from .tensor.creation import assign # noqa: F401 from .tensor.creation import assign # noqa: F401
from .tensor.creation import complex # noqa: F401 from .tensor.creation import complex # noqa: F401
from .tensor.creation import clone # noqa: F401
from .tensor.linalg import matmul # noqa: F401 from .tensor.linalg import matmul # noqa: F401
from .tensor.linalg import dot # noqa: F401 from .tensor.linalg import dot # noqa: F401
from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import norm # noqa: F401
...@@ -587,4 +588,5 @@ __all__ = [ # noqa ...@@ -587,4 +588,5 @@ __all__ = [ # noqa
'fmin', 'fmin',
'moveaxis', 'moveaxis',
'repeat_interleave', 'repeat_interleave',
'clone',
] ]
...@@ -1396,6 +1396,33 @@ class Variable(object): ...@@ -1396,6 +1396,33 @@ class Variable(object):
__repr__ = __str__ __repr__ = __str__
def element_size(self):
"""
Returns the size in bytes of an element in the Tensor.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
x = paddle.static.data(name='x1', shape=[3, 2], dtype='bool')
x.element_size() # 1
x = paddle.static.data(name='x2', shape=[3, 2], dtype='int16')
x.element_size() # 2
x = paddle.static.data(name='x3', shape=[3, 2], dtype='float16')
x.element_size() # 2
x = paddle.static.data(name='x4', shape=[3, 2], dtype='float32')
x.element_size() # 4
x = paddle.static.data(name='x5', shape=[3, 2], dtype='float64')
x.element_size() # 8
"""
return self.desc.element_size()
@property @property
def stop_gradient(self): def stop_gradient(self):
""" """
......
...@@ -169,6 +169,31 @@ class TestAssignOApi(unittest.TestCase): ...@@ -169,6 +169,31 @@ class TestAssignOApi(unittest.TestCase):
self.assertTrue(np.allclose(result3.numpy(), np.array([1]))) self.assertTrue(np.allclose(result3.numpy(), np.array([1])))
paddle.enable_static() paddle.enable_static()
def test_clone(self):
paddle.disable_static()
x = paddle.ones([2])
x.stop_gradient = False
clone_x = paddle.clone(x)
y = clone_x**3
y.backward()
self.assertTrue(np.array_equal(x, [1, 1]), True)
self.assertTrue(np.array_equal(clone_x.grad.numpy(), [3, 3]), True)
self.assertTrue(np.array_equal(x.grad.numpy(), [3, 3]), True)
paddle.enable_static()
with program_guard(Program(), Program()):
x_np = np.random.randn(2, 3).astype('float32')
x = paddle.static.data("X", shape=[2, 3])
clone_x = paddle.clone(x)
exe = paddle.static.Executor()
y_np = exe.run(paddle.static.default_main_program(),
feed={'X': x_np},
fetch_list=[clone_x])[0]
self.assertTrue(np.array_equal(y_np, x_np), True)
class TestAssignOpErrorApi(unittest.TestCase): class TestAssignOpErrorApi(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
...@@ -18,18 +18,19 @@ import unittest ...@@ -18,18 +18,19 @@ import unittest
import copy import copy
import paddle import paddle
from paddle.fluid.dygraph import guard from paddle.fluid.dygraph import guard
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program, Variable
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
import paddle.fluid.io as io import paddle.fluid.io as io
from paddle.fluid.initializer import ConstantInitializer from paddle.fluid.initializer import ConstantInitializer
import numpy as np import numpy as np
paddle.enable_static()
main_program = default_main_program() main_program = default_main_program()
class ParameterChecks(unittest.TestCase): class ParameterChecks(unittest.TestCase):
def check_parameter(self): def test_parameter(self):
shape = [784, 100] shape = [784, 100]
val = 1.0625 val = 1.0625
b = main_program.global_block() b = main_program.global_block()
...@@ -43,13 +44,13 @@ class ParameterChecks(unittest.TestCase): ...@@ -43,13 +44,13 @@ class ParameterChecks(unittest.TestCase):
self.assertEqual((784, 100), param.shape) self.assertEqual((784, 100), param.shape)
self.assertEqual(core.VarDesc.VarType.FP32, param.dtype) self.assertEqual(core.VarDesc.VarType.FP32, param.dtype)
self.assertEqual(0, param.block.idx) self.assertEqual(0, param.block.idx)
exe = Executor(core.CPUPlace()) exe = Executor(paddle.CPUPlace())
p = exe.run(main_program, fetch_list=[param])[0] p = exe.run(main_program, fetch_list=[param])[0]
self.assertTrue(np.allclose(p, np.ones(shape) * val)) self.assertTrue(np.array_equal(p, np.ones(shape) * val))
p = io.get_parameter_value_by_name('fc.w', exe, main_program) p = io.get_parameter_value_by_name('fc.w', exe, main_program)
self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val)) self.assertTrue(np.array_equal(p, np.ones(shape) * val))
def check_parambase(self): def test_parambase(self):
with guard(): with guard():
linear = paddle.nn.Linear(10, 10) linear = paddle.nn.Linear(10, 10)
param = linear.weight param = linear.weight
...@@ -71,7 +72,7 @@ class ParameterChecks(unittest.TestCase): ...@@ -71,7 +72,7 @@ class ParameterChecks(unittest.TestCase):
pram_copy2 = copy.deepcopy(param, memo) pram_copy2 = copy.deepcopy(param, memo)
self.assertEqual(id(param_copy), id(pram_copy2)) self.assertEqual(id(param_copy), id(pram_copy2))
def check_exceptions(self): def test_exception(self):
b = main_program.global_block() b = main_program.global_block()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
b.create_parameter( b.create_parameter(
...@@ -86,16 +87,30 @@ class ParameterChecks(unittest.TestCase): ...@@ -86,16 +87,30 @@ class ParameterChecks(unittest.TestCase):
b.create_parameter( b.create_parameter(
name='test', shape=[-1], dtype='float32', initializer=None) name='test', shape=[-1], dtype='float32', initializer=None)
def test_parambase_to_vector(self):
with guard():
initializer = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(3.))
linear1 = paddle.nn.Linear(10, 15, initializer)
class TestParameter(ParameterChecks): vec = paddle.nn.utils.parameters_to_vector(linear1.parameters())
def _test_parameter(self): self.assertEqual(linear1.weight.shape, [10, 15])
self.check_parameter() self.assertEqual(linear1.bias.shape, [15])
self.assertTrue(isinstance(vec, Variable))
def test_parambase(self): self.assertTrue(vec.shape, [165])
self.check_parambase()
def test_exceptions(self): linear2 = paddle.nn.Linear(10, 15)
self.check_exceptions() paddle.nn.utils.vector_to_parameters(vec, linear2.parameters())
self.assertEqual(linear2.weight.shape, [10, 15])
self.assertEqual(linear2.bias.shape, [15])
self.assertTrue(
np.array_equal(linear1.weight.numpy(), linear2.weight.numpy()),
True)
self.assertTrue(
np.array_equal(linear1.bias.numpy(), linear2.bias.numpy()),
True)
self.assertTrue(linear2.weight.is_leaf, True)
self.assertTrue(linear2.bias.is_leaf, True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -497,6 +497,41 @@ class TestVarBase(unittest.TestCase): ...@@ -497,6 +497,41 @@ class TestVarBase(unittest.TestCase):
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
self.assertTrue(isinstance(str(var), str)) self.assertTrue(isinstance(str(var), str))
def test_element_size(self):
with fluid.dygraph.guard():
x = paddle.to_tensor(1, dtype='bool')
self.assertEqual(x.element_size(), 1)
x = paddle.to_tensor(1, dtype='float16')
self.assertEqual(x.element_size(), 2)
x = paddle.to_tensor(1, dtype='float32')
self.assertEqual(x.element_size(), 4)
x = paddle.to_tensor(1, dtype='float64')
self.assertEqual(x.element_size(), 8)
x = paddle.to_tensor(1, dtype='int8')
self.assertEqual(x.element_size(), 1)
x = paddle.to_tensor(1, dtype='int16')
self.assertEqual(x.element_size(), 2)
x = paddle.to_tensor(1, dtype='int32')
self.assertEqual(x.element_size(), 4)
x = paddle.to_tensor(1, dtype='int64')
self.assertEqual(x.element_size(), 8)
x = paddle.to_tensor(1, dtype='uint8')
self.assertEqual(x.element_size(), 1)
x = paddle.to_tensor(1, dtype='complex64')
self.assertEqual(x.element_size(), 8)
x = paddle.to_tensor(1, dtype='complex128')
self.assertEqual(x.element_size(), 16)
def test_backward(self): def test_backward(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
......
...@@ -63,6 +63,35 @@ class TestVariable(unittest.TestCase): ...@@ -63,6 +63,35 @@ class TestVariable(unittest.TestCase):
self.assertRaises(ValueError, self.assertRaises(ValueError,
lambda: b.create_var(name="fc.w", shape=(24, 100))) lambda: b.create_var(name="fc.w", shape=(24, 100)))
def test_element_size(self):
with fluid.program_guard(Program(), Program()):
x = paddle.static.data(name='x1', shape=[2], dtype='bool')
self.assertEqual(x.element_size(), 1)
x = paddle.static.data(name='x2', shape=[2], dtype='float16')
self.assertEqual(x.element_size(), 2)
x = paddle.static.data(name='x3', shape=[2], dtype='float32')
self.assertEqual(x.element_size(), 4)
x = paddle.static.data(name='x4', shape=[2], dtype='float64')
self.assertEqual(x.element_size(), 8)
x = paddle.static.data(name='x5', shape=[2], dtype='int8')
self.assertEqual(x.element_size(), 1)
x = paddle.static.data(name='x6', shape=[2], dtype='int16')
self.assertEqual(x.element_size(), 2)
x = paddle.static.data(name='x7', shape=[2], dtype='int32')
self.assertEqual(x.element_size(), 4)
x = paddle.static.data(name='x8', shape=[2], dtype='int64')
self.assertEqual(x.element_size(), 8)
x = paddle.static.data(name='x9', shape=[2], dtype='uint8')
self.assertEqual(x.element_size(), 1)
def test_step_scopes(self): def test_step_scopes(self):
prog = Program() prog = Program()
b = prog.current_block() b = prog.current_block()
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
from .spectral_norm_hook import spectral_norm from .spectral_norm_hook import spectral_norm
from .weight_norm_hook import weight_norm, remove_weight_norm # noqa: F401 from .weight_norm_hook import weight_norm, remove_weight_norm # noqa: F401
from .transform_parameters import parameters_to_vector, vector_to_parameters # noqa: F401
__all__ = [ #noqa __all__ = [ #noqa
'weight_norm', 'remove_weight_norm', 'spectral_norm' 'weight_norm', 'remove_weight_norm', 'spectral_norm', 'parameters_to_vector', 'vector_to_parameters'
] ]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import paddle
from paddle.fluid.framework import dygraph_only, _dygraph_tracer, _varbase_creator
from paddle import _C_ops
#input==output, inplace strategy of reshape has no cost almostly
def _inplace_reshape_dygraph(x, shape):
x_shape = _varbase_creator(dtype=x.dtype)
_dygraph_tracer().trace_op(
type="reshape2",
inputs={'X': x},
outputs={'Out': x,
'XShape': x_shape},
attrs={'shape': shape},
stop_gradient=True)
@dygraph_only
def parameters_to_vector(parameters, name=None):
"""
Flatten parameters to a 1-D Tensor.
Args:
parameters(Iterable[Tensor]): Iterable Tensors that are trainable parameters of a Layer.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
A 1-D Tensor, which represents the parameters of a Layer.
Examples:
.. code-block:: python
import paddle
linear = paddle.nn.Linear(10, 15)
paddle.nn.utils.parameters_to_vector(linear.parameters())
# 1-D Tensor: [165]
"""
dtype = parameters[0].dtype
origin_shapes = []
for param in parameters:
origin_shapes.append(param.shape)
_inplace_reshape_dygraph(param, [-1])
out = _varbase_creator(dtype=dtype)
_dygraph_tracer().trace_op(
type='concat',
inputs={'X': parameters},
outputs={'Out': [out]},
attrs={'axis': 0},
stop_gradient=True)
for i, param in enumerate(parameters):
_inplace_reshape_dygraph(param, origin_shapes[i])
return out
@dygraph_only
def vector_to_parameters(vec, parameters, name=None):
"""
Transform a Tensor with 1-D shape to the parameters.
Args:
vec (Tensor): A Tensor with 1-D shape, which represents the parameters of a Layer.
parameters (Iterable[Tensor]): Iterable Tensors that are trainable parameters of a Layer.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
import paddle
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(3.))
linear1 = paddle.nn.Linear(10, 15, weight_attr)
vec = paddle.nn.utils.parameters_to_vector(linear1.parameters())
linear2 = paddle.nn.Linear(10, 15)
# copy weight of linear1 to linear2
paddle.nn.utils.vector_to_parameters(vec, linear2.parameters())
# weight: Tensor(shape=[10, 15], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
# [[3. , ..., 3. ],
# [..., ..., ...],
# [3. , ..., 3. ]])
"""
origin_shapes = []
sections = []
for param in parameters:
shape = param.shape
origin_shapes.append(shape)
numel = reduce(lambda x, y: x * y, shape)
sections.append(numel)
_dygraph_tracer().trace_op(
type='split',
inputs={'X': [vec]},
outputs={'Out': parameters},
attrs={'axis': 0,
'sections': sections},
stop_gradient=True)
for i, param in enumerate(parameters):
_inplace_reshape_dygraph(param, origin_shapes[i])
return
...@@ -1158,8 +1158,7 @@ def empty_like(x, dtype=None, name=None): ...@@ -1158,8 +1158,7 @@ def empty_like(x, dtype=None, name=None):
def assign(x, output=None): def assign(x, output=None):
""" """
The OP copies the :attr:`x` to the :attr:`output`. The OP copies the :attr:`x` to the :attr:`output`.
Parameters: Parameters:
...@@ -1192,6 +1191,36 @@ def assign(x, output=None): ...@@ -1192,6 +1191,36 @@ def assign(x, output=None):
return tensor.assign(x, output) return tensor.assign(x, output)
def clone(x, name=None):
"""
Returns a copy of input Tensor. It will always have a Tensor copy.
In addition, This function is derivable, so gradients will flow back from the output to input.
Parameters:
x (Tensor): The input Tensor.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns: A Tensor copied from ``input`` .
Examples:
.. code-block:: python
import paddle
x = paddle.ones([2])
x.stop_gradient = False
clone_x = paddle.clone(x)
y = clone_x**3
y.backward()
print(clone_x.grad) # [3]
print(x.grad) # [3]
"""
return x.clone()
#NOTE(zhiqiu): not public #NOTE(zhiqiu): not public
def _memcpy(input, place=None, output=None): def _memcpy(input, place=None, output=None):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册