Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
eada00c2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eada00c2
编写于
4月 17, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
init optimizer
上级
239d716b
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
217 addition
and
74 deletion
+217
-74
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+7
-0
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+18
-3
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+3
-2
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+14
-0
paddle/fluid/lite/core/mir/generate_program_pass.h
paddle/fluid/lite/core/mir/generate_program_pass.h
+2
-1
paddle/fluid/lite/core/mir/io_complement_pass.cc
paddle/fluid/lite/core/mir/io_complement_pass.cc
+15
-0
paddle/fluid/lite/core/mir/io_complement_pass.h
paddle/fluid/lite/core/mir/io_complement_pass.h
+2
-1
paddle/fluid/lite/core/mir/pass_manager.cc
paddle/fluid/lite/core/mir/pass_manager.cc
+0
-1
paddle/fluid/lite/core/mir/pass_manager.h
paddle/fluid/lite/core/mir/pass_manager.h
+12
-5
paddle/fluid/lite/core/mir/ssa_graph_test.cc
paddle/fluid/lite/core/mir/ssa_graph_test.cc
+2
-52
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+46
-0
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
+53
-1
paddle/fluid/lite/core/op_registry.h
paddle/fluid/lite/core/op_registry.h
+4
-2
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+7
-6
paddle/fluid/lite/core/types.h
paddle/fluid/lite/core/types.h
+32
-0
未找到文件。
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
eada00c2
...
@@ -9,8 +9,14 @@ cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_l
...
@@ -9,8 +9,14 @@ cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_l
#TODO(Superjomn) remove these dependencies from original framework
#TODO(Superjomn) remove these dependencies from original framework
proto_desc
)
proto_desc
)
cc_library
(
kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite
)
cc_library
(
kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS tensor_lite
)
cc_library
(
type_system SRCS type_system.cc DEPS tensor_lite
)
cc_library
(
optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager
)
cc_library
(
optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager
)
cc_library
(
program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph
scope_lite op_registry_lite proto_desc op_lite
ops_lite
host_kernels
)
cc_test
(
test_scope_lite SRCS scope_test.cc DEPS scope_lite
)
cc_test
(
test_scope_lite SRCS scope_test.cc DEPS scope_lite
)
cc_test
(
test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86
)
cc_test
(
test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86
)
...
@@ -18,5 +24,6 @@ cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
...
@@ -18,5 +24,6 @@ cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
cc_test
(
test_tensor_lite SRCS tensor_test.cc
)
cc_test
(
test_tensor_lite SRCS tensor_test.cc
)
cc_test
(
test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels
)
cc_test
(
test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels
)
cc_test
(
test_type_system SRCS type_system_test.cc DEPS type_system
)
cc_test
(
test_type_system SRCS type_system_test.cc DEPS type_system
)
cc_test
(
test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes
)
add_subdirectory
(
mir
)
add_subdirectory
(
mir
)
paddle/fluid/lite/core/kernel.h
浏览文件 @
eada00c2
...
@@ -15,7 +15,9 @@
...
@@ -15,7 +15,9 @@
#pragma once
#pragma once
#include <map>
#include <map>
#include <set>
#include <string>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
...
@@ -48,17 +50,24 @@ class KernelBase {
...
@@ -48,17 +50,24 @@ class KernelBase {
return
param_
.
get
<
Param
>
();
return
param_
.
get
<
Param
>
();
}
}
void
set_op_type
(
const
std
::
string
&
type
)
{
op_type_
=
type
;
}
const
std
::
string
&
op_type
()
const
{
return
op_type_
;
}
void
Torch
()
{}
void
Torch
()
{}
virtual
TargetType
target
()
const
=
0
;
virtual
TargetType
target
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
DataLayoutType
layout
()
const
=
0
;
virtual
DataLayoutType
layout
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
~
KernelBase
()
=
default
;
virtual
~
KernelBase
()
=
default
;
protected:
protected:
core
::
any_context_t
context_
;
core
::
any_context_t
context_
;
mutable
operators
::
param_t
param_
;
mutable
operators
::
param_t
param_
;
// The corresponding op type.
std
::
string
op_type_
;
};
};
/*
/*
...
@@ -73,8 +82,9 @@ struct ParamType {
...
@@ -73,8 +82,9 @@ struct ParamType {
Place
tensor_place
{};
Place
tensor_place
{};
const
Type
*
type_
;
const
Type
*
type_
;
ParamType
()
=
default
;
explicit
ParamType
()
=
default
;
ParamType
(
size_t
element_type_hash
)
:
element_type_hash
(
element_type_hash
)
{}
explicit
ParamType
(
size_t
element_type_hash
)
:
element_type_hash
(
element_type_hash
)
{}
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
ParamType
(
const
Type
*
type
)
:
type_
(
type
)
{}
ParamType
(
const
Type
*
type
)
:
type_
(
type
)
{}
...
@@ -135,7 +145,8 @@ class ParamTypeRegistry {
...
@@ -135,7 +145,8 @@ class ParamTypeRegistry {
* PRECISION(kFloat)});
* PRECISION(kFloat)});
*/
*/
struct
NewInstance
{
struct
NewInstance
{
NewInstance
(
const
std
::
string
&
kernel_type
)
:
kernel_type_
(
kernel_type
)
{}
explicit
NewInstance
(
const
std
::
string
&
kernel_type
)
:
kernel_type_
(
kernel_type
)
{}
NewInstance
&
BindInput
(
int
offset
,
const
ParamType
&
ptype
)
{
NewInstance
&
BindInput
(
int
offset
,
const
ParamType
&
ptype
)
{
ParamTypeRegistry
::
Global
().
Register
<
IO
::
kInput
>
(
ParamTypeRegistry
::
Global
().
Register
<
IO
::
kInput
>
(
...
@@ -205,6 +216,10 @@ class OpKernel : public KernelBase {
...
@@ -205,6 +216,10 @@ class OpKernel : public KernelBase {
TargetType
target
()
const
override
{
return
Target
;
}
TargetType
target
()
const
override
{
return
Target
;
}
PrecisionType
precision
()
const
override
{
return
Precision
;
}
PrecisionType
precision
()
const
override
{
return
Precision
;
}
DataLayoutType
layout
()
const
override
{
return
DataLayout
;
}
DataLayoutType
layout
()
const
override
{
return
DataLayout
;
}
std
::
string
name
()
const
override
{
return
op_type
()
+
":"
+
TargetToStr
(
Target
)
+
"/"
+
PrecisionToStr
(
Precision
)
+
"/"
+
DataLayoutToStr
(
DataLayout
);
}
void
Touch
()
{}
void
Touch
()
{}
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
eada00c2
...
@@ -9,13 +9,14 @@ cc_library(mir_passes
...
@@ -9,13 +9,14 @@ cc_library(mir_passes
graph_visualize_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
generate_program_pass.cc
demo_pass.cc
demo_pass.cc
DEPS mir_pass
)
DEPS mir_pass
types_lite
)
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes
)
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes
)
cc_test
(
test_ssa_graph SRCS ssa_graph_test.cc DEPS
cc_test
(
test_ssa_graph SRCS ssa_graph_test.cc DEPS
mir_ssa_graph scope_lite op_lite
mir_ssa_graph scope_lite op_lite
proto_desc
ops_lite
ops_lite
host_kernels
host_kernels
mir_passes
mir_passes
mir_pass_manager
mir_pass_manager
program_fake_utils
)
)
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
eada00c2
...
@@ -11,3 +11,17 @@
...
@@ -11,3 +11,17 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
{}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
generate_program_pass
,
paddle
::
lite
::
mir
::
GenerateProgramPass
);
paddle/fluid/lite/core/mir/generate_program_pass.h
浏览文件 @
eada00c2
...
@@ -24,8 +24,9 @@ namespace mir {
...
@@ -24,8 +24,9 @@ namespace mir {
* GenerateProgramPass will build the execution program for executor from a mir
* GenerateProgramPass will build the execution program for executor from a mir
* graph.
* graph.
*/
*/
class
GenerateProgramPass
:
public
Pass
{
class
GenerateProgramPass
:
public
P
rogramP
ass
{
public:
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
override
;
};
};
}
// namespace mir
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/io_complement_pass.cc
浏览文件 @
eada00c2
...
@@ -11,3 +11,18 @@
...
@@ -11,3 +11,18 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
IoComplementPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
{}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
io_complement_pass
,
paddle
::
lite
::
mir
::
IoComplementPass
);
paddle/fluid/lite/core/mir/io_complement_pass.h
浏览文件 @
eada00c2
...
@@ -24,8 +24,9 @@ namespace mir {
...
@@ -24,8 +24,9 @@ namespace mir {
* IoComplementPass complement the necessary instruction to make data
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
* transferring or transformation between different places.
*/
*/
class
IoComplementPass
:
public
Pass
{
class
IoComplementPass
:
public
P
rogramP
ass
{
public:
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
override
;
};
};
}
// namespace mir
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/pass_manager.cc
浏览文件 @
eada00c2
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
...
paddle/fluid/lite/core/mir/pass_manager.h
浏览文件 @
eada00c2
...
@@ -32,16 +32,17 @@ class PassManager {
...
@@ -32,16 +32,17 @@ class PassManager {
PassManager
();
PassManager
();
void
Run
()
{
void
Run
(
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
for
(
auto
&
pass
:
passes_
)
{
for
(
auto
&
pass
:
passes_
)
{
LOG
(
INFO
)
<<
"Running MIR pass "
<<
pass
->
name
();
LOG
(
INFO
)
<<
"Running MIR pass "
<<
pass
->
name
();
pass
->
Apply
(
graph
_
);
pass
->
Apply
(
graph
);
}
}
}
}
bool
AddNewPass
(
const
std
::
string
&
name
,
Pass
*
pass
)
{
bool
AddNewPass
(
const
std
::
string
&
name
,
Pass
*
pass
)
{
passes_
.
emplace_back
(
pass
);
passes_
.
emplace_back
(
pass
);
pass_map_
.
emplace
(
name
,
passes_
.
back
().
get
());
pass_map_
.
emplace
(
name
,
passes_
.
back
().
get
());
passes_
.
back
()
->
set_name
(
name
);
return
true
;
return
true
;
}
}
...
@@ -65,12 +66,18 @@ class PassManager {
...
@@ -65,12 +66,18 @@ class PassManager {
Pass
*
LookUp
(
const
std
::
string
&
key
)
{
Pass
*
LookUp
(
const
std
::
string
&
key
)
{
auto
it
=
pass_map_
.
find
(
key
);
auto
it
=
pass_map_
.
find
(
key
);
CHECK
(
it
!=
pass_map_
.
end
());
if
(
it
!=
pass_map_
.
end
())
return
it
->
second
;
return
it
->
second
;
return
nullptr
;
}
template
<
typename
PassTy
>
PassTy
*
LookUp
(
const
std
::
string
&
key
)
{
auto
it
=
pass_map_
.
find
(
key
);
if
(
it
!=
pass_map_
.
end
())
return
dynamic_cast
<
PassTy
*>
(
it
->
second
);
return
nullptr
;
}
}
private:
private:
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>
passes_
;
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>
passes_
;
std
::
map
<
std
::
string
,
mir
::
Pass
*>
pass_map_
;
std
::
map
<
std
::
string
,
mir
::
Pass
*>
pass_map_
;
};
};
...
...
paddle/fluid/lite/core/mir/ssa_graph_test.cc
浏览文件 @
eada00c2
...
@@ -16,7 +16,9 @@
...
@@ -16,7 +16,9 @@
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program_fake_utils.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -32,58 +34,6 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x,
...
@@ -32,58 +34,6 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x,
fc
->
SetOutput
(
"Out"
,
{
out
});
fc
->
SetOutput
(
"Out"
,
{
out
});
}
}
Program
FakeProgram
()
{
Program
program
;
program
.
scope
=
new
lite
::
Scope
;
auto
add_fc
=
[
&
](
int
id
,
std
::
string
x
)
{
// create variables
std
::
string
w1
=
"w"
+
std
::
to_string
(
id
);
std
::
string
b1
=
"b"
+
std
::
to_string
(
id
);
std
::
string
out1
=
"out"
+
std
::
to_string
(
id
);
auto
w1v
=
program
.
scope
->
Var
(
w1
)
->
GetMutable
<
Tensor
>
();
auto
b1v
=
program
.
scope
->
Var
(
b1
)
->
GetMutable
<
Tensor
>
();
auto
out1v
=
program
.
scope
->
Var
(
out1
)
->
GetMutable
<
Tensor
>
();
framework
::
OpDesc
desc
;
desc
.
SetInput
(
"Input"
,
{
x
});
desc
.
SetInput
(
"W"
,
{
w1
});
desc
.
SetInput
(
"Bias"
,
{
b1
});
desc
.
SetOutput
(
"Out"
,
{
out1
});
desc
.
SetType
(
"fc"
);
desc
.
SetAttr
(
"in_num_col_dims"
,
1
);
desc
.
Flush
();
// add to input
program
.
tmp_vars
.
push_back
(
w1
);
program
.
tmp_vars
.
push_back
(
b1
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
fc_op
->
PickKernel
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
fc_op
->
Attach
(
desc
,
program
.
scope
);
program
.
ops
.
emplace_back
(
std
::
move
(
fc_op
));
w1v
->
Resize
({
100
,
100
});
b1v
->
Resize
({
100
,
1
});
out1v
->
Resize
({
100
,
100
});
return
out1
;
};
// x1, w1, b1 -fc-> out1
// out1, w2, b2 -fc-> out2
std
::
string
x
=
"x"
;
program
.
tmp_vars
.
push_back
(
x
);
auto
*
xv
=
program
.
scope
->
Var
(
x
)
->
GetMutable
<
Tensor
>
();
xv
->
Resize
({
100
,
100
});
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
x
=
add_fc
(
i
,
x
);
}
return
program
;
}
TEST
(
SSAGraph
,
test
)
{
TEST
(
SSAGraph
,
test
)
{
auto
program
=
FakeProgram
();
auto
program
=
FakeProgram
();
SSAGraph
graph
;
SSAGraph
graph
;
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
eada00c2
...
@@ -11,3 +11,49 @@
...
@@ -11,3 +11,49 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
bool
KernelScoreCmp
(
const
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>&
a
,
const
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>&
b
)
{
return
a
.
first
>
b
.
first
;
}
void
StaticKernelPickPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
CHECK
(
kernel_pick_factors_
.
AnyFactorConsidered
())
<<
"kernel_pick_factors should be specified first"
;
CHECK
(
graph
)
<<
"graph not valid"
;
// sort kernels by the factors.
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsInstruct
())
continue
;
auto
&
instruct
=
node
.
AsInstruct
();
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
scored
.
emplace_back
(
KernelGrade
(
*
kernel
),
std
::
move
(
kernel
));
}
std
::
sort
(
scored
.
begin
(),
scored
.
end
(),
KernelScoreCmp
);
// Move kernel back
// Just keep a single best kernel.
// TODO(Superjomn) reconsider this.
instruct
.
valid_kernels
.
clear
();
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
LOG
(
INFO
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
}
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
static_kernel_pick_pass
,
paddle
::
lite
::
mir
::
StaticKernelPickPass
);
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
浏览文件 @
eada00c2
...
@@ -14,13 +14,65 @@
...
@@ -14,13 +14,65 @@
#pragma once
#pragma once
#include <limits>
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/types.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
class
StaticKernelPickPass
:
public
mir
::
Pass
{};
/*
* StaticKernelPickPass is a simple strategy for picking the kernel for each
* Operator using operator developer defined rule, there are many other tactics
* such as considering IO or kernel execution latency and we will implement them
* latter.
*
* There are two argument for this pass:
* - place, the target place.
* - kernel_pick_factors, the factors to consider in picking kernels.
* Set them first before execute the pass.
*/
class
StaticKernelPickPass
:
public
mir
::
InstructionPass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
const
Place
&
place
()
const
{
return
place_
;
}
const
core
::
KernelPickFactor
&
kernel_pick_factors
()
const
{
return
kernel_pick_factors_
;
}
core
::
KernelPickFactor
*
mutable_kernel_pick_factors
()
{
return
&
kernel_pick_factors_
;
}
private:
// Score the kernel.
size_t
KernelGrade
(
const
lite
::
KernelBase
&
kernel
)
{
size_t
score
{};
const
int
kMax
=
std
::
numeric_limits
<
core
::
KernelPickFactor
::
value_type
>::
max
();
if
(
kernel_pick_factors_
.
IsTargetConsidered
()
&&
place
().
target
==
kernel
.
target
())
{
score
+=
kMax
/
static_cast
<
int
>
(
core
::
KernelPickFactor
::
Factor
::
TargetFirst
);
}
if
(
kernel_pick_factors_
.
IsPrecisionConsidered
()
&&
place
().
precision
==
kernel
.
precision
())
{
score
+=
kMax
/
static_cast
<
int
>
(
core
::
KernelPickFactor
::
Factor
::
PrecisionFirst
);
}
// The data layout is not considered, for the input and output arguments
// might have different data layout.
// TODO(Superjomn) reconsider the idea of taking the data layout as a kernel
// specification.
return
score
;
}
private:
core
::
KernelPickFactor
kernel_pick_factors_
;
Place
place_
;
};
}
// namespace mir
}
// namespace mir
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/op_registry.h
浏览文件 @
eada00c2
...
@@ -120,8 +120,10 @@ class KernelRegistor : public lite::Registor<KernelType> {
...
@@ -120,8 +120,10 @@ class KernelRegistor : public lite::Registor<KernelType> {
LOG
(
INFO
)
<<
"Register kernel "
<<
op_type
<<
" for "
LOG
(
INFO
)
<<
"Register kernel "
<<
op_type
<<
" for "
<<
TargetToStr
(
target
)
<<
" "
<<
PrecisionToStr
(
precision
);
<<
TargetToStr
(
target
)
<<
" "
<<
PrecisionToStr
(
precision
);
KernelRegistry
::
Global
().
Register
<
target
,
precision
>
(
KernelRegistry
::
Global
().
Register
<
target
,
precision
>
(
op_type
,
[
&
]()
->
std
::
unique_ptr
<
KernelType
>
{
op_type
,
[
&
,
op_type
]()
->
std
::
unique_ptr
<
KernelType
>
{
return
std
::
unique_ptr
<
KernelType
>
(
new
KernelType
);
std
::
unique_ptr
<
KernelType
>
x
(
new
KernelType
);
x
->
set_op_type
(
op_type
);
return
x
;
});
});
})
{}
})
{}
};
};
...
...
paddle/fluid/lite/core/optimizer.h
浏览文件 @
eada00c2
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -28,17 +27,19 @@ namespace lite {
...
@@ -28,17 +27,19 @@ namespace lite {
*/
*/
class
Optimizer
{
class
Optimizer
{
public:
public:
void
Run
(
std
::
unique_ptr
<
mir
::
Program
>&&
program
,
void
Run
(
mir
::
Program
&&
program
,
const
std
::
vector
<
Place
>&
valid_places
,
const
std
::
vector
<
Place
>&
valid_places
,
const
std
::
vector
<
std
::
string
>&
passes
=
{})
{
const
std
::
vector
<
std
::
string
>&
passes
=
{})
{
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
->
Build
(
*
program
,
valid_places
);
graph_
->
Build
(
program
,
valid_places
);
RunPasses
();
RunPasses
();
}
}
// Generate a new program based on the mir graph.
// Generate a new program based on the mir graph.
std
::
unique_ptr
<
mir
::
Program
>
GenProgram
()
{}
std
::
unique_ptr
<
mir
::
Program
>
GenProgram
()
{
std
::
unique_ptr
<
mir
::
Program
>
res
;
return
res
;
}
// Generate C++ code which combines the inference program, model and weights.
// Generate C++ code which combines the inference program, model and weights.
void
GenCode
(
const
std
::
string
&
code_dir
);
void
GenCode
(
const
std
::
string
&
code_dir
);
...
@@ -50,7 +51,7 @@ class Optimizer {
...
@@ -50,7 +51,7 @@ class Optimizer {
protected:
protected:
// Run the default passes registered in the PassManager.
// Run the default passes registered in the PassManager.
void
RunPasses
()
{
mir
::
PassManager
::
Global
().
Run
();
}
void
RunPasses
()
{
mir
::
PassManager
::
Global
().
Run
(
graph_
);
}
// Specify the passes and run them.
// Specify the passes and run them.
void
RunPasses
(
std
::
vector
<
std
::
string
>&
passes
);
void
RunPasses
(
std
::
vector
<
std
::
string
>&
passes
);
...
...
paddle/fluid/lite/core/types.h
浏览文件 @
eada00c2
...
@@ -25,6 +25,38 @@ using any_context_t = variant<Context<TARGET(kX86)>, //
...
@@ -25,6 +25,38 @@ using any_context_t = variant<Context<TARGET(kX86)>, //
Context
<
TARGET
(
kCUDA
)
>
//
Context
<
TARGET
(
kCUDA
)
>
//
>
;
>
;
// Factors that impact the kernel picking strategy. Multiple factors can be
// considered together by using statement like 'factor1 | factor2'
class
KernelPickFactor
{
public:
using
value_type
=
unsigned
char
;
enum
class
Factor
:
int
{
// The following factors are sorted by priority.
TargetFirst
=
1
,
PrecisionFirst
=
1
<<
1
,
DataLayoutFirst
=
1
<<
2
,
DeviceFirst
=
1
<<
3
,
};
// Has any factors considered.
bool
AnyFactorConsidered
()
const
{
return
data_
;
}
KernelPickFactor
&
ConsiderTarget
();
KernelPickFactor
&
ConsiderPrecision
();
KernelPickFactor
&
ConsiderDataLayout
();
KernelPickFactor
&
ConsiderDevice
();
bool
IsTargetConsidered
()
const
;
bool
IsPrecisionConsidered
()
const
;
bool
IsDataLayoutConsidered
()
const
;
bool
IsDeviceConsidered
()
const
{
return
data_
&
static_cast
<
int
>
(
Factor
::
DeviceFirst
);
}
private:
unsigned
char
data_
{};
};
struct
dim2
{
struct
dim2
{
int
x
{};
int
x
{};
int
y
{};
int
y
{};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录