Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ae60589f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ae60589f
编写于
6月 04, 2019
作者:
Y
Yan Chunwei
提交者:
GitHub
6月 04, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Lite/refactor cp desc (#17831)
上级
5868a6ca
变更
45
显示空白变更内容
内联
并排
Showing
45 changed file
with
774 addition
and
331 deletion
+774
-331
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+3
-1
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
+5
-7
paddle/fluid/lite/core/mir/type_target_transform_pass.h
paddle/fluid/lite/core/mir/type_target_transform_pass.h
+3
-3
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+2
-2
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+3
-92
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+57
-36
paddle/fluid/lite/core/program.cc
paddle/fluid/lite/core/program.cc
+4
-4
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+6
-2
paddle/fluid/lite/model_parser/CMakeLists.txt
paddle/fluid/lite/model_parser/CMakeLists.txt
+4
-5
paddle/fluid/lite/model_parser/compatible_pb.cc
paddle/fluid/lite/model_parser/compatible_pb.cc
+111
-0
paddle/fluid/lite/model_parser/compatible_pb.h
paddle/fluid/lite/model_parser/compatible_pb.h
+7
-18
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
+1
-0
paddle/fluid/lite/model_parser/cpp/op_desc.cc
paddle/fluid/lite/model_parser/cpp/op_desc.cc
+66
-0
paddle/fluid/lite/model_parser/cpp/op_desc.h
paddle/fluid/lite/model_parser/cpp/op_desc.h
+126
-0
paddle/fluid/lite/model_parser/desc_apis.h
paddle/fluid/lite/model_parser/desc_apis.h
+85
-0
paddle/fluid/lite/model_parser/op_desc_test.cc
paddle/fluid/lite/model_parser/op_desc_test.cc
+107
-0
paddle/fluid/lite/model_parser/pb/op_desc.cc
paddle/fluid/lite/model_parser/pb/op_desc.cc
+80
-19
paddle/fluid/lite/model_parser/pb/op_desc.h
paddle/fluid/lite/model_parser/pb/op_desc.h
+38
-78
paddle/fluid/lite/operators/activation_ops.cc
paddle/fluid/lite/operators/activation_ops.cc
+2
-2
paddle/fluid/lite/operators/concat_op.cc
paddle/fluid/lite/operators/concat_op.cc
+2
-2
paddle/fluid/lite/operators/concat_op.h
paddle/fluid/lite/operators/concat_op.h
+1
-1
paddle/fluid/lite/operators/concat_op_test.cc
paddle/fluid/lite/operators/concat_op_test.cc
+1
-1
paddle/fluid/lite/operators/dropout_op.cc
paddle/fluid/lite/operators/dropout_op.cc
+6
-6
paddle/fluid/lite/operators/elementwise_ops.cc
paddle/fluid/lite/operators/elementwise_ops.cc
+5
-5
paddle/fluid/lite/operators/fc_op.h
paddle/fluid/lite/operators/fc_op.h
+2
-2
paddle/fluid/lite/operators/fc_op_test.cc
paddle/fluid/lite/operators/fc_op_test.cc
+1
-1
paddle/fluid/lite/operators/feed_op.cc
paddle/fluid/lite/operators/feed_op.cc
+2
-2
paddle/fluid/lite/operators/fetch_op.cc
paddle/fluid/lite/operators/fetch_op.cc
+2
-2
paddle/fluid/lite/operators/fill_constant_op.cc
paddle/fluid/lite/operators/fill_constant_op.cc
+5
-5
paddle/fluid/lite/operators/io_copy_op.cc
paddle/fluid/lite/operators/io_copy_op.cc
+2
-1
paddle/fluid/lite/operators/io_copy_op.h
paddle/fluid/lite/operators/io_copy_op.h
+1
-1
paddle/fluid/lite/operators/mean_op.cc
paddle/fluid/lite/operators/mean_op.cc
+3
-3
paddle/fluid/lite/operators/mul_op.cc
paddle/fluid/lite/operators/mul_op.cc
+1
-1
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+4
-4
paddle/fluid/lite/operators/relu_op.cc
paddle/fluid/lite/operators/relu_op.cc
+1
-1
paddle/fluid/lite/operators/relu_op.h
paddle/fluid/lite/operators/relu_op.h
+1
-1
paddle/fluid/lite/operators/reshape_op.cc
paddle/fluid/lite/operators/reshape_op.cc
+5
-5
paddle/fluid/lite/operators/reshape_op.h
paddle/fluid/lite/operators/reshape_op.h
+2
-2
paddle/fluid/lite/operators/reshape_op_test.cc
paddle/fluid/lite/operators/reshape_op_test.cc
+4
-4
paddle/fluid/lite/operators/scale_op.cc
paddle/fluid/lite/operators/scale_op.cc
+4
-4
paddle/fluid/lite/operators/scale_op.h
paddle/fluid/lite/operators/scale_op.h
+1
-1
paddle/fluid/lite/operators/scale_op_test.cc
paddle/fluid/lite/operators/scale_op_test.cc
+2
-2
paddle/fluid/lite/operators/softmax_op.cc
paddle/fluid/lite/operators/softmax_op.cc
+4
-3
paddle/fluid/lite/operators/softmax_op.h
paddle/fluid/lite/operators/softmax_op.h
+1
-1
paddle/fluid/lite/operators/softmax_op_test.cc
paddle/fluid/lite/operators/softmax_op_test.cc
+1
-1
未找到文件。
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
ae60589f
...
...
@@ -25,7 +25,9 @@ cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite)
cc_library
(
scope_lite SRCS scope.cc DEPS
${
tensor_lite
}
)
cc_library
(
cpu_info_lite SRCS cpu_info.cc
)
cc_library
(
context_lite SRCS context.cc DEPS
${
tensor_lite
}
any_lite cpu_info_lite
)
cc_library
(
op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite target_wrapper_lite
${
tensor_lite
}
)
cc_library
(
op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapper_lite
cpp_op_desc_lite
${
tensor_lite
}
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS
${
tensor_lite
}
target_wrapper_lite
)
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
浏览文件 @
ae60589f
...
...
@@ -90,7 +90,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
inst_node
->
AsStmt
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
lite
::
OpDesc
op_desc
;
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"io_copy"
);
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
...
...
@@ -104,8 +104,6 @@ void TypeTargetTransformPass::AddIoCopyInst(
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
auto
&
inst
=
inst_node
->
AsStmt
();
auto
inst_program_desc
=
inst
.
op_info
()
->
desc
();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink
(
graph
->
Argument
(
var
),
io_copy_inst
);
...
...
@@ -113,11 +111,11 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
auto
desc_dummy
=
inst_node
->
AsStmt
().
op
->
op_info
()
->
desc
();
UpdateInputTo
(
&
desc_dummy
,
var
,
io_copy_output_name
);
UpdateInputTo
(
inst_node
->
AsStmt
().
op
->
mutable_op_info
(),
var
,
io_copy_output_name
);
lite
::
OpDesc
desc_fake
(
desc_dummy
);
inst_node
->
AsStmt
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsStmt
().
op
->
scope
());
inst_node
->
AsStmt
().
op
->
Attach
(
*
inst_node
->
AsStmt
().
op
->
op_info
(),
inst_node
->
AsStmt
().
op
->
scope
());
std
::
string
tmp
;
if
(
inst_node
->
AsStmt
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.h
浏览文件 @
ae60589f
...
...
@@ -24,10 +24,10 @@ namespace paddle {
namespace
lite
{
namespace
mir
{
static
void
UpdateInputTo
(
framework
::
proto
::
OpDesc
*
desc
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
static
void
UpdateInputTo
(
cpp
::
OpDesc
*
desc
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
input
:
*
item
.
mutable_arguments
()
)
{
for
(
auto
&
input
:
item
.
second
)
{
if
(
input
==
from
)
{
input
=
to
;
}
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
ae60589f
...
...
@@ -65,7 +65,7 @@ class VariablePlaceInferencePass : public DebugPass {
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
auto
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
input
_argument
().
at
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
input
s
().
at
(
arg_name
);
for
(
auto
&
arg_name
:
arg_names
)
{
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
...
...
@@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass {
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
output_argnames
())
{
VLOG
(
3
)
<<
"-- output arg_name "
<<
arg_name
;
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
output
_argument
().
at
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
output
s
().
at
(
arg_name
);
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
...
...
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
ae60589f
...
...
@@ -68,13 +68,13 @@ bool OpLite::Run() {
return
true
;
}
bool
OpLite
::
Attach
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
OpLite
::
Attach
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
// valid_places_.clear();
CHECK
(
scope
!=
nullptr
);
// CHECK(!op_info_.get());
scope_
=
scope
;
op_info_
.
reset
(
new
OpInfo
);
// Force clean the out-of-date infomation.
op_info_
->
Build
(
opdesc
.
ReadonlyProto
());
op_info_
.
reset
(
new
OpInfo
(
opdesc
));
// Force clean the out-of-date infomation.
return
AttachImpl
(
opdesc
,
scope
);
}
...
...
@@ -92,94 +92,5 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope,
return
var
->
GetMutable
<
lite
::
Tensor
>
();
}
bool
OpInfo
::
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
input_argument_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
item
.
first
;
return
true
;
}
}
return
false
;
}
bool
OpInfo
::
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
output_argument_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
item
.
first
;
return
true
;
}
}
return
false
;
}
void
OpInfo
::
ExtractInputsAndOutputs
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
inputs
())
{
for
(
const
auto
&
x
:
item
.
arguments
())
{
input_names_
.
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
outputs
())
{
for
(
const
auto
&
x
:
item
.
arguments
())
{
output_names_
.
push_back
(
x
);
}
}
}
void
OpInfo
::
CollectInputAndOutputArgnames
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
inputs
())
{
input_argnames_
.
push_back
(
item
.
parameter
());
}
for
(
const
auto
&
item
:
opdesc
.
outputs
())
{
output_argnames_
.
push_back
(
item
.
parameter
());
}
}
void
OpInfo
::
CollectArguments
(
const
framework
::
proto
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
inputs
())
{
for
(
auto
&
x
:
item
.
arguments
())
{
input_argument_
[
item
.
parameter
()].
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
outputs
())
{
for
(
auto
&
x
:
item
.
arguments
())
{
output_argument_
[
item
.
parameter
()].
push_back
(
x
);
}
}
}
void
OpInfo
::
Build
(
const
framework
::
proto
::
OpDesc
&
desc
)
{
ExtractInputsAndOutputs
(
desc
);
CollectInputAndOutputArgnames
(
desc
);
CollectArguments
(
desc
);
desc_
.
reset
(
new
framework
::
proto
::
OpDesc
(
desc
));
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
OpInfo
::
input_argument
()
const
{
return
input_argument_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
OpInfo
::
output_argument
()
const
{
return
output_argument_
;
}
const
std
::
list
<
std
::
string
>
&
OpInfo
::
input_argnames
()
const
{
return
input_argnames_
;
}
const
std
::
list
<
std
::
string
>
&
OpInfo
::
output_argnames
()
const
{
return
output_argnames_
;
}
const
framework
::
proto
::
OpDesc
&
OpInfo
::
desc
()
const
{
CHECK
(
desc_
)
<<
"desc has't set"
;
return
*
desc_
;
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_lite.h
浏览文件 @
ae60589f
...
...
@@ -23,7 +23,7 @@
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/model_parser/c
ompatible_pb
.h"
#include "paddle/fluid/lite/model_parser/c
pp/op_desc
.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -71,7 +71,7 @@ class OpLite : public Registry {
virtual
bool
Run
();
// Link the external execution environ to internal context.
bool
Attach
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
bool
Attach
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
const
OpInfo
*
op_info
()
const
{
return
op_info_
.
get
();
}
OpInfo
*
mutable_op_info
()
{
return
op_info_
.
get
();
}
...
...
@@ -94,7 +94,7 @@ class OpLite : public Registry {
protected:
// Attach it with the runtime environment.
virtual
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
virtual
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
...
...
@@ -144,40 +144,61 @@ class OpLite : public Registry {
* Operator Information, such as some description. It will be shared by all the
* kernels of the same operator.
*/
class
OpInfo
{
class
OpInfo
:
public
cpp
::
OpDesc
{
public:
// To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf
// message instead.
void
Build
(
const
framework
::
proto
::
OpDesc
&
desc
);
const
framework
::
proto
::
OpDesc
&
desc
()
const
;
framework
::
proto
::
OpDesc
*
mutable_desc
()
{
return
desc_
.
get
();
}
const
std
::
list
<
std
::
string
>
&
input_names
()
const
{
return
input_names_
;
}
const
std
::
list
<
std
::
string
>
&
output_names
()
const
{
return
output_names_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
input_argument
()
const
;
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
output_argument
()
const
;
bool
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
;
bool
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
;
const
std
::
list
<
std
::
string
>
&
input_argnames
()
const
;
const
std
::
list
<
std
::
string
>
&
output_argnames
()
const
;
private:
void
ExtractInputsAndOutputs
(
const
framework
::
proto
::
OpDesc
&
opdesc
);
void
CollectInputAndOutputArgnames
(
const
framework
::
proto
::
OpDesc
&
opdesc
);
void
CollectArguments
(
const
framework
::
proto
::
OpDesc
&
opdesc
);
private:
std
::
list
<
std
::
string
>
input_names_
;
std
::
list
<
std
::
string
>
output_names_
;
std
::
list
<
std
::
string
>
input_argnames_
;
std
::
list
<
std
::
string
>
output_argnames_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
input_argument_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
output_argument_
;
// NOTE too heavy.
std
::
unique_ptr
<
framework
::
proto
::
OpDesc
>
desc_
;
OpInfo
(
const
OpInfo
&
)
=
default
;
OpInfo
(
const
cpp
::
OpDesc
&
other
)
:
cpp
::
OpDesc
(
other
)
{}
// Collect all the input variable's name.
std
::
vector
<
std
::
string
>
input_names
()
const
{
std
::
vector
<
std
::
string
>
res
;
for
(
auto
&
param
:
InputArgumentNames
())
{
for
(
auto
&
x
:
Input
(
param
))
{
res
.
push_back
(
x
);
}
}
return
res
;
}
// Collect all the output variable's name.
std
::
vector
<
std
::
string
>
output_names
()
const
{
std
::
vector
<
std
::
string
>
res
;
for
(
auto
&
param
:
OutputArgumentNames
())
{
for
(
auto
&
x
:
Output
(
param
))
{
res
.
push_back
(
x
);
}
}
return
res
;
}
std
::
vector
<
std
::
string
>
input_argnames
()
const
{
return
InputArgumentNames
();
}
std
::
vector
<
std
::
string
>
output_argnames
()
const
{
return
OutputArgumentNames
();
}
bool
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
inputs_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
item
.
first
;
return
true
;
}
}
return
false
;
}
bool
GetOutputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
const
{
for
(
auto
&
item
:
outputs_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
if
(
it
!=
item
.
second
.
end
())
{
*
out
=
item
.
first
;
return
true
;
}
}
return
false
;
}
};
}
// namespace lite
...
...
paddle/fluid/lite/core/program.cc
浏览文件 @
ae60589f
...
...
@@ -39,11 +39,11 @@ std::string RuntimeProgram::SerializeProgram(
auto
program_dummy
=
desc
;
program_dummy
.
mutable_blocks
(
0
)
->
clear_ops
();
for
(
auto
&
node
:
instructions_
)
{
auto
desc_dummy
=
node
.
op
()
->
op_info
()
->
desc
()
;
OpDesc
desc
(
desc_dummy
);
desc
.
SetAttr
(
kKernelTypeAttr
,
node
.
kernel
()
->
SerializedKernelType
());
pb
::
OpDesc
pb_desc
;
TransformOpDescCppToPb
(
*
node
.
op
()
->
op_info
(),
&
pb_desc
);
pb_
desc
.
SetAttr
(
kKernelTypeAttr
,
node
.
kernel
()
->
SerializedKernelType
());
// append new opdesc
*
program_dummy
.
mutable_blocks
(
0
)
->
add_ops
()
=
*
desc
.
Proto
();
*
program_dummy
.
mutable_blocks
(
0
)
->
add_ops
()
=
*
pb_
desc
.
Proto
();
}
return
program_dummy
.
SerializeAsString
();
}
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
ae60589f
...
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#ifdef LITE_WITH_PROFILE
#include "paddle/fluid/lite/core/profile/basic_profiler.h"
#endif // LITE_WITH_PROFILE
...
...
@@ -67,7 +68,7 @@ struct Program {
CHECK
(
ops
.
empty
())
<<
"Executor duplicate Build found"
;
// Create operators.
for
(
const
auto
&
proto_op_desc
:
program
.
blocks
(
0
).
ops
())
{
lite
::
OpDesc
op_desc
(
proto_op_desc
);
pb
::
OpDesc
op_desc
(
proto_op_desc
);
auto
op_type
=
op_desc
.
Type
();
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG
(
4
)
<<
"create Op ["
<<
op_type
<<
"]"
;
...
...
@@ -75,7 +76,10 @@ struct Program {
auto
op
=
LiteOpRegistry
::
Global
().
Create
(
op_type
);
CHECK
(
op
)
<<
"no Op found for "
<<
op_type
;
ops
.
emplace_back
(
std
::
move
(
op
));
ops
.
back
()
->
Attach
(
op_desc
,
exec_scope
);
cpp
::
OpDesc
cpp_op_desc
;
TransformOpDescPbToCpp
(
op_desc
,
&
cpp_op_desc
);
ops
.
back
()
->
Attach
(
cpp_op_desc
,
exec_scope
);
}
}
...
...
paddle/fluid/lite/model_parser/CMakeLists.txt
浏览文件 @
ae60589f
...
...
@@ -11,11 +11,7 @@ if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
endif
()
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite
)
else
()
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto_lite proto_desc
)
endif
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite
)
set
(
model_parser_deps variable_lite scope_lite
${
tensor_lite
}
scope_lite
target_wrapper_host
...
...
@@ -27,4 +23,7 @@ if (LITE_WITH_CUDA)
endif
()
cc_library
(
model_parser_lite SRCS model_parser.cc DEPS
${
model_parser_deps
}
)
cc_test
(
test_op_desc_lite SRCS op_desc_test.cc DEPS cpp_op_desc_lite any_lite op_desc_lite compatible_pb_lite
)
add_subdirectory
(
pb
)
add_subdirectory
(
cpp
)
paddle/fluid/lite/model_parser/compatible_pb.cc
浏览文件 @
ae60589f
...
...
@@ -13,3 +13,114 @@
// limitations under the License.
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#include "compatible_pb.h"
namespace
paddle
{
namespace
lite
{
void
InputsPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
)
{
for
(
const
std
::
string
&
param
:
pb_desc
.
InputArgumentNames
())
{
cpp_desc
->
SetInput
(
param
,
pb_desc
.
Input
(
param
));
}
}
void
InputsCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
)
{
for
(
const
std
::
string
&
param
:
cpp_desc
.
InputArgumentNames
())
{
pb_desc
->
SetInput
(
param
,
cpp_desc
.
Input
(
param
));
}
}
void
OutputsPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
)
{
for
(
const
std
::
string
&
param
:
pb_desc
.
OutputArgumentNames
())
{
cpp_desc
->
SetOutput
(
param
,
pb_desc
.
Output
(
param
));
}
}
void
OutputsCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
)
{
for
(
const
std
::
string
&
param
:
cpp_desc
.
OutputArgumentNames
())
{
pb_desc
->
SetOutput
(
param
,
cpp_desc
.
Output
(
param
));
}
}
void
AttrsPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
)
{
using
AttrType
=
OpDescAPI
::
AttrType
;
auto
set_attr
=
[
&
](
const
std
::
string
&
name
,
AttrType
type
)
{
switch
(
type
)
{
case
AttrType
::
INT
:
cpp_desc
->
SetAttr
<
int32_t
>
(
name
,
pb_desc
.
GetAttr
<
int32_t
>
(
name
));
break
;
case
AttrType
::
FLOAT
:
cpp_desc
->
SetAttr
<
float
>
(
name
,
pb_desc
.
GetAttr
<
float
>
(
name
));
break
;
case
AttrType
::
STRING
:
cpp_desc
->
SetAttr
<
std
::
string
>
(
name
,
pb_desc
.
GetAttr
<
std
::
string
>
(
name
));
break
;
case
AttrType
::
INTS
:
cpp_desc
->
SetAttr
<
std
::
vector
<
int
>>
(
name
,
pb_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
name
));
break
;
case
AttrType
::
FLOATS
:
cpp_desc
->
SetAttr
<
std
::
vector
<
float
>>
(
name
,
pb_desc
.
GetAttr
<
std
::
vector
<
float
>>
(
name
));
break
;
case
AttrType
::
BOOLEAN
:
cpp_desc
->
SetAttr
<
bool
>
(
name
,
pb_desc
.
GetAttr
<
bool
>
(
name
));
break
;
case
AttrType
::
STRINGS
:
cpp_desc
->
SetAttr
<
std
::
vector
<
std
::
string
>>
(
name
,
pb_desc
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
));
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported attr type found "
<<
static_cast
<
int
>
(
type
);
}
};
for
(
const
auto
&
attr_name
:
pb_desc
.
AttrNames
())
{
auto
type
=
pb_desc
.
GetAttrType
(
attr_name
);
set_attr
(
attr_name
,
type
);
}
}
void
AttrsCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
)
{
using
AttrType
=
OpDescAPI
::
AttrType
;
auto
set_attr
=
[
&
](
const
std
::
string
&
name
,
AttrType
type
)
{
switch
(
type
)
{
#define IMPL_ONE(type__, T) \
case AttrType::type__: \
pb_desc->SetAttr<T>(name, cpp_desc.GetAttr<T>(name)); \
break;
IMPL_ONE
(
INT
,
int32_t
);
IMPL_ONE
(
FLOAT
,
float
);
IMPL_ONE
(
STRING
,
std
::
string
);
IMPL_ONE
(
STRINGS
,
std
::
vector
<
std
::
string
>
);
IMPL_ONE
(
FLOATS
,
std
::
vector
<
float
>
);
IMPL_ONE
(
INTS
,
std
::
vector
<
int
>
);
IMPL_ONE
(
BOOLEAN
,
bool
);
default:
LOG
(
FATAL
)
<<
"Unsupported attr type found: "
<<
static_cast
<
int
>
(
type
);
}
};
#undef IMPL_ONE
for
(
const
auto
&
attr_name
:
cpp_desc
.
AttrNames
())
{
auto
type
=
cpp_desc
.
GetAttrType
(
attr_name
);
set_attr
(
attr_name
,
type
);
}
}
void
TransformOpDescPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
)
{
cpp_desc
->
SetType
(
pb_desc
.
Type
());
InputsPbToCpp
(
pb_desc
,
cpp_desc
);
OutputsPbToCpp
(
pb_desc
,
cpp_desc
);
AttrsPbToCpp
(
pb_desc
,
cpp_desc
);
}
void
TransformOpDescCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
)
{
pb_desc
->
SetType
(
cpp_desc
.
Type
());
InputsCppToPb
(
cpp_desc
,
pb_desc
);
OutputsCppToPb
(
cpp_desc
,
pb_desc
);
AttrsCppToPb
(
cpp_desc
,
pb_desc
);
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/compatible_pb.h
浏览文件 @
ae60589f
...
...
@@ -20,39 +20,28 @@
* lite::pb::XXDesc.
*/
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/var_desc.h"
#else
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
namespace
paddle
{
namespace
lite
{
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using
Attribute
=
lite
::
pb
::
Attribute
;
using
OpDesc
=
lite
::
pb
::
OpDesc
;
using
VarDesc
=
lite
::
pb
::
VarDesc
;
#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using
Attribute
=
framework
::
Attribute
;
using
OpDesc
=
framework
::
OpDesc
;
using
VarDesc
=
framework
::
VarDesc
;
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
template
<
typename
T
>
T
GetAttr
(
const
Attribute
&
x
)
{
return
x
.
get
<
T
>
();
}
#else
template
<
typename
T
>
T
GetAttr
(
const
Attribute
&
x
)
{
return
boost
::
get
<
T
>
(
x
);
}
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
/// Transform an OpDesc from pb to cpp format.
void
TransformOpDescPbToCpp
(
const
pb
::
OpDesc
&
pb_desc
,
cpp
::
OpDesc
*
cpp_desc
);
/// Transform an OpDesc from cpp to pb format.
void
TransformOpDescCppToPb
(
const
cpp
::
OpDesc
&
cpp_desc
,
pb
::
OpDesc
*
pb_desc
);
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/cpp/CMakeLists.txt
0 → 100644
浏览文件 @
ae60589f
cc_library
(
cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite
)
paddle/fluid/lite/model_parser/cpp/op_desc.cc
0 → 100644
浏览文件 @
ae60589f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include <set>
namespace
paddle
{
namespace
lite
{
namespace
cpp
{
#define SET_ATTR_IMPL(T, repr__) \
template <> \
void OpDesc::SetAttr<T>(const std::string& name, const T& v) { \
attr_types_[name] = AttrType::repr__; \
attrs_[name].set<T>(v); \
}
SET_ATTR_IMPL
(
int32_t
,
INT
);
SET_ATTR_IMPL
(
float
,
FLOAT
);
SET_ATTR_IMPL
(
std
::
string
,
STRING
);
SET_ATTR_IMPL
(
bool
,
BOOLEAN
);
SET_ATTR_IMPL
(
std
::
vector
<
int
>
,
INTS
);
SET_ATTR_IMPL
(
std
::
vector
<
float
>
,
FLOATS
);
SET_ATTR_IMPL
(
std
::
vector
<
std
::
string
>
,
STRINGS
);
std
::
pair
<
OpDesc
::
attrs_t
::
const_iterator
,
OpDesc
::
attr_types_t
::
const_iterator
>
FindAttr
(
const
cpp
::
OpDesc
&
desc
,
const
std
::
string
&
name
)
{
auto
it
=
desc
.
attrs
().
find
(
name
);
CHECK
(
it
!=
desc
.
attrs
().
end
())
<<
"No attributes called "
<<
name
<<
" found"
;
auto
attr_it
=
desc
.
attr_types
().
find
(
name
);
CHECK
(
attr_it
!=
desc
.
attr_types
().
end
());
return
std
::
make_pair
(
it
,
attr_it
);
}
#define GET_IMPL_ONE(T, repr__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
auto pair = FindAttr(*this, name); \
CHECK(pair.second->second == AttrType::repr__); \
return pair.first->second.get<T>(); \
}
GET_IMPL_ONE
(
int32_t
,
INT
);
GET_IMPL_ONE
(
float
,
FLOAT
);
GET_IMPL_ONE
(
std
::
string
,
STRING
);
GET_IMPL_ONE
(
bool
,
BOOLEAN
);
GET_IMPL_ONE
(
std
::
vector
<
int64_t
>
,
LONGS
);
GET_IMPL_ONE
(
std
::
vector
<
float
>
,
FLOATS
);
GET_IMPL_ONE
(
std
::
vector
<
int
>
,
INTS
);
GET_IMPL_ONE
(
std
::
vector
<
std
::
string
>
,
STRINGS
);
}
// namespace cpp
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/cpp/op_desc.h
0 → 100644
浏览文件 @
ae60589f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/lite/model_parser/desc_apis.h"
#include "paddle/fluid/lite/utils/any.h"
#include "paddle/fluid/lite/utils/varient.h"
namespace
paddle
{
namespace
lite
{
namespace
cpp
{
/*
* The cpp::OpDesc is the internal representation for Op. All the internal
* imprementation should use it, not the pb::OpDesc.
*/
class
OpDesc
:
public
OpDescAPI
{
public:
using
attrs_t
=
std
::
map
<
std
::
string
,
Any
>
;
using
attr_types_t
=
std
::
map
<
std
::
string
,
AttrType
>
;
protected:
std
::
string
type_
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
inputs_
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
outputs_
;
std
::
map
<
std
::
string
,
Any
>
attrs_
;
std
::
map
<
std
::
string
,
AttrType
>
attr_types_
;
public:
OpDesc
()
=
default
;
std
::
string
Type
()
const
override
{
return
type_
;
}
void
SetType
(
const
std
::
string
&
x
)
override
{
type_
=
x
;
}
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>&
inputs
()
const
{
return
inputs_
;
}
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>&
outputs
()
const
{
return
outputs_
;
}
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>*
mutable_inputs
()
{
return
&
inputs_
;
}
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>*
mutable_outputs
()
{
return
&
outputs_
;
}
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
override
{
auto
it
=
inputs_
.
find
(
param
);
CHECK
(
it
!=
inputs_
.
end
());
return
it
->
second
;
}
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
override
{
std
::
vector
<
std
::
string
>
res
;
for
(
const
auto
&
x
:
inputs_
)
res
.
push_back
(
x
.
first
);
return
res
;
}
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
override
{
std
::
vector
<
std
::
string
>
res
;
for
(
const
auto
&
x
:
outputs_
)
res
.
push_back
(
x
.
first
);
return
res
;
}
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
override
{
auto
it
=
outputs_
.
find
(
param
);
CHECK
(
it
!=
outputs_
.
end
());
return
it
->
second
;
}
void
SetInput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>&
args
)
override
{
inputs_
[
param
]
=
args
;
}
void
SetOutput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>&
args
)
override
{
outputs_
[
param
]
=
args
;
}
bool
HasAttr
(
const
std
::
string
&
name
)
const
override
{
return
attrs_
.
count
(
name
);
}
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
override
{
auto
it
=
attr_types_
.
find
(
name
);
CHECK
(
it
!=
attr_types_
.
end
());
return
it
->
second
;
}
std
::
vector
<
std
::
string
>
AttrNames
()
const
override
{
std
::
vector
<
std
::
string
>
res
;
for
(
const
auto
&
x
:
attrs_
)
{
res
.
push_back
(
x
.
first
);
}
return
res
;
}
template
<
typename
T
>
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
);
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
const
std
::
map
<
std
::
string
,
Any
>&
attrs
()
const
{
return
attrs_
;
}
const
std
::
map
<
std
::
string
,
AttrType
>&
attr_types
()
const
{
return
attr_types_
;
}
};
}
// namespace cpp
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/desc_apis.h
0 → 100644
浏览文件 @
ae60589f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
namespace
paddle
{
namespace
lite
{
/*
* Compatible interfaces for all the different kinds of opdesc. All the OpDesc
* classes should implement this.
* NOTE Some interfaces are weried, we remain them unchanged to keep compatible
* with framework::OpDesc in Fluid framework.
*/
class
OpDescAPI
{
public:
// The AttrType is used to make the proto::AttrType portable.
enum
class
AttrType
{
INT
=
0
,
FLOAT
=
1
,
STRING
=
2
,
INTS
=
3
,
FLOATS
=
4
,
STRINGS
=
5
,
BOOLEAN
=
6
,
BOOLEANS
=
7
,
BLOCK
=
8
,
LONG
=
9
,
BLOCKS
=
10
,
LONGS
=
11
,
UNK
,
};
virtual
~
OpDescAPI
()
=
default
;
/// Get operator's type.
virtual
std
::
string
Type
()
const
=
0
;
/// Set operator's type.
virtual
void
SetType
(
const
std
::
string
&
type
)
=
0
;
/// Get arguments given the parameter.
virtual
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
=
0
;
/// Get parameters.
virtual
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
=
0
;
/// Get arguments given the parameter.
virtual
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
=
0
;
/// Get parameters.
virtual
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
=
0
;
/// Set a input given the parameter and arguments.
virtual
void
SetInput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>&
args
)
=
0
;
virtual
void
SetOutput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>&
args
)
=
0
;
/// Tell whether this desc has an attribute.
virtual
bool
HasAttr
(
const
std
::
string
&
name
)
const
=
0
;
/// Get the type of an attribute.
virtual
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
=
0
;
virtual
std
::
vector
<
std
::
string
>
AttrNames
()
const
=
0
;
/// Set an attribute.
template
<
typename
T
>
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
);
/// Get an attribute.
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
};
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/op_desc_test.cc
0 → 100644
浏览文件 @
ae60589f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
namespace
paddle
{
namespace
lite
{
template
<
typename
OpDesc
>
void
TestX
()
{
OpDesc
desc
;
desc
.
SetInput
(
"X"
,
{
"a"
,
"b"
});
auto
X
=
desc
.
Input
(
"X"
);
ASSERT_EQ
(
X
.
size
(),
2UL
);
ASSERT_EQ
(
X
[
0
],
"a"
);
ASSERT_EQ
(
X
[
1
],
"b"
);
desc
.
SetOutput
(
"Y"
,
{
"c"
,
"d"
});
auto
Y
=
desc
.
Output
(
"Y"
);
ASSERT_EQ
(
Y
.
size
(),
2UL
);
ASSERT_EQ
(
Y
[
0
],
"c"
);
ASSERT_EQ
(
Y
[
1
],
"d"
);
desc
.
template
SetAttr
<
int32_t
>(
"aint"
,
100
);
ASSERT_TRUE
(
desc
.
HasAttr
(
"aint"
));
ASSERT_FALSE
(
desc
.
HasAttr
(
"afloat"
));
ASSERT_EQ
(
desc
.
template
GetAttr
<
int32_t
>(
"aint"
),
100
);
}
TEST
(
OpDesc
,
Basic
)
{
TestX
<
pb
::
OpDesc
>
();
TestX
<
cpp
::
OpDesc
>
();
}
TEST
(
OpDesc
,
CppToPb
)
{
cpp
::
OpDesc
desc
;
desc
.
SetInput
(
"X"
,
{
"a"
,
"b"
});
desc
.
SetOutput
(
"Y"
,
{
"c"
,
"d"
});
desc
.
template
SetAttr
<
int32_t
>(
"aint"
,
100
);
pb
::
OpDesc
pb_desc
;
TransformOpDescCppToPb
(
desc
,
&
pb_desc
);
{
auto
&
desc
=
pb_desc
;
auto
X
=
desc
.
Input
(
"X"
);
ASSERT_EQ
(
X
.
size
(),
2UL
);
ASSERT_EQ
(
X
[
0
],
"a"
);
ASSERT_EQ
(
X
[
1
],
"b"
);
auto
Y
=
desc
.
Output
(
"Y"
);
ASSERT_EQ
(
Y
.
size
(),
2UL
);
ASSERT_EQ
(
Y
[
0
],
"c"
);
ASSERT_EQ
(
Y
[
1
],
"d"
);
ASSERT_TRUE
(
desc
.
HasAttr
(
"aint"
));
ASSERT_FALSE
(
desc
.
HasAttr
(
"afloat"
));
ASSERT_EQ
(
desc
.
template
GetAttr
<
int32_t
>(
"aint"
),
100
);
}
}
TEST
(
OpDesc
,
PbToCpp
)
{
pb
::
OpDesc
desc
;
desc
.
SetInput
(
"X"
,
{
"a"
,
"b"
});
desc
.
SetOutput
(
"Y"
,
{
"c"
,
"d"
});
desc
.
template
SetAttr
<
int32_t
>(
"aint"
,
100
);
cpp
::
OpDesc
cpp_desc
;
TransformOpDescPbToCpp
(
desc
,
&
cpp_desc
);
{
auto
&
desc
=
cpp_desc
;
auto
X
=
desc
.
Input
(
"X"
);
ASSERT_EQ
(
X
.
size
(),
2UL
);
ASSERT_EQ
(
X
[
0
],
"a"
);
ASSERT_EQ
(
X
[
1
],
"b"
);
auto
Y
=
desc
.
Output
(
"Y"
);
ASSERT_EQ
(
Y
.
size
(),
2UL
);
ASSERT_EQ
(
Y
[
0
],
"c"
);
ASSERT_EQ
(
Y
[
1
],
"d"
);
ASSERT_TRUE
(
desc
.
HasAttr
(
"aint"
));
ASSERT_FALSE
(
desc
.
HasAttr
(
"afloat"
));
ASSERT_EQ
(
desc
.
template
GetAttr
<
int32_t
>(
"aint"
),
100
);
}
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/op_desc.cc
浏览文件 @
ae60589f
...
...
@@ -18,10 +18,9 @@ namespace paddle {
namespace
lite
{
namespace
pb
{
template
<
>
void
OpDesc
::
SetAttr
<
std
::
string
>
(
const
std
::
string
&
name
,
const
std
::
string
&
v
)
{
auto
&
xs
=
*
desc_
.
mutable_attrs
();
google
::
protobuf
::
internal
::
RepeatedPtrIterator
<
framework
::
proto
::
OpDesc_Attr
>
FindAttr
(
framework
::
proto
::
OpDesc
*
desc
,
const
std
::
string
&
name
)
{
auto
&
xs
=
*
desc
->
mutable_attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
...
...
@@ -33,33 +32,95 @@ void OpDesc::SetAttr<std::string>(const std::string &name,
return
x
.
name
()
==
name
;
});
}
return
it
;
}
#define SET_IMPL_ONE(T, ty__, pb_f__) \
template <> \
void OpDesc::SetAttr<T>(const std::string &name, const T &v) { \
auto it = FindAttr(&desc_, name); \
it->set_type(framework::proto::ty__); \
it->set_##pb_f__(v); \
}
SET_IMPL_ONE
(
int
,
INT
,
i
);
SET_IMPL_ONE
(
float
,
FLOAT
,
f
);
SET_IMPL_ONE
(
bool
,
FLOAT
,
f
);
template
<
>
void
OpDesc
::
SetAttr
<
std
::
vector
<
int
>>
(
const
std
::
string
&
name
,
const
std
::
vector
<
int
>
&
v
)
{
auto
it
=
FindAttr
(
&
desc_
,
name
);
it
->
set_type
(
framework
::
proto
::
INTS
);
it
->
clear_ints
();
for
(
auto
&
i
:
v
)
{
it
->
add_ints
(
i
);
}
}
template
<
>
void
OpDesc
::
SetAttr
<
std
::
string
>
(
const
std
::
string
&
name
,
const
std
::
string
&
v
)
{
auto
it
=
FindAttr
(
&
desc_
,
name
);
it
->
set_type
(
framework
::
proto
::
STRING
);
it
->
set_s
(
v
.
c_str
());
}
template
<
>
void
OpDesc
::
SetAttr
<
std
::
vector
<
int
>>
(
const
std
::
string
&
name
,
const
std
::
vector
<
int
>
&
v
)
{
auto
&
xs
=
*
desc_
.
mutable_attrs
();
void
OpDesc
::
SetAttr
<
std
::
vector
<
float
>>
(
const
std
::
string
&
name
,
const
std
::
vector
<
float
>
&
v
)
{
auto
it
=
FindAttr
(
&
desc_
,
name
);
it
->
set_type
(
framework
::
proto
::
FLOATS
);
it
->
clear_floats
();
for
(
auto
&
i
:
v
)
{
it
->
add_floats
(
i
);
}
}
template
<
>
void
OpDesc
::
SetAttr
<
std
::
vector
<
std
::
string
>>
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>
&
v
)
{
auto
it
=
FindAttr
(
&
desc_
,
name
);
it
->
set_type
(
framework
::
proto
::
STRINGS
);
it
->
clear_strings
();
for
(
auto
&
i
:
v
)
{
it
->
add_strings
(
i
);
}
}
google
::
protobuf
::
internal
::
RepeatedPtrIterator
<
const
framework
::
proto
::
OpDesc_Attr
>
GetFindAttr
(
const
framework
::
proto
::
OpDesc
&
desc
,
const
std
::
string
&
name
)
{
auto
&
xs
=
desc
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
if
(
it
==
xs
.
end
())
{
auto
*
attr
=
xs
.
Add
();
attr
->
set_name
(
name
);
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
return
it
;
}
#define GET_ATTR_IMPL(T, pb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string &name) const { \
auto it = GetFindAttr(desc_, name); \
return it->pb_f__(); \
}
it
->
set_type
(
framework
::
proto
::
INTS
);
it
->
clear_ints
();
for
(
auto
&
i
:
v
)
{
it
->
add_ints
(
i
);
#define GET_ATTRS_IMPL(T, pb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string &name) const { \
auto it = GetFindAttr(desc_, name); \
T res; \
for (const auto &v : it->pb_f__()) { \
res.push_back(v); \
} \
return res; \
}
}
GET_ATTR_IMPL
(
int32_t
,
i
);
GET_ATTR_IMPL
(
float
,
f
);
GET_ATTR_IMPL
(
bool
,
b
);
GET_ATTRS_IMPL
(
std
::
vector
<
int
>
,
ints
);
GET_ATTRS_IMPL
(
std
::
vector
<
float
>
,
floats
);
GET_ATTRS_IMPL
(
std
::
vector
<
std
::
string
>
,
strings
);
GET_ATTR_IMPL
(
std
::
string
,
s
);
}
// namespace pb
}
// namespace lite
...
...
paddle/fluid/lite/model_parser/pb/op_desc.h
浏览文件 @
ae60589f
...
...
@@ -27,6 +27,7 @@
#include <unordered_map>
#include <vector>
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/model_parser/desc_apis.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
...
...
@@ -43,7 +44,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
* except the desc_, to avoid the inconsistent state, which is normal in the
* original interface and results in bugs.
*/
class
OpDesc
{
class
OpDesc
:
public
OpDescAPI
{
public:
OpDesc
()
{}
...
...
@@ -54,38 +55,38 @@ class OpDesc {
framework
::
proto
::
OpDesc
*
Proto
()
{
return
&
desc_
;
}
const
framework
::
proto
::
OpDesc
&
ReadonlyProto
()
const
{
return
desc_
;
}
std
::
string
Type
()
const
{
return
desc_
.
type
();
}
std
::
string
Type
()
const
override
{
return
desc_
.
type
();
}
void
SetType
(
const
std
::
string
&
type
)
{
desc_
.
set_type
(
type
);
}
void
SetType
(
const
std
::
string
&
type
)
override
{
desc_
.
set_type
(
type
);
}
// Get the arguments of parameter called `param`
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
{
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
override
{
return
GetArguments
(
desc_
.
inputs
(),
param
);
}
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
{
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
override
{
return
GetArgumentNames
(
desc_
.
inputs
());
}
void
SetInput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
const
std
::
vector
<
std
::
string
>
&
args
)
override
{
SetArgument
(
desc_
.
mutable_inputs
(),
param
,
args
);
}
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
{
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
override
{
return
GetArguments
(
desc_
.
outputs
(),
param
);
}
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
{
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
override
{
return
GetArgumentNames
(
desc_
.
outputs
());
}
void
SetOutput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
const
std
::
vector
<
std
::
string
>
&
args
)
override
{
SetArgument
(
desc_
.
mutable_outputs
(),
param
,
args
);
}
bool
HasAttr
(
const
std
::
string
&
name
)
const
{
bool
HasAttr
(
const
std
::
string
&
name
)
const
override
{
const
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
...
...
@@ -94,17 +95,38 @@ class OpDesc {
return
it
!=
xs
.
end
();
}
framework
::
proto
::
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
{
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
override
{
const
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
CHECK
(
it
!=
xs
.
end
());
return
it
->
type
();
#define DEF_ONE(type__) \
case framework::proto::AttrType::type__: \
return AttrType::type__;
switch
(
it
->
type
())
{
DEF_ONE
(
INT
);
DEF_ONE
(
FLOAT
);
DEF_ONE
(
STRING
);
DEF_ONE
(
INTS
);
DEF_ONE
(
FLOATS
);
DEF_ONE
(
STRINGS
);
DEF_ONE
(
BOOLEAN
);
DEF_ONE
(
BOOLEANS
);
DEF_ONE
(
BLOCK
);
DEF_ONE
(
LONG
);
DEF_ONE
(
BLOCKS
);
DEF_ONE
(
LONGS
);
default:
LOG
(
ERROR
)
<<
"Unknown attribute type"
;
return
AttrType
::
UNK
;
}
#undef DEF_ONE
}
std
::
vector
<
std
::
string
>
AttrNames
()
const
{
std
::
vector
<
std
::
string
>
AttrNames
()
const
override
{
std
::
vector
<
std
::
string
>
res
;
const
auto
&
xs
=
desc_
.
attrs
();
std
::
transform
(
...
...
@@ -114,72 +136,10 @@ class OpDesc {
}
template
<
typename
T
>
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
)
{
auto
&
xs
=
*
desc_
.
mutable_attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
if
(
it
==
xs
.
end
())
{
auto
*
attr
=
xs
.
Add
();
attr
->
set_name
(
name
);
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
}
size_t
hash
=
typeid
(
T
).
hash_code
();
if
(
hash
==
typeid
(
int
).
hash_code
())
{
// NOLINT
it
->
set_type
(
framework
::
proto
::
INT
);
it
->
set_i
(
v
);
}
else
if
(
hash
==
typeid
(
float
).
hash_code
())
{
// NOLINT
it
->
set_type
(
framework
::
proto
::
FLOAT
);
it
->
set_f
(
v
);
}
else
if
(
hash
==
typeid
(
bool
).
hash_code
())
{
// NOLINT
it
->
set_type
(
framework
::
proto
::
BOOLEAN
);
it
->
set_b
(
v
);
}
else
{
LOG
(
FATAL
)
<<
"unsupport attr type"
;
}
}
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
{
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
);
Attribute
res
;
CHECK
(
it
!=
xs
.
end
());
switch
(
it
->
type
())
{
case
framework
::
proto
::
INT
:
res
.
set
<
int
>
(
it
->
i
());
break
;
case
framework
::
proto
::
FLOAT
:
res
.
set
<
float
>
(
it
->
f
());
break
;
case
framework
::
proto
::
STRING
:
res
.
set
<
std
::
string
>
(
it
->
s
());
break
;
case
framework
::
proto
::
BOOLEAN
:
res
.
set
<
bool
>
(
it
->
b
());
break
;
case
framework
::
proto
::
INTS
:
{
std
::
vector
<
int
>
values
;
const
auto
&
ys
=
it
->
ints
();
std
::
transform
(
ys
.
begin
(),
ys
.
end
(),
std
::
back_inserter
(
values
),
[](
const
int
&
x
)
{
return
x
;
});
res
.
set
<
std
::
vector
<
int
>>
(
values
);
}
break
;
default:
LOG
(
FATAL
)
<<
"unsupported attr type"
;
}
return
res
;
}
template
<
typename
T
>
T
GetAttr
(
const
std
::
string
&
name
)
const
;
private:
std
::
vector
<
std
::
string
>
GetArguments
(
...
...
paddle/fluid/lite/operators/activation_ops.cc
浏览文件 @
ae60589f
...
...
@@ -33,7 +33,7 @@ class ActivationOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
...
...
@@ -66,7 +66,7 @@ class ActivationGradOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
Out_grad_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_grad_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
...
...
paddle/fluid/lite/operators/concat_op.cc
浏览文件 @
ae60589f
...
...
@@ -54,7 +54,7 @@ bool ConcatOpLite::InferShape() const {
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
ConcatOpLite
::
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
bool
ConcatOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
inputs
=
op_desc
.
Input
(
"X"
);
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
...
...
@@ -63,7 +63,7 @@ bool ConcatOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) {
}
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
axis
=
GetAttr
<
int
>
(
op_desc
.
GetAttr
(
"axis"
)
);
param_
.
axis
=
op_desc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/concat_op.h
浏览文件 @
ae60589f
...
...
@@ -32,7 +32,7 @@ class ConcatOpLite : public OpLite {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"concat"
;
}
...
...
paddle/fluid/lite/operators/concat_op_test.cc
浏览文件 @
ae60589f
...
...
@@ -42,7 +42,7 @@ TEST(concat_op_lite, test) {
}
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"concat"
);
desc
.
SetInput
(
"X"
,
{
"x0"
,
"x1"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
...
...
paddle/fluid/lite/operators/dropout_op.cc
浏览文件 @
ae60589f
...
...
@@ -42,7 +42,7 @@ class DropoutOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
Mask
=
op_desc
.
Output
(
"Mask"
).
front
();
...
...
@@ -51,14 +51,14 @@ class DropoutOpLite : public OpLite {
param_
.
output
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
out
);
param_
.
mask
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Mask
);
param_
.
dropout_prob
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"dropout_prob"
)
);
param_
.
dropout_prob
=
op_desc
.
GetAttr
<
float
>
(
"dropout_prob"
);
if
(
op_desc
.
HasAttr
(
"axis"
))
{
param_
.
is_test
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"is_test"
)
);
param_
.
is_test
=
op_desc
.
GetAttr
<
bool
>
(
"is_test"
);
}
param_
.
fix_seed
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"fix_seed"
)
);
param_
.
seed
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"seed"
)
);
param_
.
fix_seed
=
op_desc
.
GetAttr
<
bool
>
(
"fix_seed"
);
param_
.
seed
=
op_desc
.
GetAttr
<
int
>
(
"seed"
);
param_
.
dropout_implementation
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"dropout_implementation"
)
);
op_desc
.
GetAttr
<
int
>
(
"dropout_implementation"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/elementwise_ops.cc
浏览文件 @
ae60589f
...
...
@@ -36,7 +36,7 @@ class ElementwiseOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Y_name
=
opdesc
.
Input
(
"Y"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
...
...
@@ -44,7 +44,7 @@ class ElementwiseOp : public OpLite {
param_
.
X
=
GetVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
Y
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Y_name
);
param_
.
Out
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
axis
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"axis"
)
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
...
...
@@ -75,8 +75,8 @@ class ElementwiseGradExplicitOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
CHECK_EQ
(
opdesc
.
Inputs
().
size
(),
1UL
);
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
CHECK_EQ
(
opdesc
.
Input
ArgumentName
s
().
size
(),
1UL
);
auto
Out_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
auto
Y_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"Y"
)).
front
();
...
...
@@ -84,7 +84,7 @@ class ElementwiseGradExplicitOp : public OpLite {
param_
.
Out_grad
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
X_grad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
Y_grad
=
GetMutableVar
<
Tensor
>
(
scope
,
Y_name
);
param_
.
axis
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"axis"
)
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/fc_op.h
浏览文件 @
ae60589f
...
...
@@ -46,7 +46,7 @@ class FcOpLite : public OpLite {
*/
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"Input"
).
front
();
auto
W
=
op_desc
.
Input
(
"W"
).
front
();
auto
bias
=
op_desc
.
Input
(
"Bias"
).
front
();
...
...
@@ -57,7 +57,7 @@ class FcOpLite : public OpLite {
param_
.
bias
=
scope
->
FindVar
(
bias
)
->
GetMutable
<
lite
::
Tensor
>
();
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
in_num_col_dims
=
GetAttr
<
int
>
(
op_desc
.
GetAttr
(
"in_num_col_dims"
)
);
param_
.
in_num_col_dims
=
op_desc
.
GetAttr
<
int
>
(
"in_num_col_dims"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/fc_op_test.cc
浏览文件 @
ae60589f
...
...
@@ -47,7 +47,7 @@ TEST(fc_op_lite, TestX86) {
}
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"fc"
);
desc
.
SetInput
(
"Input"
,
{
"x"
});
desc
.
SetInput
(
"W"
,
{
"w"
});
...
...
paddle/fluid/lite/operators/feed_op.cc
浏览文件 @
ae60589f
...
...
@@ -34,7 +34,7 @@ class FeedOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
feed_var_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
*
feed_var
=
scope
->
FindVar
(
feed_var_name
);
CHECK
(
feed_var
);
...
...
@@ -48,7 +48,7 @@ class FeedOp : public OpLite {
// NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc
param_
.
col
=
GetAttr
<
int
>
(
opdesc
.
GetAttr
(
"col"
)
);
param_
.
col
=
opdesc
.
GetAttr
<
int
>
(
"col"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/fetch_op.cc
浏览文件 @
ae60589f
...
...
@@ -33,7 +33,7 @@ class FetchOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
_x
=
opdesc
.
Input
(
"X"
).
front
();
auto
*
x
=
scope
->
FindVar
(
_x
);
CHECK
(
x
);
...
...
@@ -43,7 +43,7 @@ class FetchOp : public OpLite {
auto
*
out
=
scope
->
FindVar
(
_out
);
param_
.
fetch_list
=
out
->
GetMutable
<
std
::
vector
<
lite
::
Tensor
>>
();
param_
.
col
=
GetAttr
<
int
>
(
opdesc
.
GetAttr
(
"col"
)
);
param_
.
col
=
opdesc
.
GetAttr
<
int
>
(
"col"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/fill_constant_op.cc
浏览文件 @
ae60589f
...
...
@@ -33,14 +33,14 @@ class FillConstantOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
param_
.
Out
=
GetMutableVar
<
Tensor
>
(
scope
,
Out_name
);
param_
.
dtype
=
GetAttr
<
int
>
(
opdesc
.
GetAttr
(
"dtype"
)
);
param_
.
shape
=
GetAttr
<
std
::
vector
<
int64_t
>>
(
opdesc
.
GetAttr
(
"shape"
)
);
param_
.
value
=
GetAttr
<
float
>
(
opdesc
.
GetAttr
(
"value"
)
);
param_
.
force_cpu
=
GetAttr
<
bool
>
(
opdesc
.
GetAttr
(
"force_cpu"
)
);
param_
.
dtype
=
opdesc
.
GetAttr
<
int
>
(
"dtype"
);
param_
.
shape
=
opdesc
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"shape"
);
param_
.
value
=
opdesc
.
GetAttr
<
float
>
(
"value"
);
param_
.
force_cpu
=
opdesc
.
GetAttr
<
bool
>
(
"force_cpu"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/io_copy_op.cc
浏览文件 @
ae60589f
...
...
@@ -29,7 +29,8 @@ bool IoCopyOp::InferShape() const {
return
true
;
}
bool
IoCopyOp
::
Run
()
{
return
OpLite
::
Run
();
}
bool
IoCopyOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
paddle
::
lite
::
Scope
*
scope
)
{
bool
IoCopyOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
paddle
::
lite
::
Scope
*
scope
)
{
auto
x
=
opdesc
.
Input
(
"Input"
).
front
();
auto
out
=
opdesc
.
Output
(
"Out"
).
front
();
param_
.
x
=
GetTensor
(
scope
,
x
);
...
...
paddle/fluid/lite/operators/io_copy_op.h
浏览文件 @
ae60589f
...
...
@@ -31,7 +31,7 @@ class IoCopyOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
private:
operators
::
IoCopyParam
param_
;
...
...
paddle/fluid/lite/operators/mean_op.cc
浏览文件 @
ae60589f
...
...
@@ -37,7 +37,7 @@ class MeanOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
...
...
@@ -72,8 +72,8 @@ class MeanGradOp : public OpLite {
return
true
;
}
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
CHECK_EQ
(
opdesc
.
Inputs
().
size
(),
3UL
);
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
CHECK_EQ
(
opdesc
.
Input
ArgumentName
s
().
size
(),
3UL
);
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Out_grad_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_grad_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
...
...
paddle/fluid/lite/operators/mul_op.cc
浏览文件 @
ae60589f
...
...
@@ -85,7 +85,7 @@ bool MulGradOpLite::InferShape() const {
return
true
;
}
bool
MulGradOpLite
::
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
bool
MulGradOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
X_name
=
op_desc
.
Input
(
"X"
).
front
();
auto
Y_name
=
op_desc
.
Input
(
"Y"
).
front
();
auto
Out_grad_name
=
op_desc
.
Output
(
framework
::
GradVarName
(
"Out"
)).
front
();
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
ae60589f
...
...
@@ -37,7 +37,7 @@ class MulOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
...
...
@@ -49,8 +49,8 @@ class MulOpLite : public OpLite {
param_
.
y
=
var
->
GetMutable
<
Tensor
>
();
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
x_num_col_dims
=
GetAttr
<
int
>
(
op_desc
.
GetAttr
(
"x_num_col_dims"
)
);
param_
.
y_num_col_dims
=
GetAttr
<
int
>
(
op_desc
.
GetAttr
(
"y_num_col_dims"
)
);
param_
.
x_num_col_dims
=
op_desc
.
GetAttr
<
int
>
(
"x_num_col_dims"
);
param_
.
y_num_col_dims
=
op_desc
.
GetAttr
<
int
>
(
"y_num_col_dims"
);
return
true
;
}
...
...
@@ -73,7 +73,7 @@ class MulGradOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
std
::
string
DebugString
()
const
override
{
return
"mul_grad"
;
}
...
...
paddle/fluid/lite/operators/relu_op.cc
浏览文件 @
ae60589f
...
...
@@ -30,7 +30,7 @@ bool ReluOp::InferShape() const {
return
true
;
}
bool
ReluOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
ReluOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
input
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"Input"
).
front
())
->
Get
<
lite
::
Tensor
>
());
param_
.
output
=
...
...
paddle/fluid/lite/operators/relu_op.h
浏览文件 @
ae60589f
...
...
@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"relu"
;
}
...
...
paddle/fluid/lite/operators/reshape_op.cc
浏览文件 @
ae60589f
...
...
@@ -33,7 +33,7 @@ bool ReshapeOp::InferShape() const {
return
true
;
}
bool
ReshapeOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
ReshapeOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
auto
x_var
=
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
());
auto
output_var
=
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
());
CHECK
(
x_var
);
...
...
@@ -49,9 +49,9 @@ bool ReshapeOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) {
const_cast
<
lite
::
Tensor
*>
(
&
(
actual_shape_var
->
Get
<
lite
::
Tensor
>
()));
}
}
param_
.
shape
=
GetAttr
<
std
::
vector
<
int
>>
(
opdesc
.
GetAttr
(
"shape"
));
param_
.
shape
=
(
opdesc
.
GetAttr
<
std
::
vector
<
int
>>
(
"shape"
));
if
(
opdesc
.
HasAttr
(
"inplace"
))
{
param_
.
inplace
=
GetAttr
<
bool
>
(
opdesc
.
GetAttr
(
"inplace"
)
);
param_
.
inplace
=
opdesc
.
GetAttr
<
bool
>
(
"inplace"
);
}
CHECK
(
param_
.
x
)
<<
"Input(X) of ReshapeOp should not be null."
;
CHECK
(
param_
.
output
)
<<
"Output(Out) of ReshapeOp should not be null."
;
...
...
@@ -70,14 +70,14 @@ bool Reshape2Op::InferShape() const {
ReshapeOp
::
InferShape
();
auto
x_dims
=
param_
.
x
->
dims
();
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
0
);
for
(
in
t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
xshape_dims
[
i
+
1
]
=
x_dims
[
i
];
}
param_
.
xshape
->
Resize
(
DDim
(
xshape_dims
));
return
true
;
}
bool
Reshape2Op
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
Reshape2Op
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
ReshapeOp
::
AttachImpl
(
opdesc
,
scope
);
auto
xshape_var
=
scope
->
FindVar
(
opdesc
.
Output
(
"XShape"
).
front
());
CHECK
(
xshape_var
);
...
...
paddle/fluid/lite/operators/reshape_op.h
浏览文件 @
ae60589f
...
...
@@ -32,7 +32,7 @@ class ReshapeOp : public OpLite {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"reshape"
;
}
...
...
@@ -50,7 +50,7 @@ class Reshape2Op : public ReshapeOp {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"reshape2"
;
}
...
...
paddle/fluid/lite/operators/reshape_op_test.cc
浏览文件 @
ae60589f
...
...
@@ -47,7 +47,7 @@ TEST(reshape_op_lite, test) {
for
(
auto
&
has_actual_shape
:
{
true
,
false
})
{
for
(
auto
&
inplace
:
{
true
,
false
})
{
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"reshape"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
if
(
has_actual_shape
)
{
...
...
@@ -68,7 +68,7 @@ TEST(reshape_op_lite, test) {
// check output dims
auto
output_dims
=
output
->
dims
();
CHECK_EQ
(
output_dims
.
size
(),
shape
.
second
.
size
());
for
(
in
t
i
=
0
;
i
<
output_dims
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
output_dims
.
size
();
i
++
)
{
CHECK_EQ
(
output_dims
[
i
],
shape
.
second
[
i
]);
}
}
...
...
@@ -102,7 +102,7 @@ TEST(reshape2_op_lite, test) {
for
(
auto
&
has_actual_shape
:
{
true
,
false
})
{
for
(
auto
&
inplace
:
{
true
,
false
})
{
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"reshape"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
if
(
has_actual_shape
)
{
...
...
@@ -132,7 +132,7 @@ TEST(reshape2_op_lite, test) {
auto
xshape_dims
=
xshape
->
dims
();
CHECK_EQ
(
xshape_dims
.
size
(),
x_dims
.
size
()
+
1
);
CHECK_EQ
(
xshape_dims
[
0
],
0
);
for
(
in
t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
CHECK_EQ
(
xshape_dims
[
i
+
1
],
x_dims
[
i
]);
}
}
...
...
paddle/fluid/lite/operators/scale_op.cc
浏览文件 @
ae60589f
...
...
@@ -29,14 +29,14 @@ bool ScaleOp::InferShape() const {
return
true
;
}
bool
ScaleOp
::
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
bool
ScaleOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
output
=
op_desc
.
Output
(
"Out"
).
front
();
param_
.
x
=
scope
->
FindVar
(
x
)
->
GetMutable
<
Tensor
>
();
param_
.
output
=
scope
->
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
param_
.
scale
=
GetAttr
<
float
>
(
op_desc
.
GetAttr
(
"scale"
)
);
param_
.
bias
=
GetAttr
<
float
>
(
op_desc
.
GetAttr
(
"bias"
)
);
param_
.
bias_after_scale
=
GetAttr
<
bool
>
(
op_desc
.
GetAttr
(
"bias_after_scale"
)
);
param_
.
scale
=
op_desc
.
GetAttr
<
float
>
(
"scale"
);
param_
.
bias
=
op_desc
.
GetAttr
<
float
>
(
"bias"
);
param_
.
bias_after_scale
=
op_desc
.
GetAttr
<
bool
>
(
"bias_after_scale"
);
CHECK
(
param_
.
x
);
CHECK
(
param_
.
output
);
return
true
;
...
...
paddle/fluid/lite/operators/scale_op.h
浏览文件 @
ae60589f
...
...
@@ -32,7 +32,7 @@ class ScaleOp : public OpLite {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"scale"
;
}
...
...
paddle/fluid/lite/operators/scale_op_test.cc
浏览文件 @
ae60589f
...
...
@@ -29,7 +29,7 @@ TEST(scale_op_lite, test) {
output
->
Resize
(
DDim
(
std
::
vector
<
int64_t
>
{
1
,
1
}));
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"scale"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
...
...
@@ -48,7 +48,7 @@ TEST(scale_op_lite, test) {
auto
x_dims
=
x
->
dims
();
auto
output_dims
=
output
->
dims
();
CHECK_EQ
(
output_dims
.
size
(),
x_dims
.
size
());
for
(
in
t
i
=
0
;
i
<
output_dims
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
output_dims
.
size
();
i
++
)
{
CHECK_EQ
(
output_dims
[
i
],
x_dims
[
i
]);
}
}
...
...
paddle/fluid/lite/operators/softmax_op.cc
浏览文件 @
ae60589f
...
...
@@ -24,7 +24,8 @@ bool SoftmaxOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
output
);
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_rank
=
x_dims
.
size
();
CHECK_OR_FALSE
(
param_
.
axis
>=
-
x_rank
&&
param_
.
axis
<
x_rank
);
CHECK_OR_FALSE
(
param_
.
axis
>=
-
static_cast
<
int
>
(
x_rank
)
&&
param_
.
axis
<
static_cast
<
int
>
(
x_rank
));
return
true
;
}
...
...
@@ -33,12 +34,12 @@ bool SoftmaxOp::InferShape() const {
return
true
;
}
bool
SoftmaxOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
SoftmaxOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
x
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
())
->
Get
<
lite
::
Tensor
>
());
param_
.
output
=
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
axis
=
GetAttr
<
int
>
(
opdesc
.
GetAttr
(
"axis"
)
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
CHECK
(
param_
.
x
);
CHECK
(
param_
.
output
);
return
true
;
...
...
paddle/fluid/lite/operators/softmax_op.h
浏览文件 @
ae60589f
...
...
@@ -32,7 +32,7 @@ class SoftmaxOp : public OpLite {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"softmax"
;
}
...
...
paddle/fluid/lite/operators/softmax_op_test.cc
浏览文件 @
ae60589f
...
...
@@ -37,7 +37,7 @@ TEST(softmax_op_lite, test) {
}
// prepare op desc
lite
::
OpDesc
desc
;
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"softmax"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录