Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8426beb4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8426beb4
编写于
7月 07, 2017
作者:
D
dongzhihong
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'origin/develop' into save_state
上级
40295b9e
7b810553
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
498 addition
and
1 deletion
+498
-1
cmake/external/glog.cmake
cmake/external/glog.cmake
+2
-0
doc_theme/templates/layout.html
doc_theme/templates/layout.html
+1
-1
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+1
-0
paddle/framework/attr_checker.h
paddle/framework/attr_checker.h
+119
-0
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+253
-0
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+122
-0
未找到文件。
cmake/external/glog.cmake
浏览文件 @
8426beb4
...
...
@@ -38,12 +38,14 @@ ExternalProject_Add(
CMAKE_ARGS -DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
CMAKE_ARGS -DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
${
GLOG_INSTALL_DIR
}
CMAKE_ARGS -DCMAKE_INSTALL_LIBDIR=
${
GLOG_INSTALL_DIR
}
/lib
CMAKE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE=ON
CMAKE_ARGS -DWITH_GFLAGS=ON
CMAKE_ARGS -Dgflags_DIR=
${
GFLAGS_INSTALL_DIR
}
/lib/cmake/gflags
CMAKE_ARGS -DBUILD_TESTING=OFF
CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=
${
GLOG_INSTALL_DIR
}
-DCMAKE_INSTALL_LIBDIR:PATH=
${
GLOG_INSTALL_DIR
}
/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=Release
)
...
...
doc_theme/templates/layout.html
浏览文件 @
8426beb4
...
...
@@ -101,7 +101,7 @@
</div>
<div
class=
"site-nav-links"
>
<div
class=
"site-menu"
>
<a
class=
"fork-on-github"
href=
"https://github.com/PaddlePaddle/Paddle"
target=
"_blank"
><i
class=
"fa fa-github"
></i>
Fo
l
k me on Github
</a>
<a
class=
"fork-on-github"
href=
"https://github.com/PaddlePaddle/Paddle"
target=
"_blank"
><i
class=
"fa fa-github"
></i>
Fo
r
k me on Github
</a>
<div
class=
"language-switcher dropdown"
>
<a
type=
"button"
data-toggle=
"dropdown"
>
<span>
English
</span>
...
...
paddle/framework/CMakeLists.txt
浏览文件 @
8426beb4
...
...
@@ -11,6 +11,7 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test
(
op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf
)
proto_library
(
op_desc SRCS op_desc.proto DEPS attr_type
)
cc_test
(
op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
...
...
paddle/framework/attr_checker.h
0 → 100644
浏览文件 @
8426beb4
#pragma once
#include <boost/variant.hpp>
#include <functional>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/enforce.h"
namespace
paddle
{
namespace
framework
{
typedef
boost
::
variant
<
boost
::
blank
,
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>>
Attribute
;
typedef
std
::
unordered_map
<
std
::
string
,
Attribute
>
AttributeMap
;
// check whether a value(attribute) fit a certain limit
template
<
typename
T
>
class
LargerThanChecker
{
public:
LargerThanChecker
(
T
lower_bound
)
:
lower_bound_
(
lower_bound
)
{}
void
operator
()(
T
&
value
)
const
{
PADDLE_ENFORCE
(
value
>
lower_bound_
,
"larger_than check fail"
);
}
private:
T
lower_bound_
;
};
// we can provide users more common Checker, like 'LessThanChecker',
// 'BetweenChecker'...
template
<
typename
T
>
class
DefaultValueSetter
{
public:
DefaultValueSetter
(
T
default_value
)
:
default_value_
(
default_value
)
{}
void
operator
()(
T
&
value
)
const
{
value
=
default_value_
;
}
private:
T
default_value_
;
};
// check whether a certain attribute fit its limits
// an attribute can have more than one limits
template
<
typename
T
>
class
TypedAttrChecker
{
typedef
std
::
function
<
void
(
T
&
)
>
ValueChecker
;
public:
TypedAttrChecker
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
TypedAttrChecker
&
LargerThan
(
const
T
&
lower_bound
)
{
value_checkers_
.
push_back
(
LargerThanChecker
<
T
>
(
lower_bound
));
return
*
this
;
}
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker
&
SetDefault
(
const
T
&
default_value
)
{
PADDLE_ENFORCE
(
default_value_setter_
.
empty
(),
"%s can't have more than one default value!"
,
attr_name_
);
default_value_setter_
.
push_back
(
DefaultValueSetter
<
T
>
(
default_value
));
return
*
this
;
}
// allow users provide their own checker
TypedAttrChecker
&
AddCustomChecker
(
const
ValueChecker
&
checker
)
{
value_checkers_
.
push_back
(
checker
);
return
*
this
;
}
void
operator
()(
AttributeMap
&
attr_map
)
const
{
if
(
!
attr_map
.
count
(
attr_name_
))
{
// user do not set this attr
PADDLE_ENFORCE
(
!
default_value_setter_
.
empty
(),
"Attribute '%s' is required!"
,
attr_name_
);
// default_value_setter_ has no more than one element
T
val
;
(
default_value_setter_
[
0
])(
val
);
attr_map
[
attr_name_
]
=
val
;
}
Attribute
&
attr
=
attr_map
.
at
(
attr_name_
);
T
&
attr_value
=
boost
::
get
<
T
>
(
attr
);
for
(
const
auto
&
checker
:
value_checkers_
)
{
checker
(
attr_value
);
}
}
private:
std
::
string
attr_name_
;
std
::
vector
<
ValueChecker
>
value_checkers_
;
std
::
vector
<
ValueChecker
>
default_value_setter_
;
};
// check whether op's all attributes fit their own limits
class
OpAttrChecker
{
typedef
std
::
function
<
void
(
AttributeMap
&
)
>
AttrChecker
;
public:
template
<
typename
T
>
TypedAttrChecker
<
T
>&
AddAttrChecker
(
const
std
::
string
&
attr_name
)
{
attr_checkers_
.
push_back
(
TypedAttrChecker
<
T
>
(
attr_name
));
AttrChecker
&
checker
=
attr_checkers_
.
back
();
return
*
(
checker
.
target
<
TypedAttrChecker
<
T
>>
());
}
void
Check
(
AttributeMap
&
attr_map
)
const
{
for
(
const
auto
&
checker
:
attr_checkers_
)
{
checker
(
attr_map
);
}
}
private:
std
::
vector
<
AttrChecker
>
attr_checkers_
;
};
}
// namespace framework
}
// namespace paddle
paddle/framework/op_registry.h
0 → 100644
浏览文件 @
8426beb4
#pragma once
#include "paddle/framework/attr_checker.h"
//#include "paddle/framework/op_base.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
namespace
paddle
{
namespace
framework
{
//==================For test================//
class
OpBase
{
public:
std
::
vector
<
std
::
string
>
inputs_
;
std
::
vector
<
std
::
string
>
outputs_
;
AttributeMap
attr_map_
;
virtual
std
::
string
Run
()
const
=
0
;
virtual
~
OpBase
()
{}
};
//=========================================//
// helper class to set attribute type
struct
AttrTypeHelper
{
template
<
typename
T
>
static
void
SetAttrType
(
AttrProto
*
attr
);
static
Attribute
GetAttrValue
(
const
AttrDesc
&
attr_desc
)
{
switch
(
attr_desc
.
type
())
{
case
paddle
::
framework
::
AttrType
::
INT
:
{
return
attr_desc
.
i
();
}
case
paddle
::
framework
::
AttrType
::
FLOAT
:
{
return
attr_desc
.
f
();
}
case
paddle
::
framework
::
AttrType
::
STRING
:
{
return
attr_desc
.
s
();
}
case
paddle
::
framework
::
AttrType
::
INTS
:
{
std
::
vector
<
int
>
val
(
attr_desc
.
ints_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
ints_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
ints
(
i
);
}
return
val
;
}
case
paddle
::
framework
::
AttrType
::
FLOATS
:
{
std
::
vector
<
float
>
val
(
attr_desc
.
floats_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
floats_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
floats
(
i
);
}
return
val
;
}
case
paddle
::
framework
::
AttrType
::
STRINGS
:
{
std
::
vector
<
std
::
string
>
val
(
attr_desc
.
strings_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
strings_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
strings
(
i
);
}
return
val
;
}
}
PADDLE_ENFORCE
(
false
,
"Unknown OpDesc::AttrDesc::type !"
);
return
boost
::
blank
();
}
};
template
<
>
void
AttrTypeHelper
::
SetAttrType
<
int
>
(
AttrProto
*
attr
)
{
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
INT
);
}
template
<
>
void
AttrTypeHelper
::
SetAttrType
<
float
>
(
AttrProto
*
attr
)
{
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
}
template
<
>
void
AttrTypeHelper
::
SetAttrType
<
std
::
string
>
(
AttrProto
*
attr
)
{
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
STRING
);
}
template
<
>
void
AttrTypeHelper
::
SetAttrType
<
std
::
vector
<
int
>>
(
AttrProto
*
attr
)
{
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
}
template
<
>
void
AttrTypeHelper
::
SetAttrType
<
std
::
vector
<
float
>>
(
AttrProto
*
attr
)
{
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOATS
);
}
template
<
>
void
AttrTypeHelper
::
SetAttrType
<
std
::
vector
<
std
::
string
>>
(
AttrProto
*
attr
)
{
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
}
// this class not only make proto but also init attribute checkers.
class
OpProtoAndCheckerMaker
{
public:
OpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
proto_
(
proto
),
op_checker_
(
op_checker
)
{}
protected:
void
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
)
{
auto
input
=
proto_
->
mutable_inputs
()
->
Add
();
*
(
input
->
mutable_name
())
=
name
;
*
(
input
->
mutable_comment
())
=
comment
;
}
void
AddOutput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
)
{
auto
output
=
proto_
->
mutable_outputs
()
->
Add
();
*
(
output
->
mutable_name
())
=
name
;
*
(
output
->
mutable_comment
())
=
comment
;
}
template
<
typename
T
>
TypedAttrChecker
<
T
>&
AddAttr
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
)
{
auto
attr
=
proto_
->
mutable_attrs
()
->
Add
();
*
(
attr
->
mutable_name
())
=
name
;
*
(
attr
->
mutable_comment
())
=
comment
;
AttrTypeHelper
::
SetAttrType
<
T
>
(
attr
);
return
op_checker_
->
AddAttrChecker
<
T
>
(
name
);
}
void
AddType
(
const
std
::
string
&
op_type
)
{
proto_
->
set_type
(
op_type
);
}
void
AddComment
(
const
std
::
string
&
comment
)
{
*
(
proto_
->
mutable_comment
())
=
comment
;
}
OpProto
*
proto_
;
OpAttrChecker
*
op_checker_
;
};
class
OpRegistry
{
typedef
std
::
function
<
OpBase
*
()
>
OpCreator
;
public:
template
<
typename
OpType
,
typename
ProtoMakerType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
creators_
[
op_type
]
=
[]()
{
return
new
OpType
;
};
OpProto
&
op_proto
=
protos_
[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers_
[
op_type
];
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
PADDLE_ENFORCE
(
op_proto
.
IsInitialized
()
==
true
,
"Fail to initialize %s's OpProto !"
,
op_type
);
}
static
OpBase
*
CreateOp
(
const
OpDesc
&
op_desc
)
{
std
::
string
op_type
=
op_desc
.
type
();
OpBase
*
op
=
(
creators_
.
at
(
op_type
))();
(
op
->
inputs_
).
resize
(
op_desc
.
inputs_size
());
for
(
int
i
=
0
;
i
<
op_desc
.
inputs_size
();
++
i
)
{
(
op
->
inputs_
)[
i
]
=
op_desc
.
inputs
(
i
);
}
(
op
->
outputs_
).
resize
(
op_desc
.
outputs_size
());
for
(
int
i
=
0
;
i
<
op_desc
.
outputs_size
();
++
i
)
{
(
op
->
outputs_
)[
i
]
=
op_desc
.
outputs
(
i
);
}
for
(
int
i
=
0
;
i
<
op_desc
.
attrs_size
();
++
i
)
{
const
AttrDesc
&
ith_attr
=
op_desc
.
attrs
(
i
);
std
::
string
name
=
ith_attr
.
name
();
(
op
->
attr_map_
)[
name
]
=
AttrTypeHelper
::
GetAttrValue
(
ith_attr
);
}
const
OpAttrChecker
&
op_checker
=
op_checkers_
.
at
(
op_type
);
op_checker
.
Check
(
op
->
attr_map_
);
return
op
;
}
private:
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
creators_
;
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
};
std
::
unordered_map
<
std
::
string
,
std
::
function
<
OpBase
*
()
>>
OpRegistry
::
creators_
;
std
::
unordered_map
<
std
::
string
,
OpProto
>
OpRegistry
::
protos_
;
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
OpRegistry
::
op_checkers_
;
template
<
typename
OpType
,
typename
ProtoMakerType
>
class
OpRegisterHelper
{
public:
OpRegisterHelper
(
std
::
string
op_type
)
{
OpRegistry
::
RegisterOp
<
OpType
,
ProtoMakerType
>
(
op_type
);
}
};
#define REGISTER_OP(__op_class, __op_maker_class, __op_type) \
class __op_class##Register { \
private: \
const static OpRegisterHelper<__op_class, __op_maker_class> reg; \
}; \
const OpRegisterHelper<__op_class, __op_maker_class> \
__op_class##Register::reg(#__op_type);
// Demos
class
CosineOp
:
public
OpBase
{
public:
virtual
std
::
string
Run
()
const
{
std
::
string
msg
=
"CosineOp runs! scale = "
+
std
::
to_string
(
boost
::
get
<
float
>
(
attr_map_
.
at
(
"scale"
)));
return
msg
;
}
};
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
CosineOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of cosine op"
);
AddOutput
(
"output"
,
"output of cosine op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
AddType
(
"cos"
);
AddComment
(
"This is cos op"
);
}
};
REGISTER_OP
(
CosineOp
,
CosineOpProtoAndCheckerMaker
,
cos_sim
)
class
MyTestOp
:
public
OpBase
{
public:
virtual
std
::
string
Run
()
const
{
std
::
string
msg
=
"MyTestOp runs! test_attr = "
+
std
::
to_string
(
boost
::
get
<
int
>
(
attr_map_
.
at
(
"test_attr"
)));
return
msg
;
}
};
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
MyTestOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of cosine op"
);
AddOutput
(
"output"
,
"output of cosine op"
);
auto
my_checker
=
[](
int
i
)
{
PADDLE_ENFORCE
(
i
%
2
==
0
,
"'test_attr' must be even!"
);
};
AddAttr
<
int
>
(
"test_attr"
,
"a simple test attribute"
)
.
AddCustomChecker
(
my_checker
);
AddType
(
"my_test_op"
);
AddComment
(
"This is my_test op"
);
}
};
REGISTER_OP
(
MyTestOp
,
MyTestOpProtoAndCheckerMaker
,
my_test_op
)
}
// namespace framework
}
// namespace paddle
paddle/framework/op_registry_test.cc
0 → 100644
浏览文件 @
8426beb4
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
TEST
(
OpRegistry
,
CreateOp
)
{
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
add_inputs
(
"aa"
);
op_desc
.
add_outputs
(
"bb"
);
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
3.3
);
paddle
::
framework
::
OpBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
std
::
string
debug_str
=
op
->
Run
();
std
::
string
str
=
"CosineOp runs! scale = "
+
std
::
to_string
(
3.3
);
ASSERT_EQ
(
str
.
size
(),
debug_str
.
size
());
for
(
size_t
i
=
0
;
i
<
debug_str
.
length
();
++
i
)
{
ASSERT_EQ
(
debug_str
[
i
],
str
[
i
]);
}
}
TEST
(
OpRegistry
,
IllegalAttr
)
{
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
add_inputs
(
"aa"
);
op_desc
.
add_outputs
(
"bb"
);
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
-
2.0
);
bool
caught
=
false
;
try
{
paddle
::
framework
::
OpBase
*
op
__attribute__
((
unused
))
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
caught
=
true
;
std
::
string
msg
=
"larger_than check fail"
;
const
char
*
err_msg
=
err
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
err_msg
[
i
],
msg
[
i
]);
}
}
ASSERT_TRUE
(
caught
);
}
TEST
(
OpRegistry
,
DefaultValue
)
{
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
add_inputs
(
"aa"
);
op_desc
.
add_outputs
(
"bb"
);
paddle
::
framework
::
OpBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
std
::
string
debug_str
=
op
->
Run
();
float
default_value
=
1.0
;
std
::
string
str
=
"CosineOp runs! scale = "
+
std
::
to_string
(
default_value
);
ASSERT_EQ
(
str
.
size
(),
debug_str
.
size
());
for
(
size_t
i
=
0
;
i
<
debug_str
.
length
();
++
i
)
{
ASSERT_EQ
(
debug_str
[
i
],
str
[
i
]);
}
}
TEST
(
OpRegistry
,
CustomChecker
)
{
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"my_test_op"
);
op_desc
.
add_inputs
(
"ii"
);
op_desc
.
add_outputs
(
"oo"
);
// attr 'test_attr' is not set
bool
caught
=
false
;
try
{
paddle
::
framework
::
OpBase
*
op
__attribute__
((
unused
))
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
caught
=
true
;
std
::
string
msg
=
"Attribute 'test_attr' is required!"
;
const
char
*
err_msg
=
err
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
err_msg
[
i
],
msg
[
i
]);
}
}
ASSERT_TRUE
(
caught
);
// set 'test_attr' set to an illegal value
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"test_attr"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
INT
);
attr
->
set_i
(
3
);
caught
=
false
;
try
{
paddle
::
framework
::
OpBase
*
op
__attribute__
((
unused
))
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
caught
=
true
;
std
::
string
msg
=
"'test_attr' must be even!"
;
const
char
*
err_msg
=
err
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
err_msg
[
i
],
msg
[
i
]);
}
}
ASSERT_TRUE
(
caught
);
// set 'test_attr' set to a legal value
op_desc
.
mutable_attrs
()
->
Clear
();
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"test_attr"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
INT
);
attr
->
set_i
(
4
);
paddle
::
framework
::
OpBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
std
::
string
debug_str
=
op
->
Run
();
std
::
string
str
=
"MyTestOp runs! test_attr = "
+
std
::
to_string
(
4
);
ASSERT_EQ
(
str
.
size
(),
debug_str
.
size
());
for
(
size_t
i
=
0
;
i
<
debug_str
.
length
();
++
i
)
{
ASSERT_EQ
(
debug_str
[
i
],
str
[
i
]);
}
}
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录