Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
621d1522
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
621d1522
编写于
4月 27, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make io_copy kernel pick works
上级
1fb93746
变更
51
隐藏空白更改
内联
并排
Showing
51 changed file
with
1218 addition
and
354 deletion
+1218
-354
paddle/fluid/framework/op_desc.h
paddle/fluid/framework/op_desc.h
+1
-0
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+6
-2
paddle/fluid/lite/api/cxx_api.h
paddle/fluid/lite/api/cxx_api.h
+2
-1
paddle/fluid/lite/api/cxx_api_test.cc
paddle/fluid/lite/api/cxx_api_test.cc
+26
-6
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+1
-0
paddle/fluid/lite/core/kernel.cc
paddle/fluid/lite/core/kernel.cc
+7
-0
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+60
-8
paddle/fluid/lite/core/memory.cc
paddle/fluid/lite/core/memory.cc
+1
-1
paddle/fluid/lite/core/memory.h
paddle/fluid/lite/core/memory.h
+5
-2
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+3
-0
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
+45
-0
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+3
-0
paddle/fluid/lite/core/mir/io_complement_pass.cc
paddle/fluid/lite/core/mir/io_complement_pass.cc
+152
-18
paddle/fluid/lite/core/mir/io_complement_pass.h
paddle/fluid/lite/core/mir/io_complement_pass.h
+30
-1
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
+74
-0
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+38
-10
paddle/fluid/lite/core/mir/passes.h
paddle/fluid/lite/core/mir/passes.h
+2
-0
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+138
-0
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+64
-82
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+3
-2
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
+19
-2
paddle/fluid/lite/core/mir/variable_place_inference_pass.cc
paddle/fluid/lite/core/mir/variable_place_inference_pass.cc
+1
-1
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+27
-20
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+11
-6
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+47
-32
paddle/fluid/lite/core/op_registry.cc
paddle/fluid/lite/core/op_registry.cc
+41
-14
paddle/fluid/lite/core/op_registry.h
paddle/fluid/lite/core/op_registry.h
+92
-59
paddle/fluid/lite/core/optimizer.cc
paddle/fluid/lite/core/optimizer.cc
+29
-0
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+25
-1
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+14
-8
paddle/fluid/lite/core/target_wrapper.h
paddle/fluid/lite/core/target_wrapper.h
+51
-11
paddle/fluid/lite/core/type_system.cc
paddle/fluid/lite/core/type_system.cc
+29
-8
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+58
-22
paddle/fluid/lite/core/types.cc
paddle/fluid/lite/core/types.cc
+3
-0
paddle/fluid/lite/core/types.h
paddle/fluid/lite/core/types.h
+21
-2
paddle/fluid/lite/cuda/target_wrapper.cc
paddle/fluid/lite/cuda/target_wrapper.cc
+1
-11
paddle/fluid/lite/kernels/cuda/CMakeLists.txt
paddle/fluid/lite/kernels/cuda/CMakeLists.txt
+3
-1
paddle/fluid/lite/kernels/cuda/io_copy_compute.cc
paddle/fluid/lite/kernels/cuda/io_copy_compute.cc
+24
-3
paddle/fluid/lite/kernels/cuda/mul_compute.cc
paddle/fluid/lite/kernels/cuda/mul_compute.cc
+19
-0
paddle/fluid/lite/kernels/cuda/mul_compute.h
paddle/fluid/lite/kernels/cuda/mul_compute.h
+21
-2
paddle/fluid/lite/kernels/host/fc_compute.cc
paddle/fluid/lite/kernels/host/fc_compute.cc
+2
-2
paddle/fluid/lite/kernels/host/feed_compute.cc
paddle/fluid/lite/kernels/host/feed_compute.cc
+3
-2
paddle/fluid/lite/kernels/host/fetch_compute.cc
paddle/fluid/lite/kernels/host/fetch_compute.cc
+3
-2
paddle/fluid/lite/kernels/host/mul_compute.cc
paddle/fluid/lite/kernels/host/mul_compute.cc
+1
-1
paddle/fluid/lite/kernels/host/relu_compute.h
paddle/fluid/lite/kernels/host/relu_compute.h
+1
-1
paddle/fluid/lite/kernels/host/scale_compute.cc
paddle/fluid/lite/kernels/host/scale_compute.cc
+1
-1
paddle/fluid/lite/operators/io_copy_op.cc
paddle/fluid/lite/operators/io_copy_op.cc
+4
-1
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+0
-3
paddle/fluid/lite/utils/factory.h
paddle/fluid/lite/utils/factory.h
+3
-4
paddle/fluid/lite/utils/varient.h
paddle/fluid/lite/utils/varient.h
+1
-1
paddle/fluid/lite/utils/varient_test.cc
paddle/fluid/lite/utils/varient_test.cc
+2
-0
未找到文件。
paddle/fluid/framework/op_desc.h
浏览文件 @
621d1522
...
...
@@ -42,6 +42,7 @@ class OpDesc {
void
CopyFrom
(
const
OpDesc
&
op_desc
);
proto
::
OpDesc
*
Proto
();
const
proto
::
OpDesc
&
ReadonlyProto
()
const
{
return
desc_
;
}
std
::
string
Type
()
const
{
return
desc_
.
type
();
}
...
...
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
621d1522
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host
)
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite
)
if
(
LITE_WITH_CUDA
)
cc_library
(
cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda
)
nv_test
(
test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite
)
endif
()
paddle/fluid/lite/api/cxx_api.h
浏览文件 @
621d1522
...
...
@@ -29,7 +29,7 @@ class Predictor {
public:
Predictor
()
{
scope_
=
std
::
make_shared
<
Scope
>
();
}
void
Build
(
const
std
::
string
&
model_path
,
void
Build
(
const
std
::
string
&
model_path
,
const
Place
&
prefer_place
,
const
std
::
vector
<
Place
>&
valid_places
)
{
framework
::
proto
::
ProgramDesc
prog
;
LoadModel
(
model_path
,
scope_
.
get
(),
&
prog
);
...
...
@@ -38,6 +38,7 @@ class Predictor {
Program
program
(
prog_desc
,
scope_
,
valid_places
);
Optimizer
optimizer
;
optimizer
.
KernelPickPreferPlace
(
prefer_place
);
core
::
KernelPickFactor
factor
;
factor
.
ConsiderTarget
();
optimizer
.
Run
(
std
::
move
(
program
),
valid_places
,
factor
);
...
...
paddle/fluid/lite/api/cxx_api_test.cc
浏览文件 @
621d1522
...
...
@@ -23,8 +23,21 @@ namespace lite {
TEST
(
CXXApi
,
test
)
{
lite
::
Predictor
predictor
;
#ifndef LITE_WITH_CUDA
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
#else
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)},
});
#endif
predictor
.
Build
(
"/home/chunwei/project2/models/model2"
,
{
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}}
);
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)},
valid_places
);
auto
*
input_tensor
=
predictor
.
GetInput
(
0
);
input_tensor
->
Resize
({
100
,
100
});
...
...
@@ -54,8 +67,15 @@ USE_LITE_OP(fc);
USE_LITE_OP
(
scale
);
USE_LITE_OP
(
feed
);
USE_LITE_OP
(
fetch
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kFloat
,
def
);
USE_LITE_OP
(
io_copy
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kAny
,
kAny
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kAny
,
kAny
,
def
);
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL
(
mul
,
kCUDA
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
host_to_device
);
USE_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
device_to_host
);
#endif
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
621d1522
...
...
@@ -27,5 +27,6 @@ cc_test(test_tensor_lite SRCS tensor_test.cc)
cc_test
(
test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels
)
cc_test
(
test_type_system SRCS type_system_test.cc DEPS type_system
)
cc_test
(
test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes
)
cc_test
(
test_types_lite SRCS types_test.cc DEPS types_lite
)
add_subdirectory
(
mir
)
paddle/fluid/lite/core/kernel.cc
浏览文件 @
621d1522
...
...
@@ -17,6 +17,13 @@
namespace
paddle
{
namespace
lite
{
std
::
string
KernelBase
::
summary
()
const
{
std
::
stringstream
ss
;
ss
<<
op_type
()
<<
":"
<<
TargetToStr
(
target
())
<<
"/"
<<
PrecisionToStr
(
precision
())
<<
"/"
<<
DataLayoutToStr
(
layout
());
return
ss
.
str
();
}
bool
ParamTypeRegistry
::
KeyCmp
::
operator
()(
const
ParamTypeRegistry
::
key_t
&
a
,
const
ParamTypeRegistry
::
key_t
&
b
)
const
{
...
...
paddle/fluid/lite/core/kernel.h
浏览文件 @
621d1522
...
...
@@ -17,6 +17,7 @@
#include <map>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
...
...
@@ -34,49 +35,100 @@ namespace lite {
// different targets.
class
KernelBase
{
public:
// type_infer_handler is used to inference a output type by considering the
// input types in the type system.
using
type_infer_handler_t
=
std
::
function
<
const
Type
*
(
const
std
::
map
<
std
::
string
,
const
Type
*>&
input_types
,
const
std
::
string
&
out_arg
)
>
;
virtual
void
Run
()
=
0
;
void
SetContext
(
std
::
unique_ptr
<
KernelContext
>&&
ctx
)
{
context_
=
std
::
move
(
ctx
);
}
template
<
typename
T
>
void
SetParam
(
T
param
)
{
param_
.
set
<
T
>
(
param
);
}
template
<
typename
P
>
P
&
Param
()
const
{
return
param_
.
get
<
P
>
();
}
// This is used in the kernels that takes 'kAny' places and inference the
// output place. For `ScaleCompute` and `IoCopyCompute`, their input types are
// declared as 'kAny' in some Place field, and the output is also `kAny`, but
// when in real execution, when takes some non-kAny type as input, the
// output's kAny-fields can be determained. For example, when the
// `ScaleCompute` takes `TensorFp32NCHWTy` as input, its output should be also
// `TensorFp32NCHWTy`. This type inference rule is different for each kernel,
// so we make it a virtual method.
// One can custom this handler to make a specific type inference rule for a
// kernel, or leave the default to force the kernel use the system's
// type-inference rules.
virtual
std
::
unique_ptr
<
type_infer_handler_t
>
GetTypeInferHandler
()
{
return
nullptr
;
}
void
set_op_type
(
const
std
::
string
&
type
)
{
op_type_
=
type
;
}
const
std
::
string
&
op_type
()
const
{
return
op_type_
;
}
void
Torch
()
{}
// Get input declaration type.
const
Type
*
GetInputDeclType
(
const
std
::
string
&
arg_name
)
{
CHECK
(
!
op_type_
.
empty
())
<<
"op_type should be set first"
;
const
auto
*
type
=
ParamTypeRegistry
::
Global
().
RetrieveInArgument
(
place
(),
GenParamTypeKey
(),
arg_name
);
CHECK
(
type
)
<<
"no type registered for kernel ["
<<
op_type_
<<
"] input argument ["
<<
arg_name
<<
"]"
<<
" with key "
<<
GenParamTypeKey
();
return
type
->
type
;
}
// Get output declaration type.
const
Type
*
GetOutputDeclType
(
const
std
::
string
&
arg_name
)
{
CHECK
(
!
op_type_
.
empty
())
<<
"op_type should be set first"
;
const
auto
*
type
=
ParamTypeRegistry
::
Global
().
RetrieveOutArgument
(
place
(),
GenParamTypeKey
(),
arg_name
);
CHECK
(
type
)
<<
"no type registered for kernel ["
<<
op_type_
<<
"] output argument ["
<<
arg_name
<<
"]"
;
return
type
->
type
;
}
void
set_alias
(
const
std
::
string
&
x
)
{
alias_
=
x
;
LOG
(
INFO
)
<<
"kernel "
<<
op_type
()
<<
" setting alias "
<<
alias
();
}
const
std
::
string
&
alias
()
const
{
return
alias_
;
}
virtual
Place
place
()
const
=
0
;
virtual
TargetType
target
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
DataLayoutType
layout
()
const
=
0
;
const
KernelContext
*
context
()
const
{
return
context_
.
get
();
}
virtual
std
::
string
name
()
const
=
0
;
virtual
~
KernelBase
()
=
default
;
// Short human-readable document.
std
::
string
summary
()
const
;
// Long human-readable document.
virtual
std
::
string
doc
()
const
{
return
""
;
}
std
::
string
DebugString
()
const
{
std
::
string
GenParamTypeKey
()
const
{
std
::
stringstream
ss
;
ss
<<
op_type
()
<<
":"
<<
TargetToStr
(
target
())
<<
"/"
<<
PrecisionToStr
(
precision
())
<<
"/"
<<
DataLayoutToStr
(
layout
())
;
LOG
(
INFO
)
<<
"alias : "
<<
alias_
;
ss
<<
op_type
()
<<
"/"
<<
alias_
;
return
ss
.
str
();
}
virtual
~
KernelBase
()
=
default
;
protected:
std
::
unique_ptr
<
KernelContext
>
context_
;
mutable
operators
::
param_t
param_
;
// The corresponding op type.
std
::
string
op_type_
;
std
::
string
op_type_
{};
std
::
string
alias_
{};
};
// Light-weight kernel implementation.
...
...
paddle/fluid/lite/core/memory.cc
浏览文件 @
621d1522
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "memory.h"
#include "
paddle/fluid/lite/core/
memory.h"
namespace
paddle
{
namespace
framework
{
...
...
paddle/fluid/lite/core/memory.h
浏览文件 @
621d1522
...
...
@@ -14,7 +14,7 @@
#pragma once
#include <glog/logging.h>
#include "target_wrapper.h"
#include "
paddle/fluid/lite/core/
target_wrapper.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -26,9 +26,12 @@ static void* TargetMalloc(TargetType target, size_t size) {
case
TargetType
::
kX86
:
data
=
TargetWrapper
<
TARGET
(
kHost
)
>::
Malloc
(
size
);
break
;
#ifdef LITE_WITH_CUDA
case
TargetType
::
kCUDA
:
data
=
TargetWrapper
<
TARGET
(
kCUDA
)
>::
Malloc
(
size
);
data
=
TargetWrapper
<
TARGET
(
kCUDA
),
cudaStream_t
,
cudaEvent_t
>::
Malloc
(
size
);
break
;
#endif // LITE_WITH_CUDA
default:
LOG
(
FATAL
)
<<
"Unknown supported target "
<<
TargetToStr
(
target
);
}
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
621d1522
...
...
@@ -7,9 +7,12 @@ cc_library(mir_passes
SRCS static_kernel_pick_pass.cc
variable_place_inference_pass.cc
io_complement_pass.cc
io_copy_kernel_pick_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
argument_type_display_pass.cc
demo_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite
)
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes
)
...
...
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
0 → 100644
浏览文件 @
621d1522
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
ArgumentTypeDisplayPass
:
public
DebugPass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
LOG
(
INFO
)
<<
"== Argument types =="
;
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsArgument
())
continue
;
auto
*
type
=
node
.
AsArgument
().
type
;
if
(
type
)
{
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArgument
().
name
<<
" type: "
<<
*
type
;
}
else
{
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArgument
().
name
<<
" type: UNK"
;
}
}
LOG
(
INFO
)
<<
"---------------------"
;
}
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
argument_type_display_pass
,
paddle
::
lite
::
mir
::
ArgumentTypeDisplayPass
);
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
621d1522
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
...
...
@@ -20,9 +21,11 @@ namespace lite {
namespace
mir
{
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
LOG
(
INFO
)
<<
"final program
\n
"
<<
Visualize
(
graph
.
get
());
for
(
auto
&
item
:
graph
->
InstructTopologicalOrder
())
{
if
(
item
->
IsInstruct
())
{
auto
&
instruct
=
item
->
AsInstruct
();
LOG
(
INFO
)
<<
instruct
;
insts_
.
emplace_back
(
instruct
.
op
,
std
::
move
(
instruct
.
valid_kernels
.
front
()));
}
...
...
paddle/fluid/lite/core/mir/io_complement_pass.cc
浏览文件 @
621d1522
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
...
...
@@ -21,28 +22,161 @@ namespace mir {
void
IoComplementPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
// Start from inputs of the graph, those should have place set.
std
::
list
<
Node
*>
nodes
;
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsInstruct
())
continue
;
auto
&
inst
=
node
.
AsInstruct
();
// inputs
for
(
auto
*
in
:
node
.
inlinks
)
{
CHECK
(
in
->
IsArgument
());
auto
name
=
in
->
AsArgument
().
name
;
std
::
string
tmp
;
CHECK
(
inst
.
op_info
->
GetInputArgname
(
name
,
&
tmp
));
auto
type
=
ParamTypeRegistry
::
Global
().
Retrieve
<
ParamTypeRegistry
::
IO
::
kInput
>
(
inst
.
place
,
inst
.
op_type
,
tmp
);
CHECK
(
type
)
<<
"no param type found for "
<<
inst
.
op_type
<<
":"
<<
name
<<
" "
<<
inst
.
place
;
CHECK
(
type
->
type
);
CHECK
(
in
->
AsArgument
().
type
);
if
(
!
TypeCompatible
(
*
type
->
type
,
*
in
->
AsArgument
().
type
))
{
LOG
(
INFO
)
<<
"found IO unmatched tensor: "
<<
in
->
AsArgument
().
name
;
nodes
.
push_back
(
&
node
);
}
CHECK
(
!
valid_places_
.
empty
());
for
(
auto
&
node
:
nodes
)
{
if
(
!
node
->
IsInstruct
())
continue
;
auto
inlinks
=
node
->
inlinks
;
for
(
auto
*
in
:
inlinks
)
{
ComplementInputs
(
graph
.
get
(),
node
,
in
);
}
}
// PickIoCopyKernel(graph.get());
LOG
(
INFO
)
<<
"
\n
"
<<
Visualize
(
graph
.
get
());
}
void
IoComplementPass
::
ComplementInputs
(
SSAGraph
*
graph
,
Node
*
inst_node
,
Node
*
in
)
{
// If this input is out of date.
if
(
inst_node
->
inlinks
.
end
()
==
std
::
find
(
inst_node
->
inlinks
.
begin
(),
inst_node
->
inlinks
.
end
(),
in
))
return
;
CHECK
(
inst_node
->
IsInstruct
());
auto
&
inst
=
inst_node
->
AsInstruct
();
CHECK
(
in
->
IsRoleSet
());
CHECK
(
in
->
IsArgument
());
auto
in_arg_name
=
in
->
AsArgument
().
name
;
std
::
string
tmp
;
CHECK
(
inst
.
op_info
()
->
GetInputArgname
(
in_arg_name
,
&
tmp
));
auto
decl_arg_type
=
inst
.
picked_kernel
().
GetInputDeclType
(
tmp
);
CHECK
(
in
->
AsArgument
().
type
);
if
(
!
TypeCompatibleTo
(
*
in
->
AsArgument
().
type
,
*
decl_arg_type
))
{
LOG
(
INFO
)
<<
"found IO unmatched tensor: "
<<
in
->
AsArgument
().
name
<<
" for kernel "
<<
inst
.
op
->
DebugString
()
<<
" "
<<
*
in
->
AsArgument
().
type
<<
" -> "
<<
*
decl_arg_type
;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst
(
*
in
->
AsArgument
().
type
,
*
decl_arg_type
,
in
->
AsArgument
().
name
,
graph
,
inst_node
,
valid_places_
);
}
}
void
UpdateOpdescInputName
(
framework
::
OpDesc
*
desc
,
const
std
::
string
&
old_arg_name
,
const
std
::
string
&
new_arg_name
)
{
for
(
auto
&
item
:
*
desc
->
Proto
()
->
mutable_inputs
())
{
for
(
int
i
=
0
;
i
<
item
.
mutable_arguments
()
->
size
();
i
++
)
{
auto
*
arg
=
item
.
mutable_arguments
(
i
);
if
(
*
arg
==
old_arg_name
)
{
*
arg
=
new_arg_name
;
}
}
}
}
void
IoComplementPass
::
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
const
std
::
string
&
var
,
SSAGraph
*
graph
,
Node
*
inst_node
,
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
valid_places
.
empty
())
<<
"valid_place should be set"
;
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Instruct Node.
auto
node_id
=
[
&
]
{
return
graph
->
nodes
().
size
();
};
auto
io_copy_output_name
=
var
+
"/trans/"
+
std
::
to_string
(
node_id
());
auto
*
io_copy_output_arg
=
graph
->
NewArgumentNode
(
io_copy_output_name
);
auto
*
io_copy_inst
=
graph
->
NewInstructNode
();
// create Op and kernels.
auto
io_copy_op
=
LiteOpRegistry
::
Global
().
Create
(
"io_copy"
);
// CHECK(io_copy_op);
// Create the new var manually.
inst_node
->
AsInstruct
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
framework
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"io_copy"
);
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
op_desc
.
Flush
();
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsInstruct
().
op
->
scope
());
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
io_copy_inst
->
AsInstruct
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
// Remove the old link
RemoveDirectedLink
(
graph
->
Argument
(
var
),
inst_node
);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
auto
&
inst
=
inst_node
->
AsInstruct
();
auto
inst_program_desc
=
inst
.
op_info
()
->
desc
();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink
(
graph
->
Argument
(
var
),
io_copy_inst
);
DirectedLink
(
io_copy_inst
,
io_copy_output_arg
);
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
auto
desc_dummy
=
inst_node
->
AsInstruct
().
op
->
op_info
()
->
desc
();
UpdateInputTo
(
&
desc_dummy
,
var
,
io_copy_output_name
);
framework
::
OpDesc
desc_fake
(
desc_dummy
,
nullptr
);
inst_node
->
AsInstruct
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsInstruct
().
op
->
scope
());
std
::
string
tmp
;
if
(
inst_node
->
AsInstruct
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
CHECK
(
false
)
<<
"get old a "
<<
tmp
;
}
for
(
auto
&
kernel
:
inst_node
->
AsInstruct
().
valid_kernels
)
{
inst_node
->
AsInstruct
().
op
->
AttachKernel
(
kernel
.
get
());
}
graph
->
CheckValid
();
}
void
IoComplementPass
::
PickIoCopyKernel
(
SSAGraph
*
graph
)
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
node
.
IsInstruct
()
&&
node
.
AsInstruct
().
op_type
==
"io_copy"
)
{
auto
&
kernels
=
node
.
AsInstruct
().
valid_kernels
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
for
(
auto
&
kernel
:
kernels
)
{
CHECK_EQ
(
node
.
inlinks
.
size
(),
1UL
);
CHECK_EQ
(
node
.
outlinks
.
size
(),
1UL
);
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArgument
().
type
;
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArgument
().
type
;
const
Type
*
in_arg_ty
=
kernel
->
GetInputDeclType
(
"Input"
);
if
(
TypeCompatibleTo
(
*
inty
,
*
in_arg_ty
))
{
const
Type
*
out_arg_ty
=
kernel
->
GetOutputDeclType
(
"Out"
);
// Both the input and output type matches, remove other kernels
// directly.
if
(
out_arg_ty
->
target
()
==
outy
->
target
())
{
LOG
(
INFO
)
<<
"get a IOCopy kernel"
;
auto
x
=
std
::
move
(
kernel
);
kernels
.
clear
();
kernels
.
emplace_back
(
std
::
move
(
x
));
break
;
}
}
}
}
}
// Check the compatiblity.
}
void
IoComplementPass
::
SetValidPlaces
(
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
valid_places
.
empty
());
valid_places_
=
valid_places
;
}
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/io_complement_pass.h
浏览文件 @
621d1522
...
...
@@ -15,18 +15,47 @@
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
static
void
UpdateInputTo
(
framework
::
proto
::
OpDesc
*
desc
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
input
:
*
item
.
mutable_arguments
())
{
if
(
input
==
from
)
{
LOG
(
INFO
)
<<
"** update input argument from "
<<
from
<<
" to "
<<
to
;
input
=
to
;
}
}
}
}
/*
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
*/
class
IoComplementPass
:
public
ProgramPass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
override
;
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
void
ComplementInputs
(
SSAGraph
*
graph
,
Node
*
inst_node
,
Node
*
in
);
void
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
const
std
::
string
&
var
,
SSAGraph
*
graph
,
Node
*
inst_node
,
const
std
::
vector
<
Place
>&
valid_places
);
void
SetValidPlaces
(
const
std
::
vector
<
Place
>&
valid_places
);
// Pick the right kernel of IoCopy considering the input and output Type.
void
PickIoCopyKernel
(
SSAGraph
*
graph
);
const
std
::
vector
<
Place
>&
valid_places
()
const
{
return
valid_places_
;
};
private:
std
::
vector
<
Place
>
valid_places_
;
};
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
0 → 100644
浏览文件 @
621d1522
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
IoCopyKernelPickPass
:
public
InstructionPass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsInstruct
())
continue
;
auto
&
inst
=
node
.
AsInstruct
();
if
(
inst
.
op_type
!=
"io_copy"
)
continue
;
LOG
(
INFO
)
<<
"....> picking a IO COPY kernel"
;
auto
&
kernels
=
node
.
AsInstruct
().
valid_kernels
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArgument
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArgument
().
type
;
LOG
(
INFO
)
<<
"input type "
<<
*
inty
;
LOG
(
INFO
)
<<
"output type "
<<
*
outy
;
bool
is_found
=
false
;
LOG
(
INFO
)
<<
"kernels size "
<<
kernels
.
size
();
for
(
auto
&
kernel
:
kernels
)
{
CHECK_EQ
(
node
.
inlinks
.
size
(),
1UL
);
CHECK_EQ
(
node
.
outlinks
.
size
(),
1UL
);
const
Type
*
in_arg_ty
=
kernel
->
GetInputDeclType
(
"Input"
);
const
Type
*
out_arg_ty
=
kernel
->
GetOutputDeclType
(
"Out"
);
LOG
(
INFO
)
<<
"checking kernel candidate "
<<
*
in_arg_ty
<<
"->"
<<
*
out_arg_ty
;
if
(
inty
->
target
()
==
in_arg_ty
->
target
())
{
// Both the input and output type matches, remove other kernels
// directly.
if
(
out_arg_ty
->
target
()
==
outy
->
target
())
{
LOG
(
INFO
)
<<
"get a IOCopy kernel"
;
auto
x
=
std
::
move
(
kernel
);
kernels
.
clear
();
kernels
.
emplace_back
(
std
::
move
(
x
));
is_found
=
true
;
break
;
}
}
}
CHECK
(
is_found
)
<<
"Can't find a IoCopy kernel for IO: "
<<
*
inty
<<
"->"
<<
*
outy
;
}
}
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
io_copy_kernel_pick_pass
,
paddle
::
lite
::
mir
::
IoCopyKernelPickPass
);
paddle/fluid/lite/core/mir/node.h
浏览文件 @
621d1522
...
...
@@ -34,30 +34,43 @@ class Node {
Node
()
=
default
;
enum
class
Role
{
kUnk
=
-
1
,
kArgument
,
kArgument
=
0
,
kInstruct
,
kNumRoles
/*should be last*/
kNumRoles
,
/*should be last*/
kUnk
,
};
struct
Instruct
{
std
::
string
op_type
;
Place
place
;
// The kernel instances this Instruct contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
std
::
shared_ptr
<
OpInfo
>
op_info
;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std
::
shared_ptr
<
OpLite
>
op
;
// we hold op to run InferShape
const
OpInfo
*
op_info
()
{
CHECK
(
op
);
return
op
->
op_info
();
}
Place
place
()
const
{
CHECK
(
!
valid_kernels
.
empty
());
return
valid_kernels
.
front
()
->
place
();
}
KernelBase
&
picked_kernel
()
{
CHECK
(
!
valid_kernels
.
empty
());
return
*
valid_kernels
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Instruct
&
other
)
{
os
<<
"Instruct "
<<
other
.
op_type
<<
" "
<<
other
.
place
();
return
os
;
}
};
struct
Argument
{
std
::
string
name
;
const
Type
*
type
;
const
Type
*
type
{}
;
// Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place.
bool
is_weight
{
false
};
...
...
@@ -71,13 +84,11 @@ class Node {
Instruct
&
AsInstruct
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
OpLite
>&
op
,
const
std
::
shared_ptr
<
lite
::
OpInfo
>&
op_info
)
{
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
auto
&
x
=
AsInstruct
();
x
.
op_type
=
op_type
;
x
.
op
=
op
;
x
.
valid_kernels
=
std
::
move
(
kernels
);
x
.
op_info
=
op_info
;
return
x
;
}
...
...
@@ -100,8 +111,25 @@ class Node {
instruct_
.
reset
(
new
Instruct
);
return
*
instruct_
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
Node
&
other
)
{
os
<<
static_cast
<
int
>
(
other
.
role_
)
<<
" "
;
if
(
!
other
.
IsRoleSet
())
{
os
<<
"Unk role node"
;
}
if
(
other
.
IsArgument
())
{
auto
&
arg
=
other
.
AsArgument
();
os
<<
"Argument "
<<
arg
.
name
;
}
if
(
other
.
IsInstruct
())
{
auto
&
arg
=
other
.
AsInstruct
();
os
<<
"Instruct "
<<
arg
.
op_type
;
}
return
os
;
}
// Check roles.
bool
IsRoleSet
()
const
{
return
role_
=
=
Role
::
kUnk
;
}
bool
IsRoleSet
()
const
{
return
role_
!
=
Role
::
kUnk
;
}
bool
IsInstruct
()
const
{
return
role_
==
Role
::
kInstruct
;
}
bool
IsArgument
()
const
{
return
role_
==
Role
::
kArgument
;
}
...
...
paddle/fluid/lite/core/mir/passes.h
浏览文件 @
621d1522
...
...
@@ -26,3 +26,5 @@ USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS
(
variable_place_inference_pass
);
USE_MIR_PASS
(
io_complement_pass
);
USE_MIR_PASS
(
generate_program_pass
);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
621d1522
...
...
@@ -89,6 +89,144 @@ std::vector<mir::Node *> SSAGraph::InstructTopologicalOrder() {
return
res
;
}
void
SSAGraph
::
GraphCreateTmpVarNodes
(
const
Program
&
program
)
{
for
(
const
auto
&
name
:
program
.
tmp_vars
)
{
LOG
(
INFO
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArgument
(
name
);
arguments_
[
name
]
=
&
new_node
;
}
}
void
SSAGraph
::
GraphCreateWeightVarNodes
(
const
Program
&
program
)
{
// create weight nodes.
for
(
const
auto
&
name
:
program
.
weights
)
{
LOG
(
INFO
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArgument
(
name
);
arguments_
[
name
]
=
&
new_node
;
}
}
Node
*
SSAGraph
::
GraphCreateInstructNode
(
const
Program
&
program
,
const
std
::
shared_ptr
<
OpLite
>
&
op
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
node_storage_
.
emplace_back
();
// TODO(Superjomn) remove one valid_places here.
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
kernels
=
op
->
CreateKernels
(
valid_places
);
node_storage_
.
back
().
AsInstruct
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
);
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
return
&
node_storage_
.
back
();
}
void
SSAGraph
::
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
CHECK
(
node_storage_
.
empty
());
GraphCreateTmpVarNodes
(
program
);
GraphCreateWeightVarNodes
(
program
);
CHECK
(
CheckNodesRoleSet
());
for
(
auto
&
op
:
program
.
ops
)
{
auto
*
op_node
=
GraphCreateInstructNode
(
program
,
op
,
valid_places
);
LOG
(
INFO
)
<<
"checking op "
<<
op
->
op_type_
;
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
auto
*
arg
=
Argument
(
name
);
LOG
(
INFO
)
<<
"input "
<<
name
;
CHECK
(
arg
->
IsRoleSet
());
DirectedLink
(
arg
,
op_node
);
}
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
output_names
())
{
if
(
!
arguments_
.
count
(
name
))
{
NewArgumentNode
(
name
);
}
LOG
(
INFO
)
<<
"output "
<<
name
;
auto
*
arg
=
arguments_
.
at
(
name
);
CHECK
(
arg
->
IsRoleSet
());
DirectedLink
(
op_node
,
arg
);
}
CHECK
(
CheckLinksRoleSet
());
}
MarkArgumentWeights
(
program
);
CheckValid
();
}
mir
::
Node
*
SSAGraph
::
Argument
(
const
std
::
string
&
name
)
{
auto
it
=
arguments_
.
find
(
name
);
CHECK
(
it
!=
arguments_
.
end
())
<<
"no argument called "
<<
name
;
return
it
->
second
;
}
std
::
vector
<
mir
::
Node
*>
SSAGraph
::
inputs
()
{
std
::
vector
<
mir
::
Node
*>
res
;
for
(
auto
&
node
:
node_storage_
)
{
if
(
node
.
inlinks
.
empty
())
{
res
.
push_back
(
&
node
);
}
}
return
res
;
}
std
::
vector
<
mir
::
Node
*>
SSAGraph
::
outputs
()
{
std
::
vector
<
mir
::
Node
*>
res
;
for
(
auto
&
node
:
node_storage_
)
{
if
(
node
.
outlinks
.
empty
())
{
res
.
push_back
(
&
node
);
}
}
return
res
;
}
mir
::
Node
*
SSAGraph
::
RetrieveArgument
(
const
std
::
string
&
arg
)
{
auto
it
=
arguments_
.
find
(
arg
);
if
(
it
!=
arguments_
.
end
())
{
return
it
->
second
;
}
return
nullptr
;
}
bool
SSAGraph
::
CheckNodesRoleSet
()
{
for
(
auto
&
node
:
mutable_nodes
())
{
CHECK_OR_FALSE
(
node
.
IsRoleSet
());
}
return
true
;
}
bool
SSAGraph
::
CheckLinksRoleSet
()
{
for
(
auto
&
node
:
mutable_nodes
())
{
CHECK_OR_FALSE
(
node
.
IsRoleSet
());
if
(
!
node
.
IsInstruct
())
continue
;
for
(
auto
*
x
:
node
.
inlinks
)
{
CHECK_OR_FALSE
(
x
->
IsRoleSet
());
CHECK_OR_FALSE
(
x
->
IsArgument
());
}
for
(
auto
*
x
:
node
.
outlinks
)
{
CHECK_OR_FALSE
(
x
->
IsRoleSet
());
CHECK_OR_FALSE
(
x
->
IsArgument
());
}
}
return
true
;
}
Node
*
SSAGraph
::
NewArgumentNode
(
const
std
::
string
&
name
)
{
node_storage_
.
emplace_back
();
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate argument called "
<<
name
;
arguments_
[
name
]
=
&
node_storage_
.
back
();
node_storage_
.
back
().
AsArgument
(
name
);
return
&
node_storage_
.
back
();
}
Node
*
SSAGraph
::
NewInstructNode
()
{
node_storage_
.
emplace_back
();
return
&
node_storage_
.
back
();
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
621d1522
...
...
@@ -35,104 +35,44 @@ class SSAGraph : GraphBase {
public:
// @param program: the op program
// @param valid_places: the valid places user set for the system.
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
// create temporary nodes.
for
(
const
auto
&
name
:
program
.
tmp_vars
)
{
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
arg
=
new_node
.
AsArgument
();
arg
.
name
=
name
;
arguments_
[
name
]
=
&
new_node
;
}
// create weight nodes.
for
(
const
auto
&
name
:
program
.
weights
)
{
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
arg
=
new_node
.
AsArgument
();
arg
.
name
=
name
;
arguments_
[
name
]
=
&
new_node
;
}
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
);
for
(
auto
&
op
:
program
.
ops
)
{
node_storage_
.
emplace_back
();
// TODO(Superjomn) remove one valid_places here.
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
kernels
=
op
->
CreateKernels
(
valid_places
);
node_storage_
.
back
().
AsInstruct
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
,
op
->
op_info
());
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
// collect inputs and outputs
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
auto
*
arg
=
Argument
(
name
);
new_node
.
inlinks
.
push_back
(
arg
);
arg
->
outlinks
.
push_back
(
&
new_node
);
}
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
output_names
())
{
if
(
!
arguments_
.
count
(
name
))
{
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
arg
=
new_node
.
AsArgument
(
name
);
arg
.
name
=
name
;
arguments_
.
emplace
(
name
,
&
new_node
);
}
auto
*
arg
=
arguments_
.
at
(
name
);
new_node
.
outlinks
.
push_back
(
arg
);
arg
->
inlinks
.
push_back
(
&
new_node
);
}
}
MarkArgumentWeights
(
program
);
}
mir
::
Node
*
Argument
(
const
std
::
string
&
name
)
{
auto
it
=
arguments_
.
find
(
name
);
CHECK
(
it
!=
arguments_
.
end
())
<<
"no argument called "
<<
name
;
return
it
->
second
;
}
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
std
::
vector
<
mir
::
Node
*>
InstructTopologicalOrder
();
// The inputs of the graph.
std
::
vector
<
mir
::
Node
*>
inputs
()
{
std
::
vector
<
mir
::
Node
*>
res
;
for
(
auto
&
node
:
node_storage_
)
{
if
(
node
.
inlinks
.
empty
())
{
res
.
push_back
(
&
node
);
}
}
return
res
;
}
std
::
vector
<
mir
::
Node
*>
inputs
();
// The outputs of the graph.
std
::
vector
<
mir
::
Node
*>
outputs
()
{
std
::
vector
<
mir
::
Node
*>
res
;
for
(
auto
&
node
:
node_storage_
)
{
if
(
node
.
outlinks
.
empty
())
{
res
.
push_back
(
&
node
);
}
}
return
res
;
}
std
::
vector
<
mir
::
Node
*>
outputs
();
const
std
::
list
<
mir
::
Node
>
&
nodes
()
const
{
return
node_storage_
;
}
std
::
list
<
mir
::
Node
>
&
mutable_nodes
()
{
return
node_storage_
;
}
mir
::
Node
*
RetrieveArgument
(
const
std
::
string
&
arg
)
{
auto
it
=
arguments_
.
find
(
arg
);
if
(
it
!=
arguments_
.
end
())
{
return
it
->
second
;
}
return
nullptr
;
mir
::
Node
*
RetrieveArgument
(
const
std
::
string
&
arg
);
Node
*
NewArgumentNode
(
const
std
::
string
&
name
);
Node
*
NewInstructNode
();
void
CheckValid
()
{
CHECK
(
CheckBidirectionalConnection
());
CHECK
(
CheckNodesRoleSet
());
CHECK
(
CheckLinksRoleSet
());
}
private:
void
GraphCreateTmpVarNodes
(
const
Program
&
program
);
void
GraphCreateWeightVarNodes
(
const
Program
&
program
);
Node
*
GraphCreateInstructNode
(
const
Program
&
program
,
const
std
::
shared_ptr
<
OpLite
>
&
op
,
const
std
::
vector
<
Place
>
&
valid_places
);
// Check the bidirectional connection.
bool
CheckBidirectionalConnection
();
bool
CheckNodesRoleSet
();
// Check all the items's role in inlinks and outlinks is set.
bool
CheckLinksRoleSet
();
void
MarkArgumentWeights
(
const
Program
&
program
)
{
for
(
const
auto
&
name
:
program
.
weights
)
{
...
...
@@ -152,6 +92,48 @@ class SSAGraph : GraphBase {
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
};
// Remove the link between a -> b.
static
void
RemoveDirectedLink
(
Node
*
a
,
Node
*
b
)
{
auto
it
=
std
::
find
(
b
->
inlinks
.
begin
(),
b
->
inlinks
.
end
(),
a
);
if
(
it
!=
b
->
inlinks
.
end
())
{
b
->
inlinks
.
erase
(
it
);
}
auto
it1
=
std
::
find
(
a
->
outlinks
.
begin
(),
a
->
outlinks
.
end
(),
b
);
if
(
it1
!=
a
->
outlinks
.
end
())
{
a
->
outlinks
.
erase
((
it1
));
}
}
// Link a -> b.
static
void
DirectedLink
(
Node
*
a
,
Node
*
b
)
{
// Eagerly remove first, to avoid duplicate link.
RemoveDirectedLink
(
a
,
b
);
a
->
outlinks
.
push_back
(
b
);
b
->
inlinks
.
push_back
(
a
);
}
static
void
LocalInferenceType
(
Node
*
a
,
Node
*
b
,
const
std
::
string
&
arg_name
)
{
// instr -> output argument
if
(
a
->
IsInstruct
()
&&
b
->
IsArgument
())
{
auto
&
inst
=
a
->
AsInstruct
();
auto
&
output
=
b
->
AsArgument
();
if
(
!
output
.
type
)
{
output
.
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
}
}
// input argument -> instr
if
(
a
->
IsArgument
()
&&
b
->
IsInstruct
())
{
auto
&
input
=
a
->
AsArgument
();
auto
&
inst
=
b
->
AsInstruct
();
if
(
!
input
.
type
)
{
input
.
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
}
}
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
621d1522
...
...
@@ -37,7 +37,9 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
auto
&
instruct
=
node
.
AsInstruct
();
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
scored
.
emplace_back
(
KernelGrade
(
*
kernel
),
std
::
move
(
kernel
));
size_t
score
=
KernelGrade
(
*
kernel
);
LOG
(
INFO
)
<<
"kernel "
<<
kernel
->
summary
()
<<
" "
<<
score
;
scored
.
emplace_back
(
score
,
std
::
move
(
kernel
));
}
std
::
sort
(
scored
.
begin
(),
scored
.
end
(),
KernelScoreCmp
);
...
...
@@ -47,7 +49,6 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// TODO(Superjomn) reconsider this.
instruct
.
valid_kernels
.
clear
();
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
instruct
.
place
=
instruct
.
valid_kernels
.
front
()
->
place
();
LOG
(
INFO
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
}
}
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
浏览文件 @
621d1522
...
...
@@ -37,6 +37,7 @@ class StaticKernelPickPass : public mir::InstructionPass {
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
void
SetPreferPlace
(
const
Place
&
place
)
{
place_
=
place
;
}
const
Place
&
place
()
const
{
return
place_
;
}
const
core
::
KernelPickFactor
&
kernel_pick_factors
()
const
{
return
kernel_pick_factors_
;
...
...
@@ -51,16 +52,32 @@ class StaticKernelPickPass : public mir::InstructionPass {
size_t
score
{};
const
int
kMax
=
std
::
numeric_limits
<
core
::
KernelPickFactor
::
value_type
>::
max
();
// The more important factor comes first
if
(
kernel_pick_factors_
.
IsTargetConsidered
()
&&
place
().
target
==
kernel
.
target
())
{
(
place
().
target
==
kernel
.
target
()
||
kernel
.
target
()
==
TARGET
(
kAny
)
||
place
().
target
==
TARGET
(
kAny
)))
{
score
+=
kMax
/
static_cast
<
int
>
(
core
::
KernelPickFactor
::
Factor
::
TargetFirst
);
}
if
(
kernel_pick_factors_
.
IsPrecisionConsidered
()
&&
place
().
precision
==
kernel
.
precision
())
{
(
place
().
precision
==
kernel
.
precision
()
||
kernel
.
precision
()
==
PRECISION
(
kAny
)
||
place
().
precision
==
PRECISION
(
kAny
)))
{
score
+=
kMax
/
static_cast
<
int
>
(
core
::
KernelPickFactor
::
Factor
::
PrecisionFirst
);
}
if
(
kernel_pick_factors_
.
IsDataLayoutConsidered
()
&&
(
place
().
layout
==
kernel
.
layout
()
||
kernel
.
layout
()
==
DATALAYOUT
(
kAny
)
||
place
().
layout
==
DATALAYOUT
(
kAny
)))
{
score
+=
kMax
/
static_cast
<
int
>
(
core
::
KernelPickFactor
::
Factor
::
DataLayoutFirst
);
}
LOG
(
INFO
)
<<
"picker tactic "
<<
kernel_pick_factors_
;
LOG
(
INFO
)
<<
"kernel place "
<<
kernel
.
place
();
LOG
(
INFO
)
<<
"picker place "
<<
place
();
LOG
(
INFO
)
<<
"score "
<<
score
;
// The data layout is not considered, for the input and output arguments
// might have different data layout.
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.cc
浏览文件 @
621d1522
...
...
@@ -22,8 +22,8 @@ namespace mir {
void
VariablePlaceInferencePass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
MarkInputPlace
(
graph
.
get
());
InferenceArgumentPlace
(
graph
.
get
());
CheckAllArgumentTypeDetermined
(
graph
.
get
());
}
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
621d1522
...
...
@@ -31,6 +31,7 @@ class VariablePlaceInferencePass : public DebugPass {
private:
// Mark the place of input arguments.
void
MarkInputPlace
(
SSAGraph
*
graph
)
{
CHECK
(
!
graph
->
inputs
().
empty
())
<<
"graph's inputs should be set"
;
for
(
const
auto
&
v
:
graph
->
inputs
())
{
// the feed op might in the inputs
if
(
v
->
IsInstruct
())
{
...
...
@@ -39,54 +40,60 @@ class VariablePlaceInferencePass : public DebugPass {
}
// auto& arg = v->AsArgument();
// arg.place.target = argument_default_target_;
// LOG(INFO) << "get graph input " << arg.name << " " << *arg.type;
// arg.type.target = argument_default_target_;
// the other place description can't be determined yet, until their first
// usage by some kernel.
}
}
void
CheckAllArgumentTypeDetermined
(
SSAGraph
*
graph
)
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
node
.
IsArgument
())
{
CHECK
(
node
.
AsArgument
().
type
)
<<
"node "
<<
node
.
AsArgument
().
name
<<
" type not determined"
;
}
}
}
void
InferenceArgumentPlace
(
SSAGraph
*
graph
)
{
LOG
(
INFO
)
<<
"param-type-registry:
\n
"
<<
ParamTypeRegistry
::
Global
();
for
(
auto
&
x
:
graph
->
InstructTopologicalOrder
())
{
auto
&
inst
=
x
->
AsInstruct
();
CHECK
(
inst
.
place
.
is_valid
())
<<
"kernel's place should be set when loaded"
;
// The IoCopyOp is a tool operator, it won't support the type inference.
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
// LOG(INFO) << "- inferencing type " <<
// deal with inputs
for
(
auto
&
arg_name
:
inst
.
op_info
->
input_argnames
())
{
auto
type
=
ParamTypeRegistry
::
Global
().
Retrieve
<
ParamTypeRegistry
::
IO
::
kInput
>
(
inst
.
place
,
inst
.
op_type
,
arg_name
);
CHECK
(
type
)
<<
"no param-type found for "
<<
inst
.
op_type
<<
":"
<<
arg_name
<<
" "
<<
inst
.
place
.
DebugString
();
auto
arg_names
=
inst
.
op_info
->
input_argument
().
at
(
arg_name
);
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
input_argnames
())
{
LOG
(
INFO
)
<<
"-- input arg_name "
<<
arg_name
;
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
auto
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
input_argument
().
at
(
arg_name
);
for
(
auto
&
arg_name
:
arg_names
)
{
LOG
(
INFO
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArgument
();
if
(
arg_node
.
type
)
continue
;
arg_node
.
type
=
type
->
type
;
arg_node
.
type
=
type
;
}
}
for
(
auto
&
arg_name
:
inst
.
op_info
->
output_argnames
())
{
auto
type
=
ParamTypeRegistry
::
Global
()
.
Retrieve
<
ParamTypeRegistry
::
IO
::
kOutput
>
(
inst
.
place
,
inst
.
op_type
,
arg_name
);
CHECK
(
type
)
<<
"no param-type found for "
<<
inst
.
op_type
<<
":"
<<
arg_name
<<
" "
<<
inst
.
place
.
DebugString
();
auto
arg_names
=
inst
.
op_info
->
output_argument
().
at
(
arg_name
);
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
output_argnames
())
{
LOG
(
INFO
)
<<
"-- output arg_name "
<<
arg_name
;
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
output_argument
().
at
(
arg_name
);
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
LOG
(
INFO
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArgument
();
if
(
arg_node
.
type
)
continue
;
node
->
AsArgument
().
type
=
type
->
type
;
node
->
AsArgument
().
type
=
type
;
}
}
}
...
...
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
621d1522
...
...
@@ -27,13 +27,15 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
for
(
auto
place
:
places
)
{
auto
ks
=
KernelRegistry
::
Global
().
Create
(
(
kernel_type
.
empty
()
?
op_type_
:
kernel_type
),
place
.
target
,
place
.
precision
);
place
.
precision
,
place
.
layout
);
for
(
auto
&&
it
:
ks
)
{
AttachKernel
(
it
.
get
());
kernels
.
emplace_back
(
std
::
move
(
it
));
}
}
CHECK
(
!
kernels
.
empty
())
<<
"No kernel found for Op "
<<
op_type_
;
LOG
(
INFO
)
<<
"op "
<<
op_type_
<<
" get "
<<
kernels
.
size
()
<<
" kernels"
;
return
kernels
;
}
...
...
@@ -59,9 +61,10 @@ bool OpLite::Run() {
}
bool
OpLite
::
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
CHECK
(
!
op_info_
)
<<
"op_info duplicate build found"
;
op_info_
=
std
::
make_shared
<
OpInfo
>
();
op_info_
->
Build
(
opdesc
);
CHECK
(
scope
);
scope_
=
scope
;
op_info_
.
reset
(
new
OpInfo
);
// Force clean the out-of-date infomation.
op_info_
->
Build
(
opdesc
.
ReadonlyProto
());
return
AttachImpl
(
opdesc
,
scope
);
}
...
...
@@ -79,7 +82,8 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope,
return
var
->
GetMutable
<
lite
::
Tensor
>
();
}
bool
OpInfo
::
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
{
bool
OpInfo
::
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
input_argument_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
...
...
@@ -89,7 +93,8 @@ bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) {
}
return
false
;
}
bool
OpInfo
::
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
{
bool
OpInfo
::
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
output_argument_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
...
...
paddle/fluid/lite/core/op_lite.h
浏览文件 @
621d1522
...
...
@@ -81,19 +81,30 @@ class OpLite : public Registry {
// Run this operator.
virtual
bool
Run
();
// Link the external execution environ to internal context.
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
const
std
::
shared_ptr
<
OpInfo
>
&
op_info
()
const
{
return
op_info_
;
}
std
::
shared_ptr
<
OpInfo
>
&
mutable_op_info
()
{
return
op_info_
;
}
const
OpInfo
*
op_info
()
const
{
return
op_info_
.
get
()
;
}
OpInfo
*
mutable_op_info
()
{
return
op_info_
.
get
()
;
}
// Human-readable information.
virtual
std
::
string
DebugString
()
const
=
0
;
const
Place
&
kernel_place
()
const
{
return
kernel_place_
;
}
// NOTE This might be discarded.
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
// Create all the kernels for the valid targets.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
CreateKernels
(
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
=
""
);
lite
::
Scope
*
scope
()
{
return
scope_
;
}
// Assign op param to kernel.
virtual
void
AttachKernel
(
KernelBase
*
kernel
)
=
0
;
virtual
~
OpLite
()
=
default
;
protected:
...
...
@@ -101,9 +112,6 @@ class OpLite : public Registry {
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
// Assign op param to kernel.
virtual
void
AttachKernel
(
KernelBase
*
kernel
)
=
0
;
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
...
...
@@ -118,10 +126,6 @@ class OpLite : public Registry {
// some inputs are ready.
void
RecordOutputEvents
()
{}
// Create all the kernels for the valid targets.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
CreateKernels
(
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
=
""
);
const
Tensor
*
GetTensor
(
lite
::
Scope
*
scope
,
const
std
::
string
&
name
)
const
;
Tensor
*
GetMutableTensor
(
lite
::
Scope
*
scope
,
const
std
::
string
&
name
)
const
;
...
...
@@ -129,11 +133,12 @@ class OpLite : public Registry {
friend
class
mir
::
SSAGraph
;
protected:
lite
::
Scope
*
scope_
{};
std
::
unique_ptr
<
KernelBase
>
kernel_
;
std
::
string
op_type_
;
std
::
vector
<
Place
>
valid_places_
;
Place
kernel_place_
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)};
std
::
shared
_ptr
<
OpInfo
>
op_info_
;
std
::
unique
_ptr
<
OpInfo
>
op_info_
;
};
/*
...
...
@@ -142,22 +147,30 @@ class OpLite : public Registry {
*/
class
OpInfo
{
public:
void
Build
(
const
framework
::
OpDesc
&
desc
)
{
// To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf
// message instead.
void
Build
(
const
framework
::
proto
::
OpDesc
&
desc
)
{
ExtractInputsAndOutputs
(
desc
);
CollectInputAndOutputArgnames
(
desc
);
CollectArguments
(
desc
);
desc_
.
reset
(
new
framework
::
proto
::
OpDesc
(
desc
));
}
const
framework
::
proto
::
OpDesc
&
desc
()
const
{
CHECK
(
desc_
)
<<
"desc has't set"
;
return
*
desc_
;
}
framework
::
proto
::
OpDesc
*
mutable_desc
()
{
return
desc_
.
get
();
}
const
std
::
list
<
std
::
string
>
&
input_names
()
const
{
return
input_names_
;
}
const
std
::
list
<
std
::
string
>
&
output_names
()
const
{
return
output_names_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
input_argument
()
{
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
input_argument
()
const
{
return
input_argument_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
output_argument
()
{
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
output_argument
()
const
{
return
output_argument_
;
}
bool
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
);
bool
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
);
bool
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
;
bool
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
;
const
std
::
list
<
std
::
string
>
&
input_argnames
()
const
{
return
input_argnames_
;
...
...
@@ -167,37 +180,37 @@ class OpInfo {
}
private:
void
ExtractInputsAndOutputs
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
I
nputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
void
ExtractInputsAndOutputs
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
i
nputs
())
{
for
(
const
auto
&
x
:
item
.
arguments
()
)
{
input_names_
.
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
O
utputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
for
(
const
auto
&
item
:
opdesc
.
o
utputs
())
{
for
(
const
auto
&
x
:
item
.
arguments
()
)
{
output_names_
.
push_back
(
x
);
}
}
}
void
CollectInputAndOutputArgnames
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
InputName
s
())
{
input_argnames_
.
push_back
(
item
);
void
CollectInputAndOutputArgnames
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
input
s
())
{
input_argnames_
.
push_back
(
item
.
parameter
()
);
}
for
(
const
auto
&
item
:
opdesc
.
OutputName
s
())
{
output_argnames_
.
push_back
(
item
);
for
(
const
auto
&
item
:
opdesc
.
output
s
())
{
output_argnames_
.
push_back
(
item
.
parameter
()
);
}
}
void
CollectArguments
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
I
nputs
())
{
for
(
auto
&
x
:
item
.
second
)
{
input_argument_
[
item
.
first
].
push_back
(
x
);
void
CollectArguments
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
i
nputs
())
{
for
(
auto
&
x
:
item
.
arguments
()
)
{
input_argument_
[
item
.
parameter
()
].
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
O
utputs
())
{
for
(
auto
&
x
:
item
.
second
)
{
output_argument_
[
item
.
first
].
push_back
(
x
);
for
(
const
auto
&
item
:
opdesc
.
o
utputs
())
{
for
(
auto
&
x
:
item
.
arguments
()
)
{
output_argument_
[
item
.
parameter
()
].
push_back
(
x
);
}
}
}
...
...
@@ -209,6 +222,8 @@ class OpInfo {
std
::
list
<
std
::
string
>
output_argnames_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
input_argument_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
output_argument_
;
// NOTE too heavy.
std
::
unique_ptr
<
framework
::
proto
::
OpDesc
>
desc_
;
};
}
// namespace lite
...
...
paddle/fluid/lite/core/op_registry.cc
浏览文件 @
621d1522
...
...
@@ -18,13 +18,33 @@ namespace paddle {
namespace
lite
{
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
KernelRegistry
::
Create
(
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
)
{
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
return Create<TARGET(target__), PRECISION(kFloat)>(op_type); \
default: \
CHECK(false) << "not supported kernel place yet"; \
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
)
{
Place
place
{
target
,
precision
,
layout
};
LOG
(
INFO
)
<<
"creating "
<<
op_type
<<
" kernel for "
<<
place
;
#define CREATE_KERNEL1(target__, precision__) \
switch (layout) { \
case DATALAYOUT(kNCHW): \
return Create<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(kNCHW)>(op_type); \
case DATALAYOUT(kAny): \
return Create<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(kAny)>(op_type); \
default: \
LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \
}
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
CREATE_KERNEL1(target__, kFloat); \
case PRECISION(kInt8): \
CREATE_KERNEL1(target__, kInt8); \
case PRECISION(kAny): \
CREATE_KERNEL1(target__, kAny); \
default: \
CHECK(false) << "not supported kernel precision " \
<< PrecisionToStr(precision); \
}
switch
(
target
)
{
...
...
@@ -38,7 +58,7 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
CREATE_KERNEL
(
kCUDA
);
}
break
;
default:
CHECK
(
false
)
<<
"not supported kernel
place"
;
CHECK
(
false
)
<<
"not supported kernel
target "
<<
TargetToStr
(
target
)
;
}
#undef CREATE_KERNEL
...
...
@@ -46,14 +66,21 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
}
KernelRegistry
::
KernelRegistry
()
{
#define INIT_FOR(target__, precision__
)
\
#define INIT_FOR(target__, precision__
, layout__)
\
registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \
PRECISION(precision__)>()] \
.set<KernelRegistryForTarget<TARGET(target__), PRECISION(precision__)> \
*>(&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__)>::Global());
PRECISION(precision__), \
DATALAYOUT(layout__)>()] \
.set<KernelRegistryForTarget<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(layout__)> *>( \
&KernelRegistryForTarget<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(layout__)>::Global());
// Currently, just register 2 kernel targets.
INIT_FOR
(
kHost
,
kFloat
);
INIT_FOR
(
kCUDA
,
kFloat
,
kNCHW
);
INIT_FOR
(
kCUDA
,
kAny
,
kNCHW
);
INIT_FOR
(
kHost
,
kFloat
,
kNCHW
);
INIT_FOR
(
kHost
,
kAny
,
kNCHW
);
INIT_FOR
(
kHost
,
kAny
,
kAny
);
INIT_FOR
(
kCUDA
,
kAny
,
kAny
);
#undef INIT_FOR
}
...
...
paddle/fluid/lite/core/op_registry.h
浏览文件 @
621d1522
...
...
@@ -50,80 +50,108 @@ class OpLiteRegistor : public Registor<OpClass> {
})
{}
};
template
<
TargetType
Target
,
PrecisionType
Precision
>
template
<
TargetType
Target
,
PrecisionType
Precision
,
DataLayoutType
Layout
>
using
KernelRegistryForTarget
=
Factory
<
OpKernel
<
Target
,
Precision
>
,
std
::
unique_ptr
<
KernelBase
>>
;
Factory
<
OpKernel
<
Target
,
Precision
,
Layout
>
,
std
::
unique_ptr
<
KernelBase
>>
;
class
KernelRegistry
final
{
public:
using
any_kernel_registor_t
=
variant
<
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
*
//
variant
<
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kInt8
),
DATALAYOUT
(
kNCHW
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kInt8
),
DATALAYOUT
(
kNCHW
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)
>
*
//
>
;
KernelRegistry
();
static
KernelRegistry
&
Global
();
template
<
TargetType
Target
,
PrecisionType
Precision
>
template
<
TargetType
Target
,
PrecisionType
Precision
,
DataLayoutType
Layout
>
void
Register
(
const
std
::
string
&
name
,
typename
KernelRegistryForTarget
<
Target
,
Precision
>::
creator_t
&&
creator
)
{
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
>
;
registries_
[
GetKernelOffset
<
Target
,
Precision
>
()]
.
template
get
<
kernel_registor_t
*
>()
->
Register
(
name
,
std
::
move
(
creator
));
typename
KernelRegistryForTarget
<
Target
,
Precision
,
Layout
>::
creator_t
&&
creator
)
{
LOG
(
INFO
)
<<
"register for "
<<
TargetToStr
(
Target
)
<<
":"
<<
PrecisionToStr
(
Precision
)
<<
"//"
<<
GetKernelOffset
<
Target
,
Precision
,
Layout
>
();
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
,
Layout
>
;
auto
&
varient
=
registries_
[
GetKernelOffset
<
Target
,
Precision
,
Layout
>
()];
varient
.
template
get
<
kernel_registor_t
*
>()
->
Register
(
name
,
std
::
move
(
creator
));
}
template
<
TargetType
Target
,
PrecisionType
Precision
>
template
<
TargetType
Target
,
PrecisionType
Precision
,
DataLayoutType
Layout
>
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
Create
(
const
std
::
string
&
op_type
)
{
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
>
;
return
registries_
[
GetKernelOffset
<
Target
,
Precision
>
()]
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
,
Layout
>
;
return
registries_
[
GetKernelOffset
<
Target
,
Precision
,
Layout
>
()]
.
template
get
<
kernel_registor_t
*
>()
->
Creates
(
op_type
);
}
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
Create
(
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
);
PrecisionType
precision
,
DataLayoutType
layout
);
// Get a kernel registry offset in all the registries.
template
<
TargetType
Target
,
PrecisionType
Precision
>
static
constexpr
int
GetKernelOffset
()
{
return
kNumTargets
*
static_cast
<
int
>
(
Target
)
+
static_cast
<
int
>
(
Precision
);
template
<
TargetType
Target
,
PrecisionType
Precision
,
DataLayoutType
Layout
>
static
int
GetKernelOffset
()
{
CHECK_LT
(
static_cast
<
int
>
(
Target
),
static_cast
<
int
>
(
TARGET
(
NUM
)));
CHECK_LT
(
static_cast
<
int
>
(
Precision
),
static_cast
<
int
>
(
PRECISION
(
NUM
)));
CHECK_LT
(
static_cast
<
int
>
(
Layout
),
static_cast
<
int
>
(
DATALAYOUT
(
NUM
)));
return
static_cast
<
int
>
(
Target
)
*
static_cast
<
int
>
(
PRECISION
(
NUM
))
*
static_cast
<
int
>
(
DATALAYOUT
(
NUM
))
+
//
static_cast
<
int
>
(
Precision
)
*
static_cast
<
int
>
(
DATALAYOUT
(
NUM
))
+
//
static_cast
<
int
>
(
Layout
);
}
std
::
string
DebugString
()
const
{
std
::
stringstream
ss
;
ss
<<
"KernelCreator<host, float>:"
<<
std
::
endl
;
ss
<<
registries_
[
GetKernelOffset
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
()]
.
get
<
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
*>
()
ss
<<
registries_
[
GetKernelOffset
<
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kAny
)
>
()]
.
get
<
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)
>
*>
()
->
DebugString
();
ss
<<
std
::
endl
;
return
ss
.
str
();
}
private:
mutable
std
::
array
<
any_kernel_registor_t
,
kNumTargets
*
kNumPrecisions
>
mutable
std
::
array
<
any_kernel_registor_t
,
static_cast
<
int
>
(
TARGET
(
NUM
))
*
static_cast
<
int
>
(
PRECISION
(
NUM
))
*
static_cast
<
int
>
(
DATALAYOUT
(
NUM
))
>
registries_
;
};
template
<
TargetType
target
,
PrecisionType
precision
,
typename
KernelType
>
template
<
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
,
typename
KernelType
>
class
KernelRegistor
:
public
lite
::
Registor
<
KernelType
>
{
public:
KernelRegistor
(
const
std
::
string
op_type
)
:
Registor
<
KernelType
>
([
&
]
{
KernelRegistor
(
const
std
::
string
&
op_type
,
const
std
::
string
&
alias
)
:
Registor
<
KernelType
>
([
=
]
{
LOG
(
INFO
)
<<
"Register kernel "
<<
op_type
<<
" for "
<<
TargetToStr
(
target
)
<<
" "
<<
PrecisionToStr
(
precision
);
KernelRegistry
::
Global
().
Register
<
target
,
precision
>
(
op_type
,
[
&
,
op_type
]()
->
std
::
unique_ptr
<
KernelType
>
{
<<
TargetToStr
(
target
)
<<
" "
<<
PrecisionToStr
(
precision
)
<<
" "
<<
DataLayoutToStr
(
layout
)
<<
" alias "
<<
alias
;
KernelRegistry
::
Global
().
Register
<
target
,
precision
,
layout
>
(
op_type
,
[
=
]()
->
std
::
unique_ptr
<
KernelType
>
{
std
::
unique_ptr
<
KernelType
>
x
(
new
KernelType
);
x
->
set_op_type
(
op_type
);
x
->
set_alias
(
alias
);
return
x
;
});
})
{}
...
...
@@ -151,35 +179,40 @@ class KernelRegistor : public lite::Registor<KernelType> {
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
op_type__##__##target__##__##precision__##__registor__
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
alias__)
\
layout__, alias__)
\
op_type__##__##target__##__##precision__##__registor__instance__##alias__
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__)
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass, \
alias__) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
alias__)(#op_type__); \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \
alias__); \
int touch_##op_type__##target__##precision__##alias__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__).Touch(); \
return 0; \
} \
static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \
alias__) __attribute__((unused)) = \
paddle::lite::ParamTypeRegistry::NewInstance<TARGET(target__), \
PRECISION(precision__)>( \
#op_type__)
#define USE_LITE_KERNEL(op_type__, target__, precision__, alias__) \
extern int touch_##op_type__##target__##precision__##alias__(); \
int op_type__##target__##precision__##alias__ __attribute__((unused)) = \
touch_##op_type__##target__##precision__##alias__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__##param_register
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, layout__, \
KernelClass, alias__) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
layout__, alias__)(#op_type__, #alias__); \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \
layout__, alias__); \
int touch_##op_type__##target__##precision__##layout__##alias__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, alias__) \
.Touch(); \
return 0; \
} \
static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \
layout__, alias__) \
__attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \
TARGET(target__), PRECISION(precision__), DATALAYOUT(layout__)>( \
#op_type__ "/" #alias__)
#define USE_LITE_KERNEL(op_type__, target__, precision__, layout__, alias__) \
extern int touch_##op_type__##target__##precision__##layout__##alias__(); \
int op_type__##target__##precision__##layout__##alias__ \
__attribute__((unused)) = \
touch_##op_type__##target__##precision__##layout__##alias__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, \
alias__) \
op_type__##target__##precision__##layout__##alias__
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, layout__, \
alias__) \
op_type__##target__##precision__##layout__##alias__##param_register
paddle/fluid/lite/core/optimizer.cc
浏览文件 @
621d1522
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
namespace
paddle
{
...
...
@@ -25,5 +26,33 @@ void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) {
*
pass
->
mutable_kernel_pick_factors
()
=
factor
;
}
void
Optimizer
::
RunPasses
()
{
std
::
vector
<
std
::
string
>
passes
({
"static_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"io_complement_pass"
,
//
"argument_type_display_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"io_copy_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
});
for
(
auto
&
pass_type
:
passes
)
{
LOG
(
INFO
)
<<
".. running pass "
<<
pass_type
;
auto
*
pass
=
mir
::
PassManager
::
Global
().
LookUp
(
pass_type
);
CHECK
(
pass
);
if
(
pass
->
name
()
==
"io_complement_pass"
)
{
auto
*
_pass
=
dynamic_cast
<
mir
::
IoComplementPass
*>
(
pass
);
_pass
->
SetValidPlaces
(
valid_places_
);
CHECK
(
!
_pass
->
valid_places
().
empty
());
_pass
->
Apply
(
graph_
);
}
else
{
pass
->
Apply
(
graph_
);
}
}
// mir::PassManager::Global().Run(graph_);
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/optimizer.h
浏览文件 @
621d1522
...
...
@@ -16,8 +16,10 @@
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
...
...
@@ -33,25 +35,46 @@ class Optimizer {
void
Run
(
Program
&&
program
,
const
std
::
vector
<
Place
>&
valid_places
,
core
::
KernelPickFactor
kernel_pick_factor
,
const
std
::
vector
<
std
::
string
>&
passes
=
{})
{
valid_places_
=
valid_places
;
CHECK
(
!
valid_places
.
empty
())
<<
"At least one valid_place should be set"
;
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
->
Build
(
program
,
valid_places
);
SpecifyKernelPickTactic
(
kernel_pick_factor
);
// InitIoComplement();
RunPasses
();
exec_scope_
=
program
.
exec_scope
;
}
void
KernelPickPreferPlace
(
const
Place
&
place
)
{
auto
*
pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
StaticKernelPickPass
>
(
"static_kernel_pick_pass"
);
CHECK
(
pass
);
pass
->
SetPreferPlace
(
place
);
}
// Generate a new program based on the mir graph.
std
::
unique_ptr
<
RuntimeProgram
>
GenRuntimeProgram
()
{
LOG
(
INFO
)
<<
"generate program"
;
std
::
unique_ptr
<
Program
>
res
;
auto
pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
GenerateProgramPass
>
(
"generate_program_pass"
);
pass
->
Apply
(
graph_
);
auto
program
=
pass
->
GenProgram
();
CHECK
(
exec_scope_
);
program
->
set_exec_scope
(
exec_scope_
);
return
program
;
}
void
InitIoComplement
()
{
auto
*
pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
IoComplementPass
>
(
"io_complement_pass"
);
CHECK
(
pass
);
CHECK
(
!
valid_places_
.
empty
());
LOG
(
INFO
)
<<
"valid_places.size "
<<
valid_places_
.
size
();
pass
->
SetValidPlaces
(
valid_places_
);
}
// Generate C++ code which combines the inference program, model and weights.
void
GenCode
(
const
std
::
string
&
code_dir
);
...
...
@@ -64,13 +87,14 @@ class Optimizer {
void
SpecifyKernelPickTactic
(
core
::
KernelPickFactor
factor
);
// Run the default passes registered in the PassManager.
void
RunPasses
()
{
mir
::
PassManager
::
Global
().
Run
(
graph_
);
}
void
RunPasses
()
;
// Specify the passes and run them.
void
RunPasses
(
std
::
vector
<
std
::
string
>&
passes
);
private:
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
std
::
vector
<
Place
>
valid_places_
;
lite
::
Scope
*
exec_scope_
{};
};
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
621d1522
...
...
@@ -84,13 +84,10 @@ struct Program {
tmp_vars
.
push_back
(
"fetch"
);
for
(
auto
var_desc
:
program
.
Block
(
0
).
AllVars
())
{
if
(
!
var_desc
->
Persistable
())
{
LOG
(
INFO
)
<<
"get tmp var "
<<
var_desc
->
Name
();
tmp_vars
.
push_back
(
var_desc
->
Name
());
auto
*
var
=
exec_scope
->
Var
(
var_desc
->
Name
());
LOG
(
INFO
)
<<
"create tmp var "
<<
var_desc
->
Name
()
<<
" "
<<
var
;
exec_scope
->
Var
(
var_desc
->
Name
());
}
else
{
if
(
var_desc
->
Name
()
==
"feed"
||
var_desc
->
Name
()
==
"fetch"
)
continue
;
LOG
(
INFO
)
<<
"get weight var "
<<
var_desc
->
Name
();
weights
.
push_back
(
var_desc
->
Name
());
}
}
...
...
@@ -105,15 +102,19 @@ struct Instruction {
void
Run
()
{
CHECK
(
op_
);
CHECK
(
kernel_
);
LOG
(
INFO
)
<<
"running kernel> "
<<
kernel_
->
DebugString
();
if
(
UNLIKELY
(
first_epoch_
))
{
first_epoch_
=
false
;
op_
->
CheckShape
(
);
CHECK
(
op_
->
CheckShape
()
);
}
op_
->
InferShape
();
kernel_
->
Run
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Instruction
&
other
)
{
os
<<
other
.
kernel_
->
summary
()
<<
"
\t
("
<<
other
.
kernel_
->
doc
()
<<
")"
;
return
os
;
}
private:
std
::
shared_ptr
<
OpLite
>
op_
;
std
::
unique_ptr
<
KernelBase
>
kernel_
;
...
...
@@ -125,11 +126,16 @@ struct Instruction {
*/
class
RuntimeProgram
{
public:
explicit
RuntimeProgram
(
std
::
vector
<
Instruction
>&&
instruction
)
:
instructions_
(
std
::
move
(
instruction
))
{}
explicit
RuntimeProgram
(
std
::
vector
<
Instruction
>&&
insts
)
:
instructions_
(
std
::
move
(
insts
))
{
if
(
insts
.
empty
())
{
LOG
(
ERROR
)
<<
"no instructions"
;
}
}
void
Run
()
{
for
(
auto
&
inst
:
instructions_
)
{
LOG
(
INFO
)
<<
">> Running kernel: "
<<
inst
;
inst
.
Run
();
}
}
...
...
paddle/fluid/lite/core/target_wrapper.h
浏览文件 @
621d1522
...
...
@@ -16,6 +16,10 @@
#include <glog/logging.h>
#include <iostream>
#include <sstream>
#ifdef LITE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
namespace
paddle
{
namespace
lite
{
...
...
@@ -26,20 +30,20 @@ enum class TargetType : int {
kX86
,
kCUDA
,
kAny
,
// any target
kLastAsPlaceHolder
,
NUM
,
// number of fields.
};
enum
class
PrecisionType
:
int
{
kUnk
=
0
,
kFloat
,
kInt8
,
kAny
,
// any precision
kLastAsPlaceHolder
,
NUM
,
// number of fields.
};
enum
class
DataLayoutType
:
int
{
kUnk
=
0
,
kNCHW
,
kAny
,
// any data layout
kLastAsPlaceHolder
,
NUM
,
// number of fields.
};
// Some helper macro to get a specific TargetType.
...
...
@@ -50,25 +54,29 @@ enum class DataLayoutType : int {
#define PRECISION_VAL(item__) static_cast<int>(PRECISION(item__))
#define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__
constexpr
const
int
kNumPrecisions
=
PRECISION_VAL
(
kLastAsPlaceHolder
)
-
PRECISION_VAL
(
kFloat
);
constexpr
const
int
kNumTargets
=
TARGET_VAL
(
kLastAsPlaceHolder
)
-
TARGET_VAL
(
kHost
);
constexpr
const
int
kNumPrecisions
=
PRECISION_VAL
(
NUM
);
constexpr
const
int
kNumTargets
=
TARGET_VAL
(
NUM
);
static
const
std
::
string
target2string
[]
=
{
"unk"
,
"host"
,
"x86"
,
"cuda"
,
"any"
};
static
const
std
::
string
&
TargetToStr
(
TargetType
target
)
{
return
target2string
[
static_cast
<
int
>
(
target
)];
auto
x
=
static_cast
<
int
>
(
target
);
CHECK_LT
(
x
,
static_cast
<
int
>
(
TARGET
(
NUM
)));
return
target2string
[
x
];
}
static
const
std
::
string
precision2string
[]
=
{
"unk"
,
"float"
,
"int8"
,
"any"
};
static
const
std
::
string
&
PrecisionToStr
(
PrecisionType
precision
)
{
return
precision2string
[
static_cast
<
int
>
(
precision
)];
auto
x
=
static_cast
<
int
>
(
precision
);
CHECK_LT
(
x
,
static_cast
<
int
>
(
PRECISION
(
NUM
)));
return
precision2string
[
x
];
}
static
const
std
::
string
datalayout2string
[]
=
{
"unk"
,
"NCHW"
,
"any"
};
static
const
std
::
string
&
DataLayoutToStr
(
DataLayoutType
x
)
{
return
datalayout2string
[
static_cast
<
int
>
(
x
)];
static
const
std
::
string
&
DataLayoutToStr
(
DataLayoutType
layout
)
{
auto
x
=
static_cast
<
int
>
(
layout
);
CHECK_LT
(
x
,
static_cast
<
int
>
(
DATALAYOUT
(
NUM
)));
return
datalayout2string
[
x
];
}
/*
...
...
@@ -187,5 +195,37 @@ class TargetWrapper<TARGET(kHost)> {
}
};
#ifdef LITE_WITH_CUDA
// This interface should be specified by each kind of target.
template
<
>
class
TargetWrapper
<
TARGET
(
kCUDA
),
cudaStream_t
,
cudaEvent_t
>
{
public:
using
stream_t
=
cudaStream_t
;
using
event_t
=
cudaEvent_t
;
static
size_t
num_devices
()
{
return
0
;
}
static
size_t
maximum_stream
()
{
return
0
;
}
static
void
CreateStream
(
stream_t
*
stream
)
{}
static
void
DestroyStream
(
const
stream_t
&
stream
)
{}
static
void
CreateEvent
(
event_t
*
event
)
{}
static
void
DestroyEvent
(
const
event_t
&
event
)
{}
static
void
RecordEvent
(
const
event_t
&
event
)
{}
static
void
SyncEvent
(
const
event_t
&
event
)
{}
static
void
StreamSync
(
const
stream_t
&
stream
)
{}
static
void
*
Malloc
(
size_t
size
);
static
void
Free
(
void
*
ptr
);
static
void
MemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
);
static
void
MemcpyAsync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
,
const
stream_t
&
stream
);
};
#endif // LITE_WITH_CUDA
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/type_system.cc
浏览文件 @
621d1522
...
...
@@ -87,20 +87,41 @@ const Type* Type::Get<TensorAnyTy>(TargetType target) {
}
}
template
<
TargetType
Target
>
const
Type
*
GetTensorFp32NCHWTy
()
{
static
TensorFp32NCHWTy
x
(
Target
);
return
&
x
;
}
template
<
>
const
Type
*
Type
::
Get
<
TensorFp32NCHWTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TargetType
::
kX86
:
return
Get
<
false
,
true
,
TargetType
::
kX86
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
case
TargetType
::
kHost
:
return
Get
<
false
,
true
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
case
TARGET
(
kHost
):
return
GetTensorFp32NCHWTy
<
TARGET
(
kHost
)
>
();
case
TARGET
(
kCUDA
):
return
GetTensorFp32NCHWTy
<
TARGET
(
kCUDA
)
>
();
case
TARGET
(
kX86
):
return
GetTensorFp32NCHWTy
<
TARGET
(
kX86
)
>
();
default:
LOG
(
FATAL
)
<<
"unsupported target Type "
<<
TargetToStr
(
target
);
}
return
nullptr
;
}
const
Type
*
LookupType
(
DataTypeBase
::
ID
type_id
,
bool
is_unknown
,
bool
is_tensor
,
Place
place
)
{
using
id_t
=
DataTypeBase
::
ID
;
switch
(
type_id
)
{
case
id_t
::
Tensor_Any
:
return
Type
::
Get
<
TensorAnyTy
>
(
place
.
target
);
case
id_t
::
Tensor_Fp32_NCHW
:
return
Type
::
Get
<
TensorFp32NCHWTy
>
(
place
.
target
);
case
id_t
::
TensorList_Any
:
return
Type
::
Get
<
TensorListAnyTy
>
(
place
.
target
);
default:
LOG
(
FATAL
)
<<
"unsupported target "
<<
TargetToStr
(
target
);
return
nullptr
;
LOG
(
FATAL
)
<<
"unsupported type"
;
}
return
nullptr
;
}
// ------------------------- end GetType specification ------------------------
...
...
paddle/fluid/lite/core/type_system.h
浏览文件 @
621d1522
...
...
@@ -131,6 +131,23 @@ class Type : public DataTypeBase {
bool
operator
==
(
const
Type
&
other
)
{
return
id_
==
other
.
id
()
&&
place_
==
other
.
place
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Type
&
other
)
{
if
(
other
.
IsUnsupported
())
{
os
<<
"<Unsupported>"
;
return
os
;
}
if
(
other
.
IsVoid
())
{
os
<<
"<Void>"
;
return
os
;
}
if
(
other
.
is_tensor_
)
{
os
<<
"<Tensor:"
;
}
os
<<
TargetToStr
(
other
.
target
())
<<
"/"
<<
PrecisionToStr
(
other
.
precision
())
<<
"/"
<<
DataLayoutToStr
(
other
.
layout
())
<<
">"
;
return
os
;
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a instruction to transform a type to another.
...
...
@@ -163,29 +180,33 @@ class Type : public DataTypeBase {
};
// -------------------------------- compatible check ---------------------------
static
bool
TargetCompatible
(
const
Type
&
a
,
const
Type
&
b
)
{
return
(
a
.
IsVoid
()
||
b
.
IsVoid
())
||
//
a
.
target
()
==
b
.
target
();
static
bool
TargetCompatibleTo
(
const
Type
&
a
,
const
Type
&
b
)
{
return
a
.
IsVoid
()
||
//
(
a
.
IsTensor
()
&&
b
.
IsTensor
()
&&
(
a
.
target
()
==
b
.
target
()
||
//
b
.
target
()
==
TARGET
(
kAny
)));
}
static
bool
DataLayoutCompatible
(
const
Type
&
a
,
const
Type
&
b
)
{
return
(
a
.
IsVoid
()
||
b
.
IsVoid
())
||
//
(
a
.
IsTensor
()
&&
b
.
IsTensor
()
&&
a
.
layout
()
==
b
.
layout
());
static
bool
DataLayoutCompatibleTo
(
const
Type
&
a
,
const
Type
&
b
)
{
return
a
.
IsVoid
()
||
//
(
a
.
IsTensor
()
&&
b
.
IsTensor
()
&&
(
a
.
layout
()
==
b
.
layout
()
||
//
b
.
layout
()
==
DATALAYOUT
(
kAny
)));
}
static
bool
PrecisionCompatible
(
const
Type
&
a
,
const
Type
&
b
)
{
return
(
a
.
IsVoid
()
||
b
.
IsVoid
())
||
//
(
a
.
precision
()
==
b
.
precision
());
static
bool
PrecisionCompatibleTo
(
const
Type
&
a
,
const
Type
&
b
)
{
return
a
.
IsVoid
()
||
//
(
a
.
IsTensor
()
&&
b
.
IsTensor
()
&&
(
a
.
precision
()
==
b
.
precision
()
||
//
b
.
precision
()
==
PRECISION
(
kAny
)));
}
static
bool
DeviceCompatible
(
const
Type
&
a
,
const
Type
&
b
)
{
return
(
a
.
IsVoid
()
||
b
.
IsVoid
()
)
||
//
(
a
.
device
()
==
b
.
device
(
));
static
bool
DeviceCompatible
To
(
const
Type
&
a
,
const
Type
&
b
)
{
return
a
.
IsVoid
(
)
||
//
(
a
.
IsTensor
()
&&
b
.
IsTensor
()
&&
(
a
.
device
()
==
b
.
device
()
));
}
static
bool
TypeCompatible
(
const
Type
&
a
,
const
Type
&
b
)
{
return
TargetCompatible
(
a
,
b
)
&&
DataLayoutCompatible
(
a
,
b
)
&&
PrecisionCompatible
(
a
,
b
)
&&
DeviceCompatible
(
a
,
b
);
// Can type 'a' be passed to 'b' directly.
static
bool
TypeCompatibleTo
(
const
Type
&
a
,
const
Type
&
b
)
{
return
TargetCompatibleTo
(
a
,
b
)
&&
DataLayoutCompatibleTo
(
a
,
b
)
&&
PrecisionCompatibleTo
(
a
,
b
)
&&
DeviceCompatibleTo
(
a
,
b
);
}
// -------------------------------- predefined types ---------------------------
...
...
@@ -230,6 +251,9 @@ class TensorInt64NCHWTy : public Type {
:
Type
(
ID
::
Tensor_Int64_NCHW
,
"TensorInt64NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kInt8
,
DataLayoutType
::
kNCHW
)
{}
};
const
Type
*
LookupType
(
DataTypeBase
::
ID
type_id
,
bool
is_unknown
,
bool
is_tensor
,
Place
place
);
// ------------------------- end predefined types ---------------------------
// NOTE TypeSystem has some overhead, and better to be used in analysis phase.
...
...
@@ -381,13 +405,15 @@ class ParamTypeRegistry {
CHECK
(
types_
.
count
(
key
));
}
template
<
IO
io
>
const
ParamType
*
Retrieve
(
const
Place
&
place
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
KernelIdTy
key
{
op_type
,
place
,
io
,
arg_name
};
auto
it
=
types_
.
find
(
key
);
if
(
it
==
types_
.
end
())
return
nullptr
;
return
&
it
->
second
;
const
ParamType
*
RetrieveInArgument
(
const
Place
&
place
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
return
Retrieve
<
IO
::
kInput
>
(
place
,
op_type
,
arg_name
);
}
const
ParamType
*
RetrieveOutArgument
(
const
Place
&
place
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
return
Retrieve
<
IO
::
kOutput
>
(
place
,
op_type
,
arg_name
);
}
static
ParamTypeRegistry
&
Global
()
{
...
...
@@ -403,6 +429,16 @@ class ParamTypeRegistry {
return
os
;
}
protected:
template
<
IO
io
>
const
ParamType
*
Retrieve
(
const
Place
&
place
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
KernelIdTy
key
{
op_type
,
place
,
io
,
arg_name
};
auto
it
=
types_
.
find
(
key
);
if
(
it
==
types_
.
end
())
return
nullptr
;
return
&
it
->
second
;
}
private:
ParamTypeRegistry
()
=
default
;
...
...
paddle/fluid/lite/core/types.cc
浏览文件 @
621d1522
...
...
@@ -43,6 +43,9 @@ bool KernelPickFactor::IsTargetConsidered() const {
bool
KernelPickFactor
::
IsDataLayoutConsidered
()
const
{
return
data_
&
static_cast
<
int
>
(
Factor
::
DataLayoutFirst
);
}
bool
KernelPickFactor
::
IsDeviceConsidered
()
const
{
return
data_
&
static_cast
<
int
>
(
Factor
::
DeviceFirst
);
}
}
// namespace core
}
// namespace lite
...
...
paddle/fluid/lite/core/types.h
浏览文件 @
621d1522
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <stack>
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/utils/all.h"
...
...
@@ -38,6 +39,7 @@ class KernelPickFactor {
bool
AnyFactorConsidered
()
const
{
return
data_
;
}
KernelPickFactor
&
ConsiderTarget
();
// Perfer a specific target, e.g. prefer CUDA kernels.
KernelPickFactor
&
ConsiderPrecision
();
KernelPickFactor
&
ConsiderDataLayout
();
KernelPickFactor
&
ConsiderDevice
();
...
...
@@ -45,12 +47,29 @@ class KernelPickFactor {
bool
IsTargetConsidered
()
const
;
bool
IsPrecisionConsidered
()
const
;
bool
IsDataLayoutConsidered
()
const
;
bool
IsDeviceConsidered
()
const
{
return
data_
&
static_cast
<
int
>
(
Factor
::
DeviceFirst
);
bool
IsDeviceConsidered
()
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
KernelPickFactor
&
k
)
{
std
::
stack
<
bool
>
bits
;
auto
data
=
k
.
data_
;
while
(
data
)
{
bits
.
push
(
data
%
2
);
data
/=
2
;
}
int
nbits
=
bits
.
size
();
for
(
size_t
i
=
0
;
i
<
sizeof
(
data
)
*
8
-
nbits
;
i
++
)
{
os
<<
0
;
}
while
(
!
bits
.
empty
())
{
os
<<
bits
.
top
();
bits
.
pop
();
}
return
os
;
}
private:
unsigned
char
data_
{};
TargetType
target_
{
TARGET
(
kUnk
)};
};
struct
dim2
{
...
...
paddle/fluid/lite/cuda/target_wrapper.cc
浏览文件 @
621d1522
...
...
@@ -12,10 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Created by chunwei on 19-2-23.
//
#include "paddle/fluid/lite/cuda/target_wrapper.h"
#include <glog/logging.h>
...
...
@@ -24,19 +20,14 @@ namespace lite {
using
TargetW
=
TargetWrapper
<
TARGET
(
kCUDA
),
cudaStream_t
,
cudaEvent_t
>
;
template
<
>
void
*
TargetW
::
Malloc
(
size_t
size
)
{
void
*
ptr
{};
CHECK_EQ
(
cudaSuccess
,
cudaMalloc
(
&
ptr
,
size
));
return
ptr
;
}
template
<
>
void
TargetW
::
Free
(
void
*
ptr
)
{
CHECK_EQ
(
cudaSuccess
,
cudaFree
(
ptr
));
}
void
TargetW
::
Free
(
void
*
ptr
)
{
CHECK_EQ
(
cudaSuccess
,
cudaFree
(
ptr
));
}
template
<
>
void
TargetW
::
MemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
)
{
switch
(
dir
)
{
...
...
@@ -55,7 +46,6 @@ void TargetW::MemcpySync(void* dst, const void* src, size_t size,
}
}
template
<
>
void
TargetW
::
MemcpyAsync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
,
const
stream_t
&
stream
)
{
switch
(
dir
)
{
...
...
paddle/fluid/lite/kernels/cuda/CMakeLists.txt
浏览文件 @
621d1522
cc
_library
(
mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite
)
nv
_library
(
mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite
)
cc_library
(
io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor_lite
)
nv_library
(
kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda
)
paddle/fluid/lite/kernels/cuda/io_copy_compute.cc
浏览文件 @
621d1522
...
...
@@ -21,7 +21,7 @@ namespace lite {
namespace
kernels
{
namespace
cuda
{
using
TargetW
=
TargetWrapper
<
TARGET
(
k
Host
),
cudaStream_t
,
cudaEvent_t
>
;
using
TargetW
=
TargetWrapper
<
TARGET
(
k
CUDA
),
cudaStream_t
,
cudaEvent_t
>
;
// Host to CUDA memory.
void
CopyFromHostSync
(
void
*
target
,
const
void
*
source
,
size_t
size
)
{
...
...
@@ -51,6 +51,25 @@ class IoCopyHostToCudaCompute
auto
*
data
=
param
.
y
->
mutable_data
(
target
(),
param
.
x
->
memory_size
());
CopyFromHostSync
(
data
,
param
.
x
->
data
<
void
>
(),
param
.
x
->
memory_size
());
}
std
::
unique_ptr
<
type_infer_handler_t
>
GetTypeInferHandler
()
override
{
std
::
unique_ptr
<
type_infer_handler_t
>
res
(
new
type_infer_handler_t
);
*
res
=
[](
const
std
::
map
<
std
::
string
,
const
Type
*>&
inputs
,
const
std
::
string
&
out
)
->
const
Type
*
{
CHECK
(
!
inputs
.
empty
());
auto
*
type
=
inputs
.
at
(
"Input"
);
CHECK
(
type
->
target
()
==
TARGET
(
kHost
));
auto
out_place
=
type
->
place
();
out_place
.
target
=
TARGET
(
kCUDA
);
auto
*
out_type
=
LookupType
(
type
->
id
(),
type
->
IsUnsupported
(),
type
->
IsUnsupported
(),
out_place
);
return
out_type
;
};
return
res
;
}
std
::
string
doc
()
const
override
{
return
"Copy IO from HOST to CUDA"
;
}
};
/*
...
...
@@ -65,6 +84,8 @@ class IoCopyCudaToHostCompute
auto
*
data
=
param
.
y
->
mutable_data
(
TARGET
(
kHost
),
param
.
x
->
memory_size
());
CopyToHostSync
(
data
,
param
.
x
->
data
<
void
>
(),
param
.
x
->
memory_size
());
}
std
::
string
doc
()
const
override
{
return
"Copy IO from CUDA to HOST"
;
}
};
}
// namespace cuda
...
...
@@ -72,7 +93,7 @@ class IoCopyCudaToHostCompute
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
REGISTER_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
paddle
::
lite
::
kernels
::
cuda
::
IoCopyHostToCudaCompute
,
host_to_device
)
.
BindInput
(
"Input"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorAnyTy
>
(
...
...
@@ -81,7 +102,7 @@ REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny,
TARGET
(
kCUDA
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
REGISTER_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
paddle
::
lite
::
kernels
::
cuda
::
IoCopyCudaToHostCompute
,
device_to_host
)
.
BindInput
(
"Input"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorAnyTy
>
(
...
...
paddle/fluid/lite/kernels/cuda/mul_compute.cc
浏览文件 @
621d1522
...
...
@@ -13,3 +13,22 @@
// limitations under the License.
#include "paddle/fluid/lite/kernels/cuda/mul_compute.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
mul
,
kCUDA
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
MulCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kCUDA
))})
.
BindInput
(
"Y"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kCUDA
))})
.
Finalize
();
paddle/fluid/lite/kernels/cuda/mul_compute.h
浏览文件 @
621d1522
...
...
@@ -16,6 +16,7 @@
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/cuda/blas.h"
#include "paddle/fluid/lite/operators/op_params.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -29,11 +30,29 @@ void mul_compute(const lite::cuda::Blas<float>& blas, const T* x, int x_h,
nullptr
,
out
,
0
);
}
class
MulCompute
:
public
OpKernel
<
TARGET
(
k
Host
),
PRECISION
(
kFloat
)
>
{
class
MulCompute
:
public
OpKernel
<
TARGET
(
k
CUDA
),
PRECISION
(
kFloat
)
>
{
public:
using
param_t
=
operators
::
MulParam
;
void
Run
()
override
{}
void
Run
()
override
{
CHECK
(
context_
)
<<
"running context should be set first"
;
auto
&
context
=
context_
->
AsCudaContext
();
CHECK
(
context
.
blas_fp32
)
<<
"blas should init first"
;
auto
&
blas
=
*
context
.
blas_fp32
;
const
auto
&
param
=
Param
<
operators
::
MulParam
>
();
CHECK
(
param
.
x
->
target
()
==
TARGET
(
kCUDA
));
auto
*
x
=
param
.
x
->
data
<
float
>
();
int
x_h
=
param
.
x
->
dims
()[
0
];
int
x_w
=
param
.
x
->
dims
()[
1
];
auto
*
y
=
param
.
y
->
data
<
float
>
();
int
y_h
=
param
.
y
->
dims
()[
0
];
int
y_w
=
param
.
y
->
dims
()[
1
];
auto
*
out
=
param
.
output
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
mul_compute
<
float
>
(
blas
,
x
,
x_h
,
x_w
,
y
,
y_h
,
y_w
,
out
);
}
virtual
~
MulCompute
()
=
default
;
};
...
...
paddle/fluid/lite/kernels/host/fc_compute.cc
浏览文件 @
621d1522
...
...
@@ -51,8 +51,8 @@ void FcCompute::Run() {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
,
def
)
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
,
def
)
.
BindInput
(
"Input"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/kernels/host/feed_compute.cc
浏览文件 @
621d1522
...
...
@@ -20,7 +20,8 @@ namespace lite {
namespace
kernels
{
namespace
host
{
class
FeedCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
{
class
FeedCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)
>
{
public:
using
param_t
=
operators
::
FeedParam
;
...
...
@@ -38,7 +39,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
k
Float
,
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
k
Any
,
kAny
,
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorAnyTy
>
(
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/kernels/host/fetch_compute.cc
浏览文件 @
621d1522
...
...
@@ -20,7 +20,8 @@ namespace lite {
namespace
kernels
{
namespace
host
{
class
FetchCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
{
class
FetchCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)
>
{
public:
using
param_t
=
operators
::
FeedParam
;
...
...
@@ -41,7 +42,7 @@ class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fetch
,
kHost
,
k
Float
,
REGISTER_LITE_KERNEL
(
fetch
,
kHost
,
k
Any
,
kAny
,
paddle
::
lite
::
kernels
::
host
::
FetchCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorAnyTy
>
(
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/kernels/host/mul_compute.cc
浏览文件 @
621d1522
...
...
@@ -67,7 +67,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
host
::
MulCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/kernels/host/relu_compute.h
浏览文件 @
621d1522
...
...
@@ -42,6 +42,6 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
relu
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
relu
,
kHost
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
host
::
ReluCompute
,
def
)
.
Finalize
();
paddle/fluid/lite/kernels/host/scale_compute.cc
浏览文件 @
621d1522
...
...
@@ -50,7 +50,7 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
host
::
ScaleCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/operators/io_copy_op.cc
浏览文件 @
621d1522
...
...
@@ -24,7 +24,10 @@ bool IoCopyOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
y
);
return
true
;
}
bool
IoCopyOp
::
InferShape
()
const
{
return
true
;
}
bool
IoCopyOp
::
InferShape
()
const
{
param_
.
y
->
Resize
(
param_
.
x
->
dims
());
return
true
;
}
bool
IoCopyOp
::
Run
()
{
return
OpLite
::
Run
();
}
bool
IoCopyOp
::
AttachImpl
(
const
paddle
::
framework
::
OpDesc
&
opdesc
,
paddle
::
lite
::
Scope
*
scope
)
{
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
621d1522
...
...
@@ -51,9 +51,6 @@ class MulOpLite : public OpLite {
param_
.
x_num_col_dims
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"x_num_col_dims"
));
param_
.
y_num_col_dims
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"y_num_col_dims"
));
CHECK
(
kernel_
);
kernel_
->
SetParam
(
param_
);
return
true
;
}
...
...
paddle/fluid/lite/utils/factory.h
浏览文件 @
621d1522
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <iostream>
#include <list>
#include <memory>
...
...
@@ -48,8 +49,6 @@ class Factory {
}
void
Register
(
const
std
::
string
&
op_type
,
creator_t
&&
creator
)
{
CHECK
(
!
creators_
.
count
(
op_type
))
<<
"The op "
<<
op_type
<<
" has already registered"
;
creators_
[
op_type
].
emplace_back
(
std
::
move
(
creator
));
}
...
...
@@ -58,9 +57,9 @@ class Factory {
}
std
::
list
<
item_ptr_t
>
Creates
(
const
std
::
string
&
op_type
)
const
{
auto
it
=
creators_
.
find
(
op_type
);
CHECK
(
it
!=
creators_
.
end
())
<<
"no item called "
<<
op_type
;
std
::
list
<
item_ptr_t
>
res
;
auto
it
=
creators_
.
find
(
op_type
);
if
(
it
==
creators_
.
end
())
return
res
;
for
(
auto
&
c
:
it
->
second
)
{
res
.
emplace_back
(
c
());
}
...
...
paddle/fluid/lite/utils/varient.h
浏览文件 @
621d1522
...
...
@@ -99,7 +99,7 @@ struct variant {
size_t
type
()
{
return
type_id
;
}
void
valid
()
{
return
(
type_id
!=
invalid_type
());
}
bool
valid
()
{
return
(
type_id
!=
invalid_type
());
}
template
<
typename
T
,
typename
...
Args
>
void
set
(
Args
&&
...
args
)
{
...
...
paddle/fluid/lite/utils/varient_test.cc
浏览文件 @
621d1522
...
...
@@ -24,6 +24,8 @@ namespace utils {
TEST
(
varient
,
test
)
{
variant
<
int
,
float
>
a
;
// The initial state should be invalid.
ASSERT_FALSE
(
a
.
valid
());
a
.
set
<
int
>
(
1
);
ASSERT_EQ
(
a
.
get
<
int
>
(),
1
);
a
.
set
<
int
>
(
20
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录