Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
65be35af
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
699
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
65be35af
编写于
6月 04, 2019
作者:
Y
Yan Chunwei
提交者:
GitHub
6月 04, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Lite/refactor cp desc (#17831)
上级
b7cf0984
变更
45
隐藏空白更改
内联
并排
Showing
45 changed file
with
774 addition
and
331 deletion
+774
-331
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+3
-1
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
+5
-7
paddle/fluid/lite/core/mir/type_target_transform_pass.h
paddle/fluid/lite/core/mir/type_target_transform_pass.h
+3
-3
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+2
-2
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+3
-92
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+57
-36
paddle/fluid/lite/core/program.cc
paddle/fluid/lite/core/program.cc
+4
-4
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+6
-2
paddle/fluid/lite/model_parser/CMakeLists.txt
paddle/fluid/lite/model_parser/CMakeLists.txt
+4
-5
paddle/fluid/lite/model_parser/compatible_pb.cc
paddle/fluid/lite/model_parser/compatible_pb.cc
+111
-0
paddle/fluid/lite/model_parser/compatible_pb.h
paddle/fluid/lite/model_parser/compatible_pb.h
+7
-18
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
+1
-0
paddle/fluid/lite/model_parser/cpp/op_desc.cc
paddle/fluid/lite/model_parser/cpp/op_desc.cc
+66
-0
paddle/fluid/lite/model_parser/cpp/op_desc.h
paddle/fluid/lite/model_parser/cpp/op_desc.h
+126
-0
paddle/fluid/lite/model_parser/desc_apis.h
paddle/fluid/lite/model_parser/desc_apis.h
+85
-0
paddle/fluid/lite/model_parser/op_desc_test.cc
paddle/fluid/lite/model_parser/op_desc_test.cc
+107
-0
paddle/fluid/lite/model_parser/pb/op_desc.cc
paddle/fluid/lite/model_parser/pb/op_desc.cc
+80
-19
paddle/fluid/lite/model_parser/pb/op_desc.h
paddle/fluid/lite/model_parser/pb/op_desc.h
+38
-78
paddle/fluid/lite/operators/activation_ops.cc
paddle/fluid/lite/operators/activation_ops.cc
+2
-2
paddle/fluid/lite/operators/concat_op.cc
paddle/fluid/lite/operators/concat_op.cc
+2
-2
paddle/fluid/lite/operators/concat_op.h
paddle/fluid/lite/operators/concat_op.h
+1
-1
paddle/fluid/lite/operators/concat_op_test.cc
paddle/fluid/lite/operators/concat_op_test.cc
+1
-1
paddle/fluid/lite/operators/dropout_op.cc
paddle/fluid/lite/operators/dropout_op.cc
+6
-6
paddle/fluid/lite/operators/elementwise_ops.cc
paddle/fluid/lite/operators/elementwise_ops.cc
+5
-5
paddle/fluid/lite/operators/fc_op.h
paddle/fluid/lite/operators/fc_op.h
+2
-2
paddle/fluid/lite/operators/fc_op_test.cc
paddle/fluid/lite/operators/fc_op_test.cc
+1
-1
paddle/fluid/lite/operators/feed_op.cc
paddle/fluid/lite/operators/feed_op.cc
+2
-2
paddle/fluid/lite/operators/fetch_op.cc
paddle/fluid/lite/operators/fetch_op.cc
+2
-2
paddle/fluid/lite/operators/fill_constant_op.cc
paddle/fluid/lite/operators/fill_constant_op.cc
+5
-5
paddle/fluid/lite/operators/io_copy_op.cc
paddle/fluid/lite/operators/io_copy_op.cc
+2
-1
paddle/fluid/lite/operators/io_copy_op.h
paddle/fluid/lite/operators/io_copy_op.h
+1
-1
paddle/fluid/lite/operators/mean_op.cc
paddle/fluid/lite/operators/mean_op.cc
+3
-3
paddle/fluid/lite/operators/mul_op.cc
paddle/fluid/lite/operators/mul_op.cc
+1
-1
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+4
-4
paddle/fluid/lite/operators/relu_op.cc
paddle/fluid/lite/operators/relu_op.cc
+1
-1
paddle/fluid/lite/operators/relu_op.h
paddle/fluid/lite/operators/relu_op.h
+1
-1
paddle/fluid/lite/operators/reshape_op.cc
paddle/fluid/lite/operators/reshape_op.cc
+5
-5
paddle/fluid/lite/operators/reshape_op.h
paddle/fluid/lite/operators/reshape_op.h
+2
-2
paddle/fluid/lite/operators/reshape_op_test.cc
paddle/fluid/lite/operators/reshape_op_test.cc
+4
-4
paddle/fluid/lite/operators/scale_op.cc
paddle/fluid/lite/operators/scale_op.cc
+4
-4
paddle/fluid/lite/operators/scale_op.h
paddle/fluid/lite/operators/scale_op.h
+1
-1
paddle/fluid/lite/operators/scale_op_test.cc
paddle/fluid/lite/operators/scale_op_test.cc
+2
-2
paddle/fluid/lite/operators/softmax_op.cc
paddle/fluid/lite/operators/softmax_op.cc
+4
-3
paddle/fluid/lite/operators/softmax_op.h
paddle/fluid/lite/operators/softmax_op.h
+1
-1
paddle/fluid/lite/operators/softmax_op_test.cc
paddle/fluid/lite/operators/softmax_op_test.cc
+1
-1
未找到文件。
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
65be35af
...
@@ -25,7 +25,9 @@ cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite)
...
@@ -25,7 +25,9 @@ cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite)
cc_library
(
scope_lite SRCS scope.cc DEPS
${
tensor_lite
}
)
cc_library
(
scope_lite SRCS scope.cc DEPS
${
tensor_lite
}
)
cc_library
(
cpu_info_lite SRCS cpu_info.cc
)
cc_library
(
cpu_info_lite SRCS cpu_info.cc
)
cc_library
(
context_lite SRCS context.cc DEPS
${
tensor_lite
}
any_lite cpu_info_lite
)
cc_library
(
context_lite SRCS context.cc DEPS
${
tensor_lite
}
any_lite cpu_info_lite
)
cc_library
(
op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite target_wrapper_lite
${
tensor_lite
}
)
cc_library
(
op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapper_lite
cpp_op_desc_lite
${
tensor_lite
}
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS
${
tensor_lite
}
target_wrapper_lite
)
cc_library
(
type_system SRCS type_system.cc DEPS
${
tensor_lite
}
target_wrapper_lite
)
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
浏览文件 @
65be35af
...
@@ -90,7 +90,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -90,7 +90,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
inst_node
->
AsStmt
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
inst_node
->
AsStmt
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
// Create IoCopy Instruction.
lite
::
OpDesc
op_desc
;
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"io_copy"
);
op_desc
.
SetType
(
"io_copy"
);
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
...
@@ -104,8 +104,6 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -104,8 +104,6 @@ void TypeTargetTransformPass::AddIoCopyInst(
// Update the original instruction OpDesc.
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Update its input to the io_copy_output_name
auto
&
inst
=
inst_node
->
AsStmt
();
auto
inst_program_desc
=
inst
.
op_info
()
->
desc
();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink
(
graph
->
Argument
(
var
),
io_copy_inst
);
DirectedLink
(
graph
->
Argument
(
var
),
io_copy_inst
);
...
@@ -113,11 +111,11 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -113,11 +111,11 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink
(
io_copy_output_arg
,
inst_node
);
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
// reset opdesc and update kernel information
auto
desc_dummy
=
inst_node
->
AsStmt
().
op
->
op_info
()
->
desc
();
UpdateInputTo
(
inst_node
->
AsStmt
().
op
->
mutable_op_info
(),
var
,
UpdateInputTo
(
&
desc_dummy
,
var
,
io_copy_output_name
);
io_copy_output_name
);
lite
::
OpDesc
desc_fake
(
desc_dummy
);
inst_node
->
AsStmt
().
op
->
Attach
(
*
inst_node
->
AsStmt
().
op
->
op_info
(),
inst_node
->
AsStmt
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsStmt
().
op
->
scope
());
inst_node
->
AsStmt
().
op
->
scope
());
std
::
string
tmp
;
std
::
string
tmp
;
if
(
inst_node
->
AsStmt
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
if
(
inst_node
->
AsStmt
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.h
浏览文件 @
65be35af
...
@@ -24,10 +24,10 @@ namespace paddle {
...
@@ -24,10 +24,10 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
static
void
UpdateInputTo
(
framework
::
proto
::
OpDesc
*
desc
,
static
void
UpdateInputTo
(
cpp
::
OpDesc
*
desc
,
const
std
::
string
&
from
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
input
:
*
item
.
mutable_arguments
()
)
{
for
(
auto
&
input
:
item
.
second
)
{
if
(
input
==
from
)
{
if
(
input
==
from
)
{
input
=
to
;
input
=
to
;
}
}
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
65be35af
...
@@ -65,7 +65,7 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -65,7 +65,7 @@ class VariablePlaceInferencePass : public DebugPass {
// check if inputs's place is set, if not set, update them with the
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
// kernel's declaration.
auto
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
auto
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
input
_argument
().
at
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
input
s
().
at
(
arg_name
);
for
(
auto
&
arg_name
:
arg_names
)
{
for
(
auto
&
arg_name
:
arg_names
)
{
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
...
@@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass {
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
output_argnames
())
{
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
output_argnames
())
{
VLOG
(
3
)
<<
"-- output arg_name "
<<
arg_name
;
VLOG
(
3
)
<<
"-- output arg_name "
<<
arg_name
;
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
output
_argument
().
at
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
output
s
().
at
(
arg_name
);
// check if outputs's place is set, if not set, update them with the
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
for
(
auto
&
arg_name
:
arg_names
)
{
...
...
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
65be35af
...
@@ -68,13 +68,13 @@ bool OpLite::Run() {
...
@@ -68,13 +68,13 @@ bool OpLite::Run() {
return
true
;
return
true
;
}
}
bool
OpLite
::
Attach
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
OpLite
::
Attach
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
// valid_places_.clear();
// valid_places_.clear();
CHECK
(
scope
!=
nullptr
);
CHECK
(
scope
!=
nullptr
);
// CHECK(!op_info_.get());
// CHECK(!op_info_.get());
scope_
=
scope
;
scope_
=
scope
;
op_info_
.
reset
(
new
OpInfo
);
// Force clean the out-of-date infomation.
op_info_
.
reset
(
op_info_
->
Build
(
opdesc
.
ReadonlyProto
());
new
OpInfo
(
opdesc
));
// Force clean the out-of-date infomation.
return
AttachImpl
(
opdesc
,
scope
);
return
AttachImpl
(
opdesc
,
scope
);
}
}
...
@@ -92,94 +92,5 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope,
...
@@ -92,94 +92,5 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope,
return
var
->
GetMutable
<
lite
::
Tensor
>
();
return
var
->
GetMutable
<
lite
::
Tensor
>
();
}
}
bool
OpInfo
::
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
input_argument_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
item
.
first
;
return
true
;
}
}
return
false
;
}
bool
OpInfo
::
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
output_argument_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
item
.
first
;
return
true
;
}
}
return
false
;
}
void
OpInfo
::
ExtractInputsAndOutputs
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
inputs
())
{
for
(
const
auto
&
x
:
item
.
arguments
())
{
input_names_
.
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
outputs
())
{
for
(
const
auto
&
x
:
item
.
arguments
())
{
output_names_
.
push_back
(
x
);
}
}
}
void
OpInfo
::
CollectInputAndOutputArgnames
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
inputs
())
{
input_argnames_
.
push_back
(
item
.
parameter
());
}
for
(
const
auto
&
item
:
opdesc
.
outputs
())
{
output_argnames_
.
push_back
(
item
.
parameter
());
}
}
void
OpInfo
::
CollectArguments
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
inputs
())
{
for
(
auto
&
x
:
item
.
arguments
())
{
input_argument_
[
item
.
parameter
()].
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
outputs
())
{
for
(
auto
&
x
:
item
.
arguments
())
{
output_argument_
[
item
.
parameter
()].
push_back
(
x
);
}
}
}
void
OpInfo
::
Build
(
const
framework
::
proto
::
OpDesc
&
desc
)
{
ExtractInputsAndOutputs
(
desc
);
CollectInputAndOutputArgnames
(
desc
);
CollectArguments
(
desc
);
desc_
.
reset
(
new
framework
::
proto
::
OpDesc
(
desc
));
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
OpInfo
::
input_argument
()
const
{
return
input_argument_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
OpInfo
::
output_argument
()
const
{
return
output_argument_
;
}
const
std
::
list
<
std
::
string
>
&
OpInfo
::
input_argnames
()
const
{
return
input_argnames_
;
}
const
std
::
list
<
std
::
string
>
&
OpInfo
::
output_argnames
()
const
{
return
output_argnames_
;
}
const
framework
::
proto
::
OpDesc
&
OpInfo
::
desc
()
const
{
CHECK
(
desc_
)
<<
"desc has't set"
;
return
*
desc_
;
}
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/core/op_lite.h
浏览文件 @
65be35af
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/model_parser/c
ompatible_pb
.h"
#include "paddle/fluid/lite/model_parser/c
pp/op_desc
.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -71,7 +71,7 @@ class OpLite : public Registry {
...
@@ -71,7 +71,7 @@ class OpLite : public Registry {
virtual
bool
Run
();
virtual
bool
Run
();
// Link the external execution environ to internal context.
// Link the external execution environ to internal context.
bool
Attach
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
bool
Attach
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
const
OpInfo
*
op_info
()
const
{
return
op_info_
.
get
();
}
const
OpInfo
*
op_info
()
const
{
return
op_info_
.
get
();
}
OpInfo
*
mutable_op_info
()
{
return
op_info_
.
get
();
}
OpInfo
*
mutable_op_info
()
{
return
op_info_
.
get
();
}
...
@@ -94,7 +94,7 @@ class OpLite : public Registry {
...
@@ -94,7 +94,7 @@ class OpLite : public Registry {
protected:
protected:
// Attach it with the runtime environment.
// Attach it with the runtime environment.
virtual
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
virtual
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
// Specify the kernel to run by default. This will specify the value of
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
// `kernel_place_`.
...
@@ -144,40 +144,61 @@ class OpLite : public Registry {
...
@@ -144,40 +144,61 @@ class OpLite : public Registry {
* Operator Information, such as some description. It will be shared by all the
* Operator Information, such as some description. It will be shared by all the
* kernels of the same operator.
* kernels of the same operator.
*/
*/
class
OpInfo
{
class
OpInfo
:
public
cpp
::
OpDesc
{
public:
public:
// To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf
OpInfo
(
const
OpInfo
&
)
=
default
;
// message instead.
OpInfo
(
const
cpp
::
OpDesc
&
other
)
:
cpp
::
OpDesc
(
other
)
{}
void
Build
(
const
framework
::
proto
::
OpDesc
&
desc
);
// Collect all the input variable's name.
const
framework
::
proto
::
OpDesc
&
desc
()
const
;
std
::
vector
<
std
::
string
>
input_names
()
const
{
framework
::
proto
::
OpDesc
*
mutable_desc
()
{
return
desc_
.
get
();
}
std
::
vector
<
std
::
string
>
res
;
const
std
::
list
<
std
::
string
>
&
input_names
()
const
{
return
input_names_
;
}
for
(
auto
&
param
:
InputArgumentNames
())
{
const
std
::
list
<
std
::
string
>
&
output_names
()
const
{
return
output_names_
;
}
for
(
auto
&
x
:
Input
(
param
))
{
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
input_argument
()
const
;
res
.
push_back
(
x
);
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
output_argument
()
const
;
}
bool
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
;
}
bool
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
;
return
res
;
}
const
std
::
list
<
std
::
string
>
&
input_argnames
()
const
;
const
std
::
list
<
std
::
string
>
&
output_argnames
()
const
;
// Collect all the output variable's name.
std
::
vector
<
std
::
string
>
output_names
()
const
{
private:
std
::
vector
<
std
::
string
>
res
;
void
ExtractInputsAndOutputs
(
const
framework
::
proto
::
OpDesc
&
opdesc
);
for
(
auto
&
param
:
OutputArgumentNames
())
{
for
(
auto
&
x
:
Output
(
param
))
{
void
CollectInputAndOutputArgnames
(
const
framework
::
proto
::
OpDesc
&
opdesc
);
res
.
push_back
(
x
);
}
void
CollectArguments
(
const
framework
::
proto
::
OpDesc
&
opdesc
);
}
return
res
;
private:
}
std
::
list
<
std
::
string
>
input_names_
;
std
::
list
<
std
::
string
>
output_names_
;
std
::
vector
<
std
::
string
>
input_argnames
()
const
{
std
::
list
<
std
::
string
>
input_argnames_
;
return
InputArgumentNames
();
std
::
list
<
std
::
string
>
output_argnames_
;
}
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
input_argument_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
output_argument_
;
std
::
vector
<
std
::
string
>
output_argnames
()
const
{
// NOTE too heavy.
return
OutputArgumentNames
();
std
::
unique_ptr
<
framework
::
proto
::
OpDesc
>
desc_
;
}
bool
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
inputs_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
item
.
first
;
return
true
;
}
}
return
false
;
}
bool
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
outputs_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
item
.
first
;
return
true
;
}
}
return
false
;
}
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/program.cc
浏览文件 @
65be35af
...
@@ -39,11 +39,11 @@ std::string RuntimeProgram::SerializeProgram(
...
@@ -39,11 +39,11 @@ std::string RuntimeProgram::SerializeProgram(
auto
program_dummy
=
desc
;
auto
program_dummy
=
desc
;
program_dummy
.
mutable_blocks
(
0
)
->
clear_ops
();
program_dummy
.
mutable_blocks
(
0
)
->
clear_ops
();
for
(
auto
&
node
:
instructions_
)
{
for
(
auto
&
node
:
instructions_
)
{
auto
desc_dummy
=
node
.
op
()
->
op_info
()
->
desc
()
;
pb
::
OpDesc
pb_desc
;
OpDesc
desc
(
desc_dummy
);
TransformOpDescCppToPb
(
*
node
.
op
()
->
op_info
(),
&
pb_desc
);
desc
.
SetAttr
(
kKernelTypeAttr
,
node
.
kernel
()
->
SerializedKernelType
());
pb_
desc
.
SetAttr
(
kKernelTypeAttr
,
node
.
kernel
()
->
SerializedKernelType
());
// append new opdesc
// append new opdesc
*
program_dummy
.
mutable_blocks
(
0
)
->
add_ops
()
=
*
desc
.
Proto
();
*
program_dummy
.
mutable_blocks
(
0
)
->
add_ops
()
=
*
pb_
desc
.
Proto
();
}
}
return
program_dummy
.
SerializeAsString
();
return
program_dummy
.
SerializeAsString
();
}
}
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
65be35af
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PROFILE
#include "paddle/fluid/lite/core/profile/basic_profiler.h"
#include "paddle/fluid/lite/core/profile/basic_profiler.h"
#endif // LITE_WITH_PROFILE
#endif // LITE_WITH_PROFILE
...
@@ -67,7 +68,7 @@ struct Program {
...
@@ -67,7 +68,7 @@ struct Program {
CHECK
(
ops
.
empty
())
<<
"Executor duplicate Build found"
;
CHECK
(
ops
.
empty
())
<<
"Executor duplicate Build found"
;
// Create operators.
// Create operators.
for
(
const
auto
&
proto_op_desc
:
program
.
blocks
(
0
).
ops
())
{
for
(
const
auto
&
proto_op_desc
:
program
.
blocks
(
0
).
ops
())
{
lite
::
OpDesc
op_desc
(
proto_op_desc
);
pb
::
OpDesc
op_desc
(
proto_op_desc
);
auto
op_type
=
op_desc
.
Type
();
auto
op_type
=
op_desc
.
Type
();
// if (op_type == "feed" || op_type == "fetch") continue;
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG
(
4
)
<<
"create Op ["
<<
op_type
<<
"]"
;
VLOG
(
4
)
<<
"create Op ["
<<
op_type
<<
"]"
;
...
@@ -75,7 +76,10 @@ struct Program {
...
@@ -75,7 +76,10 @@ struct Program {
auto
op
=
LiteOpRegistry
::
Global
().
Create
(
op_type
);
auto
op
=
LiteOpRegistry
::
Global
().
Create
(
op_type
);
CHECK
(
op
)
<<
"no Op found for "
<<
op_type
;
CHECK
(
op
)
<<
"no Op found for "
<<
op_type
;
ops
.
emplace_back
(
std
::
move
(
op
));
ops
.
emplace_back
(
std
::
move
(
op
));
ops
.
back
()
->
Attach
(
op_desc
,
exec_scope
);
cpp
::
OpDesc
cpp_op_desc
;
TransformOpDescPbToCpp
(
op_desc
,
&
cpp_op_desc
);
ops
.
back
()
->
Attach
(
cpp_op_desc
,
exec_scope
);
}
}
}
}
...
...
paddle/fluid/lite/model_parser/CMakeLists.txt
浏览文件 @
65be35af
...
@@ -11,11 +11,7 @@ if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
...
@@ -11,11 +11,7 @@ if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
endif
()
endif
()
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite
)
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite
)
else
()
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto_lite proto_desc
)
endif
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
set
(
model_parser_deps variable_lite scope_lite
${
tensor_lite
}
scope_lite
set
(
model_parser_deps variable_lite scope_lite
${
tensor_lite
}
scope_lite
target_wrapper_host
target_wrapper_host
...
@@ -27,4 +23,7 @@ if (LITE_WITH_CUDA)
...
@@ -27,4 +23,7 @@ if (LITE_WITH_CUDA)
endif
()
endif
()
cc_library
(
model_parser_lite SRCS model_parser.cc DEPS
${
model_parser_deps
}
)
cc_library
(
model_parser_lite SRCS model_parser.cc DEPS
${
model_parser_deps
}
)
cc_test
(
test_op_desc_lite SRCS op_desc_test.cc DEPS cpp_op_desc_lite any_lite op_desc_lite compatible_pb_lite
)
add_subdirectory
(
pb
)
add_subdirectory
(
pb
)
add_subdirectory
(
cpp
)
paddle/fluid/lite/model_parser/compatible_pb.cc
浏览文件 @
65be35af
...
@@ -13,3 +13,114 @@
...
@@ -13,3 +13,114 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#include "compatible_pb.h"
namespace
paddle
{
namespace
lite
{
void
InputsPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
)
{
for
(
const
std
::
string
&
param
:
pb_desc
.
InputArgumentNames
())
{
cpp_desc
->
SetInput
(
param
,
pb_desc
.
Input
(
param
));
}
}
void
InputsCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
)
{
for
(
const
std
::
string
&
param
:
cpp_desc
.
InputArgumentNames
())
{
pb_desc
->
SetInput
(
param
,
cpp_desc
.
Input
(
param
));
}
}
void
OutputsPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
)
{
for
(
const
std
::
string
&
param
:
pb_desc
.
OutputArgumentNames
())
{
cpp_desc
->
SetOutput
(
param
,
pb_desc
.
Output
(
param
));
}
}
void
OutputsCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
)
{
for
(
const
std
::
string
&
param
:
cpp_desc
.
OutputArgumentNames
())
{
pb_desc
->
SetOutput
(
param
,
cpp_desc
.
Output
(
param
));
}
}
void
AttrsPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
)
{
using
AttrType
=
OpDescAPI
::
AttrType
;
auto
set_attr
=
[
&
](
const
std
::
string
&
name
,
AttrType
type
)
{
switch
(
type
)
{
case
AttrType
::
INT
:
cpp_desc
->
SetAttr
<
int32_t
>
(
name
,
pb_desc
.
GetAttr
<
int32_t
>
(
name
));
break
;
case
AttrType
::
FLOAT
:
cpp_desc
->
SetAttr
<
float
>
(
name
,
pb_desc
.
GetAttr
<
float
>
(
name
));
break
;
case
AttrType
::
STRING
:
cpp_desc
->
SetAttr
<
std
::
string
>
(
name
,
pb_desc
.
GetAttr
<
std
::
string
>
(
name
));
break
;
case
AttrType
::
INTS
:
cpp_desc
->
SetAttr
<
std
::
vector
<
int
>>
(
name
,
pb_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
name
));
break
;
case
AttrType
::
FLOATS
:
cpp_desc
->
SetAttr
<
std
::
vector
<
float
>>
(
name
,
pb_desc
.
GetAttr
<
std
::
vector
<
float
>>
(
name
));
break
;
case
AttrType
::
BOOLEAN
:
cpp_desc
->
SetAttr
<
bool
>
(
name
,
pb_desc
.
GetAttr
<
bool
>
(
name
));
break
;
case
AttrType
::
STRINGS
:
cpp_desc
->
SetAttr
<
std
::
vector
<
std
::
string
>>
(
name
,
pb_desc
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
));
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported attr type found "
<<
static_cast
<
int
>
(
type
);
}
};
for
(
const
auto
&
attr_name
:
pb_desc
.
AttrNames
())
{
auto
type
=
pb_desc
.
GetAttrType
(
attr_name
);
set_attr
(
attr_name
,
type
);
}
}
void
AttrsCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
)
{
using
AttrType
=
OpDescAPI
::
AttrType
;
auto
set_attr
=
[
&
](
const
std
::
string
&
name
,
AttrType
type
)
{
switch
(
type
)
{
#define IMPL_ONE(type__, T) \
case AttrType::type__: \
pb_desc->SetAttr<T>(name, cpp_desc.GetAttr<T>(name)); \
break;
IMPL_ONE
(
INT
,
int32_t
);
IMPL_ONE
(
FLOAT
,
float
);
IMPL_ONE
(
STRING
,
std
::
string
);
IMPL_ONE
(
STRINGS
,
std
::
vector
<
std
::
string
>
);
IMPL_ONE
(
FLOATS
,
std
::
vector
<
float
>
);
IMPL_ONE
(
INTS
,
std
::
vector
<
int
>
);
IMPL_ONE
(
BOOLEAN
,
bool
);
default:
LOG
(
FATAL
)
<<
"Unsupported attr type found: "
<<
static_cast
<
int
>
(
type
);
}
};
#undef IMPL_ONE
for
(
const
auto
&
attr_name
:
cpp_desc
.
AttrNames
())
{
auto
type
=
cpp_desc
.
GetAttrType
(
attr_name
);
set_attr
(
attr_name
,
type
);
}
}
void
TransformOpDescPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
)
{
cpp_desc
->
SetType
(
pb_desc
.
Type
());
InputsPbToCpp
(
pb_desc
,
cpp_desc
);
OutputsPbToCpp
(
pb_desc
,
cpp_desc
);
AttrsPbToCpp
(
pb_desc
,
cpp_desc
);
}
void
TransformOpDescCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
)
{
pb_desc
->
SetType
(
cpp_desc
.
Type
());
InputsCppToPb
(
cpp_desc
,
pb_desc
);
OutputsCppToPb
(
cpp_desc
,
pb_desc
);
AttrsCppToPb
(
cpp_desc
,
pb_desc
);
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/compatible_pb.h
浏览文件 @
65be35af
...
@@ -20,39 +20,28 @@
...
@@ -20,39 +20,28 @@
* lite::pb::XXDesc.
* lite::pb::XXDesc.
*/
*/
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/var_desc.h"
#include "paddle/fluid/lite/model_parser/pb/var_desc.h"
#else
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using
Attribute
=
lite
::
pb
::
Attribute
;
using
Attribute
=
lite
::
pb
::
Attribute
;
using
OpDesc
=
lite
::
pb
::
OpDesc
;
using
OpDesc
=
lite
::
pb
::
OpDesc
;
using
VarDesc
=
lite
::
pb
::
VarDesc
;
using
VarDesc
=
lite
::
pb
::
VarDesc
;
#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using
Attribute
=
framework
::
Attribute
;
using
OpDesc
=
framework
::
OpDesc
;
using
VarDesc
=
framework
::
VarDesc
;
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
template
<
typename
T
>
template
<
typename
T
>
T
GetAttr
(
const
Attribute
&
x
)
{
T
GetAttr
(
const
Attribute
&
x
)
{
return
x
.
get
<
T
>
();
return
x
.
get
<
T
>
();
}
}
#else
template
<
typename
T
>
/// Transform an OpDesc from pb to cpp format.
T
GetAttr
(
const
Attribute
&
x
)
{
void
TransformOpDescPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
);
return
boost
::
get
<
T
>
(
x
);
}
/// Transform an OpDesc from cpp to pb format.
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
void
TransformOpDescCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
);
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
0 → 100644
浏览文件 @
65be35af
cc_library
(
cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite
)
paddle/fluid/lite/model_parser/cpp/op_desc.cc
0 → 100644
浏览文件 @
65be35af
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include <set>
namespace
paddle
{
namespace
lite
{
namespace
cpp
{
#define SET_ATTR_IMPL(T, repr__) \
template <> \
void OpDesc::SetAttr<T>(const std::string& name, const T& v) { \
attr_types_[name] = AttrType::repr__; \
attrs_[name].set<T>(v); \
}
SET_ATTR_IMPL
(
int32_t
,
INT
);
SET_ATTR_IMPL
(
float
,
FLOAT
);
SET_ATTR_IMPL
(
std
::
string
,
STRING
);
SET_ATTR_IMPL
(
bool
,
BOOLEAN
);
SET_ATTR_IMPL
(
std
::
vector
<
int
>
,
INTS
);
SET_ATTR_IMPL
(
std
::
vector
<
float
>
,
FLOATS
);
SET_ATTR_IMPL
(
std
::
vector
<
std
::
string
>
,
STRINGS
);
std
::
pair
<
OpDesc
::
attrs_t
::
const_iterator
,
OpDesc
::
attr_types_t
::
const_iterator
>
FindAttr
(
const
cpp
::
OpDesc
&
desc
,
const
std
::
string
&
name
)
{
auto
it
=
desc
.
attrs
().
find
(
name
);
CHECK
(
it
!=
desc
.
attrs
().
end
())
<<
"No attributes called "
<<
name
<<
" found"
;
auto
attr_it
=
desc
.
attr_types
().
find
(
name
);
CHECK
(
attr_it
!=
desc
.
attr_types
().
end
());
return
std
::
make_pair
(
it
,
attr_it
);
}
#define GET_IMPL_ONE(T, repr__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
auto pair = FindAttr(*this, name); \
CHECK(pair.second->second == AttrType::repr__); \
return pair.first->second.get<T>(); \
}
GET_IMPL_ONE
(
int32_t
,
INT
);
GET_IMPL_ONE
(
float
,
FLOAT
);
GET_IMPL_ONE
(
std
::
string
,
STRING
);
GET_IMPL_ONE
(
bool
,
BOOLEAN
);
GET_IMPL_ONE
(
std
::
vector
<
int64_t
>
,
LONGS
);
GET_IMPL_ONE
(
std
::
vector
<
float
>
,
FLOATS
);
GET_IMPL_ONE
(
std
::
vector
<
int
>
,
INTS
);
GET_IMPL_ONE
(
std
::
vector
<
std
::
string
>
,
STRINGS
);
}
// namespace cpp
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/cpp/op_desc.h
0 → 100644
浏览文件 @
65be35af
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/lite/model_parser/desc_apis.h"
#include "paddle/fluid/lite/utils/any.h"
#include "paddle/fluid/lite/utils/varient.h"
namespace
paddle
{
namespace
lite
{
namespace
cpp
{
/*
* The cpp::OpDesc is the internal representation for Op. All the internal
* imprementation should use it, not the pb::OpDesc.
*/
class
OpDesc
:
public
OpDescAPI
{
public:
using
attrs_t
=
std
::
map
<
std
::
string
,
Any
>
;
using
attr_types_t
=
std
::
map
<
std
::
string
,
AttrType
>
;
protected:
std
::
string
type_
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
inputs_
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
outputs_
;
std
::
map
<
std
::
string
,
Any
>
attrs_
;
std
::
map
<
std
::
string
,
AttrType
>
attr_types_
;
public:
OpDesc
()
=
default
;
std
::
string
Type
()
const
override
{
return
type_
;
}
void
SetType
(
const
std
::
string
&
x
)
override
{
type_
=
x
;
}
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>&
inputs
()
const
{
return
inputs_
;
}
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>&
outputs
()
const
{
return
outputs_
;
}
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>*
mutable_inputs
()
{
return
&
inputs_
;
}
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>*
mutable_outputs
()
{
return
&
outputs_
;
}
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
override
{
auto
it
=
inputs_
.
find
(
param
);
CHECK
(
it
!=
inputs_
.
end
());
return
it
->
second
;
}
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
override
{
std
::
vector
<
std
::
string
>
res
;
for
(
const
auto
&
x
:
inputs_
)
res
.
push_back
(
x
.
first
);
return
res
;
}
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
override
{
std
::
vector
<
std
::
string
>
res
;
for
(
const
auto
&
x
:
outputs_
)
res
.
push_back
(
x
.
first
);
return
res
;
}
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
override
{
auto
it
=
outputs_
.
find
(
param
);
CHECK
(
it
!=
outputs_
.
end
());
return
it
->
second
;
}
void
SetInput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>&
args
)
override
{
inputs_
[
param
]
=
args
;
}
void
SetOutput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>&
args
)
override
{
outputs_
[
param
]
=
args
;
}
bool
HasAttr
(
const
std
::
string
&
name
)
const
override
{
return
attrs_
.
count
(
name
);
}
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
override
{
auto
it
=
attr_types_
.
find
(
name
);
CHECK
(
it
!=
attr_types_
.
end
());
return
it
->
second
;
}
std
::
vector
<
std
::
string
>
AttrNames
()
const
override
{
std
::
vector
<
std
::
string
>
res
;
for
(
const
auto
&
x
:
attrs_
)
{
res
.
push_back
(
x
.
first
);
}
return
res
;
}
template
<
typename
T
>
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
);
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
const
std
::
map
<
std
::
string
,
Any
>&
attrs
()
const
{
return
attrs_
;
}
const
std
::
map
<
std
::
string
,
AttrType
>&
attr_types
()
const
{
return
attr_types_
;
}
};
}
// namespace cpp
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/desc_apis.h
0 → 100644
浏览文件 @
65be35af
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
namespace
paddle
{
namespace
lite
{
/*
* Compatible interfaces for all the different kinds of opdesc. All the OpDesc
* classes should implement this.
* NOTE Some interfaces are weried, we remain them unchanged to keep compatible
* with framework::OpDesc in Fluid framework.
*/
class
OpDescAPI
{
public:
// The AttrType is used to make the proto::AttrType portable.
enum
class
AttrType
{
INT
=
0
,
FLOAT
=
1
,
STRING
=
2
,
INTS
=
3
,
FLOATS
=
4
,
STRINGS
=
5
,
BOOLEAN
=
6
,
BOOLEANS
=
7
,
BLOCK
=
8
,
LONG
=
9
,
BLOCKS
=
10
,
LONGS
=
11
,
UNK
,
};
virtual
~
OpDescAPI
()
=
default
;
/// Get operator's type.
virtual
std
::
string
Type
()
const
=
0
;
/// Set operator's type.
virtual
void
SetType
(
const
std
::
string
&
type
)
=
0
;
/// Get arguments given the parameter.
virtual
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
=
0
;
/// Get parameters.
virtual
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
=
0
;
/// Get arguments given the parameter.
virtual
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
=
0
;
/// Get parameters.
virtual
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
=
0
;
/// Set a input given the parameter and arguments.
virtual
void
SetInput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>&
args
)
=
0
;
virtual
void
SetOutput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>&
args
)
=
0
;
/// Tell whether this desc has an attribute.
virtual
bool
HasAttr
(
const
std
::
string
&
name
)
const
=
0
;
/// Get the type of an attribute.
virtual
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
=
0
;
virtual
std
::
vector
<
std
::
string
>
AttrNames
()
const
=
0
;
/// Set an attribute.
template
<
typename
T
>
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
);
/// Get an attribute.
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
};
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/op_desc_test.cc
0 → 100644
浏览文件 @
65be35af
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
namespace
paddle
{
namespace
lite
{
template
<
typename
OpDesc
>
void
TestX
()
{
OpDesc
desc
;
desc
.
SetInput
(
"X"
,
{
"a"
,
"b"
});
auto
X
=
desc
.
Input
(
"X"
);
ASSERT_EQ
(
X
.
size
(),
2UL
);
ASSERT_EQ
(
X
[
0
],
"a"
);
ASSERT_EQ
(
X
[
1
],
"b"
);
desc
.
SetOutput
(
"Y"
,
{
"c"
,
"d"
});
auto
Y
=
desc
.
Output
(
"Y"
);
ASSERT_EQ
(
Y
.
size
(),
2UL
);
ASSERT_EQ
(
Y
[
0
],
"c"
);
ASSERT_EQ
(
Y
[
1
],
"d"
);
desc
.
template
SetAttr
<
int32_t
>(
"aint"
,
100
);
ASSERT_TRUE
(
desc
.
HasAttr
(
"aint"
));
ASSERT_FALSE
(
desc
.
HasAttr
(
"afloat"
));
ASSERT_EQ
(
desc
.
template
GetAttr
<
int32_t
>(
"aint"
),
100
);
}
TEST
(
OpDesc
,
Basic
)
{
TestX
<
pb
::
OpDesc
>
();
TestX
<
cpp
::
OpDesc
>
();
}
TEST
(
OpDesc
,
CppToPb
)
{
cpp
::
OpDesc
desc
;
desc
.
SetInput
(
"X"
,
{
"a"
,
"b"
});
desc
.
SetOutput
(
"Y"
,
{
"c"
,
"d"
});
desc
.
template
SetAttr
<
int32_t
>(
"aint"
,
100
);
pb
::
OpDesc
pb_desc
;
TransformOpDescCppToPb
(
desc
,
&
pb_desc
);
{
auto
&
desc
=
pb_desc
;
auto
X
=
desc
.
Input
(
"X"
);
ASSERT_EQ
(
X
.
size
(),
2UL
);
ASSERT_EQ
(
X
[
0
],
"a"
);
ASSERT_EQ
(
X
[
1
],
"b"
);
auto
Y
=
desc
.
Output
(
"Y"
);
ASSERT_EQ
(
Y
.
size
(),
2UL
);
ASSERT_EQ
(
Y
[
0
],
"c"
);
ASSERT_EQ
(
Y
[
1
],
"d"
);
ASSERT_TRUE
(
desc
.
HasAttr
(
"aint"
));
ASSERT_FALSE
(
desc
.
HasAttr
(
"afloat"
));
ASSERT_EQ
(
desc
.
template
GetAttr
<
int32_t
>(
"aint"
),
100
);
}
}
TEST
(
OpDesc
,
PbToCpp
)
{
pb
::
OpDesc
desc
;
desc
.
SetInput
(
"X"
,
{
"a"
,
"b"
});
desc
.
SetOutput
(
"Y"
,
{
"c"
,
"d"
});
desc
.
template
SetAttr
<
int32_t
>(
"aint"
,
100
);
cpp
::
OpDesc
cpp_desc
;
TransformOpDescPbToCpp
(
desc
,
&
cpp_desc
);
{
auto
&
desc
=
cpp_desc
;
auto
X
=
desc
.
Input
(
"X"
);
ASSERT_EQ
(
X
.
size
(),
2UL
);
ASSERT_EQ
(
X
[
0
],
"a"
);
ASSERT_EQ
(
X
[
1
],
"b"
);
auto
Y
=
desc
.
Output
(
"Y"
);
ASSERT_EQ
(
Y
.
size
(),
2UL
);
ASSERT_EQ
(
Y
[
0
],
"c"
);
ASSERT_EQ
(
Y
[
1
],
"d"
);
ASSERT_TRUE
(
desc
.
HasAttr
(
"aint"
));
ASSERT_FALSE
(
desc
.
HasAttr
(
"afloat"
));
ASSERT_EQ
(
desc
.
template
GetAttr
<
int32_t
>(
"aint"
),
100
);
}
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/op_desc.cc
浏览文件 @
65be35af
...
@@ -18,10 +18,9 @@ namespace paddle {
...
@@ -18,10 +18,9 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
pb
{
namespace
pb
{
template
<
>
google
::
protobuf
::
internal
::
RepeatedPtrIterator
<
framework
::
proto
::
OpDesc_Attr
>
void
OpDesc
::
SetAttr
<
std
::
string
>
(
const
std
::
string
&
name
,
FindAttr
(
framework
::
proto
::
OpDesc
*
desc
,
const
std
::
string
&
name
)
{
const
std
::
string
&
v
)
{
auto
&
xs
=
*
desc
->
mutable_attrs
();
auto
&
xs
=
*
desc_
.
mutable_attrs
();
auto
it
=
std
::
find_if
(
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
...
@@ -33,33 +32,95 @@ void OpDesc::SetAttr<std::string>(const std::string &name,
...
@@ -33,33 +32,95 @@ void OpDesc::SetAttr<std::string>(const std::string &name,
return
x
.
name
()
==
name
;
return
x
.
name
()
==
name
;
});
});
}
}
return
it
;
}
#define SET_IMPL_ONE(T, ty__, pb_f__) \
template <> \
void OpDesc::SetAttr<T>(const std::string &name, const T &v) { \
auto it = FindAttr(&desc_, name); \
it->set_type(framework::proto::ty__); \
it->set_##pb_f__(v); \
}
SET_IMPL_ONE
(
int
,
INT
,
i
);
SET_IMPL_ONE
(
float
,
FLOAT
,
f
);
SET_IMPL_ONE
(
bool
,
FLOAT
,
f
);
template
<
>
void
OpDesc
::
SetAttr
<
std
::
vector
<
int
>>
(
const
std
::
string
&
name
,
const
std
::
vector
<
int
>
&
v
)
{
auto
it
=
FindAttr
(
&
desc_
,
name
);
it
->
set_type
(
framework
::
proto
::
INTS
);
it
->
clear_ints
();
for
(
auto
&
i
:
v
)
{
it
->
add_ints
(
i
);
}
}
template
<
>
void
OpDesc
::
SetAttr
<
std
::
string
>
(
const
std
::
string
&
name
,
const
std
::
string
&
v
)
{
auto
it
=
FindAttr
(
&
desc_
,
name
);
it
->
set_type
(
framework
::
proto
::
STRING
);
it
->
set_type
(
framework
::
proto
::
STRING
);
it
->
set_s
(
v
.
c_str
());
it
->
set_s
(
v
.
c_str
());
}
}
template
<
>
template
<
>
void
OpDesc
::
SetAttr
<
std
::
vector
<
int
>>
(
const
std
::
string
&
name
,
void
OpDesc
::
SetAttr
<
std
::
vector
<
float
>>
(
const
std
::
string
&
name
,
const
std
::
vector
<
int
>
&
v
)
{
const
std
::
vector
<
float
>
&
v
)
{
auto
&
xs
=
*
desc_
.
mutable_attrs
();
auto
it
=
FindAttr
(
&
desc_
,
name
);
it
->
set_type
(
framework
::
proto
::
FLOATS
);
it
->
clear_floats
();
for
(
auto
&
i
:
v
)
{
it
->
add_floats
(
i
);
}
}
template
<
>
void
OpDesc
::
SetAttr
<
std
::
vector
<
std
::
string
>>
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>
&
v
)
{
auto
it
=
FindAttr
(
&
desc_
,
name
);
it
->
set_type
(
framework
::
proto
::
STRINGS
);
it
->
clear_strings
();
for
(
auto
&
i
:
v
)
{
it
->
add_strings
(
i
);
}
}
google
::
protobuf
::
internal
::
RepeatedPtrIterator
<
const
framework
::
proto
::
OpDesc_Attr
>
GetFindAttr
(
const
framework
::
proto
::
OpDesc
&
desc
,
const
std
::
string
&
name
)
{
auto
&
xs
=
desc
.
attrs
();
auto
it
=
std
::
find_if
(
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
if
(
it
==
xs
.
end
())
{
return
it
;
auto
*
attr
=
xs
.
Add
();
}
attr
->
set_name
(
name
);
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
#define GET_ATTR_IMPL(T, pb_f__) \
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
template <> \
return
x
.
name
()
==
name
;
T OpDesc::GetAttr<T>(const std::string &name) const { \
});
auto it = GetFindAttr(desc_, name); \
return it->pb_f__(); \
}
}
it
->
set_type
(
framework
::
proto
::
INTS
);
#define GET_ATTRS_IMPL(T, pb_f__) \
it
->
clear_ints
();
template <> \
for
(
auto
&
i
:
v
)
{
T OpDesc::GetAttr<T>(const std::string &name) const { \
it
->
add_ints
(
i
);
auto it = GetFindAttr(desc_, name); \
T res; \
for (const auto &v : it->pb_f__()) { \
res.push_back(v); \
} \
return res; \
}
}
}
GET_ATTR_IMPL
(
int32_t
,
i
);
GET_ATTR_IMPL
(
float
,
f
);
GET_ATTR_IMPL
(
bool
,
b
);
GET_ATTRS_IMPL
(
std
::
vector
<
int
>
,
ints
);
GET_ATTRS_IMPL
(
std
::
vector
<
float
>
,
floats
);
GET_ATTRS_IMPL
(
std
::
vector
<
std
::
string
>
,
strings
);
GET_ATTR_IMPL
(
std
::
string
,
s
);
}
// namespace pb
}
// namespace pb
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/model_parser/pb/op_desc.h
浏览文件 @
65be35af
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/model_parser/desc_apis.h"
#include "paddle/fluid/lite/utils/all.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -43,7 +44,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
...
@@ -43,7 +44,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
* except the desc_, to avoid the inconsistent state, which is normal in the
* except the desc_, to avoid the inconsistent state, which is normal in the
* original interface and results in bugs.
* original interface and results in bugs.
*/
*/
class
OpDesc
{
class
OpDesc
:
public
OpDescAPI
{
public:
public:
OpDesc
()
{}
OpDesc
()
{}
...
@@ -54,38 +55,38 @@ class OpDesc {
...
@@ -54,38 +55,38 @@ class OpDesc {
framework
::
proto
::
OpDesc
*
Proto
()
{
return
&
desc_
;
}
framework
::
proto
::
OpDesc
*
Proto
()
{
return
&
desc_
;
}
const
framework
::
proto
::
OpDesc
&
ReadonlyProto
()
const
{
return
desc_
;
}
const
framework
::
proto
::
OpDesc
&
ReadonlyProto
()
const
{
return
desc_
;
}
std
::
string
Type
()
const
{
return
desc_
.
type
();
}
std
::
string
Type
()
const
override
{
return
desc_
.
type
();
}
void
SetType
(
const
std
::
string
&
type
)
{
desc_
.
set_type
(
type
);
}
void
SetType
(
const
std
::
string
&
type
)
override
{
desc_
.
set_type
(
type
);
}
// Get the arguments of parameter called `param`
// Get the arguments of parameter called `param`
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
{
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
override
{
return
GetArguments
(
desc_
.
inputs
(),
param
);
return
GetArguments
(
desc_
.
inputs
(),
param
);
}
}
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
{
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
override
{
return
GetArgumentNames
(
desc_
.
inputs
());
return
GetArgumentNames
(
desc_
.
inputs
());
}
}
void
SetInput
(
const
std
::
string
&
param
,
void
SetInput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
const
std
::
vector
<
std
::
string
>
&
args
)
override
{
SetArgument
(
desc_
.
mutable_inputs
(),
param
,
args
);
SetArgument
(
desc_
.
mutable_inputs
(),
param
,
args
);
}
}
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
{
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
override
{
return
GetArguments
(
desc_
.
outputs
(),
param
);
return
GetArguments
(
desc_
.
outputs
(),
param
);
}
}
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
{
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
override
{
return
GetArgumentNames
(
desc_
.
outputs
());
return
GetArgumentNames
(
desc_
.
outputs
());
}
}
void
SetOutput
(
const
std
::
string
&
param
,
void
SetOutput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
const
std
::
vector
<
std
::
string
>
&
args
)
override
{
SetArgument
(
desc_
.
mutable_outputs
(),
param
,
args
);
SetArgument
(
desc_
.
mutable_outputs
(),
param
,
args
);
}
}
bool
HasAttr
(
const
std
::
string
&
name
)
const
{
bool
HasAttr
(
const
std
::
string
&
name
)
const
override
{
const
auto
&
xs
=
desc_
.
attrs
();
const
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
...
@@ -94,17 +95,38 @@ class OpDesc {
...
@@ -94,17 +95,38 @@ class OpDesc {
return
it
!=
xs
.
end
();
return
it
!=
xs
.
end
();
}
}
framework
::
proto
::
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
{
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
override
{
const
auto
&
xs
=
desc_
.
attrs
();
const
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
return
x
.
name
()
==
name
;
});
});
CHECK
(
it
!=
xs
.
end
());
CHECK
(
it
!=
xs
.
end
());
return
it
->
type
();
#define DEF_ONE(type__) \
case framework::proto::AttrType::type__: \
return AttrType::type__;
switch
(
it
->
type
())
{
DEF_ONE
(
INT
);
DEF_ONE
(
FLOAT
);
DEF_ONE
(
STRING
);
DEF_ONE
(
INTS
);
DEF_ONE
(
FLOATS
);
DEF_ONE
(
STRINGS
);
DEF_ONE
(
BOOLEAN
);
DEF_ONE
(
BOOLEANS
);
DEF_ONE
(
BLOCK
);
DEF_ONE
(
LONG
);
DEF_ONE
(
BLOCKS
);
DEF_ONE
(
LONGS
);
default:
LOG
(
ERROR
)
<<
"Unknown attribute type"
;
return
AttrType
::
UNK
;
}
#undef DEF_ONE
}
}
std
::
vector
<
std
::
string
>
AttrNames
()
const
{
std
::
vector
<
std
::
string
>
AttrNames
()
const
override
{
std
::
vector
<
std
::
string
>
res
;
std
::
vector
<
std
::
string
>
res
;
const
auto
&
xs
=
desc_
.
attrs
();
const
auto
&
xs
=
desc_
.
attrs
();
std
::
transform
(
std
::
transform
(
...
@@ -114,72 +136,10 @@ class OpDesc {
...
@@ -114,72 +136,10 @@ class OpDesc {
}
}
template
<
typename
T
>
template
<
typename
T
>
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
)
{
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
);
auto
&
xs
=
*
desc_
.
mutable_attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
if
(
it
==
xs
.
end
())
{
auto
*
attr
=
xs
.
Add
();
attr
->
set_name
(
name
);
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
}
size_t
hash
=
typeid
(
T
).
hash_code
();
template
<
typename
T
>
if
(
hash
==
typeid
(
int
).
hash_code
())
{
// NOLINT
T
GetAttr
(
const
std
::
string
&
name
)
const
;
it
->
set_type
(
framework
::
proto
::
INT
);
it
->
set_i
(
v
);
}
else
if
(
hash
==
typeid
(
float
).
hash_code
())
{
// NOLINT
it
->
set_type
(
framework
::
proto
::
FLOAT
);
it
->
set_f
(
v
);
}
else
if
(
hash
==
typeid
(
bool
).
hash_code
())
{
// NOLINT
it
->
set_type
(
framework
::
proto
::
BOOLEAN
);
it
->
set_b
(
v
);
}
else
{
LOG
(
FATAL
)
<<
"unsupport attr type"
;
}
}
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
{
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
Attribute
res
;
CHECK
(
it
!=
xs
.
end
());
switch
(
it
->
type
())
{
case
framework
::
proto
::
INT
:
res
.
set
<
int
>
(
it
->
i
());
break
;
case
framework
::
proto
::
FLOAT
:
res
.
set
<
float
>
(
it
->
f
());
break
;
case
framework
::
proto
::
STRING
:
res
.
set
<
std
::
string
>
(
it
->
s
());
break
;
case
framework
::
proto
::
BOOLEAN
:
res
.
set
<
bool
>
(
it
->
b
());
break
;
case
framework
::
proto
::
INTS
:
{
std
::
vector
<
int
>
values
;
const
auto
&
ys
=
it
->
ints
();
std
::
transform
(
ys
.
begin
(),
ys
.
end
(),
std
::
back_inserter
(
values
),
[](
const
int
&
x
)
{
return
x
;
});
res
.
set
<
std
::
vector
<
int
>>
(
values
);
}
break
;
default:
LOG
(
FATAL
)
<<
"unsupported attr type"
;
}
return
res
;
}
private:
private:
std
::
vector
<
std
::
string
>
GetArguments
(
std
::
vector
<
std
::
string
>
GetArguments
(
...
...
paddle/fluid/lite/operators/activation_ops.cc
浏览文件 @
65be35af
...
@@ -33,7 +33,7 @@ class ActivationOp : public OpLite {
...
@@ -33,7 +33,7 @@ class ActivationOp : public OpLite {
return
true
;
return
true
;
}
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
...
@@ -66,7 +66,7 @@ class ActivationGradOp : public OpLite {
...
@@ -66,7 +66,7 @@ class ActivationGradOp : public OpLite {
return
true
;
return
true
;
}
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
Out_grad_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
Out_grad_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_grad_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
auto
X_grad_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
...
...
paddle/fluid/lite/operators/concat_op.cc
浏览文件 @
65be35af
...
@@ -54,7 +54,7 @@ bool ConcatOpLite::InferShape() const {
...
@@ -54,7 +54,7 @@ bool ConcatOpLite::InferShape() const {
}
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
ConcatOpLite
::
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
bool
ConcatOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
inputs
=
op_desc
.
Input
(
"X"
);
auto
inputs
=
op_desc
.
Input
(
"X"
);
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
...
@@ -63,7 +63,7 @@ bool ConcatOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) {
...
@@ -63,7 +63,7 @@ bool ConcatOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) {
}
}
CHECK
(
scope
->
FindVar
(
out
));
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
axis
=
GetAttr
<
int
>
(
op_desc
.
GetAttr
(
"axis"
)
);
param_
.
axis
=
op_desc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/operators/concat_op.h
浏览文件 @
65be35af
...
@@ -32,7 +32,7 @@ class ConcatOpLite : public OpLite {
...
@@ -32,7 +32,7 @@ class ConcatOpLite : public OpLite {
bool
InferShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"concat"
;
}
std
::
string
DebugString
()
const
override
{
return
"concat"
;
}
...
...
paddle/fluid/lite/operators/concat_op_test.cc
浏览文件 @
65be35af
...
@@ -42,7 +42,7 @@ TEST(concat_op_lite, test) {
...
@@ -42,7 +42,7 @@ TEST(concat_op_lite, test) {
}
}
// prepare op desc
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"concat"
);
desc
.
SetType
(
"concat"
);
desc
.
SetInput
(
"X"
,
{
"x0"
,
"x1"
});
desc
.
SetInput
(
"X"
,
{
"x0"
,
"x1"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
...
...
paddle/fluid/lite/operators/dropout_op.cc
浏览文件 @
65be35af
...
@@ -42,7 +42,7 @@ class DropoutOpLite : public OpLite {
...
@@ -42,7 +42,7 @@ class DropoutOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
Mask
=
op_desc
.
Output
(
"Mask"
).
front
();
auto
Mask
=
op_desc
.
Output
(
"Mask"
).
front
();
...
@@ -51,14 +51,14 @@ class DropoutOpLite : public OpLite {
...
@@ -51,14 +51,14 @@ class DropoutOpLite : public OpLite {
param_
.
output
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
out
);
param_
.
output
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
out
);
param_
.
mask
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Mask
);
param_
.
mask
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Mask
);
param_
.
dropout_prob
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"dropout_prob"
)
);
param_
.
dropout_prob
=
op_desc
.
GetAttr
<
float
>
(
"dropout_prob"
);
if
(
op_desc
.
HasAttr
(
"axis"
))
{
if
(
op_desc
.
HasAttr
(
"axis"
))
{
param_
.
is_test
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"is_test"
)
);
param_
.
is_test
=
op_desc
.
GetAttr
<
bool
>
(
"is_test"
);
}
}
param_
.
fix_seed
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"fix_seed"
)
);
param_
.
fix_seed
=
op_desc
.
GetAttr
<
bool
>
(
"fix_seed"
);
param_
.
seed
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"seed"
)
);
param_
.
seed
=
op_desc
.
GetAttr
<
int
>
(
"seed"
);
param_
.
dropout_implementation
=
param_
.
dropout_implementation
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"dropout_implementation"
)
);
op_desc
.
GetAttr
<
int
>
(
"dropout_implementation"
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/operators/elementwise_ops.cc
浏览文件 @
65be35af
...
@@ -36,7 +36,7 @@ class ElementwiseOp : public OpLite {
...
@@ -36,7 +36,7 @@ class ElementwiseOp : public OpLite {
return
true
;
return
true
;
}
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Y_name
=
opdesc
.
Input
(
"Y"
).
front
();
auto
Y_name
=
opdesc
.
Input
(
"Y"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
...
@@ -44,7 +44,7 @@ class ElementwiseOp : public OpLite {
...
@@ -44,7 +44,7 @@ class ElementwiseOp : public OpLite {
param_
.
X
=
GetVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
X
=
GetVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
Y
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Y_name
);
param_
.
Y
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Y_name
);
param_
.
Out
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
Out
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
axis
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"axis"
)
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
return
true
;
}
}
...
@@ -75,8 +75,8 @@ class ElementwiseGradExplicitOp : public OpLite {
...
@@ -75,8 +75,8 @@ class ElementwiseGradExplicitOp : public OpLite {
return
true
;
return
true
;
}
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
CHECK_EQ
(
opdesc
.
Inputs
().
size
(),
1UL
);
CHECK_EQ
(
opdesc
.
Input
ArgumentName
s
().
size
(),
1UL
);
auto
Out_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
Out_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
auto
X_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
auto
Y_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"Y"
)).
front
();
auto
Y_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"Y"
)).
front
();
...
@@ -84,7 +84,7 @@ class ElementwiseGradExplicitOp : public OpLite {
...
@@ -84,7 +84,7 @@ class ElementwiseGradExplicitOp : public OpLite {
param_
.
Out_grad
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
Out_grad
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
X_grad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
X_grad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
Y_grad
=
GetMutableVar
<
Tensor
>
(
scope
,
Y_name
);
param_
.
Y_grad
=
GetMutableVar
<
Tensor
>
(
scope
,
Y_name
);
param_
.
axis
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"axis"
)
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/operators/fc_op.h
浏览文件 @
65be35af
...
@@ -46,7 +46,7 @@ class FcOpLite : public OpLite {
...
@@ -46,7 +46,7 @@ class FcOpLite : public OpLite {
*/
*/
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"Input"
).
front
();
auto
input
=
op_desc
.
Input
(
"Input"
).
front
();
auto
W
=
op_desc
.
Input
(
"W"
).
front
();
auto
W
=
op_desc
.
Input
(
"W"
).
front
();
auto
bias
=
op_desc
.
Input
(
"Bias"
).
front
();
auto
bias
=
op_desc
.
Input
(
"Bias"
).
front
();
...
@@ -57,7 +57,7 @@ class FcOpLite : public OpLite {
...
@@ -57,7 +57,7 @@ class FcOpLite : public OpLite {
param_
.
bias
=
scope
->
FindVar
(
bias
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
bias
=
scope
->
FindVar
(
bias
)
->
GetMutable
<
lite
::
Tensor
>
();
CHECK
(
scope
->
FindVar
(
out
));
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
in_num_col_dims
=
GetAttr
<
int
>
(
op_desc
.
GetAttr
(
"in_num_col_dims"
)
);
param_
.
in_num_col_dims
=
op_desc
.
GetAttr
<
int
>
(
"in_num_col_dims"
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/operators/fc_op_test.cc
浏览文件 @
65be35af
...
@@ -47,7 +47,7 @@ TEST(fc_op_lite, TestX86) {
...
@@ -47,7 +47,7 @@ TEST(fc_op_lite, TestX86) {
}
}
// prepare op desc
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"fc"
);
desc
.
SetType
(
"fc"
);
desc
.
SetInput
(
"Input"
,
{
"x"
});
desc
.
SetInput
(
"Input"
,
{
"x"
});
desc
.
SetInput
(
"W"
,
{
"w"
});
desc
.
SetInput
(
"W"
,
{
"w"
});
...
...
paddle/fluid/lite/operators/feed_op.cc
浏览文件 @
65be35af
...
@@ -34,7 +34,7 @@ class FeedOp : public OpLite {
...
@@ -34,7 +34,7 @@ class FeedOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
protected:
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
feed_var_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
feed_var_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
*
feed_var
=
scope
->
FindVar
(
feed_var_name
);
auto
*
feed_var
=
scope
->
FindVar
(
feed_var_name
);
CHECK
(
feed_var
);
CHECK
(
feed_var
);
...
@@ -48,7 +48,7 @@ class FeedOp : public OpLite {
...
@@ -48,7 +48,7 @@ class FeedOp : public OpLite {
// NOTE need boost here
// NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc
// TODO(Superjomn) drop the need of framework::op_desc
param_
.
col
=
GetAttr
<
int
>
(
opdesc
.
GetAttr
(
"col"
)
);
param_
.
col
=
opdesc
.
GetAttr
<
int
>
(
"col"
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/operators/fetch_op.cc
浏览文件 @
65be35af
...
@@ -33,7 +33,7 @@ class FetchOp : public OpLite {
...
@@ -33,7 +33,7 @@ class FetchOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
protected:
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
_x
=
opdesc
.
Input
(
"X"
).
front
();
auto
_x
=
opdesc
.
Input
(
"X"
).
front
();
auto
*
x
=
scope
->
FindVar
(
_x
);
auto
*
x
=
scope
->
FindVar
(
_x
);
CHECK
(
x
);
CHECK
(
x
);
...
@@ -43,7 +43,7 @@ class FetchOp : public OpLite {
...
@@ -43,7 +43,7 @@ class FetchOp : public OpLite {
auto
*
out
=
scope
->
FindVar
(
_out
);
auto
*
out
=
scope
->
FindVar
(
_out
);
param_
.
fetch_list
=
out
->
GetMutable
<
std
::
vector
<
lite
::
Tensor
>>
();
param_
.
fetch_list
=
out
->
GetMutable
<
std
::
vector
<
lite
::
Tensor
>>
();
param_
.
col
=
GetAttr
<
int
>
(
opdesc
.
GetAttr
(
"col"
)
);
param_
.
col
=
opdesc
.
GetAttr
<
int
>
(
"col"
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/operators/fill_constant_op.cc
浏览文件 @
65be35af
...
@@ -33,14 +33,14 @@ class FillConstantOp : public OpLite {
...
@@ -33,14 +33,14 @@ class FillConstantOp : public OpLite {
return
true
;
return
true
;
}
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
param_
.
Out
=
GetMutableVar
<
Tensor
>
(
scope
,
Out_name
);
param_
.
Out
=
GetMutableVar
<
Tensor
>
(
scope
,
Out_name
);
param_
.
dtype
=
GetAttr
<
int
>
(
opdesc
.
GetAttr
(
"dtype"
)
);
param_
.
dtype
=
opdesc
.
GetAttr
<
int
>
(
"dtype"
);
param_
.
shape
=
GetAttr
<
std
::
vector
<
int64_t
>>
(
opdesc
.
GetAttr
(
"shape"
)
);
param_
.
shape
=
opdesc
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"shape"
);
param_
.
value
=
GetAttr
<
float
>
(
opdesc
.
GetAttr
(
"value"
)
);
param_
.
value
=
opdesc
.
GetAttr
<
float
>
(
"value"
);
param_
.
force_cpu
=
GetAttr
<
bool
>
(
opdesc
.
GetAttr
(
"force_cpu"
)
);
param_
.
force_cpu
=
opdesc
.
GetAttr
<
bool
>
(
"force_cpu"
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/operators/io_copy_op.cc
浏览文件 @
65be35af
...
@@ -29,7 +29,8 @@ bool IoCopyOp::InferShape() const {
...
@@ -29,7 +29,8 @@ bool IoCopyOp::InferShape() const {
return
true
;
return
true
;
}
}
bool
IoCopyOp
::
Run
()
{
return
OpLite
::
Run
();
}
bool
IoCopyOp
::
Run
()
{
return
OpLite
::
Run
();
}
bool
IoCopyOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
paddle
::
lite
::
Scope
*
scope
)
{
bool
IoCopyOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
paddle
::
lite
::
Scope
*
scope
)
{
auto
x
=
opdesc
.
Input
(
"Input"
).
front
();
auto
x
=
opdesc
.
Input
(
"Input"
).
front
();
auto
out
=
opdesc
.
Output
(
"Out"
).
front
();
auto
out
=
opdesc
.
Output
(
"Out"
).
front
();
param_
.
x
=
GetTensor
(
scope
,
x
);
param_
.
x
=
GetTensor
(
scope
,
x
);
...
...
paddle/fluid/lite/operators/io_copy_op.h
浏览文件 @
65be35af
...
@@ -31,7 +31,7 @@ class IoCopyOp : public OpLite {
...
@@ -31,7 +31,7 @@ class IoCopyOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
protected:
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
private:
private:
operators
::
IoCopyParam
param_
;
operators
::
IoCopyParam
param_
;
...
...
paddle/fluid/lite/operators/mean_op.cc
浏览文件 @
65be35af
...
@@ -37,7 +37,7 @@ class MeanOp : public OpLite {
...
@@ -37,7 +37,7 @@ class MeanOp : public OpLite {
return
true
;
return
true
;
}
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
...
@@ -72,8 +72,8 @@ class MeanGradOp : public OpLite {
...
@@ -72,8 +72,8 @@ class MeanGradOp : public OpLite {
return
true
;
return
true
;
}
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
CHECK_EQ
(
opdesc
.
Inputs
().
size
(),
3UL
);
CHECK_EQ
(
opdesc
.
Input
ArgumentName
s
().
size
(),
3UL
);
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Out_grad_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
Out_grad_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_grad_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
auto
X_grad_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
...
...
paddle/fluid/lite/operators/mul_op.cc
浏览文件 @
65be35af
...
@@ -85,7 +85,7 @@ bool MulGradOpLite::InferShape() const {
...
@@ -85,7 +85,7 @@ bool MulGradOpLite::InferShape() const {
return
true
;
return
true
;
}
}
bool
MulGradOpLite
::
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
bool
MulGradOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
X_name
=
op_desc
.
Input
(
"X"
).
front
();
auto
X_name
=
op_desc
.
Input
(
"X"
).
front
();
auto
Y_name
=
op_desc
.
Input
(
"Y"
).
front
();
auto
Y_name
=
op_desc
.
Input
(
"Y"
).
front
();
auto
Out_grad_name
=
op_desc
.
Output
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
Out_grad_name
=
op_desc
.
Output
(
framework
::
GradVarName
(
"Out"
)).
front
();
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
65be35af
...
@@ -37,7 +37,7 @@ class MulOpLite : public OpLite {
...
@@ -37,7 +37,7 @@ class MulOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
...
@@ -49,8 +49,8 @@ class MulOpLite : public OpLite {
...
@@ -49,8 +49,8 @@ class MulOpLite : public OpLite {
param_
.
y
=
var
->
GetMutable
<
Tensor
>
();
param_
.
y
=
var
->
GetMutable
<
Tensor
>
();
CHECK
(
scope
->
FindVar
(
out
));
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
x_num_col_dims
=
GetAttr
<
int
>
(
op_desc
.
GetAttr
(
"x_num_col_dims"
)
);
param_
.
x_num_col_dims
=
op_desc
.
GetAttr
<
int
>
(
"x_num_col_dims"
);
param_
.
y_num_col_dims
=
GetAttr
<
int
>
(
op_desc
.
GetAttr
(
"y_num_col_dims"
)
);
param_
.
y_num_col_dims
=
op_desc
.
GetAttr
<
int
>
(
"y_num_col_dims"
);
return
true
;
return
true
;
}
}
...
@@ -73,7 +73,7 @@ class MulGradOpLite : public OpLite {
...
@@ -73,7 +73,7 @@ class MulGradOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
std
::
string
DebugString
()
const
override
{
return
"mul_grad"
;
}
std
::
string
DebugString
()
const
override
{
return
"mul_grad"
;
}
...
...
paddle/fluid/lite/operators/relu_op.cc
浏览文件 @
65be35af
...
@@ -30,7 +30,7 @@ bool ReluOp::InferShape() const {
...
@@ -30,7 +30,7 @@ bool ReluOp::InferShape() const {
return
true
;
return
true
;
}
}
bool
ReluOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
ReluOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
input
=
const_cast
<
lite
::
Tensor
*>
(
param_
.
input
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"Input"
).
front
())
->
Get
<
lite
::
Tensor
>
());
&
scope
->
FindVar
(
opdesc
.
Input
(
"Input"
).
front
())
->
Get
<
lite
::
Tensor
>
());
param_
.
output
=
param_
.
output
=
...
...
paddle/fluid/lite/operators/relu_op.h
浏览文件 @
65be35af
...
@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
...
@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
bool
InferShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"relu"
;
}
std
::
string
DebugString
()
const
override
{
return
"relu"
;
}
...
...
paddle/fluid/lite/operators/reshape_op.cc
浏览文件 @
65be35af
...
@@ -33,7 +33,7 @@ bool ReshapeOp::InferShape() const {
...
@@ -33,7 +33,7 @@ bool ReshapeOp::InferShape() const {
return
true
;
return
true
;
}
}
bool
ReshapeOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
ReshapeOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
auto
x_var
=
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
());
auto
x_var
=
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
());
auto
output_var
=
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
());
auto
output_var
=
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
());
CHECK
(
x_var
);
CHECK
(
x_var
);
...
@@ -49,9 +49,9 @@ bool ReshapeOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) {
...
@@ -49,9 +49,9 @@ bool ReshapeOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) {
const_cast
<
lite
::
Tensor
*>
(
&
(
actual_shape_var
->
Get
<
lite
::
Tensor
>
()));
const_cast
<
lite
::
Tensor
*>
(
&
(
actual_shape_var
->
Get
<
lite
::
Tensor
>
()));
}
}
}
}
param_
.
shape
=
GetAttr
<
std
::
vector
<
int
>>
(
opdesc
.
GetAttr
(
"shape"
));
param_
.
shape
=
(
opdesc
.
GetAttr
<
std
::
vector
<
int
>>
(
"shape"
));
if
(
opdesc
.
HasAttr
(
"inplace"
))
{
if
(
opdesc
.
HasAttr
(
"inplace"
))
{
param_
.
inplace
=
GetAttr
<
bool
>
(
opdesc
.
GetAttr
(
"inplace"
)
);
param_
.
inplace
=
opdesc
.
GetAttr
<
bool
>
(
"inplace"
);
}
}
CHECK
(
param_
.
x
)
<<
"Input(X) of ReshapeOp should not be null."
;
CHECK
(
param_
.
x
)
<<
"Input(X) of ReshapeOp should not be null."
;
CHECK
(
param_
.
output
)
<<
"Output(Out) of ReshapeOp should not be null."
;
CHECK
(
param_
.
output
)
<<
"Output(Out) of ReshapeOp should not be null."
;
...
@@ -70,14 +70,14 @@ bool Reshape2Op::InferShape() const {
...
@@ -70,14 +70,14 @@ bool Reshape2Op::InferShape() const {
ReshapeOp
::
InferShape
();
ReshapeOp
::
InferShape
();
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_dims
=
param_
.
x
->
dims
();
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
0
);
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
0
);
for
(
in
t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
xshape_dims
[
i
+
1
]
=
x_dims
[
i
];
xshape_dims
[
i
+
1
]
=
x_dims
[
i
];
}
}
param_
.
xshape
->
Resize
(
DDim
(
xshape_dims
));
param_
.
xshape
->
Resize
(
DDim
(
xshape_dims
));
return
true
;
return
true
;
}
}
bool
Reshape2Op
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
Reshape2Op
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
ReshapeOp
::
AttachImpl
(
opdesc
,
scope
);
ReshapeOp
::
AttachImpl
(
opdesc
,
scope
);
auto
xshape_var
=
scope
->
FindVar
(
opdesc
.
Output
(
"XShape"
).
front
());
auto
xshape_var
=
scope
->
FindVar
(
opdesc
.
Output
(
"XShape"
).
front
());
CHECK
(
xshape_var
);
CHECK
(
xshape_var
);
...
...
paddle/fluid/lite/operators/reshape_op.h
浏览文件 @
65be35af
...
@@ -32,7 +32,7 @@ class ReshapeOp : public OpLite {
...
@@ -32,7 +32,7 @@ class ReshapeOp : public OpLite {
bool
InferShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"reshape"
;
}
std
::
string
DebugString
()
const
override
{
return
"reshape"
;
}
...
@@ -50,7 +50,7 @@ class Reshape2Op : public ReshapeOp {
...
@@ -50,7 +50,7 @@ class Reshape2Op : public ReshapeOp {
bool
InferShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"reshape2"
;
}
std
::
string
DebugString
()
const
override
{
return
"reshape2"
;
}
...
...
paddle/fluid/lite/operators/reshape_op_test.cc
浏览文件 @
65be35af
...
@@ -47,7 +47,7 @@ TEST(reshape_op_lite, test) {
...
@@ -47,7 +47,7 @@ TEST(reshape_op_lite, test) {
for
(
auto
&
has_actual_shape
:
{
true
,
false
})
{
for
(
auto
&
has_actual_shape
:
{
true
,
false
})
{
for
(
auto
&
inplace
:
{
true
,
false
})
{
for
(
auto
&
inplace
:
{
true
,
false
})
{
// prepare op desc
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"reshape"
);
desc
.
SetType
(
"reshape"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetInput
(
"X"
,
{
"x"
});
if
(
has_actual_shape
)
{
if
(
has_actual_shape
)
{
...
@@ -68,7 +68,7 @@ TEST(reshape_op_lite, test) {
...
@@ -68,7 +68,7 @@ TEST(reshape_op_lite, test) {
// check output dims
// check output dims
auto
output_dims
=
output
->
dims
();
auto
output_dims
=
output
->
dims
();
CHECK_EQ
(
output_dims
.
size
(),
shape
.
second
.
size
());
CHECK_EQ
(
output_dims
.
size
(),
shape
.
second
.
size
());
for
(
in
t
i
=
0
;
i
<
output_dims
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
output_dims
.
size
();
i
++
)
{
CHECK_EQ
(
output_dims
[
i
],
shape
.
second
[
i
]);
CHECK_EQ
(
output_dims
[
i
],
shape
.
second
[
i
]);
}
}
}
}
...
@@ -102,7 +102,7 @@ TEST(reshape2_op_lite, test) {
...
@@ -102,7 +102,7 @@ TEST(reshape2_op_lite, test) {
for
(
auto
&
has_actual_shape
:
{
true
,
false
})
{
for
(
auto
&
has_actual_shape
:
{
true
,
false
})
{
for
(
auto
&
inplace
:
{
true
,
false
})
{
for
(
auto
&
inplace
:
{
true
,
false
})
{
// prepare op desc
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"reshape"
);
desc
.
SetType
(
"reshape"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetInput
(
"X"
,
{
"x"
});
if
(
has_actual_shape
)
{
if
(
has_actual_shape
)
{
...
@@ -132,7 +132,7 @@ TEST(reshape2_op_lite, test) {
...
@@ -132,7 +132,7 @@ TEST(reshape2_op_lite, test) {
auto
xshape_dims
=
xshape
->
dims
();
auto
xshape_dims
=
xshape
->
dims
();
CHECK_EQ
(
xshape_dims
.
size
(),
x_dims
.
size
()
+
1
);
CHECK_EQ
(
xshape_dims
.
size
(),
x_dims
.
size
()
+
1
);
CHECK_EQ
(
xshape_dims
[
0
],
0
);
CHECK_EQ
(
xshape_dims
[
0
],
0
);
for
(
in
t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
CHECK_EQ
(
xshape_dims
[
i
+
1
],
x_dims
[
i
]);
CHECK_EQ
(
xshape_dims
[
i
+
1
],
x_dims
[
i
]);
}
}
}
}
...
...
paddle/fluid/lite/operators/scale_op.cc
浏览文件 @
65be35af
...
@@ -29,14 +29,14 @@ bool ScaleOp::InferShape() const {
...
@@ -29,14 +29,14 @@ bool ScaleOp::InferShape() const {
return
true
;
return
true
;
}
}
bool
ScaleOp
::
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
bool
ScaleOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
output
=
op_desc
.
Output
(
"Out"
).
front
();
auto
output
=
op_desc
.
Output
(
"Out"
).
front
();
param_
.
x
=
scope
->
FindVar
(
x
)
->
GetMutable
<
Tensor
>
();
param_
.
x
=
scope
->
FindVar
(
x
)
->
GetMutable
<
Tensor
>
();
param_
.
output
=
scope
->
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
param_
.
output
=
scope
->
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
param_
.
scale
=
GetAttr
<
float
>
(
op_desc
.
GetAttr
(
"scale"
)
);
param_
.
scale
=
op_desc
.
GetAttr
<
float
>
(
"scale"
);
param_
.
bias
=
GetAttr
<
float
>
(
op_desc
.
GetAttr
(
"bias"
)
);
param_
.
bias
=
op_desc
.
GetAttr
<
float
>
(
"bias"
);
param_
.
bias_after_scale
=
GetAttr
<
bool
>
(
op_desc
.
GetAttr
(
"bias_after_scale"
)
);
param_
.
bias_after_scale
=
op_desc
.
GetAttr
<
bool
>
(
"bias_after_scale"
);
CHECK
(
param_
.
x
);
CHECK
(
param_
.
x
);
CHECK
(
param_
.
output
);
CHECK
(
param_
.
output
);
return
true
;
return
true
;
...
...
paddle/fluid/lite/operators/scale_op.h
浏览文件 @
65be35af
...
@@ -32,7 +32,7 @@ class ScaleOp : public OpLite {
...
@@ -32,7 +32,7 @@ class ScaleOp : public OpLite {
bool
InferShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"scale"
;
}
std
::
string
DebugString
()
const
override
{
return
"scale"
;
}
...
...
paddle/fluid/lite/operators/scale_op_test.cc
浏览文件 @
65be35af
...
@@ -29,7 +29,7 @@ TEST(scale_op_lite, test) {
...
@@ -29,7 +29,7 @@ TEST(scale_op_lite, test) {
output
->
Resize
(
DDim
(
std
::
vector
<
int64_t
>
{
1
,
1
}));
output
->
Resize
(
DDim
(
std
::
vector
<
int64_t
>
{
1
,
1
}));
// prepare op desc
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"scale"
);
desc
.
SetType
(
"scale"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
...
@@ -48,7 +48,7 @@ TEST(scale_op_lite, test) {
...
@@ -48,7 +48,7 @@ TEST(scale_op_lite, test) {
auto
x_dims
=
x
->
dims
();
auto
x_dims
=
x
->
dims
();
auto
output_dims
=
output
->
dims
();
auto
output_dims
=
output
->
dims
();
CHECK_EQ
(
output_dims
.
size
(),
x_dims
.
size
());
CHECK_EQ
(
output_dims
.
size
(),
x_dims
.
size
());
for
(
in
t
i
=
0
;
i
<
output_dims
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
output_dims
.
size
();
i
++
)
{
CHECK_EQ
(
output_dims
[
i
],
x_dims
[
i
]);
CHECK_EQ
(
output_dims
[
i
],
x_dims
[
i
]);
}
}
}
}
...
...
paddle/fluid/lite/operators/softmax_op.cc
浏览文件 @
65be35af
...
@@ -24,7 +24,8 @@ bool SoftmaxOp::CheckShape() const {
...
@@ -24,7 +24,8 @@ bool SoftmaxOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
output
);
CHECK_OR_FALSE
(
param_
.
output
);
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_rank
=
x_dims
.
size
();
auto
x_rank
=
x_dims
.
size
();
CHECK_OR_FALSE
(
param_
.
axis
>=
-
x_rank
&&
param_
.
axis
<
x_rank
);
CHECK_OR_FALSE
(
param_
.
axis
>=
-
static_cast
<
int
>
(
x_rank
)
&&
param_
.
axis
<
static_cast
<
int
>
(
x_rank
));
return
true
;
return
true
;
}
}
...
@@ -33,12 +34,12 @@ bool SoftmaxOp::InferShape() const {
...
@@ -33,12 +34,12 @@ bool SoftmaxOp::InferShape() const {
return
true
;
return
true
;
}
}
bool
SoftmaxOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
SoftmaxOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
x
=
const_cast
<
lite
::
Tensor
*>
(
param_
.
x
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
())
->
Get
<
lite
::
Tensor
>
());
&
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
())
->
Get
<
lite
::
Tensor
>
());
param_
.
output
=
param_
.
output
=
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
axis
=
GetAttr
<
int
>
(
opdesc
.
GetAttr
(
"axis"
)
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
CHECK
(
param_
.
x
);
CHECK
(
param_
.
x
);
CHECK
(
param_
.
output
);
CHECK
(
param_
.
output
);
return
true
;
return
true
;
...
...
paddle/fluid/lite/operators/softmax_op.h
浏览文件 @
65be35af
...
@@ -32,7 +32,7 @@ class SoftmaxOp : public OpLite {
...
@@ -32,7 +32,7 @@ class SoftmaxOp : public OpLite {
bool
InferShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"softmax"
;
}
std
::
string
DebugString
()
const
override
{
return
"softmax"
;
}
...
...
paddle/fluid/lite/operators/softmax_op_test.cc
浏览文件 @
65be35af
...
@@ -37,7 +37,7 @@ TEST(softmax_op_lite, test) {
...
@@ -37,7 +37,7 @@ TEST(softmax_op_lite, test) {
}
}
// prepare op desc
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"softmax"
);
desc
.
SetType
(
"softmax"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录