Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
ffbb0be2
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ffbb0be2
编写于
8月 14, 2017
作者:
Y
Yu Yang
提交者:
GitHub
8月 14, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3444 from reyoung/use_ctor_create_op
Using constructor to create an operator.
上级
fa54fb33
daaa45b4
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
156 addition
and
113 deletion
+156
-113
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+4
-3
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+20
-14
paddle/framework/grad_op_builder_test.cc
paddle/framework/grad_op_builder_test.cc
+1
-1
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+16
-28
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+2
-2
paddle/framework/operator.cc
paddle/framework/operator.cc
+17
-0
paddle/framework/operator.h
paddle/framework/operator.h
+5
-18
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+7
-5
paddle/operators/add_op.cc
paddle/operators/add_op.cc
+5
-2
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+6
-3
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+2
-1
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+2
-1
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+6
-2
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+5
-2
paddle/operators/net_op.cc
paddle/operators/net_op.cc
+6
-0
paddle/operators/net_op.h
paddle/operators/net_op.h
+3
-1
paddle/operators/net_op_test.cc
paddle/operators/net_op_test.cc
+11
-11
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+10
-4
paddle/operators/recurrent_op.h
paddle/operators/recurrent_op.h
+7
-8
paddle/operators/rowwise_add_op.cc
paddle/operators/rowwise_add_op.cc
+3
-1
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+3
-1
paddle/operators/sigmoid_op.cc
paddle/operators/sigmoid_op.cc
+6
-2
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+6
-2
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+3
-1
未找到文件。
paddle/framework/backward_test.cc
浏览文件 @
ffbb0be2
...
@@ -30,7 +30,7 @@ using DeviceContext = platform::DeviceContext;
...
@@ -30,7 +30,7 @@ using DeviceContext = platform::DeviceContext;
class
EmptyOp
:
public
OperatorBase
{
class
EmptyOp
:
public
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
EmptyOp
,
OperatorBase
)
;
using
OperatorBase
::
OperatorBase
;
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
const
DeviceContext
&
dev_ctx
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
const
DeviceContext
&
dev_ctx
)
const
override
{}
};
};
...
@@ -79,8 +79,9 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
...
@@ -79,8 +79,9 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
class
FcOp
:
public
operators
::
NetOp
{
class
FcOp
:
public
operators
::
NetOp
{
public:
public:
DEFINE_OPERATOR_CTOR
(
FcOp
,
operators
::
NetOp
)
FcOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
void
Init
()
override
{
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
{{
"Out"
,
{
Output
(
"mul_result"
)}}},
{}));
{{
"Out"
,
{
Output
(
"mul_result"
)}}},
{}));
...
...
paddle/framework/grad_op_builder.cc
浏览文件 @
ffbb0be2
...
@@ -20,13 +20,12 @@ namespace paddle {
...
@@ -20,13 +20,12 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
enum
class
OpArgType
{
IN
,
OUT
};
enum
class
OpArgType
{
IN
,
OUT
};
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
OperatorBase
*
dst_op
,
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
const
OpArgType
&
src_type
,
const
OpArgType
&
dst_type
,
OperatorBase
::
VarNameMap
*
vars
,
bool
is_grad
)
{
const
OpArgType
&
src_type
,
bool
is_grad
)
{
const
auto
&
src_inout
=
const
auto
&
src_inout
=
src_type
==
OpArgType
::
IN
?
src_op
->
inputs_
:
src_op
->
outputs_
;
src_type
==
OpArgType
::
IN
?
src_op
->
inputs_
:
src_op
->
outputs_
;
auto
&
dst_inout
=
auto
&
dst_inout
=
*
vars
;
dst_type
==
OpArgType
::
IN
?
dst_op
->
inputs_
:
dst_op
->
outputs_
;
const
OpProto
&
proto
=
OpProtos
().
at
(
src_op
->
type_
);
const
OpProto
&
proto
=
OpProtos
().
at
(
src_op
->
type_
);
const
auto
&
src_arg_list
=
const
auto
&
src_arg_list
=
...
@@ -44,15 +43,22 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
...
@@ -44,15 +43,22 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
}
}
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
)
{
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
)
{
std
::
string
grad_op_type
=
OpRegistry
::
grad_ops
().
at
(
op
->
type_
);
auto
gop_type_it
=
OpRegistry
::
grad_ops
().
find
(
op
->
type_
);
OperatorBase
*
grad_op
=
OpRegistry
::
op_creators
().
at
(
grad_op_type
)();
PADDLE_ENFORCE
(
gop_type_it
!=
OpRegistry
::
grad_ops
().
end
(),
grad_op
->
type_
=
grad_op_type
;
"Operator %s do not register gradient type"
,
op
->
type_
);
grad_op
->
attrs_
=
op
->
attrs_
;
auto
&
grad_op_type
=
gop_type_it
->
second
;
TransOpArg
(
op
,
grad_op
,
OpArgType
::
IN
,
OpArgType
::
IN
,
false
);
// I
OperatorBase
::
VarNameMap
inputs
;
TransOpArg
(
op
,
grad_op
,
OpArgType
::
OUT
,
OpArgType
::
IN
,
false
);
// O
OperatorBase
::
VarNameMap
outputs
;
TransOpArg
(
op
,
grad_op
,
OpArgType
::
OUT
,
OpArgType
::
IN
,
true
);
// OG
TransOpArg
(
op
,
&
inputs
,
OpArgType
::
IN
,
false
);
// I
TransOpArg
(
op
,
grad_op
,
OpArgType
::
IN
,
OpArgType
::
OUT
,
true
);
// IG
TransOpArg
(
op
,
&
inputs
,
OpArgType
::
OUT
,
false
);
// O
return
grad_op
;
TransOpArg
(
op
,
&
inputs
,
OpArgType
::
OUT
,
true
);
// OG
TransOpArg
(
op
,
&
outputs
,
OpArgType
::
IN
,
true
);
// IG
auto
gop_it
=
OpRegistry
::
op_creators
().
find
(
grad_op_type
);
PADDLE_ENFORCE
(
gop_it
!=
OpRegistry
::
op_creators
().
end
(),
"Operator %s 's Gradient %s's creator cannot be found"
,
op
->
type_
,
grad_op_type
);
return
gop_it
->
second
(
grad_op_type
,
inputs
,
outputs
,
op
->
attrs_
);
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/framework/grad_op_builder_test.cc
浏览文件 @
ffbb0be2
...
@@ -10,7 +10,7 @@ namespace framework {
...
@@ -10,7 +10,7 @@ namespace framework {
class
NOP
:
public
OperatorBase
{
class
NOP
:
public
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
NOP
,
OperatorBase
)
;
using
OperatorBase
::
OperatorBase
;
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
...
...
paddle/framework/op_registry.h
浏览文件 @
ffbb0be2
...
@@ -120,13 +120,19 @@ class OpProtoAndCheckerMaker {
...
@@ -120,13 +120,19 @@ class OpProtoAndCheckerMaker {
};
};
class
OpRegistry
{
class
OpRegistry
{
using
OpCreator
=
std
::
function
<
OperatorBase
*
()
>
;
using
VarNameMap
=
OperatorBase
::
VarNameMap
;
using
VarNameMap
=
OperatorBase
::
VarNameMap
;
using
OpCreator
=
std
::
function
<
OperatorBase
*
(
const
std
::
string
&
/*type*/
,
const
VarNameMap
&
/*inputs*/
,
const
VarNameMap
&
/*outputs*/
,
const
AttributeMap
&
/*attrs*/
)
>
;
public:
public:
template
<
typename
OpType
,
typename
ProtoMakerType
>
template
<
typename
OpType
,
typename
ProtoMakerType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
op_creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
op_creators
()[
op_type
]
=
[](
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
{
return
new
OpType
(
type
,
inputs
,
outputs
,
attrs
);
};
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpProto
&
op_proto
=
OpProtos
()[
op_type
];
OpProto
&
op_proto
=
OpProtos
()[
op_type
];
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
...
@@ -141,29 +147,25 @@ class OpRegistry {
...
@@ -141,29 +147,25 @@ class OpRegistry {
template
<
typename
GradOpType
>
template
<
typename
GradOpType
>
static
void
RegisterGradOp
(
const
std
::
string
&
op_type
,
static
void
RegisterGradOp
(
const
std
::
string
&
op_type
,
const
std
::
string
&
grad_op_type
)
{
const
std
::
string
&
grad_op_type
)
{
op_creators
()[
grad_op_type
]
=
[]
{
return
new
GradOpType
;
};
op_creators
()[
grad_op_type
]
=
[](
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
{
return
new
GradOpType
(
type
,
inputs
,
outputs
,
attrs
);
};
grad_ops
()[
op_type
]
=
grad_op_type
;
grad_ops
()[
op_type
]
=
grad_op_type
;
}
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
{
AttributeMap
attrs
)
{
auto
op_create_it
=
op_creators
().
find
(
type
);
auto
op_create_it
=
op_creators
().
find
(
type
);
PADDLE_ENFORCE
(
op_create_it
!=
op_creators
().
end
(),
PADDLE_ENFORCE
(
op_create_it
!=
op_creators
().
end
(),
"Operator %s cannot be found."
,
type
);
"Operator %s cannot be found."
,
type
);
op_checkers
().
at
(
type
).
Check
(
attrs
);
auto
op
=
op_create_it
->
second
();
auto
op
=
op_create_it
->
second
(
type
,
inputs
,
outputs
,
attrs
);
op
->
type_
=
type
;
op
->
inputs_
=
inputs
;
op
->
outputs_
=
outputs
;
op
->
attrs_
=
attrs
;
op_checkers
().
at
(
type
).
Check
(
op
->
attrs_
);
GenerateTempVariableName
(
op
);
op
->
Init
();
return
std
::
shared_ptr
<
OperatorBase
>
(
op
);
return
std
::
shared_ptr
<
OperatorBase
>
(
op
);
}
}
...
@@ -195,7 +197,6 @@ class OpRegistry {
...
@@ -195,7 +197,6 @@ class OpRegistry {
PADDLE_ENFORCE
(
!
op
.
IsNetOp
(),
PADDLE_ENFORCE
(
!
op
.
IsNetOp
(),
"Use framework::Backward to get backward ops"
);
"Use framework::Backward to get backward ops"
);
std
::
shared_ptr
<
OperatorBase
>
grad_op
(
BuildGradOp
(
&
op
));
std
::
shared_ptr
<
OperatorBase
>
grad_op
(
BuildGradOp
(
&
op
));
grad_op
->
Init
();
return
grad_op
;
return
grad_op
;
}
}
...
@@ -214,19 +215,6 @@ class OpRegistry {
...
@@ -214,19 +215,6 @@ class OpRegistry {
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
return
op_checkers_
;
return
op_checkers_
;
}
}
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
for
(
auto
&
output
:
op
->
outputs_
)
{
for
(
auto
&
output_name
:
output
.
second
)
{
if
(
output_name
==
kTempVarName
)
{
output_name
+=
op
->
type_
;
output_name
+=
"@"
;
output_name
+=
std
::
to_string
(
gUniqId
.
fetch_add
(
1
));
}
}
}
}
};
};
class
Registrar
{
class
Registrar
{
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
ffbb0be2
...
@@ -7,7 +7,7 @@ namespace paddle {
...
@@ -7,7 +7,7 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
class
CosineOp
:
public
OperatorBase
{
class
CosineOp
:
public
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
CosineOp
,
OperatorBase
)
;
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
...
@@ -28,7 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -28,7 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class
MyTestOp
:
public
OperatorBase
{
class
MyTestOp
:
public
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
MyTestOp
,
OperatorBase
)
;
using
OperatorBase
::
OperatorBase
;
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
...
...
paddle/framework/operator.cc
浏览文件 @
ffbb0be2
...
@@ -122,6 +122,23 @@ void OperatorBase::Rename(const std::string& old_name,
...
@@ -122,6 +122,23 @@ void OperatorBase::Rename(const std::string& old_name,
}
}
}
}
OperatorBase
::
OperatorBase
(
const
std
::
string
&
type
,
const
OperatorBase
::
VarNameMap
&
inputs
,
const
OperatorBase
::
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
type_
(
type
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
for
(
auto
&
output
:
outputs_
)
{
for
(
auto
&
output_name
:
output
.
second
)
{
if
(
output_name
==
kTempVarName
)
{
output_name
+=
type_
;
output_name
+=
"@"
;
output_name
+=
std
::
to_string
(
gUniqId
.
fetch_add
(
1
));
}
}
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
OutputVars
(
bool
has_intermediate
)
const
{
std
::
vector
<
std
::
string
>
OperatorBase
::
OutputVars
(
bool
has_intermediate
)
const
{
std
::
vector
<
std
::
string
>
ret_val
;
std
::
vector
<
std
::
string
>
ret_val
;
if
(
has_intermediate
)
{
if
(
has_intermediate
)
{
...
...
paddle/framework/operator.h
浏览文件 @
ffbb0be2
...
@@ -66,10 +66,8 @@ class OperatorBase {
...
@@ -66,10 +66,8 @@ class OperatorBase {
public:
public:
using
VarNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
using
VarNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
OperatorBase
()
=
default
;
OperatorBase
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
OperatorBase
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
);
:
type_
(
type
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs
)
{}
OperatorBase
(
const
OperatorBase
&
o
)
=
delete
;
OperatorBase
(
const
OperatorBase
&
o
)
=
delete
;
OperatorBase
&
operator
=
(
const
OperatorBase
&
o
)
=
delete
;
OperatorBase
&
operator
=
(
const
OperatorBase
&
o
)
=
delete
;
...
@@ -86,10 +84,6 @@ class OperatorBase {
...
@@ -86,10 +84,6 @@ class OperatorBase {
virtual
std
::
string
DebugString
()
const
;
virtual
std
::
string
DebugString
()
const
;
/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual
void
Init
()
{}
/// InferShape infer the size of Variables used by this Operator with
/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
/// information inside scope
virtual
void
InferShape
(
const
Scope
&
scope
)
const
=
0
;
virtual
void
InferShape
(
const
Scope
&
scope
)
const
=
0
;
...
@@ -135,15 +129,6 @@ class OperatorBase {
...
@@ -135,15 +129,6 @@ class OperatorBase {
AttributeMap
attrs_
;
AttributeMap
attrs_
;
};
};
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() : ParentClass() {
/* TODO(yi): This constructor is to be removed. */
\
} \
Class(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: ParentClass(type, inputs, outputs, attrs) {}
class
InferShapeContext
{
class
InferShapeContext
{
public:
public:
InferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
InferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
...
@@ -287,8 +272,6 @@ class OpKernel {
...
@@ -287,8 +272,6 @@ class OpKernel {
class
OperatorWithKernel
:
public
OperatorBase
{
class
OperatorWithKernel
:
public
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
OperatorWithKernel
,
OperatorBase
)
struct
OpKernelKey
{
struct
OpKernelKey
{
platform
::
Place
place_
;
platform
::
Place
place_
;
...
@@ -312,6 +295,10 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -312,6 +295,10 @@ class OperatorWithKernel : public OperatorBase {
using
OpKernelMap
=
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelKey
,
std
::
unique_ptr
<
OpKernel
>
,
OpKernelHash
>
;
std
::
unordered_map
<
OpKernelKey
,
std
::
unique_ptr
<
OpKernel
>
,
OpKernelHash
>
;
OperatorWithKernel
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{
void
InferShape
(
const
Scope
&
scope
)
const
override
{
InferShape
(
InferShapeContext
(
*
this
,
scope
));
InferShape
(
InferShapeContext
(
*
this
,
scope
));
}
}
...
...
paddle/framework/operator_test.cc
浏览文件 @
ffbb0be2
...
@@ -22,10 +22,10 @@ namespace framework {
...
@@ -22,10 +22,10 @@ namespace framework {
static
int
op_run_num
=
0
;
static
int
op_run_num
=
0
;
class
OpWithoutKernelTest
:
public
OperatorBase
{
class
OpWithoutKernelTest
:
public
OperatorBase
{
DEFINE_OPERATOR_CTOR
(
OpWithoutKernelTest
,
framework
::
OperatorBase
)
public:
public:
void
Init
()
override
{
x
=
1
;
}
OpWithoutKernelTest
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
),
x
(
1
)
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
...
@@ -38,7 +38,7 @@ class OpWithoutKernelTest : public OperatorBase {
...
@@ -38,7 +38,7 @@ class OpWithoutKernelTest : public OperatorBase {
}
}
public:
public:
float
x
=
0
;
int
x
{
0
}
;
};
};
class
OpeWithoutKernelTestProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
OpeWithoutKernelTestProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
@@ -109,7 +109,9 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -109,7 +109,9 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static
int
cpu_kernel_run_num
=
0
;
static
int
cpu_kernel_run_num
=
0
;
class
OpWithKernelTest
:
public
OperatorWithKernel
{
class
OpWithKernelTest
:
public
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
OpWithKernelTest
,
framework
::
OperatorWithKernel
)
public:
using
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{}
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{}
};
};
...
...
paddle/operators/add_op.cc
浏览文件 @
ffbb0be2
...
@@ -18,7 +18,8 @@ namespace paddle {
...
@@ -18,7 +18,8 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
AddOp
:
public
framework
::
OperatorWithKernel
{
class
AddOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
AddOp
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
...
@@ -45,7 +46,9 @@ The equation is: Out = X + Y
...
@@ -45,7 +46,9 @@ The equation is: Out = X + Y
};
};
class
AddOpGrad
:
public
framework
::
OperatorWithKernel
{
class
AddOpGrad
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
AddOpGrad
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{}
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{}
};
};
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
ffbb0be2
...
@@ -18,7 +18,9 @@ namespace paddle {
...
@@ -18,7 +18,9 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
OnehotCrossEntropyOp
:
public
framework
::
OperatorWithKernel
{
class
OnehotCrossEntropyOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
OnehotCrossEntropyOp
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
...
@@ -32,8 +34,9 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -32,8 +34,9 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
};
};
class
OnehotCrossEntropyGradientOp
:
public
framework
::
OperatorWithKernel
{
class
OnehotCrossEntropyGradientOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
OnehotCrossEntropyGradientOp
,
public:
framework
::
OperatorWithKernel
)
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
X_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
X_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
ffbb0be2
...
@@ -18,7 +18,8 @@ namespace paddle {
...
@@ -18,7 +18,8 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
FillZerosLikeOp
,
framework
::
OperatorWithKernel
);
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
...
...
paddle/operators/gaussian_random_op.cc
浏览文件 @
ffbb0be2
...
@@ -43,7 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel {
...
@@ -43,7 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel {
};
};
class
GaussianRandomOp
:
public
framework
::
OperatorWithKernel
{
class
GaussianRandomOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
GaussianRandomOp
,
framework
::
OperatorWithKernel
);
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
...
...
paddle/operators/mean_op.cc
浏览文件 @
ffbb0be2
...
@@ -18,7 +18,9 @@ namespace paddle {
...
@@ -18,7 +18,9 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
MeanOp
:
public
framework
::
OperatorWithKernel
{
class
MeanOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
MeanOp
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
...
@@ -38,7 +40,9 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -38,7 +40,9 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
};
};
class
MeanGradOp
:
public
framework
::
OperatorWithKernel
{
class
MeanGradOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
MeanGradOp
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
))
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
))
...
...
paddle/operators/mul_op.cc
浏览文件 @
ffbb0be2
...
@@ -19,7 +19,8 @@ namespace paddle {
...
@@ -19,7 +19,8 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
MulOp
:
public
framework
::
OperatorWithKernel
{
class
MulOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
MulOp
,
framework
::
OperatorWithKernel
);
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
...
@@ -54,7 +55,9 @@ The equation is: Out = X * Y
...
@@ -54,7 +55,9 @@ The equation is: Out = X * Y
};
};
class
MulOpGrad
:
public
framework
::
OperatorWithKernel
{
class
MulOpGrad
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
MulOpGrad
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{}
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{}
std
::
string
DebugString
()
const
override
{
std
::
string
DebugString
()
const
override
{
...
...
paddle/operators/net_op.cc
浏览文件 @
ffbb0be2
...
@@ -81,5 +81,11 @@ std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
...
@@ -81,5 +81,11 @@ std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
return
ret_val
;
return
ret_val
;
}
}
NetOp
::
NetOp
(
const
std
::
string
&
type
,
const
framework
::
OperatorBase
::
VarNameMap
&
inputs
,
const
framework
::
OperatorBase
::
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/operators/net_op.h
浏览文件 @
ffbb0be2
...
@@ -37,7 +37,9 @@ namespace operators {
...
@@ -37,7 +37,9 @@ namespace operators {
class
NetOp
:
public
framework
::
OperatorBase
{
class
NetOp
:
public
framework
::
OperatorBase
{
public:
public:
static
const
char
kAll
[];
static
const
char
kAll
[];
DEFINE_OPERATOR_CTOR
(
NetOp
,
framework
::
OperatorBase
);
NetOp
()
:
framework
::
OperatorBase
(
"plain_net"
,
{},
{},
{})
{}
NetOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
/**
/**
* Infer all the operators' input and output variables' shapes, will be called
* Infer all the operators' input and output variables' shapes, will be called
...
...
paddle/operators/net_op_test.cc
浏览文件 @
ffbb0be2
...
@@ -12,7 +12,7 @@ static int run_cnt = 0;
...
@@ -12,7 +12,7 @@ static int run_cnt = 0;
class
TestOp
:
public
framework
::
OperatorBase
{
class
TestOp
:
public
framework
::
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
TestOp
,
framework
::
OperatorBase
)
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
InferShape
(
const
Scope
&
scope
)
const
override
{
++
infer_shape_cnt
;
}
void
InferShape
(
const
Scope
&
scope
)
const
override
{
++
infer_shape_cnt
;
}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
...
@@ -22,7 +22,7 @@ class TestOp : public framework::OperatorBase {
...
@@ -22,7 +22,7 @@ class TestOp : public framework::OperatorBase {
class
EmptyOp
:
public
framework
::
OperatorBase
{
class
EmptyOp
:
public
framework
::
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
EmptyOp
,
framework
::
OperatorBase
)
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
const
DeviceContext
&
dev_ctx
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
const
DeviceContext
&
dev_ctx
)
const
override
{}
};
};
...
@@ -44,14 +44,14 @@ TEST(OpKernel, all) {
...
@@ -44,14 +44,14 @@ TEST(OpKernel, all) {
auto
net
=
std
::
make_shared
<
NetOp
>
();
auto
net
=
std
::
make_shared
<
NetOp
>
();
ASSERT_NE
(
net
,
nullptr
);
ASSERT_NE
(
net
,
nullptr
);
auto
op1
=
std
::
make_shared
<
TestOp
>
();
auto
op1
=
std
::
shared_ptr
<
TestOp
>
(
op1
->
inputs_
=
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}};
new
TestOp
(
"test"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
op1
->
outputs_
=
{{
"Out"
,
{
"y"
}}}
;
{{
"Out"
,
{
"y"
}}},
{}))
;
net
->
AddOp
(
op1
);
net
->
AddOp
(
op1
);
auto
op2
=
std
::
make_shared
<
TestOp
>
();
auto
op2
=
std
::
shared_ptr
<
TestOp
>
(
op2
->
inputs_
=
{{
"X"
,
{
"y"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}};
new
TestOp
(
"test"
,
{{
"X"
,
{
"y"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}},
op2
->
outputs_
=
{{
"Out"
,
{
"z"
}}}
;
{{
"Out"
,
{
"z"
}}},
{}))
;
net
->
AddOp
(
op2
);
net
->
AddOp
(
op2
);
net
->
CompleteAddOp
();
net
->
CompleteAddOp
();
...
@@ -67,9 +67,9 @@ TEST(OpKernel, all) {
...
@@ -67,9 +67,9 @@ TEST(OpKernel, all) {
TEST
(
NetOp
,
insert_op
)
{
TEST
(
NetOp
,
insert_op
)
{
NetOp
net
;
NetOp
net
;
auto
op1
=
std
::
make_shared
<
EmptyOp
>
();
auto
op1
=
std
::
shared_ptr
<
EmptyOp
>
(
op1
->
inputs_
=
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}};
new
EmptyOp
(
"empty"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
op1
->
outputs_
=
{{
"Out"
,
{
"y"
}}}
;
{{
"Out"
,
{
"y"
}}},
{}))
;
net
.
AddOp
(
op1
);
net
.
AddOp
(
op1
);
net
.
InsertOp
(
0
,
op1
);
net
.
InsertOp
(
0
,
op1
);
ASSERT_EQ
(
2UL
,
net
.
ops_
.
size
());
ASSERT_EQ
(
2UL
,
net
.
ops_
.
size
());
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
ffbb0be2
...
@@ -135,8 +135,11 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{
...
@@ -135,8 +135,11 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{
"inlink@grad"
,
"inlink_alias"
,
"outlink_alias"
,
"inlink@grad"
,
"inlink_alias"
,
"outlink_alias"
,
"memories"
,
"pre_memories"
,
"boot_memories@grad"
};
"memories"
,
"pre_memories"
,
"boot_memories@grad"
};
void
RecurrentOp
::
Init
()
{
RecurrentOp
::
RecurrentOp
(
const
std
::
string
&
type
,
OperatorBase
::
Init
();
const
framework
::
OperatorBase
::
VarNameMap
&
inputs
,
const
framework
::
OperatorBase
::
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
std
::
unique_ptr
<
rnn
::
Argument
>
arg
(
new
rnn
::
Argument
());
std
::
unique_ptr
<
rnn
::
Argument
>
arg
(
new
rnn
::
Argument
());
rnn
::
InitArgument
(
kArgName
,
arg
.
get
(),
*
this
);
rnn
::
InitArgument
(
kArgName
,
arg
.
get
(),
*
this
);
alg_
.
Init
(
std
::
move
(
arg
));
alg_
.
Init
(
std
::
move
(
arg
));
...
@@ -230,8 +233,11 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
...
@@ -230,8 +233,11 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
LinkBootMemoryGradients
(
step_scopes
[
0
],
true
/*infer_shape_mode*/
);
LinkBootMemoryGradients
(
step_scopes
[
0
],
true
/*infer_shape_mode*/
);
}
}
void
RecurrentGradientOp
::
Init
()
{
RecurrentGradientOp
::
RecurrentGradientOp
(
OperatorBase
::
Init
();
const
std
::
string
&
type
,
const
framework
::
OperatorBase
::
VarNameMap
&
inputs
,
const
framework
::
OperatorBase
::
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
std
::
unique_ptr
<
rnn
::
Argument
>
arg
(
new
rnn
::
Argument
());
std
::
unique_ptr
<
rnn
::
Argument
>
arg
(
new
rnn
::
Argument
());
rnn
::
InitArgument
(
kArgName
,
arg
.
get
(),
*
this
);
rnn
::
InitArgument
(
kArgName
,
arg
.
get
(),
*
this
);
alg_
.
Init
(
std
::
move
(
arg
));
alg_
.
Init
(
std
::
move
(
arg
));
...
...
paddle/operators/recurrent_op.h
浏览文件 @
ffbb0be2
...
@@ -101,13 +101,11 @@ class RecurrentGradientAlgorithm {
...
@@ -101,13 +101,11 @@ class RecurrentGradientAlgorithm {
class
RecurrentOp
final
:
public
framework
::
OperatorBase
{
class
RecurrentOp
final
:
public
framework
::
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
RecurrentOp
,
framework
::
OperatorBase
);
RecurrentOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
void
Init
()
override
;
/**
/**
* InferShape must be called before Run.
* InferShape must be called before Run.
*/
*/
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
override
{
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
override
{
alg_
.
InferShape
(
scope
);
alg_
.
InferShape
(
scope
);
}
}
...
@@ -125,8 +123,9 @@ class RecurrentOp final : public framework::OperatorBase {
...
@@ -125,8 +123,9 @@ class RecurrentOp final : public framework::OperatorBase {
class
RecurrentGradientOp
final
:
public
framework
::
OperatorBase
{
class
RecurrentGradientOp
final
:
public
framework
::
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
RecurrentGradientOp
,
framework
::
OperatorBase
)
RecurrentGradientOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
void
Init
()
override
;
const
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
/**
/**
* InferShape must be called before Run.
* InferShape must be called before Run.
...
...
paddle/operators/rowwise_add_op.cc
浏览文件 @
ffbb0be2
...
@@ -18,7 +18,9 @@ namespace paddle {
...
@@ -18,7 +18,9 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
RowWiseAddOp
:
public
framework
::
OperatorWithKernel
{
class
RowWiseAddOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
RowWiseAddOp
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
...
...
paddle/operators/sgd_op.cc
浏览文件 @
ffbb0be2
...
@@ -18,7 +18,9 @@ namespace paddle {
...
@@ -18,7 +18,9 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
SGDOp
:
public
framework
::
OperatorWithKernel
{
class
SGDOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
SGDOp
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
...
...
paddle/operators/sigmoid_op.cc
浏览文件 @
ffbb0be2
...
@@ -18,7 +18,9 @@ namespace paddle {
...
@@ -18,7 +18,9 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
SigmoidOp
:
public
framework
::
OperatorWithKernel
{
class
SigmoidOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
SigmoidOp
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
...
@@ -37,7 +39,9 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -37,7 +39,9 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
};
};
class
SigmoidOpGrad
:
public
framework
::
OperatorWithKernel
{
class
SigmoidOpGrad
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
SigmoidOpGrad
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
...
...
paddle/operators/softmax_op.cc
浏览文件 @
ffbb0be2
...
@@ -18,7 +18,9 @@ namespace paddle {
...
@@ -18,7 +18,9 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
SoftmaxOp
:
public
framework
::
OperatorWithKernel
{
class
SoftmaxOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
SoftmaxOp
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
()
==
2UL
,
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
()
==
2UL
,
...
@@ -39,7 +41,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -39,7 +41,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
};
};
class
SoftmaxOpGrad
:
public
framework
::
OperatorWithKernel
{
class
SoftmaxOpGrad
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
SoftmaxOpGrad
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputVar
(
"Y"
)
!=
nullptr
,
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
"Y"
)
!=
nullptr
,
"Input(Y) should not be null"
);
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
ffbb0be2
...
@@ -46,7 +46,9 @@ class CPUUniformRandomKernel : public framework::OpKernel {
...
@@ -46,7 +46,9 @@ class CPUUniformRandomKernel : public framework::OpKernel {
};
};
class
UniformRandomOp
:
public
framework
::
OperatorWithKernel
{
class
UniformRandomOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
UniformRandomOp
,
framework
::
OperatorWithKernel
)
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
GetAttr
<
float
>
(
"min"
)
<
GetAttr
<
float
>
(
"max"
),
PADDLE_ENFORCE
(
GetAttr
<
float
>
(
"min"
)
<
GetAttr
<
float
>
(
"max"
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录