custom_device_py.cc 19.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/custom_device_py.h"

#include <string>
#include <vector>

#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"

namespace py = pybind11;

namespace paddle {
namespace pybind {
void BindCustomDevicePy(py::module *m_ptr) {
  auto &m = *m_ptr;
  // Bind Methods
  m.def(
      "_get_current_custom_device_stream",
      [](const std::string &device_type, int device_id) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
        auto place = paddle::platform::CustomPlace(
            device_type,
            device_id == -1 ? phi::DeviceManager::GetDevice(device_type)
                            : device_id);

        return static_cast<const phi::CustomContext *>(
                   paddle::platform::DeviceContextPool::Instance().Get(place))
            ->GetStream();
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit _get_current_custom_device_stream."));
#endif
      },
      py::return_value_policy::reference,
      py::arg("device_type"),
      py::arg("device_id") = -1);
  m.def(
      "_set_current_custom_device_stream",
      [](const std::string &device_type,
         int device_id,
         std::shared_ptr<phi::stream::Stream> stream) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
        auto place = paddle::platform::CustomPlace(
            device_type,
            device_id == -1 ? phi::DeviceManager::GetDevice(device_type)
                            : device_id);
        static_cast<phi::CustomContext *>(
            paddle::platform::DeviceContextPool::Instance().Get(place))
            ->SetStream(stream);
        return stream;
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit _set_current_custom_device_stream."));
#endif
      },
      py::arg("device_type"),
      py::arg("device_id") = -1,
      py::arg("stream") = nullptr);
  m.def("_synchronize_custom_device",
        [](const std::string &device_type, int device_id) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
          auto place = paddle::platform::CustomPlace(
              device_type,
              device_id == -1 ? phi::DeviceManager::GetDevice(device_type)
                              : device_id);
          phi::DeviceManager::SynchronizeDevice(place);
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit _synchronize_custom_device."));
#endif
        });

  py::class_<phi::stream::Stream, std::shared_ptr<phi::stream::Stream>>(
      m, "CustomDeviceStream", R"DOC(
      The handle of the custom device stream.

      Parameters:
        device(paddle.CustomPlace()|str): The device which wanted to allocate the stream.

        device_id(int, optional): The id of the device which wanted to allocate the stream.
        If device is None or negative integer, device will be the current device.
        If device is positive integer, it must less than the device count. Default: None.

        priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
        If priority is None, the priority is 2(normal). Default: None.

        blocking(int|None, optional): Whether the stream is executed synchronously. Default: False.

      Examples:
        .. code-block:: python

            # required: custom_device
            import paddle
            s3 = paddle.device.custom.Stream('custom_cpu')
            s2 = paddle.device.custom.Stream('custom_cpu', 0)
            s1 = paddle.device.custom.Stream(paddle.CustomPlace('custom_cpu'))
            s1 = paddle.device.custom.Stream(paddle.CustomPlace('custom_cpu'), 1)
            s1 = paddle.device.custom.Stream(paddle.CustomPlace('custom_cpu'), 1, True)

  )DOC")
      .def(
          "__init__",
          [](phi::stream::Stream &self,
             const platform::CustomPlace &place,
             int priority,
             bool blocking) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            new (&self) phi::stream::Stream();
            self.Init(
                place,
                static_cast<phi::stream::Stream::Priority>(priority),
                static_cast<phi::stream::Stream::Flag>(
                    blocking ? phi::stream::Stream::Flag::kDefaultFlag
                             : phi::stream::Stream::Flag::kStreamNonBlocking));
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceStream."));
#endif
          },
          py::arg("device"),
          py::arg("priority") = 2,
          py::arg("blocking") = false)
      .def(
          "__init__",
          [](phi::stream::Stream &self,
             const std::string &device_type,
             int device_id,
             int priority,
             bool blocking) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            new (&self) phi::stream::Stream();
            self.Init(
                phi::CustomPlace(
                    device_type,
                    device_id == -1 ? phi::DeviceManager::GetDevice(device_type)
                                    : device_id),
                static_cast<phi::stream::Stream::Priority>(priority),
                static_cast<phi::stream::Stream::Flag>(
                    blocking ? phi::stream::Stream::Flag::kDefaultFlag
                             : phi::stream::Stream::Flag::kStreamNonBlocking));
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceStream."));
#endif
          },
          py::arg("device"),
          py::arg("device_id") = -1,
          py::arg("priority") = 2,
          py::arg("blocking") = false)
      .def(
          "wait_event",
          [](const phi::stream::Stream &self, phi::event::Event *event) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            self.WaitEvent(event);
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceStream."));
#endif
          },
          R"DOC(
      Makes all future work submitted to stream wait for all work captured in event.

      Parameters:
        event(CustomDeviceEvent): The event to wait on.

      Examples:
        .. code-block:: python

          # required: custom_device
          import paddle
          place = paddle.CustomPlace('custom_cpu', 0)
          s = paddle.device.custom.Stream(place)
          event = paddle.device.custom.Event(place)
          s.wait_event(event)

           )DOC")
      .def(
          "wait_stream",
          [](const phi::stream::Stream &self, phi::stream::Stream *other) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            phi::event::Event event;
            event.Init(self.GetPlace());
            event.Record(other);
            self.WaitEvent(&event);
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceStream."));
#endif
          },
          R"DOC(
      Synchronizes with the given stream.

      Parameters:
        stream(CUDAStream): The stream to synchronize with.

      Examples:
        .. code-block:: python

            # required: custom_device
            import paddle
            place = paddle.CustomPlace('custom_cpu', 0)
            s1 = paddle.device.custom.Stream(place)
            s2 = paddle.device.custom.Stream(place)
            s1.wait_stream(s2)

           )DOC")
      .def(
          "query",
          [](const phi::stream::Stream &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            return self.Query();
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceStream."));
#endif
          },
          R"DOC(
      Return the status whether if all operations in stream have completed.

      Returns: A boolean value.

      Examples:
        .. code-block:: python

            # required: custom_device
            import paddle
            place = paddle.CustomPlace('custom_cpu', 0)
            s = paddle.device.custom.Stream(place)
            is_done = s.query()

           )DOC")
      .def(
          "synchronize",
          [](const phi::stream::Stream &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            self.Synchronize();
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceStream."));
#endif
          },
          R"DOC(
      Waits for stream tasks to complete.

      Examples:
        .. code-block:: python

            # required: custom_device
            import paddle
            place = paddle.CustomPlace('custom_cpu', 0)
            s = paddle.device.custom.Stream(place)
            s.synchronize()

           )DOC")
      .def(
          "record_event",
          [](const phi::stream::Stream &self, phi::event::Event *event) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            if (event == nullptr) {
              event = new phi::event::Event;
              event->Init(self.GetPlace());
            }
            event->Record(&self);
            return event;
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceStream."));
#endif
          },
          R"DOC(
      Record an event in the stream.

      Parameters:
          event(CustomDeviceEvent, optional): The event to be record. If event is None, a new event is created.
          Default: None.

      Returns:
          The recored event.

      Examples:
        .. code-block:: python

            # required: custom_device
            import paddle
            place = paddle.CustomPlace('custom_cpu', 0)
            s = paddle.device.custom.Stream(place)
            event = s.record_event()

           )DOC",
          py::arg("event") = nullptr)
      .def_property_readonly(
          "raw_stream",
          [](const phi::stream::Stream &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            VLOG(10) << self.raw_stream();
            return reinterpret_cast<std::uintptr_t>(self.raw_stream());
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceStream."));
#endif
          },
          R"DOC(
      return the raw stream of type CustomDeviceStream as type int.

      Examples:
        .. code-block:: python

            # required: custom_device
            import paddle
            import ctypes
            stream  = paddle.device.custom.current_stream().raw_stream
            print(stream)

            ptr = ctypes.c_void_p(stream)  # convert back to void*
            print(ptr)

           )DOC")
      .def_property_readonly("place", [](const phi::stream::Stream &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
        return reinterpret_cast<const phi::CustomPlace &>(self.GetPlace());
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceStream."));
#endif
      });

  py::class_<phi::event::Event, std::shared_ptr<phi::event::Event>>(
      m, "CustomDeviceEvent", R"DOC(
      The handle of the custom device event.

      Parameters:
        device(paddle.CustomPlace()|str): The device which wanted to allocate the stream.

        device_id(int, optional): The id of the device which wanted to allocate the stream.
        If device is None or negative integer, device will be the current device.
        If device is positive integer, it must less than the device count. Default: None.

        enable_timing(bool, optional): Whether the event will measure time. Default: False.

        blocking(bool, optional): Whether the wait() func will be blocking. Default: False;

        interprocess(bool, optional): Whether the event can be shared between processes. Default: False.

      Examples:
        .. code-block:: python

            # required: custom_device
            import paddle
            place = paddle.CustomPlace('custom_cpu', 0)
            event = paddle.device.custom.Event(place)

  )DOC")
      .def(
          "__init__",
          [](phi::event::Event &self,
             const platform::CustomPlace &place,
             bool enable_timing,
             bool blocking,
             bool interprocess) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            auto flag = static_cast<phi::event::Event::Flag>(
                static_cast<uint32_t>(
                    enable_timing ? 0
                                  : phi::event::Event::Flag::DisableTiming) |
                static_cast<uint32_t>(
                    !blocking ? 0 : phi::event::Event::Flag::BlockingSync) |
                static_cast<uint32_t>(
                    !interprocess ? 0 : phi::event::Event::Flag::Interprocess)

            );
            new (&self) phi::event::Event();
            self.Init(place, flag);
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceEvent."));
#endif
          },
          py::arg("device"),
          py::arg("enable_timing") = false,
          py::arg("blocking") = false,
          py::arg("interprocess") = false)
      .def(
          "__init__",
          [](phi::event::Event &self,
             const std::string &device_type,
             int device_id,
             bool enable_timing,
             bool blocking,
             bool interprocess) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            auto flag = static_cast<phi::event::Event::Flag>(
                static_cast<uint32_t>(
                    enable_timing ? 0
                                  : phi::event::Event::Flag::DisableTiming) |
                static_cast<uint32_t>(
                    !blocking ? 0 : phi::event::Event::Flag::BlockingSync) |
                static_cast<uint32_t>(
                    !interprocess ? 0 : phi::event::Event::Flag::Interprocess)

            );
            new (&self) phi::event::Event();
            self.Init(
                phi::CustomPlace(
                    device_type,
                    device_id == -1 ? phi::DeviceManager::GetDevice(device_type)
                                    : device_id),
                flag);
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceEvent."));
#endif
          },
          py::arg("device"),
          py::arg("device_id") = -1,
          py::arg("enable_timing") = false,
          py::arg("blocking") = false,
          py::arg("interprocess") = false)
      .def(
          "record",
          [](phi::event::Event &self, phi::stream::Stream *stream) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            if (stream == nullptr) {
              stream = static_cast<const phi::CustomContext *>(
                           paddle::platform::DeviceContextPool::Instance().Get(
                               self.GetPlace()))
                           ->GetStream()
                           .get();
            }
            self.Record(stream);
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceEvent."));
#endif
          },
          R"DOC(
          Records the event in the given stream.

          Parameters:
            stream(CustomDeviceStream, optional): The handle of custom device stream. If None, the stream is the current stream. Default: None.

          Examples:
            .. code-block:: python

              # required: custom_device
              import paddle
              place = paddle.CustomPlace('custom_cpu', 0)
              event = paddle.device.custom.Event(place)
              event.record()

        )DOC")
      .def(
          "query",
          [](const phi::event::Event &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            return self.Query();
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceEvent."));
#endif
          },
          R"DOC(
          Queries the event's status.

          Returns: A boolean which indicates all work currently captured by the event has been completed.

          Examples:
            .. code-block:: python

                # required: custom_device
                import paddle
                place = paddle.CustomPlace('custom_cpu', 0)
                event = paddle.device.cuda.Event(place)
                is_done = event.query()

           )DOC")
      .def(
          "synchronize",
          [](const phi::event::Event &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            self.Synchronize();
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceEvent."));
#endif
          },
          R"DOC(
            Waits for an event to complete.

            Examples:
              .. code-block:: python

                # required: custom_device
                import paddle
                place = paddle.CustomPlace('custom_cpu', 0)
                event = paddle.device.custom.Event(place)
                event.synchronize()

           )DOC")
      .def_property_readonly(
          "raw_event",
          [](const phi::event::Event &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
            VLOG(10) << self.raw_event();
            return reinterpret_cast<std::uintptr_t>(self.raw_event());
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceEvent."));
#endif
          },
          R"DOC(
      return the raw event of type CustomDeviceEvent as type int.

      Examples:
        .. code-block:: python

            # required: custom_device
            import paddle
            import ctypes
            place = paddle.CustomPlace('custom_cpu', 0)
            event = paddle.device.custom.Event(place)
            raw_event = event.raw_event
            print(raw_event)

            ptr = ctypes.c_void_p(raw_event)  # convert back to void*
            print(ptr)

           )DOC")
      .def_property_readonly("place", [](const phi::event::Event &self) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
        return reinterpret_cast<const phi::CustomPlace &>(self.GetPlace());
#else
        PADDLE_THROW(platform::errors::Unavailable(
            "Paddle is not compiled with CustomDevice. "
            "Cannot visit CustomDeviceEvent."));
#endif
      });
}
}  // namespace pybind
}  // namespace paddle