Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
eada00c2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eada00c2
编写于
4月 17, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
init optimizer
上级
239d716b
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
217 addition
and
74 deletion
+217
-74
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+7
-0
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+18
-3
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+3
-2
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+14
-0
paddle/fluid/lite/core/mir/generate_program_pass.h
paddle/fluid/lite/core/mir/generate_program_pass.h
+2
-1
paddle/fluid/lite/core/mir/io_complement_pass.cc
paddle/fluid/lite/core/mir/io_complement_pass.cc
+15
-0
paddle/fluid/lite/core/mir/io_complement_pass.h
paddle/fluid/lite/core/mir/io_complement_pass.h
+2
-1
paddle/fluid/lite/core/mir/pass_manager.cc
paddle/fluid/lite/core/mir/pass_manager.cc
+0
-1
paddle/fluid/lite/core/mir/pass_manager.h
paddle/fluid/lite/core/mir/pass_manager.h
+12
-5
paddle/fluid/lite/core/mir/ssa_graph_test.cc
paddle/fluid/lite/core/mir/ssa_graph_test.cc
+2
-52
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+46
-0
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
+53
-1
paddle/fluid/lite/core/op_registry.h
paddle/fluid/lite/core/op_registry.h
+4
-2
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+7
-6
paddle/fluid/lite/core/types.h
paddle/fluid/lite/core/types.h
+32
-0
未找到文件。
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
eada00c2
...
@@ -9,8 +9,14 @@ cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_l
...
@@ -9,8 +9,14 @@ cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_l
#TODO(Superjomn) remove these dependencies from original framework
#TODO(Superjomn) remove these dependencies from original framework
proto_desc
)
proto_desc
)
cc_library
(
kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite
)
cc_library
(
kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS tensor_lite
)
cc_library
(
type_system SRCS type_system.cc DEPS tensor_lite
)
cc_library
(
optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager
)
cc_library
(
optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager
)
cc_library
(
program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph
scope_lite op_registry_lite proto_desc op_lite
ops_lite
host_kernels
)
cc_test
(
test_scope_lite SRCS scope_test.cc DEPS scope_lite
)
cc_test
(
test_scope_lite SRCS scope_test.cc DEPS scope_lite
)
cc_test
(
test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86
)
cc_test
(
test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86
)
...
@@ -18,5 +24,6 @@ cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
...
@@ -18,5 +24,6 @@ cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
cc_test
(
test_tensor_lite SRCS tensor_test.cc
)
cc_test
(
test_tensor_lite SRCS tensor_test.cc
)
cc_test
(
test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels
)
cc_test
(
test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels
)
cc_test
(
test_type_system SRCS type_system_test.cc DEPS type_system
)
cc_test
(
test_type_system SRCS type_system_test.cc DEPS type_system
)
cc_test
(
test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes
)
add_subdirectory
(
mir
)
add_subdirectory
(
mir
)
paddle/fluid/lite/core/kernel.h
浏览文件 @
eada00c2
...
@@ -15,7 +15,9 @@
...
@@ -15,7 +15,9 @@
#pragma once
#pragma once
#include <map>
#include <map>
#include <set>
#include <string>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
...
@@ -48,17 +50,24 @@ class KernelBase {
...
@@ -48,17 +50,24 @@ class KernelBase {
return
param_
.
get
<
Param
>
();
return
param_
.
get
<
Param
>
();
}
}
void
set_op_type
(
const
std
::
string
&
type
)
{
op_type_
=
type
;
}
const
std
::
string
&
op_type
()
const
{
return
op_type_
;
}
void
Torch
()
{}
void
Torch
()
{}
virtual
TargetType
target
()
const
=
0
;
virtual
TargetType
target
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
DataLayoutType
layout
()
const
=
0
;
virtual
DataLayoutType
layout
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
~
KernelBase
()
=
default
;
virtual
~
KernelBase
()
=
default
;
protected:
protected:
core
::
any_context_t
context_
;
core
::
any_context_t
context_
;
mutable
operators
::
param_t
param_
;
mutable
operators
::
param_t
param_
;
// The corresponding op type.
std
::
string
op_type_
;
};
};
/*
/*
...
@@ -73,8 +82,9 @@ struct ParamType {
...
@@ -73,8 +82,9 @@ struct ParamType {
Place
tensor_place
{};
Place
tensor_place
{};
const
Type
*
type_
;
const
Type
*
type_
;
ParamType
()
=
default
;
explicit
ParamType
()
=
default
;
ParamType
(
size_t
element_type_hash
)
:
element_type_hash
(
element_type_hash
)
{}
explicit
ParamType
(
size_t
element_type_hash
)
:
element_type_hash
(
element_type_hash
)
{}
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
ParamType
(
const
Type
*
type
)
:
type_
(
type
)
{}
ParamType
(
const
Type
*
type
)
:
type_
(
type
)
{}
...
@@ -135,7 +145,8 @@ class ParamTypeRegistry {
...
@@ -135,7 +145,8 @@ class ParamTypeRegistry {
* PRECISION(kFloat)});
* PRECISION(kFloat)});
*/
*/
struct
NewInstance
{
struct
NewInstance
{
NewInstance
(
const
std
::
string
&
kernel_type
)
:
kernel_type_
(
kernel_type
)
{}
explicit
NewInstance
(
const
std
::
string
&
kernel_type
)
:
kernel_type_
(
kernel_type
)
{}
NewInstance
&
BindInput
(
int
offset
,
const
ParamType
&
ptype
)
{
NewInstance
&
BindInput
(
int
offset
,
const
ParamType
&
ptype
)
{
ParamTypeRegistry
::
Global
().
Register
<
IO
::
kInput
>
(
ParamTypeRegistry
::
Global
().
Register
<
IO
::
kInput
>
(
...
@@ -205,6 +216,10 @@ class OpKernel : public KernelBase {
...
@@ -205,6 +216,10 @@ class OpKernel : public KernelBase {
TargetType
target
()
const
override
{
return
Target
;
}
TargetType
target
()
const
override
{
return
Target
;
}
PrecisionType
precision
()
const
override
{
return
Precision
;
}
PrecisionType
precision
()
const
override
{
return
Precision
;
}
DataLayoutType
layout
()
const
override
{
return
DataLayout
;
}
DataLayoutType
layout
()
const
override
{
return
DataLayout
;
}
std
::
string
name
()
const
override
{
return
op_type
()
+
":"
+
TargetToStr
(
Target
)
+
"/"
+
PrecisionToStr
(
Precision
)
+
"/"
+
DataLayoutToStr
(
DataLayout
);
}
void
Touch
()
{}
void
Touch
()
{}
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
eada00c2
...
@@ -9,13 +9,14 @@ cc_library(mir_passes
...
@@ -9,13 +9,14 @@ cc_library(mir_passes
graph_visualize_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
generate_program_pass.cc
demo_pass.cc
demo_pass.cc
DEPS mir_pass
)
DEPS mir_pass
types_lite
)
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes
)
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes
)
cc_test
(
test_ssa_graph SRCS ssa_graph_test.cc DEPS
cc_test
(
test_ssa_graph SRCS ssa_graph_test.cc DEPS
mir_ssa_graph scope_lite op_lite
mir_ssa_graph scope_lite op_lite
proto_desc
ops_lite
ops_lite
host_kernels
host_kernels
mir_passes
mir_passes
mir_pass_manager
mir_pass_manager
program_fake_utils
)
)
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
eada00c2
...
@@ -11,3 +11,17 @@
...
@@ -11,3 +11,17 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
{}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
generate_program_pass
,
paddle
::
lite
::
mir
::
GenerateProgramPass
);
paddle/fluid/lite/core/mir/generate_program_pass.h
浏览文件 @
eada00c2
...
@@ -24,8 +24,9 @@ namespace mir {
...
@@ -24,8 +24,9 @@ namespace mir {
* GenerateProgramPass will build the execution program for executor from a mir
* GenerateProgramPass will build the execution program for executor from a mir
* graph.
* graph.
*/
*/
class
GenerateProgramPass
:
public
Pass
{
class
GenerateProgramPass
:
public
P
rogramP
ass
{
public:
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
override
;
};
};
}
// namespace mir
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/io_complement_pass.cc
浏览文件 @
eada00c2
...
@@ -11,3 +11,18 @@
...
@@ -11,3 +11,18 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
IoComplementPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
{}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
io_complement_pass
,
paddle
::
lite
::
mir
::
IoComplementPass
);
paddle/fluid/lite/core/mir/io_complement_pass.h
浏览文件 @
eada00c2
...
@@ -24,8 +24,9 @@ namespace mir {
...
@@ -24,8 +24,9 @@ namespace mir {
* IoComplementPass complement the necessary instruction to make data
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
* transferring or transformation between different places.
*/
*/
class
IoComplementPass
:
public
Pass
{
class
IoComplementPass
:
public
P
rogramP
ass
{
public:
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
override
;
};
};
}
// namespace mir
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/pass_manager.cc
浏览文件 @
eada00c2
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
...
paddle/fluid/lite/core/mir/pass_manager.h
浏览文件 @
eada00c2
...
@@ -32,16 +32,17 @@ class PassManager {
...
@@ -32,16 +32,17 @@ class PassManager {
PassManager
();
PassManager
();
void
Run
()
{
void
Run
(
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
for
(
auto
&
pass
:
passes_
)
{
for
(
auto
&
pass
:
passes_
)
{
LOG
(
INFO
)
<<
"Running MIR pass "
<<
pass
->
name
();
LOG
(
INFO
)
<<
"Running MIR pass "
<<
pass
->
name
();
pass
->
Apply
(
graph
_
);
pass
->
Apply
(
graph
);
}
}
}
}
bool
AddNewPass
(
const
std
::
string
&
name
,
Pass
*
pass
)
{
bool
AddNewPass
(
const
std
::
string
&
name
,
Pass
*
pass
)
{
passes_
.
emplace_back
(
pass
);
passes_
.
emplace_back
(
pass
);
pass_map_
.
emplace
(
name
,
passes_
.
back
().
get
());
pass_map_
.
emplace
(
name
,
passes_
.
back
().
get
());
passes_
.
back
()
->
set_name
(
name
);
return
true
;
return
true
;
}
}
...
@@ -65,12 +66,18 @@ class PassManager {
...
@@ -65,12 +66,18 @@ class PassManager {
Pass
*
LookUp
(
const
std
::
string
&
key
)
{
Pass
*
LookUp
(
const
std
::
string
&
key
)
{
auto
it
=
pass_map_
.
find
(
key
);
auto
it
=
pass_map_
.
find
(
key
);
CHECK
(
it
!=
pass_map_
.
end
());
if
(
it
!=
pass_map_
.
end
())
return
it
->
second
;
return
it
->
second
;
return
nullptr
;
}
template
<
typename
PassTy
>
PassTy
*
LookUp
(
const
std
::
string
&
key
)
{
auto
it
=
pass_map_
.
find
(
key
);
if
(
it
!=
pass_map_
.
end
())
return
dynamic_cast
<
PassTy
*>
(
it
->
second
);
return
nullptr
;
}
}
private:
private:
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>
passes_
;
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>
passes_
;
std
::
map
<
std
::
string
,
mir
::
Pass
*>
pass_map_
;
std
::
map
<
std
::
string
,
mir
::
Pass
*>
pass_map_
;
};
};
...
...
paddle/fluid/lite/core/mir/ssa_graph_test.cc
浏览文件 @
eada00c2
...
@@ -16,7 +16,9 @@
...
@@ -16,7 +16,9 @@
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program_fake_utils.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -32,58 +34,6 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x,
...
@@ -32,58 +34,6 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x,
fc
->
SetOutput
(
"Out"
,
{
out
});
fc
->
SetOutput
(
"Out"
,
{
out
});
}
}
Program
FakeProgram
()
{
Program
program
;
program
.
scope
=
new
lite
::
Scope
;
auto
add_fc
=
[
&
](
int
id
,
std
::
string
x
)
{
// create variables
std
::
string
w1
=
"w"
+
std
::
to_string
(
id
);
std
::
string
b1
=
"b"
+
std
::
to_string
(
id
);
std
::
string
out1
=
"out"
+
std
::
to_string
(
id
);
auto
w1v
=
program
.
scope
->
Var
(
w1
)
->
GetMutable
<
Tensor
>
();
auto
b1v
=
program
.
scope
->
Var
(
b1
)
->
GetMutable
<
Tensor
>
();
auto
out1v
=
program
.
scope
->
Var
(
out1
)
->
GetMutable
<
Tensor
>
();
framework
::
OpDesc
desc
;
desc
.
SetInput
(
"Input"
,
{
x
});
desc
.
SetInput
(
"W"
,
{
w1
});
desc
.
SetInput
(
"Bias"
,
{
b1
});
desc
.
SetOutput
(
"Out"
,
{
out1
});
desc
.
SetType
(
"fc"
);
desc
.
SetAttr
(
"in_num_col_dims"
,
1
);
desc
.
Flush
();
// add to input
program
.
tmp_vars
.
push_back
(
w1
);
program
.
tmp_vars
.
push_back
(
b1
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
fc_op
->
PickKernel
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
fc_op
->
Attach
(
desc
,
program
.
scope
);
program
.
ops
.
emplace_back
(
std
::
move
(
fc_op
));
w1v
->
Resize
({
100
,
100
});
b1v
->
Resize
({
100
,
1
});
out1v
->
Resize
({
100
,
100
});
return
out1
;
};
// x1, w1, b1 -fc-> out1
// out1, w2, b2 -fc-> out2
std
::
string
x
=
"x"
;
program
.
tmp_vars
.
push_back
(
x
);
auto
*
xv
=
program
.
scope
->
Var
(
x
)
->
GetMutable
<
Tensor
>
();
xv
->
Resize
({
100
,
100
});
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
x
=
add_fc
(
i
,
x
);
}
return
program
;
}
TEST
(
SSAGraph
,
test
)
{
TEST
(
SSAGraph
,
test
)
{
auto
program
=
FakeProgram
();
auto
program
=
FakeProgram
();
SSAGraph
graph
;
SSAGraph
graph
;
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
eada00c2
...
@@ -11,3 +11,49 @@
...
@@ -11,3 +11,49 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
bool
KernelScoreCmp
(
const
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>&
a
,
const
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>&
b
)
{
return
a
.
first
>
b
.
first
;
}
void
StaticKernelPickPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
CHECK
(
kernel_pick_factors_
.
AnyFactorConsidered
())
<<
"kernel_pick_factors should be specified first"
;
CHECK
(
graph
)
<<
"graph not valid"
;
// sort kernels by the factors.
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsInstruct
())
continue
;
auto
&
instruct
=
node
.
AsInstruct
();
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
scored
.
emplace_back
(
KernelGrade
(
*
kernel
),
std
::
move
(
kernel
));
}
std
::
sort
(
scored
.
begin
(),
scored
.
end
(),
KernelScoreCmp
);
// Move kernel back
// Just keep a single best kernel.
// TODO(Superjomn) reconsider this.
instruct
.
valid_kernels
.
clear
();
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
LOG
(
INFO
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
}
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
static_kernel_pick_pass
,
paddle
::
lite
::
mir
::
StaticKernelPickPass
);
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
浏览文件 @
eada00c2
...
@@ -14,13 +14,65 @@
...
@@ -14,13 +14,65 @@
#pragma once
#pragma once
#include <limits>
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/types.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
class
StaticKernelPickPass
:
public
mir
::
Pass
{};
/*
* StaticKernelPickPass is a simple strategy for picking the kernel for each
* Operator using operator developer defined rule, there are many other tactics
* such as considering IO or kernel execution latency and we will implement them
* latter.
*
* There are two argument for this pass:
* - place, the target place.
* - kernel_pick_factors, the factors to consider in picking kernels.
* Set them first before execute the pass.
*/
class
StaticKernelPickPass
:
public
mir
::
InstructionPass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
const
Place
&
place
()
const
{
return
place_
;
}
const
core
::
KernelPickFactor
&
kernel_pick_factors
()
const
{
return
kernel_pick_factors_
;
}
core
::
KernelPickFactor
*
mutable_kernel_pick_factors
()
{
return
&
kernel_pick_factors_
;
}
private:
// Score the kernel.
size_t
KernelGrade
(
const
lite
::
KernelBase
&
kernel
)
{
size_t
score
{};
const
int
kMax
=
std
::
numeric_limits
<
core
::
KernelPickFactor
::
value_type
>::
max
();
if
(
kernel_pick_factors_
.
IsTargetConsidered
()
&&
place
().
target
==
kernel
.
target
())
{
score
+=
kMax
/
static_cast
<
int
>
(
core
::
KernelPickFactor
::
Factor
::
TargetFirst
);
}
if
(
kernel_pick_factors_
.
IsPrecisionConsidered
()
&&
place
().
precision
==
kernel
.
precision
())
{
score
+=
kMax
/
static_cast
<
int
>
(
core
::
KernelPickFactor
::
Factor
::
PrecisionFirst
);
}
// The data layout is not considered, for the input and output arguments
// might have different data layout.
// TODO(Superjomn) reconsider the idea of taking the data layout as a kernel
// specification.
return
score
;
}
private:
core
::
KernelPickFactor
kernel_pick_factors_
;
Place
place_
;
};
}
// namespace mir
}
// namespace mir
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/op_registry.h
浏览文件 @
eada00c2
...
@@ -120,8 +120,10 @@ class KernelRegistor : public lite::Registor<KernelType> {
...
@@ -120,8 +120,10 @@ class KernelRegistor : public lite::Registor<KernelType> {
LOG
(
INFO
)
<<
"Register kernel "
<<
op_type
<<
" for "
LOG
(
INFO
)
<<
"Register kernel "
<<
op_type
<<
" for "
<<
TargetToStr
(
target
)
<<
" "
<<
PrecisionToStr
(
precision
);
<<
TargetToStr
(
target
)
<<
" "
<<
PrecisionToStr
(
precision
);
KernelRegistry
::
Global
().
Register
<
target
,
precision
>
(
KernelRegistry
::
Global
().
Register
<
target
,
precision
>
(
op_type
,
[
&
]()
->
std
::
unique_ptr
<
KernelType
>
{
op_type
,
[
&
,
op_type
]()
->
std
::
unique_ptr
<
KernelType
>
{
return
std
::
unique_ptr
<
KernelType
>
(
new
KernelType
);
std
::
unique_ptr
<
KernelType
>
x
(
new
KernelType
);
x
->
set_op_type
(
op_type
);
return
x
;
});
});
})
{}
})
{}
};
};
...
...
paddle/fluid/lite/core/optimizer.h
浏览文件 @
eada00c2
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -28,17 +27,19 @@ namespace lite {
...
@@ -28,17 +27,19 @@ namespace lite {
*/
*/
class
Optimizer
{
class
Optimizer
{
public:
public:
void
Run
(
std
::
unique_ptr
<
mir
::
Program
>&&
program
,
void
Run
(
mir
::
Program
&&
program
,
const
std
::
vector
<
Place
>&
valid_places
,
const
std
::
vector
<
Place
>&
valid_places
,
const
std
::
vector
<
std
::
string
>&
passes
=
{})
{
const
std
::
vector
<
std
::
string
>&
passes
=
{})
{
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
->
Build
(
*
program
,
valid_places
);
graph_
->
Build
(
program
,
valid_places
);
RunPasses
();
RunPasses
();
}
}
// Generate a new program based on the mir graph.
// Generate a new program based on the mir graph.
std
::
unique_ptr
<
mir
::
Program
>
GenProgram
()
{}
std
::
unique_ptr
<
mir
::
Program
>
GenProgram
()
{
std
::
unique_ptr
<
mir
::
Program
>
res
;
return
res
;
}
// Generate C++ code which combines the inference program, model and weights.
// Generate C++ code which combines the inference program, model and weights.
void
GenCode
(
const
std
::
string
&
code_dir
);
void
GenCode
(
const
std
::
string
&
code_dir
);
...
@@ -50,7 +51,7 @@ class Optimizer {
...
@@ -50,7 +51,7 @@ class Optimizer {
protected:
protected:
// Run the default passes registered in the PassManager.
// Run the default passes registered in the PassManager.
void
RunPasses
()
{
mir
::
PassManager
::
Global
().
Run
();
}
void
RunPasses
()
{
mir
::
PassManager
::
Global
().
Run
(
graph_
);
}
// Specify the passes and run them.
// Specify the passes and run them.
void
RunPasses
(
std
::
vector
<
std
::
string
>&
passes
);
void
RunPasses
(
std
::
vector
<
std
::
string
>&
passes
);
...
...
paddle/fluid/lite/core/types.h
浏览文件 @
eada00c2
...
@@ -25,6 +25,38 @@ using any_context_t = variant<Context<TARGET(kX86)>, //
...
@@ -25,6 +25,38 @@ using any_context_t = variant<Context<TARGET(kX86)>, //
Context
<
TARGET
(
kCUDA
)
>
//
Context
<
TARGET
(
kCUDA
)
>
//
>
;
>
;
// Factors that impact the kernel picking strategy. Multiple factors can be
// considered together by using statement like 'factor1 | factor2'
class
KernelPickFactor
{
public:
using
value_type
=
unsigned
char
;
enum
class
Factor
:
int
{
// The following factors are sorted by priority.
TargetFirst
=
1
,
PrecisionFirst
=
1
<<
1
,
DataLayoutFirst
=
1
<<
2
,
DeviceFirst
=
1
<<
3
,
};
// Has any factors considered.
bool
AnyFactorConsidered
()
const
{
return
data_
;
}
KernelPickFactor
&
ConsiderTarget
();
KernelPickFactor
&
ConsiderPrecision
();
KernelPickFactor
&
ConsiderDataLayout
();
KernelPickFactor
&
ConsiderDevice
();
bool
IsTargetConsidered
()
const
;
bool
IsPrecisionConsidered
()
const
;
bool
IsDataLayoutConsidered
()
const
;
bool
IsDeviceConsidered
()
const
{
return
data_
&
static_cast
<
int
>
(
Factor
::
DeviceFirst
);
}
private:
unsigned
char
data_
{};
};
struct
dim2
{
struct
dim2
{
int
x
{};
int
x
{};
int
y
{};
int
y
{};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录