Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
87189665
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
87189665
编写于
7月 17, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
差异文件
merge baidu/develop
上级
2a03e380
a0caf234
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
253 addition
and
38 deletion
+253
-38
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+11
-0
paddle/framework/operator.cc
paddle/framework/operator.cc
+61
-3
paddle/framework/operator.h
paddle/framework/operator.h
+72
-26
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+108
-8
paddle/operators/add_op.h
paddle/operators/add_op.h
+1
-1
未找到文件。
paddle/framework/op_registry.h
浏览文件 @
87189665
...
@@ -216,21 +216,32 @@ class OpRegistry {
...
@@ -216,21 +216,32 @@ class OpRegistry {
static
OperatorPtr
CreateOp
(
const
OpDesc
&
op_desc
)
{
static
OperatorPtr
CreateOp
(
const
OpDesc
&
op_desc
)
{
std
::
string
op_type
=
op_desc
.
type
();
std
::
string
op_type
=
op_desc
.
type
();
OperatorPtr
op
(
creators
().
at
(
op_type
)());
OperatorPtr
op
(
creators
().
at
(
op_type
)());
const
OpProto
&
op_proto
=
protos
().
at
(
op_type
);
// set op's inputs_ from desc.
op
->
type_
=
op_desc
.
type
();
op
->
type_
=
op_desc
.
type
();
op
->
inputs_
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
op
->
inputs_
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
std
::
copy
(
op_desc
.
inputs
().
begin
(),
op_desc
.
inputs
().
end
(),
std
::
copy
(
op_desc
.
inputs
().
begin
(),
op_desc
.
inputs
().
end
(),
std
::
back_inserter
(
op
->
inputs_
));
std
::
back_inserter
(
op
->
inputs_
));
// set op's outputs_ from desc.
op
->
outputs_
.
reserve
((
size_t
)
op_desc
.
outputs_size
());
op
->
outputs_
.
reserve
((
size_t
)
op_desc
.
outputs_size
());
std
::
copy
(
op_desc
.
outputs
().
begin
(),
op_desc
.
outputs
().
end
(),
std
::
copy
(
op_desc
.
outputs
().
begin
(),
op_desc
.
outputs
().
end
(),
std
::
back_inserter
(
op
->
outputs_
));
std
::
back_inserter
(
op
->
outputs_
));
// set op's attr;
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
op
->
attrs_
[
attr
.
name
()]
=
AttrTypeHelper
::
GetAttrValue
(
attr
);
op
->
attrs_
[
attr
.
name
()]
=
AttrTypeHelper
::
GetAttrValue
(
attr
);
}
}
op_checkers
().
at
(
op_type
).
Check
(
op
->
attrs_
);
op_checkers
().
at
(
op_type
).
Check
(
op
->
attrs_
);
// set argument offsets stored in op.
CreateInOutOffsetMap
(
op
,
op_proto
);
op
->
Init
();
op
->
Init
();
return
op
;
return
op
;
}
}
// init op.in_out_idxs_ to accelerate argument's offset lookup.
static
void
CreateInOutOffsetMap
(
OperatorPtr
op
,
const
OpProto
&
proto
)
{
op
->
CreateInOutOffsetMap
(
proto
);
}
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
return
protos_
;
return
protos_
;
...
...
paddle/framework/operator.cc
浏览文件 @
87189665
...
@@ -12,25 +12,83 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,25 +12,83 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include "paddle/framework/operator.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
template
<
>
template
<
>
Eigen
::
DefaultDevice
*
OpKernel
::
KernelContext
::
GetEigenDevice
<
Eigen
::
DefaultDevice
*
KernelContext
::
GetEigenDevice
<
platform
::
CPUPlace
,
Eigen
::
DefaultDevice
>
()
const
{
platform
::
CPUPlace
,
Eigen
::
DefaultDevice
>
()
const
{
return
device_context_
.
get_eigen_device
<
Eigen
::
DefaultDevice
>
();
return
device_context_
.
get_eigen_device
<
Eigen
::
DefaultDevice
>
();
}
}
#ifndef PADDLE_ONLY_CPU
#ifndef PADDLE_ONLY_CPU
template
<
>
template
<
>
Eigen
::
GpuDevice
*
OpKernel
::
KernelContext
::
GetEigenDevice
<
Eigen
::
GpuDevice
*
platform
::
GPUPlace
,
Eigen
::
GpuDevice
>
()
const
{
KernelContext
::
GetEigenDevice
<
platform
::
GPUPlace
,
Eigen
::
GpuDevice
>
()
const
{
return
device_context_
.
get_eigen_device
<
Eigen
::
GpuDevice
>
();
return
device_context_
.
get_eigen_device
<
Eigen
::
GpuDevice
>
();
}
}
#endif
#endif
void
OperatorBase
::
CreateInOutOffsetMap
(
const
OpProto
&
proto
)
{
PADDLE_ENFORCE
(
in_out_idxs_
.
empty
(),
"duplicate call CreateInOutOffsetMap"
);
for
(
int
i
=
0
;
i
<
proto
.
inputs_size
();
i
++
)
{
const
auto
&
name
=
proto
.
inputs
()[
i
].
name
();
in_out_idxs_
[
name
]
=
i
;
}
for
(
int
i
=
0
;
i
<
proto
.
outputs_size
();
i
++
)
{
const
auto
&
name
=
proto
.
outputs
()[
i
].
name
();
in_out_idxs_
[
name
]
=
i
;
}
}
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
it
=
in_out_idxs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
.
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
if
(
attrs_
.
count
(
"input_format"
)
==
0
)
{
return
inputs_
[
it
->
second
];
}
else
{
const
auto
&
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
int
idx
=
input_format
[
it
->
second
];
return
inputs_
.
at
(
idx
);
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Inputs
(
const
std
::
string
&
name
)
const
{
auto
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
auto
offset
=
in_out_idxs_
.
at
(
name
);
return
std
::
vector
<
std
::
string
>
{
inputs_
.
begin
()
+
input_format
.
at
(
offset
),
inputs_
.
begin
()
+
input_format
.
at
(
offset
+
1
)};
}
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
auto
it
=
in_out_idxs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
.
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
if
(
attrs_
.
count
(
"output_format"
)
==
0
)
{
return
outputs_
[
it
->
second
];
}
else
{
const
auto
&
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
int
idx
=
output_format
[
it
->
second
];
return
outputs_
.
at
(
idx
);
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Outputs
(
const
std
::
string
&
name
)
const
{
auto
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
auto
offset
=
in_out_idxs_
.
at
(
name
);
return
std
::
vector
<
std
::
string
>
{
outputs_
.
begin
()
+
output_format
.
at
(
offset
),
outputs_
.
begin
()
+
output_format
.
at
(
offset
+
1
)};
}
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
"=================
\n
"
;
ss
<<
"=================
\n
"
;
...
...
paddle/framework/operator.h
浏览文件 @
87189665
...
@@ -18,8 +18,10 @@ limitations under the License. */
...
@@ -18,8 +18,10 @@ limitations under the License. */
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/device_context.h"
...
@@ -77,11 +79,79 @@ class OperatorBase {
...
@@ -77,11 +79,79 @@ class OperatorBase {
virtual
void
Run
(
const
ScopePtr
&
scope
,
virtual
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
// Get a input with argument's name described in `op_proto`
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
// Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Inputs
(
const
std
::
string
&
name
)
const
;
// Get a output with argument's name described in `op_proto`
const
std
::
string
&
Output
(
const
std
::
string
&
name
)
const
;
// Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
;
// init in_out_idxs_ to accelerate argument's offset lookup.
void
CreateInOutOffsetMap
(
const
OpProto
&
proto
);
public:
public:
std
::
string
type_
;
std
::
string
type_
;
std
::
vector
<
std
::
string
>
inputs_
;
std
::
vector
<
std
::
string
>
inputs_
;
std
::
vector
<
std
::
string
>
outputs_
;
std
::
vector
<
std
::
string
>
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
// store the arguments' offset described in op_desc.
std
::
unordered_map
<
std
::
string
,
int
>
in_out_idxs_
;
};
class
KernelContext
{
public:
KernelContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
*
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
Variable
*
Input
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
inputs_
[
index
]);
}
Variable
*
Output
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
outputs_
[
index
]);
}
const
Variable
*
Input
(
const
std
::
string
&
name
)
const
{
return
scope_
->
GetVariable
(
op_
.
Input
(
name
));
}
const
Variable
*
Output
(
const
std
::
string
&
name
)
const
{
return
scope_
->
GetVariable
(
op_
.
Output
(
name
));
}
const
std
::
vector
<
const
Variable
*>
Inputs
(
const
std
::
string
&
name
)
const
{
auto
names
=
op_
.
Inputs
(
name
);
std
::
vector
<
const
Variable
*>
res
;
std
::
transform
(
names
.
begin
(),
names
.
end
(),
res
.
begin
(),
[
this
](
const
std
::
string
&
name
)
{
return
scope_
->
GetVariable
(
name
);
});
return
res
;
}
const
std
::
vector
<
const
Variable
*>
Outputs
(
const
std
::
string
&
name
)
const
{
auto
names
=
op_
.
Outputs
(
name
);
std
::
vector
<
const
Variable
*>
res
;
std
::
transform
(
names
.
begin
(),
names
.
end
(),
res
.
begin
(),
[
this
](
const
std
::
string
&
name
)
{
return
scope_
->
GetVariable
(
name
);
});
return
res
;
}
template
<
typename
PlaceType
,
typename
DeviceType
=
typename
EigenDeviceConverter
<
PlaceType
>
::
EigenDeviceType
>
DeviceType
*
GetEigenDevice
()
const
;
platform
::
Place
GetPlace
()
const
{
return
device_context_
.
GetPlace
();
}
const
OperatorBase
&
op_
;
const
std
::
shared_ptr
<
Scope
>&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
};
};
class
OpKernel
{
class
OpKernel
{
...
@@ -92,31 +162,6 @@ class OpKernel {
...
@@ -92,31 +162,6 @@ class OpKernel {
* device resource such as CUDA stream, cublas handle, etc. from
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
* KernelContext. User should construct it before run the Operator.
*/
*/
class
KernelContext
{
public:
KernelContext
(
const
OperatorBase
*
op
,
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
*
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
Variable
*
Input
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
inputs_
[
index
]);
}
Variable
*
Output
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
outputs_
[
index
]);
}
template
<
typename
PlaceType
,
typename
DeviceType
=
typename
EigenDeviceConverter
<
PlaceType
>
::
EigenDeviceType
>
DeviceType
*
GetEigenDevice
()
const
;
platform
::
Place
GetPlace
()
const
{
return
device_context_
.
GetPlace
();
}
const
OperatorBase
&
op_
;
const
ScopePtr
&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
};
virtual
void
Compute
(
const
KernelContext
&
context
)
const
=
0
;
virtual
void
Compute
(
const
KernelContext
&
context
)
const
=
0
;
...
@@ -162,7 +207,7 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -162,7 +207,7 @@ class OperatorWithKernel : public OperatorBase {
void
Run
(
const
ScopePtr
&
scope
,
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
auto
&
opKernel
=
AllOpKernels
().
at
(
type_
).
at
(
OpKernelKey
(
dev_ctx
));
auto
&
opKernel
=
AllOpKernels
().
at
(
type_
).
at
(
OpKernelKey
(
dev_ctx
));
opKernel
->
Compute
(
OpKernel
::
KernelContext
(
this
,
scope
,
dev_ctx
));
opKernel
->
Compute
(
KernelContext
(
this
,
scope
,
dev_ctx
));
}
}
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
...
@@ -170,6 +215,7 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -170,6 +215,7 @@ class OperatorWithKernel : public OperatorBase {
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
return
g_all_op_kernels
;
return
g_all_op_kernels
;
}
}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
final
{
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
final
{
std
::
vector
<
const
Tensor
*>
ins
;
std
::
vector
<
const
Tensor
*>
ins
;
VarNamesToTensors
(
scope
,
inputs_
,
&
ins
);
VarNamesToTensors
(
scope
,
inputs_
,
&
ins
);
...
...
paddle/framework/operator_test.cc
浏览文件 @
87189665
...
@@ -30,7 +30,6 @@ class OpWithoutKernelTest : public OperatorBase {
...
@@ -30,7 +30,6 @@ class OpWithoutKernelTest : public OperatorBase {
op_run_num
++
;
op_run_num
++
;
ASSERT_EQ
((
int
)
inputs_
.
size
(),
1
);
ASSERT_EQ
((
int
)
inputs_
.
size
(),
1
);
ASSERT_EQ
((
int
)
outputs_
.
size
(),
1
);
ASSERT_EQ
((
int
)
outputs_
.
size
(),
1
);
ASSERT_NEAR
(
GetAttr
<
float
>
(
"scale"
),
3.14
,
1e-5
);
ASSERT_EQ
(
scope
->
GetVariable
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
scope
->
GetVariable
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
x
,
1
);
ASSERT_EQ
(
x
,
1
);
ASSERT_NE
(
scope
->
GetVariable
(
outputs_
[
0
]),
nullptr
);
ASSERT_NE
(
scope
->
GetVariable
(
outputs_
[
0
]),
nullptr
);
...
@@ -86,9 +85,11 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -86,9 +85,11 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
public:
OpKernelTestProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
OpKernelTestProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of test op"
);
AddInput
(
"x"
,
"input of test op"
);
AddOutput
(
"output"
,
"output of test op"
);
AddOutput
(
"y"
,
"output of test op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
AddComment
(
"This is test op"
);
AddComment
(
"This is test op"
);
}
}
};
};
...
@@ -103,11 +104,65 @@ class OpWithKernelTest : public OperatorWithKernel {
...
@@ -103,11 +104,65 @@ class OpWithKernelTest : public OperatorWithKernel {
class
CPUKernelTest
:
public
OpKernel
{
class
CPUKernelTest
:
public
OpKernel
{
public:
public:
void
Compute
(
const
KernelContext
&
context
)
const
{
void
Compute
(
const
KernelContext
&
ctx
)
const
{
std
::
cout
<<
"this is cpu kernel"
<<
std
::
endl
;
std
::
cout
<<
ctx
.
op_
.
DebugString
()
<<
std
::
endl
;
cpu_kernel_run_num
++
;
cpu_kernel_run_num
++
;
ASSERT_EQ
((
int
)
context
.
op_
.
inputs_
.
size
(),
1
);
ASSERT_EQ
(
ctx
.
op_
.
Input
(
"x"
),
"IN1"
);
ASSERT_EQ
((
int
)
context
.
op_
.
outputs_
.
size
(),
1
);
ASSERT_EQ
(
ctx
.
op_
.
Output
(
"y"
),
"OUT1"
);
ASSERT_NEAR
(
context
.
op_
.
GetAttr
<
float
>
(
"scale"
),
3.14
,
1e-5
);
}
};
// multiple inputs test
class
OperatorMultiInputsTest
:
public
OperatorBase
{
public:
void
Init
()
override
{
x
=
1
;
}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{}
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
ASSERT_EQ
(
scope
->
GetVariable
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
x
,
1
);
ASSERT_NE
(
scope
->
GetVariable
(
outputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
Input
(
"x"
),
"IN1"
);
ASSERT_EQ
(
Input
(
"y"
),
"OUT1"
);
}
public:
float
x
=
0
;
};
class
OpKernelTestMultiInputsProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
OpKernelTestMultiInputsProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInputs
(
"xs"
,
"inputs of test op"
);
AddInput
(
"k"
,
"input of test op"
);
AddOutputs
(
"ys"
,
"outputs of test op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
AddComment
(
"This is test op"
);
}
};
class
CPUKernalMultiInputsTest
:
public
OpKernel
{
public:
void
Compute
(
const
KernelContext
&
ctx
)
const
{
auto
xs
=
ctx
.
op_
.
Inputs
(
"xs"
);
ASSERT_EQ
(
xs
.
size
(),
3UL
);
ASSERT_EQ
(
xs
[
0
],
"x0"
);
ASSERT_EQ
(
xs
[
1
],
"x1"
);
ASSERT_EQ
(
xs
[
2
],
"x2"
);
auto
k
=
ctx
.
op_
.
Input
(
"k"
);
ASSERT_EQ
(
k
,
"k0"
);
auto
ys
=
ctx
.
op_
.
Outputs
(
"ys"
);
ASSERT_EQ
(
ys
.
size
(),
2UL
);
ASSERT_EQ
(
ys
[
0
],
"y0"
);
ASSERT_EQ
(
ys
[
1
],
"y1"
);
}
}
};
};
...
@@ -118,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
...
@@ -118,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle
::
framework
::
OpKernelTestProtoAndCheckerMaker
);
paddle
::
framework
::
OpKernelTestProtoAndCheckerMaker
);
REGISTER_OP_CPU_KERNEL
(
op_with_kernel
,
paddle
::
framework
::
CPUKernelTest
);
REGISTER_OP_CPU_KERNEL
(
op_with_kernel
,
paddle
::
framework
::
CPUKernelTest
);
// test with single input
TEST
(
OpKernel
,
all
)
{
TEST
(
OpKernel
,
all
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_with_kernel"
);
op_desc
.
set_type
(
"op_with_kernel"
);
...
@@ -137,3 +193,47 @@ TEST(OpKernel, all) {
...
@@ -137,3 +193,47 @@ TEST(OpKernel, all) {
op
->
Run
(
scope
,
cpu_device_context
);
op
->
Run
(
scope
,
cpu_device_context
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
1
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
1
);
}
}
REGISTER_OP
(
op_multi_inputs_with_kernel
,
paddle
::
framework
::
OpWithKernelTest
,
paddle
::
framework
::
OpKernelTestMultiInputsProtoAndCheckerMaker
);
REGISTER_OP_CPU_KERNEL
(
op_multi_inputs_with_kernel
,
paddle
::
framework
::
CPUKernalMultiInputsTest
);
// test with multi inputs
TEST
(
OpKernel
,
multi_inputs
)
{
using
namespace
paddle
::
framework
;
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_multi_inputs_with_kernel"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x0"
;
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x1"
;
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x2"
;
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"k0"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"y0"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"y1"
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
3.14
);
auto
attr0
=
op_desc
.
mutable_attrs
()
->
Add
();
attr0
->
set_name
(
"input_format"
);
attr0
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
auto
input_format
=
attr0
->
mutable_ints
();
input_format
->
Add
(
0
);
// x0
input_format
->
Add
(
3
);
// k
input_format
->
Add
(
4
);
// end
auto
attr1
=
op_desc
.
mutable_attrs
()
->
Add
();
attr1
->
set_name
(
"output_format"
);
attr1
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
auto
output_format
=
attr1
->
mutable_ints
();
output_format
->
Add
(
0
);
// y0
output_format
->
Add
(
2
);
// y1
paddle
::
platform
::
CPUDeviceContext
cpu_device_context
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
OperatorPtr
op
(
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
));
op
->
Run
(
scope
,
cpu_device_context
);
}
paddle/operators/add_op.h
浏览文件 @
87189665
...
@@ -22,7 +22,7 @@ namespace operators {
...
@@ -22,7 +22,7 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
AddKernel
:
public
framework
::
OpKernel
{
class
AddKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
KernelContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{
auto
input0
=
context
.
Input
(
0
)
->
Get
<
framework
::
Tensor
>
();
auto
input0
=
context
.
Input
(
0
)
->
Get
<
framework
::
Tensor
>
();
auto
input1
=
context
.
Input
(
1
)
->
Get
<
framework
::
Tensor
>
();
auto
input1
=
context
.
Input
(
1
)
->
Get
<
framework
::
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录