Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
73a8b78a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
73a8b78a
编写于
10月 16, 2017
作者:
Y
Yu Yang
提交者:
GitHub
10月 16, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Correct OpWithKernel's infershape (#4847)
They are public now
上级
fafc54d6
变更
53
隐藏空白更改
内联
并排
Showing
53 changed file
with
15 addition
and
90 deletion
+15
-90
paddle/operators/accuracy_op.cc
paddle/operators/accuracy_op.cc
+0
-1
paddle/operators/activation_op.cc
paddle/operators/activation_op.cc
+0
-2
paddle/operators/adadelta_op.cc
paddle/operators/adadelta_op.cc
+0
-1
paddle/operators/adagrad_op.cc
paddle/operators/adagrad_op.cc
+0
-1
paddle/operators/adam_op.cc
paddle/operators/adam_op.cc
+0
-1
paddle/operators/adamax_op.cc
paddle/operators/adamax_op.cc
+0
-1
paddle/operators/clip_op.cc
paddle/operators/clip_op.cc
+0
-2
paddle/operators/concat_op.cc
paddle/operators/concat_op.cc
+0
-2
paddle/operators/conv2d_op.h
paddle/operators/conv2d_op.h
+0
-2
paddle/operators/conv_shift_op.cc
paddle/operators/conv_shift_op.cc
+0
-2
paddle/operators/cos_sim_op.cc
paddle/operators/cos_sim_op.cc
+0
-2
paddle/operators/crop_op.cc
paddle/operators/crop_op.cc
+0
-2
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+2
-2
paddle/operators/decayed_adagrad_op.cc
paddle/operators/decayed_adagrad_op.cc
+0
-1
paddle/operators/dropout_op.cc
paddle/operators/dropout_op.cc
+0
-2
paddle/operators/elementwise_op.h
paddle/operators/elementwise_op.h
+0
-2
paddle/operators/fill_constant_op.cc
paddle/operators/fill_constant_op.cc
+1
-1
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+0
-1
paddle/operators/gather_op.cc
paddle/operators/gather_op.cc
+2
-2
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+1
-1
paddle/operators/gru_unit_op.cc
paddle/operators/gru_unit_op.cc
+0
-2
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+2
-2
paddle/operators/lstm_unit_op.cc
paddle/operators/lstm_unit_op.cc
+0
-2
paddle/operators/margin_rank_loss_op.cc
paddle/operators/margin_rank_loss_op.cc
+0
-2
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+0
-2
paddle/operators/minus_op.cc
paddle/operators/minus_op.cc
+0
-1
paddle/operators/modified_huber_loss_op.cc
paddle/operators/modified_huber_loss_op.cc
+0
-2
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+0
-2
paddle/operators/multiplex_op.cc
paddle/operators/multiplex_op.cc
+2
-2
paddle/operators/pad_op.cc
paddle/operators/pad_op.cc
+0
-2
paddle/operators/pool_op.h
paddle/operators/pool_op.h
+0
-2
paddle/operators/pool_with_index_op.cc
paddle/operators/pool_with_index_op.cc
+0
-2
paddle/operators/prelu_op.cc
paddle/operators/prelu_op.cc
+0
-2
paddle/operators/rank_loss_op.cc
paddle/operators/rank_loss_op.cc
+0
-2
paddle/operators/reduce_op.cc
paddle/operators/reduce_op.cc
+0
-2
paddle/operators/reshape_op.cc
paddle/operators/reshape_op.cc
+0
-2
paddle/operators/rmsprop_op.cc
paddle/operators/rmsprop_op.cc
+0
-1
paddle/operators/scale_op.cc
paddle/operators/scale_op.cc
+0
-2
paddle/operators/scatter_op.cc
paddle/operators/scatter_op.cc
+2
-2
paddle/operators/sequence_concat_op.cc
paddle/operators/sequence_concat_op.cc
+0
-2
paddle/operators/sequence_pool_op.cc
paddle/operators/sequence_pool_op.cc
+0
-2
paddle/operators/sequence_softmax_op.cc
paddle/operators/sequence_softmax_op.cc
+0
-2
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+0
-1
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
+0
-2
paddle/operators/smooth_l1_loss_op.cc
paddle/operators/smooth_l1_loss_op.cc
+0
-2
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+0
-2
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+2
-2
paddle/operators/split_op.cc
paddle/operators/split_op.cc
+0
-1
paddle/operators/squared_l2_distance_op.cc
paddle/operators/squared_l2_distance_op.cc
+0
-2
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+0
-1
paddle/operators/top_k_op.cc
paddle/operators/top_k_op.cc
+0
-1
paddle/operators/transpose_op.cc
paddle/operators/transpose_op.cc
+0
-2
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+1
-1
未找到文件。
paddle/operators/accuracy_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class AccuracyOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class AccuracyOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Inference"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Inference"
),
"Input(Inference) of AccuracyOp should not be null."
);
"Input(Inference) of AccuracyOp should not be null."
);
...
...
paddle/operators/activation_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class ActivationOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class ActivationOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
...
@@ -32,7 +31,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
...
@@ -32,7 +31,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Y"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Y"
));
}
}
...
...
paddle/operators/adadelta_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class AdadeltaOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class AdadeltaOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdadeltaOp should not be null."
);
"Input(Param) of AdadeltaOp should not be null."
);
...
...
paddle/operators/adagrad_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class AdagradOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class AdagradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdagradOp should not be null."
);
"Input(Param) of AdagradOp should not be null."
);
...
...
paddle/operators/adam_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class AdamOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class AdamOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdamOp should not be null."
);
"Input(Param) of AdamOp should not be null."
);
...
...
paddle/operators/adamax_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class AdamaxOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class AdamaxOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdamaxOp should not be null."
);
"Input(Param) of AdamaxOp should not be null."
);
...
...
paddle/operators/clip_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class ClipOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class ClipOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ClipOp should not be null."
);
"Input(X) of ClipOp should not be null."
);
...
@@ -60,7 +59,6 @@ class ClipOpGrad : public framework::OperatorWithKernel {
...
@@ -60,7 +59,6 @@ class ClipOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/concat_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class ConcatOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class ConcatOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_GE
(
ctx
->
Inputs
(
"X"
).
size
(),
1UL
,
PADDLE_ENFORCE_GE
(
ctx
->
Inputs
(
"X"
).
size
(),
1UL
,
"Inputs(X) of ConcatOp should be empty."
)
"Inputs(X) of ConcatOp should be empty."
)
...
@@ -82,7 +81,6 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
...
@@ -82,7 +81,6 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputsDim
(
"X"
));
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputsDim
(
"X"
));
}
}
...
...
paddle/operators/conv2d_op.h
浏览文件 @
73a8b78a
...
@@ -44,7 +44,6 @@ class Conv2DOp : public framework::OperatorWithKernel {
...
@@ -44,7 +44,6 @@ class Conv2DOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
};
...
@@ -52,7 +51,6 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
...
@@ -52,7 +51,6 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
};
...
...
paddle/operators/conv_shift_op.cc
浏览文件 @
73a8b78a
...
@@ -27,7 +27,6 @@ class ConvShiftOp : public framework::OperatorWithKernel {
...
@@ -27,7 +27,6 @@ class ConvShiftOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
...
@@ -54,7 +53,6 @@ class ConvShiftGradOp : public framework::OperatorWithKernel {
...
@@ -54,7 +53,6 @@ class ConvShiftGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
...
...
paddle/operators/cos_sim_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class CosSimOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class CosSimOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// notnull check
// notnull check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
...
@@ -97,7 +96,6 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
...
@@ -97,7 +96,6 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// notnull check
// notnull check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
...
...
paddle/operators/crop_op.cc
浏览文件 @
73a8b78a
...
@@ -24,7 +24,6 @@ class CropOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,6 @@ class CropOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of CropOp should not be null."
);
"Input(X) of CropOp should not be null."
);
...
@@ -114,7 +113,6 @@ class CropOpGrad : public framework::OperatorWithKernel {
...
@@ -114,7 +113,6 @@ class CropOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
...
@@ -48,6 +47,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -48,6 +47,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
}
}
protected:
// CrossEntropy's data type just determined by "X"
// CrossEntropy's data type just determined by "X"
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
@@ -59,7 +59,6 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
...
@@ -59,7 +59,6 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
...
@@ -94,6 +93,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
...
@@ -94,6 +93,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
}
}
protected:
// CrossEntropy's data type just determined by "X"
// CrossEntropy's data type just determined by "X"
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
paddle/operators/decayed_adagrad_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of DecayedAdagradOp should not be null."
);
"Input(Param) of DecayedAdagradOp should not be null."
);
...
...
paddle/operators/dropout_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class DropoutOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class DropoutOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
0
);
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
0
);
...
@@ -69,7 +68,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
...
@@ -69,7 +68,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"is_training"
),
1
,
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"is_training"
),
1
,
"GradOp is only callable when is_training is true"
);
"GradOp is only callable when is_training is true"
);
...
...
paddle/operators/elementwise_op.h
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class ElementwiseOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class ElementwiseOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
...
@@ -105,7 +104,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
...
@@ -105,7 +104,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
...
...
paddle/operators/fill_constant_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class FillConstantOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class FillConstantOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FillConstantOp should not be null."
);
"Output(Out) of FillConstantOp should not be null."
);
...
@@ -33,6 +32,7 @@ class FillConstantOp : public framework::OperatorWithKernel {
...
@@ -33,6 +32,7 @@ class FillConstantOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
dims
);
ctx
->
SetOutputDim
(
"Out"
,
dims
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
));
return
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
));
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FillZerosLikeOp should not be null."
);
"Input(X) of FillZerosLikeOp should not be null."
);
...
...
paddle/operators/gather_op.cc
浏览文件 @
73a8b78a
...
@@ -22,7 +22,6 @@ class GatherOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,6 @@ class GatherOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of GatherOp should not be null."
);
"Input(X) of GatherOp should not be null."
);
...
@@ -40,6 +39,7 @@ class GatherOp : public framework::OperatorWithKernel {
...
@@ -40,6 +39,7 @@ class GatherOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
output_dims
);
ctx
->
SetOutputDim
(
"Out"
,
output_dims
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
...
@@ -50,11 +50,11 @@ class GatherGradOp : public framework::OperatorWithKernel {
...
@@ -50,11 +50,11 @@ class GatherGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
...
...
paddle/operators/gaussian_random_op.cc
浏览文件 @
73a8b78a
...
@@ -42,7 +42,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
...
@@ -42,7 +42,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of GaussianRandomOp should not be null."
);
"Output(Out) of GaussianRandomOp should not be null."
);
...
@@ -57,6 +56,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
...
@@ -57,6 +56,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
temp
));
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
temp
));
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
Attr
<
int
>
(
"data_type"
));
return
static_cast
<
framework
::
DataType
>
(
Attr
<
int
>
(
"data_type"
));
...
...
paddle/operators/gru_unit_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class GRUUnitOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class GRUUnitOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(%s) of GRUUnitOp should not be null."
,
"Input"
);
"Input(%s) of GRUUnitOp should not be null."
,
"Input"
);
...
@@ -131,7 +130,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
...
@@ -131,7 +130,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(%s) of GRUUnitGradOp should not be null."
,
"Input"
);
"Input(%s) of GRUUnitGradOp should not be null."
,
"Input"
);
...
...
paddle/operators/lookup_table_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class LookupTableOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class LookupTableOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
"Input(W) of LookupTableOp should not be null."
);
"Input(W) of LookupTableOp should not be null."
);
...
@@ -37,6 +36,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
...
@@ -37,6 +36,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
ctx
->
ShareLoD
(
"Ids"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"Ids"
,
/*->*/
"Out"
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"W"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"W"
)
->
type
());
...
@@ -69,12 +69,12 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
...
@@ -69,12 +69,12 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"W"
),
table_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"W"
),
table_dims
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"W"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"W"
)
->
type
());
...
...
paddle/operators/lstm_unit_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class LstmUnitOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class LstmUnitOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C_prev"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C_prev"
),
...
@@ -76,7 +75,6 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
...
@@ -76,7 +75,6 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"C"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"C"
)),
"Input(C@GRAD) should not be null"
);
"Input(C@GRAD) should not be null"
);
...
...
paddle/operators/margin_rank_loss_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
...
@@ -94,7 +93,6 @@ class MarginRankLossGradOp : public framework::OperatorWithKernel {
...
@@ -94,7 +93,6 @@ class MarginRankLossGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X1"
),
"Input(X1) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X1"
),
"Input(X1) shouldn't be null."
);
...
...
paddle/operators/mean_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class MeanOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class MeanOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MeanOp should not be null."
);
"Input(X) of MeanOp should not be null."
);
...
@@ -46,7 +45,6 @@ class MeanGradOp : public framework::OperatorWithKernel {
...
@@ -46,7 +45,6 @@ class MeanGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
}
...
...
paddle/operators/minus_op.cc
浏览文件 @
73a8b78a
...
@@ -25,7 +25,6 @@ class MinusOp : public framework::OperatorWithKernel {
...
@@ -25,7 +25,6 @@ class MinusOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MinusOp should not be null."
);
"Input(X) of MinusOp should not be null."
);
...
...
paddle/operators/modified_huber_loss_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
@@ -73,7 +72,6 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
...
@@ -73,7 +72,6 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
...
paddle/operators/mul_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class MulOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class MulOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of MulOp should not be null."
);
...
@@ -96,7 +95,6 @@ class MulOpGrad : public framework::OperatorWithKernel {
...
@@ -96,7 +95,6 @@ class MulOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
...
...
paddle/operators/multiplex_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class MultiplexOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class MultiplexOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) shouldn't be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
...
@@ -51,6 +50,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
...
@@ -51,6 +50,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
in_dim
);
ctx
->
SetOutputDim
(
"Out"
,
in_dim
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
...
@@ -89,7 +89,6 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
...
@@ -89,7 +89,6 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
empty
(),
PADDLE_ENFORCE
(
!
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
empty
(),
...
@@ -105,6 +104,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
...
@@ -105,6 +104,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
d_ins
);
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
d_ins
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
...
...
paddle/operators/pad_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class PadOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class PadOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of PadOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of PadOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -97,7 +96,6 @@ class PadOpGrad : public framework::OperatorWithKernel {
...
@@ -97,7 +96,6 @@ class PadOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/pool_op.h
浏览文件 @
73a8b78a
...
@@ -28,7 +28,6 @@ class PoolOp : public framework::OperatorWithKernel {
...
@@ -28,7 +28,6 @@ class PoolOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
};
...
@@ -36,7 +35,6 @@ class PoolOpGrad : public framework::OperatorWithKernel {
...
@@ -36,7 +35,6 @@ class PoolOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
};
...
...
paddle/operators/pool_with_index_op.cc
浏览文件 @
73a8b78a
...
@@ -27,7 +27,6 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
...
@@ -27,7 +27,6 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X(Input) of Pooling should not be null."
);
"X(Input) of Pooling should not be null."
);
...
@@ -72,7 +71,6 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
...
@@ -72,7 +71,6 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Mask"
),
"Input(Mask) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Mask"
),
"Input(Mask) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
...
...
paddle/operators/prelu_op.cc
浏览文件 @
73a8b78a
...
@@ -25,7 +25,6 @@ class PReluOp : public framework::OperatorWithKernel {
...
@@ -25,7 +25,6 @@ class PReluOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Alpha"
),
"Input(Alpha) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Alpha"
),
"Input(Alpha) should not be null"
);
...
@@ -62,7 +61,6 @@ class PReluGradOp : public framework::OperatorWithKernel {
...
@@ -62,7 +61,6 @@ class PReluGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/rank_loss_op.cc
浏览文件 @
73a8b78a
...
@@ -24,7 +24,6 @@ class RankLossOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,6 @@ class RankLossOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null"
);
...
@@ -89,7 +88,6 @@ class RankLossGradOp : public framework::OperatorWithKernel {
...
@@ -89,7 +88,6 @@ class RankLossGradOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null."
);
...
...
paddle/operators/reduce_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class ReduceOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class ReduceOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReduceOp should not be null."
);
"Input(X) of ReduceOp should not be null."
);
...
@@ -57,7 +56,6 @@ class ReduceGradOp : public framework::OperatorWithKernel {
...
@@ -57,7 +56,6 @@ class ReduceGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/reshape_op.cc
浏览文件 @
73a8b78a
...
@@ -25,7 +25,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
...
@@ -25,7 +25,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
...
@@ -93,7 +92,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
...
@@ -93,7 +92,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/rmsprop_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class RmspropOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class RmspropOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of RmspropOp should not be null."
);
"Input(Param) of RmspropOp should not be null."
);
...
...
paddle/operators/scale_op.cc
浏览文件 @
73a8b78a
...
@@ -25,7 +25,6 @@ class ScaleOp : public framework::OperatorWithKernel {
...
@@ -25,7 +25,6 @@ class ScaleOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ScaleOp should not be null."
);
"Input(X) of ScaleOp should not be null."
);
...
@@ -56,7 +55,6 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -56,7 +55,6 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
public:
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDescBind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDescBind
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDescBind
();
auto
*
grad_op
=
new
framework
::
OpDescBind
();
grad_op
->
SetType
(
"scale"
);
grad_op
->
SetType
(
"scale"
);
...
...
paddle/operators/scatter_op.cc
浏览文件 @
73a8b78a
...
@@ -22,7 +22,6 @@ class ScatterOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,6 @@ class ScatterOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ref"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ref"
),
"Input(Ref) of ScatterOp should not be null."
);
"Input(Ref) of ScatterOp should not be null."
);
...
@@ -49,6 +48,7 @@ class ScatterOp : public framework::OperatorWithKernel {
...
@@ -49,6 +48,7 @@ class ScatterOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
ref_dims
);
ctx
->
SetOutputDim
(
"Out"
,
ref_dims
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
...
@@ -59,13 +59,13 @@ class ScatterGradOp : public framework::OperatorWithKernel {
...
@@ -59,13 +59,13 @@ class ScatterGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Updates"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Updates"
),
ctx
->
GetInputDim
(
"Updates"
));
ctx
->
GetInputDim
(
"Updates"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ref"
),
ctx
->
GetInputDim
(
"Ref"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ref"
),
ctx
->
GetInputDim
(
"Ref"
));
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
...
...
paddle/operators/sequence_concat_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SequenceConcatOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SequenceConcatOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) of SequenceConcatOp should not be null."
);
"Inputs(X) of SequenceConcatOp should not be null."
);
...
@@ -105,7 +104,6 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel {
...
@@ -105,7 +104,6 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"The gradient of Out should not be null."
);
"The gradient of Out should not be null."
);
...
...
paddle/operators/sequence_pool_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SequencePoolOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SequencePoolOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequencePoolOp should not be null."
);
"Input(X) of SequencePoolOp should not be null."
);
...
@@ -72,7 +71,6 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
...
@@ -72,7 +71,6 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Gradient of Out should not be null."
);
"Gradient of Out should not be null."
);
...
...
paddle/operators/sequence_softmax_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequenceSoftmaxOp should not be null."
);
"Input(X) of SequenceSoftmaxOp should not be null."
);
...
@@ -66,7 +65,6 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
...
@@ -66,7 +65,6 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Out"
),
"Input(Out) of SequenceSoftmaxGradOp should not be null."
);
"Input(Out) of SequenceSoftmaxGradOp should not be null."
);
...
...
paddle/operators/sgd_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SGDOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SGDOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of SGDOp should not be null."
);
"Input(Param) of SGDOp should not be null."
);
...
...
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
...
@@ -52,7 +51,6 @@ class SigmoidCrossEntropyWithLogitsGradOp
...
@@ -52,7 +51,6 @@ class SigmoidCrossEntropyWithLogitsGradOp
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
...
...
paddle/operators/smooth_l1_loss_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
@@ -93,7 +92,6 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
...
@@ -93,7 +92,6 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
...
...
paddle/operators/softmax_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SoftmaxOp should not be null."
);
"Input(X) of SoftmaxOp should not be null."
);
...
@@ -68,7 +67,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
...
@@ -68,7 +67,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
...
...
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
73a8b78a
...
@@ -82,7 +82,6 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -82,7 +82,6 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
"Input(Logits) should be not null."
);
"Input(Logits) should be not null."
);
...
@@ -117,6 +116,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -117,6 +116,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
ctx
->
ShareLoD
(
"Logits"
,
/*->*/
"Loss"
);
ctx
->
ShareLoD
(
"Logits"
,
/*->*/
"Loss"
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
());
...
@@ -127,7 +127,6 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
...
@@ -127,7 +127,6 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
"Input(Loss@Grad) should not be null."
);
"Input(Loss@Grad) should not be null."
);
...
@@ -156,6 +155,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
...
@@ -156,6 +155,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
ctx
->
GetInputDim
(
"Softmax"
));
ctx
->
GetInputDim
(
"Softmax"
));
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
return
framework
::
ToDataType
(
...
...
paddle/operators/split_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class SplitOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class SplitOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SplitOp should not be null."
);
"Input(X) of SplitOp should not be null."
);
...
...
paddle/operators/squared_l2_distance_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SquaredL2DistanceOp should not be null."
);
"Input(X) of SquaredL2DistanceOp should not be null."
);
...
@@ -85,7 +84,6 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
...
@@ -85,7 +84,6 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Gradient of Out should not be null"
);
"Gradient of Out should not be null"
);
...
...
paddle/operators/sum_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SumOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SumOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) should not be null"
);
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
...
...
paddle/operators/top_k_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class TopkOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class TopkOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of TopkOp should not be null."
);
"Input(X) of TopkOp should not be null."
);
...
...
paddle/operators/transpose_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class TransposeOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class TransposeOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null"
);
...
@@ -92,7 +91,6 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
...
@@ -92,7 +91,6 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
73a8b78a
...
@@ -46,7 +46,6 @@ class UniformRandomOp : public framework::OperatorWithKernel {
...
@@ -46,7 +46,6 @@ class UniformRandomOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of UniformRandomOp should not be null."
);
"Output(Out) of UniformRandomOp should not be null."
);
...
@@ -63,6 +62,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
...
@@ -63,6 +62,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
temp
));
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
temp
));
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
Attr
<
int
>
(
"data_type"
));
return
static_cast
<
framework
::
DataType
>
(
Attr
<
int
>
(
"data_type"
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录