Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e12ec95a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e12ec95a
编写于
10月 09, 2017
作者:
Q
Qiao Longfei
提交者:
GitHub
10月 09, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4630 from jacquesqiao/merge-infershapecontext
Merge infershapecontext and ExecutionContext
上级
0ff540cc
c0a34e1c
变更
48
隐藏空白更改
内联
并排
Showing
48 changed file
with
106 addition
and
115 deletion
+106
-115
paddle/framework/operator.cc
paddle/framework/operator.cc
+4
-4
paddle/framework/operator.h
paddle/framework/operator.h
+23
-32
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+1
-1
paddle/framework/shape_inference.h
paddle/framework/shape_inference.h
+3
-3
paddle/operators/accuracy_op.cc
paddle/operators/accuracy_op.cc
+1
-1
paddle/operators/activation_op.cc
paddle/operators/activation_op.cc
+2
-2
paddle/operators/adadelta_op.cc
paddle/operators/adadelta_op.cc
+1
-1
paddle/operators/adagrad_op.cc
paddle/operators/adagrad_op.cc
+1
-1
paddle/operators/clip_op.cc
paddle/operators/clip_op.cc
+2
-2
paddle/operators/concat_op.cc
paddle/operators/concat_op.cc
+2
-2
paddle/operators/conv2d_op.cc
paddle/operators/conv2d_op.cc
+2
-2
paddle/operators/cos_sim_op.cc
paddle/operators/cos_sim_op.cc
+2
-2
paddle/operators/crop_op.cc
paddle/operators/crop_op.cc
+2
-2
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+2
-2
paddle/operators/dropout_op.cc
paddle/operators/dropout_op.cc
+2
-2
paddle/operators/elementwise_op.h
paddle/operators/elementwise_op.h
+2
-2
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+1
-1
paddle/operators/gather_op.cc
paddle/operators/gather_op.cc
+2
-2
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+1
-1
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+2
-2
paddle/operators/lstm_unit_op.cc
paddle/operators/lstm_unit_op.cc
+2
-2
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+2
-2
paddle/operators/minus_op.cc
paddle/operators/minus_op.cc
+1
-1
paddle/operators/modified_huber_loss_op.cc
paddle/operators/modified_huber_loss_op.cc
+2
-2
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+2
-2
paddle/operators/multiplex_op.cc
paddle/operators/multiplex_op.cc
+2
-2
paddle/operators/pad_op.cc
paddle/operators/pad_op.cc
+2
-2
paddle/operators/pool_op.cc
paddle/operators/pool_op.cc
+2
-2
paddle/operators/prelu_op.cc
paddle/operators/prelu_op.cc
+2
-2
paddle/operators/rank_loss_op.cc
paddle/operators/rank_loss_op.cc
+2
-2
paddle/operators/reduce_op.cc
paddle/operators/reduce_op.cc
+2
-2
paddle/operators/reshape_op.cc
paddle/operators/reshape_op.cc
+2
-2
paddle/operators/rmsprop_op.cc
paddle/operators/rmsprop_op.cc
+1
-1
paddle/operators/scale_op.cc
paddle/operators/scale_op.cc
+1
-1
paddle/operators/scatter_op.cc
paddle/operators/scatter_op.cc
+2
-2
paddle/operators/sequence_pool_op.cc
paddle/operators/sequence_pool_op.cc
+2
-2
paddle/operators/sequence_softmax_op.cc
paddle/operators/sequence_softmax_op.cc
+2
-2
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+1
-1
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
+2
-2
paddle/operators/smooth_l1_loss_op.cc
paddle/operators/smooth_l1_loss_op.cc
+2
-2
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+2
-2
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+2
-2
paddle/operators/split_op.cc
paddle/operators/split_op.cc
+1
-1
paddle/operators/squared_l2_distance_op.cc
paddle/operators/squared_l2_distance_op.cc
+2
-2
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+1
-1
paddle/operators/top_k_op.cc
paddle/operators/top_k_op.cc
+1
-1
paddle/operators/transpose_op.cc
paddle/operators/transpose_op.cc
+2
-2
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+1
-1
未找到文件。
paddle/framework/operator.cc
浏览文件 @
e12ec95a
...
@@ -205,13 +205,13 @@ void OperatorBase::GenerateTemporaryNames() {
...
@@ -205,13 +205,13 @@ void OperatorBase::GenerateTemporaryNames() {
}
}
template
<
>
template
<
>
const
Tensor
*
InferShape
Context
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
const
Tensor
*
Execution
Context
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
auto
*
var
=
InputVar
(
name
);
auto
*
var
=
InputVar
(
name
);
return
var
==
nullptr
?
nullptr
:
GetTensorFromVar
(
var
);
return
var
==
nullptr
?
nullptr
:
GetTensorFromVar
(
var
);
}
}
template
<
>
template
<
>
const
std
::
vector
<
const
Tensor
*>
InferShape
Context
::
MultiInput
<
Tensor
>
(
const
std
::
vector
<
const
Tensor
*>
Execution
Context
::
MultiInput
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
auto
names
=
op
().
Inputs
(
name
);
auto
names
=
op
().
Inputs
(
name
);
std
::
vector
<
const
Tensor
*>
res
;
std
::
vector
<
const
Tensor
*>
res
;
...
@@ -225,13 +225,13 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
...
@@ -225,13 +225,13 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
}
}
template
<
>
template
<
>
Tensor
*
InferShape
Context
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
Tensor
*
Execution
Context
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
auto
var
=
OutputVar
(
name
);
auto
var
=
OutputVar
(
name
);
return
var
==
nullptr
?
nullptr
:
var
->
GetMutable
<
LoDTensor
>
();
return
var
==
nullptr
?
nullptr
:
var
->
GetMutable
<
LoDTensor
>
();
}
}
template
<
>
template
<
>
std
::
vector
<
Tensor
*>
InferShape
Context
::
MultiOutput
<
Tensor
>
(
std
::
vector
<
Tensor
*>
Execution
Context
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
auto
names
=
op
().
Outputs
(
name
);
auto
names
=
op
().
Outputs
(
name
);
std
::
vector
<
Tensor
*>
res
;
std
::
vector
<
Tensor
*>
res
;
...
...
paddle/framework/operator.h
浏览文件 @
e12ec95a
...
@@ -57,7 +57,6 @@ inline std::string GradVarName(const std::string& var_name) {
...
@@ -57,7 +57,6 @@ inline std::string GradVarName(const std::string& var_name) {
}
}
class
OperatorBase
;
class
OperatorBase
;
class
InferShapeContext
;
class
ExecutionContext
;
class
ExecutionContext
;
extern
const
Tensor
*
GetTensorFromVar
(
const
Variable
*
var
);
extern
const
Tensor
*
GetTensorFromVar
(
const
Variable
*
var
);
...
@@ -169,10 +168,11 @@ class NOP : public OperatorBase {
...
@@ -169,10 +168,11 @@ class NOP : public OperatorBase {
}
}
};
};
class
InferShape
Context
{
class
Execution
Context
{
public:
public:
InferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
:
op_
(
op
),
scope_
(
scope
)
{}
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
OperatorBase
&
op
()
const
{
return
op_
;
}
const
OperatorBase
&
op
()
const
{
return
op_
;
}
...
@@ -278,31 +278,6 @@ class InferShapeContext {
...
@@ -278,31 +278,6 @@ class InferShapeContext {
out_tensor
->
set_lod
(
in_tensor
.
lod
());
out_tensor
->
set_lod
(
in_tensor
.
lod
());
}
}
private:
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
};
template
<>
const
Tensor
*
InferShapeContext
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
const
std
::
vector
<
const
Tensor
*>
InferShapeContext
::
MultiInput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
Tensor
*
InferShapeContext
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
std
::
vector
<
Tensor
*>
InferShapeContext
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
class
ExecutionContext
:
public
InferShapeContext
{
public:
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
InferShapeContext
(
op
,
scope
),
device_context_
(
device_context
)
{}
template
<
typename
PlaceType
,
template
<
typename
PlaceType
,
typename
DeviceType
=
typename
platform
::
EigenDeviceConverter
<
typename
DeviceType
=
typename
platform
::
EigenDeviceConverter
<
PlaceType
>::
EigenDeviceType
>
PlaceType
>::
EigenDeviceType
>
...
@@ -315,10 +290,26 @@ class ExecutionContext : public InferShapeContext {
...
@@ -315,10 +290,26 @@ class ExecutionContext : public InferShapeContext {
}
}
private:
private:
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
const
platform
::
DeviceContext
&
device_context_
;
};
};
class
CompileTimeInferShapeContext
:
public
InferShapeContextBase
{
template
<>
const
Tensor
*
ExecutionContext
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
const
std
::
vector
<
const
Tensor
*>
ExecutionContext
::
MultiInput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
Tensor
*
ExecutionContext
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
std
::
vector
<
Tensor
*>
ExecutionContext
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
class
CompileTimeInferShapeContext
:
public
InferShapeContext
{
public:
public:
CompileTimeInferShapeContext
(
const
OpDescBind
&
op
,
const
BlockDescBind
&
block
)
CompileTimeInferShapeContext
(
const
OpDescBind
&
op
,
const
BlockDescBind
&
block
)
:
op_
(
op
),
block_
(
block
)
{}
:
op_
(
op
),
block_
(
block
)
{}
...
@@ -414,7 +405,7 @@ class CompileTimeInferShapeContext : public InferShapeContextBase {
...
@@ -414,7 +405,7 @@ class CompileTimeInferShapeContext : public InferShapeContextBase {
const
BlockDescBind
&
block_
;
const
BlockDescBind
&
block_
;
};
};
class
RuntimeInferShapeContext
:
public
InferShapeContext
Base
{
class
RuntimeInferShapeContext
:
public
InferShapeContext
{
public:
public:
RuntimeInferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
RuntimeInferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
:
op_
(
op
),
scope_
(
scope
)
{}
:
op_
(
op
),
scope_
(
scope
)
{}
...
@@ -612,7 +603,7 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -612,7 +603,7 @@ class OperatorWithKernel : public OperatorBase {
});
});
}
}
virtual
void
InferShape
(
InferShapeContext
Base
*
ctx
)
const
=
0
;
virtual
void
InferShape
(
InferShapeContext
*
ctx
)
const
=
0
;
protected:
protected:
// indicate kernel DataType by input data. Defaultly all input data must be
// indicate kernel DataType by input data. Defaultly all input data must be
...
...
paddle/framework/operator_test.cc
浏览文件 @
e12ec95a
...
@@ -113,7 +113,7 @@ class OpWithKernelTest : public OperatorWithKernel {
...
@@ -113,7 +113,7 @@ class OpWithKernelTest : public OperatorWithKernel {
using
OperatorWithKernel
::
OperatorWithKernel
;
using
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
override
{
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
override
{
return
DataType
::
FP32
;
return
DataType
::
FP32
;
}
}
...
...
paddle/framework/shape_inference.h
浏览文件 @
e12ec95a
...
@@ -20,11 +20,11 @@ namespace paddle {
...
@@ -20,11 +20,11 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
// TODO(longfei): Once after both CompileTimeInferShapeContext and
// TODO(longfei): Once after both CompileTimeInferShapeContext and
// RuntimeInferShapeContext get merged, we can rename InferShapeContext
Base
into
// RuntimeInferShapeContext get merged, we can rename InferShapeContext into
// InferShapeContext so to replace the current InferShapeContext.
// InferShapeContext so to replace the current InferShapeContext.
class
InferShapeContext
Base
{
class
InferShapeContext
{
public:
public:
virtual
~
InferShapeContext
Base
()
{}
virtual
~
InferShapeContext
()
{}
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
=
0
;
...
...
paddle/operators/accuracy_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Inference"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Inference"
),
"Input(Inference) of AccuracyOp should not be null."
);
"Input(Inference) of AccuracyOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
...
...
paddle/operators/activation_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class ActivationOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class ActivationOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
}
}
...
@@ -33,7 +33,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
...
@@ -33,7 +33,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Y"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Y"
));
}
}
};
};
...
...
paddle/operators/adadelta_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdadeltaOp should not be null."
);
"Input(Param) of AdadeltaOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
paddle/operators/adagrad_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class AdagradOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class AdagradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdagradOp should not be null."
);
"Input(Param) of AdagradOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
paddle/operators/clip_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class ClipOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class ClipOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ClipOp should not be null."
);
"Input(X) of ClipOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -61,7 +61,7 @@ class ClipOpGrad : public framework::OperatorWithKernel {
...
@@ -61,7 +61,7 @@ class ClipOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/concat_op.cc
浏览文件 @
e12ec95a
...
@@ -24,7 +24,7 @@ class ConcatOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,7 @@ class ConcatOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_GE
(
ctx
->
Inputs
(
"X"
).
size
(),
1UL
,
PADDLE_ENFORCE_GE
(
ctx
->
Inputs
(
"X"
).
size
(),
1UL
,
"Inputs(X) of ConcatOp should be empty."
)
"Inputs(X) of ConcatOp should be empty."
)
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -83,7 +83,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
...
@@ -83,7 +83,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputsDim
(
"X"
));
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputsDim
(
"X"
));
}
}
};
};
...
...
paddle/operators/conv2d_op.cc
浏览文件 @
e12ec95a
...
@@ -27,7 +27,7 @@ class Conv2DOp : public framework::OperatorWithKernel {
...
@@ -27,7 +27,7 @@ class Conv2DOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of Conv2DOp should not be null."
);
"Input(Input) of Conv2DOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Filter"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Filter"
),
...
@@ -106,7 +106,7 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
...
@@ -106,7 +106,7 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
filter_dims
=
ctx
->
GetInputDim
(
"Filter"
);
auto
filter_dims
=
ctx
->
GetInputDim
(
"Filter"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Input"
)))
{
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Input"
)))
{
...
...
paddle/operators/cos_sim_op.cc
浏览文件 @
e12ec95a
...
@@ -24,7 +24,7 @@ class CosSimOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,7 @@ class CosSimOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// notnull check
// notnull check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of CosSimOp should not be null."
);
"Input(X) of CosSimOp should not be null."
);
...
@@ -98,7 +98,7 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
...
@@ -98,7 +98,7 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// notnull check
// notnull check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) must not be null."
);
...
...
paddle/operators/crop_op.cc
浏览文件 @
e12ec95a
...
@@ -25,7 +25,7 @@ class CropOp : public framework::OperatorWithKernel {
...
@@ -25,7 +25,7 @@ class CropOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of CropOp should not be null."
);
"Input(X) of CropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -115,7 +115,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
...
@@ -115,7 +115,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
"Output(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
"Output(Y) should be not null."
);
...
@@ -60,7 +60,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
...
@@ -60,7 +60,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
...
...
paddle/operators/dropout_op.cc
浏览文件 @
e12ec95a
...
@@ -24,7 +24,7 @@ class DropoutOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,7 @@ class DropoutOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
0
);
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
0
);
PADDLE_ENFORCE_LE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
1
);
PADDLE_ENFORCE_LE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
1
);
...
@@ -70,7 +70,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
...
@@ -70,7 +70,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"is_training"
),
1
,
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"is_training"
),
1
,
"GradOp is only callable when is_training is true"
);
"GradOp is only callable when is_training is true"
);
...
...
paddle/operators/elementwise_op.h
浏览文件 @
e12ec95a
...
@@ -25,7 +25,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
...
@@ -25,7 +25,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
protected:
protected:
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of elementwise op should not be null"
);
"Input(X) of elementwise op should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
...
@@ -106,7 +106,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
...
@@ -106,7 +106,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FillZerosLikeOp should not be null."
);
"Input(X) of FillZerosLikeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
...
...
paddle/operators/gather_op.cc
浏览文件 @
e12ec95a
...
@@ -23,7 +23,7 @@ class GatherOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,7 @@ class GatherOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of GatherOp should not be null."
);
"Input(X) of GatherOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Index"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Index"
),
...
@@ -51,7 +51,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
...
@@ -51,7 +51,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
}
...
...
paddle/operators/gaussian_random_op.cc
浏览文件 @
e12ec95a
...
@@ -43,7 +43,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
...
@@ -43,7 +43,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of GaussianRandomOp should not be null."
);
"Output(Out) of GaussianRandomOp should not be null."
);
auto
dims
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dims"
);
auto
dims
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dims"
);
...
...
paddle/operators/lookup_table_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
"Input(W) of LookupTableOp should not be null."
);
"Input(W) of LookupTableOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
...
@@ -70,7 +70,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
...
@@ -70,7 +70,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"W"
),
table_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"W"
),
table_dims
);
}
}
...
...
paddle/operators/lstm_unit_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class LstmUnitOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class LstmUnitOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C_prev"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C_prev"
),
"Input(C_prev) of LSTM should not be null."
);
"Input(C_prev) of LSTM should not be null."
);
...
@@ -77,7 +77,7 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
...
@@ -77,7 +77,7 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"C"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"C"
)),
"Input(C@GRAD) should not be null"
);
"Input(C@GRAD) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"H"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"H"
)),
...
...
paddle/operators/mean_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class MeanOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class MeanOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MeanOp should not be null."
);
"Input(X) of MeanOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -47,7 +47,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
...
@@ -47,7 +47,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
}
};
};
...
...
paddle/operators/minus_op.cc
浏览文件 @
e12ec95a
...
@@ -26,7 +26,7 @@ class MinusOp : public framework::OperatorWithKernel {
...
@@ -26,7 +26,7 @@ class MinusOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MinusOp should not be null."
);
"Input(X) of MinusOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
...
...
paddle/operators/modified_huber_loss_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
@@ -74,7 +74,7 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
...
@@ -74,7 +74,7 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"IntermediateVal"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"IntermediateVal"
),
...
...
paddle/operators/mul_op.cc
浏览文件 @
e12ec95a
...
@@ -24,7 +24,7 @@ class MulOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,7 @@ class MulOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -97,7 +97,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
...
@@ -97,7 +97,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/multiplex_op.cc
浏览文件 @
e12ec95a
...
@@ -24,7 +24,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) shouldn't be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
"MultiInput(X) shouldn't be empty."
);
"MultiInput(X) shouldn't be empty."
);
...
@@ -90,7 +90,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
...
@@ -90,7 +90,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
empty
(),
PADDLE_ENFORCE
(
!
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
empty
(),
"Output(X@Grad) should not be null."
);
"Output(X@Grad) should not be null."
);
...
...
paddle/operators/pad_op.cc
浏览文件 @
e12ec95a
...
@@ -24,7 +24,7 @@ class PadOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,7 @@ class PadOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of PadOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of PadOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of PadOp should not be null."
);
"Output(Out) of PadOp should not be null."
);
...
@@ -98,7 +98,7 @@ class PadOpGrad : public framework::OperatorWithKernel {
...
@@ -98,7 +98,7 @@ class PadOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/pool_op.cc
浏览文件 @
e12ec95a
...
@@ -27,7 +27,7 @@ class PoolOp : public framework::OperatorWithKernel {
...
@@ -27,7 +27,7 @@ class PoolOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X(Input) of Pooling should not be null."
);
"X(Input) of Pooling should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -74,7 +74,7 @@ class PoolOpGrad : public framework::OperatorWithKernel {
...
@@ -74,7 +74,7 @@ class PoolOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X(Input) of Pooling should not be null."
);
"X(Input) of Pooling should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
...
...
paddle/operators/prelu_op.cc
浏览文件 @
e12ec95a
...
@@ -26,7 +26,7 @@ class PReluOp : public framework::OperatorWithKernel {
...
@@ -26,7 +26,7 @@ class PReluOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Alpha"
),
"Input(Alpha) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Alpha"
),
"Input(Alpha) should not be null"
);
PADDLE_ENFORCE
(
product
(
ctx
->
GetInputDim
(
"Alpha"
))
==
1
,
PADDLE_ENFORCE
(
product
(
ctx
->
GetInputDim
(
"Alpha"
))
==
1
,
...
@@ -63,7 +63,7 @@ class PReluGradOp : public framework::OperatorWithKernel {
...
@@ -63,7 +63,7 @@ class PReluGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/rank_loss_op.cc
浏览文件 @
e12ec95a
...
@@ -25,7 +25,7 @@ class RankLossOp : public framework::OperatorWithKernel {
...
@@ -25,7 +25,7 @@ class RankLossOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null"
);
...
@@ -90,7 +90,7 @@ class RankLossGradOp : public framework::OperatorWithKernel {
...
@@ -90,7 +90,7 @@ class RankLossGradOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Right"
),
"Input(Right) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Right"
),
"Input(Right) shouldn't be null."
);
...
...
paddle/operators/reduce_op.cc
浏览文件 @
e12ec95a
...
@@ -24,7 +24,7 @@ class ReduceOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,7 @@ class ReduceOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReduceOp should not be null."
);
"Input(X) of ReduceOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -58,7 +58,7 @@ class ReduceGradOp : public framework::OperatorWithKernel {
...
@@ -58,7 +58,7 @@ class ReduceGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null."
);
"Input(Out@GRAD) should not be null."
);
...
...
paddle/operators/reshape_op.cc
浏览文件 @
e12ec95a
...
@@ -26,7 +26,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
...
@@ -26,7 +26,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReshapeOp should not be null."
);
"Input(X) of ReshapeOp should not be null."
);
...
@@ -94,7 +94,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
...
@@ -94,7 +94,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) shouldn't be null."
);
"Input(Out@GRAD) shouldn't be null."
);
...
...
paddle/operators/rmsprop_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class RmspropOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class RmspropOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of RmspropOp should not be null."
);
"Input(Param) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"MeanSquare"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"MeanSquare"
),
...
...
paddle/operators/scale_op.cc
浏览文件 @
e12ec95a
...
@@ -26,7 +26,7 @@ class ScaleOp : public framework::OperatorWithKernel {
...
@@ -26,7 +26,7 @@ class ScaleOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ScaleOp should not be null."
);
"Input(X) of ScaleOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
paddle/operators/scatter_op.cc
浏览文件 @
e12ec95a
...
@@ -23,7 +23,7 @@ class ScatterOp : public framework::OperatorWithKernel {
...
@@ -23,7 +23,7 @@ class ScatterOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ref"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ref"
),
"Input(Ref) of ScatterOp should not be null."
);
"Input(Ref) of ScatterOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Index"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Index"
),
...
@@ -60,7 +60,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
...
@@ -60,7 +60,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Updates"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Updates"
),
ctx
->
GetInputDim
(
"Updates"
));
ctx
->
GetInputDim
(
"Updates"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ref"
),
ctx
->
GetInputDim
(
"Ref"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ref"
),
ctx
->
GetInputDim
(
"Ref"
));
...
...
paddle/operators/sequence_pool_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class SequencePoolOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class SequencePoolOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequencePoolOp should not be null."
);
"Input(X) of SequencePoolOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -74,7 +74,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
...
@@ -74,7 +74,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Gradient of Out should not be null."
);
"Gradient of Out should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"The input X should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"The input X should not be null."
);
...
...
paddle/operators/sequence_softmax_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequenceSoftmaxOp should not be null."
);
"Input(X) of SequenceSoftmaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -67,7 +67,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
...
@@ -67,7 +67,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Out"
),
"Input(Out) of SequenceSoftmaxGradOp should not be null."
);
"Input(Out) of SequenceSoftmaxGradOp should not be null."
);
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
...
...
paddle/operators/sgd_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class SGDOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class SGDOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of SGDOp should not be null."
);
"Input(Param) of SGDOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
浏览文件 @
e12ec95a
...
@@ -24,7 +24,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
"Input(Labels) should be not null."
);
"Input(Labels) should be not null."
);
...
@@ -53,7 +53,7 @@ class SigmoidCrossEntropyWithLogitsGradOp
...
@@ -53,7 +53,7 @@ class SigmoidCrossEntropyWithLogitsGradOp
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
"Input(Labels) should be not null."
);
"Input(Labels) should be not null."
);
...
...
paddle/operators/smooth_l1_loss_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
@@ -94,7 +94,7 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
...
@@ -94,7 +94,7 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
...
...
paddle/operators/softmax_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SoftmaxOp should not be null."
);
"Input(X) of SoftmaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
...
@@ -69,7 +69,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
...
@@ -69,7 +69,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) should be not null."
);
"Input(Y@GRAD) should be not null."
);
...
...
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
e12ec95a
...
@@ -83,7 +83,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -83,7 +83,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
"Input(Logits) should be not null."
);
"Input(Logits) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
...
@@ -128,7 +128,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
...
@@ -128,7 +128,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
"Input(Loss@Grad) should not be null."
);
"Input(Loss@Grad) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Softmax"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Softmax"
),
...
...
paddle/operators/split_op.cc
浏览文件 @
e12ec95a
...
@@ -24,7 +24,7 @@ class SplitOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,7 @@ class SplitOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SplitOp should not be null."
);
"Input(X) of SplitOp should not be null."
);
PADDLE_ENFORCE_GE
(
ctx
->
Outputs
(
"Out"
).
size
(),
1UL
,
PADDLE_ENFORCE_GE
(
ctx
->
Outputs
(
"Out"
).
size
(),
1UL
,
...
...
paddle/operators/squared_l2_distance_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SquaredL2DistanceOp should not be null."
);
"Input(X) of SquaredL2DistanceOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
...
@@ -86,7 +86,7 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
...
@@ -86,7 +86,7 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Gradient of Out should not be null"
);
"Gradient of Out should not be null"
);
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
...
...
paddle/operators/sum_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class SumOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class SumOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) should not be null"
);
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
paddle/operators/top_k_op.cc
浏览文件 @
e12ec95a
...
@@ -22,7 +22,7 @@ class TopkOp : public framework::OperatorWithKernel {
...
@@ -22,7 +22,7 @@ class TopkOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of TopkOp should not be null."
);
"Input(X) of TopkOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
paddle/operators/transpose_op.cc
浏览文件 @
e12ec95a
...
@@ -24,7 +24,7 @@ class TransposeOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,7 @@ class TransposeOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
...
@@ -93,7 +93,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
...
@@ -93,7 +93,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
e12ec95a
...
@@ -47,7 +47,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
...
@@ -47,7 +47,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of UniformRandomOp should not be null."
);
"Output(Out) of UniformRandomOp should not be null."
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录