diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index f09cbfc3bef163552eaf3d190539bfc438296117..05f46dd39602385d2af9303e042f574fd26078ad 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -1442,6 +1442,28 @@ void Copy(phi::Place dst_place, return Copy(place_dst, dst, place_src, src, num); } #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + else if (src_place.GetType() == phi::AllocationType::CPU && // NOLINT + dst_place.GetType() == phi::AllocationType::CUSTOM) { + platform::CustomPlace place_dst(dst_place.GetDeviceType(), + dst_place.GetDeviceId()); + platform::CPUPlace place_src; + return Copy(place_dst, dst, place_src, src, num, nullptr); + } else if (src_place.GetType() == phi::AllocationType::CUSTOM && + dst_place.GetType() == phi::AllocationType::CPU) { + platform::CustomPlace place_src(src_place.GetDeviceType(), + src_place.GetDeviceId()); + platform::CPUPlace place_dst; + return Copy(place_dst, dst, place_src, src, num, nullptr); + } else if (src_place.GetType() == phi::AllocationType::CUSTOM && + dst_place.GetType() == phi::AllocationType::CUSTOM) { + platform::CustomPlace place_src(src_place.GetDeviceType(), + src_place.GetDeviceId()); + platform::CustomPlace place_dst(dst_place.GetDeviceType(), + dst_place.GetDeviceId()); + return Copy(place_dst, dst, place_src, src, num, nullptr); + } +#endif } // NOTE: Only for (CPUPlace) -> (CPUPlace and PinnedPlace). diff --git a/paddle/fluid/operators/reader/buffered_reader.cc b/paddle/fluid/operators/reader/buffered_reader.cc index f3913a62b29d1a376a290020f2c00550e9818054..a36d51e42f5c8bbee728e83ec012c64eb2da11cc 100644 --- a/paddle/fluid/operators/reader/buffered_reader.cc +++ b/paddle/fluid/operators/reader/buffered_reader.cc @@ -19,6 +19,9 @@ #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/phi/backends/device_guard.h" +#include "paddle/phi/backends/device_manager.h" + namespace paddle { namespace operators { namespace reader { @@ -105,11 +108,30 @@ BufferedReader::BufferedReader( } #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (platform::is_custom_place(place_)) { + auto stream = ((platform::CustomDeviceContext + *)(platform::DeviceContextPool::Instance().Get(place_))) + ->stream(); + custom_device_compute_stream_ = + std::make_shared(place_, stream); + + custom_device_events_.resize(buffer_size); + for (auto &event : custom_device_events_) { + event = std::make_shared(); + event->Init(place_); + } + custom_device_stream_ = std::make_shared(); + custom_device_stream_->Init(place_); + } +#endif + cpu_buffer_.resize(buffer_size); cuda_buffer_.resize(buffer_size); npu_buffer_.resize(buffer_size); mlu_buffer_.resize(buffer_size); xpu_buffer_.resize(buffer_size); + custom_device_buffer_.resize(buffer_size); ReadTillBufferFullAsync(); } @@ -410,6 +432,58 @@ void BufferedReader::ReadAsync(size_t i) { platform::XPUStreamSync(stream_.get()); } #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (platform::is_custom_place(place_)) { + TensorVec &custom_device = custom_device_buffer_[i]; + if (custom_device.empty()) { + custom_device.resize(cpu.size()); + } else { + PADDLE_ENFORCE_EQ(custom_device.size(), + cpu.size(), + platform::errors::InvalidArgument( + "Input tensor number on CustomDevice and CPU " + "devices are not matched. " + "The number on CustomDevice is %d, on CPU is %d", + custom_device.size(), + cpu.size())); + } + + std::vector custom_device_ptrs; + custom_device_ptrs.reserve(cpu.size()); + for (size_t i = 0; i < cpu.size(); ++i) { + custom_device[i].Resize(cpu[i].dims()); + custom_device[i].set_layout(cpu[i].layout()); + custom_device_ptrs.emplace_back( + custom_device[i].mutable_data(place_, cpu[i].type())); + } + + phi::DeviceManager::SetDevice(place_); + phi::DeviceManager::GetDeviceWithPlace(place_)->RecordEvent( + custom_device_events_[i].get(), custom_device_compute_stream_.get()); + phi::DeviceManager::GetDeviceWithPlace(place_)->StreamWaitEvent( + custom_device_stream_.get(), custom_device_events_[i].get()); + + platform::RecordEvent record_event("BufferedReader:MemoryCopy", + platform::TracerEventType::UserDefined, + 1); + for (size_t i = 0; i < cpu.size(); ++i) { + auto cpu_place = cpu[i].place(); + auto cpu_ptr = cpu[i].data(); + auto custom_device_ptr = custom_device_ptrs[i]; + auto size = + cpu[i].numel() * paddle::framework::DataTypeSize(cpu[i].dtype()); + if ((platform::is_custom_place(cpu_place))) { + memory::Copy(place_, custom_device_ptr, cpu_place, cpu_ptr, size); + custom_device_stream_->Synchronize(); + } else { + memory::Copy(place_, custom_device_ptr, cpu_place, cpu_ptr, size); + } + custom_device[i].set_lod(cpu[i].lod()); + } + custom_device_stream_->Synchronize(); + } +#endif return i; })); } @@ -449,6 +523,8 @@ void BufferedReader::ReadNextImpl(std::vector *out) { *out = std::move(mlu_buffer_[i]); } else if (platform::is_xpu_place(place_)) { *out = std::move(xpu_buffer_[i]); + } else if (platform::is_custom_place(place_)) { + *out = std::move(custom_device_buffer_[i]); } else { *out = std::move(cpu_buffer_[i]); } diff --git a/paddle/fluid/operators/reader/buffered_reader.h b/paddle/fluid/operators/reader/buffered_reader.h index 94c2fb12486bc00078a2501a496dc961029fbecd..06aaf4c12057da32febfaa4b1fc8d93bca5f5c0f 100644 --- a/paddle/fluid/operators/reader/buffered_reader.h +++ b/paddle/fluid/operators/reader/buffered_reader.h @@ -37,7 +37,10 @@ #include "paddle/fluid/platform/device/xpu/xpu_info.h" #include "paddle/fluid/platform/device/xpu/xpu_resource_pool.h" #endif - +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/backends/event.h" +#include "paddle/phi/backends/stream.h" +#endif namespace paddle { namespace operators { namespace reader { @@ -82,6 +85,7 @@ class BufferedReader : public framework::DecoratedReader { std::vector npu_buffer_; std::vector mlu_buffer_; std::vector xpu_buffer_; + std::vector custom_device_buffer_; size_t prev_pos_{-1UL}; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) gpuStream_t compute_stream_; @@ -106,6 +110,12 @@ class BufferedReader : public framework::DecoratedReader { std::shared_ptr stream_; std::vector> events_; #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE + std::shared_ptr custom_device_compute_stream_; + std::shared_ptr custom_device_stream_; + std::vector> custom_device_events_; +#endif }; } // namespace reader diff --git a/paddle/phi/backends/device_guard.h b/paddle/phi/backends/device_guard.h index eb14236d251b34a83e22d56bc7406b21cd5a84c5..668951f8a1c981acb8d16a9f2781f806156c1dc6 100644 --- a/paddle/phi/backends/device_guard.h +++ b/paddle/phi/backends/device_guard.h @@ -13,6 +13,8 @@ // limitations under the License. #pragma once +#ifdef PADDLE_WITH_CUSTOM_DEVICE + #include "paddle/phi/backends/device_manager.h" namespace phi { @@ -44,3 +46,5 @@ class DeviceGuard { }; } // namespace phi + +#endif diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index 35339aed0f3e1cd87ac65855e0255fa3277a6bfb..ffaf42a0cf4e63a39fbcf3e7c4ccb8f558db030c 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -394,8 +394,10 @@ DeviceManager& DeviceManager::Instance() { } void DeviceManager::Clear() { - Instance().device_map_.clear(); - Instance().device_impl_map_.clear(); + // TODO(wangran16): fix coredump when using npu plugin + + // Instance().device_map_.clear(); + // Instance().device_impl_map_.clear(); } std::vector ListAllLibraries(const std::string& library_dir) { diff --git a/paddle/phi/backends/event.cc b/paddle/phi/backends/event.cc index 43077d280f360aa65baec850a4f0fac31cf53a4d..b594d919abc18663c4b81a7423bcca50e2a91d1c 100644 --- a/paddle/phi/backends/event.cc +++ b/paddle/phi/backends/event.cc @@ -35,7 +35,11 @@ Event::~Event() { Destroy(); } bool Event::Init(const Place& place, Flag flags) { place_ = place; - DeviceGuard guard(place_); + device_ = phi::DeviceManager::GetDeviceWithPlace(place); + + // note(wangran16): bind device to the current thread. fix npu plugin null + // context bug. + phi::DeviceManager::SetDevice(place_); device_->CreateEvent(this, flags); VLOG(3) << "Init Event: " << event_ << ", place: " << place_ << ", flag:" << static_cast(flags); @@ -45,7 +49,7 @@ bool Event::Init(const Place& place, Flag flags) { void Event::Destroy() { if (own_data_) { - DeviceGuard guard(place_); + phi::DeviceManager::SetDevice(place_); device_->DestroyEvent(this); own_data_ = false; } diff --git a/paddle/phi/backends/event.h b/paddle/phi/backends/event.h index 0866adcf39afa6259ccc424d93eb086427b39679..8de223528f8fdffe16809c8744a56ff1cc71824d 100644 --- a/paddle/phi/backends/event.h +++ b/paddle/phi/backends/event.h @@ -36,6 +36,7 @@ class Event { Interprocess = 0x4, }; + Event() = default; // For compatible Event(const Place& place, event_t event); ~Event(); diff --git a/paddle/phi/backends/stream.cc b/paddle/phi/backends/stream.cc index f8b15bdbd9e633ff49694664194d8abc6a17744e..bad57c5238ec834364a9e4a16f1398fc51cdf905 100644 --- a/paddle/phi/backends/stream.cc +++ b/paddle/phi/backends/stream.cc @@ -40,7 +40,10 @@ bool Stream::Init(const Place& place, const Flag& flag) { place_ = place; device_ = phi::DeviceManager::GetDeviceWithPlace(place); - DeviceGuard guard(place_); + + // note(wangran16): bind device to the current thread. fix npu plugin null + // context bug. + phi::DeviceManager::SetDevice(place_); device_->CreateStream(this, priority, flag); callback_manager_.reset(new CallbackManager(this)); @@ -80,7 +83,7 @@ void Stream::WaitCallback() const { callback_manager_->Wait(); } void Stream::Destroy() { if (own_data_) { - DeviceGuard guard(place_); + phi::DeviceManager::SetDevice(place_); device_->DestroyStream(this); own_data_ = false; } diff --git a/python/paddle/fluid/tests/CMakeLists.txt b/python/paddle/fluid/tests/CMakeLists.txt index 6acee6dc11c89aa47bfe25248c5369f4e1452d19..92e29202b28b876fb058ffd9dfe1f28b1d0247a0 100644 --- a/python/paddle/fluid/tests/CMakeLists.txt +++ b/python/paddle/fluid/tests/CMakeLists.txt @@ -12,5 +12,6 @@ add_subdirectory(unittests) add_subdirectory(book) add_subdirectory(custom_op) add_subdirectory(custom_kernel) +add_subdirectory(custom_runtime) set_tests_properties(test_beam_search_decoder PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/custom_runtime/CMakeLists.txt b/python/paddle/fluid/tests/custom_runtime/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..acd441c867787cd3592c25ef8bf6fc3fc13db185 --- /dev/null +++ b/python/paddle/fluid/tests/custom_runtime/CMakeLists.txt @@ -0,0 +1,3 @@ +if(WITH_CUSTOM_DEVICE) + py_test(test_custom_device_data_loader SRCS test_custom_device_data_loader.py) +endif() diff --git a/python/paddle/fluid/tests/custom_runtime/__init__.py b/python/paddle/fluid/tests/custom_runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97043fd7ba6885aac81cad5a49924c23c67d4d47 --- /dev/null +++ b/python/paddle/fluid/tests/custom_runtime/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/fluid/tests/custom_runtime/custom_cpu_runtime.cc b/python/paddle/fluid/tests/custom_runtime/custom_cpu_runtime.cc new file mode 100644 index 0000000000000000000000000000000000000000..18762625c0fe27964174aa5bb2e66830736fd899 --- /dev/null +++ b/python/paddle/fluid/tests/custom_runtime/custom_cpu_runtime.cc @@ -0,0 +1,215 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "paddle/phi/backends/device_ext.h" + +#define MEMORY_FRACTION 0.5f + +C_Status Init() { return C_SUCCESS; } + +C_Status InitDevice(const C_Device device) { return C_SUCCESS; } + +C_Status SetDevice(const C_Device device) { return C_SUCCESS; } + +C_Status GetDevice(const C_Device device) { + device->id = 0; + return C_SUCCESS; +} + +C_Status DestroyDevice(const C_Device device) { return C_SUCCESS; } + +C_Status Finalize() { return C_SUCCESS; } + +C_Status GetDevicesCount(size_t *count) { + *count = 1; + return C_SUCCESS; +} + +C_Status GetDevicesList(size_t *devices) { + devices[0] = 0; + return C_SUCCESS; +} + +C_Status MemCpy(const C_Device device, + void *dst, + const void *src, + size_t size) { + memcpy(dst, src, size); + return C_SUCCESS; +} + +C_Status AsyncMemCpy(const C_Device device, + C_Stream stream, + void *dst, + const void *src, + size_t size) { + memcpy(dst, src, size); + return C_SUCCESS; +} + +C_Status MemCpyP2P(const C_Device dst_device, + const C_Device src_device, + void *dst, + const void *src, + size_t size) { + memcpy(dst, src, size); + return C_SUCCESS; +} + +C_Status AsyncMemCpyP2P(const C_Device dst_device, + const C_Device src_device, + C_Stream stream, + void *dst, + const void *src, + size_t size) { + memcpy(dst, src, size); + return C_SUCCESS; +} + +C_Status Allocate(const C_Device device, void **ptr, size_t size) { + auto data = malloc(size); + if (data) { + *ptr = data; + return C_SUCCESS; + } else { + *ptr = nullptr; + } + return C_FAILED; +} + +C_Status Deallocate(const C_Device device, void *ptr, size_t size) { + free(ptr); + return C_SUCCESS; +} + +C_Status CreateStream(const C_Device device, C_Stream *stream) { + stream = nullptr; + return C_SUCCESS; +} + +C_Status DestroyStream(const C_Device device, C_Stream stream) { + return C_SUCCESS; +} + +C_Status CreateEvent(const C_Device device, C_Event *event) { + return C_SUCCESS; +} + +C_Status RecordEvent(const C_Device device, C_Stream stream, C_Event event) { + return C_SUCCESS; +} + +C_Status DestroyEvent(const C_Device device, C_Event event) { + return C_SUCCESS; +} + +C_Status SyncDevice(const C_Device device) { return C_SUCCESS; } + +C_Status SyncStream(const C_Device device, C_Stream stream) { + return C_SUCCESS; +} + +C_Status SyncEvent(const C_Device device, C_Event event) { return C_SUCCESS; } + +C_Status StreamWaitEvent(const C_Device device, + C_Stream stream, + C_Event event) { + return C_SUCCESS; +} + +C_Status VisibleDevices(size_t *devices) { return C_SUCCESS; } + +C_Status DeviceMemStats(const C_Device device, + size_t *total_memory, + size_t *free_memory) { + float memusage; + FILE *fp; + char buffer[1024]; + size_t byte_read; + char *pos; + + fp = fopen("/proc/meminfo", "r"); + byte_read = fread(buffer, 1, sizeof(buffer), fp); + fclose(fp); + buffer[byte_read] = '\0'; + pos = strstr(buffer, "MemTotal:"); + sscanf(pos, "MemTotal: %lu kB", total_memory); + pos = strstr(pos, "MemFree:"); + sscanf(pos, "MemFree: %lu kB", free_memory); + *total_memory = *total_memory * 1024; + *free_memory = *free_memory * 1024; + *free_memory = *free_memory * MEMORY_FRACTION; + + return C_SUCCESS; +} + +C_Status DeviceMinChunkSize(const C_Device device, size_t *size) { + *size = 512; + return C_SUCCESS; +} + +void InitPlugin(CustomRuntimeParams *params) { + PADDLE_CUSTOM_RUNTIME_CHECK_VERSION(params); + params->device_type = "custom_cpu"; + params->sub_device_type = "v0.1"; + + memset(reinterpret_cast(params->interface), + 0, + sizeof(C_DeviceInterface)); + + params->interface->initialize = Init; + params->interface->finalize = Finalize; + + params->interface->init_device = InitDevice; + params->interface->set_device = SetDevice; + params->interface->get_device = GetDevice; + params->interface->deinit_device = DestroyDevice; + + params->interface->create_stream = CreateStream; + params->interface->destroy_stream = DestroyStream; + + params->interface->create_event = CreateEvent; + params->interface->destroy_event = DestroyEvent; + params->interface->record_event = RecordEvent; + + params->interface->synchronize_device = SyncDevice; + params->interface->synchronize_stream = SyncStream; + params->interface->synchronize_event = SyncEvent; + params->interface->stream_wait_event = StreamWaitEvent; + + params->interface->memory_copy_h2d = MemCpy; + params->interface->memory_copy_d2d = MemCpy; + params->interface->memory_copy_d2h = MemCpy; + params->interface->memory_copy_p2p = MemCpyP2P; + params->interface->async_memory_copy_h2d = AsyncMemCpy; + params->interface->async_memory_copy_d2d = AsyncMemCpy; + params->interface->async_memory_copy_d2h = AsyncMemCpy; + params->interface->async_memory_copy_p2p = AsyncMemCpyP2P; + params->interface->device_memory_allocate = Allocate; + params->interface->host_memory_allocate = Allocate; + params->interface->unified_memory_allocate = Allocate; + params->interface->device_memory_deallocate = Deallocate; + params->interface->host_memory_deallocate = Deallocate; + params->interface->unified_memory_deallocate = Deallocate; + + params->interface->get_device_count = GetDevicesCount; + params->interface->get_device_list = GetDevicesList; + params->interface->device_memory_stats = DeviceMemStats; + params->interface->device_min_chunk_size = DeviceMinChunkSize; +} diff --git a/python/paddle/fluid/tests/custom_runtime/custom_cpu_setup.py b/python/paddle/fluid/tests/custom_runtime/custom_cpu_setup.py new file mode 100644 index 0000000000000000000000000000000000000000..82accb2ad00df4b3ba3727c8e34656a234ce2b14 --- /dev/null +++ b/python/paddle/fluid/tests/custom_runtime/custom_cpu_setup.py @@ -0,0 +1,82 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import site +from paddle.fluid import core +from distutils.sysconfig import get_python_lib +from distutils.core import setup, Extension +from setuptools.command.build_ext import build_ext + + +# refer: https://note.qidong.name/2018/03/setup-warning-strict-prototypes +# Avoid a gcc warning below: +# cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid +# for C/ObjC but not for C++ +class BuildExt(build_ext): + + def build_extensions(self): + if '-Wstrict-prototypes' in self.compiler.compiler_so: + self.compiler.compiler_so.remove('-Wstrict-prototypes') + super(BuildExt, self).build_extensions() + + +# cc flags +paddle_extra_compile_args = [ + '-std=c++14', + '-shared', + '-fPIC', + '-Wno-parentheses', + '-DPADDLE_WITH_CUSTOM_KERNEL', + '-DPADDLE_WITH_CUSTOM_DEVICE', +] +if core.is_compiled_with_npu(): + paddle_extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI=0'] + +# include path +site_packages_path = site.getsitepackages() +include_dirs = list( + map(lambda path: os.path.join(path, 'paddle', 'include'), + site_packages_path)) + +# include path third_party +compile_third_party_path = os.path.join(os.environ['PADDLE_ROOT'], + 'build/third_party') +include_dirs += [ + os.path.join(compile_third_party_path, 'boost/src/extern_boost'), # boost + os.path.join(compile_third_party_path, 'install/gflags/include'), # gflags + os.path.join(compile_third_party_path, 'install/glog/include'), # glog +] + +# libs path +library_dirs = list( + map(lambda path: os.path.join(path, 'paddle', 'fluid'), site_packages_path)) + +# libs +libs = [':core_avx.so'] +if not core.has_avx_core and core.has_noavx_core: + libs = [':core_noavx.so'] + +custom_cpu_plugin_so = Extension('custom_cpu_runtime', + sources=['custom_cpu_runtime.cc'], + include_dirs=include_dirs, + library_dirs=library_dirs, + libraries=libs, + extra_compile_args=paddle_extra_compile_args) + +setup(name='custom_kernel_dot', + version='1.0', + description='custom kernel fot compiling', + cmdclass={'build_ext': BuildExt}, + ext_modules=[custom_cpu_plugin_so]) diff --git a/python/paddle/fluid/tests/custom_runtime/test_custom_device_data_loader.py b/python/paddle/fluid/tests/custom_runtime/test_custom_device_data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..775c3f487d596b974647d63f16d3e1c76afdb64b --- /dev/null +++ b/python/paddle/fluid/tests/custom_runtime/test_custom_device_data_loader.py @@ -0,0 +1,66 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import site +import unittest +import numpy as np + + +class TestCustomDeviceDataLoader(unittest.TestCase): + + def setUp(self): + # compile so and set to current path + cur_dir = os.path.dirname(os.path.abspath(__file__)) + + # --inplace to place output so file to current dir + cmd = 'cd {} && {} custom_cpu_setup.py build_ext --inplace'.format( + cur_dir, sys.executable) + os.system(cmd) + + # set environment for loading and registering compiled custom kernels + # only valid in current process + os.environ['CUSTOM_DEVICE_ROOT'] = cur_dir + + def test_custom_device_dataloader(self): + import paddle + + paddle.set_device('custom_cpu') + dataset = paddle.vision.datasets.MNIST( + mode='test', + transform=paddle.vision.transforms.Compose([ + paddle.vision.transforms.CenterCrop(20), + paddle.vision.transforms.RandomResizedCrop(14), + paddle.vision.transforms.Normalize(), + paddle.vision.transforms.ToTensor() + ])) + loader = paddle.io.DataLoader(dataset, + batch_size=32, + num_workers=1, + shuffle=True) + for image, label in loader: + self.assertTrue(image.place.is_custom_place()) + self.assertTrue(label.place.is_custom_place()) + break + + def tearDown(self): + del os.environ['CUSTOM_DEVICE_ROOT'] + + +if __name__ == '__main__': + if os.name == 'nt' or sys.platform.startswith('darwin'): + # only support Linux now + exit() + unittest.main()