未验证 提交 f00f4fcf 编写于 作者: Z Zeng Jinle 提交者: GitHub

add more copy_from method (#36978)

上级 d572fa27
......@@ -506,6 +506,17 @@ static int GetNCCLVersion() {
}
#endif
template <typename PlaceType>
static void TensorCopyFrom(framework::Tensor *dst, const framework::Tensor &src,
const PlaceType &place, int64_t batch_size) {
if (batch_size < 0) {
framework::TensorCopy(src, place, dst);
} else {
auto sliced = src.Slice(0, batch_size);
framework::TensorCopy(sliced, place, dst);
}
}
#ifdef PADDLE_WITH_AVX
PYBIND11_MODULE(core_avx, m) {
#else
......@@ -755,16 +766,17 @@ PYBIND11_MODULE(core_noavx, m) {
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(place, type));
})
.def("_copy_from",
[](framework::Tensor &self, const framework::Tensor &other,
const platform::Place &place, int64_t batch_size) {
if (batch_size < 0) {
framework::TensorCopy(other, place, &self);
} else {
auto sliced = other.Slice(0, batch_size);
framework::TensorCopy(sliced, place, &self);
}
},
.def("_copy_from", &TensorCopyFrom<paddle::platform::CPUPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::XPUPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::CUDAPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::NPUPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::CUDAPinnedPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::Place>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("set", SetTensorFromPyArray<paddle::platform::CPUPlace>,
py::arg("array"), py::arg("place"), py::arg("zero_copy") = false)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import numpy as np
from paddle.fluid.core import LoDTensor as Tensor
class TestTensorCopyFrom(unittest.TestCase):
def test_main(self):
place = paddle.CPUPlace()
np_value = np.random.random(size=[10, 30]).astype('float32')
t_src = Tensor()
t_src.set(np_value, place)
self.assertTrue(np.array_equal(np_value, t_src))
t_dst1 = Tensor()
t_dst1._copy_from(t_src, place)
self.assertTrue(np.array_equal(np_value, t_dst1))
t_dst2 = Tensor()
t_dst2._copy_from(t_src, place, 5)
self.assertTrue(np.array_equal(np.array(np_value[0:5]), t_dst2))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册