Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b53cdc9e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b53cdc9e
编写于
3月 03, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into move_yolo_box_to_phi
上级
9e00395a
97ccaa79
变更
66
隐藏空白更改
内联
并排
Showing
66 changed file
with
682 addition
and
724 deletion
+682
-724
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
...er/auto_code_generator/final_state_generator/eager_gen.py
+3
-1
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
...auto_code_generator/final_state_generator/python_c_gen.py
+6
-18
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-2
paddle/fluid/framework/custom_kernel.cc
paddle/fluid/framework/custom_kernel.cc
+0
-47
paddle/fluid/framework/custom_kernel.h
paddle/fluid/framework/custom_kernel.h
+0
-26
paddle/fluid/framework/garbage_collector.cc
paddle/fluid/framework/garbage_collector.cc
+5
-5
paddle/fluid/framework/garbage_collector.h
paddle/fluid/framework/garbage_collector.h
+3
-3
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+1
-1
paddle/fluid/imperative/tracer.cc
paddle/fluid/imperative/tracer.cc
+1
-1
paddle/fluid/inference/api/CMakeLists.txt
paddle/fluid/inference/api/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+2
-0
paddle/fluid/memory/allocation/allocator_facade.cc
paddle/fluid/memory/allocation/allocator_facade.cc
+7
-8
paddle/fluid/memory/allocation/custom_allocator.cc
paddle/fluid/memory/allocation/custom_allocator.cc
+3
-4
paddle/fluid/memory/allocation/naive_best_fit_allocator.cc
paddle/fluid/memory/allocation/naive_best_fit_allocator.cc
+8
-9
paddle/fluid/memory/detail/buddy_allocator.cc
paddle/fluid/memory/detail/buddy_allocator.cc
+2
-2
paddle/fluid/memory/detail/system_allocator.cc
paddle/fluid/memory/detail/system_allocator.cc
+3
-3
paddle/fluid/memory/memcpy.cc
paddle/fluid/memory/memcpy.cc
+10
-10
paddle/fluid/operators/detection/multiclass_nms_op.cc
paddle/fluid/operators/detection/multiclass_nms_op.cc
+2
-3
paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc
...le/fluid/operators/elementwise/elementwise_op_npu_test.cc
+1
-1
paddle/fluid/operators/elementwise/elementwise_pow_op_xpu.cc
paddle/fluid/operators/elementwise/elementwise_pow_op_xpu.cc
+0
-1
paddle/fluid/operators/elementwise/elementwise_sub_op.cc
paddle/fluid/operators/elementwise/elementwise_sub_op.cc
+7
-48
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
+0
-63
paddle/fluid/operators/elementwise/elementwise_sub_op.h
paddle/fluid/operators/elementwise/elementwise_sub_op.h
+0
-96
paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc
paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc
+1
-1
paddle/fluid/operators/elementwise/elementwise_sub_op_xpu.cc
paddle/fluid/operators/elementwise/elementwise_sub_op_xpu.cc
+0
-1
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+31
-5
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+1
-1
paddle/fluid/platform/device/CMakeLists.txt
paddle/fluid/platform/device/CMakeLists.txt
+0
-20
paddle/fluid/platform/device/custom/CMakeLists.txt
paddle/fluid/platform/device/custom/CMakeLists.txt
+0
-4
paddle/fluid/platform/device/custom/enforce_custom.h
paddle/fluid/platform/device/custom/enforce_custom.h
+4
-1
paddle/fluid/platform/device/device_wrapper.h
paddle/fluid/platform/device/device_wrapper.h
+5
-5
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+1
-1
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+3
-3
paddle/fluid/platform/init.cc
paddle/fluid/platform/init.cc
+6
-6
paddle/fluid/pybind/eager_utils.cc
paddle/fluid/pybind/eager_utils.cc
+92
-8
paddle/fluid/pybind/eager_utils.h
paddle/fluid/pybind/eager_utils.h
+10
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+7
-7
paddle/fluid/pybind/tensor_py.h
paddle/fluid/pybind/tensor_py.h
+2
-2
paddle/phi/backends/CMakeLists.txt
paddle/phi/backends/CMakeLists.txt
+7
-0
paddle/phi/backends/callback_manager.cc
paddle/phi/backends/callback_manager.cc
+5
-7
paddle/phi/backends/callback_manager.h
paddle/phi/backends/callback_manager.h
+2
-4
paddle/phi/backends/custom/CMakeLists.txt
paddle/phi/backends/custom/CMakeLists.txt
+2
-0
paddle/phi/backends/custom/custom_context.cc
paddle/phi/backends/custom/custom_context.cc
+5
-5
paddle/phi/backends/custom/custom_device.cc
paddle/phi/backends/custom/custom_device.cc
+103
-64
paddle/phi/backends/custom/custom_device_test.cc
paddle/phi/backends/custom/custom_device_test.cc
+14
-16
paddle/phi/backends/custom/fake_cpu_device.h
paddle/phi/backends/custom/fake_cpu_device.h
+15
-7
paddle/phi/backends/device_base.cc
paddle/phi/backends/device_base.cc
+54
-34
paddle/phi/backends/device_base.h
paddle/phi/backends/device_base.h
+28
-16
paddle/phi/backends/device_ext.h
paddle/phi/backends/device_ext.h
+62
-27
paddle/phi/backends/device_guard.cc
paddle/phi/backends/device_guard.cc
+3
-5
paddle/phi/backends/device_guard.h
paddle/phi/backends/device_guard.h
+5
-7
paddle/phi/backends/device_manager.cc
paddle/phi/backends/device_manager.cc
+56
-44
paddle/phi/backends/device_manager.h
paddle/phi/backends/device_manager.h
+26
-17
paddle/phi/backends/event.cc
paddle/phi/backends/event.cc
+6
-8
paddle/phi/backends/event.h
paddle/phi/backends/event.h
+2
-4
paddle/phi/backends/stream.cc
paddle/phi/backends/stream.cc
+9
-10
paddle/phi/backends/stream.h
paddle/phi/backends/stream.h
+5
-6
paddle/phi/core/CMakeLists.txt
paddle/phi/core/CMakeLists.txt
+1
-1
paddle/phi/core/compat/convert_utils.cc
paddle/phi/core/compat/convert_utils.cc
+2
-4
paddle/phi/core/custom_kernel.cc
paddle/phi/core/custom_kernel.cc
+24
-0
paddle/phi/core/custom_kernel.h
paddle/phi/core/custom_kernel.h
+2
-0
paddle/phi/kernels/math_kernel.cc
paddle/phi/kernels/math_kernel.cc
+2
-1
paddle/phi/ops/compat/elementwise_sig.cc
paddle/phi/ops/compat/elementwise_sig.cc
+9
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_emb_eltwise_layernorm.py
...ts/ir/inference/test_trt_convert_emb_eltwise_layernorm.py
+2
-14
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather.py
...d/tests/unittests/ir/inference/test_trt_convert_gather.py
+1
-1
python/setup.py.in
python/setup.py.in
+1
-4
未找到文件。
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
浏览文件 @
b53cdc9e
...
@@ -31,7 +31,9 @@ yaml_types_mapping = {
...
@@ -31,7 +31,9 @@ yaml_types_mapping = {
'int64_t[]'
:
'std::vector<int64_t>'
,
'int[]'
:
'std::vector<int>'
,
'int64_t[]'
:
'std::vector<int64_t>'
,
'int[]'
:
'std::vector<int>'
,
'Tensor'
:
'Tensor'
,
'Tensor'
:
'Tensor'
,
'Tensor[]'
:
'std::vector<Tensor>'
,
'Tensor[]'
:
'std::vector<Tensor>'
,
'Tensor[Tensor[]]'
:
'std::vector<std::vector<Tensor>>'
'Tensor[Tensor[]]'
:
'std::vector<std::vector<Tensor>>'
,
'Scalar'
:
'Scalar'
,
'ScalarArray'
:
'ScalarArray'
}
}
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
浏览文件 @
b53cdc9e
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
os
import
os
import
argparse
import
argparse
from
eager_gen
import
ReadFwdFile
,
ParseDispensable
,
IsVectorTensorType
,
GetForwardFunctionName
,
ParseYamlForward
,
DetermineForwardPositionMap
from
eager_gen
import
yaml_types_mapping
,
ReadFwdFile
,
ParseDispensable
,
IsVectorTensorType
,
GetForwardFunctionName
,
ParseYamlForward
,
DetermineForwardPositionMap
atype_to_parsing_function
=
{
atype_to_parsing_function
=
{
"bool"
:
"CastPyArg2Boolean"
,
"bool"
:
"CastPyArg2Boolean"
,
...
@@ -27,21 +27,9 @@ atype_to_parsing_function = {
...
@@ -27,21 +27,9 @@ atype_to_parsing_function = {
"long[]"
:
"CastPyArg2Longs"
,
"long[]"
:
"CastPyArg2Longs"
,
"float[]"
:
"CastPyArg2Floats"
,
"float[]"
:
"CastPyArg2Floats"
,
"double[]"
:
"CastPyArg2Float64s"
,
"double[]"
:
"CastPyArg2Float64s"
,
"string[]"
:
"CastPyArg2Strings"
"string[]"
:
"CastPyArg2Strings"
,
}
"Scalar"
:
"CastPyArg2Scalar"
,
"ScalarArray"
:
"CastPyArg2ScalarArray"
atype_to_cxx_type
=
{
"bool"
:
"bool"
,
"int"
:
"int"
,
"long"
:
"long"
,
"float"
:
"float"
,
"string"
:
"std::string"
,
"bool[]"
:
"std::vector<bool>"
,
"int[]"
:
"std::vector<int>"
,
"long[]"
:
"std::vector<long>"
,
"float[]"
:
"std::vector<float>"
,
"double[]"
:
"std::vector<double>"
,
"string[]"
:
"std::vector<std::string>"
}
}
...
@@ -56,10 +44,10 @@ def ParseArguments():
...
@@ -56,10 +44,10 @@ def ParseArguments():
def
GetCxxType
(
atype
):
def
GetCxxType
(
atype
):
if
atype
not
in
atype_to_cxx_type
.
keys
():
if
atype
not
in
yaml_types_mapping
.
keys
():
assert
False
assert
False
return
atype_to_cxx_type
[
atype
]
return
yaml_types_mapping
[
atype
]
def
FindParsingFunctionFromAttributeType
(
atype
):
def
FindParsingFunctionFromAttributeType
(
atype
):
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
b53cdc9e
...
@@ -440,11 +440,10 @@ message(STATUS "branch: ${PADDLE_BRANCH}")
...
@@ -440,11 +440,10 @@ message(STATUS "branch: ${PADDLE_BRANCH}")
configure_file
(
commit.h.in commit.h
)
configure_file
(
commit.h.in commit.h
)
cc_library
(
custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper phi_tensor op_meta_info phi_api
)
cc_library
(
custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper phi_tensor op_meta_info phi_api
)
cc_library
(
custom_kernel SRCS custom_kernel.cc DEPS op_registry phi_custom_kernel phi_tensor_raw
)
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
#cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
#cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
set
(
FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator custom_kernel
)
set
(
FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator
phi_
custom_kernel
)
cc_library
(
paddle_framework DEPS
${
FLUID_FRAMEWORK_MODULES
}
)
cc_library
(
paddle_framework DEPS
${
FLUID_FRAMEWORK_MODULES
}
)
...
...
paddle/fluid/framework/custom_kernel.cc
已删除
100644 → 0
浏览文件 @
9e00395a
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include "paddle/fluid/framework/custom_kernel.h"
#include "paddle/phi/core/custom_kernel.h"
namespace
paddle
{
namespace
framework
{
void
LoadCustomKernelLib
(
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
)
{
#ifdef _LINUX
typedef
phi
::
CustomKernelMap
&
get_custom_kernel_map_t
();
auto
*
func
=
reinterpret_cast
<
get_custom_kernel_map_t
*>
(
dlsym
(
dso_handle
,
"PD_GetCustomKernelMap"
));
if
(
func
==
nullptr
)
{
LOG
(
WARNING
)
<<
"Skipped lib ["
<<
dso_lib_path
<<
"]: fail to find "
<<
"PD_GetCustomKernelMap symbol in this lib."
;
return
;
}
auto
&
custom_kernel_map
=
func
();
phi
::
RegisterCustomKernels
(
custom_kernel_map
);
LOG
(
INFO
)
<<
"Successed in loading custom kernels in lib: "
<<
dso_lib_path
;
#else
VLOG
(
3
)
<<
"Unsupported: Custom kernel is only implemented on Linux."
;
#endif
return
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/custom_kernel.h
已删除
100644 → 0
浏览文件 @
9e00395a
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
namespace
paddle
{
namespace
framework
{
// Load custom kernel lib and register
void
LoadCustomKernelLib
(
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/garbage_collector.cc
浏览文件 @
b53cdc9e
...
@@ -231,19 +231,19 @@ void CustomDeviceUnsafeFastGarbageCollector::ClearCallback(
...
@@ -231,19 +231,19 @@ void CustomDeviceUnsafeFastGarbageCollector::ClearCallback(
CustomStreamGarbageCollector
::
CustomStreamGarbageCollector
(
CustomStreamGarbageCollector
::
CustomStreamGarbageCollector
(
const
platform
::
CustomPlace
&
place
,
size_t
max_memory_size
)
const
platform
::
CustomPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
(
place
,
max_memory_size
)
{
:
GarbageCollector
(
place
,
max_memory_size
)
{
p
latform
::
DeviceGuard
guard
(
place
);
p
hi
::
DeviceGuard
guard
(
place
);
stream_
.
reset
(
new
p
latform
::
stream
::
Stream
);
stream_
.
reset
(
new
p
hi
::
stream
::
Stream
);
stream_
->
Init
(
place
);
stream_
->
Init
(
place
);
callback_manager_
.
reset
(
new
p
latform
::
CallbackManager
(
stream_
.
get
()));
callback_manager_
.
reset
(
new
p
hi
::
CallbackManager
(
stream_
.
get
()));
}
}
CustomStreamGarbageCollector
::~
CustomStreamGarbageCollector
()
{
CustomStreamGarbageCollector
::~
CustomStreamGarbageCollector
()
{
p
latform
::
DeviceGuard
guard
(
this
->
dev_ctx_
->
GetPlace
());
p
hi
::
DeviceGuard
guard
(
this
->
dev_ctx_
->
GetPlace
());
stream_
->
Synchronize
();
stream_
->
Synchronize
();
stream_
->
Destroy
();
stream_
->
Destroy
();
}
}
p
latform
::
stream
::
Stream
*
CustomStreamGarbageCollector
::
stream
()
const
{
p
hi
::
stream
::
Stream
*
CustomStreamGarbageCollector
::
stream
()
const
{
return
stream_
.
get
();
return
stream_
.
get
();
}
}
...
...
paddle/fluid/framework/garbage_collector.h
浏览文件 @
b53cdc9e
...
@@ -230,14 +230,14 @@ class CustomStreamGarbageCollector : public GarbageCollector {
...
@@ -230,14 +230,14 @@ class CustomStreamGarbageCollector : public GarbageCollector {
void
Wait
()
const
override
;
void
Wait
()
const
override
;
p
latform
::
stream
::
Stream
*
stream
()
const
;
p
hi
::
stream
::
Stream
*
stream
()
const
;
protected:
protected:
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
;
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
;
private:
private:
std
::
unique_ptr
<
p
latform
::
stream
::
Stream
>
stream_
;
std
::
unique_ptr
<
p
hi
::
stream
::
Stream
>
stream_
;
std
::
unique_ptr
<
p
latform
::
CallbackManager
>
callback_manager_
;
std
::
unique_ptr
<
p
hi
::
CallbackManager
>
callback_manager_
;
};
};
#endif
#endif
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
b53cdc9e
...
@@ -254,7 +254,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
...
@@ -254,7 +254,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
"reinstall Paddle with CustomDevice support."
,
"reinstall Paddle with CustomDevice support."
,
place
));
place
));
#else
#else
p
latform
::
DeviceManager
::
SetDevice
(
place
);
p
hi
::
DeviceManager
::
SetDevice
(
place
);
#endif
#endif
}
}
...
...
paddle/fluid/imperative/tracer.cc
浏览文件 @
b53cdc9e
...
@@ -253,7 +253,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
...
@@ -253,7 +253,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
#endif
#endif
}
else
if
(
platform
::
is_custom_place
(
place
))
{
}
else
if
(
platform
::
is_custom_place
(
place
))
{
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
p
latform
::
DeviceManager
::
SetDevice
(
place
);
p
hi
::
DeviceManager
::
SetDevice
(
place
);
#else
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should compile with CustomDevice if use "
"PaddlePaddle should compile with CustomDevice if use "
...
...
paddle/fluid/inference/api/CMakeLists.txt
浏览文件 @
b53cdc9e
...
@@ -31,7 +31,7 @@ cc_library(paddle_infer_contrib SRCS paddle_infer_contrib.cc DEPS zero_copy_tens
...
@@ -31,7 +31,7 @@ cc_library(paddle_infer_contrib SRCS paddle_infer_contrib.cc DEPS zero_copy_tens
cc_library
(
paddle_pass_builder SRCS paddle_pass_builder.cc
)
cc_library
(
paddle_pass_builder SRCS paddle_pass_builder.cc
)
set
(
paddle_inference_api_deps lod_tensor scope reset_tensor_array
set
(
paddle_inference_api_deps lod_tensor scope reset_tensor_array
analysis_config paddle_infer_contrib zero_copy_tensor trainer_desc_proto custom_operator custom_kernel
)
analysis_config paddle_infer_contrib zero_copy_tensor trainer_desc_proto custom_operator
phi_
custom_kernel
)
if
(
WITH_CRYPTO
)
if
(
WITH_CRYPTO
)
list
(
APPEND paddle_inference_api_deps paddle_crypto
)
list
(
APPEND paddle_inference_api_deps paddle_crypto
)
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
b53cdc9e
...
@@ -54,6 +54,8 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
...
@@ -54,6 +54,8 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
return
TRT_DT
::
kFLOAT
;
return
TRT_DT
::
kFLOAT
;
case
FluidDT
::
VarType_Type_INT32
:
case
FluidDT
::
VarType_Type_INT32
:
return
TRT_DT
::
kINT32
;
return
TRT_DT
::
kINT32
;
case
FluidDT
::
VarType_Type_FP16
:
return
TRT_DT
::
kHALF
;
default:
default:
return
TRT_DT
::
kINT32
;
return
TRT_DT
::
kINT32
;
}
}
...
...
paddle/fluid/memory/allocation/allocator_facade.cc
浏览文件 @
b53cdc9e
...
@@ -193,10 +193,10 @@ class AllocatorFacadePrivate {
...
@@ -193,10 +193,10 @@ class AllocatorFacadePrivate {
}
}
#endif
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto
device_types
=
p
latform
::
DeviceManager
::
GetAllCustomDeviceTypes
();
auto
device_types
=
p
hi
::
DeviceManager
::
GetAllCustomDeviceTypes
();
for
(
const
auto
&
dev_type
:
device_types
)
{
for
(
const
auto
&
dev_type
:
device_types
)
{
for
(
size_t
dev_id
=
0
;
for
(
size_t
dev_id
=
0
;
dev_id
<
p
latform
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
dev_id
<
p
hi
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
++
dev_id
)
{
++
dev_id
)
{
InitNaiveBestFitCustomDeviceAllocator
(
InitNaiveBestFitCustomDeviceAllocator
(
platform
::
CustomPlace
(
dev_type
,
dev_id
));
platform
::
CustomPlace
(
dev_type
,
dev_id
));
...
@@ -240,10 +240,10 @@ class AllocatorFacadePrivate {
...
@@ -240,10 +240,10 @@ class AllocatorFacadePrivate {
}
}
#endif
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto
device_types
=
p
latform
::
DeviceManager
::
GetAllCustomDeviceTypes
();
auto
device_types
=
p
hi
::
DeviceManager
::
GetAllCustomDeviceTypes
();
for
(
const
auto
&
dev_type
:
device_types
)
{
for
(
const
auto
&
dev_type
:
device_types
)
{
for
(
size_t
dev_id
=
0
;
for
(
size_t
dev_id
=
0
;
dev_id
<
p
latform
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
dev_id
<
p
hi
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
++
dev_id
)
{
++
dev_id
)
{
InitAutoGrowthCustomDeviceAllocator
(
InitAutoGrowthCustomDeviceAllocator
(
platform
::
CustomPlace
(
dev_type
,
dev_id
),
allow_free_idle_chunk
);
platform
::
CustomPlace
(
dev_type
,
dev_id
),
allow_free_idle_chunk
);
...
@@ -738,7 +738,7 @@ class AllocatorFacadePrivate {
...
@@ -738,7 +738,7 @@ class AllocatorFacadePrivate {
auto
custom_allocator
=
auto
custom_allocator
=
std
::
make_shared
<
paddle
::
memory
::
allocation
::
CustomAllocator
>
(
p
);
std
::
make_shared
<
paddle
::
memory
::
allocation
::
CustomAllocator
>
(
p
);
allocators_
[
p
]
=
std
::
make_shared
<
AutoGrowthBestFitAllocator
>
(
allocators_
[
p
]
=
std
::
make_shared
<
AutoGrowthBestFitAllocator
>
(
custom_allocator
,
p
latform
::
DeviceManager
::
GetMinChunkSize
(
p
),
custom_allocator
,
p
hi
::
DeviceManager
::
GetMinChunkSize
(
p
),
allow_free_idle_chunk
);
allow_free_idle_chunk
);
}
}
#endif
#endif
...
@@ -814,11 +814,10 @@ class AllocatorFacadePrivate {
...
@@ -814,11 +814,10 @@ class AllocatorFacadePrivate {
}
}
#endif
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto
device_types
=
p
latform
::
DeviceManager
::
GetAllCustomDeviceTypes
();
auto
device_types
=
p
hi
::
DeviceManager
::
GetAllCustomDeviceTypes
();
for
(
const
auto
&
dev_type
:
device_types
)
{
for
(
const
auto
&
dev_type
:
device_types
)
{
for
(
size_t
dev_id
=
0
;
for
(
size_t
dev_id
=
0
;
dev_id
<
platform
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
dev_id
<
phi
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
dev_id
++
)
{
dev_id
++
)
{
places
.
emplace_back
(
platform
::
CustomPlace
(
dev_type
,
dev_id
));
places
.
emplace_back
(
platform
::
CustomPlace
(
dev_type
,
dev_id
));
}
}
}
}
...
...
paddle/fluid/memory/allocation/custom_allocator.cc
浏览文件 @
b53cdc9e
...
@@ -32,17 +32,16 @@ void CustomAllocator::FreeImpl(phi::Allocation* allocation) {
...
@@ -32,17 +32,16 @@ void CustomAllocator::FreeImpl(phi::Allocation* allocation) {
}
}
phi
::
Allocation
*
CustomAllocator
::
AllocateImpl
(
size_t
size
)
{
phi
::
Allocation
*
CustomAllocator
::
AllocateImpl
(
size_t
size
)
{
std
::
call_once
(
once_flag_
,
std
::
call_once
(
once_flag_
,
[
this
]
{
phi
::
DeviceManager
::
SetDevice
(
place_
);
});
[
this
]
{
platform
::
DeviceManager
::
SetDevice
(
place_
);
});
void
*
ptr
=
void
*
ptr
=
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
place_
)
->
MemoryAllocate
(
size
);
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place_
)
->
MemoryAllocate
(
size
);
if
(
LIKELY
(
ptr
))
{
if
(
LIKELY
(
ptr
))
{
return
new
Allocation
(
ptr
,
size
,
place_
);
return
new
Allocation
(
ptr
,
size
,
place_
);
}
}
size_t
avail
,
total
;
size_t
avail
,
total
;
p
latform
::
DeviceManager
::
MemoryStats
(
place_
,
&
total
,
&
avail
);
p
hi
::
DeviceManager
::
MemoryStats
(
place_
,
&
total
,
&
avail
);
auto
dev_type
=
platform
::
PlaceHelper
::
GetDeviceType
(
place_
);
auto
dev_type
=
platform
::
PlaceHelper
::
GetDeviceType
(
place_
);
auto
dev_id
=
platform
::
PlaceHelper
::
GetDeviceId
(
place_
);
auto
dev_id
=
platform
::
PlaceHelper
::
GetDeviceId
(
place_
);
...
...
paddle/fluid/memory/allocation/naive_best_fit_allocator.cc
浏览文件 @
b53cdc9e
...
@@ -739,7 +739,7 @@ class BuddyAllocatorList {
...
@@ -739,7 +739,7 @@ class BuddyAllocatorList {
private:
private:
explicit
BuddyAllocatorList
(
const
std
::
string
&
device_type
)
explicit
BuddyAllocatorList
(
const
std
::
string
&
device_type
)
:
device_type_
(
device_type
)
{
:
device_type_
(
device_type
)
{
auto
devices
=
p
latform
::
DeviceManager
::
GetDeviceList
(
device_type
);
auto
devices
=
p
hi
::
DeviceManager
::
GetDeviceList
(
device_type
);
for
(
auto
dev_id
:
devices
)
{
for
(
auto
dev_id
:
devices
)
{
init_flags_
[
dev_id
].
reset
(
new
std
::
once_flag
());
init_flags_
[
dev_id
].
reset
(
new
std
::
once_flag
());
}
}
...
@@ -766,15 +766,15 @@ class BuddyAllocatorList {
...
@@ -766,15 +766,15 @@ class BuddyAllocatorList {
device_type_
,
dev_id
));
device_type_
,
dev_id
));
std
::
call_once
(
*
init_flags_
[
dev_id
],
[
this
,
dev_id
]
{
std
::
call_once
(
*
init_flags_
[
dev_id
],
[
this
,
dev_id
]
{
p
latform
::
DeviceManager
::
SetDevice
(
device_type_
,
dev_id
);
p
hi
::
DeviceManager
::
SetDevice
(
device_type_
,
dev_id
);
platform
::
CustomPlace
place
(
device_type_
,
dev_id
);
platform
::
CustomPlace
place
(
device_type_
,
dev_id
);
allocators_
[
dev_id
].
reset
(
new
BuddyAllocator
(
allocators_
[
dev_id
].
reset
(
new
BuddyAllocator
(
std
::
unique_ptr
<
detail
::
SystemAllocator
>
(
std
::
unique_ptr
<
detail
::
SystemAllocator
>
(
new
detail
::
CustomAllocator
(
device_type_
,
dev_id
)),
new
detail
::
CustomAllocator
(
device_type_
,
dev_id
)),
p
latform
::
DeviceManager
::
GetMinChunkSize
(
place
),
p
hi
::
DeviceManager
::
GetMinChunkSize
(
place
),
p
latform
::
DeviceManager
::
GetMaxChunkSize
(
place
),
p
hi
::
DeviceManager
::
GetMaxChunkSize
(
place
),
p
latform
::
DeviceManager
::
GetExtraPaddingSize
(
place
),
device_type_
));
p
hi
::
DeviceManager
::
GetExtraPaddingSize
(
place
),
device_type_
));
});
});
return
allocators_
[
dev_id
].
get
();
return
allocators_
[
dev_id
].
get
();
...
@@ -808,9 +808,9 @@ void *Alloc<platform::CustomPlace>(const platform::CustomPlace &place,
...
@@ -808,9 +808,9 @@ void *Alloc<platform::CustomPlace>(const platform::CustomPlace &place,
auto
*
ptr
=
buddy_allocator
->
Alloc
(
size
);
auto
*
ptr
=
buddy_allocator
->
Alloc
(
size
);
if
(
ptr
==
nullptr
)
{
if
(
ptr
==
nullptr
)
{
p
latform
::
DeviceGuard
guard
(
place
);
p
hi
::
DeviceGuard
guard
(
place
);
size_t
avail
,
total
;
size_t
avail
,
total
;
p
latform
::
DeviceManager
::
MemoryStats
(
place
,
&
total
,
&
avail
);
p
hi
::
DeviceManager
::
MemoryStats
(
place
,
&
total
,
&
avail
);
PADDLE_THROW
(
platform
::
errors
::
ResourceExhausted
(
PADDLE_THROW
(
platform
::
errors
::
ResourceExhausted
(
"Cannot allocate %s in %s:%d, avaliable %s, total %s, used "
"Cannot allocate %s in %s:%d, avaliable %s, total %s, used "
"%s. "
,
"%s. "
,
...
@@ -819,8 +819,7 @@ void *Alloc<platform::CustomPlace>(const platform::CustomPlace &place,
...
@@ -819,8 +819,7 @@ void *Alloc<platform::CustomPlace>(const platform::CustomPlace &place,
string
::
HumanReadableSize
(
total
-
avail
)));
string
::
HumanReadableSize
(
total
-
avail
)));
}
else
{
}
else
{
if
(
FLAGS_init_allocated_mem
)
{
if
(
FLAGS_init_allocated_mem
)
{
platform
::
DeviceManager
::
GetDeviceWithPlace
(
place
)
->
MemorySet
(
ptr
,
0xEF
,
phi
::
DeviceManager
::
GetDeviceWithPlace
(
place
)
->
MemorySet
(
ptr
,
0xEF
,
size
);
size
);
}
}
}
}
VLOG
(
10
)
<<
" pointer="
<<
ptr
;
VLOG
(
10
)
<<
" pointer="
<<
ptr
;
...
...
paddle/fluid/memory/detail/buddy_allocator.cc
浏览文件 @
b53cdc9e
...
@@ -43,11 +43,11 @@ BuddyAllocator::BuddyAllocator(
...
@@ -43,11 +43,11 @@ BuddyAllocator::BuddyAllocator(
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if
(
!
dev_type
.
empty
())
{
if
(
!
dev_type
.
empty
())
{
init_allocate_size_func_
=
[
dev_type
]()
{
init_allocate_size_func_
=
[
dev_type
]()
{
return
p
latform
::
DeviceManager
::
GetInitAllocSize
(
return
p
hi
::
DeviceManager
::
GetInitAllocSize
(
platform
::
PlaceHelper
::
CreatePlace
(
dev_type
));
platform
::
PlaceHelper
::
CreatePlace
(
dev_type
));
};
};
re_allocate_size_func_
=
[
dev_type
]()
{
re_allocate_size_func_
=
[
dev_type
]()
{
return
p
latform
::
DeviceManager
::
GetReallocSize
(
return
p
hi
::
DeviceManager
::
GetReallocSize
(
platform
::
PlaceHelper
::
CreatePlace
(
dev_type
));
platform
::
PlaceHelper
::
CreatePlace
(
dev_type
));
};
};
}
else
{
}
else
{
...
...
paddle/fluid/memory/detail/system_allocator.cc
浏览文件 @
b53cdc9e
...
@@ -438,7 +438,7 @@ void* CustomAllocator::Alloc(size_t* index, size_t size) {
...
@@ -438,7 +438,7 @@ void* CustomAllocator::Alloc(size_t* index, size_t size) {
void
*
p
;
void
*
p
;
auto
place
=
platform
::
CustomPlace
(
dev_type_
,
dev_id_
);
auto
place
=
platform
::
CustomPlace
(
dev_type_
,
dev_id_
);
auto
device
=
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
auto
device
=
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
p
=
device
->
MemoryAllocate
(
size
);
p
=
device
->
MemoryAllocate
(
size
);
if
(
LIKELY
(
p
))
{
if
(
LIKELY
(
p
))
{
VLOG
(
4
)
<<
"CustomAllocator::Alloc "
<<
p
<<
" size "
<<
size
;
VLOG
(
4
)
<<
"CustomAllocator::Alloc "
<<
p
<<
" size "
<<
size
;
...
@@ -447,7 +447,7 @@ void* CustomAllocator::Alloc(size_t* index, size_t size) {
...
@@ -447,7 +447,7 @@ void* CustomAllocator::Alloc(size_t* index, size_t size) {
}
else
{
}
else
{
size_t
avail
,
total
;
size_t
avail
,
total
;
p
latform
::
DeviceManager
::
MemoryStats
(
place
,
&
total
,
&
avail
);
p
hi
::
DeviceManager
::
MemoryStats
(
place
,
&
total
,
&
avail
);
PADDLE_THROW_BAD_ALLOC
(
platform
::
errors
::
ResourceExhausted
(
PADDLE_THROW_BAD_ALLOC
(
platform
::
errors
::
ResourceExhausted
(
"
\n\n
Out of memory error on %s %d. "
"
\n\n
Out of memory error on %s %d. "
"total memory is %s, used memory is %s, "
"total memory is %s, used memory is %s, "
...
@@ -470,7 +470,7 @@ void CustomAllocator::Free(void* p, size_t size, size_t index) {
...
@@ -470,7 +470,7 @@ void CustomAllocator::Free(void* p, size_t size, size_t index) {
size
,
plug_alloc_size
));
size
,
plug_alloc_size
));
plug_alloc_size
-=
size
;
plug_alloc_size
-=
size
;
auto
place
=
platform
::
CustomPlace
(
dev_type_
,
dev_id_
);
auto
place
=
platform
::
CustomPlace
(
dev_type_
,
dev_id_
);
auto
device
=
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
auto
device
=
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
device
->
MemoryDeallocate
(
p
,
size
);
device
->
MemoryDeallocate
(
p
,
size
);
}
}
...
...
paddle/fluid/memory/memcpy.cc
浏览文件 @
b53cdc9e
...
@@ -44,9 +44,9 @@ void Copy<platform::CPUPlace, platform::CustomPlace>(
...
@@ -44,9 +44,9 @@ void Copy<platform::CPUPlace, platform::CustomPlace>(
VLOG
(
4
)
<<
"memory::Copy "
<<
num
<<
" Bytes from "
<<
src_place
<<
" to "
VLOG
(
4
)
<<
"memory::Copy "
<<
num
<<
" Bytes from "
<<
src_place
<<
" to "
<<
dst_place
<<
", stream="
<<
stream
;
<<
dst_place
<<
", stream="
<<
stream
;
p
latform
::
DeviceManager
::
SetDevice
(
src_place
);
p
hi
::
DeviceManager
::
SetDevice
(
src_place
);
p
latform
::
stream
::
Stream
stream_wrapper
(
src_place
,
stream
);
p
hi
::
stream
::
Stream
stream_wrapper
(
src_place
,
stream
);
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
src_place
)
->
MemoryCopyD2H
(
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
src_place
)
->
MemoryCopyD2H
(
dst
,
src
,
num
,
&
stream_wrapper
);
dst
,
src
,
num
,
&
stream_wrapper
);
}
}
...
@@ -62,9 +62,9 @@ void Copy<platform::CustomPlace, platform::CPUPlace>(
...
@@ -62,9 +62,9 @@ void Copy<platform::CustomPlace, platform::CPUPlace>(
VLOG
(
4
)
<<
"memory::Copy "
<<
num
<<
" Bytes from "
<<
src_place
<<
" to "
VLOG
(
4
)
<<
"memory::Copy "
<<
num
<<
" Bytes from "
<<
src_place
<<
" to "
<<
dst_place
<<
", stream="
<<
stream
;
<<
dst_place
<<
", stream="
<<
stream
;
p
latform
::
DeviceManager
::
SetDevice
(
dst_place
);
p
hi
::
DeviceManager
::
SetDevice
(
dst_place
);
p
latform
::
stream
::
Stream
stream_wrapper
(
dst_place
,
stream
);
p
hi
::
stream
::
Stream
stream_wrapper
(
dst_place
,
stream
);
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
dst_place
)
->
MemoryCopyH2D
(
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
dst_place
)
->
MemoryCopyH2D
(
dst
,
src
,
num
,
&
stream_wrapper
);
dst
,
src
,
num
,
&
stream_wrapper
);
}
}
...
@@ -82,16 +82,16 @@ void Copy<platform::CustomPlace, platform::CustomPlace>(
...
@@ -82,16 +82,16 @@ void Copy<platform::CustomPlace, platform::CustomPlace>(
<<
dst_place
<<
", stream="
<<
stream
;
<<
dst_place
<<
", stream="
<<
stream
;
if
(
src_type
==
dst_type
)
{
if
(
src_type
==
dst_type
)
{
p
latform
::
DeviceManager
::
SetDevice
(
src_place
);
p
hi
::
DeviceManager
::
SetDevice
(
src_place
);
p
latform
::
stream
::
Stream
stream_wrapper
(
src_place
,
stream
);
p
hi
::
stream
::
Stream
stream_wrapper
(
src_place
,
stream
);
auto
src_id
=
platform
::
PlaceHelper
::
GetDeviceId
(
src_place
);
auto
src_id
=
platform
::
PlaceHelper
::
GetDeviceId
(
src_place
);
auto
dst_id
=
platform
::
PlaceHelper
::
GetDeviceId
(
dst_place
);
auto
dst_id
=
platform
::
PlaceHelper
::
GetDeviceId
(
dst_place
);
if
(
src_id
==
dst_id
)
{
if
(
src_id
==
dst_id
)
{
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
src_place
)
->
MemoryCopyD2D
(
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
src_place
)
->
MemoryCopyD2D
(
dst
,
src
,
num
,
&
stream_wrapper
);
dst
,
src
,
num
,
&
stream_wrapper
);
}
else
{
}
else
{
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
src_place
)
->
MemoryCopyP2P
(
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
src_place
)
->
MemoryCopyP2P
(
dst_place
,
dst
,
src
,
num
,
&
stream_wrapper
);
dst_place
,
dst
,
src
,
num
,
&
stream_wrapper
);
}
}
}
else
{
}
else
{
...
...
paddle/fluid/operators/detection/multiclass_nms_op.cc
浏览文件 @
b53cdc9e
...
@@ -93,7 +93,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
...
@@ -93,7 +93,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
// Here the box_dims[0] is not the real dimension of output.
// Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel.
// It will be rewritten in the computing kernel.
if
(
score_size
==
3
)
{
if
(
score_size
==
3
)
{
ctx
->
SetOutputDim
(
"Out"
,
{
box_dims
[
1
]
,
box_dims
[
2
]
+
2
});
ctx
->
SetOutputDim
(
"Out"
,
{
-
1
,
box_dims
[
2
]
+
2
});
}
else
{
}
else
{
ctx
->
SetOutputDim
(
"Out"
,
{
-
1
,
box_dims
[
2
]
+
2
});
ctx
->
SetOutputDim
(
"Out"
,
{
-
1
,
box_dims
[
2
]
+
2
});
}
}
...
@@ -545,11 +545,10 @@ class MultiClassNMS2Op : public MultiClassNMSOp {
...
@@ -545,11 +545,10 @@ class MultiClassNMS2Op : public MultiClassNMSOp {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
MultiClassNMSOp
::
InferShape
(
ctx
);
MultiClassNMSOp
::
InferShape
(
ctx
);
auto
box_dims
=
ctx
->
GetInputDim
(
"BBoxes"
);
auto
score_dims
=
ctx
->
GetInputDim
(
"Scores"
);
auto
score_dims
=
ctx
->
GetInputDim
(
"Scores"
);
auto
score_size
=
score_dims
.
size
();
auto
score_size
=
score_dims
.
size
();
if
(
score_size
==
3
)
{
if
(
score_size
==
3
)
{
ctx
->
SetOutputDim
(
"Index"
,
{
box_dims
[
1
]
,
1
});
ctx
->
SetOutputDim
(
"Index"
,
{
-
1
,
1
});
}
else
{
}
else
{
ctx
->
SetOutputDim
(
"Index"
,
{
-
1
,
1
});
ctx
->
SetOutputDim
(
"Index"
,
{
-
1
,
1
});
}
}
...
...
paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc
浏览文件 @
b53cdc9e
...
@@ -33,7 +33,7 @@ namespace p = paddle::platform;
...
@@ -33,7 +33,7 @@ namespace p = paddle::platform;
USE_OP_ITSELF
(
elementwise_add
);
USE_OP_ITSELF
(
elementwise_add
);
USE_OP_DEVICE_KERNEL
(
elementwise_add
,
NPU
);
USE_OP_DEVICE_KERNEL
(
elementwise_add
,
NPU
);
USE_OP
(
elementwise_sub
);
USE_OP
_ITSELF
(
elementwise_sub
);
USE_OP_DEVICE_KERNEL
(
elementwise_sub
,
NPU
);
USE_OP_DEVICE_KERNEL
(
elementwise_sub
,
NPU
);
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/operators/elementwise/elementwise_pow_op_xpu.cc
浏览文件 @
b53cdc9e
...
@@ -14,7 +14,6 @@ limitations under the License. */
...
@@ -14,7 +14,6 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
#include "xpu/refactor/math.h"
#include "xpu/refactor/math.h"
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.cc
浏览文件 @
b53cdc9e
...
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include <string>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
...
@@ -78,10 +76,16 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> {
...
@@ -78,10 +76,16 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> {
}
// namespace paddle
}
// namespace paddle
REGISTER_ELEMWISE_GRAD_MAKER
(
elementwise_sub
,
Sub
);
REGISTER_ELEMWISE_GRAD_MAKER
(
elementwise_sub
,
Sub
);
REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD
(
elementwise_sub
,
Sub
);
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
elementwise_sub
,
::
paddle
::
operators
::
ElementwiseOp
,
::
paddle
::
operators
::
ElementwiseSubOpMaker
,
::
paddle
::
operators
::
ElementwiseOpInferVarType
,
elementwise_subGradMaker
<::
paddle
::
framework
::
OpDesc
>
,
elementwise_subGradMaker
<::
paddle
::
imperative
::
OpBase
>
,
::
paddle
::
operators
::
ElementwiseOpInplaceInferer
);
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
elementwise_sub_grad
,
ops
::
ElementwiseOpGrad
,
elementwise_sub_grad
,
ops
::
ElementwiseOpGrad
,
ops
::
ElementwiseGradOpInplaceInferer
,
ops
::
ElementwiseGradNoBufVarsInferer
,
ops
::
ElementwiseGradOpInplaceInferer
,
ops
::
ElementwiseGradNoBufVarsInferer
,
...
@@ -92,51 +96,6 @@ REGISTER_OPERATOR(elementwise_sub_grad_grad,
...
@@ -92,51 +96,6 @@ REGISTER_OPERATOR(elementwise_sub_grad_grad,
ops
::
ElementwiseDoubleGradOpInplaceInferer
,
ops
::
ElementwiseDoubleGradOpInplaceInferer
,
ops
::
ElementwiseDoubleGradNoBufVarsInferer
);
ops
::
ElementwiseDoubleGradNoBufVarsInferer
);
REGISTER_OP_CPU_KERNEL
(
elementwise_sub
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int16_t
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
REGISTER_OP_CPU_KERNEL
(
elementwise_sub_grad
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int16_t
>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
REGISTER_OP_CPU_KERNEL
(
elementwise_sub_grad_grad
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int16_t
>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
REGISTER_OP_VERSION
(
elementwise_sub
)
REGISTER_OP_VERSION
(
elementwise_sub
)
.
AddCheckpoint
(
.
AddCheckpoint
(
R"ROC(Register elementwise_sub for adding the attribute of Scale_y)ROC"
,
R"ROC(Register elementwise_sub for adding the attribute of Scale_y)ROC"
,
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
已删除
100644 → 0
浏览文件 @
9e00395a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
elementwise_sub
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
bfloat16
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
REGISTER_OP_CUDA_KERNEL
(
elementwise_sub_grad
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
bfloat16
>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
ElementwiseSubGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
REGISTER_OP_CUDA_KERNEL
(
elementwise_sub_grad_grad
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
bfloat16
>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
ElementwiseSubDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle/fluid/operators/elementwise/elementwise_sub_op.h
已删除
100644 → 0
浏览文件 @
9e00395a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/kernels/elementwise_grad_kernel.h"
#include "paddle/phi/kernels/math_kernel.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseSubKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
z
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
device_context
<
DeviceContext
>
();
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
phi
::
SubtractRawKernel
<
T
>
(
static_cast
<
const
typename
framework
::
ConvertToPhiContext
<
DeviceContext
>::
TYPE
&>
(
dev_ctx
),
*
x
,
*
y
,
axis
,
z
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseSubGradKernel
:
public
ElemwiseGradKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
ElemwiseGradKernel
<
T
>::
Compute
(
ctx
);
using
Tensor
=
framework
::
Tensor
;
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
&
dev_ctx
=
ctx
.
device_context
<
DeviceContext
>
();
phi
::
SubtractGradKernel
<
T
>
(
static_cast
<
const
typename
framework
::
ConvertToPhiContext
<
DeviceContext
>::
TYPE
&>
(
dev_ctx
),
*
x
,
*
y
,
*
dout
,
axis
,
dx
,
dy
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseSubDoubleGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
Tensor
=
framework
::
Tensor
;
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
"DOut"
);
auto
*
ddx
=
ctx
.
Input
<
Tensor
>
(
"DDX"
);
auto
*
ddy
=
ctx
.
Input
<
Tensor
>
(
"DDY"
);
auto
*
ddout
=
ctx
.
Output
<
Tensor
>
(
"DDOut"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
&
dev_ctx
=
ctx
.
device_context
<
DeviceContext
>
();
paddle
::
optional
<
const
phi
::
DenseTensor
&>
ddx_optional
=
paddle
::
none
;
paddle
::
optional
<
const
phi
::
DenseTensor
&>
ddy_optional
=
paddle
::
none
;
if
(
ddx
!=
nullptr
)
{
ddx_optional
=
*
ddx
;
}
if
(
ddy
!=
nullptr
)
{
ddy_optional
=
*
ddy
;
}
phi
::
SubtractDoubleGradKernel
<
T
>
(
static_cast
<
const
typename
framework
::
ConvertToPhiContext
<
DeviceContext
>::
TYPE
&>
(
dev_ctx
),
*
y
,
ddx_optional
,
ddy_optional
,
*
dout
,
axis
,
ddout
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc
浏览文件 @
b53cdc9e
...
@@ -15,7 +15,7 @@ limitations under the License. */
...
@@ -15,7 +15,7 @@ limitations under the License. */
#include <memory>
#include <memory>
#include <string>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_
sub_
op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op_xpu.cc
浏览文件 @
b53cdc9e
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
#include "xpu/refactor/math.h"
#include "xpu/refactor/math.h"
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
b53cdc9e
...
@@ -79,6 +79,28 @@ static void RuntimeStaticShapeCheck(std::vector<int64_t> runtime_input_shape,
...
@@ -79,6 +79,28 @@ static void RuntimeStaticShapeCheck(std::vector<int64_t> runtime_input_shape,
model_input_shape_str
,
runtime_input_shape_str
));
model_input_shape_str
,
runtime_input_shape_str
));
}
}
static
paddle
::
experimental
::
DataType
TRT2FluidDataType
(
nvinfer1
::
DataType
type
)
{
switch
(
type
)
{
case
nvinfer1
::
DataType
::
kFLOAT
:
return
paddle
::
experimental
::
DataType
::
FLOAT32
;
case
nvinfer1
::
DataType
::
kINT32
:
return
paddle
::
experimental
::
DataType
::
INT32
;
case
nvinfer1
::
DataType
::
kHALF
:
return
paddle
::
experimental
::
DataType
::
FLOAT16
;
case
nvinfer1
::
DataType
::
kINT8
:
return
paddle
::
experimental
::
DataType
::
INT8
;
#if IS_TRT_VERSION_GE(7000)
case
nvinfer1
::
DataType
::
kBOOL
:
return
paddle
::
experimental
::
DataType
::
BOOL
;
#endif
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"unknown fluid datatype in Fluid op converter"
));
return
paddle
::
experimental
::
DataType
::
FLOAT32
;
}
}
static
void
RuntimeDynamicShapeCheck
(
static
void
RuntimeDynamicShapeCheck
(
const
std
::
string
&
x
,
const
std
::
vector
<
int32_t
>
&
runtime_input_shape
,
const
std
::
string
&
x
,
const
std
::
vector
<
int32_t
>
&
runtime_input_shape
,
const
std
::
vector
<
int32_t
>
&
min_input_shape
,
const
std
::
vector
<
int32_t
>
&
min_input_shape
,
...
@@ -520,9 +542,12 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -520,9 +542,12 @@ class TensorRTEngineOp : public framework::OperatorBase {
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
int64_t
>
());
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
int64_t
>
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT32
)
{
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT32
)
{
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
int32_t
>
());
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
int32_t
>
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
FP16
)
{
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
float16
>
());
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
"The TRT Engine OP only support float/int32_t/int64_t input."
));
platform
::
errors
::
Fatal
(
"The TRT Engine OP only support "
"float/int32_t/int64_t/float16 input."
));
}
}
}
}
...
@@ -570,9 +595,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
...
@@ -570,9 +595,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
"than the number of bindings, but got binding "
"than the number of bindings, but got binding "
"index = %d, number of bindings = %d."
,
"index = %d, number of bindings = %d."
,
bind_index
,
num_bindings
));
bind_index
,
num_bindings
));
buffers
[
bind_index
]
=
auto
trt_type
=
engine
->
engine
()
->
getBindingDataType
(
bind_index
);
static_cast
<
void
*>
(
fluid_t
->
mutable_data
<
float
>
(
dev_place
));
// get adr and set type
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
fluid_t
->
mutable_data
(
dev_place
,
TRT2FluidDataType
(
trt_type
)));
output_index
+=
1
;
output_index
+=
1
;
}
}
...
...
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
b53cdc9e
...
@@ -117,7 +117,7 @@ endif()
...
@@ -117,7 +117,7 @@ endif()
cc_library
(
cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost
)
cc_library
(
cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost
)
# seperate init from device_context to avoid cycle dependencies
# seperate init from device_context to avoid cycle dependencies
cc_library
(
init SRCS init.cc DEPS device_context custom_kernel
)
cc_library
(
init SRCS init.cc DEPS device_context
phi_
custom_kernel
)
# memcpy depends on device_context, here add deps individually for
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
# avoiding cycle dependencies
...
...
paddle/fluid/platform/device/CMakeLists.txt
浏览文件 @
b53cdc9e
IF
(
WITH_CUSTOM_DEVICE
)
cc_library
(
callback_manager SRCS callback_manager.cc DEPS enforce place
)
cc_library
(
device_guard SRCS device_guard.cc DEPS enforce place
)
cc_library
(
stream SRCS stream.cc DEPS callback_manager
)
cc_library
(
event SRCS event.cc DEPS enforce place
)
cc_library
(
device_base SRCS device_base.cc DEPS stream event callback_manager device_guard device_context flags
)
ENDIF
()
set
(
DEV_LIBS custom_device
)
set
(
DEV_LIBS custom_device
)
...
@@ -37,11 +25,3 @@ ENDIF()
...
@@ -37,11 +25,3 @@ ENDIF()
IF
(
WITH_MLU
)
IF
(
WITH_MLU
)
add_subdirectory
(
mlu
)
add_subdirectory
(
mlu
)
ENDIF
()
ENDIF
()
# CUSTOM
IF
(
WITH_CUSTOM_DEVICE
)
add_subdirectory
(
custom
)
cc_library
(
device_manager SRCS device_manager.cc DEPS custom_device
)
set
(
GLOB_DEV_LIB device_manager custom_device CACHE INTERNAL
"Global DEV library"
)
ENDIF
()
paddle/fluid/platform/device/custom/CMakeLists.txt
已删除
100644 → 0
浏览文件 @
9e00395a
IF
(
WITH_CUSTOM_DEVICE
)
cc_library
(
custom_device SRCS custom_device.cc DEPS device_base device_context
)
cc_test
(
custom_device_test SRCS custom_device_test.cc DEPS device_manager device_context
)
ENDIF
()
paddle/fluid/platform/device/custom/enforce_custom.h
浏览文件 @
b53cdc9e
...
@@ -14,7 +14,10 @@ limitations under the License. */
...
@@ -14,7 +14,10 @@ limitations under the License. */
#pragma once
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/device_ext.h"
#include <string>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/device_ext.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
...
paddle/fluid/platform/device/device_wrapper.h
浏览文件 @
b53cdc9e
...
@@ -40,10 +40,10 @@ limitations under the License. */
...
@@ -40,10 +40,10 @@ limitations under the License. */
#endif
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/callback_manager.h"
#include "paddle/fluid/platform/device/custom/enforce_custom.h"
#include "paddle/fluid/platform/device/custom/enforce_custom.h"
#include "paddle/fluid/platform/device/device_guard.h"
#include "paddle/phi/backends/callback_manager.h"
#include "paddle/fluid/platform/device/device_manager.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/fluid/platform/device/event.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"
#endif
#endif
paddle/fluid/platform/device_context.cc
浏览文件 @
b53cdc9e
...
@@ -903,7 +903,7 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
...
@@ -903,7 +903,7 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
CustomDeviceContext
::
CustomDeviceContext
(
CustomPlace
place
)
CustomDeviceContext
::
CustomDeviceContext
(
CustomPlace
place
)
:
phi
::
CustomContext
(
place
)
{
:
phi
::
CustomContext
(
place
)
{
Init
();
Init
();
stream_
.
reset
(
new
p
latform
::
stream
::
Stream
(
place
,
stream
()));
stream_
.
reset
(
new
p
hi
::
stream
::
Stream
(
place
,
stream
()));
}
}
CustomDeviceContext
::~
CustomDeviceContext
()
{}
CustomDeviceContext
::~
CustomDeviceContext
()
{}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
b53cdc9e
...
@@ -72,8 +72,8 @@ limitations under the License. */
...
@@ -72,8 +72,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device/npu/npu_stream.h"
#include "paddle/fluid/platform/device/npu/npu_stream.h"
#endif
#endif
#include "paddle/
fluid/platform/device
/device_ext.h"
#include "paddle/
phi/backends
/device_ext.h"
#include "paddle/
fluid/platform/device
/stream.h"
#include "paddle/
phi/backends
/stream.h"
#if !defined(PADDLE_WITH_XPU_KP) || defined(__xpu_on_host__)
#if !defined(PADDLE_WITH_XPU_KP) || defined(__xpu_on_host__)
#include "unsupported/Eigen/CXX11/Tensor"
#include "unsupported/Eigen/CXX11/Tensor"
...
@@ -838,7 +838,7 @@ class CustomDeviceContext : public phi::CustomContext {
...
@@ -838,7 +838,7 @@ class CustomDeviceContext : public phi::CustomContext {
void
WaitStreamCallback
()
const
{
return
stream_
->
WaitCallback
();
}
void
WaitStreamCallback
()
const
{
return
stream_
->
WaitCallback
();
}
private:
private:
std
::
shared_ptr
<
p
latform
::
stream
::
Stream
>
stream_
;
std
::
shared_ptr
<
p
hi
::
stream
::
Stream
>
stream_
;
};
};
template
<
>
template
<
>
struct
DefaultDeviceContextType
<
platform
::
CustomPlace
>
{
struct
DefaultDeviceContextType
<
platform
::
CustomPlace
>
{
...
...
paddle/fluid/platform/init.cc
浏览文件 @
b53cdc9e
...
@@ -55,7 +55,7 @@ limitations under the License. */
...
@@ -55,7 +55,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_info.h"
#include "paddle/fluid/platform/device/ipu/ipu_info.h"
#endif
#endif
#include "paddle/
fluid/framework
/custom_kernel.h"
#include "paddle/
phi/core
/custom_kernel.h"
DECLARE_int32
(
paddle_num_threads
);
DECLARE_int32
(
paddle_num_threads
);
PADDLE_DEFINE_EXPORTED_int32
(
PADDLE_DEFINE_EXPORTED_int32
(
...
@@ -145,7 +145,7 @@ void InitCupti() {
...
@@ -145,7 +145,7 @@ void InitCupti() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void
LoadCustomDevice
(
const
std
::
string
&
library_dir
)
{
void
LoadCustomDevice
(
const
std
::
string
&
library_dir
)
{
LOG
(
INFO
)
<<
"Try loading custom device libs from: ["
<<
library_dir
<<
"]"
;
LOG
(
INFO
)
<<
"Try loading custom device libs from: ["
<<
library_dir
<<
"]"
;
std
::
vector
<
std
::
string
>
libs
=
p
latform
::
ListAllLibraries
(
library_dir
);
std
::
vector
<
std
::
string
>
libs
=
p
hi
::
ListAllLibraries
(
library_dir
);
for
(
const
auto
&
lib_path
:
libs
)
{
for
(
const
auto
&
lib_path
:
libs
)
{
auto
dso_handle
=
dlopen
(
lib_path
.
c_str
(),
RTLD_NOW
);
auto
dso_handle
=
dlopen
(
lib_path
.
c_str
(),
RTLD_NOW
);
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
...
@@ -153,8 +153,8 @@ void LoadCustomDevice(const std::string &library_dir) {
...
@@ -153,8 +153,8 @@ void LoadCustomDevice(const std::string &library_dir) {
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Fail to open library: %s with error: %s"
,
lib_path
,
dlerror
()));
"Fail to open library: %s with error: %s"
,
lib_path
,
dlerror
()));
p
latform
::
LoadCustomRuntimeLib
(
lib_path
,
dso_handle
);
p
hi
::
LoadCustomRuntimeLib
(
lib_path
,
dso_handle
);
framework
::
LoadCustomKernelLib
(
lib_path
,
dso_handle
);
phi
::
LoadCustomKernelLib
(
lib_path
,
dso_handle
);
}
}
LOG
(
INFO
)
<<
"Finished in LoadCustomDevice with libs_path: ["
<<
library_dir
LOG
(
INFO
)
<<
"Finished in LoadCustomDevice with libs_path: ["
<<
library_dir
<<
"]"
;
<<
"]"
;
...
@@ -259,9 +259,9 @@ void InitDevices(const std::vector<int> devices) {
...
@@ -259,9 +259,9 @@ void InitDevices(const std::vector<int> devices) {
LOG
(
INFO
)
<<
"ENV [CUSTOM_DEVICE_ROOT]="
<<
custom_kernel_root
;
LOG
(
INFO
)
<<
"ENV [CUSTOM_DEVICE_ROOT]="
<<
custom_kernel_root
;
LoadCustomDevice
(
custom_kernel_root
);
LoadCustomDevice
(
custom_kernel_root
);
auto
device_types
=
p
latform
::
DeviceManager
::
GetAllCustomDeviceTypes
();
auto
device_types
=
p
hi
::
DeviceManager
::
GetAllCustomDeviceTypes
();
for
(
auto
&
dev_type
:
device_types
)
{
for
(
auto
&
dev_type
:
device_types
)
{
auto
device_count
=
p
latform
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
auto
device_count
=
p
hi
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
LOG
(
INFO
)
<<
"CustomDevice: "
<<
dev_type
LOG
(
INFO
)
<<
"CustomDevice: "
<<
dev_type
<<
", visible devices count: "
<<
device_count
;
<<
", visible devices count: "
<<
device_count
;
for
(
size_t
i
=
0
;
i
<
device_count
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
device_count
;
i
++
)
{
...
...
paddle/fluid/pybind/eager_utils.cc
浏览文件 @
b53cdc9e
...
@@ -587,14 +587,9 @@ paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs(
...
@@ -587,14 +587,9 @@ paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs(
reinterpret_cast
<
TensorObject
*>
(
obj
)
->
tensor
);
reinterpret_cast
<
TensorObject
*>
(
obj
)
->
tensor
);
}
}
// For Intermediate State Dygraph,
static
paddle
::
experimental
::
Tensor
&
GetTensorFromPyObject
(
// we use an uninitialized Tensor to represent dispensable Tensor
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
obj
,
paddle
::
experimental
::
Tensor
&
GetTensorFromArgs
(
const
std
::
string
&
op_type
,
ssize_t
arg_idx
,
bool
dispensable
)
{
const
std
::
string
&
arg_name
,
PyObject
*
args
,
ssize_t
arg_idx
,
bool
dispensable
)
{
PyObject
*
obj
=
PyTuple_GET_ITEM
(
args
,
arg_idx
);
if
(
PyTuple_Check
(
obj
))
{
if
(
PyTuple_Check
(
obj
))
{
obj
=
PyTuple_GET_ITEM
(
obj
,
0
);
obj
=
PyTuple_GET_ITEM
(
obj
,
0
);
}
}
...
@@ -612,6 +607,16 @@ paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type,
...
@@ -612,6 +607,16 @@ paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type,
return
reinterpret_cast
<
TensorObject
*>
(
obj
)
->
tensor
;
return
reinterpret_cast
<
TensorObject
*>
(
obj
)
->
tensor
;
}
}
// For Intermediate State Dygraph,
// we use an uninitialized Tensor to represent dispensable Tensor
paddle
::
experimental
::
Tensor
&
GetTensorFromArgs
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
args
,
ssize_t
arg_idx
,
bool
dispensable
)
{
PyObject
*
obj
=
PyTuple_GET_ITEM
(
args
,
arg_idx
);
return
GetTensorFromPyObject
(
op_type
,
arg_name
,
obj
,
arg_idx
,
dispensable
);
}
std
::
vector
<
paddle
::
experimental
::
Tensor
>
GetTensorListFromArgs
(
std
::
vector
<
paddle
::
experimental
::
Tensor
>
GetTensorListFromArgs
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
args
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
args
,
ssize_t
arg_idx
,
bool
dispensable
)
{
ssize_t
arg_idx
,
bool
dispensable
)
{
...
@@ -746,5 +751,84 @@ std::vector<paddle::experimental::Tensor*> GetTensorPtrListFromArgs(
...
@@ -746,5 +751,84 @@ std::vector<paddle::experimental::Tensor*> GetTensorPtrListFromArgs(
return
result
;
return
result
;
}
}
paddle
::
experimental
::
Scalar
CastPyArg2Scalar
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
if
(
obj
==
Py_None
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"bool, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
// obj could be: int, float, bool, paddle.Tensor
PyTypeObject
*
type
=
obj
->
ob_type
;
auto
type_name
=
std
::
string
(
type
->
tp_name
);
if
(
type_name
==
"int"
)
{
int
value
=
CastPyArg2Int
(
obj
,
op_type
,
arg_pos
);
return
paddle
::
experimental
::
Scalar
(
value
);
}
else
if
(
type_name
==
"float"
)
{
float
value
=
CastPyArg2Float
(
obj
,
op_type
,
arg_pos
);
return
paddle
::
experimental
::
Scalar
(
value
);
}
else
if
(
type_name
==
"bool"
)
{
bool
value
=
CastPyArg2Boolean
(
obj
,
op_type
,
arg_pos
);
return
paddle
::
experimental
::
Scalar
(
value
);
}
else
if
(
type_name
==
"paddle.Tensor"
)
{
paddle
::
experimental
::
Tensor
&
value
=
GetTensorFromPyObject
(
op_type
,
""
/*arg_name*/
,
obj
,
arg_pos
,
false
/*dispensable*/
);
return
paddle
::
experimental
::
Scalar
(
value
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"bool, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
// Fake a Scalar
return
paddle
::
experimental
::
Scalar
(
1.0
);
}
paddle
::
experimental
::
ScalarArray
CastPyArg2ScalarArray
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
)
{
// In case of ScalarArray, only two possible PyObjects:
// 1. list of int
// 2. Tensor
if
(
obj
==
Py_None
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"bool, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
// obj could be: int, float, bool, paddle.Tensor
PyTypeObject
*
type
=
obj
->
ob_type
;
auto
type_name
=
std
::
string
(
type
->
tp_name
);
if
(
type_name
==
"list"
)
{
std
::
vector
<
int
>
value
=
CastPyArg2Ints
(
obj
,
op_type
,
arg_pos
);
return
paddle
::
experimental
::
ScalarArray
(
value
);
}
else
if
(
type_name
==
"paddle.Tensor"
)
{
paddle
::
experimental
::
Tensor
&
value
=
GetTensorFromPyObject
(
op_type
,
""
/*arg_name*/
,
obj
,
arg_pos
,
false
/*dispensable*/
);
return
paddle
::
experimental
::
ScalarArray
(
value
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument (position %d) must be "
"bool, but got %s"
,
op_type
,
arg_pos
+
1
,
((
PyTypeObject
*
)
obj
->
ob_type
)
->
tp_name
));
// NOLINT
}
// Fake a ScalarArray
return
paddle
::
experimental
::
ScalarArray
({
1
});
}
}
// namespace pybind
}
// namespace pybind
}
// namespace paddle
}
// namespace paddle
paddle/fluid/pybind/eager_utils.h
浏览文件 @
b53cdc9e
...
@@ -11,7 +11,10 @@ limitations under the License. */
...
@@ -11,7 +11,10 @@ limitations under the License. */
#pragma once
#pragma once
#include <Python.h>
#include <Python.h>
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "pybind11/pybind11.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -90,6 +93,13 @@ PyObject* ToPyObject(const std::tuple<Args...>& out) {
...
@@ -90,6 +93,13 @@ PyObject* ToPyObject(const std::tuple<Args...>& out) {
return
result
;
return
result
;
}
}
paddle
::
experimental
::
Scalar
CastPyArg2Scalar
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
paddle
::
experimental
::
ScalarArray
CastPyArg2ScalarArray
(
PyObject
*
obj
,
const
std
::
string
&
op_type
,
ssize_t
arg_pos
);
paddle
::
optional
<
paddle
::
experimental
::
Tensor
>
GetOptionalTensorFromArgs
(
paddle
::
optional
<
paddle
::
experimental
::
Tensor
>
GetOptionalTensorFromArgs
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
args
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
args
,
ssize_t
arg_idx
,
bool
dispensable
=
false
);
ssize_t
arg_idx
,
bool
dispensable
=
false
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
b53cdc9e
...
@@ -1668,7 +1668,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -1668,7 +1668,7 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"get_all_device_type"
,
[]()
{
m
.
def
(
"get_all_device_type"
,
[]()
{
std
::
vector
<
std
::
string
>
device_types
;
std
::
vector
<
std
::
string
>
device_types
;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
device_types
=
p
latform
::
DeviceManager
::
GetAllDeviceTypes
();
device_types
=
p
hi
::
DeviceManager
::
GetAllDeviceTypes
();
#else
#else
LOG
(
WARNING
)
<<
string
::
Sprintf
(
LOG
(
WARNING
)
<<
string
::
Sprintf
(
"Cannot use get_all_device_type because you have installed"
"Cannot use get_all_device_type because you have installed"
...
@@ -1682,7 +1682,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -1682,7 +1682,7 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"get_all_custom_device_type"
,
[]()
{
m
.
def
(
"get_all_custom_device_type"
,
[]()
{
std
::
vector
<
std
::
string
>
device_types
;
std
::
vector
<
std
::
string
>
device_types
;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
device_types
=
p
latform
::
DeviceManager
::
GetAllCustomDeviceTypes
();
device_types
=
p
hi
::
DeviceManager
::
GetAllCustomDeviceTypes
();
#else
#else
LOG
(
WARNING
)
<<
string
::
Sprintf
(
LOG
(
WARNING
)
<<
string
::
Sprintf
(
"Cannot use get_all_custom_device_type because you have installed"
"Cannot use get_all_custom_device_type because you have installed"
...
@@ -1696,7 +1696,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -1696,7 +1696,7 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"get_available_device"
,
[]
{
m
.
def
(
"get_available_device"
,
[]
{
std
::
vector
<
std
::
string
>
devices
;
std
::
vector
<
std
::
string
>
devices
;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
devices
=
p
latform
::
DeviceManager
::
GetAllDeviceList
();
devices
=
p
hi
::
DeviceManager
::
GetAllDeviceList
();
#else
#else
LOG
(
WARNING
)
<<
string
::
Sprintf
(
LOG
(
WARNING
)
<<
string
::
Sprintf
(
"Cannot use get_available_device because you have installed"
"Cannot use get_available_device because you have installed"
...
@@ -1710,7 +1710,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -1710,7 +1710,7 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"get_available_custom_device"
,
[]
{
m
.
def
(
"get_available_custom_device"
,
[]
{
std
::
vector
<
std
::
string
>
devices
;
std
::
vector
<
std
::
string
>
devices
;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
devices
=
p
latform
::
DeviceManager
::
GetAllCustomDeviceList
();
devices
=
p
hi
::
DeviceManager
::
GetAllCustomDeviceList
();
#else
#else
LOG
(
WARNING
)
<<
string
::
Sprintf
(
LOG
(
WARNING
)
<<
string
::
Sprintf
(
"Cannot use get_available_custom_device because you have "
"Cannot use get_available_custom_device because you have "
...
@@ -1747,10 +1747,10 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -1747,10 +1747,10 @@ All parameter, weight, gradient are variables in Paddle.
std
::
exit
(
-
1
);
std
::
exit
(
-
1
);
}
}
if
(
LIKELY
(
p
latform
::
DeviceManager
::
HasDeviceType
(
device_type
)
&&
if
(
LIKELY
(
p
hi
::
DeviceManager
::
HasDeviceType
(
device_type
)
&&
p
latform
::
DeviceManager
::
IsCustom
(
device_type
)))
{
p
hi
::
DeviceManager
::
IsCustom
(
device_type
)))
{
int
dev_count
=
static_cast
<
int
>
(
int
dev_count
=
static_cast
<
int
>
(
p
latform
::
DeviceManager
::
GetDeviceCount
(
device_type
));
p
hi
::
DeviceManager
::
GetDeviceCount
(
device_type
));
if
(
UNLIKELY
(
dev_id
>=
dev_count
))
{
if
(
UNLIKELY
(
dev_id
>=
dev_count
))
{
if
(
dev_count
==
0
)
{
if
(
dev_count
==
0
)
{
LOG
(
ERROR
)
<<
"Cannot use "
<<
device_type
LOG
(
ERROR
)
<<
"Cannot use "
<<
device_type
...
...
paddle/fluid/pybind/tensor_py.h
浏览文件 @
b53cdc9e
...
@@ -393,10 +393,10 @@ void SetTensorFromPyArrayT(
...
@@ -393,10 +393,10 @@ void SetTensorFromPyArrayT(
}
else
if
(
paddle
::
platform
::
is_custom_place
(
place
))
{
}
else
if
(
paddle
::
platform
::
is_custom_place
(
place
))
{
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform
::
Place
tmp_place
=
place
;
platform
::
Place
tmp_place
=
place
;
p
latform
::
DeviceGuard
guard
(
tmp_place
);
p
hi
::
DeviceGuard
guard
(
tmp_place
);
auto
dst
=
self
->
mutable_data
<
T
>
(
place
);
auto
dst
=
self
->
mutable_data
<
T
>
(
place
);
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
tmp_place
)
->
MemoryCopyH2D
(
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
tmp_place
)
->
MemoryCopyH2D
(
reinterpret_cast
<
void
*>
(
dst
),
reinterpret_cast
<
void
*>
(
dst
),
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
array
.
data
())),
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
array
.
data
())),
array
.
nbytes
());
array
.
nbytes
());
...
...
paddle/phi/backends/CMakeLists.txt
浏览文件 @
b53cdc9e
...
@@ -24,4 +24,11 @@ endif()
...
@@ -24,4 +24,11 @@ endif()
if
(
WITH_CUSTOM_DEVICE
)
if
(
WITH_CUSTOM_DEVICE
)
add_dependencies
(
phi_context custom_context
)
add_dependencies
(
phi_context custom_context
)
cc_library
(
callback_manager SRCS callback_manager.cc DEPS enforce place
)
cc_library
(
device_guard SRCS device_guard.cc DEPS enforce place
)
cc_library
(
stream SRCS stream.cc DEPS callback_manager
)
cc_library
(
event SRCS event.cc DEPS enforce place
)
cc_library
(
device_base SRCS device_base.cc DEPS stream event callback_manager device_guard device_context flags
)
cc_library
(
device_manager SRCS device_manager.cc DEPS custom_device
)
set
(
GLOB_DEV_LIB device_manager custom_device CACHE INTERNAL
"Global DEV library"
)
endif
()
endif
()
paddle/
fluid/platform/device
/callback_manager.cc
→
paddle/
phi/backends
/callback_manager.cc
浏览文件 @
b53cdc9e
...
@@ -12,12 +12,11 @@
...
@@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/
fluid/platform/device
/callback_manager.h"
#include "paddle/
phi/backends
/callback_manager.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
phi
{
namespace
platform
{
CallbackManager
::
CallbackManager
(
stream
::
Stream
*
stream
)
CallbackManager
::
CallbackManager
(
stream
::
Stream
*
stream
)
:
stream_
(
stream
),
thread_pool_
(
1
)
{}
:
stream_
(
stream
),
thread_pool_
(
1
)
{}
...
@@ -32,12 +31,12 @@ void CallbackManager::AddCallback(std::function<void()> callback) const {
...
@@ -32,12 +31,12 @@ void CallbackManager::AddCallback(std::function<void()> callback) const {
});
});
});
});
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
stream_
->
GetPlace
())
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
stream_
->
GetPlace
())
->
AddCallback
(
stream_
,
func
);
->
AddCallback
(
stream_
,
func
);
}
}
void
CallbackManager
::
Wait
()
const
{
void
CallbackManager
::
Wait
()
const
{
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
stream_
->
GetPlace
())
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
stream_
->
GetPlace
())
->
SynchronizeStream
(
stream_
);
->
SynchronizeStream
(
stream_
);
{
{
...
@@ -48,5 +47,4 @@ void CallbackManager::Wait() const {
...
@@ -48,5 +47,4 @@ void CallbackManager::Wait() const {
}
}
}
}
}
// namespace platform
}
// namespace phi
}
// namespace paddle
paddle/
fluid/platform/device
/callback_manager.h
→
paddle/
phi/backends
/callback_manager.h
浏览文件 @
b53cdc9e
...
@@ -32,8 +32,7 @@
...
@@ -32,8 +32,7 @@
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
phi
{
namespace
platform
{
namespace
stream
{
namespace
stream
{
class
Stream
;
class
Stream
;
...
@@ -58,5 +57,4 @@ class CallbackManager {
...
@@ -58,5 +57,4 @@ class CallbackManager {
mutable
std
::
future
<
void
>
last_future_
;
mutable
std
::
future
<
void
>
last_future_
;
};
};
}
// namespace platform
}
// namespace phi
}
// namespace paddle
paddle/phi/backends/custom/CMakeLists.txt
浏览文件 @
b53cdc9e
if
(
WITH_CUSTOM_DEVICE
)
if
(
WITH_CUSTOM_DEVICE
)
cc_library
(
custom_context SRCS custom_context.cc DEPS phi_device_context device_manager
)
cc_library
(
custom_context SRCS custom_context.cc DEPS phi_device_context device_manager
)
cc_library
(
custom_device SRCS custom_device.cc DEPS device_base device_context
)
cc_test
(
custom_device_test SRCS custom_device_test.cc DEPS device_manager device_context
)
endif
()
endif
()
paddle/phi/backends/custom/custom_context.cc
浏览文件 @
b53cdc9e
...
@@ -14,8 +14,8 @@ limitations under the License. */
...
@@ -14,8 +14,8 @@ limitations under the License. */
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/
fluid/platform/device
/device_guard.h"
#include "paddle/
phi/backends
/device_guard.h"
#include "paddle/
fluid/platform/device
/stream.h"
#include "paddle/
phi/backends
/stream.h"
namespace
phi
{
namespace
phi
{
...
@@ -25,8 +25,8 @@ struct CustomContext::Impl {
...
@@ -25,8 +25,8 @@ struct CustomContext::Impl {
~
Impl
()
{}
~
Impl
()
{}
void
Init
()
{
void
Init
()
{
p
addle
::
platform
::
DeviceGuard
guard
(
place_
);
p
hi
::
DeviceGuard
guard
(
place_
);
stream_
.
reset
(
new
p
addle
::
platform
::
stream
::
Stream
());
stream_
.
reset
(
new
p
hi
::
stream
::
Stream
());
stream_
->
Init
(
place_
);
stream_
->
Init
(
place_
);
}
}
...
@@ -40,7 +40,7 @@ struct CustomContext::Impl {
...
@@ -40,7 +40,7 @@ struct CustomContext::Impl {
Place
place_
;
Place
place_
;
std
::
shared_ptr
<
p
addle
::
platform
::
stream
::
Stream
>
stream_
;
std
::
shared_ptr
<
p
hi
::
stream
::
Stream
>
stream_
;
};
};
void
CustomContext
::
Init
()
{
impl_
->
Init
();
}
void
CustomContext
::
Init
()
{
impl_
->
Init
();
}
...
...
paddle/
fluid/platform/device
/custom/custom_device.cc
→
paddle/
phi/backends
/custom/custom_device.cc
浏览文件 @
b53cdc9e
...
@@ -12,23 +12,28 @@
...
@@ -12,23 +12,28 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/platform/device/device_base.h"
#include "paddle/fluid/platform/device/custom/enforce_custom.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/event.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/callback_manager.h"
#include "paddle/phi/backends/device_base.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"
static
bool
operator
==
(
const
C_Device_st
&
d1
,
const
C_Device_st
&
d2
)
{
static
bool
operator
==
(
const
C_Device_st
&
d1
,
const
C_Device_st
&
d2
)
{
return
d1
.
id
==
d2
.
id
;
return
d1
.
id
==
d2
.
id
;
}
}
namespace
paddle
{
namespace
phi
{
namespace
platform
{
class
CustomDevice
:
public
DeviceInterface
{
class
CustomDevice
:
public
DeviceInterface
{
public:
public:
CustomDevice
(
const
std
::
string
&
type
,
int
priority
,
bool
is_custom
,
CustomDevice
(
const
std
::
string
&
type
,
std
::
unique_ptr
<
C_DeviceInterface
>
pimpl
,
void
*
dso_handle
)
int
priority
,
bool
is_custom
,
std
::
unique_ptr
<
C_DeviceInterface
>
pimpl
,
void
*
dso_handle
)
:
DeviceInterface
(
type
,
priority
,
is_custom
),
:
DeviceInterface
(
type
,
priority
,
is_custom
),
pimpl_
(
std
::
move
(
pimpl
)),
pimpl_
(
std
::
move
(
pimpl
)),
dso_handle_
(
dso_handle
)
{
dso_handle_
(
dso_handle
)
{
...
@@ -122,14 +127,15 @@ class CustomDevice : public DeviceInterface {
...
@@ -122,14 +127,15 @@ class CustomDevice : public DeviceInterface {
return
device
.
id
;
return
device
.
id
;
}
}
void
CreateStream
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
void
CreateStream
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
const
stream
::
Stream
::
Priority
&
priority
=
const
stream
::
Stream
::
Priority
&
priority
=
stream
::
Stream
::
Priority
::
kNormal
,
stream
::
Stream
::
Priority
::
kNormal
,
const
stream
::
Stream
::
Flag
&
flag
=
const
stream
::
Stream
::
Flag
&
flag
=
stream
::
Stream
::
Flag
::
kDefaultFlag
)
override
{
stream
::
Stream
::
Flag
::
kDefaultFlag
)
override
{
if
(
priority
!=
stream
::
Stream
::
Priority
::
kNormal
||
if
(
priority
!=
stream
::
Stream
::
Priority
::
kNormal
||
flag
!=
stream
::
Stream
::
Flag
::
kDefaultFlag
)
{
flag
!=
stream
::
Stream
::
Flag
::
kDefaultFlag
)
{
PADDLE_THROW
(
p
latform
::
errors
::
Unavailable
(
PADDLE_THROW
(
p
hi
::
errors
::
Unavailable
(
"priority != stream::Stream::Priority::kNormal || flag != "
"priority != stream::Stream::Priority::kNormal || flag != "
"stream::Stream::Flag::kDefaultFlag is not allowed on "
"stream::Stream::Flag::kDefaultFlag is not allowed on "
"CustomDevice."
));
"CustomDevice."
));
...
@@ -162,23 +168,28 @@ class CustomDevice : public DeviceInterface {
...
@@ -162,23 +168,28 @@ class CustomDevice : public DeviceInterface {
SynchronizeStream
(
dev_id
,
stream
);
SynchronizeStream
(
dev_id
,
stream
);
return
true
;
return
true
;
}
}
if
(
pimpl_
->
query_stream
(
device
,
reinterpret_cast
<
C_Stream
>
(
if
(
pimpl_
->
query_stream
(
stream
->
raw_stream
()))
==
C_SUCCESS
)
{
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()))
==
C_SUCCESS
)
{
return
true
;
return
true
;
}
}
return
false
;
return
false
;
}
}
void
AddCallback
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
void
AddCallback
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
stream
::
Stream
::
Callback
*
callback
)
override
{
stream
::
Stream
::
Callback
*
callback
)
override
{
if
(
!
pimpl_
->
stream_add_callback
)
{
if
(
!
pimpl_
->
stream_add_callback
)
{
PADDLE_THROW
(
p
latform
::
errors
::
Unavailable
(
PADDLE_THROW
(
p
hi
::
errors
::
Unavailable
(
"AddCallback is not supported on %s."
,
Type
()));
"AddCallback is not supported on %s."
,
Type
()));
}
else
{
}
else
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
stream_add_callback
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
stream_add_callback
(
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
device
,
[](
C_Device
device
,
C_Stream
stream
,
void
*
user_data
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
[](
C_Device
device
,
C_Stream
stream
,
void
*
user_data
,
C_Status
*
status
)
{
C_Status
*
status
)
{
std
::
unique_ptr
<
std
::
function
<
void
()
>>
func
(
std
::
unique_ptr
<
std
::
function
<
void
()
>>
func
(
reinterpret_cast
<
std
::
function
<
void
()
>*>
(
user_data
));
reinterpret_cast
<
std
::
function
<
void
()
>*>
(
user_data
));
...
@@ -188,7 +199,8 @@ class CustomDevice : public DeviceInterface {
...
@@ -188,7 +199,8 @@ class CustomDevice : public DeviceInterface {
}
}
}
}
void
CreateEvent
(
size_t
dev_id
,
event
::
Event
*
event
,
void
CreateEvent
(
size_t
dev_id
,
event
::
Event
*
event
,
event
::
Event
::
Flag
flags
)
override
{
event
::
Event
::
Flag
flags
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
C_Event
c_event
;
C_Event
c_event
;
...
@@ -205,13 +217,15 @@ class CustomDevice : public DeviceInterface {
...
@@ -205,13 +217,15 @@ class CustomDevice : public DeviceInterface {
device
,
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
())));
device
,
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
())));
}
}
void
RecordEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
,
void
RecordEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
,
const
stream
::
Stream
*
stream
)
override
{
const
stream
::
Stream
*
stream
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
record_event
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
pimpl_
->
record_event
(
device
,
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
())));
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
())));
}
}
void
SynchronizeEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
)
override
{
void
SynchronizeEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
)
override
{
...
@@ -228,78 +242,93 @@ class CustomDevice : public DeviceInterface {
...
@@ -228,78 +242,93 @@ class CustomDevice : public DeviceInterface {
SynchronizeEvent
(
dev_id
,
event
);
SynchronizeEvent
(
dev_id
,
event
);
return
true
;
return
true
;
}
}
if
(
pimpl_
->
query_event
(
device
,
reinterpret_cast
<
C_Event
>
(
if
(
pimpl_
->
query_event
(
device
,
event
->
raw_event
()))
==
C_SUCCESS
)
{
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
()))
==
C_SUCCESS
)
{
return
true
;
return
true
;
}
}
return
false
;
return
false
;
}
}
void
StreamWaitEvent
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
,
void
StreamWaitEvent
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
,
const
event
::
Event
*
event
)
override
{
const
event
::
Event
*
event
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
stream_wait_event
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
stream_wait_event
(
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
())));
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
())));
}
}
void
MemoryCopyH2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyH2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
)
override
{
const
stream
::
Stream
*
stream
=
nullptr
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
auto
place
=
platform
::
CustomPlace
(
Type
(),
dev_id
);
auto
place
=
CustomPlace
(
Type
(),
dev_id
);
if
(
stream
&&
stream
->
raw_stream
()
&&
pimpl_
->
async_memory_copy_h2d
)
{
if
(
stream
&&
stream
->
raw_stream
()
&&
pimpl_
->
async_memory_copy_h2d
)
{
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
async_memory_copy_h2d
(
device
,
c_stream
,
dst
,
src
,
size
));
pimpl_
->
async_memory_copy_h2d
(
device
,
c_stream
,
dst
,
src
,
size
));
}
else
{
}
else
{
platform
::
DeviceContextPool
&
pool
=
p
addle
::
p
latform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
p
addle
::
p
latform
::
DeviceContextPool
::
Instance
();
pool
.
Get
(
place
)
->
Wait
();
pool
.
Get
(
place
)
->
Wait
();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
memory_copy_h2d
(
device
,
dst
,
src
,
size
));
pimpl_
->
memory_copy_h2d
(
device
,
dst
,
src
,
size
));
}
}
}
}
void
MemoryCopyD2H
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyD2H
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
)
override
{
const
stream
::
Stream
*
stream
=
nullptr
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
auto
place
=
platform
::
CustomPlace
(
Type
(),
dev_id
);
auto
place
=
CustomPlace
(
Type
(),
dev_id
);
if
(
stream
&&
stream
->
raw_stream
()
&&
pimpl_
->
async_memory_copy_d2h
)
{
if
(
stream
&&
stream
->
raw_stream
()
&&
pimpl_
->
async_memory_copy_d2h
)
{
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
async_memory_copy_d2h
(
device
,
c_stream
,
dst
,
src
,
size
));
pimpl_
->
async_memory_copy_d2h
(
device
,
c_stream
,
dst
,
src
,
size
));
}
else
{
}
else
{
platform
::
DeviceContextPool
&
pool
=
p
addle
::
p
latform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
p
addle
::
p
latform
::
DeviceContextPool
::
Instance
();
pool
.
Get
(
place
)
->
Wait
();
pool
.
Get
(
place
)
->
Wait
();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
memory_copy_d2h
(
device
,
dst
,
src
,
size
));
pimpl_
->
memory_copy_d2h
(
device
,
dst
,
src
,
size
));
}
}
}
}
void
MemoryCopyD2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyD2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
)
override
{
const
stream
::
Stream
*
stream
=
nullptr
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
auto
place
=
platform
::
CustomPlace
(
Type
(),
dev_id
);
auto
place
=
CustomPlace
(
Type
(),
dev_id
);
if
(
stream
&&
stream
->
raw_stream
()
&&
pimpl_
->
async_memory_copy_d2d
)
{
if
(
stream
&&
stream
->
raw_stream
()
&&
pimpl_
->
async_memory_copy_d2d
)
{
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
async_memory_copy_d2d
(
device
,
c_stream
,
dst
,
src
,
size
));
pimpl_
->
async_memory_copy_d2d
(
device
,
c_stream
,
dst
,
src
,
size
));
}
else
{
}
else
{
platform
::
DeviceContextPool
&
pool
=
p
addle
::
p
latform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
p
addle
::
p
latform
::
DeviceContextPool
::
Instance
();
pool
.
Get
(
place
)
->
Wait
();
pool
.
Get
(
place
)
->
Wait
();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
memory_copy_d2d
(
device
,
dst
,
src
,
size
));
pimpl_
->
memory_copy_d2d
(
device
,
dst
,
src
,
size
));
}
}
}
}
void
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
size_t
src_dev_id
,
void
MemoryCopyP2P
(
const
Place
&
dst_place
,
const
void
*
src
,
size_t
size
,
void
*
dst
,
size_t
src_dev_id
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
)
override
{
const
stream
::
Stream
*
stream
=
nullptr
)
override
{
int
dst_dev_id
=
PlaceToId
(
dst_place
);
int
dst_dev_id
=
PlaceToId
(
dst_place
);
auto
dst_device
=
&
devices_pool
[
dst_dev_id
];
auto
dst_device
=
&
devices_pool
[
dst_dev_id
];
...
@@ -310,8 +339,12 @@ class CustomDevice : public DeviceInterface {
...
@@ -310,8 +339,12 @@ class CustomDevice : public DeviceInterface {
MemoryCopyP2P
(
dst_place
,
dst
,
src_dev_id
,
src
,
size
);
MemoryCopyP2P
(
dst_place
,
dst
,
src_dev_id
,
src
,
size
);
}
else
{
}
else
{
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
async_memory_copy_p2p
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
async_memory_copy_p2p
(
dst_device
,
src_device
,
dst_device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
dst
,
src
,
size
));
src_device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
dst
,
src
,
size
));
}
}
}
else
{
}
else
{
if
(
!
pimpl_
->
memory_copy_p2p
)
{
if
(
!
pimpl_
->
memory_copy_p2p
)
{
...
@@ -319,9 +352,9 @@ class CustomDevice : public DeviceInterface {
...
@@ -319,9 +352,9 @@ class CustomDevice : public DeviceInterface {
MemoryCopyD2H
(
src_dev_id
,
tmp
.
get
(),
src
,
size
);
MemoryCopyD2H
(
src_dev_id
,
tmp
.
get
(),
src
,
size
);
MemoryCopyH2D
(
dst_dev_id
,
dst
,
tmp
.
get
(),
size
);
MemoryCopyH2D
(
dst_dev_id
,
dst
,
tmp
.
get
(),
size
);
}
else
{
}
else
{
auto
src_place
=
platform
::
CustomPlace
(
Type
(),
src_dev_id
);
auto
src_place
=
CustomPlace
(
Type
(),
src_dev_id
);
platform
::
DeviceContextPool
&
pool
=
p
addle
::
p
latform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
p
addle
::
p
latform
::
DeviceContextPool
::
Instance
();
pool
.
Get
(
src_place
)
->
Wait
();
pool
.
Get
(
src_place
)
->
Wait
();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
memory_copy_p2p
(
dst_device
,
src_device
,
dst
,
src
,
size
));
pimpl_
->
memory_copy_p2p
(
dst_device
,
src_device
,
dst
,
src
,
size
));
...
@@ -350,8 +383,8 @@ class CustomDevice : public DeviceInterface {
...
@@ -350,8 +383,8 @@ class CustomDevice : public DeviceInterface {
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
if
(
!
pimpl_
->
unified_memory_allocate
)
{
if
(
!
pimpl_
->
unified_memory_allocate
)
{
PADDLE_THROW
(
p
latform
::
errors
::
Unavailable
(
PADDLE_THROW
(
p
hi
::
errors
::
Unavailable
(
"MemoryAlloc
Kind::
Host is not supported on %s."
,
Type
()));
"MemoryAlloc
ate
Host is not supported on %s."
,
Type
()));
}
else
{
}
else
{
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
host_memory_allocate
(
device
,
&
ptr
,
size
));
pimpl_
->
host_memory_allocate
(
device
,
&
ptr
,
size
));
...
@@ -363,8 +396,8 @@ class CustomDevice : public DeviceInterface {
...
@@ -363,8 +396,8 @@ class CustomDevice : public DeviceInterface {
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
if
(
!
pimpl_
->
host_memory_deallocate
)
{
if
(
!
pimpl_
->
host_memory_deallocate
)
{
PADDLE_THROW
(
p
latform
::
errors
::
Unavailable
(
PADDLE_THROW
(
p
hi
::
errors
::
Unavailable
(
"Memory
AllocKind::
Host is not supported on %s."
,
Type
()));
"Memory
Deallocate
Host is not supported on %s."
,
Type
()));
}
else
{
}
else
{
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
host_memory_deallocate
(
device
,
ptr
,
size
));
pimpl_
->
host_memory_deallocate
(
device
,
ptr
,
size
));
...
@@ -376,8 +409,8 @@ class CustomDevice : public DeviceInterface {
...
@@ -376,8 +409,8 @@ class CustomDevice : public DeviceInterface {
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
if
(
!
pimpl_
->
unified_memory_allocate
)
{
if
(
!
pimpl_
->
unified_memory_allocate
)
{
PADDLE_THROW
(
p
latform
::
errors
::
Unavailable
(
PADDLE_THROW
(
p
hi
::
errors
::
Unavailable
(
"MemoryAlloc
Kind::
Unified is not supported on %s."
,
Type
()));
"MemoryAlloc
ate
Unified is not supported on %s."
,
Type
()));
}
else
{
}
else
{
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
unified_memory_allocate
(
device
,
&
ptr
,
size
));
pimpl_
->
unified_memory_allocate
(
device
,
&
ptr
,
size
));
...
@@ -389,15 +422,17 @@ class CustomDevice : public DeviceInterface {
...
@@ -389,15 +422,17 @@ class CustomDevice : public DeviceInterface {
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
if
(
!
pimpl_
->
unified_memory_deallocate
)
{
if
(
!
pimpl_
->
unified_memory_deallocate
)
{
PADDLE_THROW
(
p
latform
::
errors
::
Unavailable
(
PADDLE_THROW
(
p
hi
::
errors
::
Unavailable
(
"Memory
AllocKind::Host
is not supported on %s."
,
Type
()));
"Memory
DeallocateUnified
is not supported on %s."
,
Type
()));
}
else
{
}
else
{
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
unified_memory_deallocate
(
device
,
ptr
,
size
));
pimpl_
->
unified_memory_deallocate
(
device
,
ptr
,
size
));
}
}
}
}
void
MemorySet
(
size_t
dev_id
,
void
*
ptr
,
uint8_t
value
,
void
MemorySet
(
size_t
dev_id
,
void
*
ptr
,
uint8_t
value
,
size_t
size
)
override
{
size_t
size
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
...
@@ -532,10 +567,12 @@ class CustomDevice : public DeviceInterface {
...
@@ -532,10 +567,12 @@ class CustomDevice : public DeviceInterface {
inline
int
PlaceToId
(
const
Place
&
place
)
{
inline
int
PlaceToId
(
const
Place
&
place
)
{
int
dev_id
=
PlaceToIdNoCheck
(
place
);
int
dev_id
=
PlaceToIdNoCheck
(
place
);
PADDLE_ENFORCE_NE
(
devices_pool
.
find
(
dev_id
),
devices_pool
.
end
(),
PADDLE_ENFORCE_NE
(
devices_pool
.
find
(
dev_id
),
platform
::
errors
::
NotFound
(
devices_pool
.
end
(),
phi
::
errors
::
NotFound
(
"Cannot found %s %d, please check visible devices"
,
"Cannot found %s %d, please check visible devices"
,
Type
(),
dev_id
));
Type
(),
dev_id
));
return
dev_id
;
return
dev_id
;
}
}
...
@@ -623,11 +660,14 @@ typedef bool (*RegisterDevicePluginFn)(CustomRuntimeParams* runtime_params);
...
@@ -623,11 +660,14 @@ typedef bool (*RegisterDevicePluginFn)(CustomRuntimeParams* runtime_params);
void
LoadCustomRuntimeLib
(
const
CustomRuntimeParams
&
runtime_params
,
void
LoadCustomRuntimeLib
(
const
CustomRuntimeParams
&
runtime_params
,
std
::
unique_ptr
<
C_DeviceInterface
>
device_interface
,
std
::
unique_ptr
<
C_DeviceInterface
>
device_interface
,
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
)
{
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
)
{
if
(
ValidCustomCustomRuntimeParams
(
&
runtime_params
))
{
if
(
ValidCustomCustomRuntimeParams
(
&
runtime_params
))
{
auto
device
=
auto
device
=
std
::
make_unique
<
CustomDevice
>
(
runtime_params
.
device_type
,
std
::
make_unique
<
CustomDevice
>
(
runtime_params
.
device_type
,
255
,
true
,
255
,
std
::
move
(
device_interface
),
dso_handle
);
true
,
std
::
move
(
device_interface
),
dso_handle
);
if
(
false
==
DeviceManager
::
Register
(
std
::
move
(
device
)))
{
if
(
false
==
DeviceManager
::
Register
(
std
::
move
(
device
)))
{
LOG
(
WARNING
)
<<
"Skipped lib ["
<<
dso_lib_path
LOG
(
WARNING
)
<<
"Skipped lib ["
<<
dso_lib_path
<<
"]. Register failed!!! there may be a "
<<
"]. Register failed!!! there may be a "
...
@@ -665,10 +705,9 @@ void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle) {
...
@@ -665,10 +705,9 @@ void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle) {
"compatibility between PaddlePaddle and Custom Runtime."
;
"compatibility between PaddlePaddle and Custom Runtime."
;
return
;
return
;
}
}
LoadCustomRuntimeLib
(
runtime_params
,
std
::
move
(
device_interface
),
LoadCustomRuntimeLib
(
dso_lib_path
,
dso_handle
);
runtime_params
,
std
::
move
(
device_interface
),
dso_lib_path
,
dso_handle
);
LOG
(
INFO
)
<<
"Successed in loading custom runtime in lib: "
<<
dso_lib_path
;
LOG
(
INFO
)
<<
"Successed in loading custom runtime in lib: "
<<
dso_lib_path
;
}
}
}
// namespace platform
}
// namespace phi
}
// namespace paddle
paddle/
fluid/platform/device
/custom/custom_device_test.cc
→
paddle/
phi/backends
/custom/custom_device_test.cc
浏览文件 @
b53cdc9e
...
@@ -17,9 +17,9 @@
...
@@ -17,9 +17,9 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device/custom/fake_cpu_device.h"
#include "paddle/fluid/platform/device/device_manager.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/custom/fake_cpu_device.h"
#include "paddle/phi/backends/device_manager.h"
void
RegisterDevice
()
{
void
RegisterDevice
()
{
CustomRuntimeParams
runtime_params
;
CustomRuntimeParams
runtime_params
;
...
@@ -30,23 +30,22 @@ void RegisterDevice() {
...
@@ -30,23 +30,22 @@ void RegisterDevice() {
runtime_params
.
interface
->
size
=
sizeof
(
C_DeviceInterface
);
runtime_params
.
interface
->
size
=
sizeof
(
C_DeviceInterface
);
InitFakeCPUDevice
(
&
runtime_params
);
InitFakeCPUDevice
(
&
runtime_params
);
p
addle
::
platform
::
LoadCustomRuntimeLib
(
p
hi
::
LoadCustomRuntimeLib
(
runtime_params
,
std
::
move
(
device_interface
),
""
,
nullptr
);
runtime_params
,
std
::
move
(
device_interface
),
""
,
nullptr
);
}
}
void
InitDevice
()
{
void
InitDevice
()
{
RegisterDevice
();
RegisterDevice
();
EXPECT_GT
(
static_cast
<
int
>
(
EXPECT_GT
(
static_cast
<
int
>
(
phi
::
DeviceManager
::
GetAllDeviceTypes
().
size
()),
paddle
::
platform
::
DeviceManager
::
GetAllDeviceTypes
().
size
()),
0
);
0
);
auto
place
=
paddle
::
platform
::
CustomPlace
(
DEVICE_TYPE
,
0
);
auto
place
=
paddle
::
platform
::
CustomPlace
(
DEVICE_TYPE
,
0
);
auto
device
=
p
addle
::
platform
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
auto
device
=
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
EXPECT_NE
(
device
,
nullptr
);
EXPECT_NE
(
device
,
nullptr
);
std
::
vector
<
paddle
::
platform
::
Place
>
places
;
std
::
vector
<
paddle
::
platform
::
Place
>
places
;
auto
device_types
=
p
addle
::
platform
::
DeviceManager
::
GetAllDeviceTypes
();
auto
device_types
=
p
hi
::
DeviceManager
::
GetAllDeviceTypes
();
for
(
auto
dev_type
:
device_types
)
{
for
(
auto
dev_type
:
device_types
)
{
auto
devices
=
p
addle
::
platform
::
DeviceManager
::
GetDeviceList
(
dev_type
);
auto
devices
=
p
hi
::
DeviceManager
::
GetDeviceList
(
dev_type
);
for
(
auto
dev_id
:
devices
)
{
for
(
auto
dev_id
:
devices
)
{
places
.
push_back
(
places
.
push_back
(
paddle
::
platform
::
PlaceHelper
::
CreatePlace
(
dev_type
,
dev_id
));
paddle
::
platform
::
PlaceHelper
::
CreatePlace
(
dev_type
,
dev_id
));
...
@@ -60,14 +59,14 @@ void InitDevice() {
...
@@ -60,14 +59,14 @@ void InitDevice() {
void
TestDeviceInterface
(
const
paddle
::
platform
::
Place
&
place
)
{
void
TestDeviceInterface
(
const
paddle
::
platform
::
Place
&
place
)
{
std
::
cout
<<
"TestDeviceInterface on "
<<
place
<<
std
::
endl
;
std
::
cout
<<
"TestDeviceInterface on "
<<
place
<<
std
::
endl
;
if
(
paddle
::
platform
::
is_custom_place
(
place
))
{
if
(
paddle
::
platform
::
is_custom_place
(
place
))
{
auto
device
=
p
addle
::
platform
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
auto
device
=
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
auto
dev_type
=
paddle
::
platform
::
PlaceHelper
::
GetDeviceType
(
place
);
auto
dev_type
=
paddle
::
platform
::
PlaceHelper
::
GetDeviceType
(
place
);
auto
p1
=
device
->
MemoryAllocate
(
auto
p1
=
paddle
::
platform
::
DeviceManager
::
GetMinChunkSize
(
place
));
device
->
MemoryAllocate
(
phi
::
DeviceManager
::
GetMinChunkSize
(
place
));
EXPECT_NE
(
p1
,
nullptr
);
EXPECT_NE
(
p1
,
nullptr
);
p
addle
::
platform
::
DeviceManager
::
SetDevice
(
place
);
p
hi
::
DeviceManager
::
SetDevice
(
place
);
auto
dev_id
=
p
addle
::
platform
::
DeviceManager
::
GetDevice
(
dev_type
);
auto
dev_id
=
p
hi
::
DeviceManager
::
GetDevice
(
dev_type
);
EXPECT_EQ
(
dev_id
,
place
.
GetDeviceId
());
EXPECT_EQ
(
dev_id
,
place
.
GetDeviceId
());
}
}
}
}
...
@@ -168,11 +167,10 @@ void TestTensorUtils(const paddle::platform::Place& place) {
...
@@ -168,11 +167,10 @@ void TestTensorUtils(const paddle::platform::Place& place) {
TEST
(
CustomDevice
,
Tensor
)
{
TEST
(
CustomDevice
,
Tensor
)
{
InitDevice
();
InitDevice
();
auto
dev_types
=
p
addle
::
platform
::
DeviceManager
::
GetAllDeviceTypes
();
auto
dev_types
=
p
hi
::
DeviceManager
::
GetAllDeviceTypes
();
for
(
const
auto
&
dev_type
:
dev_types
)
{
for
(
const
auto
&
dev_type
:
dev_types
)
{
std
::
cout
<<
"Test on "
<<
dev_type
<<
std
::
endl
;
std
::
cout
<<
"Test on "
<<
dev_type
<<
std
::
endl
;
EXPECT_GT
(
static_cast
<
int
>
(
EXPECT_GT
(
static_cast
<
int
>
(
phi
::
DeviceManager
::
GetDeviceCount
(
dev_type
)),
paddle
::
platform
::
DeviceManager
::
GetDeviceCount
(
dev_type
)),
0
);
0
);
auto
place
=
paddle
::
platform
::
PlaceHelper
::
CreatePlace
(
dev_type
);
auto
place
=
paddle
::
platform
::
PlaceHelper
::
CreatePlace
(
dev_type
);
...
...
paddle/
fluid/platform/device
/custom/fake_cpu_device.h
→
paddle/
phi/backends
/custom/fake_cpu_device.h
浏览文件 @
b53cdc9e
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include "paddle/
fluid/platform/device
/device_ext.h"
#include "paddle/
phi/backends
/device_ext.h"
constexpr
size_t
global_total_memory
=
1024
*
1024UL
;
constexpr
size_t
global_total_memory
=
1024
*
1024UL
;
static
size_t
global_free_memory
=
global_total_memory
;
static
size_t
global_free_memory
=
global_total_memory
;
...
@@ -43,14 +43,19 @@ C_Status GetDevicesList(size_t *device) {
...
@@ -43,14 +43,19 @@ C_Status GetDevicesList(size_t *device) {
return
C_SUCCESS
;
return
C_SUCCESS
;
}
}
C_Status
MemCpy
(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
C_Status
MemCpy
(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
size_t
size
)
{
size_t
size
)
{
memcpy
(
dst
,
src
,
size
);
memcpy
(
dst
,
src
,
size
);
return
C_SUCCESS
;
return
C_SUCCESS
;
}
}
C_Status
AsyncMemCpy
(
const
C_Device
device
,
C_Stream
stream
,
void
*
dst
,
C_Status
AsyncMemCpy
(
const
C_Device
device
,
const
void
*
src
,
size_t
size
)
{
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
)
{
memcpy
(
dst
,
src
,
size
);
memcpy
(
dst
,
src
,
size
);
return
C_SUCCESS
;
return
C_SUCCESS
;
}
}
...
@@ -100,14 +105,16 @@ C_Status SyncStream(const C_Device device, C_Stream stream) {
...
@@ -100,14 +105,16 @@ C_Status SyncStream(const C_Device device, C_Stream stream) {
C_Status
SyncEvent
(
const
C_Device
device
,
C_Event
event
)
{
return
C_SUCCESS
;
}
C_Status
SyncEvent
(
const
C_Device
device
,
C_Event
event
)
{
return
C_SUCCESS
;
}
C_Status
StreamWaitEvent
(
const
C_Device
device
,
C_Stream
stream
,
C_Status
StreamWaitEvent
(
const
C_Device
device
,
C_Stream
stream
,
C_Event
event
)
{
C_Event
event
)
{
return
C_SUCCESS
;
return
C_SUCCESS
;
}
}
C_Status
VisibleDevices
(
size_t
*
devices
)
{
return
C_SUCCESS
;
}
C_Status
VisibleDevices
(
size_t
*
devices
)
{
return
C_SUCCESS
;
}
C_Status
DeviceMemStats
(
const
C_Device
device
,
size_t
*
total_memory
,
C_Status
DeviceMemStats
(
const
C_Device
device
,
size_t
*
total_memory
,
size_t
*
free_memory
)
{
size_t
*
free_memory
)
{
*
total_memory
=
global_total_memory
;
*
total_memory
=
global_total_memory
;
*
free_memory
=
global_free_memory
;
*
free_memory
=
global_free_memory
;
...
@@ -139,7 +146,8 @@ void InitFakeCPUDevice(CustomRuntimeParams *params) {
...
@@ -139,7 +146,8 @@ void InitFakeCPUDevice(CustomRuntimeParams *params) {
params
->
version
.
minor
=
PADDLE_CUSTOM_RUNTIME_MINOR_VERSION
;
params
->
version
.
minor
=
PADDLE_CUSTOM_RUNTIME_MINOR_VERSION
;
params
->
version
.
patch
=
PADDLE_CUSTOM_RUNTIME_PATCH_VERSION
;
params
->
version
.
patch
=
PADDLE_CUSTOM_RUNTIME_PATCH_VERSION
;
memset
(
reinterpret_cast
<
void
*>
(
params
->
interface
),
0
,
memset
(
reinterpret_cast
<
void
*>
(
params
->
interface
),
0
,
sizeof
(
C_DeviceInterface
));
sizeof
(
C_DeviceInterface
));
params
->
interface
->
initialize
=
Init
;
params
->
interface
->
initialize
=
Init
;
...
...
paddle/
fluid/platform/device
/device_base.cc
→
paddle/
phi/backends
/device_base.cc
浏览文件 @
b53cdc9e
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/
fluid/platform/device
/device_base.h"
#include "paddle/
phi/backends
/device_base.h"
#include "gflags/gflags.h"
#include "gflags/gflags.h"
DECLARE_double
(
fraction_of_gpu_memory_to_use
);
DECLARE_double
(
fraction_of_gpu_memory_to_use
);
...
@@ -21,26 +21,25 @@ DECLARE_uint64(reallocate_gpu_memory_in_mb);
...
@@ -21,26 +21,25 @@ DECLARE_uint64(reallocate_gpu_memory_in_mb);
constexpr
static
float
fraction_reserve_gpu_memory
=
0.05
f
;
constexpr
static
float
fraction_reserve_gpu_memory
=
0.05
f
;
namespace
paddle
{
namespace
phi
{
namespace
platform
{
#define INTERFACE_UNIMPLEMENT
\
#define INTERFACE_UNIMPLEMENT \
PADDLE_THROW(p
latform
::errors::Unimplemented( \
PADDLE_THROW(p
hi
::errors::Unimplemented( \
"%s is not implemented on %s device.", __func__, Type()));
"%s is not implemented on %s device.", __func__, Type()));
// info
// info
size_t
DeviceInterface
::
GetComputeCapability
()
{
size_t
DeviceInterface
::
GetComputeCapability
()
{
VLOG
(
10
)
<<
Type
()
+
" get compute capability "
<<
0
;
VLOG
(
10
)
<<
Type
()
<<
" get compute capability "
<<
0
;
return
0
;
return
0
;
}
}
size_t
DeviceInterface
::
GetRuntimeVersion
()
{
size_t
DeviceInterface
::
GetRuntimeVersion
()
{
VLOG
(
10
)
<<
Type
()
+
" get runtime version "
<<
0
;
VLOG
(
10
)
<<
Type
()
<<
" get runtime version "
<<
0
;
return
0
;
return
0
;
}
}
size_t
DeviceInterface
::
GetDriverVersion
()
{
size_t
DeviceInterface
::
GetDriverVersion
()
{
VLOG
(
10
)
<<
Type
()
+
" get driver version "
<<
0
;
VLOG
(
10
)
<<
Type
()
<<
" get driver version "
<<
0
;
return
0
;
return
0
;
}
}
...
@@ -62,7 +61,8 @@ void DeviceInterface::SetDevice(size_t dev_id) { INTERFACE_UNIMPLEMENT; }
...
@@ -62,7 +61,8 @@ void DeviceInterface::SetDevice(size_t dev_id) { INTERFACE_UNIMPLEMENT; }
int
DeviceInterface
::
GetDevice
()
{
INTERFACE_UNIMPLEMENT
;
}
int
DeviceInterface
::
GetDevice
()
{
INTERFACE_UNIMPLEMENT
;
}
// stream manage
// stream manage
void
DeviceInterface
::
CreateStream
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
void
DeviceInterface
::
CreateStream
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
const
stream
::
Stream
::
Priority
&
priority
,
const
stream
::
Stream
::
Priority
&
priority
,
const
stream
::
Stream
::
Flag
&
flag
)
{
const
stream
::
Stream
::
Flag
&
flag
)
{
INTERFACE_UNIMPLEMENT
;
INTERFACE_UNIMPLEMENT
;
...
@@ -82,7 +82,8 @@ bool DeviceInterface::QueryStream(size_t dev_id, const stream::Stream* stream) {
...
@@ -82,7 +82,8 @@ bool DeviceInterface::QueryStream(size_t dev_id, const stream::Stream* stream) {
return
true
;
return
true
;
}
}
void
DeviceInterface
::
AddCallback
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
void
DeviceInterface
::
AddCallback
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
stream
::
Stream
::
Callback
*
callback
)
{
stream
::
Stream
::
Callback
*
callback
)
{
INTERFACE_UNIMPLEMENT
;
INTERFACE_UNIMPLEMENT
;
}
}
...
@@ -94,7 +95,8 @@ void DeviceInterface::StreamWaitEvent(size_t dev_id,
...
@@ -94,7 +95,8 @@ void DeviceInterface::StreamWaitEvent(size_t dev_id,
}
}
// event manage
// event manage
void
DeviceInterface
::
CreateEvent
(
size_t
dev_id
,
event
::
Event
*
event
,
void
DeviceInterface
::
CreateEvent
(
size_t
dev_id
,
event
::
Event
*
event
,
event
::
Event
::
Flag
flags
)
{
event
::
Event
::
Flag
flags
)
{
INTERFACE_UNIMPLEMENT
;
INTERFACE_UNIMPLEMENT
;
}
}
...
@@ -103,7 +105,8 @@ void DeviceInterface::DestroyEvent(size_t dev_id, event::Event* event) {
...
@@ -103,7 +105,8 @@ void DeviceInterface::DestroyEvent(size_t dev_id, event::Event* event) {
INTERFACE_UNIMPLEMENT
;
INTERFACE_UNIMPLEMENT
;
}
}
void
DeviceInterface
::
RecordEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
,
void
DeviceInterface
::
RecordEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
,
const
stream
::
Stream
*
stream
)
{
const
stream
::
Stream
*
stream
)
{
INTERFACE_UNIMPLEMENT
;
INTERFACE_UNIMPLEMENT
;
}
}
...
@@ -119,23 +122,35 @@ bool DeviceInterface::QueryEvent(size_t dev_id, const event::Event* event) {
...
@@ -119,23 +122,35 @@ bool DeviceInterface::QueryEvent(size_t dev_id, const event::Event* event) {
}
}
// memery manage
// memery manage
void
DeviceInterface
::
MemoryCopyH2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
void
DeviceInterface
::
MemoryCopyH2D
(
size_t
dev_id
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
INTERFACE_UNIMPLEMENT
;
INTERFACE_UNIMPLEMENT
;
}
}
void
DeviceInterface
::
MemoryCopyD2H
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
void
DeviceInterface
::
MemoryCopyD2H
(
size_t
dev_id
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
INTERFACE_UNIMPLEMENT
;
INTERFACE_UNIMPLEMENT
;
}
}
void
DeviceInterface
::
MemoryCopyD2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
void
DeviceInterface
::
MemoryCopyD2D
(
size_t
dev_id
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
INTERFACE_UNIMPLEMENT
;
INTERFACE_UNIMPLEMENT
;
}
}
void
DeviceInterface
::
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
void
DeviceInterface
::
MemoryCopyP2P
(
const
Place
&
dst_place
,
size_t
src_id
,
const
void
*
src
,
size_t
size
,
void
*
dst
,
size_t
src_id
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
const
stream
::
Stream
*
stream
)
{
INTERFACE_UNIMPLEMENT
;
INTERFACE_UNIMPLEMENT
;
}
}
...
@@ -154,7 +169,8 @@ void* DeviceInterface::MemoryAllocateHost(size_t dev_id, size_t size) {
...
@@ -154,7 +169,8 @@ void* DeviceInterface::MemoryAllocateHost(size_t dev_id, size_t size) {
return
nullptr
;
return
nullptr
;
}
}
void
DeviceInterface
::
MemoryDeallocateHost
(
size_t
dev_id
,
void
*
ptr
,
void
DeviceInterface
::
MemoryDeallocateHost
(
size_t
dev_id
,
void
*
ptr
,
size_t
size
)
{
size_t
size
)
{
INTERFACE_UNIMPLEMENT
;
INTERFACE_UNIMPLEMENT
;
}
}
...
@@ -164,12 +180,15 @@ void* DeviceInterface::MemoryAllocateUnified(size_t dev_id, size_t size) {
...
@@ -164,12 +180,15 @@ void* DeviceInterface::MemoryAllocateUnified(size_t dev_id, size_t size) {
return
nullptr
;
return
nullptr
;
}
}
void
DeviceInterface
::
MemoryDeallocateUnified
(
size_t
dev_id
,
void
*
ptr
,
void
DeviceInterface
::
MemoryDeallocateUnified
(
size_t
dev_id
,
void
*
ptr
,
size_t
size
)
{
size_t
size
)
{
INTERFACE_UNIMPLEMENT
;
INTERFACE_UNIMPLEMENT
;
}
}
void
DeviceInterface
::
MemorySet
(
size_t
dev_id
,
void
*
ptr
,
uint8_t
value
,
void
DeviceInterface
::
MemorySet
(
size_t
dev_id
,
void
*
ptr
,
uint8_t
value
,
size_t
size
)
{
size_t
size
)
{
INTERFACE_UNIMPLEMENT
;
INTERFACE_UNIMPLEMENT
;
}
}
...
@@ -184,8 +203,9 @@ size_t DeviceInterface::GetMinChunkSize(size_t dev_id) {
...
@@ -184,8 +203,9 @@ size_t DeviceInterface::GetMinChunkSize(size_t dev_id) {
size_t
DeviceInterface
::
AllocSize
(
size_t
dev_id
,
bool
realloc
)
{
size_t
DeviceInterface
::
AllocSize
(
size_t
dev_id
,
bool
realloc
)
{
size_t
available_to_alloc
=
AvailableAllocSize
(
dev_id
);
size_t
available_to_alloc
=
AvailableAllocSize
(
dev_id
);
PADDLE_ENFORCE_GT
(
available_to_alloc
,
0
,
PADDLE_ENFORCE_GT
(
available_to_alloc
,
platform
::
errors
::
ResourceExhausted
(
0
,
phi
::
errors
::
ResourceExhausted
(
"Not enough available %s memory."
,
Type
()));
"Not enough available %s memory."
,
Type
()));
// If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
// If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
// allocated by fraction
// allocated by fraction
...
@@ -194,8 +214,9 @@ size_t DeviceInterface::AllocSize(size_t dev_id, bool realloc) {
...
@@ -194,8 +214,9 @@ size_t DeviceInterface::AllocSize(size_t dev_id, bool realloc) {
size_t
alloc_bytes
=
size_t
alloc_bytes
=
(
flag_mb
>
0ul
?
flag_mb
<<
20
:
available_to_alloc
*
(
flag_mb
>
0ul
?
flag_mb
<<
20
:
available_to_alloc
*
FLAGS_fraction_of_gpu_memory_to_use
);
FLAGS_fraction_of_gpu_memory_to_use
);
PADDLE_ENFORCE_GE
(
available_to_alloc
,
alloc_bytes
,
PADDLE_ENFORCE_GE
(
available_to_alloc
,
platform
::
errors
::
ResourceExhausted
(
alloc_bytes
,
phi
::
errors
::
ResourceExhausted
(
"Not enough available %s memory."
,
Type
()));
"Not enough available %s memory."
,
Type
()));
return
alloc_bytes
;
return
alloc_bytes
;
}
}
...
@@ -217,33 +238,32 @@ size_t DeviceInterface::AvailableAllocSize(size_t dev_id) {
...
@@ -217,33 +238,32 @@ size_t DeviceInterface::AvailableAllocSize(size_t dev_id) {
size_t
DeviceInterface
::
GetInitAllocSize
(
size_t
dev_id
)
{
size_t
DeviceInterface
::
GetInitAllocSize
(
size_t
dev_id
)
{
size_t
init_alloc_size
=
AllocSize
(
dev_id
,
false
);
size_t
init_alloc_size
=
AllocSize
(
dev_id
,
false
);
VLOG
(
10
)
<<
Type
()
+
" init alloc size "
<<
(
init_alloc_size
>>
20
)
<<
"M"
;
VLOG
(
10
)
<<
Type
()
<<
" init alloc size "
<<
(
init_alloc_size
>>
20
)
<<
"M"
;
return
init_alloc_size
;
return
init_alloc_size
;
}
}
size_t
DeviceInterface
::
GetReallocSize
(
size_t
dev_id
)
{
size_t
DeviceInterface
::
GetReallocSize
(
size_t
dev_id
)
{
size_t
realloc_size
=
AllocSize
(
dev_id
,
true
);
size_t
realloc_size
=
AllocSize
(
dev_id
,
true
);
VLOG
(
10
)
<<
Type
()
+
" realloc size "
<<
(
realloc_size
>>
20
)
<<
"M"
;
VLOG
(
10
)
<<
Type
()
<<
" realloc size "
<<
(
realloc_size
>>
20
)
<<
"M"
;
return
realloc_size
;
return
realloc_size
;
}
}
size_t
DeviceInterface
::
GetMaxAllocSize
(
size_t
dev_id
)
{
size_t
DeviceInterface
::
GetMaxAllocSize
(
size_t
dev_id
)
{
size_t
max_alloc_size
=
size_t
max_alloc_size
=
std
::
max
(
GetInitAllocSize
(
dev_id
),
GetReallocSize
(
dev_id
));
std
::
max
(
GetInitAllocSize
(
dev_id
),
GetReallocSize
(
dev_id
));
VLOG
(
10
)
<<
Type
()
+
" max alloc size "
<<
(
max_alloc_size
>>
20
)
<<
"M"
;
VLOG
(
10
)
<<
Type
()
<<
" max alloc size "
<<
(
max_alloc_size
>>
20
)
<<
"M"
;
return
max_alloc_size
;
return
max_alloc_size
;
}
}
size_t
DeviceInterface
::
GetMaxChunkSize
(
size_t
dev_id
)
{
size_t
DeviceInterface
::
GetMaxChunkSize
(
size_t
dev_id
)
{
size_t
max_chunk_size
=
GetMaxAllocSize
(
dev_id
);
size_t
max_chunk_size
=
GetMaxAllocSize
(
dev_id
);
VLOG
(
10
)
<<
Type
()
+
" max chunk size "
<<
(
max_chunk_size
>>
20
)
<<
"M"
;
VLOG
(
10
)
<<
Type
()
<<
" max chunk size "
<<
(
max_chunk_size
>>
20
)
<<
"M"
;
return
max_chunk_size
;
return
max_chunk_size
;
}
}
size_t
DeviceInterface
::
GetExtraPaddingSize
(
size_t
dev_id
)
{
size_t
DeviceInterface
::
GetExtraPaddingSize
(
size_t
dev_id
)
{
VLOG
(
10
)
<<
Type
()
+
" extra padding size "
<<
0
;
VLOG
(
10
)
<<
Type
()
<<
" extra padding size "
<<
0
;
return
0
;
return
0
;
}
}
}
// namespace platform
}
// namespace phi
}
// namespace paddle
paddle/
fluid/platform/device
/device_base.h
→
paddle/
phi/backends
/device_base.h
浏览文件 @
b53cdc9e
...
@@ -14,11 +14,10 @@
...
@@ -14,11 +14,10 @@
#pragma once
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/
fluid/platform/device
/event.h"
#include "paddle/
phi/backends
/event.h"
#include "paddle/
fluid/platform/device
/stream.h"
#include "paddle/
phi/backends
/stream.h"
namespace
paddle
{
namespace
phi
{
namespace
platform
{
class
DeviceInterface
{
// Driver / Runtime
class
DeviceInterface
{
// Driver / Runtime
public:
public:
...
@@ -66,7 +65,8 @@ class DeviceInterface { // Driver / Runtime
...
@@ -66,7 +65,8 @@ class DeviceInterface { // Driver / Runtime
// Stream
// Stream
// ! Create an asynchronous stream
// ! Create an asynchronous stream
virtual
void
CreateStream
(
virtual
void
CreateStream
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
size_t
dev_id
,
stream
::
Stream
*
stream
,
const
stream
::
Stream
::
Priority
&
priority
=
const
stream
::
Stream
::
Priority
&
priority
=
stream
::
Stream
::
Priority
::
kNormal
,
stream
::
Stream
::
Priority
::
kNormal
,
const
stream
::
Stream
::
Flag
&
flag
=
stream
::
Stream
::
Flag
::
kDefaultFlag
);
const
stream
::
Stream
::
Flag
&
flag
=
stream
::
Stream
::
Flag
::
kDefaultFlag
);
...
@@ -81,19 +81,22 @@ class DeviceInterface { // Driver / Runtime
...
@@ -81,19 +81,22 @@ class DeviceInterface { // Driver / Runtime
virtual
bool
QueryStream
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
);
virtual
bool
QueryStream
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
);
// ! Add a callback to a compute stream.
// ! Add a callback to a compute stream.
virtual
void
AddCallback
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
virtual
void
AddCallback
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
stream
::
Stream
::
Callback
*
callback
);
stream
::
Stream
::
Callback
*
callback
);
// Event
// Event
// ! Create an event.
// ! Create an event.
virtual
void
CreateEvent
(
size_t
dev_id
,
event
::
Event
*
event
,
virtual
void
CreateEvent
(
size_t
dev_id
,
event
::
Event
*
event
,
event
::
Event
::
Flag
flags
);
event
::
Event
::
Flag
flags
);
// ! Destroy an event.
// ! Destroy an event.
virtual
void
DestroyEvent
(
size_t
dev_id
,
event
::
Event
*
event
);
virtual
void
DestroyEvent
(
size_t
dev_id
,
event
::
Event
*
event
);
// ! Records an event.
// ! Records an event.
virtual
void
RecordEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
,
virtual
void
RecordEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
,
const
stream
::
Stream
*
stream
);
const
stream
::
Stream
*
stream
);
// ! Waits for event to complete.
// ! Waits for event to complete.
...
@@ -102,24 +105,34 @@ class DeviceInterface { // Driver / Runtime
...
@@ -102,24 +105,34 @@ class DeviceInterface { // Driver / Runtime
virtual
bool
QueryEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
);
virtual
bool
QueryEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
);
// ! Make a compute stream wait on an event
// ! Make a compute stream wait on an event
virtual
void
StreamWaitEvent
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
,
virtual
void
StreamWaitEvent
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
,
const
event
::
Event
*
event
);
const
event
::
Event
*
event
);
// Memory
// Memory
virtual
void
MemoryCopyH2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
virtual
void
MemoryCopyH2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
const
stream
::
Stream
*
stream
=
nullptr
);
virtual
void
MemoryCopyD2H
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
virtual
void
MemoryCopyD2H
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
const
stream
::
Stream
*
stream
=
nullptr
);
virtual
void
MemoryCopyD2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
virtual
void
MemoryCopyD2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
const
stream
::
Stream
*
stream
=
nullptr
);
virtual
void
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
size_t
src_id
,
virtual
void
MemoryCopyP2P
(
const
Place
&
dst_place
,
const
void
*
src
,
size_t
size
,
void
*
dst
,
size_t
src_id
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
const
stream
::
Stream
*
stream
=
nullptr
);
virtual
void
*
MemoryAllocate
(
size_t
dev_id
,
size_t
size
);
virtual
void
*
MemoryAllocate
(
size_t
dev_id
,
size_t
size
);
...
@@ -160,7 +173,6 @@ class DeviceInterface { // Driver / Runtime
...
@@ -160,7 +173,6 @@ class DeviceInterface { // Driver / Runtime
size_t
AvailableAllocSize
(
size_t
dev_id
);
size_t
AvailableAllocSize
(
size_t
dev_id
);
};
};
}
// namespace platform
}
// namespace phi
}
// namespace paddle
#endif
#endif
paddle/
fluid/platform/device
/device_ext.h
→
paddle/
phi/backends
/device_ext.h
浏览文件 @
b53cdc9e
...
@@ -40,7 +40,9 @@ typedef struct C_Stream_st* C_Stream;
...
@@ -40,7 +40,9 @@ typedef struct C_Stream_st* C_Stream;
typedef
struct
C_Event_st
*
C_Event
;
typedef
struct
C_Event_st
*
C_Event
;
typedef
void
(
*
C_Callback
)(
C_Device
device
,
C_Stream
stream
,
void
*
user_data
,
typedef
void
(
*
C_Callback
)(
C_Device
device
,
C_Stream
stream
,
void
*
user_data
,
C_Status
*
status
);
C_Status
*
status
);
struct
C_DeviceInterface
{
struct
C_DeviceInterface
{
...
@@ -124,8 +126,10 @@ struct C_DeviceInterface {
...
@@ -124,8 +126,10 @@ struct C_DeviceInterface {
* @param[C_Callback] callback
* @param[C_Callback] callback
* @param[void*] user_data
* @param[void*] user_data
*/
*/
C_Status
(
*
stream_add_callback
)(
const
C_Device
device
,
C_Stream
stream
,
C_Status
(
*
stream_add_callback
)(
const
C_Device
device
,
C_Callback
callback
,
void
*
user_data
);
C_Stream
stream
,
C_Callback
callback
,
void
*
user_data
);
/**
/**
* @brief Create an event
* @brief Create an event
...
@@ -142,7 +146,8 @@ struct C_DeviceInterface {
...
@@ -142,7 +146,8 @@ struct C_DeviceInterface {
* @param[C_Stream] stream
* @param[C_Stream] stream
* @param[C_Event] event
* @param[C_Event] event
*/
*/
C_Status
(
*
record_event
)(
const
C_Device
device
,
C_Stream
stream
,
C_Status
(
*
record_event
)(
const
C_Device
device
,
C_Stream
stream
,
C_Event
event
);
C_Event
event
);
/**
/**
...
@@ -191,7 +196,8 @@ struct C_DeviceInterface {
...
@@ -191,7 +196,8 @@ struct C_DeviceInterface {
* @param[C_Stream] stream
* @param[C_Stream] stream
* @param[C_Event] event
* @param[C_Event] event
*/
*/
C_Status
(
*
stream_wait_event
)(
const
C_Device
device
,
C_Stream
stream
,
C_Status
(
*
stream_wait_event
)(
const
C_Device
device
,
C_Stream
stream
,
C_Event
event
);
C_Event
event
);
void
*
reserved_dev_api
[
8
];
void
*
reserved_dev_api
[
8
];
...
@@ -207,7 +213,8 @@ struct C_DeviceInterface {
...
@@ -207,7 +213,8 @@ struct C_DeviceInterface {
* @param[void**] ptr Plugin allocate an address and fill it
* @param[void**] ptr Plugin allocate an address and fill it
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
device_memory_allocate
)(
const
C_Device
device
,
void
**
ptr
,
C_Status
(
*
device_memory_allocate
)(
const
C_Device
device
,
void
**
ptr
,
size_t
size
);
size_t
size
);
/**
/**
...
@@ -217,7 +224,8 @@ struct C_DeviceInterface {
...
@@ -217,7 +224,8 @@ struct C_DeviceInterface {
* @param[void*] ptr
* @param[void*] ptr
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
device_memory_deallocate
)(
const
C_Device
device
,
void
*
ptr
,
C_Status
(
*
device_memory_deallocate
)(
const
C_Device
device
,
void
*
ptr
,
size_t
size
);
size_t
size
);
/**
/**
...
@@ -228,8 +236,10 @@ struct C_DeviceInterface {
...
@@ -228,8 +236,10 @@ struct C_DeviceInterface {
* @param[unsigned char] value
* @param[unsigned char] value
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
device_memory_set
)(
const
C_Device
device
,
void
*
ptr
,
C_Status
(
*
device_memory_set
)(
const
C_Device
device
,
unsigned
char
value
,
size_t
size
);
void
*
ptr
,
unsigned
char
value
,
size_t
size
);
/**
/**
* @brief Host memory allocate
* @brief Host memory allocate
...
@@ -238,7 +248,8 @@ struct C_DeviceInterface {
...
@@ -238,7 +248,8 @@ struct C_DeviceInterface {
* @param[void**] ptr Plugin allocate an address and fill it
* @param[void**] ptr Plugin allocate an address and fill it
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
host_memory_allocate
)(
const
C_Device
device
,
void
**
ptr
,
C_Status
(
*
host_memory_allocate
)(
const
C_Device
device
,
void
**
ptr
,
size_t
size
);
size_t
size
);
/**
/**
...
@@ -248,7 +259,8 @@ struct C_DeviceInterface {
...
@@ -248,7 +259,8 @@ struct C_DeviceInterface {
* @param[void*] ptr
* @param[void*] ptr
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
host_memory_deallocate
)(
const
C_Device
device
,
void
*
ptr
,
C_Status
(
*
host_memory_deallocate
)(
const
C_Device
device
,
void
*
ptr
,
size_t
size
);
size_t
size
);
/**
/**
...
@@ -258,7 +270,8 @@ struct C_DeviceInterface {
...
@@ -258,7 +270,8 @@ struct C_DeviceInterface {
* @param[void**] ptr Plugin allocate an address and fill it
* @param[void**] ptr Plugin allocate an address and fill it
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
unified_memory_allocate
)(
const
C_Device
device
,
void
**
ptr
,
C_Status
(
*
unified_memory_allocate
)(
const
C_Device
device
,
void
**
ptr
,
size_t
size
);
size_t
size
);
/**
/**
...
@@ -268,7 +281,8 @@ struct C_DeviceInterface {
...
@@ -268,7 +281,8 @@ struct C_DeviceInterface {
* @param[void*] ptr
* @param[void*] ptr
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
unified_memory_deallocate
)(
const
C_Device
device
,
void
*
ptr
,
C_Status
(
*
unified_memory_deallocate
)(
const
C_Device
device
,
void
*
ptr
,
size_t
size
);
size_t
size
);
/**
/**
...
@@ -279,7 +293,9 @@ struct C_DeviceInterface {
...
@@ -279,7 +293,9 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[void*] src
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
memory_copy_h2d
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
C_Status
(
*
memory_copy_h2d
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
size_t
size
);
/**
/**
...
@@ -290,7 +306,9 @@ struct C_DeviceInterface {
...
@@ -290,7 +306,9 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[void*] src
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
memory_copy_d2h
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
C_Status
(
*
memory_copy_d2h
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
size_t
size
);
/**
/**
...
@@ -301,7 +319,9 @@ struct C_DeviceInterface {
...
@@ -301,7 +319,9 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[void*] src
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
memory_copy_d2d
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
C_Status
(
*
memory_copy_d2d
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
size_t
size
);
/**
/**
...
@@ -314,8 +334,10 @@ struct C_DeviceInterface {
...
@@ -314,8 +334,10 @@ struct C_DeviceInterface {
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
memory_copy_p2p
)(
const
C_Device
dst_device
,
C_Status
(
*
memory_copy_p2p
)(
const
C_Device
dst_device
,
const
C_Device
src_device
,
void
*
dst
,
const
C_Device
src_device
,
const
void
*
src
,
size_t
size
);
void
*
dst
,
const
void
*
src
,
size_t
size
);
/**
/**
* @brief Asynchonrize memory copy from host to device
* @brief Asynchonrize memory copy from host to device
...
@@ -326,8 +348,11 @@ struct C_DeviceInterface {
...
@@ -326,8 +348,11 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[void*] src
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
async_memory_copy_h2d
)(
const
C_Device
device
,
C_Stream
stream
,
C_Status
(
*
async_memory_copy_h2d
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
/**
/**
* @brief Asynchonrize memory copy from device to host
* @brief Asynchonrize memory copy from device to host
...
@@ -338,8 +363,11 @@ struct C_DeviceInterface {
...
@@ -338,8 +363,11 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[void*] src
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
async_memory_copy_d2h
)(
const
C_Device
device
,
C_Stream
stream
,
C_Status
(
*
async_memory_copy_d2h
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
/**
/**
* @brief Asynchonrize memory copy from device to device
* @brief Asynchonrize memory copy from device to device
...
@@ -350,8 +378,11 @@ struct C_DeviceInterface {
...
@@ -350,8 +378,11 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[void*] src
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
async_memory_copy_d2d
)(
const
C_Device
device
,
C_Stream
stream
,
C_Status
(
*
async_memory_copy_d2d
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
/**
/**
* @brief Peer asynchonrize memory copy from host to device
* @brief Peer asynchonrize memory copy from host to device
...
@@ -363,8 +394,11 @@ struct C_DeviceInterface {
...
@@ -363,8 +394,11 @@ struct C_DeviceInterface {
* @param[size_t] size
* @param[size_t] size
*/
*/
C_Status
(
*
async_memory_copy_p2p
)(
const
C_Device
dst_device
,
C_Status
(
*
async_memory_copy_p2p
)(
const
C_Device
dst_device
,
const
C_Device
src_device
,
C_Stream
stream
,
const
C_Device
src_device
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
void
*
reserved_mem_api
[
8
];
void
*
reserved_mem_api
[
8
];
...
@@ -394,7 +428,8 @@ struct C_DeviceInterface {
...
@@ -394,7 +428,8 @@ struct C_DeviceInterface {
* @param[size_t*] free_memory
* @param[size_t*] free_memory
* @param[size_t*] used_memory
* @param[size_t*] used_memory
*/
*/
C_Status
(
*
device_memory_stats
)(
const
C_Device
device
,
size_t
*
total_memory
,
C_Status
(
*
device_memory_stats
)(
const
C_Device
device
,
size_t
*
total_memory
,
size_t
*
free_memory
);
size_t
*
free_memory
);
/**
/**
...
...
paddle/
fluid/platform/device
/device_guard.cc
→
paddle/
phi/backends
/device_guard.cc
浏览文件 @
b53cdc9e
...
@@ -12,11 +12,9 @@
...
@@ -12,11 +12,9 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/
fluid/platform/device
/device_guard.h"
#include "paddle/
phi/backends
/device_guard.h"
namespace
paddle
{
namespace
phi
{
namespace
platform
{
// Even this source file does not contains any code, it is better to keep this
// Even this source file does not contains any code, it is better to keep this
// source file for cmake dependency.
// source file for cmake dependency.
}
// namespace platform
}
// namespace phi
}
// namespace paddle
paddle/
fluid/platform/device
/device_guard.h
→
paddle/
phi/backends
/device_guard.h
浏览文件 @
b53cdc9e
...
@@ -13,17 +13,16 @@
...
@@ -13,17 +13,16 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include "paddle/
fluid/platform/device
/device_manager.h"
#include "paddle/
phi/backends
/device_manager.h"
namespace
paddle
{
namespace
phi
{
namespace
platform
{
class
DeviceGuard
{
class
DeviceGuard
{
public:
public:
explicit
inline
DeviceGuard
(
const
Place
&
place
)
explicit
inline
DeviceGuard
(
const
Place
&
place
)
:
dev_type_
(
PlaceHelper
::
GetDeviceType
(
place
))
{
:
dev_type_
(
place
.
GetDeviceType
(
))
{
prev_id
=
DeviceManager
::
GetDevice
(
dev_type_
);
prev_id
=
DeviceManager
::
GetDevice
(
dev_type_
);
cur_id
=
PlaceHelper
::
GetDeviceId
(
place
);
cur_id
=
place
.
GetDeviceId
(
);
if
(
cur_id
!=
prev_id
)
{
if
(
cur_id
!=
prev_id
)
{
DeviceManager
::
SetDevice
(
dev_type_
,
cur_id
);
DeviceManager
::
SetDevice
(
dev_type_
,
cur_id
);
...
@@ -44,5 +43,4 @@ class DeviceGuard {
...
@@ -44,5 +43,4 @@ class DeviceGuard {
std
::
string
dev_type_
;
std
::
string
dev_type_
;
};
};
}
// namespace platform
}
// namespace phi
}
// namespace paddle
paddle/
fluid/platform/device
/device_manager.cc
→
paddle/
phi/backends
/device_manager.cc
浏览文件 @
b53cdc9e
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
// limitations under the License.
// limitations under the License.
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/
fluid/platform/device
/device_manager.h"
#include "paddle/
phi/backends
/device_manager.h"
#if !defined(_WIN32)
#if !defined(_WIN32)
#include <dirent.h>
#include <dirent.h>
...
@@ -24,8 +24,7 @@
...
@@ -24,8 +24,7 @@
#include <functional>
#include <functional>
#include <regex>
#include <regex>
namespace
paddle
{
namespace
phi
{
namespace
platform
{
void
Device
::
CreateStream
(
stream
::
Stream
*
stream
,
void
Device
::
CreateStream
(
stream
::
Stream
*
stream
,
const
stream
::
Stream
::
Priority
&
priority
,
const
stream
::
Stream
::
Priority
&
priority
,
...
@@ -76,23 +75,32 @@ void Device::StreamWaitEvent(const stream::Stream* stream,
...
@@ -76,23 +75,32 @@ void Device::StreamWaitEvent(const stream::Stream* stream,
impl_
->
StreamWaitEvent
(
dev_id_
,
stream
,
event
);
impl_
->
StreamWaitEvent
(
dev_id_
,
stream
,
event
);
}
}
void
Device
::
MemoryCopyH2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
Device
::
MemoryCopyH2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
const
stream
::
Stream
*
stream
)
{
impl_
->
MemoryCopyH2D
(
dev_id_
,
dst
,
src
,
size
,
stream
);
impl_
->
MemoryCopyH2D
(
dev_id_
,
dst
,
src
,
size
,
stream
);
}
}
void
Device
::
MemoryCopyD2H
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
Device
::
MemoryCopyD2H
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
const
stream
::
Stream
*
stream
)
{
impl_
->
MemoryCopyD2H
(
dev_id_
,
dst
,
src
,
size
,
stream
);
impl_
->
MemoryCopyD2H
(
dev_id_
,
dst
,
src
,
size
,
stream
);
}
}
void
Device
::
MemoryCopyD2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
Device
::
MemoryCopyD2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
const
stream
::
Stream
*
stream
)
{
impl_
->
MemoryCopyD2D
(
dev_id_
,
dst
,
src
,
size
,
stream
);
impl_
->
MemoryCopyD2D
(
dev_id_
,
dst
,
src
,
size
,
stream
);
}
}
void
Device
::
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
const
void
*
src
,
void
Device
::
MemoryCopyP2P
(
const
Place
&
dst_place
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
impl_
->
MemoryCopyP2P
(
dst_place
,
dst
,
dev_id_
,
src
,
size
,
stream
);
impl_
->
MemoryCopyP2P
(
dst_place
,
dst
,
dev_id_
,
src
,
size
,
stream
);
}
}
...
@@ -173,7 +181,7 @@ DeviceInterface* DeviceManager::GetDeviceInterfaceWithType(
...
@@ -173,7 +181,7 @@ DeviceInterface* DeviceManager::GetDeviceInterfaceWithType(
}
else
{
}
else
{
LOG
(
ERROR
)
<<
"GetDeviceInterfaceWithType - "
<<
device_type
<<
" Failed
\n
"
;
LOG
(
ERROR
)
<<
"GetDeviceInterfaceWithType - "
<<
device_type
<<
" Failed
\n
"
;
PADDLE_THROW
(
PADDLE_THROW
(
p
latform
::
errors
::
Fatal
(
"Unregistered device type %s."
,
device_type
));
p
hi
::
errors
::
Fatal
(
"Unregistered device type %s."
,
device_type
));
return
nullptr
;
return
nullptr
;
}
}
}
}
...
@@ -182,17 +190,21 @@ Device* DeviceManager::GetDeviceWithPlace(const Place& place) {
...
@@ -182,17 +190,21 @@ Device* DeviceManager::GetDeviceWithPlace(const Place& place) {
phi
::
AutoRDLock
lock
(
&
_global_device_manager_rw_lock
);
phi
::
AutoRDLock
lock
(
&
_global_device_manager_rw_lock
);
auto
&
dev_map
=
Instance
().
device_map_
;
auto
&
dev_map
=
Instance
().
device_map_
;
auto
dev_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
dev_type
=
place
.
GetDeviceType
();
auto
dev_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
dev_id
=
place
.
GetDeviceId
();
PADDLE_ENFORCE_NE
(
dev_map
.
find
(
dev_type
),
dev_map
.
end
(),
PADDLE_ENFORCE_NE
(
platform
::
errors
::
NotFound
(
dev_map
.
find
(
dev_type
),
"Unable to find Device with type %s."
,
dev_type
));
dev_map
.
end
(),
phi
::
errors
::
NotFound
(
"Unable to find Device with type %s."
,
dev_type
));
auto
&
dev_vec
=
dev_map
[
dev_type
];
auto
&
dev_vec
=
dev_map
[
dev_type
];
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
dev_id
,
dev_vec
.
size
(),
dev_id
,
platform
::
errors
::
OutOfRange
(
dev_vec
.
size
(),
phi
::
errors
::
OutOfRange
(
"The visible devices count of type %s is %d, but dev_id is %d."
,
"The visible devices count of type %s is %d, but dev_id is %d."
,
dev_type
,
dev_vec
.
size
(),
dev_id
));
dev_type
,
dev_vec
.
size
(),
dev_id
));
return
dev_vec
[
dev_id
].
get
();
return
dev_vec
[
dev_id
].
get
();
}
}
...
@@ -277,22 +289,22 @@ void DeviceManager::Finalize(const std::string& device_type) {
...
@@ -277,22 +289,22 @@ void DeviceManager::Finalize(const std::string& device_type) {
}
}
void
DeviceManager
::
SynchronizeDevice
(
const
Place
&
place
)
{
void
DeviceManager
::
SynchronizeDevice
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
dev_impl
->
SynchronizeDevice
(
device_id
);
dev_impl
->
SynchronizeDevice
(
device_id
);
}
}
void
DeviceManager
::
InitDevice
(
const
Place
&
place
)
{
void
DeviceManager
::
InitDevice
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
dev_impl
->
InitDevice
(
device_id
);
dev_impl
->
InitDevice
(
device_id
);
}
}
void
DeviceManager
::
DeInitDevice
(
const
Place
&
place
)
{
void
DeviceManager
::
DeInitDevice
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
dev_impl
->
DeInitDevice
(
device_id
);
dev_impl
->
DeInitDevice
(
device_id
);
}
}
...
@@ -304,8 +316,8 @@ void DeviceManager::SetDevice(const std::string& device_type,
...
@@ -304,8 +316,8 @@ void DeviceManager::SetDevice(const std::string& device_type,
}
}
void
DeviceManager
::
SetDevice
(
const
Place
&
place
)
{
void
DeviceManager
::
SetDevice
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_id
=
place
.
GetDeviceId
(
);
DeviceManager
::
SetDevice
(
device_type
,
device_id
);
DeviceManager
::
SetDevice
(
device_type
,
device_id
);
}
}
...
@@ -315,51 +327,52 @@ int DeviceManager::GetDevice(const std::string& device_type) {
...
@@ -315,51 +327,52 @@ int DeviceManager::GetDevice(const std::string& device_type) {
}
}
size_t
DeviceManager
::
GetMinChunkSize
(
const
Place
&
place
)
{
size_t
DeviceManager
::
GetMinChunkSize
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
return
dev_impl
->
GetMinChunkSize
(
device_id
);
return
dev_impl
->
GetMinChunkSize
(
device_id
);
}
}
size_t
DeviceManager
::
GetMaxChunkSize
(
const
Place
&
place
)
{
size_t
DeviceManager
::
GetMaxChunkSize
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
return
dev_impl
->
GetMaxChunkSize
(
device_id
);
return
dev_impl
->
GetMaxChunkSize
(
device_id
);
}
}
size_t
DeviceManager
::
GetMaxAllocSize
(
const
Place
&
place
)
{
size_t
DeviceManager
::
GetMaxAllocSize
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
return
dev_impl
->
GetMaxAllocSize
(
device_id
);
return
dev_impl
->
GetMaxAllocSize
(
device_id
);
}
}
size_t
DeviceManager
::
GetInitAllocSize
(
const
Place
&
place
)
{
size_t
DeviceManager
::
GetInitAllocSize
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
return
dev_impl
->
GetInitAllocSize
(
device_id
);
return
dev_impl
->
GetInitAllocSize
(
device_id
);
}
}
size_t
DeviceManager
::
GetReallocSize
(
const
Place
&
place
)
{
size_t
DeviceManager
::
GetReallocSize
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
return
dev_impl
->
GetReallocSize
(
device_id
);
return
dev_impl
->
GetReallocSize
(
device_id
);
}
}
size_t
DeviceManager
::
GetExtraPaddingSize
(
const
Place
&
place
)
{
size_t
DeviceManager
::
GetExtraPaddingSize
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
return
dev_impl
->
GetExtraPaddingSize
(
device_id
);
return
dev_impl
->
GetExtraPaddingSize
(
device_id
);
}
}
void
DeviceManager
::
MemoryStats
(
const
Place
&
place
,
size_t
*
total
,
void
DeviceManager
::
MemoryStats
(
const
Place
&
place
,
size_t
*
total
,
size_t
*
free
)
{
size_t
*
free
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
dev_impl
->
MemoryStats
(
device_id
,
total
,
free
);
dev_impl
->
MemoryStats
(
device_id
,
total
,
free
);
}
}
...
@@ -393,8 +406,8 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
...
@@ -393,8 +406,8 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
}
else
{
}
else
{
while
((
ptr
=
readdir
(
dir
))
!=
nullptr
)
{
while
((
ptr
=
readdir
(
dir
))
!=
nullptr
)
{
std
::
string
filename
(
ptr
->
d_name
);
std
::
string
filename
(
ptr
->
d_name
);
if
(
std
::
regex_match
(
filename
.
begin
(),
filename
.
end
(),
results
,
if
(
std
::
regex_match
(
express
))
{
filename
.
begin
(),
filename
.
end
(),
results
,
express
))
{
libraries
.
push_back
(
library_dir
+
'/'
+
filename
);
libraries
.
push_back
(
library_dir
+
'/'
+
filename
);
VLOG
(
4
)
<<
"Found lib: "
<<
libraries
.
back
();
VLOG
(
4
)
<<
"Found lib: "
<<
libraries
.
back
();
}
}
...
@@ -405,6 +418,5 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
...
@@ -405,6 +418,5 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
return
libraries
;
return
libraries
;
}
}
}
// namespace platform
}
// namespace phi
}
// namespace paddle
#endif
#endif
paddle/
fluid/platform/device
/device_manager.h
→
paddle/
phi/backends
/device_manager.h
浏览文件 @
b53cdc9e
...
@@ -15,17 +15,16 @@
...
@@ -15,17 +15,16 @@
#pragma once
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/
fluid/platform/device
/device_base.h"
#include "paddle/
phi/backends
/device_base.h"
#include "paddle/
fluid/platform/device
/device_ext.h"
#include "paddle/
phi/backends
/device_ext.h"
#include "paddle/
fluid/platform/device
/event.h"
#include "paddle/
phi/backends
/event.h"
#include "paddle/
fluid/platform/device
/stream.h"
#include "paddle/
phi/backends
/stream.h"
#include "paddle/
fluid/platform
/place.h"
#include "paddle/
phi/common
/place.h"
#include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/core/utils/rw_lock.h"
#include "paddle/phi/core/utils/rw_lock.h"
namespace
paddle
{
namespace
phi
{
namespace
platform
{
class
Device
final
{
class
Device
final
{
public:
public:
Device
(
size_t
dev_id
,
DeviceInterface
*
impl
)
:
dev_id_
(
dev_id
),
impl_
(
impl
)
{}
Device
(
size_t
dev_id
,
DeviceInterface
*
impl
)
:
dev_id_
(
dev_id
),
impl_
(
impl
)
{}
...
@@ -33,8 +32,9 @@ class Device final {
...
@@ -33,8 +32,9 @@ class Device final {
// Stream
// Stream
// ! Create an asynchronous stream
// ! Create an asynchronous stream
void
CreateStream
(
void
CreateStream
(
stream
::
Stream
*
stream
,
const
stream
::
Stream
::
Priority
&
priority
=
stream
::
Stream
*
stream
,
stream
::
Stream
::
Priority
::
kNormal
,
const
stream
::
Stream
::
Priority
&
priority
=
stream
::
Stream
::
Priority
::
kNormal
,
const
stream
::
Stream
::
Flag
&
flag
=
stream
::
Stream
::
Flag
::
kDefaultFlag
);
const
stream
::
Stream
::
Flag
&
flag
=
stream
::
Stream
::
Flag
::
kDefaultFlag
);
// ! Destroys an asynchronous stream.
// ! Destroys an asynchronous stream.
...
@@ -69,17 +69,26 @@ class Device final {
...
@@ -69,17 +69,26 @@ class Device final {
void
StreamWaitEvent
(
const
stream
::
Stream
*
stream
,
const
event
::
Event
*
event
);
void
StreamWaitEvent
(
const
stream
::
Stream
*
stream
,
const
event
::
Event
*
event
);
// Memory
// Memory
void
MemoryCopyH2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyH2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
const
stream
::
Stream
*
stream
=
nullptr
);
void
MemoryCopyD2H
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyD2H
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
const
stream
::
Stream
*
stream
=
nullptr
);
void
MemoryCopyD2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyD2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
const
stream
::
Stream
*
stream
=
nullptr
);
void
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
const
void
*
src
,
void
MemoryCopyP2P
(
const
Place
&
dst_place
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
void
*
MemoryAllocate
(
size_t
size
);
void
*
MemoryAllocate
(
size_t
size
);
...
@@ -168,7 +177,8 @@ void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle);
...
@@ -168,7 +177,8 @@ void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle);
void
LoadCustomRuntimeLib
(
const
CustomRuntimeParams
&
runtime_params
,
void
LoadCustomRuntimeLib
(
const
CustomRuntimeParams
&
runtime_params
,
std
::
unique_ptr
<
C_DeviceInterface
>
device_interface
,
std
::
unique_ptr
<
C_DeviceInterface
>
device_interface
,
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
);
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
);
class
Registrar
{
class
Registrar
{
public:
public:
...
@@ -180,7 +190,6 @@ class Registrar {
...
@@ -180,7 +190,6 @@ class Registrar {
void
Touch
()
{}
void
Touch
()
{}
};
};
}
// namespace platform
}
// namespace phi
}
// namespace paddle
#endif
#endif
paddle/
fluid/platform/device
/event.cc
→
paddle/
phi/backends
/event.cc
浏览文件 @
b53cdc9e
...
@@ -12,13 +12,12 @@
...
@@ -12,13 +12,12 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/platform/device/event.h"
#include "paddle/phi/backends/event.h"
#include "paddle/fluid/platform/device/device_guard.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/stream.h"
namespace
paddle
{
namespace
phi
{
namespace
platform
{
namespace
event
{
namespace
event
{
event_t
Event
::
raw_event
()
const
{
return
event_
;
}
event_t
Event
::
raw_event
()
const
{
return
event_
;
}
...
@@ -27,7 +26,7 @@ void Event::set_event(event_t event) { event_ = event; }
...
@@ -27,7 +26,7 @@ void Event::set_event(event_t event) { event_ = event; }
Event
::
Event
(
const
Place
&
place
,
event_t
event
)
Event
::
Event
(
const
Place
&
place
,
event_t
event
)
:
place_
(
place
),
:
place_
(
place
),
device_
(
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
place
)),
device_
(
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
)),
event_
(
event
),
event_
(
event
),
own_data_
(
false
)
{}
own_data_
(
false
)
{}
...
@@ -60,5 +59,4 @@ void Event::Synchonrize() const { device_->SynchronizeEvent(this); }
...
@@ -60,5 +59,4 @@ void Event::Synchonrize() const { device_->SynchronizeEvent(this); }
const
Place
&
Event
::
GetPlace
()
const
{
return
place_
;
}
const
Place
&
Event
::
GetPlace
()
const
{
return
place_
;
}
}
// namespace event
}
// namespace event
}
// namespace platform
}
// namespace phi
}
// namespace paddle
paddle/
fluid/platform/device
/event.h
→
paddle/
phi/backends
/event.h
浏览文件 @
b53cdc9e
...
@@ -15,8 +15,7 @@
...
@@ -15,8 +15,7 @@
#pragma once
#pragma once
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
phi
{
namespace
platform
{
class
Device
;
class
Device
;
...
@@ -57,5 +56,4 @@ class Event {
...
@@ -57,5 +56,4 @@ class Event {
};
};
}
// namespace event
}
// namespace event
}
// namespace platform
}
// namespace phi
}
// namespace paddle
paddle/
fluid/platform/device
/stream.cc
→
paddle/
phi/backends
/stream.cc
浏览文件 @
b53cdc9e
...
@@ -12,13 +12,12 @@
...
@@ -12,13 +12,12 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/phi/backends/stream.h"
#include "paddle/fluid/platform/device/device_guard.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/event.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/event.h"
namespace
paddle
{
namespace
phi
{
namespace
platform
{
namespace
stream
{
namespace
stream
{
Stream
::~
Stream
()
{
Destroy
();
}
Stream
::~
Stream
()
{
Destroy
();
}
...
@@ -30,15 +29,16 @@ void Stream::set_stream(stream_t stream) { stream_ = stream; }
...
@@ -30,15 +29,16 @@ void Stream::set_stream(stream_t stream) { stream_ = stream; }
// For compatiable
// For compatiable
Stream
::
Stream
(
const
Place
&
place
,
stream_t
stream
)
Stream
::
Stream
(
const
Place
&
place
,
stream_t
stream
)
:
place_
(
place
),
:
place_
(
place
),
device_
(
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
place
)),
device_
(
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
)),
stream_
(
stream
),
stream_
(
stream
),
callback_manager_
(
new
CallbackManager
(
this
)),
callback_manager_
(
new
CallbackManager
(
this
)),
own_data_
(
false
)
{}
own_data_
(
false
)
{}
bool
Stream
::
Init
(
const
Place
&
place
,
const
Priority
&
priority
,
bool
Stream
::
Init
(
const
Place
&
place
,
const
Priority
&
priority
,
const
Flag
&
flag
)
{
const
Flag
&
flag
)
{
place_
=
place
;
place_
=
place
;
device_
=
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
device_
=
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
DeviceGuard
guard
(
place_
);
DeviceGuard
guard
(
place_
);
device_
->
CreateStream
(
this
,
priority
,
flag
);
device_
->
CreateStream
(
this
,
priority
,
flag
);
...
@@ -92,5 +92,4 @@ void Stream::Synchronize() const { device_->SynchronizeStream(this); }
...
@@ -92,5 +92,4 @@ void Stream::Synchronize() const { device_->SynchronizeStream(this); }
const
Place
&
Stream
::
GetPlace
()
const
{
return
place_
;
}
const
Place
&
Stream
::
GetPlace
()
const
{
return
place_
;
}
}
// namespace stream
}
// namespace stream
}
// namespace platform
}
// namespace phi
}
// namespace paddle
paddle/
fluid/platform/device
/stream.h
→
paddle/
phi/backends
/stream.h
浏览文件 @
b53cdc9e
...
@@ -14,11 +14,10 @@
...
@@ -14,11 +14,10 @@
#pragma once
#pragma once
#include "paddle/fluid/platform/device/callback_manager.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/callback_manager.h"
namespace
paddle
{
namespace
phi
{
namespace
platform
{
class
Device
;
class
Device
;
...
@@ -49,7 +48,8 @@ class Stream {
...
@@ -49,7 +48,8 @@ class Stream {
~
Stream
();
~
Stream
();
const
stream_t
&
raw_stream
()
const
;
const
stream_t
&
raw_stream
()
const
;
void
set_stream
(
stream_t
stream
);
void
set_stream
(
stream_t
stream
);
bool
Init
(
const
Place
&
place
,
const
Priority
&
priority
=
Priority
::
kNormal
,
bool
Init
(
const
Place
&
place
,
const
Priority
&
priority
=
Priority
::
kNormal
,
const
Flag
&
flag
=
Flag
::
kDefaultFlag
);
const
Flag
&
flag
=
Flag
::
kDefaultFlag
);
template
<
typename
Callback
>
template
<
typename
Callback
>
void
AddCallback
(
Callback
&&
callback
)
const
{
void
AddCallback
(
Callback
&&
callback
)
const
{
...
@@ -75,5 +75,4 @@ class Stream {
...
@@ -75,5 +75,4 @@ class Stream {
};
};
}
// namespace stream
}
// namespace stream
}
// namespace platform
}
// namespace phi
}
// namespace paddle
paddle/phi/core/CMakeLists.txt
浏览文件 @
b53cdc9e
...
@@ -25,7 +25,7 @@ cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
...
@@ -25,7 +25,7 @@ cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library
(
selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor phi_enforce ddim memcpy
)
cc_library
(
selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor phi_enforce ddim memcpy
)
cc_library
(
phi_device_context SRCS device_context.cc DEPS dense_tensor selected_rows
)
cc_library
(
phi_device_context SRCS device_context.cc DEPS dense_tensor selected_rows
)
cc_library
(
phi_custom_kernel SRCS custom_kernel.cc DEPS kernel_factory convert_utils
)
cc_library
(
phi_custom_kernel SRCS custom_kernel.cc DEPS kernel_factory convert_utils
op_registry phi_tensor_raw
)
# Will remove once we implemented MKLDNN_Tensor
# Will remove once we implemented MKLDNN_Tensor
if
(
WITH_MKLDNN
)
if
(
WITH_MKLDNN
)
...
...
paddle/phi/core/compat/convert_utils.cc
浏览文件 @
b53cdc9e
...
@@ -20,7 +20,7 @@ limitations under the License. */
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/
fluid/platform/device
/device_manager.h"
#include "paddle/
phi/backends
/device_manager.h"
#endif
#endif
namespace
phi
{
namespace
phi
{
...
@@ -83,9 +83,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) {
...
@@ -83,9 +83,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) {
if
(
!
device_type
.
empty
())
{
if
(
!
device_type
.
empty
())
{
return
phi
::
CustomPlace
(
return
phi
::
CustomPlace
(
device_type
,
device_type
,
set_device_id
set_device_id
?
phi
::
DeviceManager
::
GetDevice
(
device_type
)
:
0
);
?
paddle
::
platform
::
DeviceManager
::
GetDevice
(
device_type
)
:
0
);
}
}
#endif
#endif
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
...
...
paddle/phi/core/custom_kernel.cc
浏览文件 @
b53cdc9e
...
@@ -12,6 +12,11 @@
...
@@ -12,6 +12,11 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include "paddle/phi/core/custom_kernel.h"
#include "paddle/phi/core/custom_kernel.h"
namespace
phi
{
namespace
phi
{
...
@@ -50,6 +55,25 @@ void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) {
...
@@ -50,6 +55,25 @@ void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) {
}
}
}
}
void
LoadCustomKernelLib
(
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
)
{
#ifdef _LINUX
typedef
phi
::
CustomKernelMap
&
get_custom_kernel_map_t
();
auto
*
func
=
reinterpret_cast
<
get_custom_kernel_map_t
*>
(
dlsym
(
dso_handle
,
"PD_GetCustomKernelMap"
));
if
(
func
==
nullptr
)
{
LOG
(
WARNING
)
<<
"Skipped lib ["
<<
dso_lib_path
<<
"]: fail to find "
<<
"PD_GetCustomKernelMap symbol in this lib."
;
return
;
}
auto
&
custom_kernel_map
=
func
();
phi
::
RegisterCustomKernels
(
custom_kernel_map
);
LOG
(
INFO
)
<<
"Successed in loading custom kernels in lib: "
<<
dso_lib_path
;
#else
VLOG
(
3
)
<<
"Unsupported: Custom kernel is only implemented on Linux."
;
#endif
return
;
}
}
// namespace phi
}
// namespace phi
#ifdef __cplusplus
#ifdef __cplusplus
...
...
paddle/phi/core/custom_kernel.h
浏览文件 @
b53cdc9e
...
@@ -46,4 +46,6 @@ class CustomKernelMap {
...
@@ -46,4 +46,6 @@ class CustomKernelMap {
*/
*/
void
RegisterCustomKernels
(
const
CustomKernelMap
&
custom_kernel_map
);
void
RegisterCustomKernels
(
const
CustomKernelMap
&
custom_kernel_map
);
// Load custom kernel lib and register
void
LoadCustomKernelLib
(
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
);
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/math_kernel.cc
浏览文件 @
b53cdc9e
...
@@ -197,7 +197,8 @@ PD_REGISTER_KERNEL(subtract,
...
@@ -197,7 +197,8 @@ PD_REGISTER_KERNEL(subtract,
int64_t
,
int64_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
,
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
divide
,
PD_REGISTER_KERNEL
(
divide
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
...
...
paddle/phi/ops/compat/elementwise_sig.cc
浏览文件 @
b53cdc9e
...
@@ -100,6 +100,12 @@ KernelSignature ElementwiseSubGradOpArgumentMapping(
...
@@ -100,6 +100,12 @@ KernelSignature ElementwiseSubGradOpArgumentMapping(
return
KernelSignature
(
"unregistered"
,
{},
{},
{});
return
KernelSignature
(
"unregistered"
,
{},
{},
{});
}
}
KernelSignature
ElementwiseSubDoubleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"subtract_double_grad"
,
{
"Y"
,
"DDX"
,
"DDY"
,
"DOut"
},
{
"axis"
},
{
"DDOut"
});
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_add
,
add
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_add
,
add
);
...
@@ -110,6 +116,7 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
...
@@ -110,6 +116,7 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_add_grad_grad
,
add_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_add_grad_grad
,
add_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_add_triple_grad
,
add_triple_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_add_triple_grad
,
add_triple_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_sub_grad
,
subtract_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_sub_grad
,
subtract_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_sub_grad_grad
,
subtract_double_grad
);
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_add
,
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_add
,
phi
::
ElementwiseAddOpArgumentMapping
);
phi
::
ElementwiseAddOpArgumentMapping
);
...
@@ -127,3 +134,5 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_add_triple_grad,
...
@@ -127,3 +134,5 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_add_triple_grad,
phi
::
ElementwiseAddTripleGradOpArgumentMapping
);
phi
::
ElementwiseAddTripleGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_sub_grad
,
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_sub_grad
,
phi
::
ElementwiseSubGradOpArgumentMapping
);
phi
::
ElementwiseSubGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_sub_grad_grad
,
phi
::
ElementwiseSubDoubleGradOpArgumentMapping
);
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_emb_eltwise_layernorm.py
浏览文件 @
b53cdc9e
...
@@ -244,28 +244,16 @@ class TrtConvertEmbEltwiseLayernormTest1(TrtLayerAutoScanTest):
...
@@ -244,28 +244,16 @@ class TrtConvertEmbEltwiseLayernormTest1(TrtLayerAutoScanTest):
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
(
0
,
5
),
1e-5
yield
self
.
create_inference_config
(),
(
0
,
5
),
1e-5
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
(
0
,
5
),
1e-5
yield
self
.
create_inference_config
(),
(
0
,
5
),
2e-2
# for dynamic_shape
# for dynamic_shape
generate_dynamic_shape
(
attrs
)
generate_dynamic_shape
(
attrs
)
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
(
1
,
4
),
1e-5
yield
self
.
create_inference_config
(),
(
1
,
4
),
1e-5
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
(
1
,
4
),
1e-5
yield
self
.
create_inference_config
(),
(
1
,
4
),
2e-2
def
add_skip_trt_case
(
self
):
def
teller1
(
program_config
,
predictor_config
):
if
self
.
trt_param
.
precision
==
paddle_infer
.
PrecisionType
.
Half
and
len
(
self
.
dynamic_shape
.
min_input_shape
)
!=
0
:
return
True
return
False
self
.
add_skip_case
(
teller1
,
SkipReasons
.
TRT_NOT_IMPLEMENTED
,
"The output has diff between gpu and trt when dynamic fp16 mode."
)
def
test
(
self
):
def
test
(
self
):
self
.
add_skip_trt_case
()
self
.
run_test
()
self
.
run_test
()
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather.py
浏览文件 @
b53cdc9e
...
@@ -138,7 +138,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
...
@@ -138,7 +138,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
"index_data"
:
[
1
]
"index_data"
:
[
1
]
}
}
self
.
dynamic_shape
.
max_input_shape
=
{
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
128
,
256
,
128
,
256
],
"input_data"
:
[
128
,
256
,
64
,
128
],
"index_data"
:
[
4
]
"index_data"
:
[
4
]
}
}
self
.
dynamic_shape
.
opt_input_shape
=
{
self
.
dynamic_shape
.
opt_input_shape
=
{
...
...
python/setup.py.in
浏览文件 @
b53cdc9e
...
@@ -579,8 +579,7 @@ headers = (
...
@@ -579,8 +579,7 @@ headers = (
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/core', recursive=True)) + # phi core headers
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/core', recursive=True)) + # phi core headers
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/backends', recursive=True)) + # phi backends headers
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/backends', recursive=True)) + # phi backends headers
# utila api headers
# utila api headers
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/utils', recursive=True)) + # paddle utils headers
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/utils', recursive=True))) # paddle utils headers
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/device/device_ext.h'])
if '${WITH_MKLDNN}' == 'ON':
if '${WITH_MKLDNN}' == 'ON':
headers += list(find_files('*', '${MKLDNN_INSTALL_DIR}/include')) # mkldnn
headers += list(find_files('*', '${MKLDNN_INSTALL_DIR}/include')) # mkldnn
...
@@ -625,8 +624,6 @@ class InstallHeaders(Command):
...
@@ -625,8 +624,6 @@ class InstallHeaders(Command):
elif 'third_party' not in header:
elif 'third_party' not in header:
# paddle headers
# paddle headers
install_dir = re.sub('@PADDLE_SOURCE_DIR@/', '', header)
install_dir = re.sub('@PADDLE_SOURCE_DIR@/', '', header)
if 'device_ext.h' in header:
install_dir = "paddle/"
else:
else:
# third_party
# third_party
install_dir = re.sub('${THIRD_PARTY_PATH}', 'third_party', header)
install_dir = re.sub('${THIRD_PARTY_PATH}', 'third_party', header)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录