Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3a8de910
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3a8de910
编写于
10月 09, 2017
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into dev_add_attr_test
上级
5984cbca
383faaf7
变更
56
显示空白变更内容
内联
并排
Showing
56 changed file
with
245 addition
and
121 deletion
+245
-121
paddle/framework/backward.cc
paddle/framework/backward.cc
+1
-1
paddle/framework/op_desc.h
paddle/framework/op_desc.h
+5
-0
paddle/framework/op_registry.cc
paddle/framework/op_registry.cc
+8
-3
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+1
-1
paddle/framework/operator.cc
paddle/framework/operator.cc
+4
-4
paddle/framework/operator.h
paddle/framework/operator.h
+23
-32
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+1
-1
paddle/framework/shape_inference.h
paddle/framework/shape_inference.h
+3
-3
paddle/framework/tensor.h
paddle/framework/tensor.h
+13
-0
paddle/framework/tensor_impl.h
paddle/framework/tensor_impl.h
+23
-0
paddle/framework/tensor_test.cc
paddle/framework/tensor_test.cc
+87
-0
paddle/operators/accuracy_op.cc
paddle/operators/accuracy_op.cc
+1
-1
paddle/operators/activation_op.cc
paddle/operators/activation_op.cc
+2
-2
paddle/operators/adadelta_op.cc
paddle/operators/adadelta_op.cc
+1
-1
paddle/operators/adagrad_op.cc
paddle/operators/adagrad_op.cc
+1
-1
paddle/operators/adamax_op.cc
paddle/operators/adamax_op.cc
+1
-1
paddle/operators/clip_op.cc
paddle/operators/clip_op.cc
+2
-2
paddle/operators/concat_op.cc
paddle/operators/concat_op.cc
+2
-2
paddle/operators/conv2d_op.cc
paddle/operators/conv2d_op.cc
+2
-2
paddle/operators/cos_sim_op.cc
paddle/operators/cos_sim_op.cc
+2
-2
paddle/operators/crop_op.cc
paddle/operators/crop_op.cc
+2
-2
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+2
-2
paddle/operators/dropout_op.cc
paddle/operators/dropout_op.cc
+2
-2
paddle/operators/elementwise_op.h
paddle/operators/elementwise_op.h
+2
-2
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+1
-1
paddle/operators/gather_op.cc
paddle/operators/gather_op.cc
+2
-2
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+1
-1
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+2
-2
paddle/operators/lstm_unit_op.cc
paddle/operators/lstm_unit_op.cc
+2
-2
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+2
-2
paddle/operators/minus_op.cc
paddle/operators/minus_op.cc
+1
-1
paddle/operators/modified_huber_loss_op.cc
paddle/operators/modified_huber_loss_op.cc
+2
-2
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+2
-2
paddle/operators/multiplex_op.cc
paddle/operators/multiplex_op.cc
+2
-2
paddle/operators/pad_op.cc
paddle/operators/pad_op.cc
+2
-2
paddle/operators/pool_op.cc
paddle/operators/pool_op.cc
+2
-2
paddle/operators/prelu_op.cc
paddle/operators/prelu_op.cc
+2
-2
paddle/operators/rank_loss_op.cc
paddle/operators/rank_loss_op.cc
+2
-2
paddle/operators/reduce_op.cc
paddle/operators/reduce_op.cc
+2
-2
paddle/operators/reshape_op.cc
paddle/operators/reshape_op.cc
+2
-2
paddle/operators/rmsprop_op.cc
paddle/operators/rmsprop_op.cc
+1
-1
paddle/operators/scale_op.cc
paddle/operators/scale_op.cc
+1
-1
paddle/operators/scatter_op.cc
paddle/operators/scatter_op.cc
+2
-2
paddle/operators/sequence_pool_op.cc
paddle/operators/sequence_pool_op.cc
+2
-2
paddle/operators/sequence_softmax_op.cc
paddle/operators/sequence_softmax_op.cc
+2
-2
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+1
-1
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
+2
-2
paddle/operators/smooth_l1_loss_op.cc
paddle/operators/smooth_l1_loss_op.cc
+2
-2
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+2
-2
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+2
-2
paddle/operators/split_op.cc
paddle/operators/split_op.cc
+1
-1
paddle/operators/squared_l2_distance_op.cc
paddle/operators/squared_l2_distance_op.cc
+2
-2
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+1
-1
paddle/operators/top_k_op.cc
paddle/operators/top_k_op.cc
+1
-1
paddle/operators/transpose_op.cc
paddle/operators/transpose_op.cc
+2
-2
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+1
-1
未找到文件。
paddle/framework/backward.cc
浏览文件 @
3a8de910
...
...
@@ -302,7 +302,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return
grad_op_descs
;
// empty vector
}
grad_op_descs
=
OpRegistry
::
CreateGradOpDescs
(
*
op_desc
);
grad_op_descs
=
OpRegistry
::
CreateGradOpDescs
(
op_desc
.
get
()
);
std
::
list
<
std
::
unique_ptr
<
OpDescBind
>>
pending_fill_zeros_ops
;
for
(
auto
&
desc
:
grad_op_descs
)
{
...
...
paddle/framework/op_desc.h
浏览文件 @
3a8de910
...
...
@@ -97,6 +97,11 @@ class OpDescBind {
const
VariableNameMap
&
Outputs
()
const
{
return
outputs_
;
}
AttributeMap
*
MutableAttrMap
()
{
this
->
need_update_
=
true
;
return
&
this
->
attrs_
;
}
private:
template
<
typename
MapType
>
static
std
::
vector
<
typename
MapType
::
key_type
>
MapKeys
(
const
MapType
&
map
)
{
...
...
paddle/framework/op_registry.cc
浏览文件 @
3a8de910
...
...
@@ -60,9 +60,14 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
}
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
OpRegistry
::
CreateGradOpDescs
(
const
OpDescBind
&
op_desc
)
{
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
op_desc
.
Type
());
return
info
.
grad_op_maker_
(
op_desc
);
OpDescBind
*
op_desc
)
{
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
op_desc
->
Type
());
if
(
info
.
Checker
()
!=
nullptr
)
{
info
.
Checker
()
->
Check
(
*
op_desc
->
MutableAttrMap
());
}
return
info
.
grad_op_maker_
(
*
op_desc
);
}
}
// namespace framework
...
...
paddle/framework/op_registry.h
浏览文件 @
3a8de910
...
...
@@ -80,7 +80,7 @@ class OpRegistry {
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
);
static
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
CreateGradOpDescs
(
const
OpDescBind
&
op_desc
);
OpDescBind
*
op_desc
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDescBind
&
op_desc
);
};
...
...
paddle/framework/operator.cc
浏览文件 @
3a8de910
...
...
@@ -205,13 +205,13 @@ void OperatorBase::GenerateTemporaryNames() {
}
template
<
>
const
Tensor
*
InferShape
Context
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
const
Tensor
*
Execution
Context
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
auto
*
var
=
InputVar
(
name
);
return
var
==
nullptr
?
nullptr
:
GetTensorFromVar
(
var
);
}
template
<
>
const
std
::
vector
<
const
Tensor
*>
InferShape
Context
::
MultiInput
<
Tensor
>
(
const
std
::
vector
<
const
Tensor
*>
Execution
Context
::
MultiInput
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
auto
names
=
op
().
Inputs
(
name
);
std
::
vector
<
const
Tensor
*>
res
;
...
...
@@ -225,13 +225,13 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
}
template
<
>
Tensor
*
InferShape
Context
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
Tensor
*
Execution
Context
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
auto
var
=
OutputVar
(
name
);
return
var
==
nullptr
?
nullptr
:
var
->
GetMutable
<
LoDTensor
>
();
}
template
<
>
std
::
vector
<
Tensor
*>
InferShape
Context
::
MultiOutput
<
Tensor
>
(
std
::
vector
<
Tensor
*>
Execution
Context
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
auto
names
=
op
().
Outputs
(
name
);
std
::
vector
<
Tensor
*>
res
;
...
...
paddle/framework/operator.h
浏览文件 @
3a8de910
...
...
@@ -57,7 +57,6 @@ inline std::string GradVarName(const std::string& var_name) {
}
class
OperatorBase
;
class
InferShapeContext
;
class
ExecutionContext
;
extern
const
Tensor
*
GetTensorFromVar
(
const
Variable
*
var
);
...
...
@@ -169,10 +168,11 @@ class NOP : public OperatorBase {
}
};
class
InferShape
Context
{
class
Execution
Context
{
public:
InferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
:
op_
(
op
),
scope_
(
scope
)
{}
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
OperatorBase
&
op
()
const
{
return
op_
;
}
...
...
@@ -278,31 +278,6 @@ class InferShapeContext {
out_tensor
->
set_lod
(
in_tensor
.
lod
());
}
private:
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
};
template
<>
const
Tensor
*
InferShapeContext
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
const
std
::
vector
<
const
Tensor
*>
InferShapeContext
::
MultiInput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
Tensor
*
InferShapeContext
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
std
::
vector
<
Tensor
*>
InferShapeContext
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
class
ExecutionContext
:
public
InferShapeContext
{
public:
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
InferShapeContext
(
op
,
scope
),
device_context_
(
device_context
)
{}
template
<
typename
PlaceType
,
typename
DeviceType
=
typename
platform
::
EigenDeviceConverter
<
PlaceType
>::
EigenDeviceType
>
...
...
@@ -315,10 +290,26 @@ class ExecutionContext : public InferShapeContext {
}
private:
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
};
class
CompileTimeInferShapeContext
:
public
InferShapeContextBase
{
template
<>
const
Tensor
*
ExecutionContext
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
const
std
::
vector
<
const
Tensor
*>
ExecutionContext
::
MultiInput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
Tensor
*
ExecutionContext
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
std
::
vector
<
Tensor
*>
ExecutionContext
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
class
CompileTimeInferShapeContext
:
public
InferShapeContext
{
public:
CompileTimeInferShapeContext
(
const
OpDescBind
&
op
,
const
BlockDescBind
&
block
)
:
op_
(
op
),
block_
(
block
)
{}
...
...
@@ -414,7 +405,7 @@ class CompileTimeInferShapeContext : public InferShapeContextBase {
const
BlockDescBind
&
block_
;
};
class
RuntimeInferShapeContext
:
public
InferShapeContext
Base
{
class
RuntimeInferShapeContext
:
public
InferShapeContext
{
public:
RuntimeInferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
:
op_
(
op
),
scope_
(
scope
)
{}
...
...
@@ -612,7 +603,7 @@ class OperatorWithKernel : public OperatorBase {
});
}
virtual
void
InferShape
(
InferShapeContext
Base
*
ctx
)
const
=
0
;
virtual
void
InferShape
(
InferShapeContext
*
ctx
)
const
=
0
;
protected:
// indicate kernel DataType by input data. Defaultly all input data must be
...
...
paddle/framework/operator_test.cc
浏览文件 @
3a8de910
...
...
@@ -113,7 +113,7 @@ class OpWithKernelTest : public OperatorWithKernel {
using
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
override
{
return
DataType
::
FP32
;
}
...
...
paddle/framework/shape_inference.h
浏览文件 @
3a8de910
...
...
@@ -20,11 +20,11 @@ namespace paddle {
namespace
framework
{
// TODO(longfei): Once after both CompileTimeInferShapeContext and
// RuntimeInferShapeContext get merged, we can rename InferShapeContext
Base
into
// RuntimeInferShapeContext get merged, we can rename InferShapeContext into
// InferShapeContext so to replace the current InferShapeContext.
class
InferShapeContext
Base
{
class
InferShapeContext
{
public:
virtual
~
InferShapeContext
Base
()
{}
virtual
~
InferShapeContext
()
{}
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
=
0
;
...
...
paddle/framework/tensor.h
浏览文件 @
3a8de910
...
...
@@ -95,6 +95,19 @@ class Tensor {
template
<
typename
T
>
inline
void
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
);
/**
* @brief Copy the content of an external vector to a tensor.
*
* @param[in] src The external vector.
* @param[in] ctx The device context contains place where to store.
*
* * @note CopyFromVector assumes that the tensor has been resized
* before invoking.
*/
template
<
typename
T
>
inline
void
CopyFromVector
(
const
std
::
vector
<
T
>&
src
,
const
platform
::
Place
&
dst_place
);
/**
* @brief Return the slice of the tensor.
*
...
...
paddle/framework/tensor_impl.h
浏览文件 @
3a8de910
...
...
@@ -123,6 +123,29 @@ inline void Tensor::CopyFrom(const Tensor& src,
#endif
}
template
<
typename
T
>
inline
void
Tensor
::
CopyFromVector
(
const
std
::
vector
<
T
>&
src
,
const
platform
::
Place
&
dst_place
)
{
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
());
platform
::
CPUPlace
src_place
;
auto
dst_ptr
=
static_cast
<
void
*>
(
mutable_data
<
T
>
(
dst_place
));
auto
size
=
src
.
size
()
*
sizeof
(
T
);
if
(
platform
::
is_cpu_place
(
dst_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
CPUPlace
>
(
dst_place
),
dst_ptr
,
src_place
,
src_ptr
,
size
);
}
#ifdef PADDLE_WITH_CUDA
else
if
(
platform
::
is_gpu_place
(
dst_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
dst_place
),
dst_ptr
,
src_place
,
src_ptr
,
size
,
0
);
}
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
0
),
"cudaStreamSynchronize failed in Tensor CopyFromVector"
);
#endif
}
template
<
typename
T
>
inline
Tensor
Tensor
::
Slice
(
const
int
&
begin_idx
,
const
int
&
end_idx
)
const
{
check_memory_size
<
T
>
();
...
...
paddle/framework/tensor_test.cc
浏览文件 @
3a8de910
...
...
@@ -263,6 +263,93 @@ TEST(Tensor, CopyFrom) {
#endif
}
TEST
(
Tensor
,
CopyFromVector
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
{
std
::
vector
<
int
>
src_vec
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
Tensor
cpu_tensor
;
// Copy to CPU Tensor
cpu_tensor
.
Resize
(
make_ddim
({
3
,
3
}));
auto
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
cpu_tensor
.
CopyFromVector
<
int
>
(
src_vec
,
*
cpu_place
);
// Compare Tensors
const
int
*
cpu_ptr
=
cpu_tensor
.
data
<
int
>
();
const
int
*
src_ptr
=
src_vec
.
data
();
ASSERT_NE
(
src_ptr
,
cpu_ptr
);
for
(
size_t
i
=
0
;
i
<
9
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
cpu_ptr
[
i
]);
}
src_vec
.
erase
(
src_vec
.
begin
(),
src_vec
.
begin
()
+
5
);
cpu_tensor
.
Resize
(
make_ddim
({
2
,
2
}));
cpu_tensor
.
CopyFromVector
<
int
>
(
src_vec
,
*
cpu_place
);
cpu_ptr
=
cpu_tensor
.
data
<
int
>
();
src_ptr
=
src_vec
.
data
();
ASSERT_NE
(
src_ptr
,
cpu_ptr
);
for
(
size_t
i
=
0
;
i
<
5
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
cpu_ptr
[
i
]);
}
delete
cpu_place
;
}
#ifdef PADDLE_WITH_CUDA
{
std
::
vector
<
int
>
src_vec
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
Tensor
cpu_tensor
;
Tensor
gpu_tensor
;
Tensor
dst_tensor
;
// Copy to CPU Tensor
cpu_tensor
.
Resize
(
make_ddim
({
3
,
3
}));
auto
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
cpu_tensor
.
CopyFromVector
<
int
>
(
src_vec
,
*
cpu_place
);
// Copy to GPUTensor
gpu_tensor
.
Resize
(
make_ddim
({
3
,
3
}));
auto
gpu_place
=
new
paddle
::
platform
::
GPUPlace
();
gpu_tensor
.
CopyFromVector
<
int
>
(
src_vec
,
*
gpu_place
);
// Copy from GPU to CPU tensor for comparison
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_place
);
// Compare Tensors
const
int
*
src_ptr
=
src_vec
.
data
();
const
int
*
cpu_ptr
=
cpu_tensor
.
data
<
int
>
();
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
src_ptr
,
cpu_ptr
);
ASSERT_NE
(
src_ptr
,
dst_ptr
);
for
(
size_t
i
=
0
;
i
<
9
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
cpu_ptr
[
i
]);
EXPECT_EQ
(
src_ptr
[
i
],
dst_ptr
[
i
]);
}
src_vec
.
erase
(
src_vec
.
begin
(),
src_vec
.
begin
()
+
5
);
cpu_tensor
.
Resize
(
make_ddim
({
2
,
2
}));
cpu_tensor
.
CopyFromVector
<
int
>
(
src_vec
,
*
cpu_place
);
gpu_tensor
.
Resize
(
make_ddim
({
2
,
2
}));
gpu_tensor
.
CopyFromVector
<
int
>
(
src_vec
,
*
gpu_place
);
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_place
);
src_ptr
=
src_vec
.
data
();
cpu_ptr
=
cpu_tensor
.
data
<
int
>
();
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
src_ptr
,
cpu_ptr
);
ASSERT_NE
(
src_ptr
,
dst_ptr
);
for
(
size_t
i
=
0
;
i
<
5
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
cpu_ptr
[
i
]);
EXPECT_EQ
(
src_ptr
[
i
],
dst_ptr
[
i
]);
}
delete
cpu_place
;
delete
gpu_place
;
}
#endif
}
TEST
(
Tensor
,
ReshapeToMatrix
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
...
...
paddle/operators/accuracy_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Inference"
),
"Input(Inference) of AccuracyOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
...
...
paddle/operators/activation_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class ActivationOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
}
...
...
@@ -33,7 +33,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Y"
));
}
};
...
...
paddle/operators/adadelta_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdadeltaOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
paddle/operators/adagrad_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class AdagradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdagradOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
paddle/operators/adamax_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class AdamaxOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdamaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
paddle/operators/clip_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class ClipOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ClipOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -61,7 +61,7 @@ class ClipOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/concat_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class ConcatOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_GE
(
ctx
->
Inputs
(
"X"
).
size
(),
1UL
,
"Inputs(X) of ConcatOp should be empty."
)
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -83,7 +83,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputsDim
(
"X"
));
}
};
...
...
paddle/operators/conv2d_op.cc
浏览文件 @
3a8de910
...
...
@@ -27,7 +27,7 @@ class Conv2DOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of Conv2DOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Filter"
),
...
...
@@ -106,7 +106,7 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
filter_dims
=
ctx
->
GetInputDim
(
"Filter"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Input"
)))
{
...
...
paddle/operators/cos_sim_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class CosSimOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// notnull check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of CosSimOp should not be null."
);
...
...
@@ -98,7 +98,7 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// notnull check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) must not be null."
);
...
...
paddle/operators/crop_op.cc
浏览文件 @
3a8de910
...
...
@@ -25,7 +25,7 @@ class CropOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of CropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -115,7 +115,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
"Output(Y) should be not null."
);
...
...
@@ -60,7 +60,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
...
...
paddle/operators/dropout_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class DropoutOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
0
);
PADDLE_ENFORCE_LE
(
ctx
->
Attrs
().
Get
<
float
>
(
"dropout_prob"
),
1
);
...
...
@@ -70,7 +70,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
Attrs
().
Get
<
bool
>
(
"is_training"
),
1
,
"GradOp is only callable when is_training is true"
);
...
...
paddle/operators/elementwise_op.h
浏览文件 @
3a8de910
...
...
@@ -25,7 +25,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
protected:
using
Tensor
=
framework
::
Tensor
;
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of elementwise op should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
...
...
@@ -106,7 +106,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using
Tensor
=
framework
::
Tensor
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FillZerosLikeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
...
...
paddle/operators/gather_op.cc
浏览文件 @
3a8de910
...
...
@@ -23,7 +23,7 @@ class GatherOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of GatherOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Index"
),
...
...
@@ -51,7 +51,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
...
...
paddle/operators/gaussian_random_op.cc
浏览文件 @
3a8de910
...
...
@@ -43,7 +43,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of GaussianRandomOp should not be null."
);
auto
dims
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dims"
);
...
...
paddle/operators/lookup_table_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
"Input(W) of LookupTableOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
...
...
@@ -70,7 +70,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"W"
),
table_dims
);
}
...
...
paddle/operators/lstm_unit_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class LstmUnitOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C_prev"
),
"Input(C_prev) of LSTM should not be null."
);
...
...
@@ -77,7 +77,7 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"C"
)),
"Input(C@GRAD) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"H"
)),
...
...
paddle/operators/mean_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class MeanOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MeanOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -47,7 +47,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
};
...
...
paddle/operators/minus_op.cc
浏览文件 @
3a8de910
...
...
@@ -26,7 +26,7 @@ class MinusOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MinusOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
...
...
paddle/operators/modified_huber_loss_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
...
@@ -74,7 +74,7 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"IntermediateVal"
),
...
...
paddle/operators/mul_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class MulOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -97,7 +97,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
...
paddle/operators/multiplex_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) shouldn't be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
"MultiInput(X) shouldn't be empty."
);
...
...
@@ -90,7 +90,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
->
Inputs
(
"X"
).
empty
(),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
!
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
empty
(),
"Output(X@Grad) should not be null."
);
...
...
paddle/operators/pad_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class PadOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of PadOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of PadOp should not be null."
);
...
...
@@ -98,7 +98,7 @@ class PadOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/pool_op.cc
浏览文件 @
3a8de910
...
...
@@ -27,7 +27,7 @@ class PoolOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X(Input) of Pooling should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -74,7 +74,7 @@ class PoolOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X(Input) of Pooling should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
...
...
paddle/operators/prelu_op.cc
浏览文件 @
3a8de910
...
...
@@ -26,7 +26,7 @@ class PReluOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Alpha"
),
"Input(Alpha) should not be null"
);
PADDLE_ENFORCE
(
product
(
ctx
->
GetInputDim
(
"Alpha"
))
==
1
,
...
...
@@ -63,7 +63,7 @@ class PReluGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/rank_loss_op.cc
浏览文件 @
3a8de910
...
...
@@ -25,7 +25,7 @@ class RankLossOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null"
);
...
...
@@ -90,7 +90,7 @@ class RankLossGradOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Left"
),
"Input(Left) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Right"
),
"Input(Right) shouldn't be null."
);
...
...
paddle/operators/reduce_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class ReduceOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReduceOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -58,7 +58,7 @@ class ReduceGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null."
);
...
...
paddle/operators/reshape_op.cc
浏览文件 @
3a8de910
...
...
@@ -26,7 +26,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// input check
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReshapeOp should not be null."
);
...
...
@@ -94,7 +94,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) shouldn't be null."
);
...
...
paddle/operators/rmsprop_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class RmspropOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"MeanSquare"
),
...
...
paddle/operators/scale_op.cc
浏览文件 @
3a8de910
...
...
@@ -26,7 +26,7 @@ class ScaleOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ScaleOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
paddle/operators/scatter_op.cc
浏览文件 @
3a8de910
...
...
@@ -23,7 +23,7 @@ class ScatterOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ref"
),
"Input(Ref) of ScatterOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Index"
),
...
...
@@ -60,7 +60,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Updates"
),
ctx
->
GetInputDim
(
"Updates"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Ref"
),
ctx
->
GetInputDim
(
"Ref"
));
...
...
paddle/operators/sequence_pool_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SequencePoolOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequencePoolOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -74,7 +74,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Gradient of Out should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"The input X should not be null."
);
...
...
paddle/operators/sequence_softmax_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequenceSoftmaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -67,7 +67,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Out"
),
"Input(Out) of SequenceSoftmaxGradOp should not be null."
);
PADDLE_ENFORCE
(
...
...
paddle/operators/sgd_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SGDOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of SGDOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
paddle/operators/sigmoid_cross_entropy_with_logits_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
"Input(Labels) should be not null."
);
...
...
@@ -53,7 +53,7 @@ class SigmoidCrossEntropyWithLogitsGradOp
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Labels"
),
"Input(Labels) should be not null."
);
...
...
paddle/operators/smooth_l1_loss_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X must be initialized."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Y must be initialized."
);
...
...
@@ -94,7 +94,7 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
...
...
paddle/operators/softmax_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SoftmaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
...
...
@@ -69,7 +69,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) should be not null."
);
...
...
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
3a8de910
...
...
@@ -83,7 +83,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Logits"
),
"Input(Logits) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Label"
),
"Input(Label) should be not null."
);
...
...
@@ -128,7 +128,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
"Input(Loss@Grad) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Softmax"
),
...
...
paddle/operators/split_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class SplitOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SplitOp should not be null."
);
PADDLE_ENFORCE_GE
(
ctx
->
Outputs
(
"Out"
).
size
(),
1UL
,
...
...
paddle/operators/squared_l2_distance_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SquaredL2DistanceOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
...
...
@@ -86,7 +86,7 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Gradient of Out should not be null"
);
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
...
...
paddle/operators/sum_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class SumOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Inputs(X) should not be null"
);
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
paddle/operators/top_k_op.cc
浏览文件 @
3a8de910
...
...
@@ -22,7 +22,7 @@ class TopkOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of TopkOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
paddle/operators/transpose_op.cc
浏览文件 @
3a8de910
...
...
@@ -24,7 +24,7 @@ class TransposeOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
...
...
@@ -93,7 +93,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
3a8de910
...
...
@@ -47,7 +47,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
Base
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of UniformRandomOp should not be null."
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录