Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
26da689d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
26da689d
编写于
5月 18, 2023
作者:
H
huangjiyi
提交者:
GitHub
5月 18, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move fusion_group kernel to phi (#53781)
上级
0bed2203
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
216 addition
and
205 deletion
+216
-205
paddle/fluid/framework/ir/fusion_group/CMakeLists.txt
paddle/fluid/framework/ir/fusion_group/CMakeLists.txt
+2
-2
paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc
.../fluid/framework/ir/fusion_group/code_generator_tester.cc
+2
-2
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
+8
-10
paddle/fluid/operators/fused/CMakeLists.txt
paddle/fluid/operators/fused/CMakeLists.txt
+1
-1
paddle/fluid/operators/fused/fusion_group_op.cc
paddle/fluid/operators/fused/fusion_group_op.cc
+1
-1
paddle/fluid/operators/fused/fusion_group_op.h
paddle/fluid/operators/fused/fusion_group_op.h
+0
-99
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+1
-5
paddle/fluid/platform/device_code_test.cc
paddle/fluid/platform/device_code_test.cc
+24
-20
paddle/phi/backends/CMakeLists.txt
paddle/phi/backends/CMakeLists.txt
+4
-0
paddle/phi/backends/device_code.cc
paddle/phi/backends/device_code.cc
+29
-29
paddle/phi/backends/device_code.h
paddle/phi/backends/device_code.h
+16
-18
paddle/phi/kernels/CMakeLists.txt
paddle/phi/kernels/CMakeLists.txt
+5
-0
paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu
paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu
+104
-0
paddle/phi/ops/compat/fusion_group_sig.cc
paddle/phi/ops/compat/fusion_group_sig.cc
+14
-12
test/cpp/fluid/fused/fusion_group_op_test.cc
test/cpp/fluid/fused/fusion_group_op_test.cc
+5
-6
未找到文件。
paddle/fluid/framework/ir/fusion_group/CMakeLists.txt
浏览文件 @
26da689d
...
@@ -6,13 +6,13 @@ if(WITH_GPU OR WITH_ROCM)
...
@@ -6,13 +6,13 @@ if(WITH_GPU OR WITH_ROCM)
cc_test
(
cc_test
(
test_code_generator
test_code_generator
SRCS code_generator_tester.cc
SRCS code_generator_tester.cc
DEPS code_generator
device_code
lod_tensor graph_viz_pass
)
DEPS code_generator
phi_backends
lod_tensor graph_viz_pass
)
endif
()
endif
()
cc_library
(
cc_library
(
fusion_group_pass
fusion_group_pass
SRCS fusion_group_pass.cc elementwise_group_detector.cc
SRCS fusion_group_pass.cc elementwise_group_detector.cc
DEPS subgraph_detector fuse_pass_base code_generator
device_code
)
DEPS subgraph_detector fuse_pass_base code_generator
phi_backends
)
cc_test
(
cc_test
(
test_fusion_group_pass
test_fusion_group_pass
SRCS fusion_group_pass_tester.cc
SRCS fusion_group_pass_tester.cc
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc
浏览文件 @
26da689d
...
@@ -20,8 +20,8 @@ limitations under the License. */
...
@@ -20,8 +20,8 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device_code.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/device_code.h"
namespace
phi
{
namespace
phi
{
class
DenseTensor
;
class
DenseTensor
;
...
@@ -182,7 +182,7 @@ void TestMainImpl(std::string func_name,
...
@@ -182,7 +182,7 @@ void TestMainImpl(std::string func_name,
std
::
type_index
(
typeid
(
paddle
::
platform
::
float16
));
std
::
type_index
(
typeid
(
paddle
::
platform
::
float16
));
paddle
::
platform
::
CUDAPlace
place
=
paddle
::
platform
::
CUDAPlace
(
0
);
paddle
::
platform
::
CUDAPlace
place
=
paddle
::
platform
::
CUDAPlace
(
0
);
p
addle
::
platform
::
CUDA
DeviceCode
device_code
(
place
,
func_name
,
code_str
);
p
hi
::
GPU
DeviceCode
device_code
(
place
,
func_name
,
code_str
);
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
device_code
.
Compile
(
true
);
device_code
.
Compile
(
true
);
#else
#else
...
...
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
浏览文件 @
26da689d
...
@@ -19,12 +19,10 @@ limitations under the License. */
...
@@ -19,12 +19,10 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/platform/device_code.h"
#include "paddle/phi/backends/device_code.h"
namespace
paddle
{
namespace
phi
{
namespace
platform
{
class
DeviceCodePool
;
class
DeviceCodePool
;
}
// namespace platform
}
// namespace phi
}
// namespace paddle
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -36,7 +34,7 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -36,7 +34,7 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase
::
Init
(
"fusion_group_pass"
,
graph
);
FusePassBase
::
Init
(
"fusion_group_pass"
,
graph
);
if
(
Get
<
bool
>
(
"use_gpu"
))
{
if
(
Get
<
bool
>
(
"use_gpu"
))
{
// TODO(liuyiqun): open this check.
// TODO(liuyiqun): open this check.
// if (!p
latform::CUDA
DeviceCode::IsAvailable()) {
// if (!p
hi::GPU
DeviceCode::IsAvailable()) {
// LOG(WARNING)
// LOG(WARNING)
// << "Disable fusion_group because CUDA Driver or NVRTC is not
// << "Disable fusion_group because CUDA Driver or NVRTC is not
// avaiable.";
// avaiable.";
...
@@ -54,7 +52,7 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -54,7 +52,7 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
int
FusionGroupPass
::
DetectFusionGroup
(
Graph
*
graph
,
int
type
)
const
{
int
FusionGroupPass
::
DetectFusionGroup
(
Graph
*
graph
,
int
type
)
const
{
// TODO(liuyiqun): supported different places
// TODO(liuyiqun): supported different places
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
0
);
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
0
);
int
index
=
p
latform
::
DeviceCodePool
::
Init
({
place
}).
size
(
place
);
int
index
=
p
hi
::
DeviceCodePool
::
Init
({
place
}).
size
(
place
);
std
::
vector
<
std
::
vector
<
Node
*>>
subgraphs
=
std
::
vector
<
std
::
vector
<
Node
*>>
subgraphs
=
fusion_group
::
ElementwiseGroupDetector
()(
graph
);
fusion_group
::
ElementwiseGroupDetector
()(
graph
);
...
@@ -88,11 +86,11 @@ bool FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const {
...
@@ -88,11 +86,11 @@ bool FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const {
// TODO(liuyiqun): supported different places
// TODO(liuyiqun): supported different places
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
0
);
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
0
);
std
::
unique_ptr
<
p
latform
::
CUDA
DeviceCode
>
device_code
(
std
::
unique_ptr
<
p
hi
::
GPU
DeviceCode
>
device_code
(
new
p
latform
::
CUDA
DeviceCode
(
place
,
subgraph
->
GetFuncName
(),
code_str
));
new
p
hi
::
GPU
DeviceCode
(
place
,
subgraph
->
GetFuncName
(),
code_str
));
bool
is_compiled
=
device_code
->
Compile
();
bool
is_compiled
=
device_code
->
Compile
();
if
(
is_compiled
)
{
if
(
is_compiled
)
{
p
latform
::
DeviceCodePool
&
pool
=
platform
::
DeviceCodePool
::
Init
({
place
});
p
hi
::
DeviceCodePool
&
pool
=
phi
::
DeviceCodePool
::
Init
({
place
});
pool
.
Set
(
std
::
move
(
device_code
));
pool
.
Set
(
std
::
move
(
device_code
));
}
}
return
is_compiled
;
return
is_compiled
;
...
...
paddle/fluid/operators/fused/CMakeLists.txt
浏览文件 @
26da689d
...
@@ -73,7 +73,7 @@ if(WITH_GPU OR WITH_ROCM)
...
@@ -73,7 +73,7 @@ if(WITH_GPU OR WITH_ROCM)
op_library
(
fused_gate_attention_op
)
op_library
(
fused_gate_attention_op
)
# fusion_group
# fusion_group
if
(
NOT APPLE AND NOT WIN32
)
if
(
NOT APPLE AND NOT WIN32
)
op_library
(
fusion_group_op
DEPS device_code
)
op_library
(
fusion_group_op
)
endif
()
endif
()
# fused_bn_add_activation
# fused_bn_add_activation
# HIP not support bn act fuse in MIOPEN
# HIP not support bn act fuse in MIOPEN
...
...
paddle/fluid/operators/fused/fusion_group_op.cc
浏览文件 @
26da689d
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/
operators/fused/fusion_group_op
.h"
#include "paddle/fluid/
framework/op_registry
.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/fluid/operators/fused/fusion_group_op.h
已删除
100644 → 0
浏览文件 @
0bed2203
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_code.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
>
static
void
MutableMultiTypeData
(
std
::
vector
<
phi
::
DenseTensor
*>*
var
,
const
std
::
vector
<
int
>&
data_type
,
const
DeviceContext
&
dev_ctx
,
const
platform
::
Place
&
place
)
{
for
(
size_t
i
=
0
;
i
<
var
->
size
();
i
++
)
{
if
(
data_type
[
i
]
==
framework
::
proto
::
VarType
::
FP32
)
{
dev_ctx
.
template
Alloc
<
float
>((
*
var
)[
i
],
(
*
var
)[
i
]
->
numel
()
*
sizeof
(
float
));
}
else
if
(
data_type
[
i
]
==
framework
::
proto
::
VarType
::
FP16
)
{
dev_ctx
.
template
Alloc
<
paddle
::
platform
::
float16
>(
(
*
var
)[
i
],
(
*
var
)[
i
]
->
numel
()
*
sizeof
(
paddle
::
platform
::
float16
));
}
else
if
(
data_type
[
i
]
==
framework
::
proto
::
VarType
::
FP64
)
{
dev_ctx
.
template
Alloc
<
double
>((
*
var
)[
i
],
(
*
var
)[
i
]
->
numel
()
*
sizeof
(
double
));
}
}
}
template
<
typename
T
,
typename
DeviceContext
>
class
FusionGroupKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
ins
=
ctx
.
MultiInput
<
phi
::
DenseTensor
>
(
"Inputs"
);
auto
outs
=
ctx
.
MultiOutput
<
phi
::
DenseTensor
>
(
"Outs"
);
int
type
=
ctx
.
Attr
<
int
>
(
"type"
);
const
auto
&
outs_dtype
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"outs_dtype"
);
const
auto
&
inputs_dtype
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"inputs_dtype"
);
size_t
num_ins
=
ins
.
size
();
size_t
num_outs
=
outs
.
size
();
auto
place
=
ctx
.
GetPlace
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
MutableMultiTypeData
(
&
outs
,
outs_dtype
,
dev_ctx
,
place
);
std
::
string
func_name
=
ctx
.
Attr
<
std
::
string
>
(
"func_name"
);
platform
::
DeviceCode
*
dev_code
=
platform
::
DeviceCodePool
::
Instance
().
Get
(
place
,
func_name
);
VLOG
(
3
)
<<
"func_name: "
<<
func_name
;
if
(
type
==
0
)
{
size_t
n
=
ins
[
0
]
->
numel
();
std
::
vector
<
void
*>
args
;
args
.
push_back
(
&
n
);
std
::
vector
<
const
void
*>
ptrs
(
num_ins
+
num_outs
);
for
(
size_t
i
=
0
;
i
<
num_ins
;
++
i
)
{
if
(
inputs_dtype
[
i
]
==
framework
::
proto
::
VarType
::
FP16
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
paddle
::
platform
::
float16
>
();
}
else
if
(
inputs_dtype
[
i
]
==
framework
::
proto
::
VarType
::
FP32
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
float
>
();
}
else
if
(
inputs_dtype
[
i
]
==
framework
::
proto
::
VarType
::
FP64
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
double
>
();
}
args
.
push_back
(
&
ptrs
[
i
]);
}
for
(
size_t
j
=
0
;
j
<
num_outs
;
++
j
)
{
if
(
outs_dtype
[
j
]
==
framework
::
proto
::
VarType
::
FP16
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
paddle
::
platform
::
float16
>
();
}
else
if
(
outs_dtype
[
j
]
==
framework
::
proto
::
VarType
::
FP32
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
float
>
();
}
else
if
(
outs_dtype
[
j
]
==
framework
::
proto
::
VarType
::
FP64
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
double
>
();
}
args
.
push_back
(
&
ptrs
[
num_ins
+
j
]);
}
dev_code
->
Launch
(
n
,
&
args
);
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
26da689d
...
@@ -356,15 +356,11 @@ if(WITH_ROCM)
...
@@ -356,15 +356,11 @@ if(WITH_ROCM)
endif
()
endif
()
if
(
NOT APPLE AND NOT WIN32
)
if
(
NOT APPLE AND NOT WIN32
)
cc_library
(
device_code
SRCS device_code.cc
DEPS device_context
)
if
(
WITH_GPU OR WITH_ROCM
)
if
(
WITH_GPU OR WITH_ROCM
)
cc_test
(
cc_test
(
device_code_test
device_code_test
SRCS device_code_test.cc
SRCS device_code_test.cc
DEPS
device_code
lod_tensor
)
DEPS
phi_backends
lod_tensor
)
endif
()
endif
()
endif
()
endif
()
...
...
paddle/fluid/platform/device_code_test.cc
浏览文件 @
26da689d
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/
fluid/platform
/device_code.h"
#include "paddle/
phi/backends
/device_code.h"
#include <utility>
#include <utility>
...
@@ -47,14 +47,13 @@ void saxpy_kernel(float a, float *x, float* y, float* z, size_t n) {
...
@@ -47,14 +47,13 @@ void saxpy_kernel(float a, float *x, float* y, float* z, size_t n) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST
(
DeviceCode
,
cuda
)
{
TEST
(
DeviceCode
,
cuda
)
{
if
(
!
paddle
::
platform
::
dynload
::
HasNVRTC
()
||
if
(
!
phi
::
dynload
::
HasNVRTC
()
||
!
phi
::
dynload
::
HasCUDADriver
())
{
!
paddle
::
platform
::
dynload
::
HasCUDADriver
())
{
return
;
return
;
}
}
paddle
::
framework
::
InitDevices
({
0
});
paddle
::
framework
::
InitDevices
({
0
});
p
addle
::
platform
::
CUDAPlace
place
=
paddle
::
platform
::
CUDA
Place
(
0
);
p
hi
::
GPUPlace
place
=
phi
::
GPU
Place
(
0
);
p
addle
::
platform
::
CUDA
DeviceCode
code
(
place
,
"saxpy_kernel"
,
saxpy_code
);
p
hi
::
GPU
DeviceCode
code
(
place
,
"saxpy_kernel"
,
saxpy_code
);
phi
::
DenseTensor
cpu_x
;
phi
::
DenseTensor
cpu_x
;
phi
::
DenseTensor
cpu_y
;
phi
::
DenseTensor
cpu_y
;
...
@@ -63,8 +62,12 @@ TEST(DeviceCode, cuda) {
...
@@ -63,8 +62,12 @@ TEST(DeviceCode, cuda) {
float
scale
=
2
;
float
scale
=
2
;
auto
dims
=
auto
dims
=
phi
::
make_ddim
({
static_cast
<
int64_t
>
(
256
),
static_cast
<
int64_t
>
(
1024
)});
phi
::
make_ddim
({
static_cast
<
int64_t
>
(
256
),
static_cast
<
int64_t
>
(
1024
)});
cpu_x
.
mutable_data
<
float
>
(
dims
,
paddle
::
platform
::
CPUPlace
());
phi
::
DeviceContextPool
&
pool
=
phi
::
DeviceContextPool
::
Instance
();
cpu_y
.
mutable_data
<
float
>
(
dims
,
paddle
::
platform
::
CPUPlace
());
auto
*
cpu_ctx
=
reinterpret_cast
<
phi
::
CPUContext
*>
(
pool
.
Get
(
phi
::
CPUPlace
()));
cpu_x
.
Resize
(
dims
);
cpu_ctx
->
template
Alloc
<
float
>(
&
cpu_x
);
cpu_y
.
Resize
(
dims
);
cpu_ctx
->
template
Alloc
<
float
>(
&
cpu_y
);
size_t
n
=
cpu_x
.
numel
();
size_t
n
=
cpu_x
.
numel
();
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
...
@@ -78,9 +81,13 @@ TEST(DeviceCode, cuda) {
...
@@ -78,9 +81,13 @@ TEST(DeviceCode, cuda) {
phi
::
DenseTensor
y
;
phi
::
DenseTensor
y
;
phi
::
DenseTensor
z
;
phi
::
DenseTensor
z
;
float
*
x_data
=
x
.
mutable_data
<
float
>
(
dims
,
place
);
auto
*
dev_ctx
=
reinterpret_cast
<
phi
::
GPUContext
*>
(
pool
.
Get
(
place
));
float
*
y_data
=
y
.
mutable_data
<
float
>
(
dims
,
place
);
x
.
Resize
(
dims
);
float
*
z_data
=
z
.
mutable_data
<
float
>
(
dims
,
place
);
float
*
x_data
=
dev_ctx
->
template
Alloc
<
float
>(
&
x
);
y
.
Resize
(
dims
);
float
*
y_data
=
dev_ctx
->
template
Alloc
<
float
>(
&
y
);
z
.
Resize
(
dims
);
float
*
z_data
=
dev_ctx
->
template
Alloc
<
float
>(
&
z
);
paddle
::
framework
::
TensorCopySync
(
cpu_x
,
place
,
&
x
);
paddle
::
framework
::
TensorCopySync
(
cpu_x
,
place
,
&
x
);
paddle
::
framework
::
TensorCopySync
(
cpu_y
,
place
,
&
y
);
paddle
::
framework
::
TensorCopySync
(
cpu_y
,
place
,
&
y
);
...
@@ -92,36 +99,33 @@ TEST(DeviceCode, cuda) {
...
@@ -92,36 +99,33 @@ TEST(DeviceCode, cuda) {
code
.
SetWorkloadPerThread
(
1
);
code
.
SetWorkloadPerThread
(
1
);
code
.
Launch
(
n
,
&
args
);
code
.
Launch
(
n
,
&
args
);
auto
*
dev_ctx
=
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
dev_ctx
->
Wait
();
dev_ctx
->
Wait
();
paddle
::
framework
::
TensorCopySync
(
z
,
p
addle
::
platform
::
CPUPlace
(),
&
cpu_z
);
paddle
::
framework
::
TensorCopySync
(
z
,
p
hi
::
CPUPlace
(),
&
cpu_z
);
for
(
size_t
i
=
0
;
i
<
n
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
n
;
i
++
)
{
EXPECT_EQ
(
cpu_z
.
data
<
float
>
()[
i
],
static_cast
<
float
>
(
i
)
*
scale
+
0.5
);
EXPECT_EQ
(
cpu_z
.
data
<
float
>
()[
i
],
static_cast
<
float
>
(
i
)
*
scale
+
0.5
);
}
}
}
}
TEST
(
DeviceCodePool
,
cuda
)
{
TEST
(
DeviceCodePool
,
cuda
)
{
if
(
!
paddle
::
platform
::
dynload
::
HasNVRTC
()
||
if
(
!
phi
::
dynload
::
HasNVRTC
()
||
!
phi
::
dynload
::
HasCUDADriver
())
{
!
paddle
::
platform
::
dynload
::
HasCUDADriver
())
{
return
;
return
;
}
}
paddle
::
framework
::
InitDevices
({
0
});
paddle
::
framework
::
InitDevices
({
0
});
paddle
::
platform
::
CUDAPlace
place
=
paddle
::
platform
::
CUDAPlace
(
0
);
phi
::
GPUPlace
place
=
phi
::
GPUPlace
(
0
);
paddle
::
platform
::
DeviceCodePool
&
pool
=
phi
::
DeviceCodePool
&
pool
=
phi
::
DeviceCodePool
::
Init
({
place
});
paddle
::
platform
::
DeviceCodePool
::
Init
({
place
});
size_t
num_device_codes_before
=
pool
.
size
(
place
);
size_t
num_device_codes_before
=
pool
.
size
(
place
);
EXPECT_EQ
(
num_device_codes_before
,
0UL
);
EXPECT_EQ
(
num_device_codes_before
,
0UL
);
std
::
unique_ptr
<
p
addle
::
platform
::
DeviceCode
>
code
(
std
::
unique_ptr
<
p
hi
::
DeviceCode
>
code
(
new
p
addle
::
platform
::
CUDA
DeviceCode
(
place
,
"saxpy_kernel"
,
saxpy_code
));
new
p
hi
::
GPU
DeviceCode
(
place
,
"saxpy_kernel"
,
saxpy_code
));
LOG
(
INFO
)
<<
"origin ptr: "
<<
code
.
get
();
LOG
(
INFO
)
<<
"origin ptr: "
<<
code
.
get
();
pool
.
Set
(
std
::
move
(
code
));
pool
.
Set
(
std
::
move
(
code
));
size_t
num_device_codes_after
=
pool
.
size
(
place
);
size_t
num_device_codes_after
=
pool
.
size
(
place
);
EXPECT_EQ
(
num_device_codes_after
,
1UL
);
EXPECT_EQ
(
num_device_codes_after
,
1UL
);
p
addle
::
platform
::
DeviceCode
*
code_get
=
pool
.
Get
(
place
,
"saxpy_kernel"
);
p
hi
::
DeviceCode
*
code_get
=
pool
.
Get
(
place
,
"saxpy_kernel"
);
LOG
(
INFO
)
<<
"get ptr: "
<<
code_get
;
LOG
(
INFO
)
<<
"get ptr: "
<<
code_get
;
}
}
#endif
#endif
paddle/phi/backends/CMakeLists.txt
浏览文件 @
26da689d
...
@@ -14,6 +14,10 @@ if(WITH_XBYAK)
...
@@ -14,6 +14,10 @@ if(WITH_XBYAK)
list
(
APPEND BACKENDS_DEPS xbyak
)
list
(
APPEND BACKENDS_DEPS xbyak
)
endif
()
endif
()
if
(
NOT APPLE AND NOT WIN32
)
list
(
APPEND BACKENDS_SRCS device_code.cc
)
endif
()
if
(
WITH_GPU OR WITH_ROCM
)
if
(
WITH_GPU OR WITH_ROCM
)
list
(
APPEND BACKENDS_SRCS gpu/gpu_context.cc gpu/gpu_info.cc
list
(
APPEND BACKENDS_SRCS gpu/gpu_context.cc gpu/gpu_info.cc
gpu/gpu_resources.cc
)
gpu/gpu_resources.cc
)
...
...
paddle/
fluid/platform
/device_code.cc
→
paddle/
phi/backends
/device_code.cc
浏览文件 @
26da689d
...
@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/
fluid/platform
/device_code.h"
#include "paddle/
phi/backends
/device_code.h"
#include <glog/logging.h>
#include <sys/stat.h>
#include <sys/stat.h>
#include <algorithm>
#include <algorithm>
#include <set>
#include <set>
#include <utility>
#include <utility>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_string
(
cuda_dir
);
PHI_DECLARE_string
(
cuda_dir
);
namespace
paddle
{
namespace
phi
{
namespace
platform
{
DeviceCodePool
*
DeviceCodePool
::
pool
=
nullptr
;
DeviceCodePool
*
DeviceCodePool
::
pool
=
nullptr
;
...
@@ -35,7 +37,7 @@ void DeviceCodePool::Set(std::unique_ptr<DeviceCode>&& code) {
...
@@ -35,7 +37,7 @@ void DeviceCodePool::Set(std::unique_ptr<DeviceCode>&& code) {
auto
iter
=
device_codes_
.
find
(
place
);
auto
iter
=
device_codes_
.
find
(
place
);
if
(
iter
==
device_codes_
.
end
())
{
if
(
iter
==
device_codes_
.
end
())
{
PADDLE_THROW
(
p
latform
::
errors
::
NotFound
(
PADDLE_THROW
(
p
hi
::
errors
::
NotFound
(
"Place %s is not supported for runtime compiling."
,
place
));
"Place %s is not supported for runtime compiling."
,
place
));
}
}
...
@@ -43,18 +45,18 @@ void DeviceCodePool::Set(std::unique_ptr<DeviceCode>&& code) {
...
@@ -43,18 +45,18 @@ void DeviceCodePool::Set(std::unique_ptr<DeviceCode>&& code) {
codes_map
.
emplace
(
name
,
std
::
move
(
code
));
codes_map
.
emplace
(
name
,
std
::
move
(
code
));
}
}
platform
::
DeviceCode
*
DeviceCodePool
::
Get
(
const
platform
::
Place
&
place
,
DeviceCode
*
DeviceCodePool
::
Get
(
const
phi
::
Place
&
place
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
)
{
auto
iter
=
device_codes_
.
find
(
place
);
auto
iter
=
device_codes_
.
find
(
place
);
if
(
iter
==
device_codes_
.
end
())
{
if
(
iter
==
device_codes_
.
end
())
{
PADDLE_THROW
(
p
latform
::
errors
::
NotFound
(
PADDLE_THROW
(
p
hi
::
errors
::
NotFound
(
"Place %s is not supported for runtime compiling."
,
place
));
"Place %s is not supported for runtime compiling."
,
place
));
}
}
auto
&
codes_map
=
iter
->
second
;
auto
&
codes_map
=
iter
->
second
;
auto
code_iter
=
codes_map
.
find
(
name
);
auto
code_iter
=
codes_map
.
find
(
name
);
if
(
code_iter
==
codes_map
.
end
())
{
if
(
code_iter
==
codes_map
.
end
())
{
PADDLE_THROW
(
p
latform
::
errors
::
NotFound
(
PADDLE_THROW
(
p
hi
::
errors
::
NotFound
(
"Device code named %s for place %s does not exist."
,
"Device code named %s for place %s does not exist."
,
name
.
c_str
(),
name
.
c_str
(),
place
));
place
));
...
@@ -63,7 +65,7 @@ platform::DeviceCode* DeviceCodePool::Get(const platform::Place& place,
...
@@ -63,7 +65,7 @@ platform::DeviceCode* DeviceCodePool::Get(const platform::Place& place,
return
code_iter
->
second
.
get
();
return
code_iter
->
second
.
get
();
}
}
DeviceCodePool
::
DeviceCodePool
(
const
std
::
vector
<
p
latform
::
Place
>&
places
)
{
DeviceCodePool
::
DeviceCodePool
(
const
std
::
vector
<
p
hi
::
Place
>&
places
)
{
PADDLE_ENFORCE_GT
(
places
.
size
(),
PADDLE_ENFORCE_GT
(
places
.
size
(),
0
,
0
,
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
...
@@ -75,11 +77,11 @@ DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) {
...
@@ -75,11 +77,11 @@ DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) {
set
.
insert
(
p
);
set
.
insert
(
p
);
}
}
for
(
auto
&
p
:
set
)
{
for
(
auto
&
p
:
set
)
{
if
(
is_gpu_place
(
p
)
)
{
if
(
p
.
GetType
()
==
phi
::
AllocationType
::
GPU
)
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
device_codes_
.
emplace
(
p
,
DeviceCodeMap
());
device_codes_
.
emplace
(
p
,
DeviceCodeMap
());
#else
#else
PADDLE_THROW
(
p
latform
::
errors
::
PreconditionNotMet
(
PADDLE_THROW
(
p
hi
::
errors
::
PreconditionNotMet
(
"CUDAPlace or HIPPlace is not supported, please re-compile with "
"CUDAPlace or HIPPlace is not supported, please re-compile with "
"WITH_GPU=ON or WITH_ROCM=ON."
));
"WITH_GPU=ON or WITH_ROCM=ON."
));
#endif
#endif
...
@@ -87,7 +89,7 @@ DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) {
...
@@ -87,7 +89,7 @@ DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) {
}
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
CUDA
DeviceCode
::
CheckAvailableStatus
();
GPU
DeviceCode
::
CheckAvailableStatus
();
#endif
#endif
}
}
...
@@ -114,8 +116,8 @@ static bool CheckCUDADriverResult(CUresult result,
...
@@ -114,8 +116,8 @@ static bool CheckCUDADriverResult(CUresult result,
return
true
;
return
true
;
}
}
bool
CUDA
DeviceCode
::
available_
=
false
;
bool
GPU
DeviceCode
::
available_
=
false
;
void
CUDA
DeviceCode
::
CheckAvailableStatus
()
{
void
GPU
DeviceCode
::
CheckAvailableStatus
()
{
available_
=
false
;
available_
=
false
;
if
(
!
dynload
::
HasNVRTC
()
||
!
dynload
::
HasCUDADriver
())
{
if
(
!
dynload
::
HasNVRTC
()
||
!
dynload
::
HasCUDADriver
())
{
LOG_FIRST_N
(
WARNING
,
1
)
LOG_FIRST_N
(
WARNING
,
1
)
...
@@ -215,12 +217,12 @@ static std::string FindCUDAIncludePath() {
...
@@ -215,12 +217,12 @@ static std::string FindCUDAIncludePath() {
return
""
;
return
""
;
}
}
CUDADeviceCode
::
CUDA
DeviceCode
(
const
Place
&
place
,
GPUDeviceCode
::
GPU
DeviceCode
(
const
Place
&
place
,
const
std
::
string
&
name
,
const
std
::
string
&
name
,
const
std
::
string
&
kernel
)
{
const
std
::
string
&
kernel
)
{
if
(
!
is_gpu_place
(
place
)
)
{
if
(
place
.
GetType
()
!=
phi
::
AllocationType
::
GPU
)
{
PADDLE_THROW
(
p
latform
::
errors
::
PermissionDenied
(
PADDLE_THROW
(
p
hi
::
errors
::
PermissionDenied
(
"
CUDA
DeviceCode can only launch on GPU place."
));
"
GPU
DeviceCode can only launch on GPU place."
));
}
}
place_
=
place
;
place_
=
place
;
...
@@ -232,7 +234,7 @@ CUDADeviceCode::CUDADeviceCode(const Place& place,
...
@@ -232,7 +234,7 @@ CUDADeviceCode::CUDADeviceCode(const Place& place,
#endif
#endif
}
}
bool
CUDA
DeviceCode
::
Compile
(
bool
include_path
)
{
bool
GPU
DeviceCode
::
Compile
(
bool
include_path
)
{
is_compiled_
=
false
;
is_compiled_
=
false
;
if
(
!
dynload
::
HasNVRTC
()
||
!
dynload
::
HasCUDADriver
())
{
if
(
!
dynload
::
HasNVRTC
()
||
!
dynload
::
HasCUDADriver
())
{
LOG_FIRST_N
(
WARNING
,
1
)
LOG_FIRST_N
(
WARNING
,
1
)
...
@@ -403,7 +405,7 @@ bool CUDADeviceCode::Compile(bool include_path) {
...
@@ -403,7 +405,7 @@ bool CUDADeviceCode::Compile(bool include_path) {
return
true
;
return
true
;
}
}
void
CUDA
DeviceCode
::
Launch
(
const
size_t
n
,
std
::
vector
<
void
*>*
args
)
const
{
void
GPU
DeviceCode
::
Launch
(
const
size_t
n
,
std
::
vector
<
void
*>*
args
)
const
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
is_compiled_
,
is_compiled_
,
true
,
true
,
...
@@ -454,7 +456,7 @@ void CUDADeviceCode::Launch(const size_t n, std::vector<void*>* args) const {
...
@@ -454,7 +456,7 @@ void CUDADeviceCode::Launch(const size_t n, std::vector<void*>* args) const {
}
}
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
bool
CUDA
DeviceCode
::
CheckNVRTCResult
(
hiprtcResult
result
,
bool
GPU
DeviceCode
::
CheckNVRTCResult
(
hiprtcResult
result
,
std
::
string
function
)
{
std
::
string
function
)
{
if
(
result
!=
HIPRTC_SUCCESS
)
{
if
(
result
!=
HIPRTC_SUCCESS
)
{
LOG_FIRST_N
(
WARNING
,
1
)
LOG_FIRST_N
(
WARNING
,
1
)
...
@@ -463,8 +465,7 @@ bool CUDADeviceCode::CheckNVRTCResult(hiprtcResult result,
...
@@ -463,8 +465,7 @@ bool CUDADeviceCode::CheckNVRTCResult(hiprtcResult result,
return
false
;
return
false
;
}
}
#else
#else
bool
CUDADeviceCode
::
CheckNVRTCResult
(
nvrtcResult
result
,
bool
GPUDeviceCode
::
CheckNVRTCResult
(
nvrtcResult
result
,
std
::
string
function
)
{
std
::
string
function
)
{
if
(
result
!=
NVRTC_SUCCESS
)
{
if
(
result
!=
NVRTC_SUCCESS
)
{
LOG_FIRST_N
(
WARNING
,
1
)
LOG_FIRST_N
(
WARNING
,
1
)
<<
"Call "
<<
function
<<
" for < "
<<
name_
<<
"Call "
<<
function
<<
" for < "
<<
name_
...
@@ -476,5 +477,4 @@ bool CUDADeviceCode::CheckNVRTCResult(nvrtcResult result,
...
@@ -476,5 +477,4 @@ bool CUDADeviceCode::CheckNVRTCResult(nvrtcResult result,
}
}
#endif
#endif
}
// namespace platform
}
// namespace phi
}
// namespace paddle
paddle/
fluid/platform
/device_code.h
→
paddle/
phi/backends
/device_code.h
浏览文件 @
26da689d
...
@@ -20,18 +20,18 @@ limitations under the License. */
...
@@ -20,18 +20,18 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include "paddle/
fluid/platform
/dynload/cuda_driver.h"
#include "paddle/
phi/backends
/dynload/cuda_driver.h"
#include "paddle/
fluid/platform
/dynload/nvrtc.h"
#include "paddle/
phi/backends
/dynload/nvrtc.h"
#endif
#endif
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
#include "paddle/
fluid/platform
/dynload/hiprtc.h"
#include "paddle/
phi/backends
/dynload/hiprtc.h"
#include "paddle/
fluid/platform
/dynload/rocm_driver.h"
#include "paddle/
phi/backends
/dynload/rocm_driver.h"
#endif
#endif
namespace
paddle
{
namespace
phi
{
namespace
platform
{
class
DeviceCode
{
class
DeviceCode
{
public:
public:
...
@@ -49,9 +49,9 @@ class DeviceCode {
...
@@ -49,9 +49,9 @@ class DeviceCode {
};
};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
class
CUDA
DeviceCode
:
public
DeviceCode
{
class
GPU
DeviceCode
:
public
DeviceCode
{
public:
public:
explicit
CUDA
DeviceCode
(
const
Place
&
place
,
explicit
GPU
DeviceCode
(
const
Place
&
place
,
const
std
::
string
&
name
,
const
std
::
string
&
name
,
const
std
::
string
&
kernel
);
const
std
::
string
&
kernel
);
bool
Compile
(
bool
include_path
=
false
)
override
;
bool
Compile
(
bool
include_path
=
false
)
override
;
...
@@ -94,7 +94,7 @@ class DeviceCodePool {
...
@@ -94,7 +94,7 @@ class DeviceCodePool {
using
DeviceCodeMap
=
using
DeviceCodeMap
=
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
DeviceCode
>>
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
DeviceCode
>>
;
explicit
DeviceCodePool
(
const
std
::
vector
<
platform
::
Place
>&
places
);
explicit
DeviceCodePool
(
const
std
::
vector
<
Place
>&
places
);
static
DeviceCodePool
&
Instance
()
{
static
DeviceCodePool
&
Instance
()
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
...
@@ -104,7 +104,7 @@ class DeviceCodePool {
...
@@ -104,7 +104,7 @@ class DeviceCodePool {
return
*
pool
;
return
*
pool
;
}
}
static
DeviceCodePool
&
Init
(
const
std
::
vector
<
platform
::
Place
>&
places
)
{
static
DeviceCodePool
&
Init
(
const
std
::
vector
<
Place
>&
places
)
{
if
(
pool
==
nullptr
)
{
if
(
pool
==
nullptr
)
{
pool
=
new
DeviceCodePool
(
places
);
pool
=
new
DeviceCodePool
(
places
);
}
}
...
@@ -113,10 +113,9 @@ class DeviceCodePool {
...
@@ -113,10 +113,9 @@ class DeviceCodePool {
void
Set
(
std
::
unique_ptr
<
DeviceCode
>&&
code
);
void
Set
(
std
::
unique_ptr
<
DeviceCode
>&&
code
);
platform
::
DeviceCode
*
Get
(
const
platform
::
Place
&
place
,
DeviceCode
*
Get
(
const
Place
&
place
,
const
std
::
string
&
name
);
const
std
::
string
&
name
);
size_t
size
(
const
platform
::
Place
&
place
)
const
{
size_t
size
(
const
Place
&
place
)
const
{
auto
iter
=
device_codes_
.
find
(
place
);
auto
iter
=
device_codes_
.
find
(
place
);
if
(
iter
==
device_codes_
.
end
())
{
if
(
iter
==
device_codes_
.
end
())
{
return
0
;
return
0
;
...
@@ -130,5 +129,4 @@ class DeviceCodePool {
...
@@ -130,5 +129,4 @@ class DeviceCodePool {
DISABLE_COPY_AND_ASSIGN
(
DeviceCodePool
);
DISABLE_COPY_AND_ASSIGN
(
DeviceCodePool
);
};
};
}
// namespace platform
}
// namespace phi
}
// namespace paddle
paddle/phi/kernels/CMakeLists.txt
浏览文件 @
26da689d
...
@@ -153,6 +153,11 @@ if(WITH_CUTLASS)
...
@@ -153,6 +153,11 @@ if(WITH_CUTLASS)
list
(
APPEND kernel_cu
${
cutlass_cu
}
)
list
(
APPEND kernel_cu
${
cutlass_cu
}
)
endif
()
endif
()
if
(
APPLE OR WIN32
)
list
(
REMOVE_ITEM kernel_cu
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/fusion/gpu/fusion_group_kernel.cu"
)
endif
()
if
(
WITH_MKLDNN
)
if
(
WITH_MKLDNN
)
file
(
file
(
GLOB
GLOB
...
...
paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu
0 → 100644
浏览文件 @
26da689d
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/phi/backends/device_code.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
DeviceContext
>
static
void
MutableMultiTypeData
(
std
::
vector
<
phi
::
DenseTensor
*>*
var
,
const
std
::
vector
<
int
>&
data_type
,
const
DeviceContext
&
dev_ctx
)
{
for
(
size_t
i
=
0
;
i
<
var
->
size
();
i
++
)
{
if
(
data_type
[
i
]
==
phi
::
TransToProtoVarType
(
phi
::
DataType
::
FLOAT32
))
{
dev_ctx
.
template
Alloc
<
float
>((
*
var
)[
i
],
(
*
var
)[
i
]
->
numel
()
*
sizeof
(
float
));
}
else
if
(
data_type
[
i
]
==
phi
::
TransToProtoVarType
(
phi
::
DataType
::
FLOAT16
))
{
dev_ctx
.
template
Alloc
<
phi
::
dtype
::
float16
>(
(
*
var
)[
i
],
(
*
var
)[
i
]
->
numel
()
*
sizeof
(
phi
::
dtype
::
float16
));
}
else
if
(
data_type
[
i
]
==
phi
::
TransToProtoVarType
(
phi
::
DataType
::
FLOAT64
))
{
dev_ctx
.
template
Alloc
<
double
>((
*
var
)[
i
],
(
*
var
)[
i
]
->
numel
()
*
sizeof
(
double
));
}
}
}
template
<
typename
T
,
typename
Context
>
void
FusionGroupKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
ins
,
const
std
::
vector
<
int
>&
outs_dtype
,
const
std
::
vector
<
int
>&
inputs_dtype
,
const
std
::
string
&
func_name
,
int
type
,
std
::
vector
<
DenseTensor
*>
outs
)
{
size_t
num_ins
=
ins
.
size
();
size_t
num_outs
=
outs
.
size
();
MutableMultiTypeData
(
&
outs
,
outs_dtype
,
dev_ctx
);
phi
::
DeviceCode
*
dev_code
=
phi
::
DeviceCodePool
::
Instance
().
Get
(
dev_ctx
.
GetPlace
(),
func_name
);
VLOG
(
3
)
<<
"func_name: "
<<
func_name
;
if
(
type
==
0
)
{
size_t
n
=
ins
[
0
]
->
numel
();
std
::
vector
<
void
*>
args
;
args
.
push_back
(
&
n
);
std
::
vector
<
const
void
*>
ptrs
(
num_ins
+
num_outs
);
for
(
size_t
i
=
0
;
i
<
num_ins
;
++
i
)
{
if
(
inputs_dtype
[
i
]
==
phi
::
TransToProtoVarType
(
phi
::
DataType
::
FLOAT16
))
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
phi
::
dtype
::
float16
>
();
}
else
if
(
inputs_dtype
[
i
]
==
phi
::
TransToProtoVarType
(
phi
::
DataType
::
FLOAT32
))
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
float
>
();
}
else
if
(
inputs_dtype
[
i
]
==
phi
::
TransToProtoVarType
(
phi
::
DataType
::
FLOAT64
))
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
double
>
();
}
args
.
push_back
(
&
ptrs
[
i
]);
}
for
(
size_t
j
=
0
;
j
<
num_outs
;
++
j
)
{
if
(
outs_dtype
[
j
]
==
phi
::
TransToProtoVarType
(
phi
::
DataType
::
FLOAT16
))
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
phi
::
dtype
::
float16
>
();
}
else
if
(
outs_dtype
[
j
]
==
phi
::
TransToProtoVarType
(
phi
::
DataType
::
FLOAT32
))
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
float
>
();
}
else
if
(
outs_dtype
[
j
]
==
phi
::
TransToProtoVarType
(
phi
::
DataType
::
FLOAT64
))
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
double
>
();
}
args
.
push_back
(
&
ptrs
[
num_ins
+
j
]);
}
dev_code
->
Launch
(
n
,
&
args
);
}
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
fusion_group
,
GPU
,
ALL_LAYOUT
,
phi
::
fusion
::
FusionGroupKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/
fluid/operators/fused/fusion_group_op.cu
.cc
→
paddle/
phi/ops/compat/fusion_group_sig
.cc
浏览文件 @
26da689d
/* Copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 20
23
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -12,16 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,16 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/
fluid/operators/fused/fusion_group_op
.h"
#include "paddle/
phi/core/compat/op_utils
.h"
#include "paddle/fluid/platform/float16.h"
namespace
phi
{
namespace
ops
=
paddle
::
operators
;
KernelSignature
FusionGroupOpArgumentMapping
(
namespace
plat
=
paddle
::
platform
;
const
ArgumentMappingContext
&
ctx
)
{
PD_REGISTER_STRUCT_KERNEL
(
fusion_group
,
return
KernelSignature
(
"fusion_group"
,
GPU
,
{
"Inputs"
},
ALL_LAYOUT
,
{
"outs_dtype"
,
"inputs_dtype"
,
"func_name"
,
"type"
},
ops
::
FusionGroupKernel
,
{
"Outs"
});
float
,
}
double
,
plat
::
float16
)
{}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
fusion_group
,
phi
::
FusionGroupOpArgumentMapping
);
test/cpp/fluid/fused/fusion_group_op_test.cc
浏览文件 @
26da689d
...
@@ -17,8 +17,8 @@ limitations under the License. */
...
@@ -17,8 +17,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/device_code.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/phi/backends/device_code.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -93,11 +93,10 @@ framework::OpDesc* CreateFusionGroupOp(
...
@@ -93,11 +93,10 @@ framework::OpDesc* CreateFusionGroupOp(
void
PrepareDeviceCode
(
platform
::
Place
place
,
void
PrepareDeviceCode
(
platform
::
Place
place
,
std
::
string
func_name
,
std
::
string
func_name
,
std
::
string
cuda_kernel_str
)
{
std
::
string
cuda_kernel_str
)
{
paddle
::
platform
::
DeviceCodePool
&
pool
=
phi
::
DeviceCodePool
&
pool
=
phi
::
DeviceCodePool
::
Init
({
place
});
paddle
::
platform
::
DeviceCodePool
::
Init
({
place
});
std
::
unique_ptr
<
p
addle
::
platform
::
DeviceCode
>
code
(
std
::
unique_ptr
<
p
hi
::
DeviceCode
>
code
(
new
p
addle
::
platform
::
CUDA
DeviceCode
(
place
,
func_name
,
cuda_kernel_str
));
new
p
hi
::
GPU
DeviceCode
(
place
,
func_name
,
cuda_kernel_str
));
code
->
Compile
();
code
->
Compile
();
pool
.
Set
(
std
::
move
(
code
));
pool
.
Set
(
std
::
move
(
code
));
}
}
...
@@ -183,7 +182,7 @@ void TestMain(const std::vector<std::string>& input_names,
...
@@ -183,7 +182,7 @@ void TestMain(const std::vector<std::string>& input_names,
}
}
TEST
(
FusionGroupOp
,
elementwise
)
{
TEST
(
FusionGroupOp
,
elementwise
)
{
if
(
!
p
latform
::
dynload
::
HasNVRTC
()
||
!
platform
::
dynload
::
HasCUDADriver
())
{
if
(
!
p
hi
::
dynload
::
HasNVRTC
()
||
!
phi
::
dynload
::
HasCUDADriver
())
{
return
;
return
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录