Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
73a8b78a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
73a8b78a
编写于
10月 16, 2017
作者:
Y
Yu Yang
提交者:
GitHub
10月 16, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Correct OpWithKernel's infershape (#4847)
They are public now
上级
fafc54d6
变更
53
显示空白变更内容
内联
并排
Showing
53 changed file
with
15 addition
and
90 deletion
+15
-90
paddle/operators/accuracy_op.cc
paddle/operators/accuracy_op.cc
+0
-1
paddle/operators/activation_op.cc
paddle/operators/activation_op.cc
+0
-2
paddle/operators/adadelta_op.cc
paddle/operators/adadelta_op.cc
+0
-1
paddle/operators/adagrad_op.cc
paddle/operators/adagrad_op.cc
+0
-1
paddle/operators/adam_op.cc
paddle/operators/adam_op.cc
+0
-1
paddle/operators/adamax_op.cc
paddle/operators/adamax_op.cc
+0
-1
paddle/operators/clip_op.cc
paddle/operators/clip_op.cc
+0
-2
paddle/operators/concat_op.cc
paddle/operators/concat_op.cc
+0
-2
paddle/operators/conv2d_op.h
paddle/operators/conv2d_op.h
+0
-2
paddle/operators/conv_shift_op.cc
paddle/operators/conv_shift_op.cc
+0
-2
paddle/operators/cos_sim_op.cc
paddle/operators/cos_sim_op.cc
+0
-2
paddle/operators/crop_op.cc
paddle/operators/crop_op.cc
+0
-2
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+2
-2
paddle/operators/decayed_adagrad_op.cc
paddle/operators/decayed_adagrad_op.cc
+0
-1
paddle/operators/dropout_op.cc
paddle/operators/dropout_op.cc
+0
-2
paddle/operators/elementwise_op.h
paddle/operators/elementwise_op.h
+0
-2
paddle/operators/fill_constant_op.cc
paddle/operators/fill_constant_op.cc
+1
-1
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+0
-1
paddle/operators/gather_op.cc
paddle/operators/gather_op.cc
+2
-2
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+1
-1
paddle/operators/gru_unit_op.cc
paddle/operators/gru_unit_op.cc
+0
-2
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+2
-2
paddle/operators/lstm_unit_op.cc
paddle/operators/lstm_unit_op.cc
+0
-2
paddle/operators/margin_rank_loss_op.cc
paddle/operators/margin_rank_loss_op.cc
+0
-2
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+0
-2
paddle/operators/minus_op.cc
paddle/operators/minus_op.cc
+0
-1
paddle/operators/modified_huber_loss_op.cc
paddle/operators/modified_huber_loss_op.cc
+0
-2
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+0
-2
paddle/operators/multiplex_op.cc
paddle/operators/multiplex_op.cc
+2
-2
paddle/operators/pad_op.cc
paddle/operators/pad_op.cc
+0
-2
paddle/operators/pool_op.h
paddle/operators/pool_op.h
+0
-2
paddle/operators/pool_with_index_op.cc
paddle/operators/pool_with_index_op.cc
+0
-2
paddle/operators/prelu_op.cc
paddle/operators/prelu_op.cc
+0
-2
paddle/operators/rank_loss_op.cc
paddle/operators/rank_loss_op.cc
+0
-2
paddle/operators/reduce_op.cc
paddle/operators/reduce_op.cc
+0
-2
paddle/operators/reshape_op.cc
paddle/operators/reshape_op.cc
+0
-2
paddle/operators/rmsprop_op.cc
paddle/operators/rmsprop_op.cc
+0
-1
paddle/operators/scale_op.cc
paddle/operators/scale_op.cc
+0
-2
paddle/operators/scatter_op.cc
paddle/operators/scatter_op.cc
+2
-2
paddle/operators/sequence_concat_op.cc
paddle/operators/sequence_concat_op.cc
+0
-2
paddle/operators/sequence_pool_op.cc
paddle/operators/sequence_pool_op.cc
+0
-2
paddle/operators/sequence_softmax_op.cc
paddle/operators/sequence_softmax_op.cc
+0
-2
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+0
-1
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
+0
-2
paddle/operators/smooth_l1_loss_op.cc
paddle/operators/smooth_l1_loss_op.cc
+0
-2
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+0
-2
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+2
-2
paddle/operators/split_op.cc
paddle/operators/split_op.cc
+0
-1
paddle/operators/squared_l2_distance_op.cc
paddle/operators/squared_l2_distance_op.cc
+0
-2
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+0
-1
paddle/operators/top_k_op.cc
paddle/operators/top_k_op.cc
+0
-1
paddle/operators/transpose_op.cc
paddle/operators/transpose_op.cc
+0
-2
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+1
-1
未找到文件。
paddle/operators/accuracy_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class AccuracyOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class AccuracyOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Inference"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Inference"
),
"Input(Inference) of AccuracyOp should not be null."
);
"Input(Inference) of AccuracyOp should not be null."
);
...
...
paddle/operators/activation_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class ActivationOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class ActivationOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
...
@@ -32,7 +31,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
...
@@ -32,7 +31,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Y"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Y"
));
}
}
...
...
paddle/operators/adadelta_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class AdadeltaOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class AdadeltaOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdadeltaOp should not be null."
);
"Input(Param) of AdadeltaOp should not be null."
);
...
...
paddle/operators/adagrad_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class AdagradOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class AdagradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdagradOp should not be null."
);
"Input(Param) of AdagradOp should not be null."
);
...
...
paddle/operators/adam_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class AdamOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class AdamOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdamOp should not be null."
);
"Input(Param) of AdamOp should not be null."
);
...
...
paddle/operators/adamax_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class AdamaxOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class AdamaxOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdamaxOp should not be null."
);
"Input(Param) of AdamaxOp should not be null."
);
...
...
paddle/operators/clip_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class ClipOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class ClipOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ClipOp should not be null."
);
"Input(X) of ClipOp should not be null."
);
...
@@ -60,7 +59,6 @@ class ClipOpGrad : public framework::OperatorWithKernel {
...
@@ -60,7 +59,6 @@ class ClipOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/concat_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class ConcatOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class ConcatOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_GE
(
ctx
->
Inputs
(
"X"
).
size
(),
1UL
,
PADDLE_ENFORCE_GE
(
ctx
->
Inputs
(
"X"
).
size
(),
1UL
,
"Inputs(X) of ConcatOp should be empty."
)
"Inputs(X) of ConcatOp should be empty."
)
...
@@ -82,7 +81,6 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
...
@@ -82,7 +81,6 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputsDim
(
"X"
));
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputsDim
(
"X"
));
}
}
...
...
paddle/operators/conv2d_op.h
浏览文件 @
73a8b78a
...
@@ -44,7 +44,6 @@ class Conv2DOp : public framework::OperatorWithKernel {
...
@@ -44,7 +44,6 @@ class Conv2DOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
};
...
@@ -52,7 +51,6 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
...
@@ -52,7 +51,6 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
};
...
...
paddle/operators/conv_shift_op.cc
浏览文件 @
73a8b78a
...
@@ -27,7 +27,6 @@ class ConvShiftOp : public framework::OperatorWithKernel {
...
@@ -27,7 +27,6 @@ class ConvShiftOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
...
@@ -54,7 +53,6 @@ class ConvShiftGradOp : public framework::OperatorWithKernel {
...
@@ -54,7 +53,6 @@ class ConvShiftGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
...
...
paddle/operators/cos_sim_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class CosSimOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class CosSimOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// notnull check
// notnull check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
...
@@ -97,7 +96,6 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
...
@@ -97,7 +96,6 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// notnull check
// notnull check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
...
...
paddle/operators/crop_op.cc
浏览文件 @
73a8b78a
...
@@ -24,7 +24,6 @@ class CropOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,6 @@ class CropOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of CropOp should not be null."
);
"Input(X) of CropOp should not be null."
);
...
@@ -114,7 +113,6 @@ class CropOpGrad : public framework::OperatorWithKernel {
...
@@ -114,7 +113,6 @@ class CropOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
...
@@ -48,6 +47,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -48,6 +47,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
}
}
protected:
// CrossEntropy's data type just determined by "X"
// CrossEntropy's data type just determined by "X"
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
@@ -59,7 +59,6 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
...
@@ -59,7 +59,6 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
...
@@ -94,6 +93,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
...
@@ -94,6 +93,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
}
}
protected:
// CrossEntropy's data type just determined by "X"
// CrossEntropy's data type just determined by "X"
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
paddle/operators/decayed_adagrad_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of DecayedAdagradOp should not be null."
);
"Input(Param) of DecayedAdagradOp should not be null."
);
...
...
paddle/operators/dropout_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class DropoutOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class DropoutOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
0
);
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
0
);
...
@@ -69,7 +68,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
...
@@ -69,7 +68,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"is_training"
),
1
,
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"is_training"
),
1
,
"GradOp is only callable when is_training is true"
);
"GradOp is only callable when is_training is true"
);
...
...
paddle/operators/elementwise_op.h
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class ElementwiseOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class ElementwiseOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
...
@@ -105,7 +104,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
...
@@ -105,7 +104,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
...
...
paddle/operators/fill_constant_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class FillConstantOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class FillConstantOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FillConstantOp should not be null."
);
"Output(Out) of FillConstantOp should not be null."
);
...
@@ -33,6 +32,7 @@ class FillConstantOp : public framework::OperatorWithKernel {
...
@@ -33,6 +32,7 @@ class FillConstantOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
dims
);
ctx
->
SetOutputDim
(
"Out"
,
dims
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
));
return
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
));
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FillZerosLikeOp should not be null."
);
"Input(X) of FillZerosLikeOp should not be null."
);
...
...
paddle/operators/gather_op.cc
浏览文件 @
73a8b78a
...
@@ -22,7 +22,6 @@ class GatherOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,6 @@ class GatherOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of GatherOp should not be null."
);
"Input(X) of GatherOp should not be null."
);
...
@@ -40,6 +39,7 @@ class GatherOp : public framework::OperatorWithKernel {
...
@@ -40,6 +39,7 @@ class GatherOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
output_dims
);
ctx
->
SetOutputDim
(
"Out"
,
output_dims
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
...
@@ -50,11 +50,11 @@ class GatherGradOp : public framework::OperatorWithKernel {
...
@@ -50,11 +50,11 @@ class GatherGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
...
...
paddle/operators/gaussian_random_op.cc
浏览文件 @
73a8b78a
...
@@ -42,7 +42,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
...
@@ -42,7 +42,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of GaussianRandomOp should not be null."
);
"Output(Out) of GaussianRandomOp should not be null."
);
...
@@ -57,6 +56,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
...
@@ -57,6 +56,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
temp
));
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
temp
));
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
Attr
<
int
>
(
"data_type"
));
return
static_cast
<
framework
::
DataType
>
(
Attr
<
int
>
(
"data_type"
));
...
...
paddle/operators/gru_unit_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class GRUUnitOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class GRUUnitOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(%s) of GRUUnitOp should not be null."
,
"Input"
);
"Input(%s) of GRUUnitOp should not be null."
,
"Input"
);
...
@@ -131,7 +130,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
...
@@ -131,7 +130,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(%s) of GRUUnitGradOp should not be null."
,
"Input"
);
"Input(%s) of GRUUnitGradOp should not be null."
,
"Input"
);
...
...
paddle/operators/lookup_table_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class LookupTableOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class LookupTableOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
"Input(W) of LookupTableOp should not be null."
);
"Input(W) of LookupTableOp should not be null."
);
...
@@ -37,6 +36,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
...
@@ -37,6 +36,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
ctx
->
ShareLoD
(
"Ids"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"Ids"
,
/*->*/
"Out"
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"W"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"W"
)
->
type
());
...
@@ -69,12 +69,12 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
...
@@ -69,12 +69,12 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"W"
),
table_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"W"
),
table_dims
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"W"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"W"
)
->
type
());
...
...
paddle/operators/lstm_unit_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class LstmUnitOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class LstmUnitOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C_prev"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C_prev"
),
...
@@ -76,7 +75,6 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
...
@@ -76,7 +75,6 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"C"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"C"
)),
"Input(C@GRAD) should not be null"
);
"Input(C@GRAD) should not be null"
);
...
...
paddle/operators/margin_rank_loss_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
...
@@ -94,7 +93,6 @@ class MarginRankLossGradOp : public framework::OperatorWithKernel {
...
@@ -94,7 +93,6 @@ class MarginRankLossGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X1"
),
"Input(X1) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X1"
),
"Input(X1) shouldn't be null."
);
...
...
paddle/operators/mean_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class MeanOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class MeanOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MeanOp should not be null."
);
"Input(X) of MeanOp should not be null."
);
...
@@ -46,7 +45,6 @@ class MeanGradOp : public framework::OperatorWithKernel {
...
@@ -46,7 +45,6 @@ class MeanGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
}
...
...
paddle/operators/minus_op.cc
浏览文件 @
73a8b78a
...
@@ -25,7 +25,6 @@ class MinusOp : public framework::OperatorWithKernel {
...
@@ -25,7 +25,6 @@ class MinusOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MinusOp should not be null."
);
"Input(X) of MinusOp should not be null."
);
...
...
paddle/operators/modified_huber_loss_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
@@ -73,7 +72,6 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
...
@@ -73,7 +72,6 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
...
paddle/operators/mul_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class MulOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class MulOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of MulOp should not be null."
);
...
@@ -96,7 +95,6 @@ class MulOpGrad : public framework::OperatorWithKernel {
...
@@ -96,7 +95,6 @@ class MulOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
...
...
paddle/operators/multiplex_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class MultiplexOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class MultiplexOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) shouldn't be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
...
@@ -51,6 +50,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
...
@@ -51,6 +50,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
in_dim
);
ctx
->
SetOutputDim
(
"Out"
,
in_dim
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
...
@@ -89,7 +89,6 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
...
@@ -89,7 +89,6 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
empty
(),
PADDLE_ENFORCE
(
!
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
empty
(),
...
@@ -105,6 +104,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
...
@@ -105,6 +104,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
d_ins
);
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
d_ins
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
...
...
paddle/operators/pad_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class PadOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class PadOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of PadOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of PadOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -97,7 +96,6 @@ class PadOpGrad : public framework::OperatorWithKernel {
...
@@ -97,7 +96,6 @@ class PadOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/pool_op.h
浏览文件 @
73a8b78a
...
@@ -28,7 +28,6 @@ class PoolOp : public framework::OperatorWithKernel {
...
@@ -28,7 +28,6 @@ class PoolOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
};
...
@@ -36,7 +35,6 @@ class PoolOpGrad : public framework::OperatorWithKernel {
...
@@ -36,7 +35,6 @@ class PoolOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
};
...
...
paddle/operators/pool_with_index_op.cc
浏览文件 @
73a8b78a
...
@@ -27,7 +27,6 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
...
@@ -27,7 +27,6 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X(Input) of Pooling should not be null."
);
"X(Input) of Pooling should not be null."
);
...
@@ -72,7 +71,6 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
...
@@ -72,7 +71,6 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Mask"
),
"Input(Mask) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Mask"
),
"Input(Mask) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
...
...
paddle/operators/prelu_op.cc
浏览文件 @
73a8b78a
...
@@ -25,7 +25,6 @@ class PReluOp : public framework::OperatorWithKernel {
...
@@ -25,7 +25,6 @@ class PReluOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Alpha"
),
"Input(Alpha) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Alpha"
),
"Input(Alpha) should not be null"
);
...
@@ -62,7 +61,6 @@ class PReluGradOp : public framework::OperatorWithKernel {
...
@@ -62,7 +61,6 @@ class PReluGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/rank_loss_op.cc
浏览文件 @
73a8b78a
...
@@ -24,7 +24,6 @@ class RankLossOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,6 @@ class RankLossOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null"
);
...
@@ -89,7 +88,6 @@ class RankLossGradOp : public framework::OperatorWithKernel {
...
@@ -89,7 +88,6 @@ class RankLossGradOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null."
);
...
...
paddle/operators/reduce_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class ReduceOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class ReduceOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReduceOp should not be null."
);
"Input(X) of ReduceOp should not be null."
);
...
@@ -57,7 +56,6 @@ class ReduceGradOp : public framework::OperatorWithKernel {
...
@@ -57,7 +56,6 @@ class ReduceGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/reshape_op.cc
浏览文件 @
73a8b78a
...
@@ -25,7 +25,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
...
@@ -25,7 +25,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
...
@@ -93,7 +92,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
...
@@ -93,7 +92,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/rmsprop_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class RmspropOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class RmspropOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of RmspropOp should not be null."
);
"Input(Param) of RmspropOp should not be null."
);
...
...
paddle/operators/scale_op.cc
浏览文件 @
73a8b78a
...
@@ -25,7 +25,6 @@ class ScaleOp : public framework::OperatorWithKernel {
...
@@ -25,7 +25,6 @@ class ScaleOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ScaleOp should not be null."
);
"Input(X) of ScaleOp should not be null."
);
...
@@ -56,7 +55,6 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -56,7 +55,6 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
public:
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDescBind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDescBind
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDescBind
();
auto
*
grad_op
=
new
framework
::
OpDescBind
();
grad_op
->
SetType
(
"scale"
);
grad_op
->
SetType
(
"scale"
);
...
...
paddle/operators/scatter_op.cc
浏览文件 @
73a8b78a
...
@@ -22,7 +22,6 @@ class ScatterOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,6 @@ class ScatterOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ref"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ref"
),
"Input(Ref) of ScatterOp should not be null."
);
"Input(Ref) of ScatterOp should not be null."
);
...
@@ -49,6 +48,7 @@ class ScatterOp : public framework::OperatorWithKernel {
...
@@ -49,6 +48,7 @@ class ScatterOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
ref_dims
);
ctx
->
SetOutputDim
(
"Out"
,
ref_dims
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
...
@@ -59,13 +59,13 @@ class ScatterGradOp : public framework::OperatorWithKernel {
...
@@ -59,13 +59,13 @@ class ScatterGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Updates"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Updates"
),
ctx
->
GetInputDim
(
"Updates"
));
ctx
->
GetInputDim
(
"Updates"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ref"
),
ctx
->
GetInputDim
(
"Ref"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ref"
),
ctx
->
GetInputDim
(
"Ref"
));
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
...
...
paddle/operators/sequence_concat_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SequenceConcatOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SequenceConcatOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) of SequenceConcatOp should not be null."
);
"Inputs(X) of SequenceConcatOp should not be null."
);
...
@@ -105,7 +104,6 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel {
...
@@ -105,7 +104,6 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"The gradient of Out should not be null."
);
"The gradient of Out should not be null."
);
...
...
paddle/operators/sequence_pool_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SequencePoolOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SequencePoolOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequencePoolOp should not be null."
);
"Input(X) of SequencePoolOp should not be null."
);
...
@@ -72,7 +71,6 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
...
@@ -72,7 +71,6 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Gradient of Out should not be null."
);
"Gradient of Out should not be null."
);
...
...
paddle/operators/sequence_softmax_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequenceSoftmaxOp should not be null."
);
"Input(X) of SequenceSoftmaxOp should not be null."
);
...
@@ -66,7 +65,6 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
...
@@ -66,7 +65,6 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Out"
),
"Input(Out) of SequenceSoftmaxGradOp should not be null."
);
"Input(Out) of SequenceSoftmaxGradOp should not be null."
);
...
...
paddle/operators/sgd_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SGDOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SGDOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of SGDOp should not be null."
);
"Input(Param) of SGDOp should not be null."
);
...
...
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
...
@@ -52,7 +51,6 @@ class SigmoidCrossEntropyWithLogitsGradOp
...
@@ -52,7 +51,6 @@ class SigmoidCrossEntropyWithLogitsGradOp
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
...
...
paddle/operators/smooth_l1_loss_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
@@ -93,7 +92,6 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
...
@@ -93,7 +92,6 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
...
...
paddle/operators/softmax_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SoftmaxOp should not be null."
);
"Input(X) of SoftmaxOp should not be null."
);
...
@@ -68,7 +67,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
...
@@ -68,7 +67,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
...
...
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
73a8b78a
...
@@ -82,7 +82,6 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -82,7 +82,6 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
"Input(Logits) should be not null."
);
"Input(Logits) should be not null."
);
...
@@ -117,6 +116,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -117,6 +116,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
ctx
->
ShareLoD
(
"Logits"
,
/*->*/
"Loss"
);
ctx
->
ShareLoD
(
"Logits"
,
/*->*/
"Loss"
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
());
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
());
...
@@ -127,7 +127,6 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
...
@@ -127,7 +127,6 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
"Input(Loss@Grad) should not be null."
);
"Input(Loss@Grad) should not be null."
);
...
@@ -156,6 +155,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
...
@@ -156,6 +155,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
ctx
->
GetInputDim
(
"Softmax"
));
ctx
->
GetInputDim
(
"Softmax"
));
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
return
framework
::
ToDataType
(
...
...
paddle/operators/split_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class SplitOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class SplitOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SplitOp should not be null."
);
"Input(X) of SplitOp should not be null."
);
...
...
paddle/operators/squared_l2_distance_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SquaredL2DistanceOp should not be null."
);
"Input(X) of SquaredL2DistanceOp should not be null."
);
...
@@ -85,7 +84,6 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
...
@@ -85,7 +84,6 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Gradient of Out should not be null"
);
"Gradient of Out should not be null"
);
...
...
paddle/operators/sum_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class SumOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class SumOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) should not be null"
);
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
...
...
paddle/operators/top_k_op.cc
浏览文件 @
73a8b78a
...
@@ -21,7 +21,6 @@ class TopkOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class TopkOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of TopkOp should not be null."
);
"Input(X) of TopkOp should not be null."
);
...
...
paddle/operators/transpose_op.cc
浏览文件 @
73a8b78a
...
@@ -23,7 +23,6 @@ class TransposeOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,6 @@ class TransposeOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null"
);
...
@@ -92,7 +91,6 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
...
@@ -92,7 +91,6 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
73a8b78a
...
@@ -46,7 +46,6 @@ class UniformRandomOp : public framework::OperatorWithKernel {
...
@@ -46,7 +46,6 @@ class UniformRandomOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of UniformRandomOp should not be null."
);
"Output(Out) of UniformRandomOp should not be null."
);
...
@@ -63,6 +62,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
...
@@ -63,6 +62,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
temp
));
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
temp
));
}
}
protected:
framework
::
DataType
IndicateDataType
(
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
Attr
<
int
>
(
"data_type"
));
return
static_cast
<
framework
::
DataType
>
(
Attr
<
int
>
(
"data_type"
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录