Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8b950a4f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8b950a4f
编写于
4月 15, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mir implementation
上级
f41d73b3
变更
27
隐藏空白更改
内联
并排
Showing
27 changed file
with
634 addition
and
65 deletion
+634
-65
paddle/fluid/lite/core/executor.h
paddle/fluid/lite/core/executor.h
+1
-1
paddle/fluid/lite/core/kernel.cc
paddle/fluid/lite/core/kernel.cc
+0
-5
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+23
-6
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+7
-2
paddle/fluid/lite/core/mir/demo_pass.cc
paddle/fluid/lite/core/mir/demo_pass.cc
+33
-0
paddle/fluid/lite/core/mir/node.cc
paddle/fluid/lite/core/mir/node.cc
+14
-0
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+78
-6
paddle/fluid/lite/core/mir/pass.cc
paddle/fluid/lite/core/mir/pass.cc
+14
-0
paddle/fluid/lite/core/mir/pass.h
paddle/fluid/lite/core/mir/pass.h
+37
-0
paddle/fluid/lite/core/mir/pass_manager.cc
paddle/fluid/lite/core/mir/pass_manager.cc
+30
-0
paddle/fluid/lite/core/mir/pass_manager.h
paddle/fluid/lite/core/mir/pass_manager.h
+80
-0
paddle/fluid/lite/core/mir/pass_manager_test.cc
paddle/fluid/lite/core/mir/pass_manager_test.cc
+30
-0
paddle/fluid/lite/core/mir/pass_registry.cc
paddle/fluid/lite/core/mir/pass_registry.cc
+21
-0
paddle/fluid/lite/core/mir/pass_registry.h
paddle/fluid/lite/core/mir/pass_registry.h
+37
-0
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+14
-0
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+74
-0
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+10
-0
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+34
-10
paddle/fluid/lite/core/type_system.cc
paddle/fluid/lite/core/type_system.cc
+43
-0
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+38
-27
paddle/fluid/lite/kernels/host/fc_compute.cc
paddle/fluid/lite/kernels/host/fc_compute.cc
+5
-2
paddle/fluid/lite/operators/fc_op.h
paddle/fluid/lite/operators/fc_op.h
+4
-1
paddle/fluid/lite/operators/fc_op_test.cc
paddle/fluid/lite/operators/fc_op_test.cc
+1
-1
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+2
-1
paddle/fluid/lite/operators/relu_op.cc
paddle/fluid/lite/operators/relu_op.cc
+1
-1
paddle/fluid/lite/operators/relu_op.h
paddle/fluid/lite/operators/relu_op.h
+1
-1
paddle/fluid/lite/operators/scale_op.cc
paddle/fluid/lite/operators/scale_op.cc
+2
-1
未找到文件。
paddle/fluid/lite/core/executor.h
浏览文件 @
8b950a4f
...
@@ -52,7 +52,7 @@ class Executor {
...
@@ -52,7 +52,7 @@ class Executor {
ops_
.
emplace_back
(
LiteOpRegistry
::
Global
().
Create
(
op_type
));
ops_
.
emplace_back
(
LiteOpRegistry
::
Global
().
Create
(
op_type
));
// pick initial kernel
// pick initial kernel
ops_
.
back
()
->
PickKernel
(
valid_places_
);
ops_
.
back
()
->
PickKernel
(
valid_places_
);
ops_
.
back
()
->
Attach
(
*
op_desc
,
exec_scope_
);
ops_
.
back
()
->
Attach
Impl
(
*
op_desc
,
exec_scope_
);
}
}
}
}
...
...
paddle/fluid/lite/core/kernel.cc
浏览文件 @
8b950a4f
...
@@ -17,11 +17,6 @@
...
@@ -17,11 +17,6 @@
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
bool
operator
==
(
const
Place
&
a
,
const
Place
&
b
)
{
return
a
.
target
==
b
.
target
&&
a
.
precision
==
b
.
precision
&&
a
.
layout
==
b
.
layout
;
}
bool
operator
<
(
const
Place
&
a
,
const
Place
&
b
)
{
bool
operator
<
(
const
Place
&
a
,
const
Place
&
b
)
{
if
(
a
.
target
!=
b
.
target
)
if
(
a
.
target
!=
b
.
target
)
return
a
.
target
<
b
.
target
;
return
a
.
target
<
b
.
target
;
...
...
paddle/fluid/lite/core/kernel.h
浏览文件 @
8b950a4f
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
#include "paddle/fluid/lite/utils/all.h"
...
@@ -51,6 +52,7 @@ class KernelBase {
...
@@ -51,6 +52,7 @@ class KernelBase {
virtual
TargetType
target
()
const
=
0
;
virtual
TargetType
target
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
DataLayoutType
layout
()
const
=
0
;
virtual
~
KernelBase
()
=
default
;
virtual
~
KernelBase
()
=
default
;
...
@@ -66,17 +68,21 @@ class KernelBase {
...
@@ -66,17 +68,21 @@ class KernelBase {
* registered in the `TypeSystem`.
* registered in the `TypeSystem`.
*/
*/
struct
ParamType
{
struct
ParamType
{
// For unsupported types.
size_t
element_type_hash
{};
size_t
element_type_hash
{};
Place
tensor_place
{};
Place
tensor_place
{};
const
Type
*
type_
;
ParamType
()
=
default
;
ParamType
()
=
default
;
ParamType
(
size_t
element_type_hash
)
:
element_type_hash
(
element_type_hash
)
{}
ParamType
(
size_t
element_type_hash
)
:
element_type_hash
(
element_type_hash
)
{}
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
ParamType
(
const
Type
*
type
)
:
type_
(
type
)
{}
};
};
/*
/*
* The data types of kernel parameters.
* The data types of kernel parameters. It is used to track the type of kernel's
* inputs and outputs.
*/
*/
struct
ParamTypes
{
struct
ParamTypes
{
std
::
vector
<
std
::
vector
<
ParamType
>>
inputs
;
std
::
vector
<
std
::
vector
<
ParamType
>>
inputs
;
...
@@ -115,6 +121,8 @@ struct ParamTypes {
...
@@ -115,6 +121,8 @@ struct ParamTypes {
*/
*/
class
ParamTypeRegistry
{
class
ParamTypeRegistry
{
public:
public:
enum
class
IO
:
int
{
kInput
=
0
,
kOutput
};
template
<
TargetType
target
,
PrecisionType
precision
,
template
<
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
>
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
>
/*
/*
...
@@ -130,7 +138,12 @@ class ParamTypeRegistry {
...
@@ -130,7 +138,12 @@ class ParamTypeRegistry {
NewInstance
(
const
std
::
string
&
kernel_type
)
:
kernel_type_
(
kernel_type
)
{}
NewInstance
(
const
std
::
string
&
kernel_type
)
:
kernel_type_
(
kernel_type
)
{}
NewInstance
&
BindInput
(
int
offset
,
const
ParamType
&
ptype
)
{
NewInstance
&
BindInput
(
int
offset
,
const
ParamType
&
ptype
)
{
ParamTypeRegistry
::
Global
().
Register
(
ParamTypeRegistry
::
Global
().
Register
<
IO
::
kInput
>
(
kernel_type_
,
Place
{
target
,
precision
,
layout
},
offset
,
ptype
);
return
*
this
;
}
NewInstance
&
BindOutput
(
int
offset
,
const
ParamType
&
ptype
)
{
ParamTypeRegistry
::
Global
().
Register
<
IO
::
kOutput
>
(
kernel_type_
,
Place
{
target
,
precision
,
layout
},
offset
,
ptype
);
kernel_type_
,
Place
{
target
,
precision
,
layout
},
offset
,
ptype
);
return
*
this
;
return
*
this
;
}
}
...
@@ -141,8 +154,12 @@ class ParamTypeRegistry {
...
@@ -141,8 +154,12 @@ class ParamTypeRegistry {
std
::
string
kernel_type_
;
std
::
string
kernel_type_
;
};
};
template
<
IO
io
>
void
Register
(
const
std
::
string
&
kernel_type
,
const
Place
&
place
,
int
offset
,
void
Register
(
const
std
::
string
&
kernel_type
,
const
Place
&
place
,
int
offset
,
ParamType
data_type
)
{}
ParamType
data_type
)
{
KernelIdTy
key
{
kernel_type
,
place
,
io
,
offset
};
types_
[
key
]
=
data_type
;
}
ParamType
Retrive
(
const
Place
&
place
,
int
offset
);
ParamType
Retrive
(
const
Place
&
place
,
int
offset
);
...
@@ -155,16 +172,15 @@ class ParamTypeRegistry {
...
@@ -155,16 +172,15 @@ class ParamTypeRegistry {
ParamTypeRegistry
()
=
default
;
ParamTypeRegistry
()
=
default
;
public:
public:
enum
class
IO
:
int
{
kInput
=
0
,
kOutput
};
// Identification for a Kernel.
// Identification for a Kernel.
struct
KernelIdT
{
struct
KernelIdT
y
{
std
::
string
kernel_type
;
std
::
string
kernel_type
;
Place
place
;
Place
place
;
IO
io
;
IO
io
;
int
offset
;
int
offset
;
};
};
using
key_t
=
KernelIdT
;
using
key_t
=
KernelIdT
y
;
struct
KeyCmp
{
struct
KeyCmp
{
bool
operator
()(
const
key_t
&
a
,
const
key_t
&
b
)
const
;
bool
operator
()(
const
key_t
&
a
,
const
key_t
&
b
)
const
;
};
};
...
@@ -188,6 +204,7 @@ class OpKernel : public KernelBase {
...
@@ -188,6 +204,7 @@ class OpKernel : public KernelBase {
TargetType
target
()
const
override
{
return
Target
;
}
TargetType
target
()
const
override
{
return
Target
;
}
PrecisionType
precision
()
const
override
{
return
Precision
;
}
PrecisionType
precision
()
const
override
{
return
Precision
;
}
DataLayoutType
layout
()
const
override
{
return
DataLayout
;
}
void
Touch
()
{}
void
Touch
()
{}
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
8b950a4f
cc_library
(
mir_pass SRCS pass.cc
)
cc_library
(
mir_node SRCS node.cc
)
cc_library
(
mir_node SRCS node.cc
)
cc_library
(
mir_ssa_graph SRCS ssa_graph.cc
)
cc_library
(
mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node
)
\ No newline at end of file
cc_library
(
mir_pass SRCS pass.cc DEPS mir_ssa_graph
)
cc_library
(
mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
cc_library
(
mir_demo_pass SRCS demo_pass.cc DEPS mir_pass
)
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_demo_pass
)
paddle/fluid/lite/core/mir/demo_pass.cc
0 → 100644
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
DemoPass
:
public
mir
::
Pass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{}
};
bool
RegisterDemoPass
()
{
return
PassManager
::
Global
().
AddNewPass
(
"demo"
,
new
DemoPass
);
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/node.cc
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/node.h"
paddle/fluid/lite/core/mir/node.h
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
// Node in a MIR graph.
class
Node
{
class
Node
{
public:
public:
// Tell is instruction.
std
::
list
<
Node
*>
inlinks
;
bool
IsInstruct
()
const
;
std
::
list
<
Node
*>
outlinks
;
// Tell is an argument.
bool
IsArgument
()
const
;
Node
()
=
default
;
};
enum
class
Role
{
kUnk
=
-
1
,
kArgument
,
kInstruct
,
kNumRoles
/*should be last*/
};
struct
Instruct
{
std
::
string
op_type
;
Place
place
;
// The kernel instances this Instruct contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
};
struct
Argument
{
std
::
string
name
;
Place
place
;
};
// Set roles.
Argument
&
AsArgument
()
{
if
(
role_
!=
Role
::
kUnk
)
{
CHECK
(
role_
==
Role
::
kArgument
);
return
*
argument_
;
}
role_
=
Role
::
kArgument
;
argument_
.
reset
(
new
Argument
);
return
*
argument_
;
}
Instruct
&
AsInstruct
()
{
if
(
role_
!=
Role
::
kUnk
)
{
CHECK
(
role_
==
Role
::
kInstruct
);
return
*
instruct_
;
}
role_
=
Role
::
kInstruct
;
instruct_
.
reset
(
new
Instruct
);
return
*
instruct_
;
}
// Check roles.
bool
IsRoleSet
()
const
{
return
role_
==
Role
::
kUnk
;
}
bool
IsInstruct
()
const
{
return
role_
==
Role
::
kInstruct
;
}
bool
IsArgument
()
const
{
return
role_
==
Role
::
kArgument
;
}
private:
// Either instruct_ or argument_ is used.
std
::
unique_ptr
<
Instruct
>
instruct_
;
std
::
unique_ptr
<
Argument
>
argument_
;
Role
role_
{
Role
::
kUnk
};
};
}
// namespace mir
}
// namespace mir
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
\ No newline at end of file
paddle/fluid/lite/core/mir/pass.cc
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass.h"
paddle/fluid/lite/core/mir/pass.h
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
Pass
{
public:
virtual
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
=
0
;
const
std
::
string
&
name
()
const
{
return
name_
;
}
virtual
~
Pass
()
=
default
;
private:
std
::
string
name_
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/pass_manager.cc
0 → 100644
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
PassManager
::
PassManager
()
{}
// Manually register here.
extern
bool
RegisterDemoPass
();
static
bool
xx
__attribute__
((
unused
))
=
RegisterDemoPass
();
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/pass_manager.h
0 → 100644
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <map>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
PassManager
{
public:
static
PassManager
&
Global
()
{
static
PassManager
x
;
return
x
;
}
PassManager
();
void
Run
()
{
for
(
auto
&
pass
:
passes_
)
{
LOG
(
INFO
)
<<
"Running MIR pass "
<<
pass
->
name
();
pass
->
Apply
(
graph_
);
}
}
bool
AddNewPass
(
const
std
::
string
&
name
,
Pass
*
pass
)
{
passes_
.
emplace_back
(
pass
);
pass_map_
.
emplace
(
name
,
passes_
.
back
().
get
());
return
true
;
}
// Clear all the passes.
void
Clear
()
{
passes_
.
clear
();
}
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>::
iterator
passes_begin
()
{
return
passes_
.
begin
();
}
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>::
iterator
passes_end
()
{
return
passes_
.
end
();
}
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>::
const_iterator
passes_const_begin
()
const
{
return
passes_
.
begin
();
}
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>::
const_iterator
passes_const_end
()
const
{
return
passes_
.
end
();
}
Pass
*
LookUp
(
const
std
::
string
&
key
)
{
auto
it
=
pass_map_
.
find
(
key
);
CHECK
(
it
!=
pass_map_
.
end
());
return
it
->
second
;
}
private:
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>
passes_
;
std
::
map
<
std
::
string
,
mir
::
Pass
*>
pass_map_
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/pass_manager_test.cc
0 → 100644
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
TEST
(
PassManager
,
test
)
{
auto
*
pass
=
PassManager
::
Global
().
LookUp
(
"demo"
);
LOG
(
INFO
)
<<
"pass: "
<<
pass
;
ASSERT_TRUE
(
pass
!=
nullptr
);
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/pass_registry.cc
0 → 100644
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/pass_registry.h
0 → 100644
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
PassRegistry
{
public:
PassRegistry
(
const
std
::
string
&
name
,
mir
::
Pass
*
pass
)
{
LOG
(
INFO
)
<<
"Registry add MIR pass "
<<
name
;
PassManager
::
Global
().
AddNewPass
(
name
,
pass
);
}
bool
Touch
()
const
{
return
true
;
}
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
// A program is used to represent a code program, in Paddle, a code program
// contains:
// - main block, which is a list of OpLite
// - scope: which contains all the weights
struct
Program
{
std
::
list
<
std
::
unique_ptr
<
OpLite
>>
ops
;
lite
::
Scope
*
scope
;
};
// An Graph for MIR. It is built from a list of Op and a scope.
class
GraphBase
{};
class
SSAGraph
:
GraphBase
{
public:
// @param program: the op program
// @param valid_places: the valid places user set for the system.
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
for
(
auto
&
op
:
program
.
ops
)
{
node_storage_
.
emplace_back
();
// TODO(Superjomn) remove one valid_places here.
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_kernel
=
node_storage_
.
back
().
AsInstruct
();
new_kernel
.
valid_kernels
=
op
->
CreateKernels
(
valid_places
);
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
// collect inputs and outputs
for
(
const
std
::
string
&
name
:
op
->
input_names
())
{
new_node
.
inlinks
.
push_back
(
arguments_
.
at
(
name
));
}
for
(
const
std
::
string
&
name
:
op
->
output_names
())
{
new_node
.
outlinks
.
push_back
(
arguments_
.
at
(
name
));
}
}
}
std
::
vector
<
mir
::
Node
*>
TopoloticalOrder
()
const
;
private:
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
8b950a4f
...
@@ -44,5 +44,15 @@ void OpLite::PickKernel(const std::vector<Place> &valid_places,
...
@@ -44,5 +44,15 @@ void OpLite::PickKernel(const std::vector<Place> &valid_places,
}
}
}
}
bool
OpLite
::
Run
()
{
CHECK
(
kernel_
);
SyncInputEvents
();
kernel_
->
Run
();
RecordOutputEvents
();
return
true
;
}
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/core/op_lite.h
浏览文件 @
8b950a4f
...
@@ -36,6 +36,11 @@ struct Registry {
...
@@ -36,6 +36,11 @@ struct Registry {
void
Touch
()
{}
void
Touch
()
{}
};
};
namespace
mir
{
class
Node
;
class
SSAGraph
;
}
/**
/**
* The base class of an light-weight operators, currently just used in inference
* The base class of an light-weight operators, currently just used in inference
* to eliminate overhead of some operations in current framework.
* to eliminate overhead of some operations in current framework.
...
@@ -71,19 +76,13 @@ class OpLite : public Registry {
...
@@ -71,19 +76,13 @@ class OpLite : public Registry {
// Inference the outputs' shape.
// Inference the outputs' shape.
virtual
bool
InferShape
()
const
{
return
true
;
}
virtual
bool
InferShape
()
const
{
return
true
;
}
// Run this operator.
// Run this operator.
virtual
bool
Run
()
{
virtual
bool
Run
();
CHECK
(
kernel_
);
SyncInputEvents
();
kernel_
->
Run
();
RecordOutputEvents
();
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
return
true
;
ExtractInputsAndOutputs
(
opdesc
);
return
AttachImpl
(
opdesc
,
scope
);
}
}
// Attach it with the runtime environment.
virtual
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
// Human-readable information.
// Human-readable information.
virtual
std
::
string
DebugString
()
const
=
0
;
virtual
std
::
string
DebugString
()
const
=
0
;
...
@@ -92,9 +91,29 @@ class OpLite : public Registry {
...
@@ -92,9 +91,29 @@ class OpLite : public Registry {
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
const
std
::
list
<
std
::
string
>
&
input_names
()
const
{
return
input_names_
;
}
const
std
::
list
<
std
::
string
>
&
output_names
()
const
{
return
output_names_
;
}
virtual
~
OpLite
()
=
default
;
virtual
~
OpLite
()
=
default
;
protected:
protected:
// Attach it with the runtime environment.
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
void
ExtractInputsAndOutputs
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
Inputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
input_names_
.
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
Outputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
output_names_
.
push_back
(
x
);
}
}
}
// Specify the kernel to run by default. This will specify the value of
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
// `kernel_place_`.
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
...
@@ -113,12 +132,17 @@ class OpLite : public Registry {
...
@@ -113,12 +132,17 @@ class OpLite : public Registry {
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
CreateKernels
(
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
CreateKernels
(
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
=
""
);
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
=
""
);
friend
class
mir
::
Node
;
friend
class
mir
::
SSAGraph
;
protected:
protected:
std
::
unique_ptr
<
OpContext
>
op_context_
;
std
::
unique_ptr
<
OpContext
>
op_context_
;
std
::
unique_ptr
<
KernelBase
>
kernel_
;
std
::
unique_ptr
<
KernelBase
>
kernel_
;
std
::
string
op_type_
;
std
::
string
op_type_
;
std
::
vector
<
Place
>
valid_places_
;
std
::
vector
<
Place
>
valid_places_
;
Place
kernel_place_
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)};
Place
kernel_place_
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)};
std
::
list
<
std
::
string
>
input_names_
;
std
::
list
<
std
::
string
>
output_names_
;
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/type_system.cc
浏览文件 @
8b950a4f
...
@@ -13,3 +13,46 @@
...
@@ -13,3 +13,46 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace
paddle
{
namespace
lite
{
// ------------------------- GetType specification ----------------------------
template
<
>
const
Type
*
Type
::
Get
<
false
/*is_unsupported*/
,
false
/*is_tensor*/
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
()
{
static
UnsupportedTy
x
;
return
&
x
;
}
template
<
>
const
Type
*
Type
::
Get
<
false
/*is_unsupported*/
,
true
/*is_tensor*/
,
TargetType
::
kX86
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
()
{
static
TensorFp32NCHWTy
x
(
TargetType
::
kX86
);
return
&
x
;
}
template
<
>
const
Type
*
Type
::
Get
<
UnsupportedTy
>
(
TargetType
target
,
int
device
)
{
return
Get
<
false
,
false
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
}
template
<
>
const
Type
*
Type
::
Get
<
TensorFp32NCHWTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TargetType
::
kX86
:
return
Get
<
false
,
true
,
TargetType
::
kX86
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
default:
LOG
(
FATAL
)
<<
"unsupported target "
<<
TargetToStr
(
target
);
return
nullptr
;
}
}
// ------------------------- end GetType specification ------------------------
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/type_system.h
浏览文件 @
8b950a4f
...
@@ -82,7 +82,7 @@ class DataTypeBase {
...
@@ -82,7 +82,7 @@ class DataTypeBase {
* Datatype with device info considered.
* Datatype with device info considered.
* NOTE A Type with different device is treated as different DeviceDataType.
* NOTE A Type with different device is treated as different DeviceDataType.
*/
*/
class
DeviceData
Type
:
public
DataTypeBase
{
class
Type
:
public
DataTypeBase
{
public:
public:
TargetType
target
()
const
{
return
place_
.
target
;
}
TargetType
target
()
const
{
return
place_
.
target
;
}
PrecisionType
precision
()
const
{
return
place_
.
precision
;
}
PrecisionType
precision
()
const
{
return
place_
.
precision
;
}
...
@@ -90,23 +90,31 @@ class DeviceDataType : public DataTypeBase {
...
@@ -90,23 +90,31 @@ class DeviceDataType : public DataTypeBase {
const
Place
&
place
()
const
{
return
place_
;
}
const
Place
&
place
()
const
{
return
place_
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
bool
operator
==
(
const
DeviceData
Type
&
other
)
{
bool
operator
==
(
const
Type
&
other
)
{
return
id_
==
other
.
id
()
&&
place_
==
other
.
place
();
return
id_
==
other
.
id
()
&&
place_
==
other
.
place
();
}
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a instruction to transform a type to another.
// is is possible to add a instruction to transform a type to another.
virtual
bool
TypeCastable
(
const
DeviceDataType
&
type
)
const
{
virtual
bool
TypeCastable
(
const
Type
&
type
)
const
{
return
id_
==
type
.
id
();
}
return
id_
==
type
.
id
();
}
template
<
bool
is_unknown
,
bool
is_tensor
=
true
,
TargetType
target
=
TargetType
::
kHost
,
PrecisionType
precision
=
PrecisionType
::
kFloat
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
>
// Get a type.
static
const
Type
*
Get
();
template
<
typename
TypeTy
>
static
const
Type
*
Get
(
TargetType
target
=
TargetType
::
kHost
);
virtual
~
DeviceData
Type
()
=
default
;
virtual
~
Type
()
=
default
;
protected:
protected:
DeviceData
Type
(
ID
id
,
const
std
::
string
&
name
,
bool
is_tensor
,
Type
(
ID
id
,
const
std
::
string
&
name
,
bool
is_tensor
,
TargetType
target
=
TargetType
::
kHost
,
TargetType
target
=
TargetType
::
kHost
,
PrecisionType
precision
=
PrecisionType
::
kFloat
,
PrecisionType
precision
=
PrecisionType
::
kFloat
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
)
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
)
:
DataTypeBase
(
id
,
is_tensor
),
:
DataTypeBase
(
id
,
is_tensor
),
place_
{
target
,
precision
,
layout
},
place_
{
target
,
precision
,
layout
},
name_
(
name
)
{}
name_
(
name
)
{}
...
@@ -117,30 +125,33 @@ class DeviceDataType : public DataTypeBase {
...
@@ -117,30 +125,33 @@ class DeviceDataType : public DataTypeBase {
};
};
// -------------------------------- predefined types ---------------------------
// -------------------------------- predefined types ---------------------------
class
Void
:
public
DeviceDataType
{
// TODO(Superjomn) make all the Types' constructs protected to make sure there
// is only one instance across the system.
class
VoidTy
:
public
Type
{
public:
VoidTy
()
:
Type
(
ID
::
Void
,
"Void"
,
false
/*is_tensor*/
)
{}
};
class
UnsupportedTy
:
public
Type
{
public:
public:
Void
()
:
DeviceDataType
(
ID
::
Void
,
"Voi
d"
,
false
/*is_tensor*/
)
{}
UnsupportedTy
()
:
Type
(
ID
::
Unsupported
,
"Unsupporte
d"
,
false
/*is_tensor*/
)
{}
};
};
class
TensorFp32NCHW
:
public
DeviceData
Type
{
class
TensorFp32NCHW
Ty
:
public
Type
{
public:
public:
TensorFp32NCHW
(
TargetType
target
)
TensorFp32NCHWTy
(
TargetType
target
)
:
DeviceDataType
(
ID
::
Tensor_Fp32_NCHW
,
"TensorFp32NCHW"
,
:
Type
(
ID
::
Tensor_Fp32_NCHW
,
"TensorFp32NCHW"
,
true
/*is_tensor*/
,
target
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kFloat
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
)
{}
DataLayoutType
::
kNCHW
)
{}
};
};
class
TensorInt8NCHW
:
public
DeviceData
Type
{
class
TensorInt8NCHW
Ty
:
public
Type
{
public:
public:
TensorInt8NCHW
(
TargetType
target
)
TensorInt8NCHWTy
(
TargetType
target
)
:
DeviceDataType
(
ID
::
Tensor_Int8_NCHW
,
"TensorInt8NCHW"
,
:
Type
(
ID
::
Tensor_Int8_NCHW
,
"TensorInt8NCHW"
,
true
/*is_tensor*/
,
target
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kInt8
,
PrecisionType
::
kInt8
,
DataLayoutType
::
kNCHW
)
{}
DataLayoutType
::
kNCHW
)
{}
};
};
class
TensorInt64NCHW
:
public
DeviceData
Type
{
class
TensorInt64NCHW
Ty
:
public
Type
{
public:
public:
TensorInt64NCHW
(
TargetType
target
)
TensorInt64NCHWTy
(
TargetType
target
)
:
DeviceDataType
(
ID
::
Tensor_Int64_NCHW
,
"TensorInt64NCHW"
,
:
Type
(
ID
::
Tensor_Int64_NCHW
,
"TensorInt64NCHW"
,
true
/*is_tensor*/
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kInt8
,
target
,
PrecisionType
::
kInt8
,
DataLayoutType
::
kNCHW
)
{}
DataLayoutType
::
kNCHW
)
{}
};
};
// ------------------------- end predefined types ---------------------------
// ------------------------- end predefined types ---------------------------
...
...
paddle/fluid/lite/kernels/host/fc_compute.cc
浏览文件 @
8b950a4f
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/lite/kernels/host/fc_compute.h"
#include "paddle/fluid/lite/kernels/host/fc_compute.h"
#include <Eigen/Core>
#include <Eigen/Core>
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -51,6 +52,8 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
...
@@ -51,6 +52,8 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
)
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
)
.
BindInput
(
0
,
{
typeid
(
paddle
::
lite
::
Tensor
).
hash_code
(),
.
BindInput
(
0
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
paddle
::
lite
::
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}})
TARGET
(
kX86
))})
.
BindOutput
(
0
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kX86
))})
.
Finalize
();
.
Finalize
();
paddle/fluid/lite/operators/fc_op.h
浏览文件 @
8b950a4f
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#pragma once
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/kernel.h"
...
@@ -44,7 +46,8 @@ class FcOpLite : public OpLite {
...
@@ -44,7 +46,8 @@ class FcOpLite : public OpLite {
*/
*/
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
Attach
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"Input"
).
front
();
auto
input
=
op_desc
.
Input
(
"Input"
).
front
();
auto
W
=
op_desc
.
Input
(
"W"
).
front
();
auto
W
=
op_desc
.
Input
(
"W"
).
front
();
auto
bias
=
op_desc
.
Input
(
"Bias"
).
front
();
auto
bias
=
op_desc
.
Input
(
"Bias"
).
front
();
...
...
paddle/fluid/lite/operators/fc_op_test.cc
浏览文件 @
8b950a4f
...
@@ -61,7 +61,7 @@ TEST(fc_op_lite, test) {
...
@@ -61,7 +61,7 @@ TEST(fc_op_lite, test) {
fc
.
SetValidPlaces
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
fc
.
SetValidPlaces
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
fc
.
PickKernel
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
fc
.
PickKernel
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
fc
.
Attach
(
desc
,
&
scope
);
fc
.
Attach
Impl
(
desc
,
&
scope
);
fc
.
Run
();
fc
.
Run
();
for
(
int
i
=
0
;
i
<
10
*
20
;
i
++
)
{
for
(
int
i
=
0
;
i
<
10
*
20
;
i
++
)
{
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
8b950a4f
...
@@ -37,7 +37,8 @@ class MulOpLite : public OpLite {
...
@@ -37,7 +37,8 @@ class MulOpLite : public OpLite {
bool
InferShape
()
const
override
;
bool
InferShape
()
const
override
;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
Attach
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
...
...
paddle/fluid/lite/operators/relu_op.cc
浏览文件 @
8b950a4f
...
@@ -31,7 +31,7 @@ bool ReluOp::InferShape() const {
...
@@ -31,7 +31,7 @@ bool ReluOp::InferShape() const {
return
true
;
return
true
;
}
}
bool
ReluOp
::
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
ReluOp
::
Attach
Impl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
input
=
const_cast
<
Tensor
*>
(
param_
.
input
=
const_cast
<
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"Input"
).
front
())
->
Get
<
Tensor
>
());
&
scope
->
FindVar
(
opdesc
.
Input
(
"Input"
).
front
())
->
Get
<
Tensor
>
());
param_
.
output
=
param_
.
output
=
...
...
paddle/fluid/lite/operators/relu_op.h
浏览文件 @
8b950a4f
...
@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
...
@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
bool
InferShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
Attach
Impl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
std
::
string
DebugString
()
const
override
{
return
"tanh"
;
}
std
::
string
DebugString
()
const
override
{
return
"tanh"
;
}
...
...
paddle/fluid/lite/operators/scale_op.cc
浏览文件 @
8b950a4f
...
@@ -44,7 +44,8 @@ class ScaleOp : public OpLite {
...
@@ -44,7 +44,8 @@ class ScaleOp : public OpLite {
}
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
Attach
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录