Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3a8de910
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3a8de910
编写于
10月 09, 2017
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into dev_add_attr_test
上级
5984cbca
383faaf7
变更
56
隐藏空白更改
内联
并排
Showing
56 changed file
with
245 addition
and
121 deletion
+245
-121
paddle/framework/backward.cc
paddle/framework/backward.cc
+1
-1
paddle/framework/op_desc.h
paddle/framework/op_desc.h
+5
-0
paddle/framework/op_registry.cc
paddle/framework/op_registry.cc
+8
-3
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+1
-1
paddle/framework/operator.cc
paddle/framework/operator.cc
+4
-4
paddle/framework/operator.h
paddle/framework/operator.h
+23
-32
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+1
-1
paddle/framework/shape_inference.h
paddle/framework/shape_inference.h
+3
-3
paddle/framework/tensor.h
paddle/framework/tensor.h
+13
-0
paddle/framework/tensor_impl.h
paddle/framework/tensor_impl.h
+23
-0
paddle/framework/tensor_test.cc
paddle/framework/tensor_test.cc
+87
-0
paddle/operators/accuracy_op.cc
paddle/operators/accuracy_op.cc
+1
-1
paddle/operators/activation_op.cc
paddle/operators/activation_op.cc
+2
-2
paddle/operators/adadelta_op.cc
paddle/operators/adadelta_op.cc
+1
-1
paddle/operators/adagrad_op.cc
paddle/operators/adagrad_op.cc
+1
-1
paddle/operators/adamax_op.cc
paddle/operators/adamax_op.cc
+1
-1
paddle/operators/clip_op.cc
paddle/operators/clip_op.cc
+2
-2
paddle/operators/concat_op.cc
paddle/operators/concat_op.cc
+2
-2
paddle/operators/conv2d_op.cc
paddle/operators/conv2d_op.cc
+2
-2
paddle/operators/cos_sim_op.cc
paddle/operators/cos_sim_op.cc
+2
-2
paddle/operators/crop_op.cc
paddle/operators/crop_op.cc
+2
-2
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+2
-2
paddle/operators/dropout_op.cc
paddle/operators/dropout_op.cc
+2
-2
paddle/operators/elementwise_op.h
paddle/operators/elementwise_op.h
+2
-2
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+1
-1
paddle/operators/gather_op.cc
paddle/operators/gather_op.cc
+2
-2
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+1
-1
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+2
-2
paddle/operators/lstm_unit_op.cc
paddle/operators/lstm_unit_op.cc
+2
-2
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+2
-2
paddle/operators/minus_op.cc
paddle/operators/minus_op.cc
+1
-1
paddle/operators/modified_huber_loss_op.cc
paddle/operators/modified_huber_loss_op.cc
+2
-2
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+2
-2
paddle/operators/multiplex_op.cc
paddle/operators/multiplex_op.cc
+2
-2
paddle/operators/pad_op.cc
paddle/operators/pad_op.cc
+2
-2
paddle/operators/pool_op.cc
paddle/operators/pool_op.cc
+2
-2
paddle/operators/prelu_op.cc
paddle/operators/prelu_op.cc
+2
-2
paddle/operators/rank_loss_op.cc
paddle/operators/rank_loss_op.cc
+2
-2
paddle/operators/reduce_op.cc
paddle/operators/reduce_op.cc
+2
-2
paddle/operators/reshape_op.cc
paddle/operators/reshape_op.cc
+2
-2
paddle/operators/rmsprop_op.cc
paddle/operators/rmsprop_op.cc
+1
-1
paddle/operators/scale_op.cc
paddle/operators/scale_op.cc
+1
-1
paddle/operators/scatter_op.cc
paddle/operators/scatter_op.cc
+2
-2
paddle/operators/sequence_pool_op.cc
paddle/operators/sequence_pool_op.cc
+2
-2
paddle/operators/sequence_softmax_op.cc
paddle/operators/sequence_softmax_op.cc
+2
-2
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+1
-1
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
+2
-2
paddle/operators/smooth_l1_loss_op.cc
paddle/operators/smooth_l1_loss_op.cc
+2
-2
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+2
-2
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+2
-2
paddle/operators/split_op.cc
paddle/operators/split_op.cc
+1
-1
paddle/operators/squared_l2_distance_op.cc
paddle/operators/squared_l2_distance_op.cc
+2
-2
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+1
-1
paddle/operators/top_k_op.cc
paddle/operators/top_k_op.cc
+1
-1
paddle/operators/transpose_op.cc
paddle/operators/transpose_op.cc
+2
-2
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+1
-1
未找到文件。
paddle/framework/backward.cc
浏览文件 @
3a8de910
...
...
@@ -302,7 +302,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return
grad_op_descs
;
// empty vector
}
grad_op_descs
=
OpRegistry
::
CreateGradOpDescs
(
*
op_desc
);
grad_op_descs
=
OpRegistry
::
CreateGradOpDescs
(
op_desc
.
get
()
);
std
::
list
<
std
::
unique_ptr
<
OpDescBind
>>
pending_fill_zeros_ops
;
for
(
auto
&
desc
:
grad_op_descs
)
{
...
...
paddle/framework/op_desc.h
浏览文件 @
3a8de910
...
...
@@ -97,6 +97,11 @@ class OpDescBind {
const
VariableNameMap
&
Outputs
()
const
{
return
outputs_
;
}
AttributeMap
*
MutableAttrMap
()
{
this
->
need_update_
=
true
;
return
&
this
->
attrs_
;
}
private:
template
<
typename
MapType
>
static
std
::
vector
<
typename
MapType
::
key_type
>
MapKeys
(
const
MapType
&
map
)
{
...
...
paddle/framework/op_registry.cc
浏览文件 @
3a8de910
...
...
@@ -60,9 +60,14 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
}
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
OpRegistry
::
CreateGradOpDescs
(
const
OpDescBind
&
op_desc
)
{
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
op_desc
.
Type
());
return
info
.
grad_op_maker_
(
op_desc
);
OpDescBind
*
op_desc
)
{
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
op_desc
->
Type
());
if
(
info
.
Checker
()
!=
nullptr
)
{
info
.
Checker
()
->
Check
(
*
op_desc
->
MutableAttrMap
());
}
return
info
.
grad_op_maker_
(
*
op_desc
);
}
}
// namespace framework
...
...
paddle/framework/op_registry.h
浏览文件 @
3a8de910
...
...
@@ -80,7 +80,7 @@ class OpRegistry {
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
);
static
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
CreateGradOpDescs
(
const
OpDescBind
&
op_desc
);
OpDescBind
*
op_desc
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDescBind
&
op_desc
);
};
...
...
paddle/framework/operator.cc
浏览文件 @
3a8de910
...
...
@@ -205,13 +205,13 @@ void OperatorBase::GenerateTemporaryNames() {
}
template
<
>
const
Tensor
*
InferShape
Context
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
const
Tensor
*
Execution
Context
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
auto
*
var
=
InputVar
(
name
);
return
var
==
nullptr
?
nullptr
:
GetTensorFromVar
(
var
);
}
template
<
>
const
std
::
vector
<
const
Tensor
*>
InferShape
Context
::
MultiInput
<
Tensor
>
(
const
std
::
vector
<
const
Tensor
*>
Execution
Context
::
MultiInput
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
auto
names
=
op
().
Inputs
(
name
);
std
::
vector
<
const
Tensor
*>
res
;
...
...
@@ -225,13 +225,13 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
}
template
<
>
Tensor
*
InferShape
Context
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
Tensor
*
Execution
Context
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
auto
var
=
OutputVar
(
name
);
return
var
==
nullptr
?
nullptr
:
var
->
GetMutable
<
LoDTensor
>
();
}
template
<
>
std
::
vector
<
Tensor
*>
InferShape
Context
::
MultiOutput
<
Tensor
>
(
std
::
vector
<
Tensor
*>
Execution
Context
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
auto
names
=
op
().
Outputs
(
name
);
std
::
vector
<
Tensor
*>
res
;
...
...
paddle/framework/operator.h
浏览文件 @
3a8de910
...
...
@@ -57,7 +57,6 @@ inline std::string GradVarName(const std::string& var_name) {
}
class
OperatorBase
;
class
InferShapeContext
;
class
ExecutionContext
;
extern
const
Tensor
*
GetTensorFromVar
(
const
Variable
*
var
);
...
...
@@ -169,10 +168,11 @@ class NOP : public OperatorBase {
}
};
class
InferShape
Context
{
class
Execution
Context
{
public:
InferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
:
op_
(
op
),
scope_
(
scope
)
{}
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
OperatorBase
&
op
()
const
{
return
op_
;
}
...
...
@@ -278,31 +278,6 @@ class InferShapeContext {
out_tensor
->
set_lod
(
in_tensor
.
lod
());
}
private:
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
};
template
<>
const
Tensor
*
InferShapeContext
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
const
std
::
vector
<
const
Tensor
*>
InferShapeContext
::
MultiInput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
Tensor
*
InferShapeContext
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
std
::
vector
<
Tensor
*>
InferShapeContext
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
class
ExecutionContext
:
public
InferShapeContext
{
public:
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
InferShapeContext
(
op
,
scope
),
device_context_
(
device_context
)
{}
template
<
typename
PlaceType
,
typename
DeviceType
=
typename
platform
::
EigenDeviceConverter
<
PlaceType
>::
EigenDeviceType
>
...
...
@@ -315,10 +290,26 @@ class ExecutionContext : public InferShapeContext {
}
private:
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
};
class
CompileTimeInferShapeContext
:
public
InferShapeContextBase
{
template
<>
const
Tensor
*
ExecutionContext
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
const
std
::
vector
<
const
Tensor
*>
ExecutionContext
::
MultiInput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
Tensor
*
ExecutionContext
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
std
::
vector
<
Tensor
*>
ExecutionContext
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
class
CompileTimeInferShapeContext
:
public
InferShapeContext
{
public:
CompileTimeInferShapeContext
(
const
OpDescBind
&
op
,
const
BlockDescBind
&
block
)
:
op_
(
op
),
block_
(
block
)
{}
...
...
@@ -414,7 +405,7 @@ class CompileTimeInferShapeContext : public InferShapeContextBase {
const
BlockDescBind
&
block_
;
};
class
RuntimeInferShapeContext
:
public
InferShapeContext
Base
{
class
RuntimeInferShapeContext
:
public
InferShapeContext
{
public:
RuntimeInferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
:
op_
(
op
),
scope_
(
scope
)
{}
...
...
@@ -612,7 +603,7 @@ class OperatorWithKernel : public OperatorBase {
});
}
virtual
void
InferShape
(
InferShapeContext
Base
*
ctx
)
const
=
0
;
virtual
void
InferShape
(
InferShapeContext
*
ctx
)
const
=
0
;
protected:
// indicate kernel DataType by input data. Defaultly all input data must be
...
...
paddle/framework/operator_test.cc
浏览文件 @
3a8de910
...
...
@@ -113,7 +113,7 @@ class OpWithKernelTest : public OperatorWithKernel {
using
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
override
{
return
DataType
::
FP32
;
}
...
...
paddle/framework/shape_inference.h
浏览文件 @
3a8de910
...
...
@@ -20,11 +20,11 @@ namespace paddle {
namespace
framework
{
// TODO(longfei): Once after both CompileTimeInferShapeContext and
// RuntimeInferShapeContext get merged, we can rename InferShapeContext
Base
into
// RuntimeInferShapeContext get merged, we can rename InferShapeContext into
// InferShapeContext so to replace the current InferShapeContext.
class
InferShapeContext
Base
{
class
InferShapeContext
{
public:
virtual
~
InferShapeContext
Base
()
{}
virtual
~
InferShapeContext
()
{}
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
=
0
;
...
...
paddle/framework/tensor.h
浏览文件 @
3a8de910
...
...
@@ -95,6 +95,19 @@ class Tensor {
template
<
typename
T
>
inline
void
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
);
/**
* @brief Copy the content of an external vector to a tensor.
*
* @param[in] src The external vector.
* @param[in] ctx The device context contains place where to store.
*
* * @note CopyFromVector assumes that the tensor has been resized
* before invoking.
*/
template
<
typename
T
>
inline
void
CopyFromVector
(
const
std
::
vector
<
T
>&
src
,
const
platform
::
Place
&
dst_place
);
/**
* @brief Return the slice of the tensor.
*
...
...
paddle/framework/tensor_impl.h
浏览文件 @
3a8de910
...
...
@@ -123,6 +123,29 @@ inline void Tensor::CopyFrom(const Tensor& src,
#endif
}
template
<
typename
T
>
inline
void
Tensor
::
CopyFromVector
(
const
std
::
vector
<
T
>&
src
,
const
platform
::
Place
&
dst_place
)
{
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
());
platform
::
CPUPlace
src_place
;
auto
dst_ptr
=
static_cast
<
void
*>
(
mutable_data
<
T
>
(
dst_place
));
auto
size
=
src
.
size
()
*
sizeof
(
T
);
if
(
platform
::
is_cpu_place
(
dst_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
CPUPlace
>
(
dst_place
),
dst_ptr
,
src_place
,
src_ptr
,
size
);
}
#ifdef PADDLE_WITH_CUDA
else
if
(
platform
::
is_gpu_place
(
dst_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
dst_place
),
dst_ptr
,
src_place
,
src_ptr
,
size
,
0
);
}
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
0
),
"cudaStreamSynchronize failed in Tensor CopyFromVector"
);
#endif
}
template
<
typename
T
>
inline
Tensor
Tensor
::
Slice
(
const
int
&
begin_idx
,
const
int
&
end_idx
)
const
{
check_memory_size
<
T
>
();
...
...
paddle/framework/tensor_test.cc
浏览文件 @
3a8de910
...
...
@@ -263,6 +263,93 @@ TEST(Tensor, CopyFrom) {
#endif
}
TEST
(
Tensor
,
CopyFromVector
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
{
std
::
vector
<
int
>
src_vec
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
Tensor
cpu_tensor
;
// Copy to CPU Tensor
cpu_tensor
.
Resize
(
make_ddim
({
3
,
3
}));
auto
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
cpu_tensor
.
CopyFromVector
<
int
>
(
src_vec
,
*
cpu_place
);
// Compare Tensors
const
int
*
cpu_ptr
=
cpu_tensor
.
data
<
int
>
();
const
int
*
src_ptr
=
src_vec
.
data
();
ASSERT_NE
(
src_ptr
,
cpu_ptr
);
for
(
size_t
i
=
0
;
i
<
9
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
cpu_ptr
[
i
]);
}
src_vec
.
erase
(
src_vec
.
begin
(),
src_vec
.
begin
()
+
5
);
cpu_tensor
.
Resize
(
make_ddim
({
2
,
2
}));
cpu_tensor
.
CopyFromVector
<
int
>
(
src_vec
,
*
cpu_place
);
cpu_ptr
=
cpu_tensor
.
data
<
int
>
();
src_ptr
=
src_vec
.
data
();
ASSERT_NE
(
src_ptr
,
cpu_ptr
);
for
(
size_t
i
=
0
;
i
<
5
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
cpu_ptr
[
i
]);
}
delete
cpu_place
;
}
#ifdef PADDLE_WITH_CUDA
{
std
::
vector
<
int
>
src_vec
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
Tensor
cpu_tensor
;
Tensor
gpu_tensor
;
Tensor
dst_tensor
;
// Copy to CPU Tensor
cpu_tensor
.
Resize
(
make_ddim
({
3
,
3
}));
auto
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
cpu_tensor
.
CopyFromVector
<
int
>
(
src_vec
,
*
cpu_place
);
// Copy to GPUTensor
gpu_tensor
.
Resize
(
make_ddim
({
3
,
3
}));
auto
gpu_place
=
new
paddle
::
platform
::
GPUPlace
();
gpu_tensor
.
CopyFromVector
<
int
>
(
src_vec
,
*
gpu_place
);
// Copy from GPU to CPU tensor for comparison
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_place
);
// Compare Tensors
const
int
*
src_ptr
=
src_vec
.
data
();
const
int
*
cpu_ptr
=
cpu_tensor
.
data
<
int
>
();
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
src_ptr
,
cpu_ptr
);
ASSERT_NE
(
src_ptr
,
dst_ptr
);
for
(
size_t
i
=
0
;
i
<
9
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
cpu_ptr
[
i
]);
EXPECT_EQ
(
src_ptr
[
i
],
dst_ptr
[
i
]);
}
src_vec
.
erase
(
src_vec
.
begin
(),
src_vec
.
begin
()
+
5
);
cpu_tensor
.
Resize
(
make_ddim
({
2
,
2
}));
cpu_tensor
.
CopyFromVector
<
int
>
(
src_vec
,
*
cpu_place
);
gpu_tensor
.
Resize
(
make_ddim
({
2
,
2
}));
gpu_tensor
.
CopyFromVector
<
int
>
(
src_vec
,
*
gpu_place
);
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_place
);
src_ptr
=
src_vec
.
data
();
cpu_ptr
=
cpu_tensor
.
data
<
int
>
();
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
src_ptr
,
cpu_ptr
);
ASSERT_NE
(
src_ptr
,
dst_ptr
);
for
(
size_t
i
=
0
;
i
<
5
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
cpu_ptr
[
i
]);
EXPECT_EQ
(
src_ptr
[
i
],
dst_ptr
[
i
]);
}
delete
cpu_place
;
delete
gpu_place
;
}
#endif
}
TEST
(
Tensor
,
ReshapeToMatrix
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
...
...
paddle/operators/accuracy_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Inference"
),
"Input(Inference) of AccuracyOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
...
...
paddle/operators/activation_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class ActivationOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
}
...
...
@@ -33,7 +33,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Y"
));
}
};
...
...
paddle/operators/adadelta_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdadeltaOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
paddle/operators/adagrad_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class AdagradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdagradOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
paddle/operators/adamax_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class AdamaxOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdamaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
paddle/operators/clip_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class ClipOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ClipOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -61,7 +61,7 @@ class ClipOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/concat_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class ConcatOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_GE
(
ctx
->
Inputs
(
"X"
).
size
(),
1UL
,
"Inputs(X) of ConcatOp should be empty."
)
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -83,7 +83,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputsDim
(
"X"
));
}
};
...
...
paddle/operators/conv2d_op.cc
浏览文件 @
3a8de910
...
...
@@ -27,7 +27,7 @@ class Conv2DOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of Conv2DOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Filter"
),
...
...
@@ -106,7 +106,7 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
filter_dims
=
ctx
->
GetInputDim
(
"Filter"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Input"
)))
{
...
...
paddle/operators/cos_sim_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class CosSimOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// notnull check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of CosSimOp should not be null."
);
...
...
@@ -98,7 +98,7 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// notnull check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) must not be null."
);
...
...
paddle/operators/crop_op.cc
浏览文件 @
3a8de910
...
...
@@ -25,7 +25,7 @@ class CropOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of CropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -115,7 +115,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
"Output(Y) should be not null."
);
...
...
@@ -60,7 +60,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
...
...
paddle/operators/dropout_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class DropoutOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
0
);
PADDLE_ENFORCE_LE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
1
);
...
...
@@ -70,7 +70,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"is_training"
),
1
,
"GradOp is only callable when is_training is true"
);
...
...
paddle/operators/elementwise_op.h
浏览文件 @
3a8de910
...
...
@@ -25,7 +25,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
protected:
using
Tensor
=
framework
::
Tensor
;
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of elementwise op should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
...
...
@@ -106,7 +106,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using
Tensor
=
framework
::
Tensor
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FillZerosLikeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
...
...
paddle/operators/gather_op.cc
浏览文件 @
3a8de910
...
...
@@ -23,7 +23,7 @@ class GatherOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of GatherOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Index"
),
...
...
@@ -51,7 +51,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
...
...
paddle/operators/gaussian_random_op.cc
浏览文件 @
3a8de910
...
...
@@ -43,7 +43,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of GaussianRandomOp should not be null."
);
auto
dims
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dims"
);
...
...
paddle/operators/lookup_table_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
"Input(W) of LookupTableOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
...
...
@@ -70,7 +70,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"W"
),
table_dims
);
}
...
...
paddle/operators/lstm_unit_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class LstmUnitOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C_prev"
),
"Input(C_prev) of LSTM should not be null."
);
...
...
@@ -77,7 +77,7 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"C"
)),
"Input(C@GRAD) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"H"
)),
...
...
paddle/operators/mean_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class MeanOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MeanOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -47,7 +47,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
};
...
...
paddle/operators/minus_op.cc
浏览文件 @
3a8de910
...
...
@@ -26,7 +26,7 @@ class MinusOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MinusOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
...
...
paddle/operators/modified_huber_loss_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
...
@@ -74,7 +74,7 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"IntermediateVal"
),
...
...
paddle/operators/mul_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class MulOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -97,7 +97,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/multiplex_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) shouldn't be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
"MultiInput(X) shouldn't be empty."
);
...
...
@@ -90,7 +90,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
empty
(),
"Output(X@Grad) should not be null."
);
...
...
paddle/operators/pad_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class PadOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of PadOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of PadOp should not be null."
);
...
...
@@ -98,7 +98,7 @@ class PadOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/pool_op.cc
浏览文件 @
3a8de910
...
...
@@ -27,7 +27,7 @@ class PoolOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X(Input) of Pooling should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -74,7 +74,7 @@ class PoolOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X(Input) of Pooling should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
...
...
paddle/operators/prelu_op.cc
浏览文件 @
3a8de910
...
...
@@ -26,7 +26,7 @@ class PReluOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Alpha"
),
"Input(Alpha) should not be null"
);
PADDLE_ENFORCE
(
product
(
ctx
->
GetInputDim
(
"Alpha"
))
==
1
,
...
...
@@ -63,7 +63,7 @@ class PReluGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/rank_loss_op.cc
浏览文件 @
3a8de910
...
...
@@ -25,7 +25,7 @@ class RankLossOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null"
);
...
...
@@ -90,7 +90,7 @@ class RankLossGradOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Right"
),
"Input(Right) shouldn't be null."
);
...
...
paddle/operators/reduce_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class ReduceOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReduceOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -58,7 +58,7 @@ class ReduceGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null."
);
...
...
paddle/operators/reshape_op.cc
浏览文件 @
3a8de910
...
...
@@ -26,7 +26,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReshapeOp should not be null."
);
...
...
@@ -94,7 +94,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) shouldn't be null."
);
...
...
paddle/operators/rmsprop_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class RmspropOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"MeanSquare"
),
...
...
paddle/operators/scale_op.cc
浏览文件 @
3a8de910
...
...
@@ -26,7 +26,7 @@ class ScaleOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ScaleOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
paddle/operators/scatter_op.cc
浏览文件 @
3a8de910
...
...
@@ -23,7 +23,7 @@ class ScatterOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ref"
),
"Input(Ref) of ScatterOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Index"
),
...
...
@@ -60,7 +60,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Updates"
),
ctx
->
GetInputDim
(
"Updates"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ref"
),
ctx
->
GetInputDim
(
"Ref"
));
...
...
paddle/operators/sequence_pool_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SequencePoolOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequencePoolOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -74,7 +74,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Gradient of Out should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"The input X should not be null."
);
...
...
paddle/operators/sequence_softmax_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequenceSoftmaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -67,7 +67,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Out"
),
"Input(Out) of SequenceSoftmaxGradOp should not be null."
);
PADDLE_ENFORCE
(
...
...
paddle/operators/sgd_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SGDOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of SGDOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
"Input(Labels) should be not null."
);
...
...
@@ -53,7 +53,7 @@ class SigmoidCrossEntropyWithLogitsGradOp
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
"Input(Labels) should be not null."
);
...
...
paddle/operators/smooth_l1_loss_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
...
@@ -94,7 +94,7 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
...
...
paddle/operators/softmax_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SoftmaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
...
...
@@ -69,7 +69,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) should be not null."
);
...
...
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
3a8de910
...
...
@@ -83,7 +83,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
"Input(Logits) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
...
...
@@ -128,7 +128,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
"Input(Loss@Grad) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Softmax"
),
...
...
paddle/operators/split_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class SplitOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SplitOp should not be null."
);
PADDLE_ENFORCE_GE
(
ctx
->
Outputs
(
"Out"
).
size
(),
1UL
,
...
...
paddle/operators/squared_l2_distance_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SquaredL2DistanceOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
...
...
@@ -86,7 +86,7 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Gradient of Out should not be null"
);
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
...
...
paddle/operators/sum_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SumOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) should not be null"
);
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
paddle/operators/top_k_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class TopkOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of TopkOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
paddle/operators/transpose_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class TransposeOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
...
...
@@ -93,7 +93,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
3a8de910
...
...
@@ -47,7 +47,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of UniformRandomOp should not be null."
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录