未验证 提交 684de4b3 编写于 作者: 石晓伟 提交者: GitHub

add pten allocation places, test=develop (#37369)

上级 9ec1432d
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
# float16.h/complex.h/bfloat16.h into pten # float16.h/complex.h/bfloat16.h into pten
include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform) include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform)
# paddle experimental common components
add_subdirectory(common)
# pten (low level) api headers: include # pten (low level) api headers: include
# pten (high level) api # pten (high level) api
add_subdirectory(api) add_subdirectory(api)
......
cc_library(pten_api_utils SRCS allocator.cc storage.cc tensor_utils.cc DEPS tensor_base convert_utils dense_tensor lod_tensor selected_rows place var_type_traits) cc_library(pten_api_utils SRCS allocator.cc storage.cc tensor_utils.cc place_utils.cc DEPS
tensor_base convert_utils dense_tensor lod_tensor selected_rows place var_type_traits pten_common)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/api/lib/utils/place_utils.h"
#include "paddle/pten/api/ext/exception.h"
namespace paddle {
namespace experimental {
Place ConvertToPtenPlace(const platform::Place& src) {
Place place;
if (platform::is_cpu_place(src)) {
place.Reset(Device(DeviceType::kHost, 0));
} else if (platform::is_gpu_place(src)) {
place.Reset(
Device(DeviceType::kCuda,
BOOST_GET_CONST(platform::CUDAPlace, src).GetDeviceId()));
} else if (platform::is_cuda_pinned_place(src)) {
place.Reset(Device(DeviceType::kCuda, 0), true);
} else if (platform::is_xpu_place(src)) {
place.Reset(Device(DeviceType::kXpu,
BOOST_GET_CONST(platform::XPUPlace, src).GetDeviceId()));
} else {
PD_THROW("Invalid platform place type.");
}
return place;
}
platform::Place ConvertToPlatformPlace(const Place& src) {
switch (src.device().type()) {
case DeviceType::kHost: {
return platform::CPUPlace();
}
case DeviceType::kCuda: {
if (src.is_pinned()) {
return platform::CUDAPinnedPlace();
} else {
return platform::CUDAPlace(src.device().id());
}
}
case DeviceType::kXpu: {
return platform::XPUPlace(src.device().id());
}
default:
PD_THROW("Invalid pten place type.");
}
return {};
}
} // namespace experimental
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/place.h"
#include "paddle/pten/common/place.h"
namespace paddle {
namespace experimental {
Place ConvertToPtenPlace(const platform::Place& src);
platform::Place ConvertToPlatformPlace(const Place& src);
} // namespace experimental
} // namespace paddle
cc_library(pten_common SRCS device.cc place.cc DEPS enforce)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/common/device.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/api/ext/exception.h"
namespace paddle {
namespace experimental {
const char* DeviceTypeStr(DeviceType type) {
switch (type) {
case DeviceType::kUndef:
return "kUndef";
case DeviceType::kHost:
return "kUndef";
case DeviceType::kXpu:
return "kXpu";
case DeviceType::kCuda:
return "kCuda";
case DeviceType::kHip:
return "kHip";
case DeviceType::kNpu:
return "kNpu";
default:
PD_THROW("Invalid pten device type.");
}
return {};
}
Device::Device(DeviceType type, int8_t id) : type_(type), id_(id) {
PADDLE_ENFORCE_GE(
id,
0,
platform::errors::InvalidArgument(
"The device id needs to start from zero, but you passed in %d.", id));
}
Device::Device(DeviceType type) : type_(type), id_(0) {
PADDLE_ENFORCE_EQ(
type,
DeviceType::kHost,
platform::errors::InvalidArgument(
"The device id needs to start from zero, but you passed in %s.",
DeviceTypeStr(type)));
}
std::string Device::DebugString() const {
std::string str{"DeviceType:"};
return str + DeviceTypeStr(type_) + ", id: " + std::to_string(id_);
}
} // namespace experimental
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdint>
#include <string>
namespace paddle {
namespace experimental {
enum class DeviceType : int8_t {
kUndef = 0,
kHost = 1,
kXpu = 2,
kCuda = 3,
kHip = 4,
kNpu = 5,
};
const char* DeviceTypeStr(DeviceType type);
/// \brief The device is used to store hardware information. It has not yet
/// stored information related to the math acceleration library.
struct Device final {
public:
Device() = default;
Device(DeviceType type, int8_t id);
Device(DeviceType type);
DeviceType type() const noexcept { return type_; }
/// \brief Returns the index of the device. Here, -1 is used to indicate an
/// invalid value, and 0 to indicate a default value.
/// \return The index of the device.
int8_t id() const noexcept { return id_; }
void set_type(DeviceType type) noexcept { type_ = type; }
void set_id(int8_t id) noexcept { id_ = id; }
std::string DebugString() const;
private:
friend bool operator==(const Device&, const Device&) noexcept;
private:
DeviceType type_{DeviceType::kUndef};
int8_t id_{-1};
};
inline bool operator==(const Device& lhs, const Device& rhs) noexcept {
return (lhs.type_ == rhs.type_) && (lhs.id_ == rhs.id_);
}
} // namespace experimental
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/common/place.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace experimental {
std::string Place::DebugString() const {
return device_.DebugString() + ", is_pinned: " + std::to_string(is_pinned_);
}
} // namespace experimental
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/pten/common/device.h"
namespace paddle {
namespace experimental {
/// \brief The place is used to specify where the data is stored.
class Place final {
public:
Place() = default;
explicit Place(const Device& device) : device_(device) {}
Place(DeviceType type, int8_t id) : device_(type, id) {}
Place(DeviceType type) : device_(type) {}
Place(const Device& device, bool is_pinned) noexcept : device_(device),
is_pinned_(is_pinned) {
}
const Device& device() const noexcept { return device_; }
/// \brief Returns whether the memory is a locked page. The page lock
/// memory is actually located in the host memory, but it can only be
/// used by certain devices and can be directly transferred by DMA.
/// \return Whether the memory is a locked page.
bool is_pinned() const noexcept { return is_pinned_; }
void Reset(const Device& device, bool is_pinned = false) noexcept {
device_ = device;
is_pinned_ = is_pinned;
}
std::string DebugString() const;
private:
friend bool operator==(const Place&, const Place&) noexcept;
private:
Device device_;
bool is_pinned_{false};
};
inline bool operator==(const Place& lhs, const Place& rhs) noexcept {
return (lhs.device_ == rhs.device_) && (lhs.is_pinned_ == rhs.is_pinned_);
}
} // namespace experimental
} // namespace paddle
...@@ -7,6 +7,8 @@ endif() ...@@ -7,6 +7,8 @@ endif()
cc_test(test_pten_exception SRCS test_pten_exception.cc DEPS gtest) cc_test(test_pten_exception SRCS test_pten_exception.cc DEPS gtest)
cc_test(test_framework_storage SRCS test_storage.cc DEPS pten_api_utils) cc_test(test_framework_storage SRCS test_storage.cc DEPS pten_api_utils)
cc_test(test_framework_tensor_utils SRCS test_tensor_utils.cc DEPS pten_api_utils) cc_test(test_framework_tensor_utils SRCS test_tensor_utils.cc DEPS pten_api_utils)
cc_test(test_framework_place_utils storage SRCS test_place_utils.cc DEPS pten_api_utils)
cc_test(test_mean_api SRCS test_mean_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_mean_api SRCS test_mean_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_dot_api SRCS test_dot_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_dot_api SRCS test_dot_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS pten_tensor pten_api pten_api_utils)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/pten/api/lib/utils/place_utils.h"
namespace paddle {
namespace experimental {
namespace tests {
TEST(place_utils, cpu_place) {
auto pd_place = platform::CPUPlace();
Place pten_place = ConvertToPtenPlace(pd_place);
CHECK_EQ(pten_place.device().id(), 0);
CHECK(pten_place.device().type() == DeviceType::kHost);
CHECK(pten_place.is_pinned() == false);
auto pd_place_1 = ConvertToPlatformPlace(pten_place);
CHECK(platform::is_cpu_place(pd_place_1));
CHECK(pd_place == BOOST_GET_CONST(platform::CPUPlace, pd_place_1));
CHECK(pten_place == ConvertToPtenPlace(pd_place_1));
}
TEST(place_utils, cuda_place) {
auto pd_place = platform::CUDAPlace(1);
Place pten_place = ConvertToPtenPlace(pd_place);
CHECK_EQ(pten_place.device().id(), 1);
CHECK(pten_place.device().type() == DeviceType::kCuda);
CHECK(pten_place.is_pinned() == false);
auto pd_place_1 = ConvertToPlatformPlace(pten_place);
CHECK(platform::is_gpu_place(pd_place_1));
CHECK(pd_place == BOOST_GET_CONST(platform::CUDAPlace, pd_place_1));
CHECK(pten_place == ConvertToPtenPlace(pd_place_1));
}
TEST(place_utils, cuda_pinned_place) {
auto pd_place = platform::CUDAPinnedPlace();
Place pten_place = ConvertToPtenPlace(pd_place);
CHECK_EQ(pten_place.device().id(), 0);
CHECK(pten_place.device().type() == DeviceType::kCuda);
CHECK(pten_place.is_pinned() == true);
auto pd_place_1 = ConvertToPlatformPlace(pten_place);
CHECK(platform::is_cuda_pinned_place(pd_place_1));
CHECK(pd_place == BOOST_GET_CONST(platform::CUDAPinnedPlace, pd_place_1));
CHECK(pten_place == ConvertToPtenPlace(pd_place_1));
}
TEST(place_utils, xpu_place) {
auto pd_place = platform::XPUPlace(1);
Place pten_place = ConvertToPtenPlace(pd_place);
CHECK_EQ(pten_place.device().id(), 1);
CHECK(pten_place.device().type() == DeviceType::kXpu);
CHECK(pten_place.is_pinned() == false);
auto pd_place_1 = ConvertToPlatformPlace(pten_place);
CHECK(platform::is_xpu_place(pd_place_1));
CHECK(pd_place == BOOST_GET_CONST(platform::XPUPlace, pd_place_1));
CHECK(pten_place == ConvertToPtenPlace(pd_place_1));
}
} // namespace tests
} // namespace experimental
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册