Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
65be35af
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
65be35af
编写于
6月 04, 2019
作者:
Y
Yan Chunwei
提交者:
GitHub
6月 04, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Lite/refactor cp desc (#17831)
上级
b7cf0984
变更
45
隐藏空白更改
内联
并排
Showing
45 changed file
with
774 addition
and
331 deletion
+774
-331
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+3
-1
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
+5
-7
paddle/fluid/lite/core/mir/type_target_transform_pass.h
paddle/fluid/lite/core/mir/type_target_transform_pass.h
+3
-3
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+2
-2
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+3
-92
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+57
-36
paddle/fluid/lite/core/program.cc
paddle/fluid/lite/core/program.cc
+4
-4
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+6
-2
paddle/fluid/lite/model_parser/CMakeLists.txt
paddle/fluid/lite/model_parser/CMakeLists.txt
+4
-5
paddle/fluid/lite/model_parser/compatible_pb.cc
paddle/fluid/lite/model_parser/compatible_pb.cc
+111
-0
paddle/fluid/lite/model_parser/compatible_pb.h
paddle/fluid/lite/model_parser/compatible_pb.h
+7
-18
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
+1
-0
paddle/fluid/lite/model_parser/cpp/op_desc.cc
paddle/fluid/lite/model_parser/cpp/op_desc.cc
+66
-0
paddle/fluid/lite/model_parser/cpp/op_desc.h
paddle/fluid/lite/model_parser/cpp/op_desc.h
+126
-0
paddle/fluid/lite/model_parser/desc_apis.h
paddle/fluid/lite/model_parser/desc_apis.h
+85
-0
paddle/fluid/lite/model_parser/op_desc_test.cc
paddle/fluid/lite/model_parser/op_desc_test.cc
+107
-0
paddle/fluid/lite/model_parser/pb/op_desc.cc
paddle/fluid/lite/model_parser/pb/op_desc.cc
+80
-19
paddle/fluid/lite/model_parser/pb/op_desc.h
paddle/fluid/lite/model_parser/pb/op_desc.h
+38
-78
paddle/fluid/lite/operators/activation_ops.cc
paddle/fluid/lite/operators/activation_ops.cc
+2
-2
paddle/fluid/lite/operators/concat_op.cc
paddle/fluid/lite/operators/concat_op.cc
+2
-2
paddle/fluid/lite/operators/concat_op.h
paddle/fluid/lite/operators/concat_op.h
+1
-1
paddle/fluid/lite/operators/concat_op_test.cc
paddle/fluid/lite/operators/concat_op_test.cc
+1
-1
paddle/fluid/lite/operators/dropout_op.cc
paddle/fluid/lite/operators/dropout_op.cc
+6
-6
paddle/fluid/lite/operators/elementwise_ops.cc
paddle/fluid/lite/operators/elementwise_ops.cc
+5
-5
paddle/fluid/lite/operators/fc_op.h
paddle/fluid/lite/operators/fc_op.h
+2
-2
paddle/fluid/lite/operators/fc_op_test.cc
paddle/fluid/lite/operators/fc_op_test.cc
+1
-1
paddle/fluid/lite/operators/feed_op.cc
paddle/fluid/lite/operators/feed_op.cc
+2
-2
paddle/fluid/lite/operators/fetch_op.cc
paddle/fluid/lite/operators/fetch_op.cc
+2
-2
paddle/fluid/lite/operators/fill_constant_op.cc
paddle/fluid/lite/operators/fill_constant_op.cc
+5
-5
paddle/fluid/lite/operators/io_copy_op.cc
paddle/fluid/lite/operators/io_copy_op.cc
+2
-1
paddle/fluid/lite/operators/io_copy_op.h
paddle/fluid/lite/operators/io_copy_op.h
+1
-1
paddle/fluid/lite/operators/mean_op.cc
paddle/fluid/lite/operators/mean_op.cc
+3
-3
paddle/fluid/lite/operators/mul_op.cc
paddle/fluid/lite/operators/mul_op.cc
+1
-1
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+4
-4
paddle/fluid/lite/operators/relu_op.cc
paddle/fluid/lite/operators/relu_op.cc
+1
-1
paddle/fluid/lite/operators/relu_op.h
paddle/fluid/lite/operators/relu_op.h
+1
-1
paddle/fluid/lite/operators/reshape_op.cc
paddle/fluid/lite/operators/reshape_op.cc
+5
-5
paddle/fluid/lite/operators/reshape_op.h
paddle/fluid/lite/operators/reshape_op.h
+2
-2
paddle/fluid/lite/operators/reshape_op_test.cc
paddle/fluid/lite/operators/reshape_op_test.cc
+4
-4
paddle/fluid/lite/operators/scale_op.cc
paddle/fluid/lite/operators/scale_op.cc
+4
-4
paddle/fluid/lite/operators/scale_op.h
paddle/fluid/lite/operators/scale_op.h
+1
-1
paddle/fluid/lite/operators/scale_op_test.cc
paddle/fluid/lite/operators/scale_op_test.cc
+2
-2
paddle/fluid/lite/operators/softmax_op.cc
paddle/fluid/lite/operators/softmax_op.cc
+4
-3
paddle/fluid/lite/operators/softmax_op.h
paddle/fluid/lite/operators/softmax_op.h
+1
-1
paddle/fluid/lite/operators/softmax_op_test.cc
paddle/fluid/lite/operators/softmax_op_test.cc
+1
-1
未找到文件。
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
65be35af
...
...
@@ -25,7 +25,9 @@ cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite)
cc_library
(
scope_lite SRCS scope.cc DEPS
${
tensor_lite
}
)
cc_library
(
cpu_info_lite SRCS cpu_info.cc
)
cc_library
(
context_lite SRCS context.cc DEPS
${
tensor_lite
}
any_lite cpu_info_lite
)
cc_library
(
op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite target_wrapper_lite
${
tensor_lite
}
)
cc_library
(
op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapper_lite
cpp_op_desc_lite
${
tensor_lite
}
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS
${
tensor_lite
}
target_wrapper_lite
)
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
浏览文件 @
65be35af
...
...
@@ -90,7 +90,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
inst_node
->
AsStmt
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
lite
::
OpDesc
op_desc
;
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"io_copy"
);
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
...
...
@@ -104,8 +104,6 @@ void TypeTargetTransformPass::AddIoCopyInst(
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
auto
&
inst
=
inst_node
->
AsStmt
();
auto
inst_program_desc
=
inst
.
op_info
()
->
desc
();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink
(
graph
->
Argument
(
var
),
io_copy_inst
);
...
...
@@ -113,11 +111,11 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
auto
desc_dummy
=
inst_node
->
AsStmt
().
op
->
op_info
()
->
desc
();
UpdateInputTo
(
&
desc_dummy
,
var
,
io_copy_output_name
);
UpdateInputTo
(
inst_node
->
AsStmt
().
op
->
mutable_op_info
(),
var
,
io_copy_output_name
);
lite
::
OpDesc
desc_fake
(
desc_dummy
);
inst_node
->
AsStmt
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsStmt
().
op
->
scope
());
inst_node
->
AsStmt
().
op
->
Attach
(
*
inst_node
->
AsStmt
().
op
->
op_info
(),
inst_node
->
AsStmt
().
op
->
scope
());
std
::
string
tmp
;
if
(
inst_node
->
AsStmt
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.h
浏览文件 @
65be35af
...
...
@@ -24,10 +24,10 @@ namespace paddle {
namespace
lite
{
namespace
mir
{
static
void
UpdateInputTo
(
framework
::
proto
::
OpDesc
*
desc
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
static
void
UpdateInputTo
(
cpp
::
OpDesc
*
desc
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
input
:
*
item
.
mutable_arguments
()
)
{
for
(
auto
&
input
:
item
.
second
)
{
if
(
input
==
from
)
{
input
=
to
;
}
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
65be35af
...
...
@@ -65,7 +65,7 @@ class VariablePlaceInferencePass : public DebugPass {
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
auto
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
input
_argument
().
at
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
input
s
().
at
(
arg_name
);
for
(
auto
&
arg_name
:
arg_names
)
{
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
...
...
@@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass {
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
output_argnames
())
{
VLOG
(
3
)
<<
"-- output arg_name "
<<
arg_name
;
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
output
_argument
().
at
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
output
s
().
at
(
arg_name
);
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
...
...
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
65be35af
...
...
@@ -68,13 +68,13 @@ bool OpLite::Run() {
return
true
;
}
bool
OpLite
::
Attach
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
OpLite
::
Attach
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
// valid_places_.clear();
CHECK
(
scope
!=
nullptr
);
// CHECK(!op_info_.get());
scope_
=
scope
;
op_info_
.
reset
(
new
OpInfo
);
// Force clean the out-of-date infomation.
op_info_
->
Build
(
opdesc
.
ReadonlyProto
());
op_info_
.
reset
(
new
OpInfo
(
opdesc
));
// Force clean the out-of-date infomation.
return
AttachImpl
(
opdesc
,
scope
);
}
...
...
@@ -92,94 +92,5 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope,
return
var
->
GetMutable
<
lite
::
Tensor
>
();
}
bool
OpInfo
::
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
input_argument_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
item
.
first
;
return
true
;
}
}
return
false
;
}
bool
OpInfo
::
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
output_argument_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
item
.
first
;
return
true
;
}
}
return
false
;
}
void
OpInfo
::
ExtractInputsAndOutputs
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
inputs
())
{
for
(
const
auto
&
x
:
item
.
arguments
())
{
input_names_
.
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
outputs
())
{
for
(
const
auto
&
x
:
item
.
arguments
())
{
output_names_
.
push_back
(
x
);
}
}
}
void
OpInfo
::
CollectInputAndOutputArgnames
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
inputs
())
{
input_argnames_
.
push_back
(
item
.
parameter
());
}
for
(
const
auto
&
item
:
opdesc
.
outputs
())
{
output_argnames_
.
push_back
(
item
.
parameter
());
}
}
void
OpInfo
::
CollectArguments
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
inputs
())
{
for
(
auto
&
x
:
item
.
arguments
())
{
input_argument_
[
item
.
parameter
()].
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
outputs
())
{
for
(
auto
&
x
:
item
.
arguments
())
{
output_argument_
[
item
.
parameter
()].
push_back
(
x
);
}
}
}
void
OpInfo
::
Build
(
const
framework
::
proto
::
OpDesc
&
desc
)
{
ExtractInputsAndOutputs
(
desc
);
CollectInputAndOutputArgnames
(
desc
);
CollectArguments
(
desc
);
desc_
.
reset
(
new
framework
::
proto
::
OpDesc
(
desc
));
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
OpInfo
::
input_argument
()
const
{
return
input_argument_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
OpInfo
::
output_argument
()
const
{
return
output_argument_
;
}
const
std
::
list
<
std
::
string
>
&
OpInfo
::
input_argnames
()
const
{
return
input_argnames_
;
}
const
std
::
list
<
std
::
string
>
&
OpInfo
::
output_argnames
()
const
{
return
output_argnames_
;
}
const
framework
::
proto
::
OpDesc
&
OpInfo
::
desc
()
const
{
CHECK
(
desc_
)
<<
"desc has't set"
;
return
*
desc_
;
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_lite.h
浏览文件 @
65be35af
...
...
@@ -23,7 +23,7 @@
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/model_parser/c
ompatible_pb
.h"
#include "paddle/fluid/lite/model_parser/c
pp/op_desc
.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -71,7 +71,7 @@ class OpLite : public Registry {
virtual
bool
Run
();
// Link the external execution environ to internal context.
bool
Attach
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
bool
Attach
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
const
OpInfo
*
op_info
()
const
{
return
op_info_
.
get
();
}
OpInfo
*
mutable_op_info
()
{
return
op_info_
.
get
();
}
...
...
@@ -94,7 +94,7 @@ class OpLite : public Registry {
protected:
// Attach it with the runtime environment.
virtual
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
virtual
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
...
...
@@ -144,40 +144,61 @@ class OpLite : public Registry {
* Operator Information, such as some description. It will be shared by all the
* kernels of the same operator.
*/
class
OpInfo
{
class
OpInfo
:
public
cpp
::
OpDesc
{
public:
// To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf
// message instead.
void
Build
(
const
framework
::
proto
::
OpDesc
&
desc
);
const
framework
::
proto
::
OpDesc
&
desc
()
const
;
framework
::
proto
::
OpDesc
*
mutable_desc
()
{
return
desc_
.
get
();
}
const
std
::
list
<
std
::
string
>
&
input_names
()
const
{
return
input_names_
;
}
const
std
::
list
<
std
::
string
>
&
output_names
()
const
{
return
output_names_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
input_argument
()
const
;
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
output_argument
()
const
;
bool
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
;
bool
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
;
const
std
::
list
<
std
::
string
>
&
input_argnames
()
const
;
const
std
::
list
<
std
::
string
>
&
output_argnames
()
const
;
private:
void
ExtractInputsAndOutputs
(
const
framework
::
proto
::
OpDesc
&
opdesc
);
void
CollectInputAndOutputArgnames
(
const
framework
::
proto
::
OpDesc
&
opdesc
);
void
CollectArguments
(
const
framework
::
proto
::
OpDesc
&
opdesc
);
private:
std
::
list
<
std
::
string
>
input_names_
;
std
::
list
<
std
::
string
>
output_names_
;
std
::
list
<
std
::
string
>
input_argnames_
;
std
::
list
<
std
::
string
>
output_argnames_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
input_argument_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
output_argument_
;
// NOTE too heavy.
std
::
unique_ptr
<
framework
::
proto
::
OpDesc
>
desc_
;
OpInfo
(
const
OpInfo
&
)
=
default
;
OpInfo
(
const
cpp
::
OpDesc
&
other
)
:
cpp
::
OpDesc
(
other
)
{}
// Collect all the input variable's name.
std
::
vector
<
std
::
string
>
input_names
()
const
{
std
::
vector
<
std
::
string
>
res
;
for
(
auto
&
param
:
InputArgumentNames
())
{
for
(
auto
&
x
:
Input
(
param
))
{
res
.
push_back
(
x
);
}
}
return
res
;
}
// Collect all the output variable's name.
std
::
vector
<
std
::
string
>
output_names
()
const
{
std
::
vector
<
std
::
string
>
res
;
for
(
auto
&
param
:
OutputArgumentNames
())
{
for
(
auto
&
x
:
Output
(
param
))
{
res
.
push_back
(
x
);
}
}
return
res
;
}
std
::
vector
<
std
::
string
>
input_argnames
()
const
{
return
InputArgumentNames
();
}
std
::
vector
<
std
::
string
>
output_argnames
()
const
{
return
OutputArgumentNames
();
}
bool
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
inputs_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
item
.
first
;
return
true
;
}
}
return
false
;
}
bool
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
outputs_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
item
.
first
;
return
true
;
}
}
return
false
;
}
};
}
// namespace lite
...
...
paddle/fluid/lite/core/program.cc
浏览文件 @
65be35af
...
...
@@ -39,11 +39,11 @@ std::string RuntimeProgram::SerializeProgram(
auto
program_dummy
=
desc
;
program_dummy
.
mutable_blocks
(
0
)
->
clear_ops
();
for
(
auto
&
node
:
instructions_
)
{
auto
desc_dummy
=
node
.
op
()
->
op_info
()
->
desc
()
;
OpDesc
desc
(
desc_dummy
);
desc
.
SetAttr
(
kKernelTypeAttr
,
node
.
kernel
()
->
SerializedKernelType
());
pb
::
OpDesc
pb_desc
;
TransformOpDescCppToPb
(
*
node
.
op
()
->
op_info
(),
&
pb_desc
);
pb_
desc
.
SetAttr
(
kKernelTypeAttr
,
node
.
kernel
()
->
SerializedKernelType
());
// append new opdesc
*
program_dummy
.
mutable_blocks
(
0
)
->
add_ops
()
=
*
desc
.
Proto
();
*
program_dummy
.
mutable_blocks
(
0
)
->
add_ops
()
=
*
pb_
desc
.
Proto
();
}
return
program_dummy
.
SerializeAsString
();
}
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
65be35af
...
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#ifdef LITE_WITH_PROFILE
#include "paddle/fluid/lite/core/profile/basic_profiler.h"
#endif // LITE_WITH_PROFILE
...
...
@@ -67,7 +68,7 @@ struct Program {
CHECK
(
ops
.
empty
())
<<
"Executor duplicate Build found"
;
// Create operators.
for
(
const
auto
&
proto_op_desc
:
program
.
blocks
(
0
).
ops
())
{
lite
::
OpDesc
op_desc
(
proto_op_desc
);
pb
::
OpDesc
op_desc
(
proto_op_desc
);
auto
op_type
=
op_desc
.
Type
();
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG
(
4
)
<<
"create Op ["
<<
op_type
<<
"]"
;
...
...
@@ -75,7 +76,10 @@ struct Program {
auto
op
=
LiteOpRegistry
::
Global
().
Create
(
op_type
);
CHECK
(
op
)
<<
"no Op found for "
<<
op_type
;
ops
.
emplace_back
(
std
::
move
(
op
));
ops
.
back
()
->
Attach
(
op_desc
,
exec_scope
);
cpp
::
OpDesc
cpp_op_desc
;
TransformOpDescPbToCpp
(
op_desc
,
&
cpp_op_desc
);
ops
.
back
()
->
Attach
(
cpp_op_desc
,
exec_scope
);
}
}
...
...
paddle/fluid/lite/model_parser/CMakeLists.txt
浏览文件 @
65be35af
...
...
@@ -11,11 +11,7 @@ if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
endif
()
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite
)
else
()
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto_lite proto_desc
)
endif
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite
)
set
(
model_parser_deps variable_lite scope_lite
${
tensor_lite
}
scope_lite
target_wrapper_host
...
...
@@ -27,4 +23,7 @@ if (LITE_WITH_CUDA)
endif
()
cc_library
(
model_parser_lite SRCS model_parser.cc DEPS
${
model_parser_deps
}
)
cc_test
(
test_op_desc_lite SRCS op_desc_test.cc DEPS cpp_op_desc_lite any_lite op_desc_lite compatible_pb_lite
)
add_subdirectory
(
pb
)
add_subdirectory
(
cpp
)
paddle/fluid/lite/model_parser/compatible_pb.cc
浏览文件 @
65be35af
...
...
@@ -13,3 +13,114 @@
// limitations under the License.
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#include "compatible_pb.h"
namespace
paddle
{
namespace
lite
{
void
InputsPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
)
{
for
(
const
std
::
string
&
param
:
pb_desc
.
InputArgumentNames
())
{
cpp_desc
->
SetInput
(
param
,
pb_desc
.
Input
(
param
));
}
}
void
InputsCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
)
{
for
(
const
std
::
string
&
param
:
cpp_desc
.
InputArgumentNames
())
{
pb_desc
->
SetInput
(
param
,
cpp_desc
.
Input
(
param
));
}
}
void
OutputsPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
)
{
for
(
const
std
::
string
&
param
:
pb_desc
.
OutputArgumentNames
())
{
cpp_desc
->
SetOutput
(
param
,
pb_desc
.
Output
(
param
));
}
}
void
OutputsCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
)
{
for
(
const
std
::
string
&
param
:
cpp_desc
.
OutputArgumentNames
())
{
pb_desc
->
SetOutput
(
param
,
cpp_desc
.
Output
(
param
));
}
}
void
AttrsPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
)
{
using
AttrType
=
OpDescAPI
::
AttrType
;
auto
set_attr
=
[
&
](
const
std
::
string
&
name
,
AttrType
type
)
{
switch
(
type
)
{
case
AttrType
::
INT
:
cpp_desc
->
SetAttr
<
int32_t
>
(
name
,
pb_desc
.
GetAttr
<
int32_t
>
(
name
));
break
;
case
AttrType
::
FLOAT
:
cpp_desc
->
SetAttr
<
float
>
(
name
,
pb_desc
.
GetAttr
<
float
>
(
name
));
break
;
case
AttrType
::
STRING
:
cpp_desc
->
SetAttr
<
std
::
string
>
(
name
,
pb_desc
.
GetAttr
<
std
::
string
>
(
name
));
break
;
case
AttrType
::
INTS
:
cpp_desc
->
SetAttr
<
std
::
vector
<
int
>>
(
name
,
pb_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
name
));
break
;
case
AttrType
::
FLOATS
:
cpp_desc
->
SetAttr
<
std
::
vector
<
float
>>
(
name
,
pb_desc
.
GetAttr
<
std
::
vector
<
float
>>
(
name
));
break
;
case
AttrType
::
BOOLEAN
:
cpp_desc
->
SetAttr
<
bool
>
(
name
,
pb_desc
.
GetAttr
<
bool
>
(
name
));
break
;
case
AttrType
::
STRINGS
:
cpp_desc
->
SetAttr
<
std
::
vector
<
std
::
string
>>
(
name
,
pb_desc
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
));
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported attr type found "
<<
static_cast
<
int
>
(
type
);
}
};
for
(
const
auto
&
attr_name
:
pb_desc
.
AttrNames
())
{
auto
type
=
pb_desc
.
GetAttrType
(
attr_name
);
set_attr
(
attr_name
,
type
);
}
}
void
AttrsCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
)
{
using
AttrType
=
OpDescAPI
::
AttrType
;
auto
set_attr
=
[
&
](
const
std
::
string
&
name
,
AttrType
type
)
{
switch
(
type
)
{
#define IMPL_ONE(type__, T) \
case AttrType::type__: \
pb_desc->SetAttr<T>(name, cpp_desc.GetAttr<T>(name)); \
break;
IMPL_ONE
(
INT
,
int32_t
);
IMPL_ONE
(
FLOAT
,
float
);
IMPL_ONE
(
STRING
,
std
::
string
);
IMPL_ONE
(
STRINGS
,
std
::
vector
<
std
::
string
>
);
IMPL_ONE
(
FLOATS
,
std
::
vector
<
float
>
);
IMPL_ONE
(
INTS
,
std
::
vector
<
int
>
);
IMPL_ONE
(
BOOLEAN
,
bool
);
default:
LOG
(
FATAL
)
<<
"Unsupported attr type found: "
<<
static_cast
<
int
>
(
type
);
}
};
#undef IMPL_ONE
for
(
const
auto
&
attr_name
:
cpp_desc
.
AttrNames
())
{
auto
type
=
cpp_desc
.
GetAttrType
(
attr_name
);
set_attr
(
attr_name
,
type
);
}
}
void
TransformOpDescPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
)
{
cpp_desc
->
SetType
(
pb_desc
.
Type
());
InputsPbToCpp
(
pb_desc
,
cpp_desc
);
OutputsPbToCpp
(
pb_desc
,
cpp_desc
);
AttrsPbToCpp
(
pb_desc
,
cpp_desc
);
}
void
TransformOpDescCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
)
{
pb_desc
->
SetType
(
cpp_desc
.
Type
());
InputsCppToPb
(
cpp_desc
,
pb_desc
);
OutputsCppToPb
(
cpp_desc
,
pb_desc
);
AttrsCppToPb
(
cpp_desc
,
pb_desc
);
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/compatible_pb.h
浏览文件 @
65be35af
...
...
@@ -20,39 +20,28 @@
* lite::pb::XXDesc.
*/
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/var_desc.h"
#else
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
namespace
paddle
{
namespace
lite
{
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using
Attribute
=
lite
::
pb
::
Attribute
;
using
OpDesc
=
lite
::
pb
::
OpDesc
;
using
VarDesc
=
lite
::
pb
::
VarDesc
;
#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using
Attribute
=
framework
::
Attribute
;
using
OpDesc
=
framework
::
OpDesc
;
using
VarDesc
=
framework
::
VarDesc
;
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
template
<
typename
T
>
T
GetAttr
(
const
Attribute
&
x
)
{
return
x
.
get
<
T
>
();
}
#else
template
<
typename
T
>
T
GetAttr
(
const
Attribute
&
x
)
{
return
boost
::
get
<
T
>
(
x
);
}
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
/// Transform an OpDesc from pb to cpp format.
void
TransformOpDescPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
);
/// Transform an OpDesc from cpp to pb format.
void
TransformOpDescCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
);
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
0 → 100644
浏览文件 @
65be35af
cc_library
(
cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite
)
paddle/fluid/lite/model_parser/cpp/op_desc.cc
0 → 100644
浏览文件 @
65be35af
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include <set>
namespace
paddle
{
namespace
lite
{
namespace
cpp
{
#define SET_ATTR_IMPL(T, repr__) \
template <> \
void OpDesc::SetAttr<T>(const std::string& name, const T& v) { \
attr_types_[name] = AttrType::repr__; \
attrs_[name].set<T>(v); \
}
SET_ATTR_IMPL
(
int32_t
,
INT
);
SET_ATTR_IMPL
(
float
,
FLOAT
);
SET_ATTR_IMPL
(
std
::
string
,
STRING
);
SET_ATTR_IMPL
(
bool
,
BOOLEAN
);
SET_ATTR_IMPL
(
std
::
vector
<
int
>
,
INTS
);
SET_ATTR_IMPL
(
std
::
vector
<
float
>
,
FLOATS
);
SET_ATTR_IMPL
(
std
::
vector
<
std
::
string
>
,
STRINGS
);
std
::
pair
<
OpDesc
::
attrs_t
::
const_iterator
,
OpDesc
::
attr_types_t
::
const_iterator
>
FindAttr
(
const
cpp
::
OpDesc
&
desc
,
const
std
::
string
&
name
)
{
auto
it
=
desc
.
attrs
().
find
(
name
);
CHECK
(
it
!=
desc
.
attrs
().
end
())
<<
"No attributes called "
<<
name
<<
" found"
;
auto
attr_it
=
desc
.
attr_types
().
find
(
name
);
CHECK
(
attr_it
!=
desc
.
attr_types
().
end
());
return
std
::
make_pair
(
it
,
attr_it
);
}
#define GET_IMPL_ONE(T, repr__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
auto pair = FindAttr(*this, name); \
CHECK(pair.second->second == AttrType::repr__); \
return pair.first->second.get<T>(); \
}
GET_IMPL_ONE
(
int32_t
,
INT
);
GET_IMPL_ONE
(
float
,
FLOAT
);
GET_IMPL_ONE
(
std
::
string
,
STRING
);
GET_IMPL_ONE
(
bool
,
BOOLEAN
);
GET_IMPL_ONE
(
std
::
vector
<
int64_t
>
,
LONGS
);
GET_IMPL_ONE
(
std
::
vector
<
float
>
,
FLOATS
);
GET_IMPL_ONE
(
std
::
vector
<
int
>
,
INTS
);
GET_IMPL_ONE
(
std
::
vector
<
std
::
string
>
,
STRINGS
);
}
// namespace cpp
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/cpp/op_desc.h
0 → 100644
浏览文件 @
65be35af
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/lite/model_parser/desc_apis.h"
#include "paddle/fluid/lite/utils/any.h"
#include "paddle/fluid/lite/utils/varient.h"
namespace
paddle
{
namespace
lite
{
namespace
cpp
{
/*
* The cpp::OpDesc is the internal representation for Op. All the internal
* imprementation should use it, not the pb::OpDesc.
*/
class
OpDesc
:
public
OpDescAPI
{
public:
using
attrs_t
=
std
::
map
<
std
::
string
,
Any
>
;
using
attr_types_t
=
std
::
map
<
std
::
string
,
AttrType
>
;
protected:
std
::
string
type_
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
inputs_
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
outputs_
;
std
::
map
<
std
::
string
,
Any
>
attrs_
;
std
::
map
<
std
::
string
,
AttrType
>
attr_types_
;
public:
OpDesc
()
=
default
;
std
::
string
Type
()
const
override
{
return
type_
;
}
void
SetType
(
const
std
::
string
&
x
)
override
{
type_
=
x
;
}
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>&
inputs
()
const
{
return
inputs_
;
}
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>&
outputs
()
const
{
return
outputs_
;
}
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>*
mutable_inputs
()
{
return
&
inputs_
;
}
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>*
mutable_outputs
()
{
return
&
outputs_
;
}
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
override
{
auto
it
=
inputs_
.
find
(
param
);
CHECK
(
it
!=
inputs_
.
end
());
return
it
->
second
;
}
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
override
{
std
::
vector
<
std
::
string
>
res
;
for
(
const
auto
&
x
:
inputs_
)
res
.
push_back
(
x
.
first
);
return
res
;
}
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
override
{
std
::
vector
<
std
::
string
>
res
;
for
(
const
auto
&
x
:
outputs_
)
res
.
push_back
(
x
.
first
);
return
res
;
}
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
override
{
auto
it
=
outputs_
.
find
(
param
);
CHECK
(
it
!=
outputs_
.
end
());
return
it
->
second
;
}
void
SetInput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>&
args
)
override
{
inputs_
[
param
]
=
args
;
}
void
SetOutput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>&
args
)
override
{
outputs_
[
param
]
=
args
;
}
bool
HasAttr
(
const
std
::
string
&
name
)
const
override
{
return
attrs_
.
count
(
name
);
}
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
override
{
auto
it
=
attr_types_
.
find
(
name
);
CHECK
(
it
!=
attr_types_
.
end
());
return
it
->
second
;
}
std
::
vector
<
std
::
string
>
AttrNames
()
const
override
{
std
::
vector
<
std
::
string
>
res
;
for
(
const
auto
&
x
:
attrs_
)
{
res
.
push_back
(
x
.
first
);
}
return
res
;
}
template
<
typename
T
>
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
);
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
const
std
::
map
<
std
::
string
,
Any
>&
attrs
()
const
{
return
attrs_
;
}
const
std
::
map
<
std
::
string
,
AttrType
>&
attr_types
()
const
{
return
attr_types_
;
}
};
}
// namespace cpp
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/desc_apis.h
0 → 100644
浏览文件 @
65be35af
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
namespace
paddle
{
namespace
lite
{
/*
* Compatible interfaces for all the different kinds of opdesc. All the OpDesc
* classes should implement this.
* NOTE Some interfaces are weried, we remain them unchanged to keep compatible
* with framework::OpDesc in Fluid framework.
*/
class
OpDescAPI
{
public:
// The AttrType is used to make the proto::AttrType portable.
enum
class
AttrType
{
INT
=
0
,
FLOAT
=
1
,
STRING
=
2
,
INTS
=
3
,
FLOATS
=
4
,
STRINGS
=
5
,
BOOLEAN
=
6
,
BOOLEANS
=
7
,
BLOCK
=
8
,
LONG
=
9
,
BLOCKS
=
10
,
LONGS
=
11
,
UNK
,
};
virtual
~
OpDescAPI
()
=
default
;
/// Get operator's type.
virtual
std
::
string
Type
()
const
=
0
;
/// Set operator's type.
virtual
void
SetType
(
const
std
::
string
&
type
)
=
0
;
/// Get arguments given the parameter.
virtual
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
=
0
;
/// Get parameters.
virtual
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
=
0
;
/// Get arguments given the parameter.
virtual
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
=
0
;
/// Get parameters.
virtual
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
=
0
;
/// Set a input given the parameter and arguments.
virtual
void
SetInput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>&
args
)
=
0
;
virtual
void
SetOutput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>&
args
)
=
0
;
/// Tell whether this desc has an attribute.
virtual
bool
HasAttr
(
const
std
::
string
&
name
)
const
=
0
;
/// Get the type of an attribute.
virtual
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
=
0
;
virtual
std
::
vector
<
std
::
string
>
AttrNames
()
const
=
0
;
/// Set an attribute.
template
<
typename
T
>
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
);
/// Get an attribute.
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
};
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/op_desc_test.cc
0 → 100644
浏览文件 @
65be35af
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
namespace
paddle
{
namespace
lite
{
template
<
typename
OpDesc
>
void
TestX
()
{
OpDesc
desc
;
desc
.
SetInput
(
"X"
,
{
"a"
,
"b"
});
auto
X
=
desc
.
Input
(
"X"
);
ASSERT_EQ
(
X
.
size
(),
2UL
);
ASSERT_EQ
(
X
[
0
],
"a"
);
ASSERT_EQ
(
X
[
1
],
"b"
);
desc
.
SetOutput
(
"Y"
,
{
"c"
,
"d"
});
auto
Y
=
desc
.
Output
(
"Y"
);
ASSERT_EQ
(
Y
.
size
(),
2UL
);
ASSERT_EQ
(
Y
[
0
],
"c"
);
ASSERT_EQ
(
Y
[
1
],
"d"
);
desc
.
template
SetAttr
<
int32_t
>(
"aint"
,
100
);
ASSERT_TRUE
(
desc
.
HasAttr
(
"aint"
));
ASSERT_FALSE
(
desc
.
HasAttr
(
"afloat"
));
ASSERT_EQ
(
desc
.
template
GetAttr
<
int32_t
>(
"aint"
),
100
);
}
TEST
(
OpDesc
,
Basic
)
{
TestX
<
pb
::
OpDesc
>
();
TestX
<
cpp
::
OpDesc
>
();
}
TEST
(
OpDesc
,
CppToPb
)
{
cpp
::
OpDesc
desc
;
desc
.
SetInput
(
"X"
,
{
"a"
,
"b"
});
desc
.
SetOutput
(
"Y"
,
{
"c"
,
"d"
});
desc
.
template
SetAttr
<
int32_t
>(
"aint"
,
100
);
pb
::
OpDesc
pb_desc
;
TransformOpDescCppToPb
(
desc
,
&
pb_desc
);
{
auto
&
desc
=
pb_desc
;
auto
X
=
desc
.
Input
(
"X"
);
ASSERT_EQ
(
X
.
size
(),
2UL
);
ASSERT_EQ
(
X
[
0
],
"a"
);
ASSERT_EQ
(
X
[
1
],
"b"
);
auto
Y
=
desc
.
Output
(
"Y"
);
ASSERT_EQ
(
Y
.
size
(),
2UL
);
ASSERT_EQ
(
Y
[
0
],
"c"
);
ASSERT_EQ
(
Y
[
1
],
"d"
);
ASSERT_TRUE
(
desc
.
HasAttr
(
"aint"
));
ASSERT_FALSE
(
desc
.
HasAttr
(
"afloat"
));
ASSERT_EQ
(
desc
.
template
GetAttr
<
int32_t
>(
"aint"
),
100
);
}
}
TEST
(
OpDesc
,
PbToCpp
)
{
pb
::
OpDesc
desc
;
desc
.
SetInput
(
"X"
,
{
"a"
,
"b"
});
desc
.
SetOutput
(
"Y"
,
{
"c"
,
"d"
});
desc
.
template
SetAttr
<
int32_t
>(
"aint"
,
100
);
cpp
::
OpDesc
cpp_desc
;
TransformOpDescPbToCpp
(
desc
,
&
cpp_desc
);
{
auto
&
desc
=
cpp_desc
;
auto
X
=
desc
.
Input
(
"X"
);
ASSERT_EQ
(
X
.
size
(),
2UL
);
ASSERT_EQ
(
X
[
0
],
"a"
);
ASSERT_EQ
(
X
[
1
],
"b"
);
auto
Y
=
desc
.
Output
(
"Y"
);
ASSERT_EQ
(
Y
.
size
(),
2UL
);
ASSERT_EQ
(
Y
[
0
],
"c"
);
ASSERT_EQ
(
Y
[
1
],
"d"
);
ASSERT_TRUE
(
desc
.
HasAttr
(
"aint"
));
ASSERT_FALSE
(
desc
.
HasAttr
(
"afloat"
));
ASSERT_EQ
(
desc
.
template
GetAttr
<
int32_t
>(
"aint"
),
100
);
}
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/op_desc.cc
浏览文件 @
65be35af
...
...
@@ -18,10 +18,9 @@ namespace paddle {
namespace
lite
{
namespace
pb
{
template
<
>
void
OpDesc
::
SetAttr
<
std
::
string
>
(
const
std
::
string
&
name
,
const
std
::
string
&
v
)
{
auto
&
xs
=
*
desc_
.
mutable_attrs
();
google
::
protobuf
::
internal
::
RepeatedPtrIterator
<
framework
::
proto
::
OpDesc_Attr
>
FindAttr
(
framework
::
proto
::
OpDesc
*
desc
,
const
std
::
string
&
name
)
{
auto
&
xs
=
*
desc
->
mutable_attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
...
...
@@ -33,33 +32,95 @@ void OpDesc::SetAttr<std::string>(const std::string &name,
return
x
.
name
()
==
name
;
});
}
return
it
;
}
#define SET_IMPL_ONE(T, ty__, pb_f__) \
template <> \
void OpDesc::SetAttr<T>(const std::string &name, const T &v) { \
auto it = FindAttr(&desc_, name); \
it->set_type(framework::proto::ty__); \
it->set_##pb_f__(v); \
}
SET_IMPL_ONE
(
int
,
INT
,
i
);
SET_IMPL_ONE
(
float
,
FLOAT
,
f
);
SET_IMPL_ONE
(
bool
,
FLOAT
,
f
);
template
<
>
void
OpDesc
::
SetAttr
<
std
::
vector
<
int
>>
(
const
std
::
string
&
name
,
const
std
::
vector
<
int
>
&
v
)
{
auto
it
=
FindAttr
(
&
desc_
,
name
);
it
->
set_type
(
framework
::
proto
::
INTS
);
it
->
clear_ints
();
for
(
auto
&
i
:
v
)
{
it
->
add_ints
(
i
);
}
}
template
<
>
void
OpDesc
::
SetAttr
<
std
::
string
>
(
const
std
::
string
&
name
,
const
std
::
string
&
v
)
{
auto
it
=
FindAttr
(
&
desc_
,
name
);
it
->
set_type
(
framework
::
proto
::
STRING
);
it
->
set_s
(
v
.
c_str
());
}
template
<
>
void
OpDesc
::
SetAttr
<
std
::
vector
<
int
>>
(
const
std
::
string
&
name
,
const
std
::
vector
<
int
>
&
v
)
{
auto
&
xs
=
*
desc_
.
mutable_attrs
();
void
OpDesc
::
SetAttr
<
std
::
vector
<
float
>>
(
const
std
::
string
&
name
,
const
std
::
vector
<
float
>
&
v
)
{
auto
it
=
FindAttr
(
&
desc_
,
name
);
it
->
set_type
(
framework
::
proto
::
FLOATS
);
it
->
clear_floats
();
for
(
auto
&
i
:
v
)
{
it
->
add_floats
(
i
);
}
}
template
<
>
void
OpDesc
::
SetAttr
<
std
::
vector
<
std
::
string
>>
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>
&
v
)
{
auto
it
=
FindAttr
(
&
desc_
,
name
);
it
->
set_type
(
framework
::
proto
::
STRINGS
);
it
->
clear_strings
();
for
(
auto
&
i
:
v
)
{
it
->
add_strings
(
i
);
}
}
google
::
protobuf
::
internal
::
RepeatedPtrIterator
<
const
framework
::
proto
::
OpDesc_Attr
>
GetFindAttr
(
const
framework
::
proto
::
OpDesc
&
desc
,
const
std
::
string
&
name
)
{
auto
&
xs
=
desc
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
if
(
it
==
xs
.
end
())
{
auto
*
attr
=
xs
.
Add
();
attr
->
set_name
(
name
);
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
return
it
;
}
#define GET_ATTR_IMPL(T, pb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string &name) const { \
auto it = GetFindAttr(desc_, name); \
return it->pb_f__(); \
}
it
->
set_type
(
framework
::
proto
::
INTS
);
it
->
clear_ints
();
for
(
auto
&
i
:
v
)
{
it
->
add_ints
(
i
);
#define GET_ATTRS_IMPL(T, pb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string &name) const { \
auto it = GetFindAttr(desc_, name); \
T res; \
for (const auto &v : it->pb_f__()) { \
res.push_back(v); \
} \
return res; \
}
}
GET_ATTR_IMPL
(
int32_t
,
i
);
GET_ATTR_IMPL
(
float
,
f
);
GET_ATTR_IMPL
(
bool
,
b
);
GET_ATTRS_IMPL
(
std
::
vector
<
int
>
,
ints
);
GET_ATTRS_IMPL
(
std
::
vector
<
float
>
,
floats
);
GET_ATTRS_IMPL
(
std
::
vector
<
std
::
string
>
,
strings
);
GET_ATTR_IMPL
(
std
::
string
,
s
);
}
// namespace pb
}
// namespace lite
...
...
paddle/fluid/lite/model_parser/pb/op_desc.h
浏览文件 @
65be35af
...
...
@@ -27,6 +27,7 @@
#include <unordered_map>
#include <vector>
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/model_parser/desc_apis.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
...
...
@@ -43,7 +44,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
* except the desc_, to avoid the inconsistent state, which is normal in the
* original interface and results in bugs.
*/
class
OpDesc
{
class
OpDesc
:
public
OpDescAPI
{
public:
OpDesc
()
{}
...
...
@@ -54,38 +55,38 @@ class OpDesc {
framework
::
proto
::
OpDesc
*
Proto
()
{
return
&
desc_
;
}
const
framework
::
proto
::
OpDesc
&
ReadonlyProto
()
const
{
return
desc_
;
}
std
::
string
Type
()
const
{
return
desc_
.
type
();
}
std
::
string
Type
()
const
override
{
return
desc_
.
type
();
}
void
SetType
(
const
std
::
string
&
type
)
{
desc_
.
set_type
(
type
);
}
void
SetType
(
const
std
::
string
&
type
)
override
{
desc_
.
set_type
(
type
);
}
// Get the arguments of parameter called `param`
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
{
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
override
{
return
GetArguments
(
desc_
.
inputs
(),
param
);
}
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
{
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
override
{
return
GetArgumentNames
(
desc_
.
inputs
());
}
void
SetInput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
const
std
::
vector
<
std
::
string
>
&
args
)
override
{
SetArgument
(
desc_
.
mutable_inputs
(),
param
,
args
);
}
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
{
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
override
{
return
GetArguments
(
desc_
.
outputs
(),
param
);
}
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
{
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
override
{
return
GetArgumentNames
(
desc_
.
outputs
());
}
void
SetOutput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
const
std
::
vector
<
std
::
string
>
&
args
)
override
{
SetArgument
(
desc_
.
mutable_outputs
(),
param
,
args
);
}
bool
HasAttr
(
const
std
::
string
&
name
)
const
{
bool
HasAttr
(
const
std
::
string
&
name
)
const
override
{
const
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
...
...
@@ -94,17 +95,38 @@ class OpDesc {
return
it
!=
xs
.
end
();
}
framework
::
proto
::
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
{
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
override
{
const
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
CHECK
(
it
!=
xs
.
end
());
return
it
->
type
();
#define DEF_ONE(type__) \
case framework::proto::AttrType::type__: \
return AttrType::type__;
switch
(
it
->
type
())
{
DEF_ONE
(
INT
);
DEF_ONE
(
FLOAT
);
DEF_ONE
(
STRING
);
DEF_ONE
(
INTS
);
DEF_ONE
(
FLOATS
);
DEF_ONE
(
STRINGS
);
DEF_ONE
(
BOOLEAN
);
DEF_ONE
(
BOOLEANS
);
DEF_ONE
(
BLOCK
);
DEF_ONE
(
LONG
);
DEF_ONE
(
BLOCKS
);
DEF_ONE
(
LONGS
);
default:
LOG
(
ERROR
)
<<
"Unknown attribute type"
;
return
AttrType
::
UNK
;
}
#undef DEF_ONE
}
std
::
vector
<
std
::
string
>
AttrNames
()
const
{
std
::
vector
<
std
::
string
>
AttrNames
()
const
override
{
std
::
vector
<
std
::
string
>
res
;
const
auto
&
xs
=
desc_
.
attrs
();
std
::
transform
(
...
...
@@ -114,72 +136,10 @@ class OpDesc {
}
template
<
typename
T
>
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
)
{
auto
&
xs
=
*
desc_
.
mutable_attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
if
(
it
==
xs
.
end
())
{
auto
*
attr
=
xs
.
Add
();
attr
->
set_name
(
name
);
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
}
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
);
size_t
hash
=
typeid
(
T
).
hash_code
();
if
(
hash
==
typeid
(
int
).
hash_code
())
{
// NOLINT
it
->
set_type
(
framework
::
proto
::
INT
);
it
->
set_i
(
v
);
}
else
if
(
hash
==
typeid
(
float
).
hash_code
())
{
// NOLINT
it
->
set_type
(
framework
::
proto
::
FLOAT
);
it
->
set_f
(
v
);
}
else
if
(
hash
==
typeid
(
bool
).
hash_code
())
{
// NOLINT
it
->
set_type
(
framework
::
proto
::
BOOLEAN
);
it
->
set_b
(
v
);
}
else
{
LOG
(
FATAL
)
<<
"unsupport attr type"
;
}
}
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
{
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
Attribute
res
;
CHECK
(
it
!=
xs
.
end
());
switch
(
it
->
type
())
{
case
framework
::
proto
::
INT
:
res
.
set
<
int
>
(
it
->
i
());
break
;
case
framework
::
proto
::
FLOAT
:
res
.
set
<
float
>
(
it
->
f
());
break
;
case
framework
::
proto
::
STRING
:
res
.
set
<
std
::
string
>
(
it
->
s
());
break
;
case
framework
::
proto
::
BOOLEAN
:
res
.
set
<
bool
>
(
it
->
b
());
break
;
case
framework
::
proto
::
INTS
:
{
std
::
vector
<
int
>
values
;
const
auto
&
ys
=
it
->
ints
();
std
::
transform
(
ys
.
begin
(),
ys
.
end
(),
std
::
back_inserter
(
values
),
[](
const
int
&
x
)
{
return
x
;
});
res
.
set
<
std
::
vector
<
int
>>
(
values
);
}
break
;
default:
LOG
(
FATAL
)
<<
"unsupported attr type"
;
}
return
res
;
}
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
private:
std
::
vector
<
std
::
string
>
GetArguments
(
...
...
paddle/fluid/lite/operators/activation_ops.cc
浏览文件 @
65be35af
...
...
@@ -33,7 +33,7 @@ class ActivationOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
...
...
@@ -66,7 +66,7 @@ class ActivationGradOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
Out_grad_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_grad_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
...
...
paddle/fluid/lite/operators/concat_op.cc
浏览文件 @
65be35af
...
...
@@ -54,7 +54,7 @@ bool ConcatOpLite::InferShape() const {
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
ConcatOpLite
::
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
bool
ConcatOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
inputs
=
op_desc
.
Input
(
"X"
);
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
...
...
@@ -63,7 +63,7 @@ bool ConcatOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) {
}
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
axis
=
GetAttr
<
int
>
(
op_desc
.
GetAttr
(
"axis"
)
);
param_
.
axis
=
op_desc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/concat_op.h
浏览文件 @
65be35af
...
...
@@ -32,7 +32,7 @@ class ConcatOpLite : public OpLite {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"concat"
;
}
...
...
paddle/fluid/lite/operators/concat_op_test.cc
浏览文件 @
65be35af
...
...
@@ -42,7 +42,7 @@ TEST(concat_op_lite, test) {
}
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"concat"
);
desc
.
SetInput
(
"X"
,
{
"x0"
,
"x1"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
...
...
paddle/fluid/lite/operators/dropout_op.cc
浏览文件 @
65be35af
...
...
@@ -42,7 +42,7 @@ class DropoutOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
Mask
=
op_desc
.
Output
(
"Mask"
).
front
();
...
...
@@ -51,14 +51,14 @@ class DropoutOpLite : public OpLite {
param_
.
output
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
out
);
param_
.
mask
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Mask
);
param_
.
dropout_prob
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"dropout_prob"
)
);
param_
.
dropout_prob
=
op_desc
.
GetAttr
<
float
>
(
"dropout_prob"
);
if
(
op_desc
.
HasAttr
(
"axis"
))
{
param_
.
is_test
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"is_test"
)
);
param_
.
is_test
=
op_desc
.
GetAttr
<
bool
>
(
"is_test"
);
}
param_
.
fix_seed
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"fix_seed"
)
);
param_
.
seed
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"seed"
)
);
param_
.
fix_seed
=
op_desc
.
GetAttr
<
bool
>
(
"fix_seed"
);
param_
.
seed
=
op_desc
.
GetAttr
<
int
>
(
"seed"
);
param_
.
dropout_implementation
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"dropout_implementation"
)
);
op_desc
.
GetAttr
<
int
>
(
"dropout_implementation"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/elementwise_ops.cc
浏览文件 @
65be35af
...
...
@@ -36,7 +36,7 @@ class ElementwiseOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Y_name
=
opdesc
.
Input
(
"Y"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
...
...
@@ -44,7 +44,7 @@ class ElementwiseOp : public OpLite {
param_
.
X
=
GetVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
Y
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Y_name
);
param_
.
Out
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
axis
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"axis"
)
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
...
...
@@ -75,8 +75,8 @@ class ElementwiseGradExplicitOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
CHECK_EQ
(
opdesc
.
Inputs
().
size
(),
1UL
);
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
CHECK_EQ
(
opdesc
.
Input
ArgumentName
s
().
size
(),
1UL
);
auto
Out_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
auto
Y_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"Y"
)).
front
();
...
...
@@ -84,7 +84,7 @@ class ElementwiseGradExplicitOp : public OpLite {
param_
.
Out_grad
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
X_grad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
Y_grad
=
GetMutableVar
<
Tensor
>
(
scope
,
Y_name
);
param_
.
axis
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"axis"
)
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/fc_op.h
浏览文件 @
65be35af
...
...
@@ -46,7 +46,7 @@ class FcOpLite : public OpLite {
*/
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"Input"
).
front
();
auto
W
=
op_desc
.
Input
(
"W"
).
front
();
auto
bias
=
op_desc
.
Input
(
"Bias"
).
front
();
...
...
@@ -57,7 +57,7 @@ class FcOpLite : public OpLite {
param_
.
bias
=
scope
->
FindVar
(
bias
)
->
GetMutable
<
lite
::
Tensor
>
();
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
in_num_col_dims
=
GetAttr
<
int
>
(
op_desc
.
GetAttr
(
"in_num_col_dims"
)
);
param_
.
in_num_col_dims
=
op_desc
.
GetAttr
<
int
>
(
"in_num_col_dims"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/fc_op_test.cc
浏览文件 @
65be35af
...
...
@@ -47,7 +47,7 @@ TEST(fc_op_lite, TestX86) {
}
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"fc"
);
desc
.
SetInput
(
"Input"
,
{
"x"
});
desc
.
SetInput
(
"W"
,
{
"w"
});
...
...
paddle/fluid/lite/operators/feed_op.cc
浏览文件 @
65be35af
...
...
@@ -34,7 +34,7 @@ class FeedOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
feed_var_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
*
feed_var
=
scope
->
FindVar
(
feed_var_name
);
CHECK
(
feed_var
);
...
...
@@ -48,7 +48,7 @@ class FeedOp : public OpLite {
// NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc
param_
.
col
=
GetAttr
<
int
>
(
opdesc
.
GetAttr
(
"col"
)
);
param_
.
col
=
opdesc
.
GetAttr
<
int
>
(
"col"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/fetch_op.cc
浏览文件 @
65be35af
...
...
@@ -33,7 +33,7 @@ class FetchOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
_x
=
opdesc
.
Input
(
"X"
).
front
();
auto
*
x
=
scope
->
FindVar
(
_x
);
CHECK
(
x
);
...
...
@@ -43,7 +43,7 @@ class FetchOp : public OpLite {
auto
*
out
=
scope
->
FindVar
(
_out
);
param_
.
fetch_list
=
out
->
GetMutable
<
std
::
vector
<
lite
::
Tensor
>>
();
param_
.
col
=
GetAttr
<
int
>
(
opdesc
.
GetAttr
(
"col"
)
);
param_
.
col
=
opdesc
.
GetAttr
<
int
>
(
"col"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/fill_constant_op.cc
浏览文件 @
65be35af
...
...
@@ -33,14 +33,14 @@ class FillConstantOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
param_
.
Out
=
GetMutableVar
<
Tensor
>
(
scope
,
Out_name
);
param_
.
dtype
=
GetAttr
<
int
>
(
opdesc
.
GetAttr
(
"dtype"
)
);
param_
.
shape
=
GetAttr
<
std
::
vector
<
int64_t
>>
(
opdesc
.
GetAttr
(
"shape"
)
);
param_
.
value
=
GetAttr
<
float
>
(
opdesc
.
GetAttr
(
"value"
)
);
param_
.
force_cpu
=
GetAttr
<
bool
>
(
opdesc
.
GetAttr
(
"force_cpu"
)
);
param_
.
dtype
=
opdesc
.
GetAttr
<
int
>
(
"dtype"
);
param_
.
shape
=
opdesc
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"shape"
);
param_
.
value
=
opdesc
.
GetAttr
<
float
>
(
"value"
);
param_
.
force_cpu
=
opdesc
.
GetAttr
<
bool
>
(
"force_cpu"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/io_copy_op.cc
浏览文件 @
65be35af
...
...
@@ -29,7 +29,8 @@ bool IoCopyOp::InferShape() const {
return
true
;
}
bool
IoCopyOp
::
Run
()
{
return
OpLite
::
Run
();
}
bool
IoCopyOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
paddle
::
lite
::
Scope
*
scope
)
{
bool
IoCopyOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
paddle
::
lite
::
Scope
*
scope
)
{
auto
x
=
opdesc
.
Input
(
"Input"
).
front
();
auto
out
=
opdesc
.
Output
(
"Out"
).
front
();
param_
.
x
=
GetTensor
(
scope
,
x
);
...
...
paddle/fluid/lite/operators/io_copy_op.h
浏览文件 @
65be35af
...
...
@@ -31,7 +31,7 @@ class IoCopyOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
private:
operators
::
IoCopyParam
param_
;
...
...
paddle/fluid/lite/operators/mean_op.cc
浏览文件 @
65be35af
...
...
@@ -37,7 +37,7 @@ class MeanOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
...
...
@@ -72,8 +72,8 @@ class MeanGradOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
CHECK_EQ
(
opdesc
.
Inputs
().
size
(),
3UL
);
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
CHECK_EQ
(
opdesc
.
Input
ArgumentName
s
().
size
(),
3UL
);
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Out_grad_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_grad_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
...
...
paddle/fluid/lite/operators/mul_op.cc
浏览文件 @
65be35af
...
...
@@ -85,7 +85,7 @@ bool MulGradOpLite::InferShape() const {
return
true
;
}
bool
MulGradOpLite
::
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
bool
MulGradOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
X_name
=
op_desc
.
Input
(
"X"
).
front
();
auto
Y_name
=
op_desc
.
Input
(
"Y"
).
front
();
auto
Out_grad_name
=
op_desc
.
Output
(
framework
::
GradVarName
(
"Out"
)).
front
();
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
65be35af
...
...
@@ -37,7 +37,7 @@ class MulOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
...
...
@@ -49,8 +49,8 @@ class MulOpLite : public OpLite {
param_
.
y
=
var
->
GetMutable
<
Tensor
>
();
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
x_num_col_dims
=
GetAttr
<
int
>
(
op_desc
.
GetAttr
(
"x_num_col_dims"
)
);
param_
.
y_num_col_dims
=
GetAttr
<
int
>
(
op_desc
.
GetAttr
(
"y_num_col_dims"
)
);
param_
.
x_num_col_dims
=
op_desc
.
GetAttr
<
int
>
(
"x_num_col_dims"
);
param_
.
y_num_col_dims
=
op_desc
.
GetAttr
<
int
>
(
"y_num_col_dims"
);
return
true
;
}
...
...
@@ -73,7 +73,7 @@ class MulGradOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
std
::
string
DebugString
()
const
override
{
return
"mul_grad"
;
}
...
...
paddle/fluid/lite/operators/relu_op.cc
浏览文件 @
65be35af
...
...
@@ -30,7 +30,7 @@ bool ReluOp::InferShape() const {
return
true
;
}
bool
ReluOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
ReluOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
input
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"Input"
).
front
())
->
Get
<
lite
::
Tensor
>
());
param_
.
output
=
...
...
paddle/fluid/lite/operators/relu_op.h
浏览文件 @
65be35af
...
...
@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"relu"
;
}
...
...
paddle/fluid/lite/operators/reshape_op.cc
浏览文件 @
65be35af
...
...
@@ -33,7 +33,7 @@ bool ReshapeOp::InferShape() const {
return
true
;
}
bool
ReshapeOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
ReshapeOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
auto
x_var
=
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
());
auto
output_var
=
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
());
CHECK
(
x_var
);
...
...
@@ -49,9 +49,9 @@ bool ReshapeOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) {
const_cast
<
lite
::
Tensor
*>
(
&
(
actual_shape_var
->
Get
<
lite
::
Tensor
>
()));
}
}
param_
.
shape
=
GetAttr
<
std
::
vector
<
int
>>
(
opdesc
.
GetAttr
(
"shape"
));
param_
.
shape
=
(
opdesc
.
GetAttr
<
std
::
vector
<
int
>>
(
"shape"
));
if
(
opdesc
.
HasAttr
(
"inplace"
))
{
param_
.
inplace
=
GetAttr
<
bool
>
(
opdesc
.
GetAttr
(
"inplace"
)
);
param_
.
inplace
=
opdesc
.
GetAttr
<
bool
>
(
"inplace"
);
}
CHECK
(
param_
.
x
)
<<
"Input(X) of ReshapeOp should not be null."
;
CHECK
(
param_
.
output
)
<<
"Output(Out) of ReshapeOp should not be null."
;
...
...
@@ -70,14 +70,14 @@ bool Reshape2Op::InferShape() const {
ReshapeOp
::
InferShape
();
auto
x_dims
=
param_
.
x
->
dims
();
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
0
);
for
(
in
t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
xshape_dims
[
i
+
1
]
=
x_dims
[
i
];
}
param_
.
xshape
->
Resize
(
DDim
(
xshape_dims
));
return
true
;
}
bool
Reshape2Op
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
Reshape2Op
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
ReshapeOp
::
AttachImpl
(
opdesc
,
scope
);
auto
xshape_var
=
scope
->
FindVar
(
opdesc
.
Output
(
"XShape"
).
front
());
CHECK
(
xshape_var
);
...
...
paddle/fluid/lite/operators/reshape_op.h
浏览文件 @
65be35af
...
...
@@ -32,7 +32,7 @@ class ReshapeOp : public OpLite {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"reshape"
;
}
...
...
@@ -50,7 +50,7 @@ class Reshape2Op : public ReshapeOp {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"reshape2"
;
}
...
...
paddle/fluid/lite/operators/reshape_op_test.cc
浏览文件 @
65be35af
...
...
@@ -47,7 +47,7 @@ TEST(reshape_op_lite, test) {
for
(
auto
&
has_actual_shape
:
{
true
,
false
})
{
for
(
auto
&
inplace
:
{
true
,
false
})
{
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"reshape"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
if
(
has_actual_shape
)
{
...
...
@@ -68,7 +68,7 @@ TEST(reshape_op_lite, test) {
// check output dims
auto
output_dims
=
output
->
dims
();
CHECK_EQ
(
output_dims
.
size
(),
shape
.
second
.
size
());
for
(
in
t
i
=
0
;
i
<
output_dims
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
output_dims
.
size
();
i
++
)
{
CHECK_EQ
(
output_dims
[
i
],
shape
.
second
[
i
]);
}
}
...
...
@@ -102,7 +102,7 @@ TEST(reshape2_op_lite, test) {
for
(
auto
&
has_actual_shape
:
{
true
,
false
})
{
for
(
auto
&
inplace
:
{
true
,
false
})
{
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"reshape"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
if
(
has_actual_shape
)
{
...
...
@@ -132,7 +132,7 @@ TEST(reshape2_op_lite, test) {
auto
xshape_dims
=
xshape
->
dims
();
CHECK_EQ
(
xshape_dims
.
size
(),
x_dims
.
size
()
+
1
);
CHECK_EQ
(
xshape_dims
[
0
],
0
);
for
(
in
t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
CHECK_EQ
(
xshape_dims
[
i
+
1
],
x_dims
[
i
]);
}
}
...
...
paddle/fluid/lite/operators/scale_op.cc
浏览文件 @
65be35af
...
...
@@ -29,14 +29,14 @@ bool ScaleOp::InferShape() const {
return
true
;
}
bool
ScaleOp
::
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
bool
ScaleOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
output
=
op_desc
.
Output
(
"Out"
).
front
();
param_
.
x
=
scope
->
FindVar
(
x
)
->
GetMutable
<
Tensor
>
();
param_
.
output
=
scope
->
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
param_
.
scale
=
GetAttr
<
float
>
(
op_desc
.
GetAttr
(
"scale"
)
);
param_
.
bias
=
GetAttr
<
float
>
(
op_desc
.
GetAttr
(
"bias"
)
);
param_
.
bias_after_scale
=
GetAttr
<
bool
>
(
op_desc
.
GetAttr
(
"bias_after_scale"
)
);
param_
.
scale
=
op_desc
.
GetAttr
<
float
>
(
"scale"
);
param_
.
bias
=
op_desc
.
GetAttr
<
float
>
(
"bias"
);
param_
.
bias_after_scale
=
op_desc
.
GetAttr
<
bool
>
(
"bias_after_scale"
);
CHECK
(
param_
.
x
);
CHECK
(
param_
.
output
);
return
true
;
...
...
paddle/fluid/lite/operators/scale_op.h
浏览文件 @
65be35af
...
...
@@ -32,7 +32,7 @@ class ScaleOp : public OpLite {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"scale"
;
}
...
...
paddle/fluid/lite/operators/scale_op_test.cc
浏览文件 @
65be35af
...
...
@@ -29,7 +29,7 @@ TEST(scale_op_lite, test) {
output
->
Resize
(
DDim
(
std
::
vector
<
int64_t
>
{
1
,
1
}));
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"scale"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
...
...
@@ -48,7 +48,7 @@ TEST(scale_op_lite, test) {
auto
x_dims
=
x
->
dims
();
auto
output_dims
=
output
->
dims
();
CHECK_EQ
(
output_dims
.
size
(),
x_dims
.
size
());
for
(
in
t
i
=
0
;
i
<
output_dims
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
output_dims
.
size
();
i
++
)
{
CHECK_EQ
(
output_dims
[
i
],
x_dims
[
i
]);
}
}
...
...
paddle/fluid/lite/operators/softmax_op.cc
浏览文件 @
65be35af
...
...
@@ -24,7 +24,8 @@ bool SoftmaxOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
output
);
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_rank
=
x_dims
.
size
();
CHECK_OR_FALSE
(
param_
.
axis
>=
-
x_rank
&&
param_
.
axis
<
x_rank
);
CHECK_OR_FALSE
(
param_
.
axis
>=
-
static_cast
<
int
>
(
x_rank
)
&&
param_
.
axis
<
static_cast
<
int
>
(
x_rank
));
return
true
;
}
...
...
@@ -33,12 +34,12 @@ bool SoftmaxOp::InferShape() const {
return
true
;
}
bool
SoftmaxOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
SoftmaxOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
x
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
())
->
Get
<
lite
::
Tensor
>
());
param_
.
output
=
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
axis
=
GetAttr
<
int
>
(
opdesc
.
GetAttr
(
"axis"
)
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
CHECK
(
param_
.
x
);
CHECK
(
param_
.
output
);
return
true
;
...
...
paddle/fluid/lite/operators/softmax_op.h
浏览文件 @
65be35af
...
...
@@ -32,7 +32,7 @@ class SoftmaxOp : public OpLite {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"softmax"
;
}
...
...
paddle/fluid/lite/operators/softmax_op_test.cc
浏览文件 @
65be35af
...
...
@@ -37,7 +37,7 @@ TEST(softmax_op_lite, test) {
}
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"softmax"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录