Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
621d1522
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
621d1522
编写于
4月 27, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make io_copy kernel pick works
上级
1fb93746
变更
51
隐藏空白更改
内联
并排
Showing
51 changed file
with
1218 addition
and
354 deletion
+1218
-354
paddle/fluid/framework/op_desc.h
paddle/fluid/framework/op_desc.h
+1
-0
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+6
-2
paddle/fluid/lite/api/cxx_api.h
paddle/fluid/lite/api/cxx_api.h
+2
-1
paddle/fluid/lite/api/cxx_api_test.cc
paddle/fluid/lite/api/cxx_api_test.cc
+26
-6
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+1
-0
paddle/fluid/lite/core/kernel.cc
paddle/fluid/lite/core/kernel.cc
+7
-0
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+60
-8
paddle/fluid/lite/core/memory.cc
paddle/fluid/lite/core/memory.cc
+1
-1
paddle/fluid/lite/core/memory.h
paddle/fluid/lite/core/memory.h
+5
-2
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+3
-0
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
+45
-0
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+3
-0
paddle/fluid/lite/core/mir/io_complement_pass.cc
paddle/fluid/lite/core/mir/io_complement_pass.cc
+152
-18
paddle/fluid/lite/core/mir/io_complement_pass.h
paddle/fluid/lite/core/mir/io_complement_pass.h
+30
-1
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
+74
-0
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+38
-10
paddle/fluid/lite/core/mir/passes.h
paddle/fluid/lite/core/mir/passes.h
+2
-0
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+138
-0
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+64
-82
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+3
-2
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
+19
-2
paddle/fluid/lite/core/mir/variable_place_inference_pass.cc
paddle/fluid/lite/core/mir/variable_place_inference_pass.cc
+1
-1
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+27
-20
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+11
-6
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+47
-32
paddle/fluid/lite/core/op_registry.cc
paddle/fluid/lite/core/op_registry.cc
+41
-14
paddle/fluid/lite/core/op_registry.h
paddle/fluid/lite/core/op_registry.h
+92
-59
paddle/fluid/lite/core/optimizer.cc
paddle/fluid/lite/core/optimizer.cc
+29
-0
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+25
-1
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+14
-8
paddle/fluid/lite/core/target_wrapper.h
paddle/fluid/lite/core/target_wrapper.h
+51
-11
paddle/fluid/lite/core/type_system.cc
paddle/fluid/lite/core/type_system.cc
+29
-8
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+58
-22
paddle/fluid/lite/core/types.cc
paddle/fluid/lite/core/types.cc
+3
-0
paddle/fluid/lite/core/types.h
paddle/fluid/lite/core/types.h
+21
-2
paddle/fluid/lite/cuda/target_wrapper.cc
paddle/fluid/lite/cuda/target_wrapper.cc
+1
-11
paddle/fluid/lite/kernels/cuda/CMakeLists.txt
paddle/fluid/lite/kernels/cuda/CMakeLists.txt
+3
-1
paddle/fluid/lite/kernels/cuda/io_copy_compute.cc
paddle/fluid/lite/kernels/cuda/io_copy_compute.cc
+24
-3
paddle/fluid/lite/kernels/cuda/mul_compute.cc
paddle/fluid/lite/kernels/cuda/mul_compute.cc
+19
-0
paddle/fluid/lite/kernels/cuda/mul_compute.h
paddle/fluid/lite/kernels/cuda/mul_compute.h
+21
-2
paddle/fluid/lite/kernels/host/fc_compute.cc
paddle/fluid/lite/kernels/host/fc_compute.cc
+2
-2
paddle/fluid/lite/kernels/host/feed_compute.cc
paddle/fluid/lite/kernels/host/feed_compute.cc
+3
-2
paddle/fluid/lite/kernels/host/fetch_compute.cc
paddle/fluid/lite/kernels/host/fetch_compute.cc
+3
-2
paddle/fluid/lite/kernels/host/mul_compute.cc
paddle/fluid/lite/kernels/host/mul_compute.cc
+1
-1
paddle/fluid/lite/kernels/host/relu_compute.h
paddle/fluid/lite/kernels/host/relu_compute.h
+1
-1
paddle/fluid/lite/kernels/host/scale_compute.cc
paddle/fluid/lite/kernels/host/scale_compute.cc
+1
-1
paddle/fluid/lite/operators/io_copy_op.cc
paddle/fluid/lite/operators/io_copy_op.cc
+4
-1
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+0
-3
paddle/fluid/lite/utils/factory.h
paddle/fluid/lite/utils/factory.h
+3
-4
paddle/fluid/lite/utils/varient.h
paddle/fluid/lite/utils/varient.h
+1
-1
paddle/fluid/lite/utils/varient_test.cc
paddle/fluid/lite/utils/varient_test.cc
+2
-0
未找到文件。
paddle/fluid/framework/op_desc.h
浏览文件 @
621d1522
...
...
@@ -42,6 +42,7 @@ class OpDesc {
void
CopyFrom
(
const
OpDesc
&
op_desc
);
proto
::
OpDesc
*
Proto
();
const
proto
::
OpDesc
&
ReadonlyProto
()
const
{
return
desc_
;
}
std
::
string
Type
()
const
{
return
desc_
.
type
();
}
...
...
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
621d1522
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host
)
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite
)
if
(
LITE_WITH_CUDA
)
cc_library
(
cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda
)
nv_test
(
test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite
)
endif
()
paddle/fluid/lite/api/cxx_api.h
浏览文件 @
621d1522
...
...
@@ -29,7 +29,7 @@ class Predictor {
public:
Predictor
()
{
scope_
=
std
::
make_shared
<
Scope
>
();
}
void
Build
(
const
std
::
string
&
model_path
,
void
Build
(
const
std
::
string
&
model_path
,
const
Place
&
prefer_place
,
const
std
::
vector
<
Place
>&
valid_places
)
{
framework
::
proto
::
ProgramDesc
prog
;
LoadModel
(
model_path
,
scope_
.
get
(),
&
prog
);
...
...
@@ -38,6 +38,7 @@ class Predictor {
Program
program
(
prog_desc
,
scope_
,
valid_places
);
Optimizer
optimizer
;
optimizer
.
KernelPickPreferPlace
(
prefer_place
);
core
::
KernelPickFactor
factor
;
factor
.
ConsiderTarget
();
optimizer
.
Run
(
std
::
move
(
program
),
valid_places
,
factor
);
...
...
paddle/fluid/lite/api/cxx_api_test.cc
浏览文件 @
621d1522
...
...
@@ -23,8 +23,21 @@ namespace lite {
TEST
(
CXXApi
,
test
)
{
lite
::
Predictor
predictor
;
#ifndef LITE_WITH_CUDA
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
#else
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)},
});
#endif
predictor
.
Build
(
"/home/chunwei/project2/models/model2"
,
{
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}}
);
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)},
valid_places
);
auto
*
input_tensor
=
predictor
.
GetInput
(
0
);
input_tensor
->
Resize
({
100
,
100
});
...
...
@@ -54,8 +67,15 @@ USE_LITE_OP(fc);
USE_LITE_OP
(
scale
);
USE_LITE_OP
(
feed
);
USE_LITE_OP
(
fetch
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kFloat
,
def
);
USE_LITE_OP
(
io_copy
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kAny
,
kAny
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kAny
,
kAny
,
def
);
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL
(
mul
,
kCUDA
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
host_to_device
);
USE_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
device_to_host
);
#endif
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
621d1522
...
...
@@ -27,5 +27,6 @@ cc_test(test_tensor_lite SRCS tensor_test.cc)
cc_test
(
test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels
)
cc_test
(
test_type_system SRCS type_system_test.cc DEPS type_system
)
cc_test
(
test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes
)
cc_test
(
test_types_lite SRCS types_test.cc DEPS types_lite
)
add_subdirectory
(
mir
)
paddle/fluid/lite/core/kernel.cc
浏览文件 @
621d1522
...
...
@@ -17,6 +17,13 @@
namespace
paddle
{
namespace
lite
{
std
::
string
KernelBase
::
summary
()
const
{
std
::
stringstream
ss
;
ss
<<
op_type
()
<<
":"
<<
TargetToStr
(
target
())
<<
"/"
<<
PrecisionToStr
(
precision
())
<<
"/"
<<
DataLayoutToStr
(
layout
());
return
ss
.
str
();
}
bool
ParamTypeRegistry
::
KeyCmp
::
operator
()(
const
ParamTypeRegistry
::
key_t
&
a
,
const
ParamTypeRegistry
::
key_t
&
b
)
const
{
...
...
paddle/fluid/lite/core/kernel.h
浏览文件 @
621d1522
...
...
@@ -17,6 +17,7 @@
#include <map>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
...
...
@@ -34,49 +35,100 @@ namespace lite {
// different targets.
class
KernelBase
{
public:
// type_infer_handler is used to inference a output type by considering the
// input types in the type system.
using
type_infer_handler_t
=
std
::
function
<
const
Type
*
(
const
std
::
map
<
std
::
string
,
const
Type
*>&
input_types
,
const
std
::
string
&
out_arg
)
>
;
virtual
void
Run
()
=
0
;
void
SetContext
(
std
::
unique_ptr
<
KernelContext
>&&
ctx
)
{
context_
=
std
::
move
(
ctx
);
}
template
<
typename
T
>
void
SetParam
(
T
param
)
{
param_
.
set
<
T
>
(
param
);
}
template
<
typename
P
>
P
&
Param
()
const
{
return
param_
.
get
<
P
>
();
}
// This is used in the kernels that takes 'kAny' places and inference the
// output place. For `ScaleCompute` and `IoCopyCompute`, their input types are
// declared as 'kAny' in some Place field, and the output is also `kAny`, but
// when in real execution, when takes some non-kAny type as input, the
// output's kAny-fields can be determained. For example, when the
// `ScaleCompute` takes `TensorFp32NCHWTy` as input, its output should be also
// `TensorFp32NCHWTy`. This type inference rule is different for each kernel,
// so we make it a virtual method.
// One can custom this handler to make a specific type inference rule for a
// kernel, or leave the default to force the kernel use the system's
// type-inference rules.
virtual
std
::
unique_ptr
<
type_infer_handler_t
>
GetTypeInferHandler
()
{
return
nullptr
;
}
void
set_op_type
(
const
std
::
string
&
type
)
{
op_type_
=
type
;
}
const
std
::
string
&
op_type
()
const
{
return
op_type_
;
}
void
Torch
()
{}
// Get input declaration type.
const
Type
*
GetInputDeclType
(
const
std
::
string
&
arg_name
)
{
CHECK
(
!
op_type_
.
empty
())
<<
"op_type should be set first"
;
const
auto
*
type
=
ParamTypeRegistry
::
Global
().
RetrieveInArgument
(
place
(),
GenParamTypeKey
(),
arg_name
);
CHECK
(
type
)
<<
"no type registered for kernel ["
<<
op_type_
<<
"] input argument ["
<<
arg_name
<<
"]"
<<
" with key "
<<
GenParamTypeKey
();
return
type
->
type
;
}
// Get output declaration type.
const
Type
*
GetOutputDeclType
(
const
std
::
string
&
arg_name
)
{
CHECK
(
!
op_type_
.
empty
())
<<
"op_type should be set first"
;
const
auto
*
type
=
ParamTypeRegistry
::
Global
().
RetrieveOutArgument
(
place
(),
GenParamTypeKey
(),
arg_name
);
CHECK
(
type
)
<<
"no type registered for kernel ["
<<
op_type_
<<
"] output argument ["
<<
arg_name
<<
"]"
;
return
type
->
type
;
}
void
set_alias
(
const
std
::
string
&
x
)
{
alias_
=
x
;
LOG
(
INFO
)
<<
"kernel "
<<
op_type
()
<<
" setting alias "
<<
alias
();
}
const
std
::
string
&
alias
()
const
{
return
alias_
;
}
virtual
Place
place
()
const
=
0
;
virtual
TargetType
target
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
DataLayoutType
layout
()
const
=
0
;
const
KernelContext
*
context
()
const
{
return
context_
.
get
();
}
virtual
std
::
string
name
()
const
=
0
;
virtual
~
KernelBase
()
=
default
;
// Short human-readable document.
std
::
string
summary
()
const
;
// Long human-readable document.
virtual
std
::
string
doc
()
const
{
return
""
;
}
std
::
string
DebugString
()
const
{
std
::
string
GenParamTypeKey
()
const
{
std
::
stringstream
ss
;
ss
<<
op_type
()
<<
":"
<<
TargetToStr
(
target
())
<<
"/"
<<
PrecisionToStr
(
precision
())
<<
"/"
<<
DataLayoutToStr
(
layout
())
;
LOG
(
INFO
)
<<
"alias : "
<<
alias_
;
ss
<<
op_type
()
<<
"/"
<<
alias_
;
return
ss
.
str
();
}
virtual
~
KernelBase
()
=
default
;
protected:
std
::
unique_ptr
<
KernelContext
>
context_
;
mutable
operators
::
param_t
param_
;
// The corresponding op type.
std
::
string
op_type_
;
std
::
string
op_type_
{};
std
::
string
alias_
{};
};
// Light-weight kernel implementation.
...
...
paddle/fluid/lite/core/memory.cc
浏览文件 @
621d1522
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "memory.h"
#include "
paddle/fluid/lite/core/
memory.h"
namespace
paddle
{
namespace
framework
{
...
...
paddle/fluid/lite/core/memory.h
浏览文件 @
621d1522
...
...
@@ -14,7 +14,7 @@
#pragma once
#include <glog/logging.h>
#include "target_wrapper.h"
#include "
paddle/fluid/lite/core/
target_wrapper.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -26,9 +26,12 @@ static void* TargetMalloc(TargetType target, size_t size) {
case
TargetType
::
kX86
:
data
=
TargetWrapper
<
TARGET
(
kHost
)
>::
Malloc
(
size
);
break
;
#ifdef LITE_WITH_CUDA
case
TargetType
::
kCUDA
:
data
=
TargetWrapper
<
TARGET
(
kCUDA
)
>::
Malloc
(
size
);
data
=
TargetWrapper
<
TARGET
(
kCUDA
),
cudaStream_t
,
cudaEvent_t
>::
Malloc
(
size
);
break
;
#endif // LITE_WITH_CUDA
default:
LOG
(
FATAL
)
<<
"Unknown supported target "
<<
TargetToStr
(
target
);
}
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
621d1522
...
...
@@ -7,9 +7,12 @@ cc_library(mir_passes
SRCS static_kernel_pick_pass.cc
variable_place_inference_pass.cc
io_complement_pass.cc
io_copy_kernel_pick_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
argument_type_display_pass.cc
demo_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite
)
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes
)
...
...
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
0 → 100644
浏览文件 @
621d1522
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
ArgumentTypeDisplayPass
:
public
DebugPass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
LOG
(
INFO
)
<<
"== Argument types =="
;
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsArgument
())
continue
;
auto
*
type
=
node
.
AsArgument
().
type
;
if
(
type
)
{
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArgument
().
name
<<
" type: "
<<
*
type
;
}
else
{
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArgument
().
name
<<
" type: UNK"
;
}
}
LOG
(
INFO
)
<<
"---------------------"
;
}
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
argument_type_display_pass
,
paddle
::
lite
::
mir
::
ArgumentTypeDisplayPass
);
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
621d1522
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
...
...
@@ -20,9 +21,11 @@ namespace lite {
namespace
mir
{
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
LOG
(
INFO
)
<<
"final program
\n
"
<<
Visualize
(
graph
.
get
());
for
(
auto
&
item
:
graph
->
InstructTopologicalOrder
())
{
if
(
item
->
IsInstruct
())
{
auto
&
instruct
=
item
->
AsInstruct
();
LOG
(
INFO
)
<<
instruct
;
insts_
.
emplace_back
(
instruct
.
op
,
std
::
move
(
instruct
.
valid_kernels
.
front
()));
}
...
...
paddle/fluid/lite/core/mir/io_complement_pass.cc
浏览文件 @
621d1522
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
...
...
@@ -21,28 +22,161 @@ namespace mir {
void
IoComplementPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
// Start from inputs of the graph, those should have place set.
std
::
list
<
Node
*>
nodes
;
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsInstruct
())
continue
;
auto
&
inst
=
node
.
AsInstruct
();
// inputs
for
(
auto
*
in
:
node
.
inlinks
)
{
CHECK
(
in
->
IsArgument
());
auto
name
=
in
->
AsArgument
().
name
;
std
::
string
tmp
;
CHECK
(
inst
.
op_info
->
GetInputArgname
(
name
,
&
tmp
));
auto
type
=
ParamTypeRegistry
::
Global
().
Retrieve
<
ParamTypeRegistry
::
IO
::
kInput
>
(
inst
.
place
,
inst
.
op_type
,
tmp
);
CHECK
(
type
)
<<
"no param type found for "
<<
inst
.
op_type
<<
":"
<<
name
<<
" "
<<
inst
.
place
;
CHECK
(
type
->
type
);
CHECK
(
in
->
AsArgument
().
type
);
if
(
!
TypeCompatible
(
*
type
->
type
,
*
in
->
AsArgument
().
type
))
{
LOG
(
INFO
)
<<
"found IO unmatched tensor: "
<<
in
->
AsArgument
().
name
;
nodes
.
push_back
(
&
node
);
}
CHECK
(
!
valid_places_
.
empty
());
for
(
auto
&
node
:
nodes
)
{
if
(
!
node
->
IsInstruct
())
continue
;
auto
inlinks
=
node
->
inlinks
;
for
(
auto
*
in
:
inlinks
)
{
ComplementInputs
(
graph
.
get
(),
node
,
in
);
}
}
// PickIoCopyKernel(graph.get());
LOG
(
INFO
)
<<
"
\n
"
<<
Visualize
(
graph
.
get
());
}
void
IoComplementPass
::
ComplementInputs
(
SSAGraph
*
graph
,
Node
*
inst_node
,
Node
*
in
)
{
// If this input is out of date.
if
(
inst_node
->
inlinks
.
end
()
==
std
::
find
(
inst_node
->
inlinks
.
begin
(),
inst_node
->
inlinks
.
end
(),
in
))
return
;
CHECK
(
inst_node
->
IsInstruct
());
auto
&
inst
=
inst_node
->
AsInstruct
();
CHECK
(
in
->
IsRoleSet
());
CHECK
(
in
->
IsArgument
());
auto
in_arg_name
=
in
->
AsArgument
().
name
;
std
::
string
tmp
;
CHECK
(
inst
.
op_info
()
->
GetInputArgname
(
in_arg_name
,
&
tmp
));
auto
decl_arg_type
=
inst
.
picked_kernel
().
GetInputDeclType
(
tmp
);
CHECK
(
in
->
AsArgument
().
type
);
if
(
!
TypeCompatibleTo
(
*
in
->
AsArgument
().
type
,
*
decl_arg_type
))
{
LOG
(
INFO
)
<<
"found IO unmatched tensor: "
<<
in
->
AsArgument
().
name
<<
" for kernel "
<<
inst
.
op
->
DebugString
()
<<
" "
<<
*
in
->
AsArgument
().
type
<<
" -> "
<<
*
decl_arg_type
;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst
(
*
in
->
AsArgument
().
type
,
*
decl_arg_type
,
in
->
AsArgument
().
name
,
graph
,
inst_node
,
valid_places_
);
}
}
void
UpdateOpdescInputName
(
framework
::
OpDesc
*
desc
,
const
std
::
string
&
old_arg_name
,
const
std
::
string
&
new_arg_name
)
{
for
(
auto
&
item
:
*
desc
->
Proto
()
->
mutable_inputs
())
{
for
(
int
i
=
0
;
i
<
item
.
mutable_arguments
()
->
size
();
i
++
)
{
auto
*
arg
=
item
.
mutable_arguments
(
i
);
if
(
*
arg
==
old_arg_name
)
{
*
arg
=
new_arg_name
;
}
}
}
}
void
IoComplementPass
::
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
const
std
::
string
&
var
,
SSAGraph
*
graph
,
Node
*
inst_node
,
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
valid_places
.
empty
())
<<
"valid_place should be set"
;
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Instruct Node.
auto
node_id
=
[
&
]
{
return
graph
->
nodes
().
size
();
};
auto
io_copy_output_name
=
var
+
"/trans/"
+
std
::
to_string
(
node_id
());
auto
*
io_copy_output_arg
=
graph
->
NewArgumentNode
(
io_copy_output_name
);
auto
*
io_copy_inst
=
graph
->
NewInstructNode
();
// create Op and kernels.
auto
io_copy_op
=
LiteOpRegistry
::
Global
().
Create
(
"io_copy"
);
// CHECK(io_copy_op);
// Create the new var manually.
inst_node
->
AsInstruct
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
framework
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"io_copy"
);
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
op_desc
.
Flush
();
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsInstruct
().
op
->
scope
());
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
io_copy_inst
->
AsInstruct
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
// Remove the old link
RemoveDirectedLink
(
graph
->
Argument
(
var
),
inst_node
);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
auto
&
inst
=
inst_node
->
AsInstruct
();
auto
inst_program_desc
=
inst
.
op_info
()
->
desc
();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink
(
graph
->
Argument
(
var
),
io_copy_inst
);
DirectedLink
(
io_copy_inst
,
io_copy_output_arg
);
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
auto
desc_dummy
=
inst_node
->
AsInstruct
().
op
->
op_info
()
->
desc
();
UpdateInputTo
(
&
desc_dummy
,
var
,
io_copy_output_name
);
framework
::
OpDesc
desc_fake
(
desc_dummy
,
nullptr
);
inst_node
->
AsInstruct
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsInstruct
().
op
->
scope
());
std
::
string
tmp
;
if
(
inst_node
->
AsInstruct
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
CHECK
(
false
)
<<
"get old a "
<<
tmp
;
}
for
(
auto
&
kernel
:
inst_node
->
AsInstruct
().
valid_kernels
)
{
inst_node
->
AsInstruct
().
op
->
AttachKernel
(
kernel
.
get
());
}
graph
->
CheckValid
();
}
void
IoComplementPass
::
PickIoCopyKernel
(
SSAGraph
*
graph
)
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
node
.
IsInstruct
()
&&
node
.
AsInstruct
().
op_type
==
"io_copy"
)
{
auto
&
kernels
=
node
.
AsInstruct
().
valid_kernels
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
for
(
auto
&
kernel
:
kernels
)
{
CHECK_EQ
(
node
.
inlinks
.
size
(),
1UL
);
CHECK_EQ
(
node
.
outlinks
.
size
(),
1UL
);
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArgument
().
type
;
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArgument
().
type
;
const
Type
*
in_arg_ty
=
kernel
->
GetInputDeclType
(
"Input"
);
if
(
TypeCompatibleTo
(
*
inty
,
*
in_arg_ty
))
{
const
Type
*
out_arg_ty
=
kernel
->
GetOutputDeclType
(
"Out"
);
// Both the input and output type matches, remove other kernels
// directly.
if
(
out_arg_ty
->
target
()
==
outy
->
target
())
{
LOG
(
INFO
)
<<
"get a IOCopy kernel"
;
auto
x
=
std
::
move
(
kernel
);
kernels
.
clear
();
kernels
.
emplace_back
(
std
::
move
(
x
));
break
;
}
}
}
}
}
// Check the compatiblity.
}
void
IoComplementPass
::
SetValidPlaces
(
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
valid_places
.
empty
());
valid_places_
=
valid_places
;
}
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/io_complement_pass.h
浏览文件 @
621d1522
...
...
@@ -15,18 +15,47 @@
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
static
void
UpdateInputTo
(
framework
::
proto
::
OpDesc
*
desc
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
input
:
*
item
.
mutable_arguments
())
{
if
(
input
==
from
)
{
LOG
(
INFO
)
<<
"** update input argument from "
<<
from
<<
" to "
<<
to
;
input
=
to
;
}
}
}
}
/*
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
*/
class
IoComplementPass
:
public
ProgramPass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
override
;
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
void
ComplementInputs
(
SSAGraph
*
graph
,
Node
*
inst_node
,
Node
*
in
);
void
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
const
std
::
string
&
var
,
SSAGraph
*
graph
,
Node
*
inst_node
,
const
std
::
vector
<
Place
>&
valid_places
);
void
SetValidPlaces
(
const
std
::
vector
<
Place
>&
valid_places
);
// Pick the right kernel of IoCopy considering the input and output Type.
void
PickIoCopyKernel
(
SSAGraph
*
graph
);
const
std
::
vector
<
Place
>&
valid_places
()
const
{
return
valid_places_
;
};
private:
std
::
vector
<
Place
>
valid_places_
;
};
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
0 → 100644
浏览文件 @
621d1522
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
IoCopyKernelPickPass
:
public
InstructionPass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsInstruct
())
continue
;
auto
&
inst
=
node
.
AsInstruct
();
if
(
inst
.
op_type
!=
"io_copy"
)
continue
;
LOG
(
INFO
)
<<
"....> picking a IO COPY kernel"
;
auto
&
kernels
=
node
.
AsInstruct
().
valid_kernels
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArgument
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArgument
().
type
;
LOG
(
INFO
)
<<
"input type "
<<
*
inty
;
LOG
(
INFO
)
<<
"output type "
<<
*
outy
;
bool
is_found
=
false
;
LOG
(
INFO
)
<<
"kernels size "
<<
kernels
.
size
();
for
(
auto
&
kernel
:
kernels
)
{
CHECK_EQ
(
node
.
inlinks
.
size
(),
1UL
);
CHECK_EQ
(
node
.
outlinks
.
size
(),
1UL
);
const
Type
*
in_arg_ty
=
kernel
->
GetInputDeclType
(
"Input"
);
const
Type
*
out_arg_ty
=
kernel
->
GetOutputDeclType
(
"Out"
);
LOG
(
INFO
)
<<
"checking kernel candidate "
<<
*
in_arg_ty
<<
"->"
<<
*
out_arg_ty
;
if
(
inty
->
target
()
==
in_arg_ty
->
target
())
{
// Both the input and output type matches, remove other kernels
// directly.
if
(
out_arg_ty
->
target
()
==
outy
->
target
())
{
LOG
(
INFO
)
<<
"get a IOCopy kernel"
;
auto
x
=
std
::
move
(
kernel
);
kernels
.
clear
();
kernels
.
emplace_back
(
std
::
move
(
x
));
is_found
=
true
;
break
;
}
}
}
CHECK
(
is_found
)
<<
"Can't find a IoCopy kernel for IO: "
<<
*
inty
<<
"->"
<<
*
outy
;
}
}
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
io_copy_kernel_pick_pass
,
paddle
::
lite
::
mir
::
IoCopyKernelPickPass
);
paddle/fluid/lite/core/mir/node.h
浏览文件 @
621d1522
...
...
@@ -34,30 +34,43 @@ class Node {
Node
()
=
default
;
enum
class
Role
{
kUnk
=
-
1
,
kArgument
,
kArgument
=
0
,
kInstruct
,
kNumRoles
/*should be last*/
kNumRoles
,
/*should be last*/
kUnk
,
};
struct
Instruct
{
std
::
string
op_type
;
Place
place
;
// The kernel instances this Instruct contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
std
::
shared_ptr
<
OpInfo
>
op_info
;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std
::
shared_ptr
<
OpLite
>
op
;
// we hold op to run InferShape
const
OpInfo
*
op_info
()
{
CHECK
(
op
);
return
op
->
op_info
();
}
Place
place
()
const
{
CHECK
(
!
valid_kernels
.
empty
());
return
valid_kernels
.
front
()
->
place
();
}
KernelBase
&
picked_kernel
()
{
CHECK
(
!
valid_kernels
.
empty
());
return
*
valid_kernels
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Instruct
&
other
)
{
os
<<
"Instruct "
<<
other
.
op_type
<<
" "
<<
other
.
place
();
return
os
;
}
};
struct
Argument
{
std
::
string
name
;
const
Type
*
type
;
const
Type
*
type
{}
;
// Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place.
bool
is_weight
{
false
};
...
...
@@ -71,13 +84,11 @@ class Node {
Instruct
&
AsInstruct
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
OpLite
>&
op
,
const
std
::
shared_ptr
<
lite
::
OpInfo
>&
op_info
)
{
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
auto
&
x
=
AsInstruct
();
x
.
op_type
=
op_type
;
x
.
op
=
op
;
x
.
valid_kernels
=
std
::
move
(
kernels
);
x
.
op_info
=
op_info
;
return
x
;
}
...
...
@@ -100,8 +111,25 @@ class Node {
instruct_
.
reset
(
new
Instruct
);
return
*
instruct_
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
Node
&
other
)
{
os
<<
static_cast
<
int
>
(
other
.
role_
)
<<
" "
;
if
(
!
other
.
IsRoleSet
())
{
os
<<
"Unk role node"
;
}
if
(
other
.
IsArgument
())
{
auto
&
arg
=
other
.
AsArgument
();
os
<<
"Argument "
<<
arg
.
name
;
}
if
(
other
.
IsInstruct
())
{
auto
&
arg
=
other
.
AsInstruct
();
os
<<
"Instruct "
<<
arg
.
op_type
;
}
return
os
;
}
// Check roles.
bool
IsRoleSet
()
const
{
return
role_
=
=
Role
::
kUnk
;
}
bool
IsRoleSet
()
const
{
return
role_
!
=
Role
::
kUnk
;
}
bool
IsInstruct
()
const
{
return
role_
==
Role
::
kInstruct
;
}
bool
IsArgument
()
const
{
return
role_
==
Role
::
kArgument
;
}
...
...
paddle/fluid/lite/core/mir/passes.h
浏览文件 @
621d1522
...
...
@@ -26,3 +26,5 @@ USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS
(
variable_place_inference_pass
);
USE_MIR_PASS
(
io_complement_pass
);
USE_MIR_PASS
(
generate_program_pass
);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
621d1522
...
...
@@ -89,6 +89,144 @@ std::vector<mir::Node *> SSAGraph::InstructTopologicalOrder() {
return
res
;
}
void
SSAGraph
::
GraphCreateTmpVarNodes
(
const
Program
&
program
)
{
for
(
const
auto
&
name
:
program
.
tmp_vars
)
{
LOG
(
INFO
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArgument
(
name
);
arguments_
[
name
]
=
&
new_node
;
}
}
void
SSAGraph
::
GraphCreateWeightVarNodes
(
const
Program
&
program
)
{
// create weight nodes.
for
(
const
auto
&
name
:
program
.
weights
)
{
LOG
(
INFO
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArgument
(
name
);
arguments_
[
name
]
=
&
new_node
;
}
}
Node
*
SSAGraph
::
GraphCreateInstructNode
(
const
Program
&
program
,
const
std
::
shared_ptr
<
OpLite
>
&
op
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
node_storage_
.
emplace_back
();
// TODO(Superjomn) remove one valid_places here.
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
kernels
=
op
->
CreateKernels
(
valid_places
);
node_storage_
.
back
().
AsInstruct
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
);
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
return
&
node_storage_
.
back
();
}
void
SSAGraph
::
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
CHECK
(
node_storage_
.
empty
());
GraphCreateTmpVarNodes
(
program
);
GraphCreateWeightVarNodes
(
program
);
CHECK
(
CheckNodesRoleSet
());
for
(
auto
&
op
:
program
.
ops
)
{
auto
*
op_node
=
GraphCreateInstructNode
(
program
,
op
,
valid_places
);
LOG
(
INFO
)
<<
"checking op "
<<
op
->
op_type_
;
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
auto
*
arg
=
Argument
(
name
);
LOG
(
INFO
)
<<
"input "
<<
name
;
CHECK
(
arg
->
IsRoleSet
());
DirectedLink
(
arg
,
op_node
);
}
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
output_names
())
{
if
(
!
arguments_
.
count
(
name
))
{
NewArgumentNode
(
name
);
}
LOG
(
INFO
)
<<
"output "
<<
name
;
auto
*
arg
=
arguments_
.
at
(
name
);
CHECK
(
arg
->
IsRoleSet
());
DirectedLink
(
op_node
,
arg
);
}
CHECK
(
CheckLinksRoleSet
());
}
MarkArgumentWeights
(
program
);
CheckValid
();
}
mir
::
Node
*
SSAGraph
::
Argument
(
const
std
::
string
&
name
)
{
auto
it
=
arguments_
.
find
(
name
);
CHECK
(
it
!=
arguments_
.
end
())
<<
"no argument called "
<<
name
;
return
it
->
second
;
}
std
::
vector
<
mir
::
Node
*>
SSAGraph
::
inputs
()
{
std
::
vector
<
mir
::
Node
*>
res
;
for
(
auto
&
node
:
node_storage_
)
{
if
(
node
.
inlinks
.
empty
())
{
res
.
push_back
(
&
node
);
}
}
return
res
;
}
std
::
vector
<
mir
::
Node
*>
SSAGraph
::
outputs
()
{
std
::
vector
<
mir
::
Node
*>
res
;
for
(
auto
&
node
:
node_storage_
)
{
if
(
node
.
outlinks
.
empty
())
{
res
.
push_back
(
&
node
);
}
}
return
res
;
}
mir
::
Node
*
SSAGraph
::
RetrieveArgument
(
const
std
::
string
&
arg
)
{
auto
it
=
arguments_
.
find
(
arg
);
if
(
it
!=
arguments_
.
end
())
{
return
it
->
second
;
}
return
nullptr
;
}
bool
SSAGraph
::
CheckNodesRoleSet
()
{
for
(
auto
&
node
:
mutable_nodes
())
{
CHECK_OR_FALSE
(
node
.
IsRoleSet
());
}
return
true
;
}
bool
SSAGraph
::
CheckLinksRoleSet
()
{
for
(
auto
&
node
:
mutable_nodes
())
{
CHECK_OR_FALSE
(
node
.
IsRoleSet
());
if
(
!
node
.
IsInstruct
())
continue
;
for
(
auto
*
x
:
node
.
inlinks
)
{
CHECK_OR_FALSE
(
x
->
IsRoleSet
());
CHECK_OR_FALSE
(
x
->
IsArgument
());
}
for
(
auto
*
x
:
node
.
outlinks
)
{
CHECK_OR_FALSE
(
x
->
IsRoleSet
());
CHECK_OR_FALSE
(
x
->
IsArgument
());
}
}
return
true
;
}
Node
*
SSAGraph
::
NewArgumentNode
(
const
std
::
string
&
name
)
{
node_storage_
.
emplace_back
();
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate argument called "
<<
name
;
arguments_
[
name
]
=
&
node_storage_
.
back
();
node_storage_
.
back
().
AsArgument
(
name
);
return
&
node_storage_
.
back
();
}
Node
*
SSAGraph
::
NewInstructNode
()
{
node_storage_
.
emplace_back
();
return
&
node_storage_
.
back
();
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
621d1522
...
...
@@ -35,104 +35,44 @@ class SSAGraph : GraphBase {
public:
// @param program: the op program
// @param valid_places: the valid places user set for the system.
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
// create temporary nodes.
for
(
const
auto
&
name
:
program
.
tmp_vars
)
{
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
arg
=
new_node
.
AsArgument
();
arg
.
name
=
name
;
arguments_
[
name
]
=
&
new_node
;
}
// create weight nodes.
for
(
const
auto
&
name
:
program
.
weights
)
{
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
arg
=
new_node
.
AsArgument
();
arg
.
name
=
name
;
arguments_
[
name
]
=
&
new_node
;
}
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
);
for
(
auto
&
op
:
program
.
ops
)
{
node_storage_
.
emplace_back
();
// TODO(Superjomn) remove one valid_places here.
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
kernels
=
op
->
CreateKernels
(
valid_places
);
node_storage_
.
back
().
AsInstruct
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
,
op
->
op_info
());
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
// collect inputs and outputs
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
auto
*
arg
=
Argument
(
name
);
new_node
.
inlinks
.
push_back
(
arg
);
arg
->
outlinks
.
push_back
(
&
new_node
);
}
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
output_names
())
{
if
(
!
arguments_
.
count
(
name
))
{
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
arg
=
new_node
.
AsArgument
(
name
);
arg
.
name
=
name
;
arguments_
.
emplace
(
name
,
&
new_node
);
}
auto
*
arg
=
arguments_
.
at
(
name
);
new_node
.
outlinks
.
push_back
(
arg
);
arg
->
inlinks
.
push_back
(
&
new_node
);
}
}
MarkArgumentWeights
(
program
);
}
mir
::
Node
*
Argument
(
const
std
::
string
&
name
)
{
auto
it
=
arguments_
.
find
(
name
);
CHECK
(
it
!=
arguments_
.
end
())
<<
"no argument called "
<<
name
;
return
it
->
second
;
}
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
std
::
vector
<
mir
::
Node
*>
InstructTopologicalOrder
();
// The inputs of the graph.
std
::
vector
<
mir
::
Node
*>
inputs
()
{
std
::
vector
<
mir
::
Node
*>
res
;
for
(
auto
&
node
:
node_storage_
)
{
if
(
node
.
inlinks
.
empty
())
{
res
.
push_back
(
&
node
);
}
}
return
res
;
}
std
::
vector
<
mir
::
Node
*>
inputs
();
// The outputs of the graph.
std
::
vector
<
mir
::
Node
*>
outputs
()
{
std
::
vector
<
mir
::
Node
*>
res
;
for
(
auto
&
node
:
node_storage_
)
{
if
(
node
.
outlinks
.
empty
())
{
res
.
push_back
(
&
node
);
}
}
return
res
;
}
std
::
vector
<
mir
::
Node
*>
outputs
();
const
std
::
list
<
mir
::
Node
>
&
nodes
()
const
{
return
node_storage_
;
}
std
::
list
<
mir
::
Node
>
&
mutable_nodes
()
{
return
node_storage_
;
}
mir
::
Node
*
RetrieveArgument
(
const
std
::
string
&
arg
)
{
auto
it
=
arguments_
.
find
(
arg
);
if
(
it
!=
arguments_
.
end
())
{
return
it
->
second
;
}
return
nullptr
;
mir
::
Node
*
RetrieveArgument
(
const
std
::
string
&
arg
);
Node
*
NewArgumentNode
(
const
std
::
string
&
name
);
Node
*
NewInstructNode
();
void
CheckValid
()
{
CHECK
(
CheckBidirectionalConnection
());
CHECK
(
CheckNodesRoleSet
());
CHECK
(
CheckLinksRoleSet
());
}
private:
void
GraphCreateTmpVarNodes
(
const
Program
&
program
);
void
GraphCreateWeightVarNodes
(
const
Program
&
program
);
Node
*
GraphCreateInstructNode
(
const
Program
&
program
,
const
std
::
shared_ptr
<
OpLite
>
&
op
,
const
std
::
vector
<
Place
>
&
valid_places
);
// Check the bidirectional connection.
bool
CheckBidirectionalConnection
();
bool
CheckNodesRoleSet
();
// Check all the items's role in inlinks and outlinks is set.
bool
CheckLinksRoleSet
();
void
MarkArgumentWeights
(
const
Program
&
program
)
{
for
(
const
auto
&
name
:
program
.
weights
)
{
...
...
@@ -152,6 +92,48 @@ class SSAGraph : GraphBase {
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
};
// Remove the link between a -> b.
static
void
RemoveDirectedLink
(
Node
*
a
,
Node
*
b
)
{
auto
it
=
std
::
find
(
b
->
inlinks
.
begin
(),
b
->
inlinks
.
end
(),
a
);
if
(
it
!=
b
->
inlinks
.
end
())
{
b
->
inlinks
.
erase
(
it
);
}
auto
it1
=
std
::
find
(
a
->
outlinks
.
begin
(),
a
->
outlinks
.
end
(),
b
);
if
(
it1
!=
a
->
outlinks
.
end
())
{
a
->
outlinks
.
erase
((
it1
));
}
}
// Link a -> b.
static
void
DirectedLink
(
Node
*
a
,
Node
*
b
)
{
// Eagerly remove first, to avoid duplicate link.
RemoveDirectedLink
(
a
,
b
);
a
->
outlinks
.
push_back
(
b
);
b
->
inlinks
.
push_back
(
a
);
}
static
void
LocalInferenceType
(
Node
*
a
,
Node
*
b
,
const
std
::
string
&
arg_name
)
{
// instr -> output argument
if
(
a
->
IsInstruct
()
&&
b
->
IsArgument
())
{
auto
&
inst
=
a
->
AsInstruct
();
auto
&
output
=
b
->
AsArgument
();
if
(
!
output
.
type
)
{
output
.
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
}
}
// input argument -> instr
if
(
a
->
IsArgument
()
&&
b
->
IsInstruct
())
{
auto
&
input
=
a
->
AsArgument
();
auto
&
inst
=
b
->
AsInstruct
();
if
(
!
input
.
type
)
{
input
.
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
}
}
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
621d1522
...
...
@@ -37,7 +37,9 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
auto
&
instruct
=
node
.
AsInstruct
();
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
scored
.
emplace_back
(
KernelGrade
(
*
kernel
),
std
::
move
(
kernel
));
size_t
score
=
KernelGrade
(
*
kernel
);
LOG
(
INFO
)
<<
"kernel "
<<
kernel
->
summary
()
<<
" "
<<
score
;
scored
.
emplace_back
(
score
,
std
::
move
(
kernel
));
}
std
::
sort
(
scored
.
begin
(),
scored
.
end
(),
KernelScoreCmp
);
...
...
@@ -47,7 +49,6 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// TODO(Superjomn) reconsider this.
instruct
.
valid_kernels
.
clear
();
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
instruct
.
place
=
instruct
.
valid_kernels
.
front
()
->
place
();
LOG
(
INFO
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
}
}
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
浏览文件 @
621d1522
...
...
@@ -37,6 +37,7 @@ class StaticKernelPickPass : public mir::InstructionPass {
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
void
SetPreferPlace
(
const
Place
&
place
)
{
place_
=
place
;
}
const
Place
&
place
()
const
{
return
place_
;
}
const
core
::
KernelPickFactor
&
kernel_pick_factors
()
const
{
return
kernel_pick_factors_
;
...
...
@@ -51,16 +52,32 @@ class StaticKernelPickPass : public mir::InstructionPass {
size_t
score
{};
const
int
kMax
=
std
::
numeric_limits
<
core
::
KernelPickFactor
::
value_type
>::
max
();
// The more important factor comes first
if
(
kernel_pick_factors_
.
IsTargetConsidered
()
&&
place
().
target
==
kernel
.
target
())
{
(
place
().
target
==
kernel
.
target
()
||
kernel
.
target
()
==
TARGET
(
kAny
)
||
place
().
target
==
TARGET
(
kAny
)))
{
score
+=
kMax
/
static_cast
<
int
>
(
core
::
KernelPickFactor
::
Factor
::
TargetFirst
);
}
if
(
kernel_pick_factors_
.
IsPrecisionConsidered
()
&&
place
().
precision
==
kernel
.
precision
())
{
(
place
().
precision
==
kernel
.
precision
()
||
kernel
.
precision
()
==
PRECISION
(
kAny
)
||
place
().
precision
==
PRECISION
(
kAny
)))
{
score
+=
kMax
/
static_cast
<
int
>
(
core
::
KernelPickFactor
::
Factor
::
PrecisionFirst
);
}
if
(
kernel_pick_factors_
.
IsDataLayoutConsidered
()
&&
(
place
().
layout
==
kernel
.
layout
()
||
kernel
.
layout
()
==
DATALAYOUT
(
kAny
)
||
place
().
layout
==
DATALAYOUT
(
kAny
)))
{
score
+=
kMax
/
static_cast
<
int
>
(
core
::
KernelPickFactor
::
Factor
::
DataLayoutFirst
);
}
LOG
(
INFO
)
<<
"picker tactic "
<<
kernel_pick_factors_
;
LOG
(
INFO
)
<<
"kernel place "
<<
kernel
.
place
();
LOG
(
INFO
)
<<
"picker place "
<<
place
();
LOG
(
INFO
)
<<
"score "
<<
score
;
// The data layout is not considered, for the input and output arguments
// might have different data layout.
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.cc
浏览文件 @
621d1522
...
...
@@ -22,8 +22,8 @@ namespace mir {
void
VariablePlaceInferencePass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
MarkInputPlace
(
graph
.
get
());
InferenceArgumentPlace
(
graph
.
get
());
CheckAllArgumentTypeDetermined
(
graph
.
get
());
}
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
621d1522
...
...
@@ -31,6 +31,7 @@ class VariablePlaceInferencePass : public DebugPass {
private:
// Mark the place of input arguments.
void
MarkInputPlace
(
SSAGraph
*
graph
)
{
CHECK
(
!
graph
->
inputs
().
empty
())
<<
"graph's inputs should be set"
;
for
(
const
auto
&
v
:
graph
->
inputs
())
{
// the feed op might in the inputs
if
(
v
->
IsInstruct
())
{
...
...
@@ -39,54 +40,60 @@ class VariablePlaceInferencePass : public DebugPass {
}
// auto& arg = v->AsArgument();
// arg.place.target = argument_default_target_;
// LOG(INFO) << "get graph input " << arg.name << " " << *arg.type;
// arg.type.target = argument_default_target_;
// the other place description can't be determined yet, until their first
// usage by some kernel.
}
}
void
CheckAllArgumentTypeDetermined
(
SSAGraph
*
graph
)
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
node
.
IsArgument
())
{
CHECK
(
node
.
AsArgument
().
type
)
<<
"node "
<<
node
.
AsArgument
().
name
<<
" type not determined"
;
}
}
}
void
InferenceArgumentPlace
(
SSAGraph
*
graph
)
{
LOG
(
INFO
)
<<
"param-type-registry:
\n
"
<<
ParamTypeRegistry
::
Global
();
for
(
auto
&
x
:
graph
->
InstructTopologicalOrder
())
{
auto
&
inst
=
x
->
AsInstruct
();
CHECK
(
inst
.
place
.
is_valid
())
<<
"kernel's place should be set when loaded"
;
// The IoCopyOp is a tool operator, it won't support the type inference.
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
// LOG(INFO) << "- inferencing type " <<
// deal with inputs
for
(
auto
&
arg_name
:
inst
.
op_info
->
input_argnames
())
{
auto
type
=
ParamTypeRegistry
::
Global
().
Retrieve
<
ParamTypeRegistry
::
IO
::
kInput
>
(
inst
.
place
,
inst
.
op_type
,
arg_name
);
CHECK
(
type
)
<<
"no param-type found for "
<<
inst
.
op_type
<<
":"
<<
arg_name
<<
" "
<<
inst
.
place
.
DebugString
();
auto
arg_names
=
inst
.
op_info
->
input_argument
().
at
(
arg_name
);
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
input_argnames
())
{
LOG
(
INFO
)
<<
"-- input arg_name "
<<
arg_name
;
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
auto
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
input_argument
().
at
(
arg_name
);
for
(
auto
&
arg_name
:
arg_names
)
{
LOG
(
INFO
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArgument
();
if
(
arg_node
.
type
)
continue
;
arg_node
.
type
=
type
->
type
;
arg_node
.
type
=
type
;
}
}
for
(
auto
&
arg_name
:
inst
.
op_info
->
output_argnames
())
{
auto
type
=
ParamTypeRegistry
::
Global
()
.
Retrieve
<
ParamTypeRegistry
::
IO
::
kOutput
>
(
inst
.
place
,
inst
.
op_type
,
arg_name
);
CHECK
(
type
)
<<
"no param-type found for "
<<
inst
.
op_type
<<
":"
<<
arg_name
<<
" "
<<
inst
.
place
.
DebugString
();
auto
arg_names
=
inst
.
op_info
->
output_argument
().
at
(
arg_name
);
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
output_argnames
())
{
LOG
(
INFO
)
<<
"-- output arg_name "
<<
arg_name
;
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
output_argument
().
at
(
arg_name
);
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
LOG
(
INFO
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArgument
();
if
(
arg_node
.
type
)
continue
;
node
->
AsArgument
().
type
=
type
->
type
;
node
->
AsArgument
().
type
=
type
;
}
}
}
...
...
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
621d1522
...
...
@@ -27,13 +27,15 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
for
(
auto
place
:
places
)
{
auto
ks
=
KernelRegistry
::
Global
().
Create
(
(
kernel_type
.
empty
()
?
op_type_
:
kernel_type
),
place
.
target
,
place
.
precision
);
place
.
precision
,
place
.
layout
);
for
(
auto
&&
it
:
ks
)
{
AttachKernel
(
it
.
get
());
kernels
.
emplace_back
(
std
::
move
(
it
));
}
}
CHECK
(
!
kernels
.
empty
())
<<
"No kernel found for Op "
<<
op_type_
;
LOG
(
INFO
)
<<
"op "
<<
op_type_
<<
" get "
<<
kernels
.
size
()
<<
" kernels"
;
return
kernels
;
}
...
...
@@ -59,9 +61,10 @@ bool OpLite::Run() {
}
bool
OpLite
::
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
CHECK
(
!
op_info_
)
<<
"op_info duplicate build found"
;
op_info_
=
std
::
make_shared
<
OpInfo
>
();
op_info_
->
Build
(
opdesc
);
CHECK
(
scope
);
scope_
=
scope
;
op_info_
.
reset
(
new
OpInfo
);
// Force clean the out-of-date infomation.
op_info_
->
Build
(
opdesc
.
ReadonlyProto
());
return
AttachImpl
(
opdesc
,
scope
);
}
...
...
@@ -79,7 +82,8 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope,
return
var
->
GetMutable
<
lite
::
Tensor
>
();
}
bool
OpInfo
::
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
{
bool
OpInfo
::
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
input_argument_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
...
...
@@ -89,7 +93,8 @@ bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) {
}
return
false
;
}
bool
OpInfo
::
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
{
bool
OpInfo
::
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
output_argument_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
...
...
paddle/fluid/lite/core/op_lite.h
浏览文件 @
621d1522
...
...
@@ -81,19 +81,30 @@ class OpLite : public Registry {
// Run this operator.
virtual
bool
Run
();
// Link the external execution environ to internal context.
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
const
std
::
shared_ptr
<
OpInfo
>
&
op_info
()
const
{
return
op_info_
;
}
std
::
shared_ptr
<
OpInfo
>
&
mutable_op_info
()
{
return
op_info_
;
}
const
OpInfo
*
op_info
()
const
{
return
op_info_
.
get
()
;
}
OpInfo
*
mutable_op_info
()
{
return
op_info_
.
get
()
;
}
// Human-readable information.
virtual
std
::
string
DebugString
()
const
=
0
;
const
Place
&
kernel_place
()
const
{
return
kernel_place_
;
}
// NOTE This might be discarded.
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
// Create all the kernels for the valid targets.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
CreateKernels
(
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
=
""
);
lite
::
Scope
*
scope
()
{
return
scope_
;
}
// Assign op param to kernel.
virtual
void
AttachKernel
(
KernelBase
*
kernel
)
=
0
;
virtual
~
OpLite
()
=
default
;
protected:
...
...
@@ -101,9 +112,6 @@ class OpLite : public Registry {
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
// Assign op param to kernel.
virtual
void
AttachKernel
(
KernelBase
*
kernel
)
=
0
;
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
...
...
@@ -118,10 +126,6 @@ class OpLite : public Registry {
// some inputs are ready.
void
RecordOutputEvents
()
{}
// Create all the kernels for the valid targets.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
CreateKernels
(
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
=
""
);
const
Tensor
*
GetTensor
(
lite
::
Scope
*
scope
,
const
std
::
string
&
name
)
const
;
Tensor
*
GetMutableTensor
(
lite
::
Scope
*
scope
,
const
std
::
string
&
name
)
const
;
...
...
@@ -129,11 +133,12 @@ class OpLite : public Registry {
friend
class
mir
::
SSAGraph
;
protected:
lite
::
Scope
*
scope_
{};
std
::
unique_ptr
<
KernelBase
>
kernel_
;
std
::
string
op_type_
;
std
::
vector
<
Place
>
valid_places_
;
Place
kernel_place_
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)};
std
::
shared
_ptr
<
OpInfo
>
op_info_
;
std
::
unique
_ptr
<
OpInfo
>
op_info_
;
};
/*
...
...
@@ -142,22 +147,30 @@ class OpLite : public Registry {
*/
class
OpInfo
{
public:
void
Build
(
const
framework
::
OpDesc
&
desc
)
{
// To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf
// message instead.
void
Build
(
const
framework
::
proto
::
OpDesc
&
desc
)
{
ExtractInputsAndOutputs
(
desc
);
CollectInputAndOutputArgnames
(
desc
);
CollectArguments
(
desc
);
desc_
.
reset
(
new
framework
::
proto
::
OpDesc
(
desc
));
}
const
framework
::
proto
::
OpDesc
&
desc
()
const
{
CHECK
(
desc_
)
<<
"desc has't set"
;
return
*
desc_
;
}
framework
::
proto
::
OpDesc
*
mutable_desc
()
{
return
desc_
.
get
();
}
const
std
::
list
<
std
::
string
>
&
input_names
()
const
{
return
input_names_
;
}
const
std
::
list
<
std
::
string
>
&
output_names
()
const
{
return
output_names_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
input_argument
()
{
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
input_argument
()
const
{
return
input_argument_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
output_argument
()
{
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
output_argument
()
const
{
return
output_argument_
;
}
bool
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
);
bool
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
);
bool
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
;
bool
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
;
const
std
::
list
<
std
::
string
>
&
input_argnames
()
const
{
return
input_argnames_
;
...
...
@@ -167,37 +180,37 @@ class OpInfo {
}
private:
void
ExtractInputsAndOutputs
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
I
nputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
void
ExtractInputsAndOutputs
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
i
nputs
())
{
for
(
const
auto
&
x
:
item
.
arguments
()
)
{
input_names_
.
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
O
utputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
for
(
const
auto
&
item
:
opdesc
.
o
utputs
())
{
for
(
const
auto
&
x
:
item
.
arguments
()
)
{
output_names_
.
push_back
(
x
);
}
}
}
void
CollectInputAndOutputArgnames
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
InputName
s
())
{
input_argnames_
.
push_back
(
item
);
void
CollectInputAndOutputArgnames
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
input
s
())
{
input_argnames_
.
push_back
(
item
.
parameter
()
);
}
for
(
const
auto
&
item
:
opdesc
.
OutputName
s
())
{
output_argnames_
.
push_back
(
item
);
for
(
const
auto
&
item
:
opdesc
.
output
s
())
{
output_argnames_
.
push_back
(
item
.
parameter
()
);
}
}
void
CollectArguments
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
I
nputs
())
{
for
(
auto
&
x
:
item
.
second
)
{
input_argument_
[
item
.
first
].
push_back
(
x
);
void
CollectArguments
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
i
nputs
())
{
for
(
auto
&
x
:
item
.
arguments
()
)
{
input_argument_
[
item
.
parameter
()
].
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
O
utputs
())
{
for
(
auto
&
x
:
item
.
second
)
{
output_argument_
[
item
.
first
].
push_back
(
x
);
for
(
const
auto
&
item
:
opdesc
.
o
utputs
())
{
for
(
auto
&
x
:
item
.
arguments
()
)
{
output_argument_
[
item
.
parameter
()
].
push_back
(
x
);
}
}
}
...
...
@@ -209,6 +222,8 @@ class OpInfo {
std
::
list
<
std
::
string
>
output_argnames_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
input_argument_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
output_argument_
;
// NOTE too heavy.
std
::
unique_ptr
<
framework
::
proto
::
OpDesc
>
desc_
;
};
}
// namespace lite
...
...
paddle/fluid/lite/core/op_registry.cc
浏览文件 @
621d1522
...
...
@@ -18,13 +18,33 @@ namespace paddle {
namespace
lite
{
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
KernelRegistry
::
Create
(
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
)
{
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
return Create<TARGET(target__), PRECISION(kFloat)>(op_type); \
default: \
CHECK(false) << "not supported kernel place yet"; \
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
)
{
Place
place
{
target
,
precision
,
layout
};
LOG
(
INFO
)
<<
"creating "
<<
op_type
<<
" kernel for "
<<
place
;
#define CREATE_KERNEL1(target__, precision__) \
switch (layout) { \
case DATALAYOUT(kNCHW): \
return Create<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(kNCHW)>(op_type); \
case DATALAYOUT(kAny): \
return Create<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(kAny)>(op_type); \
default: \
LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \
}
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
CREATE_KERNEL1(target__, kFloat); \
case PRECISION(kInt8): \
CREATE_KERNEL1(target__, kInt8); \
case PRECISION(kAny): \
CREATE_KERNEL1(target__, kAny); \
default: \
CHECK(false) << "not supported kernel precision " \
<< PrecisionToStr(precision); \
}
switch
(
target
)
{
...
...
@@ -38,7 +58,7 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
CREATE_KERNEL
(
kCUDA
);
}
break
;
default:
CHECK
(
false
)
<<
"not supported kernel
place"
;
CHECK
(
false
)
<<
"not supported kernel
target "
<<
TargetToStr
(
target
)
;
}
#undef CREATE_KERNEL
...
...
@@ -46,14 +66,21 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
}
KernelRegistry
::
KernelRegistry
()
{
#define INIT_FOR(target__, precision__
)
\
#define INIT_FOR(target__, precision__
, layout__)
\
registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \
PRECISION(precision__)>()] \
.set<KernelRegistryForTarget<TARGET(target__), PRECISION(precision__)> \
*>(&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__)>::Global());
PRECISION(precision__), \
DATALAYOUT(layout__)>()] \
.set<KernelRegistryForTarget<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(layout__)> *>( \
&KernelRegistryForTarget<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(layout__)>::Global());
// Currently, just register 2 kernel targets.
INIT_FOR
(
kHost
,
kFloat
);
INIT_FOR
(
kCUDA
,
kFloat
,
kNCHW
);
INIT_FOR
(
kCUDA
,
kAny
,
kNCHW
);
INIT_FOR
(
kHost
,
kFloat
,
kNCHW
);
INIT_FOR
(
kHost
,
kAny
,
kNCHW
);
INIT_FOR
(
kHost
,
kAny
,
kAny
);
INIT_FOR
(
kCUDA
,
kAny
,
kAny
);
#undef INIT_FOR
}
...
...
paddle/fluid/lite/core/op_registry.h
浏览文件 @
621d1522
...
...
@@ -50,80 +50,108 @@ class OpLiteRegistor : public Registor<OpClass> {
})
{}
};
template
<
TargetType
Target
,
PrecisionType
Precision
>
template
<
TargetType
Target
,
PrecisionType
Precision
,
DataLayoutType
Layout
>
using
KernelRegistryForTarget
=
Factory
<
OpKernel
<
Target
,
Precision
>
,
std
::
unique_ptr
<
KernelBase
>>
;
Factory
<
OpKernel
<
Target
,
Precision
,
Layout
>
,
std
::
unique_ptr
<
KernelBase
>>
;
class
KernelRegistry
final
{
public:
using
any_kernel_registor_t
=
variant
<
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
*
//
variant
<
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kInt8
),
DATALAYOUT
(
kNCHW
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kInt8
),
DATALAYOUT
(
kNCHW
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)
>
*
//
>
;
KernelRegistry
();
static
KernelRegistry
&
Global
();
template
<
TargetType
Target
,
PrecisionType
Precision
>
template
<
TargetType
Target
,
PrecisionType
Precision
,
DataLayoutType
Layout
>
void
Register
(
const
std
::
string
&
name
,
typename
KernelRegistryForTarget
<
Target
,
Precision
>::
creator_t
&&
creator
)
{
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
>
;
registries_
[
GetKernelOffset
<
Target
,
Precision
>
()]
.
template
get
<
kernel_registor_t
*
>()
->
Register
(
name
,
std
::
move
(
creator
));
typename
KernelRegistryForTarget
<
Target
,
Precision
,
Layout
>::
creator_t
&&
creator
)
{
LOG
(
INFO
)
<<
"register for "
<<
TargetToStr
(
Target
)
<<
":"
<<
PrecisionToStr
(
Precision
)
<<
"//"
<<
GetKernelOffset
<
Target
,
Precision
,
Layout
>
();
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
,
Layout
>
;
auto
&
varient
=
registries_
[
GetKernelOffset
<
Target
,
Precision
,
Layout
>
()];
varient
.
template
get
<
kernel_registor_t
*
>()
->
Register
(
name
,
std
::
move
(
creator
));
}
template
<
TargetType
Target
,
PrecisionType
Precision
>
template
<
TargetType
Target
,
PrecisionType
Precision
,
DataLayoutType
Layout
>
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
Create
(
const
std
::
string
&
op_type
)
{
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
>
;
return
registries_
[
GetKernelOffset
<
Target
,
Precision
>
()]
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
,
Layout
>
;
return
registries_
[
GetKernelOffset
<
Target
,
Precision
,
Layout
>
()]
.
template
get
<
kernel_registor_t
*
>()
->
Creates
(
op_type
);
}
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
Create
(
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
);
PrecisionType
precision
,
DataLayoutType
layout
);
// Get a kernel registry offset in all the registries.
template
<
TargetType
Target
,
PrecisionType
Precision
>
static
constexpr
int
GetKernelOffset
()
{
return
kNumTargets
*
static_cast
<
int
>
(
Target
)
+
static_cast
<
int
>
(
Precision
);
template
<
TargetType
Target
,
PrecisionType
Precision
,
DataLayoutType
Layout
>
static
int
GetKernelOffset
()
{
CHECK_LT
(
static_cast
<
int
>
(
Target
),
static_cast
<
int
>
(
TARGET
(
NUM
)));
CHECK_LT
(
static_cast
<
int
>
(
Precision
),
static_cast
<
int
>
(
PRECISION
(
NUM
)));
CHECK_LT
(
static_cast
<
int
>
(
Layout
),
static_cast
<
int
>
(
DATALAYOUT
(
NUM
)));
return
static_cast
<
int
>
(
Target
)
*
static_cast
<
int
>
(
PRECISION
(
NUM
))
*
static_cast
<
int
>
(
DATALAYOUT
(
NUM
))
+
//
static_cast
<
int
>
(
Precision
)
*
static_cast
<
int
>
(
DATALAYOUT
(
NUM
))
+
//
static_cast
<
int
>
(
Layout
);
}
std
::
string
DebugString
()
const
{
std
::
stringstream
ss
;
ss
<<
"KernelCreator<host, float>:"
<<
std
::
endl
;
ss
<<
registries_
[
GetKernelOffset
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
()]
.
get
<
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
*>
()
ss
<<
registries_
[
GetKernelOffset
<
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kAny
)
>
()]
.
get
<
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)
>
*>
()
->
DebugString
();
ss
<<
std
::
endl
;
return
ss
.
str
();
}
private:
mutable
std
::
array
<
any_kernel_registor_t
,
kNumTargets
*
kNumPrecisions
>
mutable
std
::
array
<
any_kernel_registor_t
,
static_cast
<
int
>
(
TARGET
(
NUM
))
*
static_cast
<
int
>
(
PRECISION
(
NUM
))
*
static_cast
<
int
>
(
DATALAYOUT
(
NUM
))
>
registries_
;
};
template
<
TargetType
target
,
PrecisionType
precision
,
typename
KernelType
>
template
<
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
,
typename
KernelType
>
class
KernelRegistor
:
public
lite
::
Registor
<
KernelType
>
{
public:
KernelRegistor
(
const
std
::
string
op_type
)
:
Registor
<
KernelType
>
([
&
]
{
KernelRegistor
(
const
std
::
string
&
op_type
,
const
std
::
string
&
alias
)
:
Registor
<
KernelType
>
([
=
]
{
LOG
(
INFO
)
<<
"Register kernel "
<<
op_type
<<
" for "
<<
TargetToStr
(
target
)
<<
" "
<<
PrecisionToStr
(
precision
);
KernelRegistry
::
Global
().
Register
<
target
,
precision
>
(
op_type
,
[
&
,
op_type
]()
->
std
::
unique_ptr
<
KernelType
>
{
<<
TargetToStr
(
target
)
<<
" "
<<
PrecisionToStr
(
precision
)
<<
" "
<<
DataLayoutToStr
(
layout
)
<<
" alias "
<<
alias
;
KernelRegistry
::
Global
().
Register
<
target
,
precision
,
layout
>
(
op_type
,
[
=
]()
->
std
::
unique_ptr
<
KernelType
>
{
std
::
unique_ptr
<
KernelType
>
x
(
new
KernelType
);
x
->
set_op_type
(
op_type
);
x
->
set_alias
(
alias
);
return
x
;
});
})
{}
...
...
@@ -151,35 +179,40 @@ class KernelRegistor : public lite::Registor<KernelType> {
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
op_type__##__##target__##__##precision__##__registor__
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
alias__)
\
layout__, alias__)
\
op_type__##__##target__##__##precision__##__registor__instance__##alias__
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__)
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass, \
alias__) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
alias__)(#op_type__); \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \
alias__); \
int touch_##op_type__##target__##precision__##alias__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__).Touch(); \
return 0; \
} \
static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \
alias__) __attribute__((unused)) = \
paddle::lite::ParamTypeRegistry::NewInstance<TARGET(target__), \
PRECISION(precision__)>( \
#op_type__)
#define USE_LITE_KERNEL(op_type__, target__, precision__, alias__) \
extern int touch_##op_type__##target__##precision__##alias__(); \
int op_type__##target__##precision__##alias__ __attribute__((unused)) = \
touch_##op_type__##target__##precision__##alias__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__##param_register
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, layout__, \
KernelClass, alias__) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
layout__, alias__)(#op_type__, #alias__); \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \
layout__, alias__); \
int touch_##op_type__##target__##precision__##layout__##alias__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, alias__) \
.Touch(); \
return 0; \
} \
static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \
layout__, alias__) \
__attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \
TARGET(target__), PRECISION(precision__), DATALAYOUT(layout__)>( \
#op_type__ "/" #alias__)
#define USE_LITE_KERNEL(op_type__, target__, precision__, layout__, alias__) \
extern int touch_##op_type__##target__##precision__##layout__##alias__(); \
int op_type__##target__##precision__##layout__##alias__ \
__attribute__((unused)) = \
touch_##op_type__##target__##precision__##layout__##alias__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, \
alias__) \
op_type__##target__##precision__##layout__##alias__
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, layout__, \
alias__) \
op_type__##target__##precision__##layout__##alias__##param_register
paddle/fluid/lite/core/optimizer.cc
浏览文件 @
621d1522
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
namespace
paddle
{
...
...
@@ -25,5 +26,33 @@ void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) {
*
pass
->
mutable_kernel_pick_factors
()
=
factor
;
}
void
Optimizer
::
RunPasses
()
{
std
::
vector
<
std
::
string
>
passes
({
"static_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"io_complement_pass"
,
//
"argument_type_display_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"io_copy_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
});
for
(
auto
&
pass_type
:
passes
)
{
LOG
(
INFO
)
<<
".. running pass "
<<
pass_type
;
auto
*
pass
=
mir
::
PassManager
::
Global
().
LookUp
(
pass_type
);
CHECK
(
pass
);
if
(
pass
->
name
()
==
"io_complement_pass"
)
{
auto
*
_pass
=
dynamic_cast
<
mir
::
IoComplementPass
*>
(
pass
);
_pass
->
SetValidPlaces
(
valid_places_
);
CHECK
(
!
_pass
->
valid_places
().
empty
());
_pass
->
Apply
(
graph_
);
}
else
{
pass
->
Apply
(
graph_
);
}
}
// mir::PassManager::Global().Run(graph_);
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/optimizer.h
浏览文件 @
621d1522
...
...
@@ -16,8 +16,10 @@
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
...
...
@@ -33,25 +35,46 @@ class Optimizer {
void
Run
(
Program
&&
program
,
const
std
::
vector
<
Place
>&
valid_places
,
core
::
KernelPickFactor
kernel_pick_factor
,
const
std
::
vector
<
std
::
string
>&
passes
=
{})
{
valid_places_
=
valid_places
;
CHECK
(
!
valid_places
.
empty
())
<<
"At least one valid_place should be set"
;
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
->
Build
(
program
,
valid_places
);
SpecifyKernelPickTactic
(
kernel_pick_factor
);
// InitIoComplement();
RunPasses
();
exec_scope_
=
program
.
exec_scope
;
}
void
KernelPickPreferPlace
(
const
Place
&
place
)
{
auto
*
pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
StaticKernelPickPass
>
(
"static_kernel_pick_pass"
);
CHECK
(
pass
);
pass
->
SetPreferPlace
(
place
);
}
// Generate a new program based on the mir graph.
std
::
unique_ptr
<
RuntimeProgram
>
GenRuntimeProgram
()
{
LOG
(
INFO
)
<<
"generate program"
;
std
::
unique_ptr
<
Program
>
res
;
auto
pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
GenerateProgramPass
>
(
"generate_program_pass"
);
pass
->
Apply
(
graph_
);
auto
program
=
pass
->
GenProgram
();
CHECK
(
exec_scope_
);
program
->
set_exec_scope
(
exec_scope_
);
return
program
;
}
void
InitIoComplement
()
{
auto
*
pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
IoComplementPass
>
(
"io_complement_pass"
);
CHECK
(
pass
);
CHECK
(
!
valid_places_
.
empty
());
LOG
(
INFO
)
<<
"valid_places.size "
<<
valid_places_
.
size
();
pass
->
SetValidPlaces
(
valid_places_
);
}
// Generate C++ code which combines the inference program, model and weights.
void
GenCode
(
const
std
::
string
&
code_dir
);
...
...
@@ -64,13 +87,14 @@ class Optimizer {
void
SpecifyKernelPickTactic
(
core
::
KernelPickFactor
factor
);
// Run the default passes registered in the PassManager.
void
RunPasses
()
{
mir
::
PassManager
::
Global
().
Run
(
graph_
);
}
void
RunPasses
()
;
// Specify the passes and run them.
void
RunPasses
(
std
::
vector
<
std
::
string
>&
passes
);
private:
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
std
::
vector
<
Place
>
valid_places_
;
lite
::
Scope
*
exec_scope_
{};
};
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
621d1522
...
...
@@ -84,13 +84,10 @@ struct Program {
tmp_vars
.
push_back
(
"fetch"
);
for
(
auto
var_desc
:
program
.
Block
(
0
).
AllVars
())
{
if
(
!
var_desc
->
Persistable
())
{
LOG
(
INFO
)
<<
"get tmp var "
<<
var_desc
->
Name
();
tmp_vars
.
push_back
(
var_desc
->
Name
());
auto
*
var
=
exec_scope
->
Var
(
var_desc
->
Name
());
LOG
(
INFO
)
<<
"create tmp var "
<<
var_desc
->
Name
()
<<
" "
<<
var
;
exec_scope
->
Var
(
var_desc
->
Name
());
}
else
{
if
(
var_desc
->
Name
()
==
"feed"
||
var_desc
->
Name
()
==
"fetch"
)
continue
;
LOG
(
INFO
)
<<
"get weight var "
<<
var_desc
->
Name
();
weights
.
push_back
(
var_desc
->
Name
());
}
}
...
...
@@ -105,15 +102,19 @@ struct Instruction {
void
Run
()
{
CHECK
(
op_
);
CHECK
(
kernel_
);
LOG
(
INFO
)
<<
"running kernel> "
<<
kernel_
->
DebugString
();
if
(
UNLIKELY
(
first_epoch_
))
{
first_epoch_
=
false
;
op_
->
CheckShape
(
);
CHECK
(
op_
->
CheckShape
()
);
}
op_
->
InferShape
();
kernel_
->
Run
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Instruction
&
other
)
{
os
<<
other
.
kernel_
->
summary
()
<<
"
\t
("
<<
other
.
kernel_
->
doc
()
<<
")"
;
return
os
;
}
private:
std
::
shared_ptr
<
OpLite
>
op_
;
std
::
unique_ptr
<
KernelBase
>
kernel_
;
...
...
@@ -125,11 +126,16 @@ struct Instruction {
*/
class
RuntimeProgram
{
public:
explicit
RuntimeProgram
(
std
::
vector
<
Instruction
>&&
instruction
)
:
instructions_
(
std
::
move
(
instruction
))
{}
explicit
RuntimeProgram
(
std
::
vector
<
Instruction
>&&
insts
)
:
instructions_
(
std
::
move
(
insts
))
{
if
(
insts
.
empty
())
{
LOG
(
ERROR
)
<<
"no instructions"
;
}
}
void
Run
()
{
for
(
auto
&
inst
:
instructions_
)
{
LOG
(
INFO
)
<<
">> Running kernel: "
<<
inst
;
inst
.
Run
();
}
}
...
...
paddle/fluid/lite/core/target_wrapper.h
浏览文件 @
621d1522
...
...
@@ -16,6 +16,10 @@
#include <glog/logging.h>
#include <iostream>
#include <sstream>
#ifdef LITE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
namespace
paddle
{
namespace
lite
{
...
...
@@ -26,20 +30,20 @@ enum class TargetType : int {
kX86
,
kCUDA
,
kAny
,
// any target
kLastAsPlaceHolder
,
NUM
,
// number of fields.
};
enum
class
PrecisionType
:
int
{
kUnk
=
0
,
kFloat
,
kInt8
,
kAny
,
// any precision
kLastAsPlaceHolder
,
NUM
,
// number of fields.
};
enum
class
DataLayoutType
:
int
{
kUnk
=
0
,
kNCHW
,
kAny
,
// any data layout
kLastAsPlaceHolder
,
NUM
,
// number of fields.
};
// Some helper macro to get a specific TargetType.
...
...
@@ -50,25 +54,29 @@ enum class DataLayoutType : int {
#define PRECISION_VAL(item__) static_cast<int>(PRECISION(item__))
#define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__
constexpr
const
int
kNumPrecisions
=
PRECISION_VAL
(
kLastAsPlaceHolder
)
-
PRECISION_VAL
(
kFloat
);
constexpr
const
int
kNumTargets
=
TARGET_VAL
(
kLastAsPlaceHolder
)
-
TARGET_VAL
(
kHost
);
constexpr
const
int
kNumPrecisions
=
PRECISION_VAL
(
NUM
);
constexpr
const
int
kNumTargets
=
TARGET_VAL
(
NUM
);
static
const
std
::
string
target2string
[]
=
{
"unk"
,
"host"
,
"x86"
,
"cuda"
,
"any"
};
static
const
std
::
string
&
TargetToStr
(
TargetType
target
)
{
return
target2string
[
static_cast
<
int
>
(
target
)];
auto
x
=
static_cast
<
int
>
(
target
);
CHECK_LT
(
x
,
static_cast
<
int
>
(
TARGET
(
NUM
)));
return
target2string
[
x
];
}
static
const
std
::
string
precision2string
[]
=
{
"unk"
,
"float"
,
"int8"
,
"any"
};
static
const
std
::
string
&
PrecisionToStr
(
PrecisionType
precision
)
{
return
precision2string
[
static_cast
<
int
>
(
precision
)];
auto
x
=
static_cast
<
int
>
(
precision
);
CHECK_LT
(
x
,
static_cast
<
int
>
(
PRECISION
(
NUM
)));
return
precision2string
[
x
];
}
static
const
std
::
string
datalayout2string
[]
=
{
"unk"
,
"NCHW"
,
"any"
};
static
const
std
::
string
&
DataLayoutToStr
(
DataLayoutType
x
)
{
return
datalayout2string
[
static_cast
<
int
>
(
x
)];
static
const
std
::
string
&
DataLayoutToStr
(
DataLayoutType
layout
)
{
auto
x
=
static_cast
<
int
>
(
layout
);
CHECK_LT
(
x
,
static_cast
<
int
>
(
DATALAYOUT
(
NUM
)));
return
datalayout2string
[
x
];
}
/*
...
...
@@ -187,5 +195,37 @@ class TargetWrapper<TARGET(kHost)> {
}
};
#ifdef LITE_WITH_CUDA
// This interface should be specified by each kind of target.
template
<
>
class
TargetWrapper
<
TARGET
(
kCUDA
),
cudaStream_t
,
cudaEvent_t
>
{
public:
using
stream_t
=
cudaStream_t
;
using
event_t
=
cudaEvent_t
;
static
size_t
num_devices
()
{
return
0
;
}
static
size_t
maximum_stream
()
{
return
0
;
}
static
void
CreateStream
(
stream_t
*
stream
)
{}
static
void
DestroyStream
(
const
stream_t
&
stream
)
{}
static
void
CreateEvent
(
event_t
*
event
)
{}
static
void
DestroyEvent
(
const
event_t
&
event
)
{}
static
void
RecordEvent
(
const
event_t
&
event
)
{}
static
void
SyncEvent
(
const
event_t
&
event
)
{}
static
void
StreamSync
(
const
stream_t
&
stream
)
{}
static
void
*
Malloc
(
size_t
size
);
static
void
Free
(
void
*
ptr
);
static
void
MemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
);
static
void
MemcpyAsync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
,
const
stream_t
&
stream
);
};
#endif // LITE_WITH_CUDA
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/type_system.cc
浏览文件 @
621d1522
...
...
@@ -87,20 +87,41 @@ const Type* Type::Get<TensorAnyTy>(TargetType target) {
}
}
template
<
TargetType
Target
>
const
Type
*
GetTensorFp32NCHWTy
()
{
static
TensorFp32NCHWTy
x
(
Target
);
return
&
x
;
}
template
<
>
const
Type
*
Type
::
Get
<
TensorFp32NCHWTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TargetType
::
kX86
:
return
Get
<
false
,
true
,
TargetType
::
kX86
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
case
TargetType
::
kHost
:
return
Get
<
false
,
true
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
case
TARGET
(
kHost
):
return
GetTensorFp32NCHWTy
<
TARGET
(
kHost
)
>
();
case
TARGET
(
kCUDA
):
return
GetTensorFp32NCHWTy
<
TARGET
(
kCUDA
)
>
();
case
TARGET
(
kX86
):
return
GetTensorFp32NCHWTy
<
TARGET
(
kX86
)
>
();
default:
LOG
(
FATAL
)
<<
"unsupported target Type "
<<
TargetToStr
(
target
);
}
return
nullptr
;
}
const
Type
*
LookupType
(
DataTypeBase
::
ID
type_id
,
bool
is_unknown
,
bool
is_tensor
,
Place
place
)
{
using
id_t
=
DataTypeBase
::
ID
;
switch
(
type_id
)
{
case
id_t
::
Tensor_Any
:
return
Type
::
Get
<
TensorAnyTy
>
(
place
.
target
);
case
id_t
::
Tensor_Fp32_NCHW
:
return
Type
::
Get
<
TensorFp32NCHWTy
>
(
place
.
target
);
case
id_t
::
TensorList_Any
:
return
Type
::
Get
<
TensorListAnyTy
>
(
place
.
target
);
default:
LOG
(
FATAL
)
<<
"unsupported target "
<<
TargetToStr
(
target
);
return
nullptr
;
LOG
(
FATAL
)
<<
"unsupported type"
;
}
return
nullptr
;
}
// ------------------------- end GetType specification ------------------------
...
...
paddle/fluid/lite/core/type_system.h
浏览文件 @
621d1522
...
...
@@ -131,6 +131,23 @@ class Type : public DataTypeBase {
bool
operator
==
(
const
Type
&
other
)
{
return
id_
==
other
.
id
()
&&
place_
==
other
.
place
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Type
&
other
)
{
if
(
other
.
IsUnsupported
())
{
os
<<
"<Unsupported>"
;
return
os
;
}
if
(
other
.
IsVoid
())
{
os
<<
"<Void>"
;
return
os
;
}
if
(
other
.
is_tensor_
)
{
os
<<
"<Tensor:"
;
}
os
<<
TargetToStr
(
other
.
target
())
<<
"/"
<<
PrecisionToStr
(
other
.
precision
())
<<
"/"
<<
DataLayoutToStr
(
other
.
layout
())
<<
">"
;
return
os
;
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a instruction to transform a type to another.
...
...
@@ -163,29 +180,33 @@ class Type : public DataTypeBase {
};
// -------------------------------- compatible check ---------------------------
static
bool
TargetCompatible
(
const
Type
&
a
,
const
Type
&
b
)
{
return
(
a
.
IsVoid
()
||
b
.
IsVoid
())
||
//
a
.
target
()
==
b
.
target
();
static
bool
TargetCompatibleTo
(
const
Type
&
a
,
const
Type
&
b
)
{
return
a
.
IsVoid
()
||
//
(
a
.
IsTensor
()
&&
b
.
IsTensor
()
&&
(
a
.
target
()
==
b
.
target
()
||
//
b
.
target
()
==
TARGET
(
kAny
)));
}
static
bool
DataLayoutCompatible
(
const
Type
&
a
,
const
Type
&
b
)
{
return
(
a
.
IsVoid
()
||
b
.
IsVoid
())
||
//
(
a
.
IsTensor
()
&&
b
.
IsTensor
()
&&
a
.
layout
()
==
b
.
layout
());
static
bool
DataLayoutCompatibleTo
(
const
Type
&
a
,
const
Type
&
b
)
{
return
a
.
IsVoid
()
||
//
(
a
.
IsTensor
()
&&
b
.
IsTensor
()
&&
(
a
.
layout
()
==
b
.
layout
()
||
//
b
.
layout
()
==
DATALAYOUT
(
kAny
)));
}
static
bool
PrecisionCompatible
(
const
Type
&
a
,
const
Type
&
b
)
{
return
(
a
.
IsVoid
()
||
b
.
IsVoid
())
||
//
(
a
.
precision
()
==
b
.
precision
());
static
bool
PrecisionCompatibleTo
(
const
Type
&
a
,
const
Type
&
b
)
{
return
a
.
IsVoid
()
||
//
(
a
.
IsTensor
()
&&
b
.
IsTensor
()
&&
(
a
.
precision
()
==
b
.
precision
()
||
//
b
.
precision
()
==
PRECISION
(
kAny
)));
}
static
bool
DeviceCompatible
(
const
Type
&
a
,
const
Type
&
b
)
{
return
(
a
.
IsVoid
()
||
b
.
IsVoid
()
)
||
//
(
a
.
device
()
==
b
.
device
(
));
static
bool
DeviceCompatible
To
(
const
Type
&
a
,
const
Type
&
b
)
{
return
a
.
IsVoid
(
)
||
//
(
a
.
IsTensor
()
&&
b
.
IsTensor
()
&&
(
a
.
device
()
==
b
.
device
()
));
}
static
bool
TypeCompatible
(
const
Type
&
a
,
const
Type
&
b
)
{
return
TargetCompatible
(
a
,
b
)
&&
DataLayoutCompatible
(
a
,
b
)
&&
PrecisionCompatible
(
a
,
b
)
&&
DeviceCompatible
(
a
,
b
);
// Can type 'a' be passed to 'b' directly.
static
bool
TypeCompatibleTo
(
const
Type
&
a
,
const
Type
&
b
)
{
return
TargetCompatibleTo
(
a
,
b
)
&&
DataLayoutCompatibleTo
(
a
,
b
)
&&
PrecisionCompatibleTo
(
a
,
b
)
&&
DeviceCompatibleTo
(
a
,
b
);
}
// -------------------------------- predefined types ---------------------------
...
...
@@ -230,6 +251,9 @@ class TensorInt64NCHWTy : public Type {
:
Type
(
ID
::
Tensor_Int64_NCHW
,
"TensorInt64NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kInt8
,
DataLayoutType
::
kNCHW
)
{}
};
const
Type
*
LookupType
(
DataTypeBase
::
ID
type_id
,
bool
is_unknown
,
bool
is_tensor
,
Place
place
);
// ------------------------- end predefined types ---------------------------
// NOTE TypeSystem has some overhead, and better to be used in analysis phase.
...
...
@@ -381,13 +405,15 @@ class ParamTypeRegistry {
CHECK
(
types_
.
count
(
key
));
}
template
<
IO
io
>
const
ParamType
*
Retrieve
(
const
Place
&
place
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
KernelIdTy
key
{
op_type
,
place
,
io
,
arg_name
};
auto
it
=
types_
.
find
(
key
);
if
(
it
==
types_
.
end
())
return
nullptr
;
return
&
it
->
second
;
const
ParamType
*
RetrieveInArgument
(
const
Place
&
place
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
return
Retrieve
<
IO
::
kInput
>
(
place
,
op_type
,
arg_name
);
}
const
ParamType
*
RetrieveOutArgument
(
const
Place
&
place
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
return
Retrieve
<
IO
::
kOutput
>
(
place
,
op_type
,
arg_name
);
}
static
ParamTypeRegistry
&
Global
()
{
...
...
@@ -403,6 +429,16 @@ class ParamTypeRegistry {
return
os
;
}
protected:
template
<
IO
io
>
const
ParamType
*
Retrieve
(
const
Place
&
place
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
KernelIdTy
key
{
op_type
,
place
,
io
,
arg_name
};
auto
it
=
types_
.
find
(
key
);
if
(
it
==
types_
.
end
())
return
nullptr
;
return
&
it
->
second
;
}
private:
ParamTypeRegistry
()
=
default
;
...
...
paddle/fluid/lite/core/types.cc
浏览文件 @
621d1522
...
...
@@ -43,6 +43,9 @@ bool KernelPickFactor::IsTargetConsidered() const {
bool
KernelPickFactor
::
IsDataLayoutConsidered
()
const
{
return
data_
&
static_cast
<
int
>
(
Factor
::
DataLayoutFirst
);
}
bool
KernelPickFactor
::
IsDeviceConsidered
()
const
{
return
data_
&
static_cast
<
int
>
(
Factor
::
DeviceFirst
);
}
}
// namespace core
}
// namespace lite
...
...
paddle/fluid/lite/core/types.h
浏览文件 @
621d1522
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <stack>
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/utils/all.h"
...
...
@@ -38,6 +39,7 @@ class KernelPickFactor {
bool
AnyFactorConsidered
()
const
{
return
data_
;
}
KernelPickFactor
&
ConsiderTarget
();
// Perfer a specific target, e.g. prefer CUDA kernels.
KernelPickFactor
&
ConsiderPrecision
();
KernelPickFactor
&
ConsiderDataLayout
();
KernelPickFactor
&
ConsiderDevice
();
...
...
@@ -45,12 +47,29 @@ class KernelPickFactor {
bool
IsTargetConsidered
()
const
;
bool
IsPrecisionConsidered
()
const
;
bool
IsDataLayoutConsidered
()
const
;
bool
IsDeviceConsidered
()
const
{
return
data_
&
static_cast
<
int
>
(
Factor
::
DeviceFirst
);
bool
IsDeviceConsidered
()
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
KernelPickFactor
&
k
)
{
std
::
stack
<
bool
>
bits
;
auto
data
=
k
.
data_
;
while
(
data
)
{
bits
.
push
(
data
%
2
);
data
/=
2
;
}
int
nbits
=
bits
.
size
();
for
(
size_t
i
=
0
;
i
<
sizeof
(
data
)
*
8
-
nbits
;
i
++
)
{
os
<<
0
;
}
while
(
!
bits
.
empty
())
{
os
<<
bits
.
top
();
bits
.
pop
();
}
return
os
;
}
private:
unsigned
char
data_
{};
TargetType
target_
{
TARGET
(
kUnk
)};
};
struct
dim2
{
...
...
paddle/fluid/lite/cuda/target_wrapper.cc
浏览文件 @
621d1522
...
...
@@ -12,10 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Created by chunwei on 19-2-23.
//
#include "paddle/fluid/lite/cuda/target_wrapper.h"
#include <glog/logging.h>
...
...
@@ -24,19 +20,14 @@ namespace lite {
using
TargetW
=
TargetWrapper
<
TARGET
(
kCUDA
),
cudaStream_t
,
cudaEvent_t
>
;
template
<
>
void
*
TargetW
::
Malloc
(
size_t
size
)
{
void
*
ptr
{};
CHECK_EQ
(
cudaSuccess
,
cudaMalloc
(
&
ptr
,
size
));
return
ptr
;
}
template
<
>
void
TargetW
::
Free
(
void
*
ptr
)
{
CHECK_EQ
(
cudaSuccess
,
cudaFree
(
ptr
));
}
void
TargetW
::
Free
(
void
*
ptr
)
{
CHECK_EQ
(
cudaSuccess
,
cudaFree
(
ptr
));
}
template
<
>
void
TargetW
::
MemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
)
{
switch
(
dir
)
{
...
...
@@ -55,7 +46,6 @@ void TargetW::MemcpySync(void* dst, const void* src, size_t size,
}
}
template
<
>
void
TargetW
::
MemcpyAsync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
,
const
stream_t
&
stream
)
{
switch
(
dir
)
{
...
...
paddle/fluid/lite/kernels/cuda/CMakeLists.txt
浏览文件 @
621d1522
cc
_library
(
mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite
)
nv
_library
(
mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite
)
cc_library
(
io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor_lite
)
nv_library
(
kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda
)
paddle/fluid/lite/kernels/cuda/io_copy_compute.cc
浏览文件 @
621d1522
...
...
@@ -21,7 +21,7 @@ namespace lite {
namespace
kernels
{
namespace
cuda
{
using
TargetW
=
TargetWrapper
<
TARGET
(
k
Host
),
cudaStream_t
,
cudaEvent_t
>
;
using
TargetW
=
TargetWrapper
<
TARGET
(
k
CUDA
),
cudaStream_t
,
cudaEvent_t
>
;
// Host to CUDA memory.
void
CopyFromHostSync
(
void
*
target
,
const
void
*
source
,
size_t
size
)
{
...
...
@@ -51,6 +51,25 @@ class IoCopyHostToCudaCompute
auto
*
data
=
param
.
y
->
mutable_data
(
target
(),
param
.
x
->
memory_size
());
CopyFromHostSync
(
data
,
param
.
x
->
data
<
void
>
(),
param
.
x
->
memory_size
());
}
std
::
unique_ptr
<
type_infer_handler_t
>
GetTypeInferHandler
()
override
{
std
::
unique_ptr
<
type_infer_handler_t
>
res
(
new
type_infer_handler_t
);
*
res
=
[](
const
std
::
map
<
std
::
string
,
const
Type
*>&
inputs
,
const
std
::
string
&
out
)
->
const
Type
*
{
CHECK
(
!
inputs
.
empty
());
auto
*
type
=
inputs
.
at
(
"Input"
);
CHECK
(
type
->
target
()
==
TARGET
(
kHost
));
auto
out_place
=
type
->
place
();
out_place
.
target
=
TARGET
(
kCUDA
);
auto
*
out_type
=
LookupType
(
type
->
id
(),
type
->
IsUnsupported
(),
type
->
IsUnsupported
(),
out_place
);
return
out_type
;
};
return
res
;
}
std
::
string
doc
()
const
override
{
return
"Copy IO from HOST to CUDA"
;
}
};
/*
...
...
@@ -65,6 +84,8 @@ class IoCopyCudaToHostCompute
auto
*
data
=
param
.
y
->
mutable_data
(
TARGET
(
kHost
),
param
.
x
->
memory_size
());
CopyToHostSync
(
data
,
param
.
x
->
data
<
void
>
(),
param
.
x
->
memory_size
());
}
std
::
string
doc
()
const
override
{
return
"Copy IO from CUDA to HOST"
;
}
};
}
// namespace cuda
...
...
@@ -72,7 +93,7 @@ class IoCopyCudaToHostCompute
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
REGISTER_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
paddle
::
lite
::
kernels
::
cuda
::
IoCopyHostToCudaCompute
,
host_to_device
)
.
BindInput
(
"Input"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorAnyTy
>
(
...
...
@@ -81,7 +102,7 @@ REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny,
TARGET
(
kCUDA
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
REGISTER_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
paddle
::
lite
::
kernels
::
cuda
::
IoCopyCudaToHostCompute
,
device_to_host
)
.
BindInput
(
"Input"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorAnyTy
>
(
...
...
paddle/fluid/lite/kernels/cuda/mul_compute.cc
浏览文件 @
621d1522
...
...
@@ -13,3 +13,22 @@
// limitations under the License.
#include "paddle/fluid/lite/kernels/cuda/mul_compute.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
mul
,
kCUDA
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
MulCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kCUDA
))})
.
BindInput
(
"Y"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kCUDA
))})
.
Finalize
();
paddle/fluid/lite/kernels/cuda/mul_compute.h
浏览文件 @
621d1522
...
...
@@ -16,6 +16,7 @@
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/cuda/blas.h"
#include "paddle/fluid/lite/operators/op_params.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -29,11 +30,29 @@ void mul_compute(const lite::cuda::Blas<float>& blas, const T* x, int x_h,
nullptr
,
out
,
0
);
}
class
MulCompute
:
public
OpKernel
<
TARGET
(
k
Host
),
PRECISION
(
kFloat
)
>
{
class
MulCompute
:
public
OpKernel
<
TARGET
(
k
CUDA
),
PRECISION
(
kFloat
)
>
{
public:
using
param_t
=
operators
::
MulParam
;
void
Run
()
override
{}
void
Run
()
override
{
CHECK
(
context_
)
<<
"running context should be set first"
;
auto
&
context
=
context_
->
AsCudaContext
();
CHECK
(
context
.
blas_fp32
)
<<
"blas should init first"
;
auto
&
blas
=
*
context
.
blas_fp32
;
const
auto
&
param
=
Param
<
operators
::
MulParam
>
();
CHECK
(
param
.
x
->
target
()
==
TARGET
(
kCUDA
));
auto
*
x
=
param
.
x
->
data
<
float
>
();
int
x_h
=
param
.
x
->
dims
()[
0
];
int
x_w
=
param
.
x
->
dims
()[
1
];
auto
*
y
=
param
.
y
->
data
<
float
>
();
int
y_h
=
param
.
y
->
dims
()[
0
];
int
y_w
=
param
.
y
->
dims
()[
1
];
auto
*
out
=
param
.
output
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
mul_compute
<
float
>
(
blas
,
x
,
x_h
,
x_w
,
y
,
y_h
,
y_w
,
out
);
}
virtual
~
MulCompute
()
=
default
;
};
...
...
paddle/fluid/lite/kernels/host/fc_compute.cc
浏览文件 @
621d1522
...
...
@@ -51,8 +51,8 @@ void FcCompute::Run() {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
,
def
)
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
,
def
)
.
BindInput
(
"Input"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/kernels/host/feed_compute.cc
浏览文件 @
621d1522
...
...
@@ -20,7 +20,8 @@ namespace lite {
namespace
kernels
{
namespace
host
{
class
FeedCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
{
class
FeedCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)
>
{
public:
using
param_t
=
operators
::
FeedParam
;
...
...
@@ -38,7 +39,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
k
Float
,
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
k
Any
,
kAny
,
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorAnyTy
>
(
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/kernels/host/fetch_compute.cc
浏览文件 @
621d1522
...
...
@@ -20,7 +20,8 @@ namespace lite {
namespace
kernels
{
namespace
host
{
class
FetchCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
{
class
FetchCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)
>
{
public:
using
param_t
=
operators
::
FeedParam
;
...
...
@@ -41,7 +42,7 @@ class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fetch
,
kHost
,
k
Float
,
REGISTER_LITE_KERNEL
(
fetch
,
kHost
,
k
Any
,
kAny
,
paddle
::
lite
::
kernels
::
host
::
FetchCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorAnyTy
>
(
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/kernels/host/mul_compute.cc
浏览文件 @
621d1522
...
...
@@ -67,7 +67,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
host
::
MulCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/kernels/host/relu_compute.h
浏览文件 @
621d1522
...
...
@@ -42,6 +42,6 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
relu
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
relu
,
kHost
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
host
::
ReluCompute
,
def
)
.
Finalize
();
paddle/fluid/lite/kernels/host/scale_compute.cc
浏览文件 @
621d1522
...
...
@@ -50,7 +50,7 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
host
::
ScaleCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/operators/io_copy_op.cc
浏览文件 @
621d1522
...
...
@@ -24,7 +24,10 @@ bool IoCopyOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
y
);
return
true
;
}
bool
IoCopyOp
::
InferShape
()
const
{
return
true
;
}
bool
IoCopyOp
::
InferShape
()
const
{
param_
.
y
->
Resize
(
param_
.
x
->
dims
());
return
true
;
}
bool
IoCopyOp
::
Run
()
{
return
OpLite
::
Run
();
}
bool
IoCopyOp
::
AttachImpl
(
const
paddle
::
framework
::
OpDesc
&
opdesc
,
paddle
::
lite
::
Scope
*
scope
)
{
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
621d1522
...
...
@@ -51,9 +51,6 @@ class MulOpLite : public OpLite {
param_
.
x_num_col_dims
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"x_num_col_dims"
));
param_
.
y_num_col_dims
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"y_num_col_dims"
));
CHECK
(
kernel_
);
kernel_
->
SetParam
(
param_
);
return
true
;
}
...
...
paddle/fluid/lite/utils/factory.h
浏览文件 @
621d1522
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <iostream>
#include <list>
#include <memory>
...
...
@@ -48,8 +49,6 @@ class Factory {
}
void
Register
(
const
std
::
string
&
op_type
,
creator_t
&&
creator
)
{
CHECK
(
!
creators_
.
count
(
op_type
))
<<
"The op "
<<
op_type
<<
" has already registered"
;
creators_
[
op_type
].
emplace_back
(
std
::
move
(
creator
));
}
...
...
@@ -58,9 +57,9 @@ class Factory {
}
std
::
list
<
item_ptr_t
>
Creates
(
const
std
::
string
&
op_type
)
const
{
auto
it
=
creators_
.
find
(
op_type
);
CHECK
(
it
!=
creators_
.
end
())
<<
"no item called "
<<
op_type
;
std
::
list
<
item_ptr_t
>
res
;
auto
it
=
creators_
.
find
(
op_type
);
if
(
it
==
creators_
.
end
())
return
res
;
for
(
auto
&
c
:
it
->
second
)
{
res
.
emplace_back
(
c
());
}
...
...
paddle/fluid/lite/utils/varient.h
浏览文件 @
621d1522
...
...
@@ -99,7 +99,7 @@ struct variant {
size_t
type
()
{
return
type_id
;
}
void
valid
()
{
return
(
type_id
!=
invalid_type
());
}
bool
valid
()
{
return
(
type_id
!=
invalid_type
());
}
template
<
typename
T
,
typename
...
Args
>
void
set
(
Args
&&
...
args
)
{
...
...
paddle/fluid/lite/utils/varient_test.cc
浏览文件 @
621d1522
...
...
@@ -24,6 +24,8 @@ namespace utils {
TEST
(
varient
,
test
)
{
variant
<
int
,
float
>
a
;
// The initial state should be invalid.
ASSERT_FALSE
(
a
.
valid
());
a
.
set
<
int
>
(
1
);
ASSERT_EQ
(
a
.
get
<
int
>
(),
1
);
a
.
set
<
int
>
(
20
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录