Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8426beb4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8426beb4
编写于
7月 07, 2017
作者:
D
dongzhihong
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'origin/develop' into save_state
上级
40295b9e
7b810553
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
498 addition
and
1 deletion
+498
-1
cmake/external/glog.cmake
cmake/external/glog.cmake
+2
-0
doc_theme/templates/layout.html
doc_theme/templates/layout.html
+1
-1
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+1
-0
paddle/framework/attr_checker.h
paddle/framework/attr_checker.h
+119
-0
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+253
-0
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+122
-0
未找到文件。
cmake/external/glog.cmake
浏览文件 @
8426beb4
...
...
@@ -38,12 +38,14 @@ ExternalProject_Add(
CMAKE_ARGS -DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
CMAKE_ARGS -DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
${
GLOG_INSTALL_DIR
}
CMAKE_ARGS -DCMAKE_INSTALL_LIBDIR=
${
GLOG_INSTALL_DIR
}
/lib
CMAKE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE=ON
CMAKE_ARGS -DWITH_GFLAGS=ON
CMAKE_ARGS -Dgflags_DIR=
${
GFLAGS_INSTALL_DIR
}
/lib/cmake/gflags
CMAKE_ARGS -DBUILD_TESTING=OFF
CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=
${
GLOG_INSTALL_DIR
}
-DCMAKE_INSTALL_LIBDIR:PATH=
${
GLOG_INSTALL_DIR
}
/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=Release
)
...
...
doc_theme/templates/layout.html
浏览文件 @
8426beb4
...
...
@@ -101,7 +101,7 @@
</div>
<div
class=
"site-nav-links"
>
<div
class=
"site-menu"
>
<a
class=
"fork-on-github"
href=
"https://github.com/PaddlePaddle/Paddle"
target=
"_blank"
><i
class=
"fa fa-github"
></i>
Fo
l
k me on Github
</a>
<a
class=
"fork-on-github"
href=
"https://github.com/PaddlePaddle/Paddle"
target=
"_blank"
><i
class=
"fa fa-github"
></i>
Fo
r
k me on Github
</a>
<div
class=
"language-switcher dropdown"
>
<a
type=
"button"
data-toggle=
"dropdown"
>
<span>
English
</span>
...
...
paddle/framework/CMakeLists.txt
浏览文件 @
8426beb4
...
...
@@ -11,6 +11,7 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test
(
op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf
)
proto_library
(
op_desc SRCS op_desc.proto DEPS attr_type
)
cc_test
(
op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
...
...
paddle/framework/attr_checker.h
0 → 100644
浏览文件 @
8426beb4
#pragma once
#include <boost/variant.hpp>
#include <functional>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/enforce.h"
namespace
paddle
{
namespace
framework
{
typedef
boost
::
variant
<
boost
::
blank
,
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>>
Attribute
;
typedef
std
::
unordered_map
<
std
::
string
,
Attribute
>
AttributeMap
;
// check whether a value(attribute) fit a certain limit
template
<
typename
T
>
class
LargerThanChecker
{
public:
LargerThanChecker
(
T
lower_bound
)
:
lower_bound_
(
lower_bound
)
{}
void
operator
()(
T
&
value
)
const
{
PADDLE_ENFORCE
(
value
>
lower_bound_
,
"larger_than check fail"
);
}
private:
T
lower_bound_
;
};
// we can provide users more common Checker, like 'LessThanChecker',
// 'BetweenChecker'...
template
<
typename
T
>
class
DefaultValueSetter
{
public:
DefaultValueSetter
(
T
default_value
)
:
default_value_
(
default_value
)
{}
void
operator
()(
T
&
value
)
const
{
value
=
default_value_
;
}
private:
T
default_value_
;
};
// check whether a certain attribute fit its limits
// an attribute can have more than one limits
template
<
typename
T
>
class
TypedAttrChecker
{
typedef
std
::
function
<
void
(
T
&
)
>
ValueChecker
;
public:
TypedAttrChecker
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
TypedAttrChecker
&
LargerThan
(
const
T
&
lower_bound
)
{
value_checkers_
.
push_back
(
LargerThanChecker
<
T
>
(
lower_bound
));
return
*
this
;
}
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker
&
SetDefault
(
const
T
&
default_value
)
{
PADDLE_ENFORCE
(
default_value_setter_
.
empty
(),
"%s can't have more than one default value!"
,
attr_name_
);
default_value_setter_
.
push_back
(
DefaultValueSetter
<
T
>
(
default_value
));
return
*
this
;
}
// allow users provide their own checker
TypedAttrChecker
&
AddCustomChecker
(
const
ValueChecker
&
checker
)
{
value_checkers_
.
push_back
(
checker
);
return
*
this
;
}
void
operator
()(
AttributeMap
&
attr_map
)
const
{
if
(
!
attr_map
.
count
(
attr_name_
))
{
// user do not set this attr
PADDLE_ENFORCE
(
!
default_value_setter_
.
empty
(),
"Attribute '%s' is required!"
,
attr_name_
);
// default_value_setter_ has no more than one element
T
val
;
(
default_value_setter_
[
0
])(
val
);
attr_map
[
attr_name_
]
=
val
;
}
Attribute
&
attr
=
attr_map
.
at
(
attr_name_
);
T
&
attr_value
=
boost
::
get
<
T
>
(
attr
);
for
(
const
auto
&
checker
:
value_checkers_
)
{
checker
(
attr_value
);
}
}
private:
std
::
string
attr_name_
;
std
::
vector
<
ValueChecker
>
value_checkers_
;
std
::
vector
<
ValueChecker
>
default_value_setter_
;
};
// check whether op's all attributes fit their own limits
class
OpAttrChecker
{
typedef
std
::
function
<
void
(
AttributeMap
&
)
>
AttrChecker
;
public:
template
<
typename
T
>
TypedAttrChecker
<
T
>&
AddAttrChecker
(
const
std
::
string
&
attr_name
)
{
attr_checkers_
.
push_back
(
TypedAttrChecker
<
T
>
(
attr_name
));
AttrChecker
&
checker
=
attr_checkers_
.
back
();
return
*
(
checker
.
target
<
TypedAttrChecker
<
T
>>
());
}
void
Check
(
AttributeMap
&
attr_map
)
const
{
for
(
const
auto
&
checker
:
attr_checkers_
)
{
checker
(
attr_map
);
}
}
private:
std
::
vector
<
AttrChecker
>
attr_checkers_
;
};
}
// namespace framework
}
// namespace paddle
paddle/framework/op_registry.h
0 → 100644
浏览文件 @
8426beb4
#pragma once
#include "paddle/framework/attr_checker.h"
//#include "paddle/framework/op_base.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
namespace
paddle
{
namespace
framework
{
//==================For test================//
class
OpBase
{
public:
std
::
vector
<
std
::
string
>
inputs_
;
std
::
vector
<
std
::
string
>
outputs_
;
AttributeMap
attr_map_
;
virtual
std
::
string
Run
()
const
=
0
;
virtual
~
OpBase
()
{}
};
//=========================================//
// helper class to set attribute type
struct
AttrTypeHelper
{
template
<
typename
T
>
static
void
SetAttrType
(
AttrProto
*
attr
);
static
Attribute
GetAttrValue
(
const
AttrDesc
&
attr_desc
)
{
switch
(
attr_desc
.
type
())
{
case
paddle
::
framework
::
AttrType
::
INT
:
{
return
attr_desc
.
i
();
}
case
paddle
::
framework
::
AttrType
::
FLOAT
:
{
return
attr_desc
.
f
();
}
case
paddle
::
framework
::
AttrType
::
STRING
:
{
return
attr_desc
.
s
();
}
case
paddle
::
framework
::
AttrType
::
INTS
:
{
std
::
vector
<
int
>
val
(
attr_desc
.
ints_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
ints_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
ints
(
i
);
}
return
val
;
}
case
paddle
::
framework
::
AttrType
::
FLOATS
:
{
std
::
vector
<
float
>
val
(
attr_desc
.
floats_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
floats_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
floats
(
i
);
}
return
val
;
}
case
paddle
::
framework
::
AttrType
::
STRINGS
:
{
std
::
vector
<
std
::
string
>
val
(
attr_desc
.
strings_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
strings_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
strings
(
i
);
}
return
val
;
}
}
PADDLE_ENFORCE
(
false
,
"Unknown OpDesc::AttrDesc::type !"
);
return
boost
::
blank
();
}
};
template
<
>
void
AttrTypeHelper
::
SetAttrType
<
int
>
(
AttrProto
*
attr
)
{
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
INT
);
}
template
<
>
void
AttrTypeHelper
::
SetAttrType
<
float
>
(
AttrProto
*
attr
)
{
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
}
template
<
>
void
AttrTypeHelper
::
SetAttrType
<
std
::
string
>
(
AttrProto
*
attr
)
{
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
STRING
);
}
template
<
>
void
AttrTypeHelper
::
SetAttrType
<
std
::
vector
<
int
>>
(
AttrProto
*
attr
)
{
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
}
template
<
>
void
AttrTypeHelper
::
SetAttrType
<
std
::
vector
<
float
>>
(
AttrProto
*
attr
)
{
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOATS
);
}
template
<
>
void
AttrTypeHelper
::
SetAttrType
<
std
::
vector
<
std
::
string
>>
(
AttrProto
*
attr
)
{
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
}
// this class not only make proto but also init attribute checkers.
class
OpProtoAndCheckerMaker
{
public:
OpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
proto_
(
proto
),
op_checker_
(
op_checker
)
{}
protected:
void
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
)
{
auto
input
=
proto_
->
mutable_inputs
()
->
Add
();
*
(
input
->
mutable_name
())
=
name
;
*
(
input
->
mutable_comment
())
=
comment
;
}
void
AddOutput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
)
{
auto
output
=
proto_
->
mutable_outputs
()
->
Add
();
*
(
output
->
mutable_name
())
=
name
;
*
(
output
->
mutable_comment
())
=
comment
;
}
template
<
typename
T
>
TypedAttrChecker
<
T
>&
AddAttr
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
)
{
auto
attr
=
proto_
->
mutable_attrs
()
->
Add
();
*
(
attr
->
mutable_name
())
=
name
;
*
(
attr
->
mutable_comment
())
=
comment
;
AttrTypeHelper
::
SetAttrType
<
T
>
(
attr
);
return
op_checker_
->
AddAttrChecker
<
T
>
(
name
);
}
void
AddType
(
const
std
::
string
&
op_type
)
{
proto_
->
set_type
(
op_type
);
}
void
AddComment
(
const
std
::
string
&
comment
)
{
*
(
proto_
->
mutable_comment
())
=
comment
;
}
OpProto
*
proto_
;
OpAttrChecker
*
op_checker_
;
};
class
OpRegistry
{
typedef
std
::
function
<
OpBase
*
()
>
OpCreator
;
public:
template
<
typename
OpType
,
typename
ProtoMakerType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
creators_
[
op_type
]
=
[]()
{
return
new
OpType
;
};
OpProto
&
op_proto
=
protos_
[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers_
[
op_type
];
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
PADDLE_ENFORCE
(
op_proto
.
IsInitialized
()
==
true
,
"Fail to initialize %s's OpProto !"
,
op_type
);
}
static
OpBase
*
CreateOp
(
const
OpDesc
&
op_desc
)
{
std
::
string
op_type
=
op_desc
.
type
();
OpBase
*
op
=
(
creators_
.
at
(
op_type
))();
(
op
->
inputs_
).
resize
(
op_desc
.
inputs_size
());
for
(
int
i
=
0
;
i
<
op_desc
.
inputs_size
();
++
i
)
{
(
op
->
inputs_
)[
i
]
=
op_desc
.
inputs
(
i
);
}
(
op
->
outputs_
).
resize
(
op_desc
.
outputs_size
());
for
(
int
i
=
0
;
i
<
op_desc
.
outputs_size
();
++
i
)
{
(
op
->
outputs_
)[
i
]
=
op_desc
.
outputs
(
i
);
}
for
(
int
i
=
0
;
i
<
op_desc
.
attrs_size
();
++
i
)
{
const
AttrDesc
&
ith_attr
=
op_desc
.
attrs
(
i
);
std
::
string
name
=
ith_attr
.
name
();
(
op
->
attr_map_
)[
name
]
=
AttrTypeHelper
::
GetAttrValue
(
ith_attr
);
}
const
OpAttrChecker
&
op_checker
=
op_checkers_
.
at
(
op_type
);
op_checker
.
Check
(
op
->
attr_map_
);
return
op
;
}
private:
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
creators_
;
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
};
std
::
unordered_map
<
std
::
string
,
std
::
function
<
OpBase
*
()
>>
OpRegistry
::
creators_
;
std
::
unordered_map
<
std
::
string
,
OpProto
>
OpRegistry
::
protos_
;
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
OpRegistry
::
op_checkers_
;
template
<
typename
OpType
,
typename
ProtoMakerType
>
class
OpRegisterHelper
{
public:
OpRegisterHelper
(
std
::
string
op_type
)
{
OpRegistry
::
RegisterOp
<
OpType
,
ProtoMakerType
>
(
op_type
);
}
};
#define REGISTER_OP(__op_class, __op_maker_class, __op_type) \
class __op_class##Register { \
private: \
const static OpRegisterHelper<__op_class, __op_maker_class> reg; \
}; \
const OpRegisterHelper<__op_class, __op_maker_class> \
__op_class##Register::reg(#__op_type);
// Demos
class
CosineOp
:
public
OpBase
{
public:
virtual
std
::
string
Run
()
const
{
std
::
string
msg
=
"CosineOp runs! scale = "
+
std
::
to_string
(
boost
::
get
<
float
>
(
attr_map_
.
at
(
"scale"
)));
return
msg
;
}
};
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
CosineOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of cosine op"
);
AddOutput
(
"output"
,
"output of cosine op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
AddType
(
"cos"
);
AddComment
(
"This is cos op"
);
}
};
REGISTER_OP
(
CosineOp
,
CosineOpProtoAndCheckerMaker
,
cos_sim
)
class
MyTestOp
:
public
OpBase
{
public:
virtual
std
::
string
Run
()
const
{
std
::
string
msg
=
"MyTestOp runs! test_attr = "
+
std
::
to_string
(
boost
::
get
<
int
>
(
attr_map_
.
at
(
"test_attr"
)));
return
msg
;
}
};
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
MyTestOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of cosine op"
);
AddOutput
(
"output"
,
"output of cosine op"
);
auto
my_checker
=
[](
int
i
)
{
PADDLE_ENFORCE
(
i
%
2
==
0
,
"'test_attr' must be even!"
);
};
AddAttr
<
int
>
(
"test_attr"
,
"a simple test attribute"
)
.
AddCustomChecker
(
my_checker
);
AddType
(
"my_test_op"
);
AddComment
(
"This is my_test op"
);
}
};
REGISTER_OP
(
MyTestOp
,
MyTestOpProtoAndCheckerMaker
,
my_test_op
)
}
// namespace framework
}
// namespace paddle
paddle/framework/op_registry_test.cc
0 → 100644
浏览文件 @
8426beb4
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
TEST
(
OpRegistry
,
CreateOp
)
{
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
add_inputs
(
"aa"
);
op_desc
.
add_outputs
(
"bb"
);
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
3.3
);
paddle
::
framework
::
OpBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
std
::
string
debug_str
=
op
->
Run
();
std
::
string
str
=
"CosineOp runs! scale = "
+
std
::
to_string
(
3.3
);
ASSERT_EQ
(
str
.
size
(),
debug_str
.
size
());
for
(
size_t
i
=
0
;
i
<
debug_str
.
length
();
++
i
)
{
ASSERT_EQ
(
debug_str
[
i
],
str
[
i
]);
}
}
TEST
(
OpRegistry
,
IllegalAttr
)
{
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
add_inputs
(
"aa"
);
op_desc
.
add_outputs
(
"bb"
);
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
-
2.0
);
bool
caught
=
false
;
try
{
paddle
::
framework
::
OpBase
*
op
__attribute__
((
unused
))
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
caught
=
true
;
std
::
string
msg
=
"larger_than check fail"
;
const
char
*
err_msg
=
err
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
err_msg
[
i
],
msg
[
i
]);
}
}
ASSERT_TRUE
(
caught
);
}
TEST
(
OpRegistry
,
DefaultValue
)
{
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
add_inputs
(
"aa"
);
op_desc
.
add_outputs
(
"bb"
);
paddle
::
framework
::
OpBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
std
::
string
debug_str
=
op
->
Run
();
float
default_value
=
1.0
;
std
::
string
str
=
"CosineOp runs! scale = "
+
std
::
to_string
(
default_value
);
ASSERT_EQ
(
str
.
size
(),
debug_str
.
size
());
for
(
size_t
i
=
0
;
i
<
debug_str
.
length
();
++
i
)
{
ASSERT_EQ
(
debug_str
[
i
],
str
[
i
]);
}
}
TEST
(
OpRegistry
,
CustomChecker
)
{
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"my_test_op"
);
op_desc
.
add_inputs
(
"ii"
);
op_desc
.
add_outputs
(
"oo"
);
// attr 'test_attr' is not set
bool
caught
=
false
;
try
{
paddle
::
framework
::
OpBase
*
op
__attribute__
((
unused
))
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
caught
=
true
;
std
::
string
msg
=
"Attribute 'test_attr' is required!"
;
const
char
*
err_msg
=
err
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
err_msg
[
i
],
msg
[
i
]);
}
}
ASSERT_TRUE
(
caught
);
// set 'test_attr' set to an illegal value
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"test_attr"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
INT
);
attr
->
set_i
(
3
);
caught
=
false
;
try
{
paddle
::
framework
::
OpBase
*
op
__attribute__
((
unused
))
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
caught
=
true
;
std
::
string
msg
=
"'test_attr' must be even!"
;
const
char
*
err_msg
=
err
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
err_msg
[
i
],
msg
[
i
]);
}
}
ASSERT_TRUE
(
caught
);
// set 'test_attr' set to a legal value
op_desc
.
mutable_attrs
()
->
Clear
();
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"test_attr"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
INT
);
attr
->
set_i
(
4
);
paddle
::
framework
::
OpBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
std
::
string
debug_str
=
op
->
Run
();
std
::
string
str
=
"MyTestOp runs! test_attr = "
+
std
::
to_string
(
4
);
ASSERT_EQ
(
str
.
size
(),
debug_str
.
size
());
for
(
size_t
i
=
0
;
i
<
debug_str
.
length
();
++
i
)
{
ASSERT_EQ
(
debug_str
[
i
],
str
[
i
]);
}
}
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录