From b6bf650aedd53ff2e52728e2e60306688eb597a7 Mon Sep 17 00:00:00 2001 From: fwenguang <95677191+fwenguang@users.noreply.github.com> Date: Fri, 31 Dec 2021 15:12:09 +0800 Subject: [PATCH] [MLU]support calling mlu op from python interface (#38292) * [MLU]support calling mlu op from python interface * [MLU]fix * fix * [mlu]fix mlu_places * [mlu]fix required mlu * fix * [MLU]fix tensor copy * [mlu] fix MLUPlace call path --- paddle/fluid/framework/tensor_util.cc | 59 ++++++++ paddle/fluid/imperative/prepared_operator.cc | 10 ++ .../memory/allocation/allocator_facade.cc | 2 +- paddle/fluid/pybind/pybind.cc | 143 +++++++++++++++++- paddle/fluid/pybind/tensor_py.h | 38 ++++- python/CMakeLists.txt | 2 + python/paddle/__init__.py | 2 + python/paddle/device/__init__.py | 71 ++++++++- python/paddle/fluid/__init__.py | 3 +- python/paddle/fluid/framework.py | 85 ++++++++++- python/paddle/fluid/io.py | 8 + python/paddle/framework/__init__.py | 1 + python/paddle/static/__init__.py | 2 + 13 files changed, 409 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index d655e3e8e53..f0e5a447fd2 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -357,6 +357,36 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, "Copying from %s to %s is not supported.", src_place, dst_place)); } #endif +#ifdef PADDLE_WITH_MLU + else if (platform::is_mlu_place(src_place) && // NOLINT + platform::is_cpu_place(dst_place)) { + auto src_mlu_place = BOOST_GET_CONST(platform::MLUPlace, src_place); + auto dst_cpu_place = BOOST_GET_CONST(platform::CPUPlace, dst_place); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(dst_cpu_place, dst_ptr, src_mlu_place, src_ptr, size, stream); + } + else if (platform::is_cpu_place(src_place) && // NOLINT + platform::is_mlu_place(dst_place)) { + auto src_cpu_place = BOOST_GET_CONST(platform::CPUPlace, src_place); + auto dst_mlu_place = BOOST_GET_CONST(platform::MLUPlace, dst_place); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(dst_mlu_place, dst_ptr, src_cpu_place, src_ptr, size, stream); + } + else if (platform::is_mlu_place(src_place) && // NOLINT + platform::is_mlu_place(dst_place)) { + auto src_mlu_place = BOOST_GET_CONST(platform::MLUPlace, src_place); + auto dst_mlu_place = BOOST_GET_CONST(platform::MLUPlace, dst_place); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(dst_mlu_place, dst_ptr, src_mlu_place, src_ptr, size, stream); + } + else { // NOLINT + PADDLE_THROW(platform::errors::Unimplemented( + "Copying from %s to %s is not supported.", src_place, dst_place)); + } +#endif } void TensorCopy(const Tensor& src, const platform::Place& dst_place, @@ -526,6 +556,35 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, "Copy from %s to %s is not supported.", src_place, dst_place)); } #endif +#ifdef PADDLE_WITH_MLU + else if (platform::is_mlu_place(src_place) && // NOLINT + platform::is_cpu_place(dst_place)) { + memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::MLUPlace, src_place), src_ptr, size, + nullptr); + } + else if (platform::is_cpu_place(src_place) && // NOLINT + platform::is_mlu_place(dst_place)) { + memory::Copy(BOOST_GET_CONST(platform::MLUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size, + nullptr); + } + else if (platform::is_mlu_place(src_place) && // NOLINT + platform::is_mlu_place(dst_place)) { + if (src_ptr == dst_ptr) { + VLOG(3) << "Skip copy the same data async from " << src_place << " to " + << dst_place; + return; + } + memory::Copy(BOOST_GET_CONST(platform::MLUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::MLUPlace, src_place), src_ptr, size, + nullptr); + } + else { // NOLINT + PADDLE_THROW(platform::errors::Unimplemented( + "Copy from %s to %s is not supported.", src_place, dst_place)); + } +#endif } template diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 1ed1716f12e..c5623a8f4f2 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -217,6 +217,16 @@ PreparedOp PrepareImpl(const NameVarMap& ins, expected_kernel_key.place_ = platform::CPUPlace(); kernel_iter = kernels.find(expected_kernel_key); } +#endif +#ifdef PADDLE_WITH_MLU + if (kernel_iter == kernels.end() && + is_mlu_place(expected_kernel_key.place_)) { + VLOG(3) << "missing MLU kernel: " << op.Type() + << ", expected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; + expected_kernel_key.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(expected_kernel_key); + } #endif // TODO(jiabin): Add operator.cc's line 1000 part back when we need that // case diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 473a2d28877..9bc2f5461f3 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -690,7 +690,7 @@ class AllocatorFacadePrivate { #ifdef PADDLE_WITH_MLU int device_count = platform::GetMLUDeviceCount(); for (int i = 0; i < device_count; ++i) { - platform::XPUPlace p(i); + platform::MLUPlace p(i); system_allocators_[p] = std::make_shared(p); } #endif diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 46a679b0c97..b5845a1ef96 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -137,6 +137,10 @@ limitations under the License. */ #include "paddle/fluid/platform/ipu_info.h" #endif +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif + #ifdef PADDLE_WITH_CRYPTO #include "paddle/fluid/pybind/crypto.h" #endif @@ -164,6 +168,7 @@ PyTypeObject *g_cpuplace_pytype = nullptr; PyTypeObject *g_xpuplace_pytype = nullptr; PyTypeObject *g_npuplace_pytype = nullptr; PyTypeObject *g_cudapinnedplace_pytype = nullptr; +PyTypeObject *g_mluplace_pytype = nullptr; PyTypeObject *g_framework_tensor_pytype = nullptr; bool IsCompiledWithCUDA() { @@ -230,6 +235,14 @@ bool IsCompiledWithCINN() { #endif } +bool IsCompiledWithMLU() { +#ifndef PADDLE_WITH_MLU + return false; +#else + return true; +#endif +} + bool IsCompiledWithHETERPS() { #ifndef PADDLE_WITH_HETERPS return false; @@ -295,10 +308,9 @@ OpSupportedInfos(const std::string &place, [](unsigned char c) { return std::toupper(c); }); using fn_type = std::add_pointer::type; std::unordered_map is_target_place{ - {"GPU", &platform::is_gpu_place}, - {"CPU", &platform::is_cpu_place}, - {"XPU", &platform::is_xpu_place}, - {"NPU", &platform::is_npu_place}, + {"GPU", &platform::is_gpu_place}, {"CPU", &platform::is_cpu_place}, + {"XPU", &platform::is_xpu_place}, {"NPU", &platform::is_npu_place}, + {"MLU", &platform::is_mlu_place}, }; PADDLE_ENFORCE_NE( is_target_place.count(query_place), 0, @@ -769,6 +781,10 @@ PYBIND11_MODULE(core_noavx, m) { [](framework::Tensor &self, paddle::platform::NPUPlace &place) { self.mutable_data(place); }) + .def("_alloc_float", + [](framework::Tensor &self, paddle::platform::MLUPlace &place) { + self.mutable_data(place); + }) .def("_alloc_double", [](framework::Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); @@ -785,6 +801,10 @@ PYBIND11_MODULE(core_noavx, m) { [](framework::Tensor &self, paddle::platform::CUDAPlace &place) { self.mutable_data(place); }) + .def("_alloc_int", + [](framework::Tensor &self, paddle::platform::MLUPlace &place) { + self.mutable_data(place); + }) .def("_alloc_int", [](framework::Tensor &self, paddle::platform::CUDAPinnedPlace &place) { @@ -815,6 +835,11 @@ PYBIND11_MODULE(core_noavx, m) { paddle::framework::proto::VarType::Type type) { return reinterpret_cast(self.mutable_data(place, type)); }) + .def("_mutable_data", + [](framework::Tensor &self, paddle::platform::MLUPlace &place, + paddle::framework::proto::VarType::Type type) { + return reinterpret_cast(self.mutable_data(place, type)); + }) .def("_clear", &framework::Tensor::clear) .def("_mutable_data", [](framework::Tensor &self, paddle::platform::NPUPlace &place, @@ -831,6 +856,8 @@ PYBIND11_MODULE(core_noavx, m) { py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) .def("_copy_from", &TensorCopyFrom, py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) + .def("_copy_from", &TensorCopyFrom, + py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) .def("_copy_from", &TensorCopyFrom, py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) .def("set", SetTensorFromPyArray, @@ -843,6 +870,8 @@ PYBIND11_MODULE(core_noavx, m) { py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) .def("set", SetTensorFromPyArray, py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) + .def("set", SetTensorFromPyArray, + py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) .def("set", SetTensorFromPyArray, py::arg("array"), py::arg("place"), py::arg("zero_copy") = false, R"DOC( @@ -850,7 +879,7 @@ PYBIND11_MODULE(core_noavx, m) { Args: lod (numpy.ndarray): The data to set. - place (CPUPlace|CUDAPlace|XPUPlace|IPUPlace|CUDAPinnedPlace|NPUPlace): The place where the + place (CPUPlace|CUDAPlace|XPUPlace|IPUPlace|CUDAPinnedPlace|NPUPlace|MLUPlace): The place where the LoDTensor is to be set. zero_copy (bool, optional): Whether to share memory with the input numpy array. This parameter only works with CPUPlace. Default: False. @@ -1619,6 +1648,18 @@ All parameter, weight, gradient are variables in Paddle. "Please recompile or reinstall Paddle with XPU support.")); #else return new paddle::platform::XPUDeviceContext(place); +#endif + }) + .def_static("create", + [](paddle::platform::MLUPlace& place) + -> paddle::platform::DeviceContext* { +#ifndef PADDLE_WITH_MLU + PADDLE_THROW( + platform::errors::PermissionDenied( + "Cannot use MLUPlace in CPU/GPU version, " + "Please recompile or reinstall Paddle with MLU support.")); +#else + return new paddle::platform::MLUDeviceContext(place); #endif }) .def_static("create", @@ -1736,6 +1777,7 @@ All parameter, weight, gradient are variables in Paddle. .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_get_device_id", @@ -2004,6 +2046,75 @@ All parameter, weight, gradient are variables in Paddle. #endif .def("__str__", string::to_string); + // MLUPlace + py::class_ mluplace(m, "MLUPlace", R"DOC( + MLUPlace is a descriptor of a device. + It represents a MLU device on which a tensor will be allocated and a model will run. + + Examples: + .. code-block:: python + import paddle + # required: mlu + mlu_place = paddle.MLUPlace(0) + + )DOC"); + g_mluplace_pytype = reinterpret_cast(mluplace.ptr()); + mluplace + .def("__init__", + [](platform::MLUPlace &self, int dev_id) { +#ifdef PADDLE_WITH_MLU + if (UNLIKELY(dev_id < 0)) { + LOG(ERROR) << string::Sprintf( + "Invalid MLUPlace(%d), device id must be 0 or " + "positive integer", + dev_id); + std::exit(-1); + } + if (UNLIKELY(dev_id >= platform::GetMLUDeviceCount())) { + if (platform::GetMLUDeviceCount() == 0) { + LOG(ERROR) << "Cannot use MLU because there is no MLU " + "detected on your " + "machine."; + std::exit(-1); + } else { + LOG(ERROR) << string::Sprintf( + "Invalid MLUPlace(%d), must inside [0, %d), because MLU " + "number on your machine is %d", + dev_id, platform::GetMLUDeviceCount(), + platform::GetMLUDeviceCount()); + std::exit(-1); + } + } + new (&self) platform::MLUPlace(dev_id); +#else + LOG(ERROR) << string::Sprintf( + "Cannot use MLU because you have installed CPU/GPU/... " + "version " + "PaddlePaddle.\n" + "If you want to use MLU, please try to install MLU version " + "PaddlePaddle by: pip install paddlepaddle-mlu\n" + "If you only have CPU, please change MLUPlace(%d) to be " + "CPUPlace().\n", + dev_id); + std::exit(-1); +#endif + }) + .def("_type", &PlaceIndex) +#ifdef PADDLE_WITH_MLU + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", + &IsSamePlace) + .def("get_device_id", + [](const platform::MLUPlace &self) { return self.GetDeviceId(); }) +#endif + .def("__str__", string::to_string); + py::class_ platformplace(m, "Place"); g_place_pytype = reinterpret_cast(platformplace.ptr()); platformplace.def(py::init<>()) @@ -2015,6 +2126,7 @@ All parameter, weight, gradient are variables in Paddle. .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) .def("is_gpu_place", [](platform::Place &self) { return platform::is_gpu_place(self); }) .def("is_cpu_place", @@ -2029,6 +2141,8 @@ All parameter, weight, gradient are variables in Paddle. [](platform::Place &self) { return platform::is_cuda_pinned_place(self); }) + .def("is_mlu_place", + [](platform::Place &self) { return platform::is_mlu_place(self); }) .def("gpu_device_id", [](platform::Place &self) { return BOOST_GET_CONST(platform::CUDAPlace, self).device; @@ -2045,6 +2159,10 @@ All parameter, weight, gradient are variables in Paddle. [](platform::Place &self) { return BOOST_GET_CONST(platform::IPUPlace, self).device; }) + .def("mlu_device_id", + [](platform::Place &self) { + return BOOST_GET_CONST(platform::MLUPlace, self).device; + }) .def("set_place", [](platform::Place &self, const platform::Place &other) { self = other; }) .def("set_place", @@ -2072,6 +2190,10 @@ All parameter, weight, gradient are variables in Paddle. [](platform::Place &self, const platform::IPUPlace &ipu_place) { self = ipu_place; }) + .def("set_place", + [](platform::Place &self, const platform::MLUPlace &mlu_place) { + self = mlu_place; + }) .def("__repr__", string::to_string) .def("__str__", string::to_string); @@ -2120,6 +2242,12 @@ All parameter, weight, gradient are variables in Paddle. pybind11::gil_scoped_release release; self.Run(scope, place); }) + .def("run", + [](OperatorBase &self, const Scope &scope, + const platform::MLUPlace &place) { + pybind11::gil_scoped_release release; + self.Run(scope, place); + }) .def("type", [](const OperatorBase &op) -> std::string { return op.Type(); }) .def("outputs", @@ -2307,6 +2435,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("is_compiled_with_xpu", IsCompiledWithXPU); m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN); m.def("is_compiled_with_cinn", IsCompiledWithCINN); + m.def("is_compiled_with_mlu", IsCompiledWithMLU); m.def("_is_compiled_with_heterps", IsCompiledWithHETERPS); m.def("supports_bfloat16", SupportsBfloat16); m.def("supports_bfloat16_fast_performance", SupportsBfloat16FastPerformance); @@ -2627,6 +2756,10 @@ All parameter, weight, gradient are variables in Paddle. m.def("get_ipu_device_count", platform::GetIPUDeviceCount); #endif +#ifdef PADDLE_WITH_MLU + m.def("get_mlu_device_count", platform::GetMLUDeviceCount); +#endif + py::enum_(m, "TracerOption", py::arithmetic()) .value("kDefault", platform::TracerOption::kDefault) .value("kOpDetail", platform::TracerOption::kOpDetail) diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index df9ba02eadf..9d3a858d1bd 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -345,6 +345,18 @@ void SetTensorFromPyArrayT( PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use NPUPlace in CPU/GPU/XPU version. " "Please recompile or reinstall Paddle with NPU support.")); +#endif + } else if (paddle::platform::is_mlu_place(place)) { +#ifdef PADDLE_WITH_MLU + platform::Place tmp_place = place; + platform::MLUDeviceGuard guard( + BOOST_GET_CONST(platform::MLUPlace, tmp_place).device); + auto dst = self->mutable_data(place); + paddle::platform::MLUMemcpyH2DSync(dst, array.data(), array.nbytes()); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Cannot use MLUPlace in CPU/GPU version, " + "Please recompile or reinstall Paddle with MLU support.")); #endif } else { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -702,6 +714,7 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor, bool is_gpu_tensor = platform::is_gpu_place(tensor.place()); bool is_xpu_tensor = platform::is_xpu_place(tensor.place()); bool is_npu_tensor = platform::is_npu_place(tensor.place()); + bool is_mlu_tensor = platform::is_mlu_place(tensor.place()); const auto &tensor_dims = tensor.dims(); auto tensor_dtype = tensor.type(); size_t sizeof_dtype = framework::SizeOfType(tensor_dtype); @@ -720,7 +733,7 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor, std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type()); - if (!is_gpu_tensor && !is_xpu_tensor && !is_npu_tensor) { + if (!is_gpu_tensor && !is_xpu_tensor && !is_npu_tensor && !is_mlu_tensor) { if (!need_deep_copy) { auto base = py::cast(std::move(tensor)); return py::array(py::dtype(py_dtype_str.c_str()), py_dims, py_strides, @@ -816,6 +829,29 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor, PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use NPUPlace in CPU/GPU/XPU version, " "Please recompile or reinstall Paddle with NPU support.")); +#endif + } else if (is_mlu_tensor) { +#ifdef PADDLE_WITH_MLU + py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides); + PADDLE_ENFORCE_EQ(py_arr.writeable(), true, + platform::errors::InvalidArgument( + "PyArray is not writable, in which case memory leak " + "or double free would occur")); + PADDLE_ENFORCE_EQ( + py_arr.owndata(), true, + platform::errors::InvalidArgument( + "PyArray does not own data, in which case memory leak " + "or double free would occur")); + + size_t copy_bytes = sizeof_dtype * numel; + auto p = BOOST_GET_CONST(platform::MLUPlace, tensor.place()); + paddle::memory::Copy(platform::CPUPlace(), py_arr.mutable_data(), p, + tensor_buf_ptr, copy_bytes, nullptr); + return py_arr; +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Cannot use MLUPlace in CPU/GPU/XPU/NPU version, " + "Please recompile or reinstall Paddle with MLU support.")); #endif } PADDLE_THROW(platform::errors::Unimplemented("Place is not supported")); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index b7a601f53fd..0fecd7c8c36 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -6,6 +6,8 @@ set(PY_FILES paddle/__init__.py if(WITH_GPU) SET(PACKAGE_NAME "paddlepaddle-gpu") +elseif(WITH_MLU) + SET(PACKAGE_NAME "paddlepaddle-mlu") elseif(WITH_ROCM) SET(PACKAGE_NAME "paddlepaddle-rocm") elseif(WITH_ASCEND_CL) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 771a9053fc2..e3171c4f3bb 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -291,6 +291,7 @@ from .framework import IPUPlace # noqa: F401 from .framework import CUDAPlace # noqa: F401 from .framework import NPUPlace # noqa: F401 from .framework import CUDAPinnedPlace # noqa: F401 +from .framework import MLUPlace # noqa: F401 from .autograd import grad # noqa: F401 from .autograd import no_grad # noqa: F401 @@ -322,6 +323,7 @@ from .fluid.framework import set_flags # noqa: F401 from .device import is_compiled_with_xpu # noqa: F401 from .device import is_compiled_with_npu # noqa: F401 from .device import is_compiled_with_ipu # noqa: F401 +from .device import is_compiled_with_mlu # noqa: F401 from .device import XPUPlace # noqa: F401 from .fluid.dygraph.base import enable_dygraph as disable_static # noqa: F401 diff --git a/python/paddle/device/__init__.py b/python/paddle/device/__init__.py index 0a11d59d69c..d102473fef7 100644 --- a/python/paddle/device/__init__.py +++ b/python/paddle/device/__init__.py @@ -29,12 +29,14 @@ __all__ = [ # noqa 'get_device', 'XPUPlace', 'IPUPlace', + 'MLUPlace', 'is_compiled_with_xpu', 'is_compiled_with_ipu', 'is_compiled_with_cinn', 'is_compiled_with_cuda', 'is_compiled_with_rocm', - 'is_compiled_with_npu' + 'is_compiled_with_npu', + 'is_compiled_with_mlu' ] _cudnn_version = None @@ -120,6 +122,41 @@ def XPUPlace(dev_id): return core.XPUPlace(dev_id) +def is_compiled_with_mlu(): + """ + Whether paddle was built with WITH_MLU=ON to support Cambricon MLU + + Returns (bool): whether paddle was built with WITH_MLU=ON + + Examples: + .. code-block:: python + + # required: mlu + + import paddle + support_mlu = paddle.device.is_compiled_with_mlu() + """ + return core.is_compiled_with_mlu() + + +def MLUPlace(dev_id): + """ + Return a Cambricon MLU Place + + Parameters: + dev_id(int): MLU device id + + Examples: + .. code-block:: python + + # required: mlu + + import paddle + place = paddle.device.MLUPlace(0) + """ + return core.MLUPlace(dev_id) + + def get_cudnn_version(): """ This funciton return the version of cudnn. the retuen value is int which represents the @@ -181,13 +218,21 @@ def _convert_to_place(device): "The device should not be 'ipu', " \ "since PaddlePaddle is not compiled with IPU") place = core.IPUPlace() + elif lower_device == 'mlu': + if not core.is_compiled_with_mlu(): + raise ValueError("The device should not be 'mlu', " + "since PaddlePaddle is not compiled with MLU") + selected_mlus = os.getenv("FLAGS_selected_mlus", "0").split(",") + device_id = int(selected_mlus[0]) + place = core.MLUPlace(device_id) else: avaliable_gpu_device = re.match(r'gpu:\d+', lower_device) avaliable_xpu_device = re.match(r'xpu:\d+', lower_device) avaliable_npu_device = re.match(r'npu:\d+', lower_device) - if not avaliable_gpu_device and not avaliable_xpu_device and not avaliable_npu_device: + avaliable_mlu_device = re.match(r'mlu:\d+', lower_device) + if not avaliable_gpu_device and not avaliable_xpu_device and not avaliable_npu_device and not avaliable_mlu_device: raise ValueError( - "The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu', 'npu:x' or ipu" + "The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'mlu', 'mlu:x', 'npu', 'npu:x' or ipu" ) if avaliable_gpu_device: if not core.is_compiled_with_cuda(): @@ -216,19 +261,28 @@ def _convert_to_place(device): device_id = device_info_list[1] device_id = int(device_id) place = core.NPUPlace(device_id) + if avaliable_mlu_device: + if not core.is_compiled_with_mlu(): + raise ValueError( + "The device should not be {}, since PaddlePaddle is " + "not compiled with mlu".format(avaliable_mlu_device)) + device_info_list = device.split(':', 1) + device_id = device_info_list[1] + device_id = int(device_id) + place = core.MLUPlace(device_id) return place def set_device(device): """ - Paddle supports running calculations on various types of devices, including CPU, GPU, XPU, NPU and IPU. + Paddle supports running calculations on various types of devices, including CPU, GPU, XPU, NPU, MLU and IPU. They are represented by string identifiers. This function can specify the global device which the OP will run. Parameters: device(str): This parameter determines the specific running device. - It can be ``cpu``, ``gpu``, ``xpu``, ``npu``, ``gpu:x``, ``xpu:x``, ``npu:x`` and ``ipu``, - where ``x`` is the index of the GPUs, XPUs or NPUs. + It can be ``cpu``, ``gpu``, ``xpu``, ``npu``, ``mlu``, ``gpu:x``, ``xpu:x``, ``npu:x``, ``mlu:x`` and ``ipu``, + where ``x`` is the index of the GPUs, XPUs, NPUs or MLUs. Examples: @@ -249,7 +303,7 @@ def set_device(device): def get_device(): """ This funciton can get the current global device of the program is running. - It's a string which is like 'cpu', 'gpu:x', 'xpu:x' and 'npu:x'. if the global device is not + It's a string which is like 'cpu', 'gpu:x', 'xpu:x', 'mlu:x' and 'npu:x'. if the global device is not set, it will return a string which is 'gpu:x' when cuda is avaliable or it will return a string which is 'cpu' when cuda is not avaliable. @@ -277,6 +331,9 @@ def get_device(): elif isinstance(place, core.IPUPlace): num_devices = core.get_ipu_device_count() device = "ipus:{{0-{}}}".format(num_devices - 1) + elif isinstance(place, core.MLUPlace): + device_id = place.get_device_id() + device = 'mlu:' + str(device_id) else: raise ValueError("The device specification {} is invalid".format(place)) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index d8ee875e768..cd8f9f85458 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -71,7 +71,7 @@ from . import distribute_lookup_table from .param_attr import ParamAttr, WeightNormParamAttr from .data_feeder import DataFeeder from .core import LoDTensor, LoDTensorArray, Scope, _Scope -from .core import CPUPlace, XPUPlace, CUDAPlace, CUDAPinnedPlace, NPUPlace, IPUPlace +from .core import CPUPlace, XPUPlace, CUDAPlace, CUDAPinnedPlace, NPUPlace, IPUPlace, MLUPlace from .incubate import fleet from .transpiler import DistributeTranspiler, \ memory_optimize, release_memory, DistributeTranspilerConfig @@ -133,6 +133,7 @@ __all__ = framework.__all__ + executor.__all__ + \ 'CUDAPinnedPlace', 'NPUPlace', 'IPUPlace', + 'MLUPlace', 'Tensor', 'ParamAttr', 'WeightNormParamAttr', diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index fd2a9387487..73407ef834e 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -51,6 +51,7 @@ __all__ = [ 'cuda_places', 'cpu_places', 'xpu_places', + 'mlu_places', 'cuda_pinned_places', 'in_dygraph_mode', 'is_compiled_with_cinn', @@ -347,6 +348,18 @@ def _current_expected_place(): "You are using XPU version Paddle, but your XPU device is not set properly. CPU device will be used by default." ) _global_expected_place_ = core.CPUPlace() + elif core.is_compiled_with_mlu(): + try: + device_count = core.get_mlu_device_count() + except Exception as e: + device_count = 0 + if device_count > 0: + _global_expected_place_ = core.MLUPlace(0) + else: + warnings.warn( + "You are using MLU version Paddle, but your MLU device is not set properly. CPU device will be used by default." + ) + _global_expected_place_ = core.CPUPlace() else: _global_expected_place_ = core.CPUPlace() @@ -426,6 +439,15 @@ def _npu_ids(): return device_ids +def _mlu_ids(): + mlus_env = os.getenv("FLAGS_selected_mlus") + if mlus_env: + device_ids = [int(s) for s in mlus_env.split(",")] + else: + device_ids = six.moves.range(core.get_mlu_device_count()) + return device_ids + + def is_compiled_with_xpu(): """ Whether this whl package can be used to run the model on XPU. @@ -721,6 +743,48 @@ def cuda_pinned_places(device_count=None): return [core.CUDAPinnedPlace()] * device_count +def mlu_places(device_ids=None): + """ + **Note**: + For multi-card tasks, please use `FLAGS_selected_mlus` environment variable to set the visible MLU device. + This function creates a list of :code:`paddle.device.MLUPlace` objects. + If :code:`device_ids` is None, environment variable of + :code:`FLAGS_selected_mlus` would be checked first. For example, if + :code:`FLAGS_selected_mlus=0,1,2`, the returned list would + be [paddle.device.MLUPlace(0), paddle.device.MLUPlace(1), paddle.device.MLUPlace(2)]. + If :code:`FLAGS_selected_mlus` is not set, all visible + mlu places would be returned. + If :code:`device_ids` is not None, it should be the device + ids of MLUs. For example, if :code:`device_ids=[0,1,2]`, + the returned list would be + [paddle.device.MLUPlace(0), paddle.device.MLUPlace(1), paddle.device.MLUPlace(2)]. + + Parameters: + device_ids (list or tuple of int, optional): list of MLU device ids. + + Returns: + list of paddle.device.MLUPlace: Created MLU place list. + + Examples: + .. code-block:: python + + # required: mlu + + import paddle + import paddle.static as static + + paddle.enable_static() + mlu_places = static.mlu_places() + """ + assert core.is_compiled_with_mlu(), \ + "Not compiled with MLU" + if device_ids is None: + device_ids = _mlu_ids() + elif not isinstance(device_ids, (list, tuple)): + device_ids = [device_ids] + return [core.MLUPlace(dev_id) for dev_id in device_ids] + + class NameScope(object): def __init__(self, name="", parent=None): self._children = dict() @@ -2090,6 +2154,10 @@ class Variable(object): p = core.Place() p.set_place(t._place()) place = core.NPUPlace(p.npu_device_id()) + elif p.is_mlu_place(): + p = core.Place() + p.set_place(t._place()) + place = core.MLUPlace(p.mlu_device_id()) else: p = core.Place() p.set_place(t._place()) @@ -6768,7 +6836,8 @@ def _get_paddle_place(place): if place is None: return place if isinstance(place, (core.Place, core.XPUPlace, core.CPUPlace, - core.CUDAPinnedPlace, core.CUDAPlace, core.NPUPlace)): + core.CUDAPinnedPlace, core.CUDAPlace, core.NPUPlace, + core.MLUPlace)): return place if not isinstance(place, str): @@ -6823,8 +6892,20 @@ def _get_paddle_place(place): device_id = int(device_id) return core.NPUPlace(device_id) + # MLU + avaliable_mlu_place = re.match(r'mlu:\d+', place) + if avaliable_mlu_place: + if not core.is_compiled_with_mlu(): + raise ValueError( + "The device should not be {}, since PaddlePaddle is " \ + "not compiled with MLU".format(avaliable_mlu_place)) + place_info_list = place.split(':', 1) + device_id = place_info_list[1] + device_id = int(device_id) + return core.MLUPlace(device_id) + raise ValueError( - "Paddle supports CPUPlace, CUDAPlace,CUDAPinnedPlace, XPUPlace and NPUPlace, but received {}.". + "Paddle supports CPUPlace, CUDAPlace,CUDAPinnedPlace, XPUPlace, MLUPlace and NPUPlace, but received {}.". format(place)) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index e110c47d790..4bbc0ba03c9 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -2100,6 +2100,10 @@ def load(program, model_path, executor=None, var_list=None): p = paddle.fluid.core.Place() p.set_place(t._place()) place = paddle.fluid.NPUPlace(p.npu_device_id()) + elif p.is_mlu_place(): + p = paddle.fluid.core.Place() + p.set_place(t._place()) + place = paddle.fluid.MLUPlace(p.mlu_device_id()) else: p = paddle.fluid.core.Place() p.set_place(t._place()) @@ -2394,6 +2398,10 @@ def set_program_state(program, state_dict): p = paddle.fluid.core.Place() p.set_place(ten_place) py_place = paddle.fluid.NPUPlace(p.npu_device_id()) + elif ten_place.is_mlu_place(): + p = paddle.fluid.core.Place() + p.set_place(ten_place) + py_place = paddle.fluid.MLUPlace(p.mlu_device_id()) ten.set(new_para_np, py_place) diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index a081da60c68..a0503322806 100644 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -28,6 +28,7 @@ from ..fluid.core import IPUPlace # noqa: F401 from ..fluid.core import CUDAPlace # noqa: F401 from ..fluid.core import CUDAPinnedPlace # noqa: F401 from ..fluid.core import NPUPlace # noqa: F401 +from ..fluid.core import MLUPlace # noqa: F401 from ..fluid.core import VarBase # noqa: F401 from paddle.fluid import core # noqa: F401 diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index 92aa5000dfa..f18b77997a5 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -45,6 +45,7 @@ from ..fluid.framework import program_guard # noqa: F401 from ..fluid.framework import cpu_places # noqa: F401 from ..fluid.framework import cuda_places # noqa: F401 from ..fluid.framework import xpu_places # noqa: F401 +from ..fluid.framework import mlu_places # noqa: F401 from ..fluid.framework import npu_places # noqa: F401 from ..fluid.framework import Variable # noqa: F401 from ..fluid.layers.control_flow import Print # noqa: F401 @@ -103,6 +104,7 @@ __all__ = [ #noqa 'cuda_places', 'xpu_places', 'npu_places', + 'mlu_places', 'Variable', 'create_global_var', 'accuracy', -- GitLab