Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ed8f44ea
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ed8f44ea
编写于
9月 06, 2019
作者:
W
wangchaochaohu
提交者:
GitHub
9月 06, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
codegen for fused elementwise operation (#19520)
* test=develop codegen for fused elementwise operation * fix test=develop
上级
25c0eb28
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
313 addition
and
0 deletion
+313
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+3
-0
paddle/fluid/framework/ir/codegen.cc
paddle/fluid/framework/ir/codegen.cc
+96
-0
paddle/fluid/framework/ir/codegen.h
paddle/fluid/framework/ir/codegen.h
+36
-0
paddle/fluid/framework/ir/codegen_helper.cc
paddle/fluid/framework/ir/codegen_helper.cc
+61
-0
paddle/fluid/framework/ir/codegen_helper.h
paddle/fluid/framework/ir/codegen_helper.h
+70
-0
paddle/fluid/framework/ir/codegen_test.cc
paddle/fluid/framework/ir/codegen_test.cc
+43
-0
paddle/fluid/operators/math.h
paddle/fluid/operators/math.h
+4
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
ed8f44ea
...
@@ -30,6 +30,8 @@ function(pass_library TARGET DEST)
...
@@ -30,6 +30,8 @@ function(pass_library TARGET DEST)
endif
()
endif
()
endfunction
()
endfunction
()
cc_library
(
codegen SRCS codegen.cc DEPS codegen_helper
)
cc_library
(
codegen_helper SRCS codegen_helper.cc DEPS graph node graph_helper
)
cc_library
(
node SRCS node.cc DEPS proto_desc
)
cc_library
(
node SRCS node.cc DEPS proto_desc
)
cc_library
(
graph SRCS graph.cc DEPS node pretty_log
)
cc_library
(
graph SRCS graph.cc DEPS node pretty_log
)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
...
@@ -107,6 +109,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
...
@@ -107,6 +109,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_library
(
pass_builder SRCS pass_builder.cc DEPS pass
)
cc_library
(
pass_builder SRCS pass_builder.cc DEPS pass
)
cc_test
(
codegen_test SRCS codegen_test.cc DEPS codegen_helper codegen
)
cc_test
(
node_test SRCS node_test.cc DEPS node
)
cc_test
(
node_test SRCS node_test.cc DEPS node
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
...
...
paddle/fluid/framework/ir/codegen.cc
0 → 100644
浏览文件 @
ed8f44ea
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/codegen.h"
#include <set>
#include <sstream>
#include "paddle/fluid/framework/ir/codegen_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// we get the parameter list code for the expression information
std
::
string
CodeGen
::
GetDeclarationCode
(
std
::
vector
<
OperationExpression
>
expression
)
{
std
::
stringstream
ret
;
ret
<<
"fuse_kernel"
;
ret
<<
R"((int N )"
;
std
::
set
<
int
>
input_ids
;
std
::
set
<
int
>
output_ids
;
std
::
vector
<
int
>
last_output_idis
;
for
(
size_t
i
=
0
;
i
<
expression
.
size
();
i
++
)
{
std
::
vector
<
int
>
tmp_input
=
expression
[
i
].
GetInputIds
();
for
(
size_t
j
=
0
;
j
<
tmp_input
.
size
();
j
++
)
{
int
id
=
tmp_input
[
j
];
input_ids
.
insert
(
id
);
}
int
tmp_output
=
expression
[
i
].
GetOutputId
();
output_ids
.
insert
(
tmp_output
);
}
std
::
set
<
int
>::
iterator
it
=
input_ids
.
begin
();
while
(
it
!=
input_ids
.
end
())
{
int
var_index
=
*
it
;
if
(
output_ids
.
find
(
var_index
)
!=
output_ids
.
end
())
{
input_ids
.
erase
(
it
++
);
}
else
{
it
++
;
}
}
for
(
it
=
input_ids
.
begin
();
it
!=
input_ids
.
end
();
it
++
)
{
int
var_index
=
*
it
;
ret
<<
R"(, const T* var)"
<<
var_index
;
}
for
(
it
=
output_ids
.
begin
();
it
!=
output_ids
.
end
();
it
++
)
{
int
var_index
=
*
it
;
ret
<<
R"(, T* var)"
<<
var_index
;
}
ret
<<
R"())"
;
return
ret
.
str
();
}
std
::
string
CodeGen
::
GetOffsetCode
()
{
std
::
stringstream
ret
;
ret
<<
indentation
<<
"int offset = idx;"
<<
std
::
endl
;
return
ret
.
str
();
}
std
::
string
CodeGen
::
GetComputeCode
(
std
::
vector
<
OperationExpression
>
expression
)
{
// get the right experssion code using suffix expression
std
::
stringstream
ret
;
for
(
size_t
i
=
0
;
i
<
expression
.
size
();
i
++
)
{
ret
<<
expression
[
i
].
GetExpression
();
}
return
ret
.
str
();
}
// in order to get the right result of expression, we need to calculate, we
// store the expression as
// suffix Expressions using vector
std
::
string
CodeGen
::
GetKernelCode
(
std
::
vector
<
OperationExpression
>
expression
)
{
auto
declaration_code
=
GetDeclarationCode
(
expression
);
auto
offset_code
=
GetOffsetCode
();
auto
compute_code
=
GetComputeCode
(
expression
);
auto
cuda_kernel
=
const_kernel_start
+
declaration_code
+
const_kernel_mid
+
offset_code
+
compute_code
+
const_kernel_end
;
return
cuda_kernel
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/codegen.h
0 → 100644
浏览文件 @
ed8f44ea
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/codegen_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
CodeGen
{
public:
std
::
string
GetKernelCode
(
std
::
vector
<
OperationExpression
>
expression
);
private:
std
::
string
GetDeclarationCode
(
std
::
vector
<
paddle
::
framework
::
ir
::
OperationExpression
>
expression
);
std
::
string
GetOffsetCode
();
std
::
string
GetComputeCode
(
std
::
vector
<
paddle
::
framework
::
ir
::
OperationExpression
>
expression
);
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/codegen_helper.cc
0 → 100644
浏览文件 @
ed8f44ea
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/ir/codegen_helper.h"
#include <algorithm>
#include <sstream>
#include <string>
#include <vector>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
OperationExpression
::
OperationExpression
(
std
::
vector
<
int
>
input_ids
,
int
output_id
,
std
::
string
search_operation
)
{
input_ids_
=
input_ids
;
output_id_
=
output_id
;
search_operation_
=
search_operation
;
}
// we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group
std
::
string
OperationExpression
::
GetExpression
()
{
std
::
stringstream
ret
;
if
(
operator_cuda_table
.
find
(
search_operation_
)
==
operator_cuda_table
.
end
())
{
std
::
cerr
<<
"Not supportted operation, "
<<
search_operation_
<<
std
::
endl
;
}
else
{
auto
rhs
=
operator_cuda_table
[
search_operation_
];
std
::
string
replaced_str
=
"$"
;
int
count
=
0
;
auto
pos
=
rhs
.
find
(
replaced_str
);
while
(
pos
!=
-
1
)
{
auto
index
=
input_ids_
[
count
];
rhs
.
replace
(
pos
,
replaced_str
.
length
(),
std
::
to_string
(
index
)
+
R"([offset])"
);
pos
=
rhs
.
find
(
replaced_str
);
count
++
;
}
auto
lhs
=
std
::
string
(
indentation
)
+
"var"
+
std
::
to_string
(
output_id_
)
+
R"([offset])"
;
auto
equal_split
=
R"( = )"
;
auto
semicolon
=
R"(;)"
;
ret
<<
lhs
<<
equal_split
<<
rhs
<<
semicolon
<<
std
::
endl
;
}
return
ret
.
str
();
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/codegen_helper.h
0 → 100644
浏览文件 @
ed8f44ea
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
operator_cuda_table
=
{
{
"elementwise_add"
,
"var$ + var$"
},
{
"elementwise_sub"
,
"var$ - var$"
},
{
"elementwise_mul"
,
"var$ * var$"
},
{
"elementwise_div"
,
"var$ / var$"
},
{
"elementwise_min"
,
"real_min(var$, var$)"
},
{
"elementwise_max"
,
"real_max(var$, var$)"
},
{
"relu"
,
"real_max(var$, 0)"
},
{
"sigmoid"
,
"1.0 / (1.0 + real_exp(-var$))"
}};
// op computation is composed by single or many operation
class
OperationExpression
{
public:
OperationExpression
(
std
::
vector
<
int
>
input_ids
,
int
output_id
,
std
::
string
search_oprtation
);
std
::
string
GetExpression
();
std
::
vector
<
int
>
GetInputIds
()
{
return
input_ids_
;
}
int
GetOutputId
()
{
return
output_id_
;
}
private:
std
::
vector
<
int
>
input_ids_
;
int
output_id_
;
std
::
string
search_operation_
;
};
static
const
char
indentation
[]
=
R"( )"
;
static
const
char
const_kernel_start
[]
=
R"(
template <typename T>
extern "C" __global__ void
)"
;
static
const
char
const_kernel_mid
[]
=
R"(
{
for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < N;
idx += gridDim.x * blockDim.x) {
)"
;
static
const
char
const_kernel_end
[]
=
R"(
}
}
)"
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/codegen_test.cc
0 → 100644
浏览文件 @
ed8f44ea
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/codegen.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/codegen_helper.h"
#ifdef PADDLE_WITH_CUDA
TEST
(
codegen
,
cuda
)
{
std
::
vector
<
int
>
mul_input
{
1
,
2
};
std
::
vector
<
int
>
add_input
{
3
,
4
};
std
::
vector
<
int
>
sigmod_input
{
5
};
int
mul_out
=
3
;
int
add_out
=
5
;
int
sigmod_out
=
6
;
std
::
string
op1
=
"elementwise_mul"
;
std
::
string
op2
=
"elementwise_add"
;
std
::
string
op3
=
"sigmoid"
;
paddle
::
framework
::
ir
::
OperationExpression
opexp1
(
mul_input
,
mul_out
,
op1
);
paddle
::
framework
::
ir
::
OperationExpression
opexp2
(
add_input
,
add_out
,
op2
);
paddle
::
framework
::
ir
::
OperationExpression
opexp3
(
sigmod_input
,
sigmod_out
,
op3
);
std
::
vector
<
paddle
::
framework
::
ir
::
OperationExpression
>
fused_op
=
{
opexp1
,
opexp2
,
opexp3
};
paddle
::
framework
::
ir
::
CodeGen
codegen
;
std
::
string
result
=
codegen
.
GetKernelCode
(
fused_op
);
std
::
cout
<<
result
<<
std
::
endl
;
}
#endif
paddle/fluid/operators/math.h
浏览文件 @
ed8f44ea
...
@@ -38,5 +38,9 @@ inline HOSTDEVICE float real_log(float x) { return ::logf(x); }
...
@@ -38,5 +38,9 @@ inline HOSTDEVICE float real_log(float x) { return ::logf(x); }
inline
HOSTDEVICE
double
real_log
(
double
x
)
{
return
::
log
(
x
);
}
inline
HOSTDEVICE
double
real_log
(
double
x
)
{
return
::
log
(
x
);
}
inline
HOSTDEVICE
float
real_min
(
float
x
,
float
y
)
{
return
::
fminf
(
x
,
y
);
}
inline
HOSTDEVICE
double
real_min
(
double
x
,
double
y
)
{
return
::
fmin
(
x
,
y
);
}
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录