Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
73a8b78a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
73a8b78a
编写于
10月 16, 2017
作者:
Y
Yu Yang
提交者:
GitHub
10月 16, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Correct OpWithKernel's infershape (#4847)
They are public now
上级
fafc54d6
变更
53
隐藏空白更改
内联
并排
Showing
53 changed file
with
15 addition
and
90 deletion
+15
-90
paddle/operators/accuracy_op.cc
paddle/operators/accuracy_op.cc
+0
-1
paddle/operators/activation_op.cc
paddle/operators/activation_op.cc
+0
-2
paddle/operators/adadelta_op.cc
paddle/operators/adadelta_op.cc
+0
-1
paddle/operators/adagrad_op.cc
paddle/operators/adagrad_op.cc
+0
-1
paddle/operators/adam_op.cc
paddle/operators/adam_op.cc
+0
-1
paddle/operators/adamax_op.cc
paddle/operators/adamax_op.cc
+0
-1
paddle/operators/clip_op.cc
paddle/operators/clip_op.cc
+0
-2
paddle/operators/concat_op.cc
paddle/operators/concat_op.cc
+0
-2
paddle/operators/conv2d_op.h
paddle/operators/conv2d_op.h
+0
-2
paddle/operators/conv_shift_op.cc
paddle/operators/conv_shift_op.cc
+0
-2
paddle/operators/cos_sim_op.cc
paddle/operators/cos_sim_op.cc
+0
-2
paddle/operators/crop_op.cc
paddle/operators/crop_op.cc
+0
-2
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+2
-2
paddle/operators/decayed_adagrad_op.cc
paddle/operators/decayed_adagrad_op.cc
+0
-1
paddle/operators/dropout_op.cc
paddle/operators/dropout_op.cc
+0
-2
paddle/operators/elementwise_op.h
paddle/operators/elementwise_op.h
+0
-2
paddle/operators/fill_constant_op.cc
paddle/operators/fill_constant_op.cc
+1
-1
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+0
-1
paddle/operators/gather_op.cc
paddle/operators/gather_op.cc
+2
-2
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+1
-1
paddle/operators/gru_unit_op.cc
paddle/operators/gru_unit_op.cc
+0
-2
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+2
-2
paddle/operators/lstm_unit_op.cc
paddle/operators/lstm_unit_op.cc
+0
-2
paddle/operators/margin_rank_loss_op.cc
paddle/operators/margin_rank_loss_op.cc
+0
-2
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+0
-2
paddle/operators/minus_op.cc
paddle/operators/minus_op.cc
+0
-1
paddle/operators/modified_huber_loss_op.cc
paddle/operators/modified_huber_loss_op.cc
+0
-2
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+0
-2
paddle/operators/multiplex_op.cc
paddle/operators/multiplex_op.cc
+2
-2
paddle/operators/pad_op.cc
paddle/operators/pad_op.cc
+0
-2
paddle/operators/pool_op.h
paddle/operators/pool_op.h
+0
-2
paddle/operators/pool_with_index_op.cc
paddle/operators/pool_with_index_op.cc
+0
-2
paddle/operators/prelu_op.cc
paddle/operators/prelu_op.cc
+0
-2
paddle/operators/rank_loss_op.cc
paddle/operators/rank_loss_op.cc
+0
-2
paddle/operators/reduce_op.cc
paddle/operators/reduce_op.cc
+0
-2
paddle/operators/reshape_op.cc
paddle/operators/reshape_op.cc
+0
-2
paddle/operators/rmsprop_op.cc
paddle/operators/rmsprop_op.cc
+0
-1
paddle/operators/scale_op.cc
paddle/operators/scale_op.cc
+0
-2
paddle/operators/scatter_op.cc
paddle/operators/scatter_op.cc
+2
-2
paddle/operators/sequence_concat_op.cc
paddle/operators/sequence_concat_op.cc
+0
-2
paddle/operators/sequence_pool_op.cc
paddle/operators/sequence_pool_op.cc
+0
-2
paddle/operators/sequence_softmax_op.cc
paddle/operators/sequence_softmax_op.cc
+0
-2
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+0
-1
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
+0
-2
paddle/operators/smooth_l1_loss_op.cc
paddle/operators/smooth_l1_loss_op.cc
+0
-2
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+0
-2
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+2
-2
paddle/operators/split_op.cc
paddle/operators/split_op.cc
+0
-1
paddle/operators/squared_l2_distance_op.cc
paddle/operators/squared_l2_distance_op.cc
+0
-2
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+0
-1
paddle/operators/top_k_op.cc
paddle/operators/top_k_op.cc
+0
-1
paddle/operators/transpose_op.cc
paddle/operators/transpose_op.cc
+0
-2
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+1
-1
未找到文件。
paddle/operators/accuracy_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class AccuracyOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Inference"
),
"Input(Inference) of AccuracyOp should not be null."
);
...
...
paddle/operators/activation_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class ActivationOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
...
...
@@ -32,7 +31,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Y"
));
}
...
...
paddle/operators/adadelta_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class AdadeltaOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdadeltaOp should not be null."
);
...
...
paddle/operators/adagrad_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class AdagradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdagradOp should not be null."
);
...
...
paddle/operators/adam_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class AdamOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdamOp should not be null."
);
...
...
paddle/operators/adamax_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class AdamaxOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdamaxOp should not be null."
);
...
...
paddle/operators/clip_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class ClipOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ClipOp should not be null."
);
...
...
@@ -60,7 +59,6 @@ class ClipOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/concat_op.cc
浏览文件 @
73a8b78a
...
...
@@ -23,7 +23,6 @@ class ConcatOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_GE
(
ctx
->
Inputs
(
"X"
).
size
(),
1UL
,
"Inputs(X) of ConcatOp should be empty."
)
...
...
@@ -82,7 +81,6 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputsDim
(
"X"
));
}
...
...
paddle/operators/conv2d_op.h
浏览文件 @
73a8b78a
...
...
@@ -44,7 +44,6 @@ class Conv2DOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
...
...
@@ -52,7 +51,6 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
...
...
paddle/operators/conv_shift_op.cc
浏览文件 @
73a8b78a
...
...
@@ -27,7 +27,6 @@ class ConvShiftOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
...
...
@@ -54,7 +53,6 @@ class ConvShiftGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
...
...
paddle/operators/cos_sim_op.cc
浏览文件 @
73a8b78a
...
...
@@ -23,7 +23,6 @@ class CosSimOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// notnull check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
...
...
@@ -97,7 +96,6 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// notnull check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
...
...
paddle/operators/crop_op.cc
浏览文件 @
73a8b78a
...
...
@@ -24,7 +24,6 @@ class CropOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of CropOp should not be null."
);
...
...
@@ -114,7 +113,6 @@ class CropOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
...
...
@@ -48,6 +47,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
}
protected:
// CrossEntropy's data type just determined by "X"
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
@@ -59,7 +59,6 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
...
...
@@ -94,6 +93,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
}
protected:
// CrossEntropy's data type just determined by "X"
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
paddle/operators/decayed_adagrad_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of DecayedAdagradOp should not be null."
);
...
...
paddle/operators/dropout_op.cc
浏览文件 @
73a8b78a
...
...
@@ -23,7 +23,6 @@ class DropoutOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
0
);
...
...
@@ -69,7 +68,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"is_training"
),
1
,
"GradOp is only callable when is_training is true"
);
...
...
paddle/operators/elementwise_op.h
浏览文件 @
73a8b78a
...
...
@@ -23,7 +23,6 @@ class ElementwiseOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
using
Tensor
=
framework
::
Tensor
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
...
...
@@ -105,7 +104,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
Tensor
=
framework
::
Tensor
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
...
...
paddle/operators/fill_constant_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class FillConstantOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FillConstantOp should not be null."
);
...
...
@@ -33,6 +32,7 @@ class FillConstantOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
dims
);
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
));
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FillZerosLikeOp should not be null."
);
...
...
paddle/operators/gather_op.cc
浏览文件 @
73a8b78a
...
...
@@ -22,7 +22,6 @@ class GatherOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of GatherOp should not be null."
);
...
...
@@ -40,6 +39,7 @@ class GatherOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
output_dims
);
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
...
...
@@ -50,11 +50,11 @@ class GatherGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
...
...
paddle/operators/gaussian_random_op.cc
浏览文件 @
73a8b78a
...
...
@@ -42,7 +42,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of GaussianRandomOp should not be null."
);
...
...
@@ -57,6 +56,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
temp
));
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
Attr
<
int
>
(
"data_type"
));
...
...
paddle/operators/gru_unit_op.cc
浏览文件 @
73a8b78a
...
...
@@ -23,7 +23,6 @@ class GRUUnitOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(%s) of GRUUnitOp should not be null."
,
"Input"
);
...
...
@@ -131,7 +130,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(%s) of GRUUnitGradOp should not be null."
,
"Input"
);
...
...
paddle/operators/lookup_table_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class LookupTableOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
"Input(W) of LookupTableOp should not be null."
);
...
...
@@ -37,6 +36,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
ctx
->
ShareLoD
(
"Ids"
,
/*->*/
"Out"
);
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"W"
)
->
type
());
...
...
@@ -69,12 +69,12 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"W"
),
table_dims
);
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"W"
)
->
type
());
...
...
paddle/operators/lstm_unit_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class LstmUnitOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C_prev"
),
...
...
@@ -76,7 +75,6 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"C"
)),
"Input(C@GRAD) should not be null"
);
...
...
paddle/operators/margin_rank_loss_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
...
...
@@ -94,7 +93,6 @@ class MarginRankLossGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X1"
),
"Input(X1) shouldn't be null."
);
...
...
paddle/operators/mean_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class MeanOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MeanOp should not be null."
);
...
...
@@ -46,7 +45,6 @@ class MeanGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
...
...
paddle/operators/minus_op.cc
浏览文件 @
73a8b78a
...
...
@@ -25,7 +25,6 @@ class MinusOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MinusOp should not be null."
);
...
...
paddle/operators/modified_huber_loss_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
...
@@ -73,7 +72,6 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
...
paddle/operators/mul_op.cc
浏览文件 @
73a8b78a
...
...
@@ -23,7 +23,6 @@ class MulOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of MulOp should not be null."
);
...
...
@@ -96,7 +95,6 @@ class MulOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
...
...
paddle/operators/multiplex_op.cc
浏览文件 @
73a8b78a
...
...
@@ -23,7 +23,6 @@ class MultiplexOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) shouldn't be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
...
...
@@ -51,6 +50,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
in_dim
);
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
...
...
@@ -89,7 +89,6 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
empty
(),
...
...
@@ -105,6 +104,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
d_ins
);
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
...
...
paddle/operators/pad_op.cc
浏览文件 @
73a8b78a
...
...
@@ -23,7 +23,6 @@ class PadOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of PadOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -97,7 +96,6 @@ class PadOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/pool_op.h
浏览文件 @
73a8b78a
...
...
@@ -28,7 +28,6 @@ class PoolOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
...
...
@@ -36,7 +35,6 @@ class PoolOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
...
...
paddle/operators/pool_with_index_op.cc
浏览文件 @
73a8b78a
...
...
@@ -27,7 +27,6 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X(Input) of Pooling should not be null."
);
...
...
@@ -72,7 +71,6 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Mask"
),
"Input(Mask) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
...
...
paddle/operators/prelu_op.cc
浏览文件 @
73a8b78a
...
...
@@ -25,7 +25,6 @@ class PReluOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Alpha"
),
"Input(Alpha) should not be null"
);
...
...
@@ -62,7 +61,6 @@ class PReluGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/rank_loss_op.cc
浏览文件 @
73a8b78a
...
...
@@ -24,7 +24,6 @@ class RankLossOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null"
);
...
...
@@ -89,7 +88,6 @@ class RankLossGradOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null."
);
...
...
paddle/operators/reduce_op.cc
浏览文件 @
73a8b78a
...
...
@@ -23,7 +23,6 @@ class ReduceOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReduceOp should not be null."
);
...
...
@@ -57,7 +56,6 @@ class ReduceGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/reshape_op.cc
浏览文件 @
73a8b78a
...
...
@@ -25,7 +25,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
...
...
@@ -93,7 +92,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/rmsprop_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class RmspropOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of RmspropOp should not be null."
);
...
...
paddle/operators/scale_op.cc
浏览文件 @
73a8b78a
...
...
@@ -25,7 +25,6 @@ class ScaleOp : public framework::OperatorWithKernel {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ScaleOp should not be null."
);
...
...
@@ -56,7 +55,6 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDescBind
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDescBind
();
grad_op
->
SetType
(
"scale"
);
...
...
paddle/operators/scatter_op.cc
浏览文件 @
73a8b78a
...
...
@@ -22,7 +22,6 @@ class ScatterOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ref"
),
"Input(Ref) of ScatterOp should not be null."
);
...
...
@@ -49,6 +48,7 @@ class ScatterOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
ref_dims
);
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
...
...
@@ -59,13 +59,13 @@ class ScatterGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Updates"
),
ctx
->
GetInputDim
(
"Updates"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ref"
),
ctx
->
GetInputDim
(
"Ref"
));
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
...
...
paddle/operators/sequence_concat_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class SequenceConcatOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) of SequenceConcatOp should not be null."
);
...
...
@@ -105,7 +104,6 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"The gradient of Out should not be null."
);
...
...
paddle/operators/sequence_pool_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class SequencePoolOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequencePoolOp should not be null."
);
...
...
@@ -72,7 +71,6 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Gradient of Out should not be null."
);
...
...
paddle/operators/sequence_softmax_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequenceSoftmaxOp should not be null."
);
...
...
@@ -66,7 +65,6 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Out"
),
"Input(Out) of SequenceSoftmaxGradOp should not be null."
);
...
...
paddle/operators/sgd_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class SGDOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of SGDOp should not be null."
);
...
...
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
浏览文件 @
73a8b78a
...
...
@@ -23,7 +23,6 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
...
...
@@ -52,7 +51,6 @@ class SigmoidCrossEntropyWithLogitsGradOp
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
...
...
paddle/operators/smooth_l1_loss_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
...
@@ -93,7 +92,6 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
...
...
paddle/operators/softmax_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SoftmaxOp should not be null."
);
...
...
@@ -68,7 +67,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
...
...
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
73a8b78a
...
...
@@ -82,7 +82,6 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
"Input(Logits) should be not null."
);
...
...
@@ -117,6 +116,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
ctx
->
ShareLoD
(
"Logits"
,
/*->*/
"Loss"
);
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
());
...
...
@@ -127,7 +127,6 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
"Input(Loss@Grad) should not be null."
);
...
...
@@ -156,6 +155,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
ctx
->
GetInputDim
(
"Softmax"
));
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
...
...
paddle/operators/split_op.cc
浏览文件 @
73a8b78a
...
...
@@ -23,7 +23,6 @@ class SplitOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SplitOp should not be null."
);
...
...
paddle/operators/squared_l2_distance_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SquaredL2DistanceOp should not be null."
);
...
...
@@ -85,7 +84,6 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Gradient of Out should not be null"
);
...
...
paddle/operators/sum_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class SumOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) should not be null"
);
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
...
...
paddle/operators/top_k_op.cc
浏览文件 @
73a8b78a
...
...
@@ -21,7 +21,6 @@ class TopkOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of TopkOp should not be null."
);
...
...
paddle/operators/transpose_op.cc
浏览文件 @
73a8b78a
...
...
@@ -23,7 +23,6 @@ class TransposeOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null"
);
...
...
@@ -92,7 +91,6 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
73a8b78a
...
...
@@ -46,7 +46,6 @@ class UniformRandomOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of UniformRandomOp should not be null."
);
...
...
@@ -63,6 +62,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
temp
));
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
Attr
<
int
>
(
"data_type"
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录