Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9620df44
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
9620df44
编写于
8月 04, 2017
作者:
Y
Yi Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reformat paddle/operators/* strictly following Google Style Guide
上级
559b0224
变更
26
显示空白变更内容
内联
并排
Showing
26 changed file
with
129 addition
and
170 deletion
+129
-170
paddle/operators/.clang-format
paddle/operators/.clang-format
+5
-0
paddle/operators/add_op.cc
paddle/operators/add_op.cc
+3
-3
paddle/operators/add_op.h
paddle/operators/add_op.h
+1
-1
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+3
-4
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+1
-1
paddle/operators/fc_op.cc
paddle/operators/fc_op.cc
+6
-8
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+3
-4
paddle/operators/fill_zeros_like_op.h
paddle/operators/fill_zeros_like_op.h
+1
-1
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+3
-3
paddle/operators/mean_op.h
paddle/operators/mean_op.h
+2
-2
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+3
-3
paddle/operators/mul_op.h
paddle/operators/mul_op.h
+1
-1
paddle/operators/net_op.h
paddle/operators/net_op.h
+2
-2
paddle/operators/net_op_test.cc
paddle/operators/net_op_test.cc
+2
-2
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+51
-80
paddle/operators/recurrent_op.h
paddle/operators/recurrent_op.h
+14
-18
paddle/operators/recurrent_op_test.cc
paddle/operators/recurrent_op_test.cc
+9
-13
paddle/operators/rowwise_add_op.cc
paddle/operators/rowwise_add_op.cc
+2
-2
paddle/operators/rowwise_add_op.h
paddle/operators/rowwise_add_op.h
+1
-1
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+2
-2
paddle/operators/sgd_op.h
paddle/operators/sgd_op.h
+1
-1
paddle/operators/sigmoid_op.cc
paddle/operators/sigmoid_op.cc
+3
-3
paddle/operators/sigmoid_op.h
paddle/operators/sigmoid_op.h
+1
-1
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+3
-3
paddle/operators/softmax_op.h
paddle/operators/softmax_op.h
+2
-2
paddle/operators/type_alias.h
paddle/operators/type_alias.h
+4
-9
未找到文件。
paddle/operators/.clang-format
0 → 100644
浏览文件 @
9620df44
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...
paddle/operators/add_op.cc
浏览文件 @
9620df44
...
@@ -18,7 +18,7 @@ namespace paddle {
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
AddOp
:
public
OperatorWithKernel
{
class
AddOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"Input size of AddOp must be two"
);
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"Input size of AddOp must be two"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Output size of AddOp must be one"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Output size of AddOp must be one"
);
...
@@ -33,7 +33,7 @@ protected:
...
@@ -33,7 +33,7 @@ protected:
};
};
class
AddOpMaker
:
public
OpProtoAndCheckerMaker
{
class
AddOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
AddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
AddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The first input of add op"
);
AddInput
(
"X"
,
"The first input of add op"
);
...
@@ -48,7 +48,7 @@ The equation is: Out = X + Y
...
@@ -48,7 +48,7 @@ The equation is: Out = X + Y
};
};
class
AddOpGrad
:
public
OperatorWithKernel
{
class
AddOpGrad
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{}
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{}
};
};
...
...
paddle/operators/add_op.h
浏览文件 @
9620df44
...
@@ -20,7 +20,7 @@ namespace operators {
...
@@ -20,7 +20,7 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
AddKernel
:
public
OpKernel
{
class
AddKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
auto
input0
=
context
.
Input
<
Tensor
>
(
0
);
auto
input0
=
context
.
Input
<
Tensor
>
(
0
);
auto
input1
=
context
.
Input
<
Tensor
>
(
1
);
auto
input1
=
context
.
Input
<
Tensor
>
(
1
);
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
9620df44
...
@@ -18,7 +18,7 @@ namespace paddle {
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
OnehotCrossEntropyOp
:
public
OperatorWithKernel
{
class
OnehotCrossEntropyOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"Input size of OnehotCrossEntropyOp must be two"
);
"Input size of OnehotCrossEntropyOp must be two"
);
...
@@ -37,7 +37,7 @@ protected:
...
@@ -37,7 +37,7 @@ protected:
};
};
class
OnehotCrossEntropyOpMaker
:
public
OpProtoAndCheckerMaker
{
class
OnehotCrossEntropyOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
OnehotCrossEntropyOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
OnehotCrossEntropyOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The first input of OnehotCrossEntropyOp"
);
AddInput
(
"X"
,
"The first input of OnehotCrossEntropyOp"
);
...
@@ -54,8 +54,7 @@ OnehotCrossEntropy Operator.
...
@@ -54,8 +54,7 @@ OnehotCrossEntropy Operator.
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP
(
onehot_cross_entropy
,
REGISTER_OP
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOp
,
ops
::
OnehotCrossEntropyOp
,
ops
::
OnehotCrossEntropyOpMaker
);
ops
::
OnehotCrossEntropyOpMaker
);
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy
,
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOpKernel
<
ops
::
CPUPlace
,
float
>
);
ops
::
OnehotCrossEntropyOpKernel
<
ops
::
CPUPlace
,
float
>
);
paddle/operators/cross_entropy_op.h
浏览文件 @
9620df44
...
@@ -20,7 +20,7 @@ namespace operators {
...
@@ -20,7 +20,7 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
OnehotCrossEntropyOpKernel
:
public
OpKernel
{
class
OnehotCrossEntropyOpKernel
:
public
OpKernel
{
public:
public:
constexpr
T
LOG_THRESHOLD
()
const
{
return
static_cast
<
T
>
(
1e-20
);
}
constexpr
T
LOG_THRESHOLD
()
const
{
return
static_cast
<
T
>
(
1e-20
);
}
void
Compute
(
const
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
ExecutionContext
&
ctx
)
const
override
{
...
...
paddle/operators/fc_op.cc
浏览文件 @
9620df44
...
@@ -18,31 +18,29 @@ namespace paddle {
...
@@ -18,31 +18,29 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
FullyConnectedOp
:
public
NetOp
{
class
FullyConnectedOp
:
public
NetOp
{
public:
public:
void
Init
()
override
{
void
Init
()
override
{
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
{
Input
(
"X"
),
Input
(
"W"
),
Input
(
"X"
),
Input
(
"W"
),
},
},
{
Output
(
"before_act"
)},
{
Output
(
"before_act"
)},
{}));
{}));
auto
b
=
Input
(
"b"
);
auto
b
=
Input
(
"b"
);
if
(
b
!=
framework
::
kEmptyVarName
)
{
if
(
b
!=
framework
::
kEmptyVarName
)
{
AddOp
(
OpRegistry
::
CreateOp
(
"rowwise_add"
,
AddOp
(
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
Output
(
"before_act"
),
Input
(
"b"
)},
{
Output
(
"before_act"
),
Input
(
"b"
)},
{
Output
(
"before_act"
)},
{
Output
(
"before_act"
)},
{}));
{}));
}
}
auto
activation
=
GetAttr
<
std
::
string
>
(
"activation"
);
auto
activation
=
GetAttr
<
std
::
string
>
(
"activation"
);
AddOp
(
OpRegistry
::
CreateOp
(
AddOp
(
OpRegistry
::
CreateOp
(
activation
,
{
Output
(
"before_act"
)},
activation
,
{
Output
(
"before_act"
)},
{
Output
(
"Y"
)},
{}));
{
Output
(
"Y"
)},
{}));
CompleteAddOp
(
false
);
CompleteAddOp
(
false
);
}
}
};
};
class
FullyConnectedOpMaker
:
public
OpProtoAndCheckerMaker
{
class
FullyConnectedOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
FullyConnectedOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
FullyConnectedOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"the input of fc operator"
);
AddInput
(
"X"
,
"the input of fc operator"
);
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
9620df44
...
@@ -20,7 +20,7 @@ namespace paddle {
...
@@ -20,7 +20,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1UL
,
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1UL
,
"Input size of FillZerosLikeOp must be one."
);
"Input size of FillZerosLikeOp must be one."
);
...
@@ -36,7 +36,7 @@ protected:
...
@@ -36,7 +36,7 @@ protected:
};
};
class
FillZerosLikeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
FillZerosLikeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
FillZerosLikeOpMaker
(
framework
::
OpProto
*
proto
,
FillZerosLikeOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
...
@@ -52,8 +52,7 @@ The output will have the same size with input.
...
@@ -52,8 +52,7 @@ The output will have the same size with input.
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP
(
fill_zeros_like
,
REGISTER_OP
(
fill_zeros_like
,
paddle
::
operators
::
FillZerosLikeOp
,
paddle
::
operators
::
FillZerosLikeOp
,
paddle
::
operators
::
FillZerosLikeOpMaker
);
paddle
::
operators
::
FillZerosLikeOpMaker
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
fill_zeros_like
,
fill_zeros_like
,
...
...
paddle/operators/fill_zeros_like_op.h
浏览文件 @
9620df44
...
@@ -22,7 +22,7 @@ namespace operators {
...
@@ -22,7 +22,7 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
FillZerosLikeKernel
:
public
framework
::
OpKernel
{
class
FillZerosLikeKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/operators/mean_op.cc
浏览文件 @
9620df44
...
@@ -18,7 +18,7 @@ namespace paddle {
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
MeanOp
:
public
OperatorWithKernel
{
class
MeanOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1
,
"Input size of AddOp must be one"
);
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1
,
"Input size of AddOp must be one"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Output size of AddOp must be one"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Output size of AddOp must be one"
);
...
@@ -29,7 +29,7 @@ protected:
...
@@ -29,7 +29,7 @@ protected:
};
};
class
MeanOpMaker
:
public
OpProtoAndCheckerMaker
{
class
MeanOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
MeanOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
MeanOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input of mean op"
);
AddInput
(
"X"
,
"The input of mean op"
);
...
@@ -39,7 +39,7 @@ public:
...
@@ -39,7 +39,7 @@ public:
};
};
class
MeanGradOp
:
public
OperatorWithKernel
{
class
MeanGradOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
ctx
.
Output
<
Tensor
>
(
"X"
+
framework
::
kGradVarSuffix
)
ctx
.
Output
<
Tensor
>
(
"X"
+
framework
::
kGradVarSuffix
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
...
...
paddle/operators/mean_op.h
浏览文件 @
9620df44
...
@@ -20,7 +20,7 @@ namespace operators {
...
@@ -20,7 +20,7 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
MeanKernel
:
public
OpKernel
{
class
MeanKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
auto
input
=
context
.
Input
<
Tensor
>
(
0
);
auto
input
=
context
.
Input
<
Tensor
>
(
0
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
...
@@ -37,7 +37,7 @@ public:
...
@@ -37,7 +37,7 @@ public:
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
MeanGradKernel
:
public
OpKernel
{
class
MeanGradKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
auto
OG
=
context
.
Input
<
Tensor
>
(
"Out"
+
framework
::
kGradVarSuffix
);
auto
OG
=
context
.
Input
<
Tensor
>
(
"Out"
+
framework
::
kGradVarSuffix
);
PADDLE_ENFORCE
(
framework
::
product
(
OG
->
dims
())
==
1
,
PADDLE_ENFORCE
(
framework
::
product
(
OG
->
dims
())
==
1
,
...
...
paddle/operators/mul_op.cc
浏览文件 @
9620df44
...
@@ -18,7 +18,7 @@ namespace paddle {
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
MulOp
:
public
OperatorWithKernel
{
class
MulOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"The mul op must take two inputs"
);
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"The mul op must take two inputs"
);
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
();
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
();
...
@@ -34,7 +34,7 @@ protected:
...
@@ -34,7 +34,7 @@ protected:
};
};
class
MulOpMaker
:
public
OpProtoAndCheckerMaker
{
class
MulOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The first input of mul op"
);
AddInput
(
"X"
,
"The first input of mul op"
);
...
@@ -49,7 +49,7 @@ The equation is: Out = X * Y
...
@@ -49,7 +49,7 @@ The equation is: Out = X * Y
};
};
class
MulOpGrad
:
public
OperatorWithKernel
{
class
MulOpGrad
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{}
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{}
std
::
string
DebugString
()
const
override
{
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"MulGrad"
;
LOG
(
INFO
)
<<
"MulGrad"
;
...
...
paddle/operators/mul_op.h
浏览文件 @
9620df44
...
@@ -21,7 +21,7 @@ namespace operators {
...
@@ -21,7 +21,7 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
MulKernel
:
public
OpKernel
{
class
MulKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
Eigen
::
array
<
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
,
1
>
dim_pair
=
{
Eigen
::
array
<
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
,
1
>
dim_pair
=
{
{
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
(
1
,
0
)}};
{
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
(
1
,
0
)}};
...
...
paddle/operators/net_op.h
浏览文件 @
9620df44
...
@@ -40,7 +40,7 @@ namespace operators {
...
@@ -40,7 +40,7 @@ namespace operators {
* it defines.
* it defines.
*/
*/
class
NetOp
:
public
framework
::
OperatorBase
{
class
NetOp
:
public
framework
::
OperatorBase
{
public:
public:
/**
/**
* Infer all the operators' input and output variables' shapes, will be called
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
* before every mini-batch
...
@@ -90,7 +90,7 @@ public:
...
@@ -90,7 +90,7 @@ public:
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
private:
private:
bool
add_op_done_
{
false
};
bool
add_op_done_
{
false
};
template
<
typename
T
,
typename
KeyType
>
template
<
typename
T
,
typename
KeyType
>
...
...
paddle/operators/net_op_test.cc
浏览文件 @
9620df44
...
@@ -12,7 +12,7 @@ static int infer_shape_cnt = 0;
...
@@ -12,7 +12,7 @@ static int infer_shape_cnt = 0;
static
int
run_cnt
=
0
;
static
int
run_cnt
=
0
;
class
TestOp
:
public
OperatorBase
{
class
TestOp
:
public
OperatorBase
{
public:
public:
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
override
{
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
override
{
++
infer_shape_cnt
;
++
infer_shape_cnt
;
}
}
...
@@ -23,7 +23,7 @@ public:
...
@@ -23,7 +23,7 @@ public:
};
};
class
EmptyOp
:
public
OperatorBase
{
class
EmptyOp
:
public
OperatorBase
{
public:
public:
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
9620df44
...
@@ -28,14 +28,12 @@ namespace operators {
...
@@ -28,14 +28,12 @@ namespace operators {
namespace
rnn
{
namespace
rnn
{
void
SegmentInputs
(
const
std
::
vector
<
Scope
*>&
step_scopes
,
void
SegmentInputs
(
const
std
::
vector
<
Scope
*>&
step_scopes
,
const
std
::
vector
<
Link
>&
inlinks
,
const
std
::
vector
<
Link
>&
inlinks
,
const
size_t
seq_len
,
const
size_t
seq_len
,
bool
infer_shape_mode
)
{
bool
infer_shape_mode
)
{
PADDLE_ENFORCE
(
!
inlinks
.
empty
(),
"no in links are provided."
);
PADDLE_ENFORCE
(
!
inlinks
.
empty
(),
"no in links are provided."
);
for
(
size_t
i
=
0
;
i
<
inlinks
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inlinks
.
size
();
++
i
)
{
auto
input_var
=
step_scopes
[
0
]
->
FindVar
(
inlinks
[
i
].
external
);
auto
input_var
=
step_scopes
[
0
]
->
FindVar
(
inlinks
[
i
].
external
);
PADDLE_ENFORCE
(
input_var
!=
nullptr
,
PADDLE_ENFORCE
(
input_var
!=
nullptr
,
"input link [%s] is not in scope."
,
"input link [%s] is not in scope."
,
inlinks
[
i
].
external
);
inlinks
[
i
].
external
);
Tensor
*
input
=
input_var
->
GetMutable
<
Tensor
>
();
Tensor
*
input
=
input_var
->
GetMutable
<
Tensor
>
();
framework
::
DDim
dims
=
input
->
dims
();
framework
::
DDim
dims
=
input
->
dims
();
...
@@ -54,13 +52,11 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
...
@@ -54,13 +52,11 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
}
}
void
ConcatOutputs
(
const
std
::
vector
<
Scope
*>&
step_scopes
,
void
ConcatOutputs
(
const
std
::
vector
<
Scope
*>&
step_scopes
,
const
std
::
vector
<
Link
>&
outlinks
,
const
std
::
vector
<
Link
>&
outlinks
,
const
size_t
seq_len
,
const
size_t
seq_len
,
bool
infer_shape_mode
)
{
bool
infer_shape_mode
)
{
for
(
size_t
i
=
0
;
i
<
outlinks
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outlinks
.
size
();
i
++
)
{
auto
output_var
=
step_scopes
[
0
]
->
FindVar
(
outlinks
[
i
].
external
);
auto
output_var
=
step_scopes
[
0
]
->
FindVar
(
outlinks
[
i
].
external
);
PADDLE_ENFORCE
(
output_var
!=
nullptr
,
PADDLE_ENFORCE
(
output_var
!=
nullptr
,
"output link [%s] is not in scope."
,
"output link [%s] is not in scope."
,
outlinks
[
i
].
external
);
outlinks
[
i
].
external
);
Tensor
*
output
=
output_var
->
GetMutable
<
Tensor
>
();
Tensor
*
output
=
output_var
->
GetMutable
<
Tensor
>
();
if
(
infer_shape_mode
)
{
if
(
infer_shape_mode
)
{
...
@@ -87,22 +83,16 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
...
@@ -87,22 +83,16 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
void
LinkMemories
(
const
std
::
vector
<
Scope
*>&
scopes
,
void
LinkMemories
(
const
std
::
vector
<
Scope
*>&
scopes
,
const
std
::
vector
<
rnn
::
MemoryAttr
>&
memories
,
const
std
::
vector
<
rnn
::
MemoryAttr
>&
memories
,
const
size_t
step_id
,
const
size_t
step_id
,
const
int
offset
,
const
int
offset
,
bool
infer_shape_mode
)
{
bool
infer_shape_mode
)
{
PADDLE_ENFORCE
(
step_id
<
scopes
.
size
(),
PADDLE_ENFORCE
(
step_id
<
scopes
.
size
(),
"step [%d] is out of range of step scopes' size [%d]"
,
"step [%d] is out of range of step scopes' size [%d]"
,
step_id
,
step_id
,
scopes
.
size
());
scopes
.
size
());
PADDLE_ENFORCE
(
static_cast
<
int
>
(
step_id
)
+
offset
>=
0
,
PADDLE_ENFORCE
(
static_cast
<
int
>
(
step_id
)
+
offset
>=
0
,
"offset [%d] must be large than -[%d]"
,
"offset [%d] must be large than -[%d]"
,
offset
,
step_id
);
offset
,
step_id
);
PADDLE_ENFORCE
(
step_id
+
offset
<
scopes
.
size
(),
PADDLE_ENFORCE
(
step_id
+
offset
<
scopes
.
size
(),
"offset [%d] is out of range, it must be less than (%d - %d)"
,
"offset [%d] is out of range, it must be less than (%d - %d)"
,
offset
,
offset
,
scopes
.
size
(),
step_id
);
scopes
.
size
(),
step_id
);
auto
scope
=
scopes
[
step_id
];
auto
scope
=
scopes
[
step_id
];
auto
linked_scope
=
scopes
[
step_id
+
offset
];
auto
linked_scope
=
scopes
[
step_id
+
offset
];
for
(
auto
&
attr
:
memories
)
{
for
(
auto
&
attr
:
memories
)
{
...
@@ -116,8 +106,7 @@ void LinkMemories(const std::vector<Scope*>& scopes,
...
@@ -116,8 +106,7 @@ void LinkMemories(const std::vector<Scope*>& scopes,
}
}
}
}
void
InitArgument
(
const
ArgumentName
&
name
,
void
InitArgument
(
const
ArgumentName
&
name
,
Argument
*
arg
,
Argument
*
arg
,
const
OperatorBase
&
op
)
{
const
OperatorBase
&
op
)
{
arg
->
step_net
=
op
.
Input
(
name
.
step_net
);
arg
->
step_net
=
op
.
Input
(
name
.
step_net
);
arg
->
step_scopes
=
op
.
Output
(
name
.
step_scopes
);
arg
->
step_scopes
=
op
.
Output
(
name
.
step_scopes
);
...
@@ -126,8 +115,7 @@ void InitArgument(const ArgumentName& name,
...
@@ -126,8 +115,7 @@ void InitArgument(const ArgumentName& name,
auto
inlink_alias
=
op
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
inlink_alias
);
auto
inlink_alias
=
op
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
inlink_alias
);
PADDLE_ENFORCE
(
inlinks
.
size
()
==
inlink_alias
.
size
(),
PADDLE_ENFORCE
(
inlinks
.
size
()
==
inlink_alias
.
size
(),
"the size of inlinks and inlink_alias don't match:%d,%d"
,
"the size of inlinks and inlink_alias don't match:%d,%d"
,
inlinks
.
size
(),
inlinks
.
size
(),
inlink_alias
.
size
());
inlink_alias
.
size
());
for
(
size_t
i
=
0
;
i
<
inlinks
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inlinks
.
size
();
++
i
)
{
rnn
::
Link
link
;
rnn
::
Link
link
;
link
.
external
=
inlinks
[
i
];
link
.
external
=
inlinks
[
i
];
...
@@ -139,8 +127,7 @@ void InitArgument(const ArgumentName& name,
...
@@ -139,8 +127,7 @@ void InitArgument(const ArgumentName& name,
auto
outlink_alias
=
op
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
outlink_alias
);
auto
outlink_alias
=
op
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
outlink_alias
);
PADDLE_ENFORCE
(
outlinks
.
size
()
==
outlink_alias
.
size
(),
PADDLE_ENFORCE
(
outlinks
.
size
()
==
outlink_alias
.
size
(),
"the size of outlinks and outlink_alias don't match:%d,%d"
,
"the size of outlinks and outlink_alias don't match:%d,%d"
,
outlinks
.
size
(),
outlinks
.
size
(),
outlink_alias
.
size
());
outlink_alias
.
size
());
for
(
size_t
i
=
0
;
i
<
outlinks
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
outlinks
.
size
();
++
i
)
{
rnn
::
Link
link
;
rnn
::
Link
link
;
link
.
external
=
outlinks
[
i
];
link
.
external
=
outlinks
[
i
];
...
@@ -156,12 +143,10 @@ void InitArgument(const ArgumentName& name,
...
@@ -156,12 +143,10 @@ void InitArgument(const ArgumentName& name,
PADDLE_ENFORCE
(
memories
.
size
()
==
boot_memories
.
size
(),
PADDLE_ENFORCE
(
memories
.
size
()
==
boot_memories
.
size
(),
"the size of memories, boot_memories don't match:%d,%d"
,
"the size of memories, boot_memories don't match:%d,%d"
,
memories
.
size
(),
memories
.
size
(),
boot_memories
.
size
());
boot_memories
.
size
());
PADDLE_ENFORCE
(
pre_memories
.
size
()
==
boot_memories
.
size
(),
PADDLE_ENFORCE
(
pre_memories
.
size
()
==
boot_memories
.
size
(),
"the size of pre_memories, boot_memories don't match:%d,%d"
,
"the size of pre_memories, boot_memories don't match:%d,%d"
,
pre_memories
.
size
(),
pre_memories
.
size
(),
boot_memories
.
size
());
boot_memories
.
size
());
PADDLE_ENFORCE
(
memories
.
size
()
>
0
,
"more than 1 memories should be set"
);
PADDLE_ENFORCE
(
memories
.
size
()
>
0
,
"more than 1 memories should be set"
);
for
(
size_t
i
=
0
;
i
<
memories
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
memories
.
size
();
++
i
)
{
...
@@ -181,39 +166,39 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
...
@@ -181,39 +166,39 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
->
dims
()[
0
];
->
dims
()[
0
];
CreateScopes
(
scope
);
CreateScopes
(
scope
);
auto
step_scopes
=
GetStepScopes
(
scope
);
auto
step_scopes
=
GetStepScopes
(
scope
);
rnn
::
SegmentInputs
(
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len_
,
step_scopes
,
arg_
->
inlinks
,
seq_len_
,
true
/*infer_shape_mode*/
);
true
/*infer_shape_mode*/
);
InitMemories
(
step_scopes
[
0
],
true
/*infer_shape_mode*/
);
InitMemories
(
step_scopes
[
0
],
true
/*infer_shape_mode*/
);
Variable
*
net
=
scope
.
FindVar
(
arg_
->
step_net
);
Variable
*
net
=
scope
.
FindVar
(
arg_
->
step_net
);
PADDLE_ENFORCE
(
net
!=
nullptr
,
"failed to get step net"
);
PADDLE_ENFORCE
(
net
!=
nullptr
,
"failed to get step net"
);
for
(
size_t
i
=
0
;
i
<
seq_len_
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
seq_len_
;
i
++
)
{
if
(
i
>
0
)
{
if
(
i
>
0
)
{
rnn
::
LinkMemories
(
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
i
,
-
1
,
step_scopes
,
arg_
->
memories
,
i
,
-
1
,
true
/*infer_shape_mode*/
);
true
/*infer_shape_mode*/
);
}
}
net
->
GetMutable
<
NetOp
>
()
->
InferShape
(
*
step_scopes
[
i
]);
net
->
GetMutable
<
NetOp
>
()
->
InferShape
(
*
step_scopes
[
i
]);
}
}
rnn
::
ConcatOutputs
(
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
true
/*infer_shape_mode*/
);
true
/*infer_shape_mode*/
);
}
}
void
RecurrentAlgorithm
::
Run
(
const
Scope
&
scope
,
void
RecurrentAlgorithm
::
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
auto
step_scopes
=
GetStepScopes
(
scope
);
auto
step_scopes
=
GetStepScopes
(
scope
);
rnn
::
SegmentInputs
(
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len_
,
step_scopes
,
arg_
->
inlinks
,
seq_len_
,
false
/*infer_shape_mode*/
);
false
/*infer_shape_mode*/
);
InitMemories
(
step_scopes
[
0
],
false
/*infer_shape_mode*/
);
InitMemories
(
step_scopes
[
0
],
false
/*infer_shape_mode*/
);
Variable
*
net
=
scope
.
FindVar
(
arg_
->
step_net
);
Variable
*
net
=
scope
.
FindVar
(
arg_
->
step_net
);
for
(
size_t
step_id
=
0
;
step_id
<
seq_len_
;
step_id
++
)
{
for
(
size_t
step_id
=
0
;
step_id
<
seq_len_
;
step_id
++
)
{
if
(
step_id
>
0
)
{
if
(
step_id
>
0
)
{
rnn
::
LinkMemories
(
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
-
1
,
step_scopes
,
arg_
->
memories
,
step_id
,
-
1
,
false
/*infer_shape_mode*/
);
false
/*infer_shape_mode*/
);
}
}
net
->
GetMutable
<
NetOp
>
()
->
Run
(
*
step_scopes
[
step_id
],
dev_ctx
);
net
->
GetMutable
<
NetOp
>
()
->
Run
(
*
step_scopes
[
step_id
],
dev_ctx
);
}
}
rnn
::
ConcatOutputs
(
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
false
/*infer_shape_mode*/
);
false
/*infer_shape_mode*/
);
}
}
void
RecurrentAlgorithm
::
CreateScopes
(
const
Scope
&
scope
)
const
{
void
RecurrentAlgorithm
::
CreateScopes
(
const
Scope
&
scope
)
const
{
...
@@ -245,8 +230,7 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope,
...
@@ -245,8 +230,7 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope,
for
(
auto
&
attr
:
arg_
->
memories
)
{
for
(
auto
&
attr
:
arg_
->
memories
)
{
Tensor
*
pre_mem
=
step_scope
->
NewVar
(
attr
.
pre_var
)
->
GetMutable
<
Tensor
>
();
Tensor
*
pre_mem
=
step_scope
->
NewVar
(
attr
.
pre_var
)
->
GetMutable
<
Tensor
>
();
PADDLE_ENFORCE
(
step_scope
->
FindVar
(
attr
.
boot_var
)
!=
nullptr
,
PADDLE_ENFORCE
(
step_scope
->
FindVar
(
attr
.
boot_var
)
!=
nullptr
,
"memory [%s]'s boot variable [%s] not exists"
,
"memory [%s]'s boot variable [%s] not exists"
,
attr
.
var
,
attr
.
var
,
attr
.
boot_var
);
attr
.
boot_var
);
Tensor
*
boot_mem
=
step_scope
->
FindVar
(
attr
.
boot_var
)
->
GetMutable
<
Tensor
>
();
Tensor
*
boot_mem
=
step_scope
->
FindVar
(
attr
.
boot_var
)
->
GetMutable
<
Tensor
>
();
if
(
infer_shape_mode
)
{
if
(
infer_shape_mode
)
{
...
@@ -257,25 +241,15 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope,
...
@@ -257,25 +241,15 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope,
}
}
}
}
const
rnn
::
ArgumentName
RecurrentOp
::
kArgName
{
"step_net"
,
const
rnn
::
ArgumentName
RecurrentOp
::
kArgName
{
"step_scopes"
,
"step_net"
,
"step_scopes"
,
"inlinks"
,
"inlinks"
,
"outlinks"
,
"inlink_alias"
,
"outlink_alias"
,
"outlinks"
,
"memories"
,
"pre_memories"
,
"boot_memories"
};
"inlink_alias"
,
"outlink_alias"
,
const
rnn
::
ArgumentName
RecurrentGradientOp
::
kArgName
{
"memories"
,
"step_net"
,
"step_scopes"
,
"outlink@grad"
,
"pre_memories"
,
"inlink@grad"
,
"inlink_alias"
,
"outlink_alias"
,
"boot_memories"
};
"memories"
,
"pre_memories"
,
"boot_memories@grad"
};
const
rnn
::
ArgumentName
RecurrentGradientOp
::
kArgName
{
"step_net"
,
"step_scopes"
,
"outlink@grad"
,
"inlink@grad"
,
"inlink_alias"
,
"outlink_alias"
,
"memories"
,
"pre_memories"
,
"boot_memories@grad"
};
void
RecurrentOp
::
Init
()
{
void
RecurrentOp
::
Init
()
{
OperatorBase
::
Init
();
OperatorBase
::
Init
();
...
@@ -285,7 +259,7 @@ void RecurrentOp::Init() {
...
@@ -285,7 +259,7 @@ void RecurrentOp::Init() {
}
}
class
RecurrentAlgorithmProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
RecurrentAlgorithmProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
RecurrentAlgorithmProtoAndCheckerMaker
(
OpProto
*
proto
,
RecurrentAlgorithmProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
...
@@ -316,31 +290,29 @@ public:
...
@@ -316,31 +290,29 @@ public:
void
RecurrentGradientAlgorithm
::
Run
(
void
RecurrentGradientAlgorithm
::
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
auto
step_scopes
=
GetStepScopes
(
scope
);
auto
step_scopes
=
GetStepScopes
(
scope
);
rnn
::
SegmentInputs
(
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len_
,
step_scopes
,
arg_
->
inlinks
,
seq_len_
,
false
/*infer_shape_mode*/
);
false
/*infer_shape_mode*/
);
Variable
*
net
=
scope
.
FindVar
(
arg_
->
step_net
);
Variable
*
net
=
scope
.
FindVar
(
arg_
->
step_net
);
PADDLE_ENFORCE
(
net
!=
nullptr
,
"failed to get step net"
);
PADDLE_ENFORCE
(
net
!=
nullptr
,
"failed to get step net"
);
for
(
int
step_id
=
seq_len_
-
1
;
step_id
>=
0
;
--
step_id
)
{
for
(
int
step_id
=
seq_len_
-
1
;
step_id
>=
0
;
--
step_id
)
{
if
(
static_cast
<
size_t
>
(
step_id
)
!=
seq_len_
-
1
)
{
if
(
static_cast
<
size_t
>
(
step_id
)
!=
seq_len_
-
1
)
{
rnn
::
LinkMemories
(
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
1
,
step_scopes
,
arg_
->
memories
,
step_id
,
1
,
false
/*infer_shape_mode*/
);
false
/*infer_shape_mode*/
);
}
}
net
->
GetMutable
<
NetOp
>
()
->
Run
(
*
step_scopes
[
step_id
],
dev_ctx
);
net
->
GetMutable
<
NetOp
>
()
->
Run
(
*
step_scopes
[
step_id
],
dev_ctx
);
}
}
LinkBootMemoryGradients
(
step_scopes
[
0
],
false
);
LinkBootMemoryGradients
(
step_scopes
[
0
],
false
);
rnn
::
ConcatOutputs
(
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
false
/*infer_shape_mode*/
);
false
/*infer_shape_mode*/
);
}
}
void
RecurrentGradientAlgorithm
::
LinkBootMemoryGradients
(
void
RecurrentGradientAlgorithm
::
LinkBootMemoryGradients
(
Scope
*
step_scope
,
bool
infer_shape_mode
)
const
{
Scope
*
step_scope
,
bool
infer_shape_mode
)
const
{
for
(
auto
&
attr
:
arg_
->
memories
)
{
for
(
auto
&
attr
:
arg_
->
memories
)
{
PADDLE_ENFORCE
(
step_scope
->
FindVar
(
attr
.
var
)
!=
nullptr
,
PADDLE_ENFORCE
(
step_scope
->
FindVar
(
attr
.
var
)
!=
nullptr
,
"memory variable [%s] does not exists"
,
"memory variable [%s] does not exists"
,
attr
.
var
);
attr
.
var
);
PADDLE_ENFORCE
(
step_scope
->
FindVar
(
attr
.
boot_var
)
!=
nullptr
,
PADDLE_ENFORCE
(
step_scope
->
FindVar
(
attr
.
boot_var
)
!=
nullptr
,
"boot variable [%s] does not exists"
,
"boot variable [%s] does not exists"
,
attr
.
boot_var
);
attr
.
boot_var
);
Tensor
*
mem_grad
=
step_scope
->
NewVar
(
attr
.
var
)
->
GetMutable
<
Tensor
>
();
Tensor
*
mem_grad
=
step_scope
->
NewVar
(
attr
.
var
)
->
GetMutable
<
Tensor
>
();
Tensor
*
boot_mem_grad
=
Tensor
*
boot_mem_grad
=
step_scope
->
NewVar
(
attr
.
boot_var
)
->
GetMutable
<
Tensor
>
();
step_scope
->
NewVar
(
attr
.
boot_var
)
->
GetMutable
<
Tensor
>
();
...
@@ -357,19 +329,19 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
...
@@ -357,19 +329,19 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
->
GetMutable
<
Tensor
>
()
->
GetMutable
<
Tensor
>
()
->
dims
()[
0
];
->
dims
()[
0
];
auto
step_scopes
=
GetStepScopes
(
scope
);
auto
step_scopes
=
GetStepScopes
(
scope
);
rnn
::
SegmentInputs
(
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len_
,
step_scopes
,
arg_
->
inlinks
,
seq_len_
,
true
/*infer_shape_mode*/
);
true
/*infer_shape_mode*/
);
Variable
*
net
=
scope
.
FindVar
(
arg_
->
step_net
);
Variable
*
net
=
scope
.
FindVar
(
arg_
->
step_net
);
PADDLE_ENFORCE
(
net
!=
nullptr
,
"failed to get step net"
);
PADDLE_ENFORCE
(
net
!=
nullptr
,
"failed to get step net"
);
for
(
int
step_id
=
seq_len_
-
1
;
step_id
>=
0
;
--
step_id
)
{
for
(
int
step_id
=
seq_len_
-
1
;
step_id
>=
0
;
--
step_id
)
{
if
(
static_cast
<
size_t
>
(
step_id
)
!=
seq_len_
-
1
)
{
if
(
static_cast
<
size_t
>
(
step_id
)
!=
seq_len_
-
1
)
{
rnn
::
LinkMemories
(
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
1
,
step_scopes
,
arg_
->
memories
,
step_id
,
1
,
true
/*infer_shape_mode*/
);
true
/*infer_shape_mode*/
);
}
}
net
->
GetMutable
<
NetOp
>
()
->
InferShape
(
*
step_scopes
[
step_id
]);
net
->
GetMutable
<
NetOp
>
()
->
InferShape
(
*
step_scopes
[
step_id
]);
}
}
rnn
::
ConcatOutputs
(
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
true
/*infer_shape_mode*/
);
true
/*infer_shape_mode*/
);
LinkBootMemoryGradients
(
step_scopes
[
0
],
true
/*infer_shape_mode*/
);
LinkBootMemoryGradients
(
step_scopes
[
0
],
true
/*infer_shape_mode*/
);
}
}
...
@@ -383,6 +355,5 @@ void RecurrentGradientOp::Init() {
...
@@ -383,6 +355,5 @@ void RecurrentGradientOp::Init() {
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP
(
recurrent_op
,
REGISTER_OP
(
recurrent_op
,
paddle
::
operators
::
RecurrentOp
,
paddle
::
operators
::
RecurrentOp
,
paddle
::
operators
::
RecurrentAlgorithmProtoAndCheckerMaker
);
paddle
::
operators
::
RecurrentAlgorithmProtoAndCheckerMaker
);
paddle/operators/recurrent_op.h
浏览文件 @
9620df44
...
@@ -69,23 +69,19 @@ struct ArgumentName {
...
@@ -69,23 +69,19 @@ struct ArgumentName {
* Prepare inputs for each step net.
* Prepare inputs for each step net.
*/
*/
void
SegmentInputs
(
const
std
::
vector
<
framework
::
Scope
*>&
step_scopes
,
void
SegmentInputs
(
const
std
::
vector
<
framework
::
Scope
*>&
step_scopes
,
const
std
::
vector
<
Link
>&
inlinks
,
const
std
::
vector
<
Link
>&
inlinks
,
const
size_t
seq_len
,
const
size_t
seq_len
,
bool
infer_shape_mode
);
bool
infer_shape_mode
);
/**
/**
* Process outputs of step nets and merge to variables.
* Process outputs of step nets and merge to variables.
*/
*/
void
ConcatOutputs
(
const
std
::
vector
<
framework
::
Scope
*>&
step_scopes
,
void
ConcatOutputs
(
const
std
::
vector
<
framework
::
Scope
*>&
step_scopes
,
const
std
::
vector
<
Link
>&
outlinks
,
const
std
::
vector
<
Link
>&
outlinks
,
const
size_t
seq_len
,
const
size_t
seq_len
,
bool
infer_shape_mode
);
bool
infer_shape_mode
);
void
LinkMemories
(
const
std
::
vector
<
framework
::
Scope
*>&
step_scopes
,
void
LinkMemories
(
const
std
::
vector
<
framework
::
Scope
*>&
step_scopes
,
const
std
::
vector
<
MemoryAttr
>&
memories
,
const
std
::
vector
<
MemoryAttr
>&
memories
,
const
size_t
step_id
,
const
size_t
step_id
,
const
int
offset
,
bool
infer_shape_mode
);
const
int
offset
,
bool
infer_shape_mode
);
void
InitArgument
(
const
ArgumentName
&
name
,
Argument
*
arg
);
void
InitArgument
(
const
ArgumentName
&
name
,
Argument
*
arg
);
...
@@ -100,7 +96,7 @@ void InitArgument(const ArgumentName& name, Argument* arg);
...
@@ -100,7 +96,7 @@ void InitArgument(const ArgumentName& name, Argument* arg);
// Refer to: https://arxiv.org/pdf/1502.02367.pdf
// Refer to: https://arxiv.org/pdf/1502.02367.pdf
class
RecurrentAlgorithm
{
class
RecurrentAlgorithm
{
public:
public:
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
;
const
platform
::
DeviceContext
&
dev_ctx
)
const
;
...
@@ -111,7 +107,7 @@ public:
...
@@ -111,7 +107,7 @@ public:
*/
*/
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
;
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
;
protected:
protected:
/*
/*
* The step scopes will be stored in the father scope as a variable.
* The step scopes will be stored in the father scope as a variable.
*
*
...
@@ -128,7 +124,7 @@ protected:
...
@@ -128,7 +124,7 @@ protected:
void
InitMemories
(
framework
::
Scope
*
step_scopes
,
bool
infer_shape_mode
)
const
;
void
InitMemories
(
framework
::
Scope
*
step_scopes
,
bool
infer_shape_mode
)
const
;
private:
private:
std
::
unique_ptr
<
rnn
::
Argument
>
arg_
;
std
::
unique_ptr
<
rnn
::
Argument
>
arg_
;
mutable
size_t
seq_len_
;
mutable
size_t
seq_len_
;
};
};
...
@@ -144,7 +140,7 @@ class RecurrentGradientAlgorithm {
...
@@ -144,7 +140,7 @@ class RecurrentGradientAlgorithm {
* lot, and the latter is a wrapper acts like an dapter for it to make RNN an
* lot, and the latter is a wrapper acts like an dapter for it to make RNN an
* operator.
* operator.
*/
*/
public:
public:
void
Init
(
std
::
unique_ptr
<
rnn
::
Argument
>
arg
)
{
arg_
=
std
::
move
(
arg
);
}
void
Init
(
std
::
unique_ptr
<
rnn
::
Argument
>
arg
)
{
arg_
=
std
::
move
(
arg
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
...
@@ -158,20 +154,20 @@ public:
...
@@ -158,20 +154,20 @@ public:
*/
*/
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
;
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
;
protected:
protected:
inline
const
std
::
vector
<
framework
::
Scope
*>&
GetStepScopes
(
inline
const
std
::
vector
<
framework
::
Scope
*>&
GetStepScopes
(
const
framework
::
Scope
&
scope
)
const
{
const
framework
::
Scope
&
scope
)
const
{
return
*
scope
.
FindVar
(
arg_
->
step_scopes
)
return
*
scope
.
FindVar
(
arg_
->
step_scopes
)
->
GetMutable
<
std
::
vector
<
framework
::
Scope
*>>
();
->
GetMutable
<
std
::
vector
<
framework
::
Scope
*>>
();
}
}
private:
private:
std
::
unique_ptr
<
rnn
::
Argument
>
arg_
;
std
::
unique_ptr
<
rnn
::
Argument
>
arg_
;
mutable
size_t
seq_len_
;
mutable
size_t
seq_len_
;
};
};
class
RecurrentOp
final
:
public
framework
::
OperatorBase
{
class
RecurrentOp
final
:
public
framework
::
OperatorBase
{
public:
public:
void
Init
()
override
;
void
Init
()
override
;
/**
/**
...
@@ -188,12 +184,12 @@ public:
...
@@ -188,12 +184,12 @@ public:
static
const
rnn
::
ArgumentName
kArgName
;
static
const
rnn
::
ArgumentName
kArgName
;
private:
private:
RecurrentAlgorithm
alg_
;
RecurrentAlgorithm
alg_
;
};
};
class
RecurrentGradientOp
final
:
public
framework
::
OperatorBase
{
class
RecurrentGradientOp
final
:
public
framework
::
OperatorBase
{
public:
public:
void
Init
()
override
;
void
Init
()
override
;
/**
/**
...
@@ -210,7 +206,7 @@ public:
...
@@ -210,7 +206,7 @@ public:
static
const
rnn
::
ArgumentName
kArgName
;
static
const
rnn
::
ArgumentName
kArgName
;
private:
private:
RecurrentGradientAlgorithm
alg_
;
RecurrentGradientAlgorithm
alg_
;
};
};
...
...
paddle/operators/recurrent_op_test.cc
浏览文件 @
9620df44
...
@@ -29,7 +29,7 @@ using framework::make_ddim;
...
@@ -29,7 +29,7 @@ using framework::make_ddim;
using
framework
::
DDim
;
using
framework
::
DDim
;
class
RecurrentOpTest
:
public
::
testing
::
Test
{
class
RecurrentOpTest
:
public
::
testing
::
Test
{
protected:
protected:
virtual
void
SetUp
()
override
{
virtual
void
SetUp
()
override
{
CreateGlobalVariables
();
CreateGlobalVariables
();
CreateStepNet
();
CreateStepNet
();
...
@@ -174,7 +174,7 @@ TEST_F(RecurrentOpTest, Run) {
...
@@ -174,7 +174,7 @@ TEST_F(RecurrentOpTest, Run) {
}
}
class
RecurrentGradientAlgorithmTest
:
public
::
testing
::
Test
{
class
RecurrentGradientAlgorithmTest
:
public
::
testing
::
Test
{
protected:
protected:
virtual
void
SetUp
()
override
{
virtual
void
SetUp
()
override
{
CreateGlobalVariables
();
CreateGlobalVariables
();
CreateStepScopes
();
CreateStepScopes
();
...
@@ -277,13 +277,11 @@ protected:
...
@@ -277,13 +277,11 @@ protected:
LOG
(
INFO
)
<<
"create variable step_net"
;
LOG
(
INFO
)
<<
"create variable step_net"
;
Variable
*
var
=
scope_
.
NewVar
(
"step_net"
);
Variable
*
var
=
scope_
.
NewVar
(
"step_net"
);
auto
net
=
var
->
GetMutable
<
NetOp
>
();
auto
net
=
var
->
GetMutable
<
NetOp
>
();
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
"rnn/h_pre"
,
"rnn/w"
,
"rnn/s_grad"
},
{
"rnn/h_pre"
,
"rnn/w"
,
"rnn/s_grad"
},
{
"rnn/h_pre_grad"
,
"rnn/w_grad"
},
{}));
{
"rnn/h_pre_grad"
,
"rnn/w_grad"
},
{}));
net
->
AddOp
(
OpRegistry
::
CreateOp
(
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"rnn/h_grad"
},
"add_two"
,
{
"rnn/h_grad"
},
{
"rnn/x_grad"
,
"rnn/s_grad"
},
{}));
{
"rnn/x_grad"
,
"rnn/s_grad"
},
{}));
net
->
CompleteAddOp
();
net
->
CompleteAddOp
();
}
}
...
@@ -297,9 +295,7 @@ protected:
...
@@ -297,9 +295,7 @@ protected:
inlink
.
internal
=
"rnn/x"
;
inlink
.
internal
=
"rnn/x"
;
auto
step_scopes
=
auto
step_scopes
=
scope_
.
FindVar
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
scope_
.
FindVar
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
rnn
::
SegmentInputs
(
*
step_scopes
,
rnn
::
SegmentInputs
(
*
step_scopes
,
std
::
vector
<
rnn
::
Link
>
{
inlink
},
10
,
std
::
vector
<
rnn
::
Link
>
{
inlink
},
10
,
true
/*infer_shape_mode*/
);
true
/*infer_shape_mode*/
);
}
}
...
@@ -314,8 +310,8 @@ protected:
...
@@ -314,8 +310,8 @@ protected:
auto
step_scopes
=
auto
step_scopes
=
scope_
.
FindVar
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
scope_
.
FindVar
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
for
(
int
i
=
1
;
i
<
10
;
++
i
)
{
for
(
int
i
=
1
;
i
<
10
;
++
i
)
{
rnn
::
LinkMemories
(
rnn
::
LinkMemories
(
*
step_scopes
,
memories
,
i
,
-
1
,
*
step_scopes
,
memories
,
i
,
-
1
,
true
/*infer_shape_mode*/
);
true
/*infer_shape_mode*/
);
}
}
}
}
...
...
paddle/operators/rowwise_add_op.cc
浏览文件 @
9620df44
...
@@ -17,7 +17,7 @@ namespace paddle {
...
@@ -17,7 +17,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
RowWiseAddOp
:
public
OperatorWithKernel
{
class
RowWiseAddOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2UL
,
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2UL
,
"Two inputs is needed by rowwise add"
);
"Two inputs is needed by rowwise add"
);
...
@@ -33,7 +33,7 @@ protected:
...
@@ -33,7 +33,7 @@ protected:
};
};
class
RowWiseAddOpMaker
:
public
OpProtoAndCheckerMaker
{
class
RowWiseAddOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The left input of row-wise add op, must be matrix"
);
AddInput
(
"X"
,
"The left input of row-wise add op, must be matrix"
);
...
...
paddle/operators/rowwise_add_op.h
浏览文件 @
9620df44
...
@@ -20,7 +20,7 @@ namespace operators {
...
@@ -20,7 +20,7 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
RowWiseAddKernel
:
public
OpKernel
{
class
RowWiseAddKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
auto
out
=
context
.
Output
<
Tensor
>
(
0
);
auto
out
=
context
.
Output
<
Tensor
>
(
0
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/operators/sgd_op.cc
浏览文件 @
9620df44
...
@@ -18,7 +18,7 @@ namespace paddle {
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
SGDOp
:
public
OperatorWithKernel
{
class
SGDOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"Input size of SGDOp must be two"
);
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"Input size of SGDOp must be two"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Output size of SGDOp must be one"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Output size of SGDOp must be one"
);
...
@@ -32,7 +32,7 @@ protected:
...
@@ -32,7 +32,7 @@ protected:
};
};
class
SGDOpMaker
:
public
OpProtoAndCheckerMaker
{
class
SGDOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
SGDOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
SGDOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"param"
,
"input parameter"
);
AddInput
(
"param"
,
"input parameter"
);
...
...
paddle/operators/sgd_op.h
浏览文件 @
9620df44
...
@@ -20,7 +20,7 @@ namespace operators {
...
@@ -20,7 +20,7 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
SGDOpKernel
:
public
OpKernel
{
class
SGDOpKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
ExecutionContext
&
ctx
)
const
override
{
auto
param
=
ctx
.
Input
<
Tensor
>
(
"param"
);
auto
param
=
ctx
.
Input
<
Tensor
>
(
"param"
);
auto
grad
=
ctx
.
Input
<
Tensor
>
(
"grad"
);
auto
grad
=
ctx
.
Input
<
Tensor
>
(
"grad"
);
...
...
paddle/operators/sigmoid_op.cc
浏览文件 @
9620df44
...
@@ -17,7 +17,7 @@ namespace paddle {
...
@@ -17,7 +17,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
SigmoidOp
:
public
OperatorWithKernel
{
class
SigmoidOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1
,
"Sigmoid Op only have one input"
);
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1
,
"Sigmoid Op only have one input"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Sigmoid Op only have one output"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Sigmoid Op only have one output"
);
...
@@ -26,7 +26,7 @@ protected:
...
@@ -26,7 +26,7 @@ protected:
};
};
class
SigmoidOpMaker
:
public
OpProtoAndCheckerMaker
{
class
SigmoidOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
SigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
SigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"sigmoid input"
);
AddInput
(
"X"
,
"sigmoid input"
);
...
@@ -36,7 +36,7 @@ public:
...
@@ -36,7 +36,7 @@ public:
};
};
class
SigmoidOpGrad
:
public
OperatorWithKernel
{
class
SigmoidOpGrad
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{}
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{}
std
::
string
DebugString
()
const
override
{
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"SigmoidGrad"
;
LOG
(
INFO
)
<<
"SigmoidGrad"
;
...
...
paddle/operators/sigmoid_op.h
浏览文件 @
9620df44
...
@@ -21,7 +21,7 @@ namespace operators {
...
@@ -21,7 +21,7 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
SigmoidKernel
:
public
OpKernel
{
class
SigmoidKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
auto
input
=
context
.
Input
<
Tensor
>
(
0
);
auto
input
=
context
.
Input
<
Tensor
>
(
0
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
...
...
paddle/operators/softmax_op.cc
浏览文件 @
9620df44
...
@@ -18,7 +18,7 @@ namespace paddle {
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
SoftmaxOp
:
public
OperatorWithKernel
{
class
SoftmaxOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1UL
,
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1UL
,
"Only one input is need for softmax"
);
"Only one input is need for softmax"
);
...
@@ -31,7 +31,7 @@ protected:
...
@@ -31,7 +31,7 @@ protected:
};
};
class
SoftmaxOpMaker
:
public
OpProtoAndCheckerMaker
{
class
SoftmaxOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
SoftmaxOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
SoftmaxOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"input of softmax"
);
AddInput
(
"X"
,
"input of softmax"
);
...
@@ -41,7 +41,7 @@ public:
...
@@ -41,7 +41,7 @@ public:
};
};
class
SoftmaxOpGrad
:
public
OperatorWithKernel
{
class
SoftmaxOpGrad
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
3UL
,
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
3UL
,
"Input of SoftmaxOpGrad should be 3, X, Y, YG"
);
"Input of SoftmaxOpGrad should be 3, X, Y, YG"
);
...
...
paddle/operators/softmax_op.h
浏览文件 @
9620df44
...
@@ -24,7 +24,7 @@ namespace operators {
...
@@ -24,7 +24,7 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
SoftmaxKernel
:
public
OpKernel
{
class
SoftmaxKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
auto
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
output
=
context
.
Output
<
Tensor
>
(
"Y"
);
auto
output
=
context
.
Output
<
Tensor
>
(
"Y"
);
...
@@ -63,7 +63,7 @@ public:
...
@@ -63,7 +63,7 @@ public:
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
SoftmaxGradKernel
:
public
OpKernel
{
class
SoftmaxGradKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
std
::
shared_ptr
<
Tensor
>
scale_
=
std
::
make_shared
<
Tensor
>
();
std
::
shared_ptr
<
Tensor
>
scale_
=
std
::
make_shared
<
Tensor
>
();
...
...
paddle/operators/type_alias.h
浏览文件 @
9620df44
...
@@ -26,21 +26,16 @@ using OperatorBase = framework::OperatorBase;
...
@@ -26,21 +26,16 @@ using OperatorBase = framework::OperatorBase;
using
InferShapeContext
=
framework
::
InferShapeContext
;
using
InferShapeContext
=
framework
::
InferShapeContext
;
using
ExecutionContext
=
framework
::
ExecutionContext
;
using
ExecutionContext
=
framework
::
ExecutionContext
;
using
Variable
=
framework
::
Variable
;
using
Variable
=
framework
::
Variable
;
template
<
typename
T
,
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenScalar
=
framework
::
EigenScalar
<
T
,
MajorType
,
IndexType
>
;
using
EigenScalar
=
framework
::
EigenScalar
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
,
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
,
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
,
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录