Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
5046869e
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5046869e
编写于
2月 12, 2018
作者:
Y
Yu Yang
提交者:
GitHub
2月 12, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #8287 from tonyyang-svail/operator_set_device
Correctly handle cuda place for operators
上级
7757a8ad
40c7972d
变更
41
隐藏空白更改
内联
并排
Showing
41 changed file
with
214 addition
and
114 deletion
+214
-114
paddle/fluid/framework/op_registry_test.cc
paddle/fluid/framework/op_registry_test.cc
+8
-2
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+14
-2
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+10
-5
paddle/fluid/framework/operator_test.cc
paddle/fluid/framework/operator_test.cc
+8
-3
paddle/fluid/operators/array_to_lod_tensor_op.cc
paddle/fluid/operators/array_to_lod_tensor_op.cc
+4
-2
paddle/fluid/operators/assign_op.cc
paddle/fluid/operators/assign_op.cc
+4
-2
paddle/fluid/operators/beam_search_decode_op.cc
paddle/fluid/operators/beam_search_decode_op.cc
+4
-2
paddle/fluid/operators/beam_search_op.h
paddle/fluid/operators/beam_search_op.h
+3
-2
paddle/fluid/operators/cond_op.cc
paddle/fluid/operators/cond_op.cc
+1
-1
paddle/fluid/operators/cond_op.h
paddle/fluid/operators/cond_op.h
+3
-2
paddle/fluid/operators/conditional_block_op.cc
paddle/fluid/operators/conditional_block_op.cc
+8
-4
paddle/fluid/operators/create_reader_op.cc
paddle/fluid/operators/create_reader_op.cc
+12
-6
paddle/fluid/operators/feed_op.cc
paddle/fluid/operators/feed_op.cc
+4
-2
paddle/fluid/operators/fetch_op.cc
paddle/fluid/operators/fetch_op.cc
+3
-2
paddle/fluid/operators/fill_constant_op.cc
paddle/fluid/operators/fill_constant_op.cc
+4
-2
paddle/fluid/operators/fill_op.cc
paddle/fluid/operators/fill_op.cc
+4
-2
paddle/fluid/operators/get_places_op.cc
paddle/fluid/operators/get_places_op.cc
+4
-2
paddle/fluid/operators/increment_op.cc
paddle/fluid/operators/increment_op.cc
+3
-2
paddle/fluid/operators/is_empty_op.cc
paddle/fluid/operators/is_empty_op.cc
+3
-2
paddle/fluid/operators/load_combine_op.cc
paddle/fluid/operators/load_combine_op.cc
+4
-2
paddle/fluid/operators/load_op.cc
paddle/fluid/operators/load_op.cc
+4
-2
paddle/fluid/operators/lod_array_length_op.cc
paddle/fluid/operators/lod_array_length_op.cc
+4
-2
paddle/fluid/operators/lod_rank_table_op.cc
paddle/fluid/operators/lod_rank_table_op.cc
+4
-2
paddle/fluid/operators/lod_tensor_to_array_op.cc
paddle/fluid/operators/lod_tensor_to_array_op.cc
+4
-2
paddle/fluid/operators/max_sequence_len_op.cc
paddle/fluid/operators/max_sequence_len_op.cc
+3
-2
paddle/fluid/operators/merge_lod_tensor_op.cc
paddle/fluid/operators/merge_lod_tensor_op.cc
+4
-2
paddle/fluid/operators/nccl_op.cc
paddle/fluid/operators/nccl_op.cc
+3
-2
paddle/fluid/operators/net_op.h
paddle/fluid/operators/net_op.h
+14
-14
paddle/fluid/operators/net_op_test.cc
paddle/fluid/operators/net_op_test.cc
+4
-1
paddle/fluid/operators/parallel_do_op.cc
paddle/fluid/operators/parallel_do_op.cc
+6
-4
paddle/fluid/operators/print_op.cc
paddle/fluid/operators/print_op.cc
+3
-2
paddle/fluid/operators/read_op.cc
paddle/fluid/operators/read_op.cc
+4
-2
paddle/fluid/operators/recurrent_op.cc
paddle/fluid/operators/recurrent_op.cc
+6
-4
paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc
paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc
+4
-2
paddle/fluid/operators/rnn_memory_helper_op.cc
paddle/fluid/operators/rnn_memory_helper_op.cc
+8
-4
paddle/fluid/operators/save_combine_op.cc
paddle/fluid/operators/save_combine_op.cc
+4
-2
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+4
-2
paddle/fluid/operators/shrink_rnn_memory_op.cc
paddle/fluid/operators/shrink_rnn_memory_op.cc
+6
-4
paddle/fluid/operators/split_lod_tensor_op.cc
paddle/fluid/operators/split_lod_tensor_op.cc
+4
-2
paddle/fluid/operators/tensor_array_read_write_op.cc
paddle/fluid/operators/tensor_array_read_write_op.cc
+7
-4
paddle/fluid/operators/while_op.cc
paddle/fluid/operators/while_op.cc
+6
-4
未找到文件。
paddle/fluid/framework/op_registry_test.cc
浏览文件 @
5046869e
...
...
@@ -25,7 +25,10 @@ namespace framework {
class
CosineOp
:
public
OperatorBase
{
public:
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
...
@@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class
MyTestOp
:
public
OperatorBase
{
public:
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
5046869e
...
...
@@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}
}
void
OperatorBase
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
{
if
(
platform
::
is_gpu_place
(
place
))
{
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW
(
"Cannot run operator on place %s"
,
place
);
#else
auto
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
;
platform
::
SetDeviceId
(
dev_id
);
#endif
}
RunImpl
(
scope
,
place
);
}
std
::
string
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
&
ins
=
Inputs
(
name
);
PADDLE_ENFORCE_LE
(
ins
.
size
(),
1UL
,
...
...
@@ -479,8 +491,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
const
Scope
&
scope_
;
};
void
OperatorWithKernel
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
void
OperatorWithKernel
::
Run
Impl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
scope
);
this
->
InferShape
(
&
infer_shape_ctx
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
paddle/fluid/framework/operator.h
浏览文件 @
5046869e
...
...
@@ -89,8 +89,9 @@ class OperatorBase {
std
::
string
DebugString
()
const
{
return
DebugStringEx
(
nullptr
);
}
/// Net will call this function to Run an op.
virtual
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
=
0
;
/// Net will call this interface function to Run an op.
// The implementation should be written at RunImpl
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
);
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual
void
Stop
()
{}
...
...
@@ -144,6 +145,8 @@ class OperatorBase {
private:
void
GenerateTemporaryNames
();
void
CheckAllInputOutputSet
()
const
;
virtual
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
=
0
;
};
// Macro for define a clone method.
...
...
@@ -168,10 +171,13 @@ class OperatorBase {
class
NOP
:
public
OperatorBase
{
public:
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
std
::
unique_ptr
<
OperatorBase
>
Clone
()
const
override
{
return
std
::
unique_ptr
<
OperatorBase
>
(
new
NOP
(
*
this
));
}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
class
ExecutionContext
{
...
...
@@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase {
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
final
;
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
AllOpKernels
()
{
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
...
...
@@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase {
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
proto
::
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
;
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
final
;
};
extern
bool
OpSupportGPU
(
const
std
::
string
&
op_type
);
...
...
paddle/fluid/framework/operator_test.cc
浏览文件 @
5046869e
...
...
@@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase {
OpWithoutKernelTest
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
),
x
(
1
)
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
++
op_run_num
;
ASSERT_EQ
(
static_cast
<
int
>
(
inputs_
.
size
()),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
outputs_
.
size
()),
1
);
...
...
@@ -259,8 +262,10 @@ class OperatorClone : public paddle::framework::OperatorBase {
const
paddle
::
framework
::
VariableNameMap
&
outputs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
paddle
::
framework
::
Scope
&
scope
,
const
paddle
::
platform
::
Place
&
place
)
const
override
{}
private:
void
RunImpl
(
const
paddle
::
framework
::
Scope
&
scope
,
const
paddle
::
platform
::
Place
&
place
)
const
override
{}
};
TEST
(
Operator
,
Clone
)
{
...
...
paddle/fluid/operators/array_to_lod_tensor_op.cc
浏览文件 @
5046869e
...
...
@@ -31,8 +31,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
rank_table
=
scope
.
FindVar
(
Input
(
"RankTable"
))
->
Get
<
framework
::
LoDRankTable
>
();
...
...
paddle/fluid/operators/assign_op.cc
浏览文件 @
5046869e
...
...
@@ -71,8 +71,10 @@ class AssignOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
if
(
x
==
nullptr
)
{
return
;
...
...
paddle/fluid/operators/beam_search_decode_op.cc
浏览文件 @
5046869e
...
...
@@ -55,8 +55,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
paddle/fluid/operators/beam_search_op.h
浏览文件 @
5046869e
...
...
@@ -204,8 +204,9 @@ class BeamSearchOp : public framework::OperatorBase {
PADDLE_THROW
(
"Not Implemented"
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
ids_var
=
scope
.
FindVar
(
Input
(
"ids"
));
auto
scores_var
=
scope
.
FindVar
(
Input
(
"scores"
));
auto
pre_ids_var
=
scope
.
FindVar
(
Input
(
"pre_ids"
));
...
...
paddle/fluid/operators/cond_op.cc
浏览文件 @
5046869e
...
...
@@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
}
}
void
CondOp
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
void
CondOp
::
Run
Impl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
...
...
paddle/fluid/operators/cond_op.h
浏览文件 @
5046869e
...
...
@@ -77,8 +77,9 @@ class CondOp : public framework::OperatorBase {
sub_net_op_
[
FALSE_BRANCH
]
=
std
::
move
(
net
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
;
private:
const
int
TRUE_BRANCH
=
0
;
...
...
paddle/fluid/operators/conditional_block_op.cc
浏览文件 @
5046869e
...
...
@@ -65,8 +65,10 @@ class ConditionalBlockOp : public ConditionalOp {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
xs
=
InputTensors
(
scope
);
bool
need_run
;
...
...
@@ -128,8 +130,10 @@ class ConditionalBlockGradOp : public ConditionalOp {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
xs
=
this
->
InputTensors
(
scope
);
bool
need_run
;
...
...
paddle/fluid/operators/create_reader_op.cc
浏览文件 @
5046869e
...
...
@@ -106,8 +106,10 @@ template <typename T>
class
CreateRandomDataGeneratorOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
...
...
@@ -155,8 +157,10 @@ class CreateRandomDataGeneratorOpMaker
class
CreateShuffleReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
...
...
@@ -187,8 +191,10 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
class
CreateBatchReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
...
...
paddle/fluid/operators/feed_op.cc
浏览文件 @
5046869e
...
...
@@ -24,8 +24,10 @@ class FeedOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
feed_var_name
=
Input
(
"X"
);
auto
*
feed_var
=
scope
.
FindVar
(
feed_var_name
);
...
...
paddle/fluid/operators/fetch_op.cc
浏览文件 @
5046869e
...
...
@@ -26,8 +26,9 @@ class FetchOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
fetch_var_name
=
Input
(
"X"
);
auto
*
fetch_var
=
scope
.
FindVar
(
fetch_var_name
);
PADDLE_ENFORCE
(
fetch_var
!=
nullptr
,
...
...
paddle/fluid/operators/fill_constant_op.cc
浏览文件 @
5046869e
...
...
@@ -33,8 +33,10 @@ class FillConstantInferShape : public framework::InferShapeBase {
class
FillConstantOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
data_type
=
static_cast
<
framework
::
proto
::
DataType
>
(
Attr
<
int
>
(
"dtype"
));
auto
value
=
Attr
<
float
>
(
"value"
);
...
...
paddle/fluid/operators/fill_op.cc
浏览文件 @
5046869e
...
...
@@ -42,8 +42,10 @@ class FillOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
out
=
detail
::
Ref
(
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)),
"Cannot find variable %s"
,
Output
(
"Out"
))
...
...
paddle/fluid/operators/get_places_op.cc
浏览文件 @
5046869e
...
...
@@ -37,8 +37,10 @@ class GetPlacesOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
bool
is_gpu
;
if
(
Attr
<
std
::
string
>
(
"device_type"
)
==
"AUTO"
)
{
is_gpu
=
platform
::
is_gpu_place
(
place
);
...
...
paddle/fluid/operators/increment_op.cc
浏览文件 @
5046869e
...
...
@@ -51,8 +51,9 @@ class IncrementOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
out
=
*
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/operators/is_empty_op.cc
浏览文件 @
5046869e
...
...
@@ -28,8 +28,9 @@ class IsEmptyOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
// get input
auto
*
var
=
scope
.
FindVar
(
Input
(
kInput
));
PADDLE_ENFORCE_NOT_NULL
(
var
);
...
...
paddle/fluid/operators/load_combine_op.cc
浏览文件 @
5046869e
...
...
@@ -26,8 +26,10 @@ class LoadCombineOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
std
::
ifstream
fin
(
filename
);
...
...
paddle/fluid/operators/load_op.cc
浏览文件 @
5046869e
...
...
@@ -25,8 +25,10 @@ class LoadOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
std
::
ifstream
fin
(
filename
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load op"
,
...
...
paddle/fluid/operators/lod_array_length_op.cc
浏览文件 @
5046869e
...
...
@@ -25,8 +25,10 @@ class LoDArrayLengthOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
out
=
*
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/operators/lod_rank_table_op.cc
浏览文件 @
5046869e
...
...
@@ -23,8 +23,10 @@ class LoDRankTableOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDRankTable
>
();
...
...
paddle/fluid/operators/lod_tensor_to_array_op.cc
浏览文件 @
5046869e
...
...
@@ -32,8 +32,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
"Cannot find input %s"
,
Input
(
"X"
))
.
Get
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/operators/max_sequence_len_op.cc
浏览文件 @
5046869e
...
...
@@ -27,8 +27,9 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
rank_table
=
scope
.
FindVar
(
Input
(
"RankTable"
))
->
Get
<
framework
::
LoDRankTable
>
();
auto
*
out
=
...
...
paddle/fluid/operators/merge_lod_tensor_op.cc
浏览文件 @
5046869e
...
...
@@ -27,8 +27,10 @@ class MergeLoDTensorOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
paddle/fluid/operators/nccl_op.cc
浏览文件 @
5046869e
...
...
@@ -26,8 +26,9 @@ class NCCLInitOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
auto
&
name
=
Output
(
"Communicator"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
name
),
"Can not find variable '%s' in the scope."
,
name
);
...
...
paddle/fluid/operators/net_op.h
浏览文件 @
5046869e
...
...
@@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase {
this
->
CompleteAddOp
();
}
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
scope
,
place
);
}
}
bool
SupportGPU
()
const
override
{
for
(
auto
&
op
:
ops_
)
{
if
(
!
op
->
SupportGPU
())
{
...
...
@@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase {
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops_
;
private:
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
scope
,
place
);
}
}
bool
add_op_done_
{
false
};
std
::
set
<
std
::
string
>
intermediate_outputs_
;
...
...
paddle/fluid/operators/net_op_test.cc
浏览文件 @
5046869e
...
...
@@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase {
public:
using
framework
::
OperatorBase
::
OperatorBase
;
DEFINE_OP_CLONE_METHOD
(
TestOp
);
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
++
run_cnt
;
}
};
...
...
paddle/fluid/operators/parallel_do_op.cc
浏览文件 @
5046869e
...
...
@@ -118,8 +118,9 @@ class ParallelDoOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
...
...
@@ -207,8 +208,9 @@ class ParallelDoGradOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kParallelBlock
);
auto
*
program
=
block
->
Program
();
...
...
paddle/fluid/operators/print_op.cc
浏览文件 @
5046869e
...
...
@@ -130,8 +130,9 @@ class TensorPrintOp : public framework::OperatorBase {
PADDLE_THROW
(
"Not implemented."
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
framework
::
Variable
*
in_var_ptr
=
nullptr
;
std
::
string
phase
=
kForward
;
std
::
string
printed_var_name
=
""
;
...
...
paddle/fluid/operators/read_op.cc
浏览文件 @
5046869e
...
...
@@ -54,8 +54,10 @@ class ReadInferVarType : public framework::VarTypeInference {
class
ReadOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
framework
::
ReaderHolder
*
reader
=
scope
.
FindVar
(
Input
(
"Reader"
))
->
GetMutable
<
framework
::
ReaderHolder
>
();
if
(
!
reader
->
HasNext
())
{
...
...
paddle/fluid/operators/recurrent_op.cc
浏览文件 @
5046869e
...
...
@@ -226,8 +226,9 @@ class RecurrentOp : public RecurrentBase {
const
framework
::
AttributeMap
&
attrs
)
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
seq_len
=
static_cast
<
size_t
>
(
this
->
GetSequenceLength
(
scope
));
VLOG
(
3
)
<<
"Static RNN input sequence length = "
<<
seq_len
;
StepScopes
scopes
=
CreateStepScopes
(
scope
,
seq_len
);
...
...
@@ -315,8 +316,9 @@ class RecurrentGradOp : public RecurrentBase {
const
framework
::
AttributeMap
&
attrs
)
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
seq_len
=
static_cast
<
size_t
>
(
GetSequenceLength
(
scope
));
StepScopes
scopes
=
CreateStepScopes
(
scope
,
seq_len
);
auto
reverse
=
Attr
<
bool
>
(
kReverse
);
...
...
paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc
浏览文件 @
5046869e
...
...
@@ -75,8 +75,10 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
"Cannot find input lod tensor variable %s"
,
Input
(
"X"
))
...
...
paddle/fluid/operators/rnn_memory_helper_op.cc
浏览文件 @
5046869e
...
...
@@ -24,8 +24,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
mem_var_name
=
Input
(
"X"
);
auto
*
mem_var
=
scope
.
FindVar
(
mem_var_name
);
PADDLE_ENFORCE
(
mem_var
!=
nullptr
,
...
...
@@ -76,8 +78,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
out_grad_var_name
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
*
out_grad_var
=
scope
.
FindVar
(
out_grad_var_name
);
...
...
paddle/fluid/operators/save_combine_op.cc
浏览文件 @
5046869e
...
...
@@ -63,8 +63,10 @@ class SaveCombineOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
5046869e
...
...
@@ -62,8 +62,10 @@ class SaveOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
...
...
paddle/fluid/operators/shrink_rnn_memory_op.cc
浏览文件 @
5046869e
...
...
@@ -27,8 +27,9 @@ class ShrinkRNNMemoryOp : public ArrayOp {
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
x_var
=
scope
.
FindVar
(
Input
(
"X"
));
PADDLE_ENFORCE
(
x_var
!=
nullptr
,
"Input X must be set"
);
auto
&
x_tensor
=
x_var
->
Get
<
framework
::
LoDTensor
>
();
...
...
@@ -108,8 +109,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
dout_var
=
scope
.
FindVar
(
Input
(
framework
::
GradVarName
(
"Out"
)));
auto
*
dx_var
=
scope
.
FindVar
(
Output
(
framework
::
GradVarName
(
"X"
)));
PADDLE_ENFORCE
(
dx_var
!=
nullptr
,
"Input Gradient should not be nullptr"
);
...
...
paddle/fluid/operators/split_lod_tensor_op.cc
浏览文件 @
5046869e
...
...
@@ -33,8 +33,10 @@ class SplitLoDTensorOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
mask
=
scope
.
FindVar
(
Input
(
"Mask"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
*
out_true
=
...
...
paddle/fluid/operators/tensor_array_read_write_op.cc
浏览文件 @
5046869e
...
...
@@ -24,8 +24,9 @@ class WriteToArrayOp : public ArrayOp {
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
if
(
x
==
nullptr
)
return
;
auto
&
x_tensor
=
x
->
Get
<
framework
::
LoDTensor
>
();
...
...
@@ -122,8 +123,10 @@ class ReadFromArrayOp : public ArrayOp {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
PADDLE_ENFORCE
(
x
!=
nullptr
,
"X must be set"
);
auto
&
x_array
=
x
->
Get
<
framework
::
LoDTensorArray
>
();
...
...
paddle/fluid/operators/while_op.cc
浏览文件 @
5046869e
...
...
@@ -39,8 +39,9 @@ class WhileOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
Input
(
kCondition
)));
auto
&
cond
=
scope
.
FindVar
(
Input
(
kCondition
))
->
Get
<
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
cond
.
dims
(),
paddle
::
framework
::
make_ddim
({
1
}));
...
...
@@ -99,8 +100,9 @@ class WhileGradOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录