Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
be441f7d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
be441f7d
编写于
7月 12, 2017
作者:
Q
Qiao Longfei
提交者:
GitHub
7月 12, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test OpKernel (#2820)
Add unit test for OpKernel
上级
555b0a72
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
61 addition
and
11 deletion
+61
-11
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+1
-1
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+7
-7
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+2
-2
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+51
-1
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
be441f7d
...
...
@@ -12,7 +12,7 @@ cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library
(
op_desc SRCS op_desc.proto DEPS attr_type
)
cc_test
(
op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf
)
cc_library
(
operator SRCS operator.cc DEPS op_desc protobuf
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
place
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto op_desc
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry operator
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
...
...
paddle/framework/op_registry.h
浏览文件 @
be441f7d
...
...
@@ -147,13 +147,13 @@ class OpRegisterHelper {
}
};
#define REGISTER_OP(
__op_class, __op_maker_class, __op_type)
\
class
__op_class##Register {
\
private: \
const static OpRegisterHelper<
__op_class, __op_maker_class> reg;
\
}; \
const OpRegisterHelper<
__op_class, __op_maker_class>
\
__op_class##Register::reg(#__op_type);
#define REGISTER_OP(
type, op_class, op_maker_class)
\
class
op_class##Register {
\
private:
\
const static OpRegisterHelper<
op_class, op_maker_class> reg;
\
};
\
const OpRegisterHelper<
op_class, op_maker_class> op_class##Register::reg(
\
#type)
}
// namespace framework
}
// namespace paddle
paddle/framework/op_registry_test.cc
浏览文件 @
be441f7d
...
...
@@ -26,7 +26,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
};
REGISTER_OP
(
CosineOp
,
CosineOpProtoAndCheckerMaker
,
cos_sim
)
REGISTER_OP
(
cos_sim
,
CosineOp
,
CosineOpProtoAndCheckerMaker
);
class
MyTestOp
:
public
OperatorBase
{
public:
...
...
@@ -53,7 +53,7 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
};
REGISTER_OP
(
MyTestOp
,
MyTestOpProtoAndCheckerMaker
,
my_test_op
)
REGISTER_OP
(
my_test_op
,
MyTestOp
,
MyTestOpProtoAndCheckerMaker
);
}
// namespace framework
}
// namespace paddle
...
...
paddle/framework/operator_test.cc
浏览文件 @
be441f7d
...
...
@@ -45,7 +45,7 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
};
REGISTER_OP
(
OperatorTest
,
OperatorTestProtoAndCheckerMaker
,
test_operator
)
REGISTER_OP
(
test_operator
,
OperatorTest
,
OperatorTestProtoAndCheckerMaker
);
TEST
(
OperatorBase
,
all
)
{
OpDesc
op_desc
;
...
...
@@ -69,5 +69,55 @@ TEST(OperatorBase, all) {
delete
op
;
}
class
OpKernelTestProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
OpKernelTestProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of test op"
);
AddOutput
(
"output"
,
"output of test op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
AddType
(
"test_operator"
);
AddComment
(
"This is test op"
);
}
};
class
OpWithKernelTest
:
public
OperatorWithKernel
{
public:
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{}
};
class
CPUKernelTest
:
public
OpKernel
{
public:
void
Compute
(
const
KernelContext
&
context
)
const
{
float
scale
=
context
.
op_
.
GetAttr
<
float
>
(
"scale"
);
ASSERT_NEAR
(
scale
,
3.14
,
1e-5
);
std
::
cout
<<
"this is cpu kernel"
<<
std
::
endl
;
std
::
cout
<<
context
.
op_
.
DebugString
()
<<
std
::
endl
;
}
};
REGISTER_OP
(
op_with_kernel
,
OpWithKernelTest
,
OpKernelTestProtoAndCheckerMaker
);
REGISTER_OP_KERNEL
(
op_with_kernel
,
platform
::
CPUPlace
,
CPUKernelTest
);
TEST
(
OpKernel
,
all
)
{
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_with_kernel"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"IN1"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"OUT1"
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
3.14
);
platform
::
CPUDeviceContext
cpu_device_context
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
OperatorBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
op
->
Run
(
scope
,
cpu_device_context
);
delete
op
;
}
}
// namespace framework
}
// namespace paddle
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录