diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index d655e3e8e53e5ea3b66de73a8c520d91ac9e455e..f0e5a447fd2dac8bf7a9374e132bcfa3f96e3b78 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -357,6 +357,36 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, "Copying from %s to %s is not supported.", src_place, dst_place)); } #endif +#ifdef PADDLE_WITH_MLU + else if (platform::is_mlu_place(src_place) && // NOLINT + platform::is_cpu_place(dst_place)) { + auto src_mlu_place = BOOST_GET_CONST(platform::MLUPlace, src_place); + auto dst_cpu_place = BOOST_GET_CONST(platform::CPUPlace, dst_place); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(dst_cpu_place, dst_ptr, src_mlu_place, src_ptr, size, stream); + } + else if (platform::is_cpu_place(src_place) && // NOLINT + platform::is_mlu_place(dst_place)) { + auto src_cpu_place = BOOST_GET_CONST(platform::CPUPlace, src_place); + auto dst_mlu_place = BOOST_GET_CONST(platform::MLUPlace, dst_place); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(dst_mlu_place, dst_ptr, src_cpu_place, src_ptr, size, stream); + } + else if (platform::is_mlu_place(src_place) && // NOLINT + platform::is_mlu_place(dst_place)) { + auto src_mlu_place = BOOST_GET_CONST(platform::MLUPlace, src_place); + auto dst_mlu_place = BOOST_GET_CONST(platform::MLUPlace, dst_place); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(dst_mlu_place, dst_ptr, src_mlu_place, src_ptr, size, stream); + } + else { // NOLINT + PADDLE_THROW(platform::errors::Unimplemented( + "Copying from %s to %s is not supported.", src_place, dst_place)); + } +#endif } void TensorCopy(const Tensor& src, const platform::Place& dst_place, @@ -526,6 +556,35 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, "Copy from %s to %s is not supported.", src_place, dst_place)); } #endif +#ifdef PADDLE_WITH_MLU + else if (platform::is_mlu_place(src_place) && // NOLINT + platform::is_cpu_place(dst_place)) { + memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::MLUPlace, src_place), src_ptr, size, + nullptr); + } + else if (platform::is_cpu_place(src_place) && // NOLINT + platform::is_mlu_place(dst_place)) { + memory::Copy(BOOST_GET_CONST(platform::MLUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size, + nullptr); + } + else if (platform::is_mlu_place(src_place) && // NOLINT + platform::is_mlu_place(dst_place)) { + if (src_ptr == dst_ptr) { + VLOG(3) << "Skip copy the same data async from " << src_place << " to " + << dst_place; + return; + } + memory::Copy(BOOST_GET_CONST(platform::MLUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::MLUPlace, src_place), src_ptr, size, + nullptr); + } + else { // NOLINT + PADDLE_THROW(platform::errors::Unimplemented( + "Copy from %s to %s is not supported.", src_place, dst_place)); + } +#endif } template diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 1ed1716f12ebe91e6643ef67ac8cbd471435d248..c5623a8f4f2438b3794fea02062cb484cd548947 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -217,6 +217,16 @@ PreparedOp PrepareImpl(const NameVarMap& ins, expected_kernel_key.place_ = platform::CPUPlace(); kernel_iter = kernels.find(expected_kernel_key); } +#endif +#ifdef PADDLE_WITH_MLU + if (kernel_iter == kernels.end() && + is_mlu_place(expected_kernel_key.place_)) { + VLOG(3) << "missing MLU kernel: " << op.Type() + << ", expected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; + expected_kernel_key.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(expected_kernel_key); + } #endif // TODO(jiabin): Add operator.cc's line 1000 part back when we need that // case diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 473a2d28877a66d2d97b0f205ece35044e6e1b08..9bc2f5461f383fbeba509e6de7e5a81f7f7e2780 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -690,7 +690,7 @@ class AllocatorFacadePrivate { #ifdef PADDLE_WITH_MLU int device_count = platform::GetMLUDeviceCount(); for (int i = 0; i < device_count; ++i) { - platform::XPUPlace p(i); + platform::MLUPlace p(i); system_allocators_[p] = std::make_shared(p); } #endif diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 46a679b0c97a0fd724f3591573bf8d22a9220ea0..b5845a1ef9628ac78016da1d56d5f45922aeb92e 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -137,6 +137,10 @@ limitations under the License. */ #include "paddle/fluid/platform/ipu_info.h" #endif +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif + #ifdef PADDLE_WITH_CRYPTO #include "paddle/fluid/pybind/crypto.h" #endif @@ -164,6 +168,7 @@ PyTypeObject *g_cpuplace_pytype = nullptr; PyTypeObject *g_xpuplace_pytype = nullptr; PyTypeObject *g_npuplace_pytype = nullptr; PyTypeObject *g_cudapinnedplace_pytype = nullptr; +PyTypeObject *g_mluplace_pytype = nullptr; PyTypeObject *g_framework_tensor_pytype = nullptr; bool IsCompiledWithCUDA() { @@ -230,6 +235,14 @@ bool IsCompiledWithCINN() { #endif } +bool IsCompiledWithMLU() { +#ifndef PADDLE_WITH_MLU + return false; +#else + return true; +#endif +} + bool IsCompiledWithHETERPS() { #ifndef PADDLE_WITH_HETERPS return false; @@ -295,10 +308,9 @@ OpSupportedInfos(const std::string &place, [](unsigned char c) { return std::toupper(c); }); using fn_type = std::add_pointer::type; std::unordered_map is_target_place{ - {"GPU", &platform::is_gpu_place}, - {"CPU", &platform::is_cpu_place}, - {"XPU", &platform::is_xpu_place}, - {"NPU", &platform::is_npu_place}, + {"GPU", &platform::is_gpu_place}, {"CPU", &platform::is_cpu_place}, + {"XPU", &platform::is_xpu_place}, {"NPU", &platform::is_npu_place}, + {"MLU", &platform::is_mlu_place}, }; PADDLE_ENFORCE_NE( is_target_place.count(query_place), 0, @@ -769,6 +781,10 @@ PYBIND11_MODULE(core_noavx, m) { [](framework::Tensor &self, paddle::platform::NPUPlace &place) { self.mutable_data(place); }) + .def("_alloc_float", + [](framework::Tensor &self, paddle::platform::MLUPlace &place) { + self.mutable_data(place); + }) .def("_alloc_double", [](framework::Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); @@ -785,6 +801,10 @@ PYBIND11_MODULE(core_noavx, m) { [](framework::Tensor &self, paddle::platform::CUDAPlace &place) { self.mutable_data(place); }) + .def("_alloc_int", + [](framework::Tensor &self, paddle::platform::MLUPlace &place) { + self.mutable_data(place); + }) .def("_alloc_int", [](framework::Tensor &self, paddle::platform::CUDAPinnedPlace &place) { @@ -815,6 +835,11 @@ PYBIND11_MODULE(core_noavx, m) { paddle::framework::proto::VarType::Type type) { return reinterpret_cast(self.mutable_data(place, type)); }) + .def("_mutable_data", + [](framework::Tensor &self, paddle::platform::MLUPlace &place, + paddle::framework::proto::VarType::Type type) { + return reinterpret_cast(self.mutable_data(place, type)); + }) .def("_clear", &framework::Tensor::clear) .def("_mutable_data", [](framework::Tensor &self, paddle::platform::NPUPlace &place, @@ -831,6 +856,8 @@ PYBIND11_MODULE(core_noavx, m) { py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) .def("_copy_from", &TensorCopyFrom, py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) + .def("_copy_from", &TensorCopyFrom, + py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) .def("_copy_from", &TensorCopyFrom, py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1) .def("set", SetTensorFromPyArray, @@ -843,6 +870,8 @@ PYBIND11_MODULE(core_noavx, m) { py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) .def("set", SetTensorFromPyArray, py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) + .def("set", SetTensorFromPyArray, + py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) .def("set", SetTensorFromPyArray, py::arg("array"), py::arg("place"), py::arg("zero_copy") = false, R"DOC( @@ -850,7 +879,7 @@ PYBIND11_MODULE(core_noavx, m) { Args: lod (numpy.ndarray): The data to set. - place (CPUPlace|CUDAPlace|XPUPlace|IPUPlace|CUDAPinnedPlace|NPUPlace): The place where the + place (CPUPlace|CUDAPlace|XPUPlace|IPUPlace|CUDAPinnedPlace|NPUPlace|MLUPlace): The place where the LoDTensor is to be set. zero_copy (bool, optional): Whether to share memory with the input numpy array. This parameter only works with CPUPlace. Default: False. @@ -1619,6 +1648,18 @@ All parameter, weight, gradient are variables in Paddle. "Please recompile or reinstall Paddle with XPU support.")); #else return new paddle::platform::XPUDeviceContext(place); +#endif + }) + .def_static("create", + [](paddle::platform::MLUPlace& place) + -> paddle::platform::DeviceContext* { +#ifndef PADDLE_WITH_MLU + PADDLE_THROW( + platform::errors::PermissionDenied( + "Cannot use MLUPlace in CPU/GPU version, " + "Please recompile or reinstall Paddle with MLU support.")); +#else + return new paddle::platform::MLUDeviceContext(place); #endif }) .def_static("create", @@ -1736,6 +1777,7 @@ All parameter, weight, gradient are variables in Paddle. .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_get_device_id", @@ -2004,6 +2046,75 @@ All parameter, weight, gradient are variables in Paddle. #endif .def("__str__", string::to_string); + // MLUPlace + py::class_ mluplace(m, "MLUPlace", R"DOC( + MLUPlace is a descriptor of a device. + It represents a MLU device on which a tensor will be allocated and a model will run. + + Examples: + .. code-block:: python + import paddle + # required: mlu + mlu_place = paddle.MLUPlace(0) + + )DOC"); + g_mluplace_pytype = reinterpret_cast(mluplace.ptr()); + mluplace + .def("__init__", + [](platform::MLUPlace &self, int dev_id) { +#ifdef PADDLE_WITH_MLU + if (UNLIKELY(dev_id < 0)) { + LOG(ERROR) << string::Sprintf( + "Invalid MLUPlace(%d), device id must be 0 or " + "positive integer", + dev_id); + std::exit(-1); + } + if (UNLIKELY(dev_id >= platform::GetMLUDeviceCount())) { + if (platform::GetMLUDeviceCount() == 0) { + LOG(ERROR) << "Cannot use MLU because there is no MLU " + "detected on your " + "machine."; + std::exit(-1); + } else { + LOG(ERROR) << string::Sprintf( + "Invalid MLUPlace(%d), must inside [0, %d), because MLU " + "number on your machine is %d", + dev_id, platform::GetMLUDeviceCount(), + platform::GetMLUDeviceCount()); + std::exit(-1); + } + } + new (&self) platform::MLUPlace(dev_id); +#else + LOG(ERROR) << string::Sprintf( + "Cannot use MLU because you have installed CPU/GPU/... " + "version " + "PaddlePaddle.\n" + "If you want to use MLU, please try to install MLU version " + "PaddlePaddle by: pip install paddlepaddle-mlu\n" + "If you only have CPU, please change MLUPlace(%d) to be " + "CPUPlace().\n", + dev_id); + std::exit(-1); +#endif + }) + .def("_type", &PlaceIndex) +#ifdef PADDLE_WITH_MLU + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", + &IsSamePlace) + .def("get_device_id", + [](const platform::MLUPlace &self) { return self.GetDeviceId(); }) +#endif + .def("__str__", string::to_string); + py::class_ platformplace(m, "Place"); g_place_pytype = reinterpret_cast(platformplace.ptr()); platformplace.def(py::init<>()) @@ -2015,6 +2126,7 @@ All parameter, weight, gradient are variables in Paddle. .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) .def("is_gpu_place", [](platform::Place &self) { return platform::is_gpu_place(self); }) .def("is_cpu_place", @@ -2029,6 +2141,8 @@ All parameter, weight, gradient are variables in Paddle. [](platform::Place &self) { return platform::is_cuda_pinned_place(self); }) + .def("is_mlu_place", + [](platform::Place &self) { return platform::is_mlu_place(self); }) .def("gpu_device_id", [](platform::Place &self) { return BOOST_GET_CONST(platform::CUDAPlace, self).device; @@ -2045,6 +2159,10 @@ All parameter, weight, gradient are variables in Paddle. [](platform::Place &self) { return BOOST_GET_CONST(platform::IPUPlace, self).device; }) + .def("mlu_device_id", + [](platform::Place &self) { + return BOOST_GET_CONST(platform::MLUPlace, self).device; + }) .def("set_place", [](platform::Place &self, const platform::Place &other) { self = other; }) .def("set_place", @@ -2072,6 +2190,10 @@ All parameter, weight, gradient are variables in Paddle. [](platform::Place &self, const platform::IPUPlace &ipu_place) { self = ipu_place; }) + .def("set_place", + [](platform::Place &self, const platform::MLUPlace &mlu_place) { + self = mlu_place; + }) .def("__repr__", string::to_string) .def("__str__", string::to_string); @@ -2120,6 +2242,12 @@ All parameter, weight, gradient are variables in Paddle. pybind11::gil_scoped_release release; self.Run(scope, place); }) + .def("run", + [](OperatorBase &self, const Scope &scope, + const platform::MLUPlace &place) { + pybind11::gil_scoped_release release; + self.Run(scope, place); + }) .def("type", [](const OperatorBase &op) -> std::string { return op.Type(); }) .def("outputs", @@ -2307,6 +2435,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("is_compiled_with_xpu", IsCompiledWithXPU); m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN); m.def("is_compiled_with_cinn", IsCompiledWithCINN); + m.def("is_compiled_with_mlu", IsCompiledWithMLU); m.def("_is_compiled_with_heterps", IsCompiledWithHETERPS); m.def("supports_bfloat16", SupportsBfloat16); m.def("supports_bfloat16_fast_performance", SupportsBfloat16FastPerformance); @@ -2627,6 +2756,10 @@ All parameter, weight, gradient are variables in Paddle. m.def("get_ipu_device_count", platform::GetIPUDeviceCount); #endif +#ifdef PADDLE_WITH_MLU + m.def("get_mlu_device_count", platform::GetMLUDeviceCount); +#endif + py::enum_(m, "TracerOption", py::arithmetic()) .value("kDefault", platform::TracerOption::kDefault) .value("kOpDetail", platform::TracerOption::kOpDetail) diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index df9ba02eadf43268ed9d7d7e874703eb9500df48..9d3a858d1bdbfb1bbe560e14fbf122df035fc84c 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -345,6 +345,18 @@ void SetTensorFromPyArrayT( PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use NPUPlace in CPU/GPU/XPU version. " "Please recompile or reinstall Paddle with NPU support.")); +#endif + } else if (paddle::platform::is_mlu_place(place)) { +#ifdef PADDLE_WITH_MLU + platform::Place tmp_place = place; + platform::MLUDeviceGuard guard( + BOOST_GET_CONST(platform::MLUPlace, tmp_place).device); + auto dst = self->mutable_data(place); + paddle::platform::MLUMemcpyH2DSync(dst, array.data(), array.nbytes()); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Cannot use MLUPlace in CPU/GPU version, " + "Please recompile or reinstall Paddle with MLU support.")); #endif } else { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -702,6 +714,7 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor, bool is_gpu_tensor = platform::is_gpu_place(tensor.place()); bool is_xpu_tensor = platform::is_xpu_place(tensor.place()); bool is_npu_tensor = platform::is_npu_place(tensor.place()); + bool is_mlu_tensor = platform::is_mlu_place(tensor.place()); const auto &tensor_dims = tensor.dims(); auto tensor_dtype = tensor.type(); size_t sizeof_dtype = framework::SizeOfType(tensor_dtype); @@ -720,7 +733,7 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor, std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type()); - if (!is_gpu_tensor && !is_xpu_tensor && !is_npu_tensor) { + if (!is_gpu_tensor && !is_xpu_tensor && !is_npu_tensor && !is_mlu_tensor) { if (!need_deep_copy) { auto base = py::cast(std::move(tensor)); return py::array(py::dtype(py_dtype_str.c_str()), py_dims, py_strides, @@ -816,6 +829,29 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor, PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use NPUPlace in CPU/GPU/XPU version, " "Please recompile or reinstall Paddle with NPU support.")); +#endif + } else if (is_mlu_tensor) { +#ifdef PADDLE_WITH_MLU + py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides); + PADDLE_ENFORCE_EQ(py_arr.writeable(), true, + platform::errors::InvalidArgument( + "PyArray is not writable, in which case memory leak " + "or double free would occur")); + PADDLE_ENFORCE_EQ( + py_arr.owndata(), true, + platform::errors::InvalidArgument( + "PyArray does not own data, in which case memory leak " + "or double free would occur")); + + size_t copy_bytes = sizeof_dtype * numel; + auto p = BOOST_GET_CONST(platform::MLUPlace, tensor.place()); + paddle::memory::Copy(platform::CPUPlace(), py_arr.mutable_data(), p, + tensor_buf_ptr, copy_bytes, nullptr); + return py_arr; +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Cannot use MLUPlace in CPU/GPU/XPU/NPU version, " + "Please recompile or reinstall Paddle with MLU support.")); #endif } PADDLE_THROW(platform::errors::Unimplemented("Place is not supported")); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index b7a601f53fd8506dc81bcbd510a26c7b45933dde..0fecd7c8c36ee32f96b29ab5e4a91f862975580c 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -6,6 +6,8 @@ set(PY_FILES paddle/__init__.py if(WITH_GPU) SET(PACKAGE_NAME "paddlepaddle-gpu") +elseif(WITH_MLU) + SET(PACKAGE_NAME "paddlepaddle-mlu") elseif(WITH_ROCM) SET(PACKAGE_NAME "paddlepaddle-rocm") elseif(WITH_ASCEND_CL) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 771a9053fc264653462db750689aa102c2e0acb0..e3171c4f3bb2ebc28e18df69b5bc742280068e50 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -291,6 +291,7 @@ from .framework import IPUPlace # noqa: F401 from .framework import CUDAPlace # noqa: F401 from .framework import NPUPlace # noqa: F401 from .framework import CUDAPinnedPlace # noqa: F401 +from .framework import MLUPlace # noqa: F401 from .autograd import grad # noqa: F401 from .autograd import no_grad # noqa: F401 @@ -322,6 +323,7 @@ from .fluid.framework import set_flags # noqa: F401 from .device import is_compiled_with_xpu # noqa: F401 from .device import is_compiled_with_npu # noqa: F401 from .device import is_compiled_with_ipu # noqa: F401 +from .device import is_compiled_with_mlu # noqa: F401 from .device import XPUPlace # noqa: F401 from .fluid.dygraph.base import enable_dygraph as disable_static # noqa: F401 diff --git a/python/paddle/device/__init__.py b/python/paddle/device/__init__.py index 0a11d59d69c94c0345802b0f9d070aa23f4c0a24..d102473fef791124e0605008dd1844507c3b4a61 100644 --- a/python/paddle/device/__init__.py +++ b/python/paddle/device/__init__.py @@ -29,12 +29,14 @@ __all__ = [ # noqa 'get_device', 'XPUPlace', 'IPUPlace', + 'MLUPlace', 'is_compiled_with_xpu', 'is_compiled_with_ipu', 'is_compiled_with_cinn', 'is_compiled_with_cuda', 'is_compiled_with_rocm', - 'is_compiled_with_npu' + 'is_compiled_with_npu', + 'is_compiled_with_mlu' ] _cudnn_version = None @@ -120,6 +122,41 @@ def XPUPlace(dev_id): return core.XPUPlace(dev_id) +def is_compiled_with_mlu(): + """ + Whether paddle was built with WITH_MLU=ON to support Cambricon MLU + + Returns (bool): whether paddle was built with WITH_MLU=ON + + Examples: + .. code-block:: python + + # required: mlu + + import paddle + support_mlu = paddle.device.is_compiled_with_mlu() + """ + return core.is_compiled_with_mlu() + + +def MLUPlace(dev_id): + """ + Return a Cambricon MLU Place + + Parameters: + dev_id(int): MLU device id + + Examples: + .. code-block:: python + + # required: mlu + + import paddle + place = paddle.device.MLUPlace(0) + """ + return core.MLUPlace(dev_id) + + def get_cudnn_version(): """ This funciton return the version of cudnn. the retuen value is int which represents the @@ -181,13 +218,21 @@ def _convert_to_place(device): "The device should not be 'ipu', " \ "since PaddlePaddle is not compiled with IPU") place = core.IPUPlace() + elif lower_device == 'mlu': + if not core.is_compiled_with_mlu(): + raise ValueError("The device should not be 'mlu', " + "since PaddlePaddle is not compiled with MLU") + selected_mlus = os.getenv("FLAGS_selected_mlus", "0").split(",") + device_id = int(selected_mlus[0]) + place = core.MLUPlace(device_id) else: avaliable_gpu_device = re.match(r'gpu:\d+', lower_device) avaliable_xpu_device = re.match(r'xpu:\d+', lower_device) avaliable_npu_device = re.match(r'npu:\d+', lower_device) - if not avaliable_gpu_device and not avaliable_xpu_device and not avaliable_npu_device: + avaliable_mlu_device = re.match(r'mlu:\d+', lower_device) + if not avaliable_gpu_device and not avaliable_xpu_device and not avaliable_npu_device and not avaliable_mlu_device: raise ValueError( - "The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu', 'npu:x' or ipu" + "The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'mlu', 'mlu:x', 'npu', 'npu:x' or ipu" ) if avaliable_gpu_device: if not core.is_compiled_with_cuda(): @@ -216,19 +261,28 @@ def _convert_to_place(device): device_id = device_info_list[1] device_id = int(device_id) place = core.NPUPlace(device_id) + if avaliable_mlu_device: + if not core.is_compiled_with_mlu(): + raise ValueError( + "The device should not be {}, since PaddlePaddle is " + "not compiled with mlu".format(avaliable_mlu_device)) + device_info_list = device.split(':', 1) + device_id = device_info_list[1] + device_id = int(device_id) + place = core.MLUPlace(device_id) return place def set_device(device): """ - Paddle supports running calculations on various types of devices, including CPU, GPU, XPU, NPU and IPU. + Paddle supports running calculations on various types of devices, including CPU, GPU, XPU, NPU, MLU and IPU. They are represented by string identifiers. This function can specify the global device which the OP will run. Parameters: device(str): This parameter determines the specific running device. - It can be ``cpu``, ``gpu``, ``xpu``, ``npu``, ``gpu:x``, ``xpu:x``, ``npu:x`` and ``ipu``, - where ``x`` is the index of the GPUs, XPUs or NPUs. + It can be ``cpu``, ``gpu``, ``xpu``, ``npu``, ``mlu``, ``gpu:x``, ``xpu:x``, ``npu:x``, ``mlu:x`` and ``ipu``, + where ``x`` is the index of the GPUs, XPUs, NPUs or MLUs. Examples: @@ -249,7 +303,7 @@ def set_device(device): def get_device(): """ This funciton can get the current global device of the program is running. - It's a string which is like 'cpu', 'gpu:x', 'xpu:x' and 'npu:x'. if the global device is not + It's a string which is like 'cpu', 'gpu:x', 'xpu:x', 'mlu:x' and 'npu:x'. if the global device is not set, it will return a string which is 'gpu:x' when cuda is avaliable or it will return a string which is 'cpu' when cuda is not avaliable. @@ -277,6 +331,9 @@ def get_device(): elif isinstance(place, core.IPUPlace): num_devices = core.get_ipu_device_count() device = "ipus:{{0-{}}}".format(num_devices - 1) + elif isinstance(place, core.MLUPlace): + device_id = place.get_device_id() + device = 'mlu:' + str(device_id) else: raise ValueError("The device specification {} is invalid".format(place)) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index d8ee875e768e524525a67042eb4952df53901f05..cd8f9f8545847d1c08588cbba1524e2f56331116 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -71,7 +71,7 @@ from . import distribute_lookup_table from .param_attr import ParamAttr, WeightNormParamAttr from .data_feeder import DataFeeder from .core import LoDTensor, LoDTensorArray, Scope, _Scope -from .core import CPUPlace, XPUPlace, CUDAPlace, CUDAPinnedPlace, NPUPlace, IPUPlace +from .core import CPUPlace, XPUPlace, CUDAPlace, CUDAPinnedPlace, NPUPlace, IPUPlace, MLUPlace from .incubate import fleet from .transpiler import DistributeTranspiler, \ memory_optimize, release_memory, DistributeTranspilerConfig @@ -133,6 +133,7 @@ __all__ = framework.__all__ + executor.__all__ + \ 'CUDAPinnedPlace', 'NPUPlace', 'IPUPlace', + 'MLUPlace', 'Tensor', 'ParamAttr', 'WeightNormParamAttr', diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index fd2a93874876ca852c2ad88f4e6ebb8e48892bcb..73407ef834e228f918a0d0bd488b7cc17e685077 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -51,6 +51,7 @@ __all__ = [ 'cuda_places', 'cpu_places', 'xpu_places', + 'mlu_places', 'cuda_pinned_places', 'in_dygraph_mode', 'is_compiled_with_cinn', @@ -347,6 +348,18 @@ def _current_expected_place(): "You are using XPU version Paddle, but your XPU device is not set properly. CPU device will be used by default." ) _global_expected_place_ = core.CPUPlace() + elif core.is_compiled_with_mlu(): + try: + device_count = core.get_mlu_device_count() + except Exception as e: + device_count = 0 + if device_count > 0: + _global_expected_place_ = core.MLUPlace(0) + else: + warnings.warn( + "You are using MLU version Paddle, but your MLU device is not set properly. CPU device will be used by default." + ) + _global_expected_place_ = core.CPUPlace() else: _global_expected_place_ = core.CPUPlace() @@ -426,6 +439,15 @@ def _npu_ids(): return device_ids +def _mlu_ids(): + mlus_env = os.getenv("FLAGS_selected_mlus") + if mlus_env: + device_ids = [int(s) for s in mlus_env.split(",")] + else: + device_ids = six.moves.range(core.get_mlu_device_count()) + return device_ids + + def is_compiled_with_xpu(): """ Whether this whl package can be used to run the model on XPU. @@ -721,6 +743,48 @@ def cuda_pinned_places(device_count=None): return [core.CUDAPinnedPlace()] * device_count +def mlu_places(device_ids=None): + """ + **Note**: + For multi-card tasks, please use `FLAGS_selected_mlus` environment variable to set the visible MLU device. + This function creates a list of :code:`paddle.device.MLUPlace` objects. + If :code:`device_ids` is None, environment variable of + :code:`FLAGS_selected_mlus` would be checked first. For example, if + :code:`FLAGS_selected_mlus=0,1,2`, the returned list would + be [paddle.device.MLUPlace(0), paddle.device.MLUPlace(1), paddle.device.MLUPlace(2)]. + If :code:`FLAGS_selected_mlus` is not set, all visible + mlu places would be returned. + If :code:`device_ids` is not None, it should be the device + ids of MLUs. For example, if :code:`device_ids=[0,1,2]`, + the returned list would be + [paddle.device.MLUPlace(0), paddle.device.MLUPlace(1), paddle.device.MLUPlace(2)]. + + Parameters: + device_ids (list or tuple of int, optional): list of MLU device ids. + + Returns: + list of paddle.device.MLUPlace: Created MLU place list. + + Examples: + .. code-block:: python + + # required: mlu + + import paddle + import paddle.static as static + + paddle.enable_static() + mlu_places = static.mlu_places() + """ + assert core.is_compiled_with_mlu(), \ + "Not compiled with MLU" + if device_ids is None: + device_ids = _mlu_ids() + elif not isinstance(device_ids, (list, tuple)): + device_ids = [device_ids] + return [core.MLUPlace(dev_id) for dev_id in device_ids] + + class NameScope(object): def __init__(self, name="", parent=None): self._children = dict() @@ -2090,6 +2154,10 @@ class Variable(object): p = core.Place() p.set_place(t._place()) place = core.NPUPlace(p.npu_device_id()) + elif p.is_mlu_place(): + p = core.Place() + p.set_place(t._place()) + place = core.MLUPlace(p.mlu_device_id()) else: p = core.Place() p.set_place(t._place()) @@ -6768,7 +6836,8 @@ def _get_paddle_place(place): if place is None: return place if isinstance(place, (core.Place, core.XPUPlace, core.CPUPlace, - core.CUDAPinnedPlace, core.CUDAPlace, core.NPUPlace)): + core.CUDAPinnedPlace, core.CUDAPlace, core.NPUPlace, + core.MLUPlace)): return place if not isinstance(place, str): @@ -6823,8 +6892,20 @@ def _get_paddle_place(place): device_id = int(device_id) return core.NPUPlace(device_id) + # MLU + avaliable_mlu_place = re.match(r'mlu:\d+', place) + if avaliable_mlu_place: + if not core.is_compiled_with_mlu(): + raise ValueError( + "The device should not be {}, since PaddlePaddle is " \ + "not compiled with MLU".format(avaliable_mlu_place)) + place_info_list = place.split(':', 1) + device_id = place_info_list[1] + device_id = int(device_id) + return core.MLUPlace(device_id) + raise ValueError( - "Paddle supports CPUPlace, CUDAPlace,CUDAPinnedPlace, XPUPlace and NPUPlace, but received {}.". + "Paddle supports CPUPlace, CUDAPlace,CUDAPinnedPlace, XPUPlace, MLUPlace and NPUPlace, but received {}.". format(place)) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index e110c47d790f1e082f858f76da89bec4dc97f1f4..4bbc0ba03c9342afc4a0d2edee6c2b963ad6e0f8 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -2100,6 +2100,10 @@ def load(program, model_path, executor=None, var_list=None): p = paddle.fluid.core.Place() p.set_place(t._place()) place = paddle.fluid.NPUPlace(p.npu_device_id()) + elif p.is_mlu_place(): + p = paddle.fluid.core.Place() + p.set_place(t._place()) + place = paddle.fluid.MLUPlace(p.mlu_device_id()) else: p = paddle.fluid.core.Place() p.set_place(t._place()) @@ -2394,6 +2398,10 @@ def set_program_state(program, state_dict): p = paddle.fluid.core.Place() p.set_place(ten_place) py_place = paddle.fluid.NPUPlace(p.npu_device_id()) + elif ten_place.is_mlu_place(): + p = paddle.fluid.core.Place() + p.set_place(ten_place) + py_place = paddle.fluid.MLUPlace(p.mlu_device_id()) ten.set(new_para_np, py_place) diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index a081da60c68392d498f016668a2242f0110c942c..a0503322806e5825ca720740e93c07ecf6cb51fb 100644 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -28,6 +28,7 @@ from ..fluid.core import IPUPlace # noqa: F401 from ..fluid.core import CUDAPlace # noqa: F401 from ..fluid.core import CUDAPinnedPlace # noqa: F401 from ..fluid.core import NPUPlace # noqa: F401 +from ..fluid.core import MLUPlace # noqa: F401 from ..fluid.core import VarBase # noqa: F401 from paddle.fluid import core # noqa: F401 diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index 92aa5000dfa58cbe899c8718ce6f3d356d283822..f18b77997a5e2bcb47341d48035dc59fdd8f657c 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -45,6 +45,7 @@ from ..fluid.framework import program_guard # noqa: F401 from ..fluid.framework import cpu_places # noqa: F401 from ..fluid.framework import cuda_places # noqa: F401 from ..fluid.framework import xpu_places # noqa: F401 +from ..fluid.framework import mlu_places # noqa: F401 from ..fluid.framework import npu_places # noqa: F401 from ..fluid.framework import Variable # noqa: F401 from ..fluid.layers.control_flow import Print # noqa: F401 @@ -103,6 +104,7 @@ __all__ = [ #noqa 'cuda_places', 'xpu_places', 'npu_places', + 'mlu_places', 'Variable', 'create_global_var', 'accuracy',