Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c249556d
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c249556d
编写于
11月 25, 2021
作者:
Z
Zhen Wang
提交者:
GitHub
11月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Pass the stream created by Paddle to CINN. (#37337)
上级
a4ef88ed
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
105 addition
and
15 deletion
+105
-15
paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
+7
-6
paddle/fluid/framework/paddle2cinn/cinn_compiler.h
paddle/fluid/framework/paddle2cinn/cinn_compiler.h
+4
-3
paddle/fluid/operators/cinn_launch_op.cc
paddle/fluid/operators/cinn_launch_op.cc
+3
-2
paddle/fluid/operators/cinn_launch_op.cu.cc
paddle/fluid/operators/cinn_launch_op.cu.cc
+50
-0
paddle/fluid/operators/cinn_launch_op.h
paddle/fluid/operators/cinn_launch_op.h
+41
-4
未找到文件。
paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
浏览文件 @
c249556d
...
...
@@ -66,7 +66,7 @@ CinnCompiler* CinnCompiler::GetInstance() {
const
CinnCompiledObject
&
CinnCompiler
::
Compile
(
const
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
Target
&
target
)
{
const
Target
&
target
,
void
*
stream
)
{
VLOG
(
1
)
<<
"-- The graph to be compiled is:
\n
"
<<
VizGraph
(
graph
);
CinnCacheKey
cur_key
(
graph
,
input_tensors
,
target
.
arch_str
());
bool
exist
=
false
;
...
...
@@ -77,7 +77,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
if
(
!
exist
)
{
std
::
int64_t
compiled_num
=
real_compiled_num_
.
fetch_add
(
1
);
auto
compiled_res
=
CompileGraph
(
graph
,
input_tensors
,
target
,
compiled_num
);
CompileGraph
(
graph
,
input_tensors
,
target
,
compiled_num
,
stream
);
AutoWRLock
w_guard
{
&
rwlock_
};
if
(
!
cache_
.
count
(
cur_key
))
{
cache_
[
cur_key
]
=
std
::
move
(
compiled_res
);
...
...
@@ -91,9 +91,9 @@ const CinnCompiledObject& CinnCompiler::Compile(
const
CinnCompiledObject
&
CinnCompiler
::
Compile
(
const
std
::
string
&
compilation_key
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
Target
&
target
)
{
const
Target
&
target
,
void
*
stream
)
{
const
auto
&
graph
=
FindGraph
(
compilation_key
);
return
Compile
(
graph
,
input_tensors
,
target
);
return
Compile
(
graph
,
input_tensors
,
target
,
stream
);
}
std
::
string
CinnCompiler
::
AddGraph
(
std
::
unique_ptr
<
Graph
>
graph
)
{
...
...
@@ -189,7 +189,7 @@ void CinnCompiler::Clear() {
std
::
unique_ptr
<
CinnCompiledObject
>
CinnCompiler
::
CompileGraph
(
const
ir
::
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
Target
&
target
,
std
::
int64_t
compiled_num
)
const
{
const
Target
&
target
,
std
::
int64_t
compiled_num
,
void
*
stream
)
const
{
CinnGraphSymbolization
symbol
{
compiled_num
,
graph
,
target
,
input_tensors
};
auto
frontend_program
=
symbol
();
ProgramPass
::
Apply
(
&
frontend_program
,
target
,
{
"Decomposer"
});
...
...
@@ -209,7 +209,8 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
std
::
make_unique
<
GraphCompiler
>
(
target
,
scope
,
cinn_graph
);
GraphCompiler
::
CompileOptions
options
;
options
.
with_instantiate_variables
=
false
;
auto
compiled_res
=
graph_compiler
->
Build
(
options
,
std
::
move
(
fetch_ids
));
auto
compiled_res
=
graph_compiler
->
Build
(
options
,
std
::
move
(
fetch_ids
),
stream
);
auto
compiled_obj
=
std
::
make_unique
<
CinnCompiledObject
>
();
*
compiled_obj
=
{
std
::
move
(
graph_compiler
),
std
::
move
(
compiled_res
.
runtime_program
),
scope
,
...
...
paddle/fluid/framework/paddle2cinn/cinn_compiler.h
浏览文件 @
c249556d
...
...
@@ -55,12 +55,12 @@ class CinnCompiler {
const
CinnCompiledObject
&
Compile
(
const
ir
::
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
::
cinn
::
common
::
Target
&
target
);
const
::
cinn
::
common
::
Target
&
target
,
void
*
stream
=
nullptr
);
const
CinnCompiledObject
&
Compile
(
const
std
::
string
&
compilation_key
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
::
cinn
::
common
::
Target
&
target
);
const
::
cinn
::
common
::
Target
&
target
,
void
*
stream
=
nullptr
);
std
::
string
AddGraph
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
);
...
...
@@ -83,7 +83,8 @@ class CinnCompiler {
std
::
unique_ptr
<
CinnCompiledObject
>
CompileGraph
(
const
ir
::
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
::
cinn
::
common
::
Target
&
target
,
std
::
int64_t
compiled_num
)
const
;
const
::
cinn
::
common
::
Target
&
target
,
std
::
int64_t
compiled_num
,
void
*
stream
=
nullptr
)
const
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
ir
::
Graph
>>
graphs_
;
std
::
unordered_map
<
CinnCacheKey
,
std
::
unique_ptr
<
CinnCompiledObject
>
,
...
...
paddle/fluid/operators/cinn_launch_op.cc
浏览文件 @
c249556d
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/operators/cinn_launch_op.h"
#include <vector>
#include "paddle/fluid/string/string_helper.h"
DECLARE_bool
(
cudnn_deterministic
);
...
...
@@ -65,8 +66,8 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result) {
}
void
LaunchCinnExecution
(
const
CinnCompiledObject
&
compiled_obj
,
const
CinnLaunchContext
&
context
)
{
compiled_obj
.
runtime_program
->
Execute
(
&
context
.
FinalizeArguments
());
const
CinnLaunchContext
&
context
,
void
*
stream
)
{
compiled_obj
.
runtime_program
->
Execute
(
&
context
.
FinalizeArguments
()
,
stream
);
}
void
SetCinnRuntimeFlags
()
{
...
...
paddle/fluid/operators/cinn_launch_op.cu.cc
浏览文件 @
c249556d
...
...
@@ -13,6 +13,56 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cinn_launch_op.h"
#include <memory>
#include <vector>
#include "cinn/runtime/cinn_runtime.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/type_defs.h"
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
namespace
paddle
{
namespace
operators
{
namespace
details
{
#ifdef PADDLE_WITH_CUDA
void
CUDART_CB
ReleaseScope
(
void
*
data
)
{
auto
*
temp_scope
=
static_cast
<
framework
::
Scope
*>
(
data
);
delete
temp_scope
;
}
void
CUDART_CB
ReleaseBuffers
(
void
*
data
)
{
auto
*
buffers
=
static_cast
<
std
::
vector
<
std
::
unique_ptr
<
cinn_buffer_t
>>*>
(
data
);
delete
buffers
;
}
template
<
>
void
ReleaseResource
<
platform
::
CUDADeviceContext
>
(
const
std
::
vector
<
void
*>&
resources
,
void
*
stream
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaLaunchHostFunc
(
static_cast
<
gpuStream_t
>
(
stream
),
ReleaseScope
,
resources
[
0
]));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaLaunchHostFunc
(
static_cast
<
gpuStream_t
>
(
stream
),
ReleaseBuffers
,
resources
[
1
]));
}
template
<
>
void
*
GetStream
<
platform
::
CUDADeviceContext
>
(
const
framework
::
ExecutionContext
&
ctx
)
{
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
return
dev_ctx
.
stream
();
}
#endif
}
// namespace details
}
// namespace operators
}
// namespace paddle
/* see [Why use single type kernel] */
REGISTER_OP_CUDA_KERNEL
(
cinn_launch
,
...
...
paddle/fluid/operators/cinn_launch_op.h
浏览文件 @
c249556d
...
...
@@ -67,6 +67,10 @@ class CinnLaunchContext {
// Finalize all execution arguments and return them
const
std
::
map
<
std
::
string
,
cinn_pod_value_t
>&
FinalizeArguments
()
const
;
std
::
vector
<
std
::
unique_ptr
<
cinn_buffer_t
>>
HandoverBuffers
()
{
return
std
::
move
(
hold_buffers_
);
}
private:
// Get CinnTensor with CINN variable name
CinnTensor
GetCinnTensor
(
const
std
::
string
&
var_name
);
...
...
@@ -110,10 +114,35 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result);
// Launch cinn to execute compiled executable program and wait done
void
LaunchCinnExecution
(
const
CinnCompiledObject
&
compiled_obj
,
const
CinnLaunchContext
&
context
);
const
CinnLaunchContext
&
context
,
void
*
stream
);
// Set cinn FLAGS (such as FLAGS_cinn_cudnn_deterministic) with paddle's FLAGS.
void
SetCinnRuntimeFlags
();
template
<
typename
DeviceContext
>
void
ReleaseResource
(
const
std
::
vector
<
void
*>&
resources
,
void
*
stream
)
{
auto
*
temp_scope
=
static_cast
<
framework
::
Scope
*>
(
resources
[
0
]);
auto
*
buffers
=
static_cast
<
std
::
vector
<
std
::
unique_ptr
<
cinn_buffer_t
>>*>
(
resources
[
1
]);
delete
temp_scope
;
delete
buffers
;
}
template
<
typename
DeviceContext
>
void
*
GetStream
(
const
framework
::
ExecutionContext
&
ctx
)
{
return
nullptr
;
}
#ifdef PADDLE_WITH_CUDA
template
<
>
void
ReleaseResource
<
platform
::
CUDADeviceContext
>
(
const
std
::
vector
<
void
*>&
resources
,
void
*
stream
);
template
<
>
void
*
GetStream
<
platform
::
CUDADeviceContext
>
(
const
framework
::
ExecutionContext
&
ctx
);
#endif
}
// namespace details
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -122,6 +151,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
&
scope
=
ctx
.
scope
();
const
auto
&
place
=
ctx
.
GetPlace
();
void
*
stream
=
details
::
GetStream
<
DeviceContext
>
(
ctx
);
// Step 1. Find graph object and prepare input
PADDLE_ENFORCE_EQ
(
ctx
.
HasAttr
(
kCompilationKey
),
true
,
platform
::
errors
::
NotFound
(
...
...
@@ -146,7 +176,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
// Step 2. Get compilation result of the graph
auto
target
=
details
::
PlaceToCinnTarget
(
place
);
const
auto
&
cinn_compiled_object
=
CinnCompiler
::
GetInstance
()
->
Compile
(
compilation_key
,
inputs_name2tensor
,
target
);
compilation_key
,
inputs_name2tensor
,
target
,
stream
);
details
::
DebugCinnCompiledResult
(
cinn_compiled_object
);
auto
launch_context
=
...
...
@@ -199,7 +229,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
// names, because they will not be used outside the graph
// and should be destructed after computation finished.
auto
internal_variable_names
=
launch_context
->
GetInternalVariableNames
();
auto
temp_scope
=
scope
.
NewTmpScop
e
();
framework
::
Scope
*
temp_scope
=
scope
.
NewTmpScope
().
releas
e
();
for
(
const
auto
&
var_name
:
internal_variable_names
)
{
auto
*
tensor
=
temp_scope
->
Var
(
var_name
)
->
GetMutable
<
LoDTensor
>
();
launch_context
->
MutableTensorData
(
var_name
,
place
,
tensor
,
true
);
...
...
@@ -210,8 +240,15 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
details
::
SetCinnRuntimeFlags
();
// Step 5. Launch CINN to execute the compiled executable program
details
::
LaunchCinnExecution
(
cinn_compiled_object
,
*
launch_context
);
VLOG
(
4
)
<<
"Run Cinn compiled executable program with stream: "
<<
stream
;
details
::
LaunchCinnExecution
(
cinn_compiled_object
,
*
launch_context
,
stream
);
VLOG
(
4
)
<<
"CinnLaunchOp launch execution done."
;
// Step 6. Release some resources, such as `temp_scope` and cinn_buffers.
auto
*
buffers_holder
=
new
std
::
vector
<
std
::
unique_ptr
<
cinn_buffer_t
>>
{
launch_context
->
HandoverBuffers
()};
details
::
ReleaseResource
<
DeviceContext
>
({
temp_scope
,
buffers_holder
},
stream
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录