Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5046869e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5046869e
编写于
2月 12, 2018
作者:
Y
Yu Yang
提交者:
GitHub
2月 12, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #8287 from tonyyang-svail/operator_set_device
Correctly handle cuda place for operators
上级
7757a8ad
40c7972d
变更
41
隐藏空白更改
内联
并排
Showing
41 changed file
with
214 addition
and
114 deletion
+214
-114
paddle/fluid/framework/op_registry_test.cc
paddle/fluid/framework/op_registry_test.cc
+8
-2
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+14
-2
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+10
-5
paddle/fluid/framework/operator_test.cc
paddle/fluid/framework/operator_test.cc
+8
-3
paddle/fluid/operators/array_to_lod_tensor_op.cc
paddle/fluid/operators/array_to_lod_tensor_op.cc
+4
-2
paddle/fluid/operators/assign_op.cc
paddle/fluid/operators/assign_op.cc
+4
-2
paddle/fluid/operators/beam_search_decode_op.cc
paddle/fluid/operators/beam_search_decode_op.cc
+4
-2
paddle/fluid/operators/beam_search_op.h
paddle/fluid/operators/beam_search_op.h
+3
-2
paddle/fluid/operators/cond_op.cc
paddle/fluid/operators/cond_op.cc
+1
-1
paddle/fluid/operators/cond_op.h
paddle/fluid/operators/cond_op.h
+3
-2
paddle/fluid/operators/conditional_block_op.cc
paddle/fluid/operators/conditional_block_op.cc
+8
-4
paddle/fluid/operators/create_reader_op.cc
paddle/fluid/operators/create_reader_op.cc
+12
-6
paddle/fluid/operators/feed_op.cc
paddle/fluid/operators/feed_op.cc
+4
-2
paddle/fluid/operators/fetch_op.cc
paddle/fluid/operators/fetch_op.cc
+3
-2
paddle/fluid/operators/fill_constant_op.cc
paddle/fluid/operators/fill_constant_op.cc
+4
-2
paddle/fluid/operators/fill_op.cc
paddle/fluid/operators/fill_op.cc
+4
-2
paddle/fluid/operators/get_places_op.cc
paddle/fluid/operators/get_places_op.cc
+4
-2
paddle/fluid/operators/increment_op.cc
paddle/fluid/operators/increment_op.cc
+3
-2
paddle/fluid/operators/is_empty_op.cc
paddle/fluid/operators/is_empty_op.cc
+3
-2
paddle/fluid/operators/load_combine_op.cc
paddle/fluid/operators/load_combine_op.cc
+4
-2
paddle/fluid/operators/load_op.cc
paddle/fluid/operators/load_op.cc
+4
-2
paddle/fluid/operators/lod_array_length_op.cc
paddle/fluid/operators/lod_array_length_op.cc
+4
-2
paddle/fluid/operators/lod_rank_table_op.cc
paddle/fluid/operators/lod_rank_table_op.cc
+4
-2
paddle/fluid/operators/lod_tensor_to_array_op.cc
paddle/fluid/operators/lod_tensor_to_array_op.cc
+4
-2
paddle/fluid/operators/max_sequence_len_op.cc
paddle/fluid/operators/max_sequence_len_op.cc
+3
-2
paddle/fluid/operators/merge_lod_tensor_op.cc
paddle/fluid/operators/merge_lod_tensor_op.cc
+4
-2
paddle/fluid/operators/nccl_op.cc
paddle/fluid/operators/nccl_op.cc
+3
-2
paddle/fluid/operators/net_op.h
paddle/fluid/operators/net_op.h
+14
-14
paddle/fluid/operators/net_op_test.cc
paddle/fluid/operators/net_op_test.cc
+4
-1
paddle/fluid/operators/parallel_do_op.cc
paddle/fluid/operators/parallel_do_op.cc
+6
-4
paddle/fluid/operators/print_op.cc
paddle/fluid/operators/print_op.cc
+3
-2
paddle/fluid/operators/read_op.cc
paddle/fluid/operators/read_op.cc
+4
-2
paddle/fluid/operators/recurrent_op.cc
paddle/fluid/operators/recurrent_op.cc
+6
-4
paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc
paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc
+4
-2
paddle/fluid/operators/rnn_memory_helper_op.cc
paddle/fluid/operators/rnn_memory_helper_op.cc
+8
-4
paddle/fluid/operators/save_combine_op.cc
paddle/fluid/operators/save_combine_op.cc
+4
-2
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+4
-2
paddle/fluid/operators/shrink_rnn_memory_op.cc
paddle/fluid/operators/shrink_rnn_memory_op.cc
+6
-4
paddle/fluid/operators/split_lod_tensor_op.cc
paddle/fluid/operators/split_lod_tensor_op.cc
+4
-2
paddle/fluid/operators/tensor_array_read_write_op.cc
paddle/fluid/operators/tensor_array_read_write_op.cc
+7
-4
paddle/fluid/operators/while_op.cc
paddle/fluid/operators/while_op.cc
+6
-4
未找到文件。
paddle/fluid/framework/op_registry_test.cc
浏览文件 @
5046869e
...
@@ -25,7 +25,10 @@ namespace framework {
...
@@ -25,7 +25,10 @@ namespace framework {
class
CosineOp
:
public
OperatorBase
{
class
CosineOp
:
public
OperatorBase
{
public:
public:
using
OperatorBase
::
OperatorBase
;
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
};
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
@@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class
MyTestOp
:
public
OperatorBase
{
class
MyTestOp
:
public
OperatorBase
{
public:
public:
using
OperatorBase
::
OperatorBase
;
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
};
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
5046869e
...
@@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
...
@@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}
}
}
}
void
OperatorBase
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
{
if
(
platform
::
is_gpu_place
(
place
))
{
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW
(
"Cannot run operator on place %s"
,
place
);
#else
auto
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
;
platform
::
SetDeviceId
(
dev_id
);
#endif
}
RunImpl
(
scope
,
place
);
}
std
::
string
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
std
::
string
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
&
ins
=
Inputs
(
name
);
auto
&
ins
=
Inputs
(
name
);
PADDLE_ENFORCE_LE
(
ins
.
size
(),
1UL
,
PADDLE_ENFORCE_LE
(
ins
.
size
(),
1UL
,
...
@@ -479,8 +491,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -479,8 +491,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
const
Scope
&
scope_
;
const
Scope
&
scope_
;
};
};
void
OperatorWithKernel
::
Run
(
const
Scope
&
scope
,
void
OperatorWithKernel
::
Run
Impl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
const
platform
::
Place
&
place
)
const
{
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
scope
);
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
scope
);
this
->
InferShape
(
&
infer_shape_ctx
);
this
->
InferShape
(
&
infer_shape_ctx
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
paddle/fluid/framework/operator.h
浏览文件 @
5046869e
...
@@ -89,8 +89,9 @@ class OperatorBase {
...
@@ -89,8 +89,9 @@ class OperatorBase {
std
::
string
DebugString
()
const
{
return
DebugStringEx
(
nullptr
);
}
std
::
string
DebugString
()
const
{
return
DebugStringEx
(
nullptr
);
}
/// Net will call this function to Run an op.
/// Net will call this interface function to Run an op.
virtual
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
=
0
;
// The implementation should be written at RunImpl
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
);
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual
void
Stop
()
{}
virtual
void
Stop
()
{}
...
@@ -144,6 +145,8 @@ class OperatorBase {
...
@@ -144,6 +145,8 @@ class OperatorBase {
private:
private:
void
GenerateTemporaryNames
();
void
GenerateTemporaryNames
();
void
CheckAllInputOutputSet
()
const
;
void
CheckAllInputOutputSet
()
const
;
virtual
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
=
0
;
};
};
// Macro for define a clone method.
// Macro for define a clone method.
...
@@ -168,10 +171,13 @@ class OperatorBase {
...
@@ -168,10 +171,13 @@ class OperatorBase {
class
NOP
:
public
OperatorBase
{
class
NOP
:
public
OperatorBase
{
public:
public:
using
OperatorBase
::
OperatorBase
;
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
std
::
unique_ptr
<
OperatorBase
>
Clone
()
const
override
{
std
::
unique_ptr
<
OperatorBase
>
Clone
()
const
override
{
return
std
::
unique_ptr
<
OperatorBase
>
(
new
NOP
(
*
this
));
return
std
::
unique_ptr
<
OperatorBase
>
(
new
NOP
(
*
this
));
}
}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
};
class
ExecutionContext
{
class
ExecutionContext
{
...
@@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase {
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
final
;
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
AllOpKernels
()
{
AllOpKernels
()
{
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
...
@@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase {
// indicate kernel DataType by input data. Defaultly all input data must be
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
// same.
proto
::
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
;
proto
::
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
;
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
final
;
};
};
extern
bool
OpSupportGPU
(
const
std
::
string
&
op_type
);
extern
bool
OpSupportGPU
(
const
std
::
string
&
op_type
);
...
...
paddle/fluid/framework/operator_test.cc
浏览文件 @
5046869e
...
@@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase {
...
@@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase {
OpWithoutKernelTest
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
OpWithoutKernelTest
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
),
x
(
1
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
),
x
(
1
)
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
++
op_run_num
;
++
op_run_num
;
ASSERT_EQ
(
static_cast
<
int
>
(
inputs_
.
size
()),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
inputs_
.
size
()),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
outputs_
.
size
()),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
outputs_
.
size
()),
1
);
...
@@ -259,8 +262,10 @@ class OperatorClone : public paddle::framework::OperatorBase {
...
@@ -259,8 +262,10 @@ class OperatorClone : public paddle::framework::OperatorBase {
const
paddle
::
framework
::
VariableNameMap
&
outputs
,
const
paddle
::
framework
::
VariableNameMap
&
outputs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
const
paddle
::
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
paddle
::
framework
::
Scope
&
scope
,
const
paddle
::
platform
::
Place
&
place
)
const
override
{}
private:
void
RunImpl
(
const
paddle
::
framework
::
Scope
&
scope
,
const
paddle
::
platform
::
Place
&
place
)
const
override
{}
};
};
TEST
(
Operator
,
Clone
)
{
TEST
(
Operator
,
Clone
)
{
...
...
paddle/fluid/operators/array_to_lod_tensor_op.cc
浏览文件 @
5046869e
...
@@ -31,8 +31,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
...
@@ -31,8 +31,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
rank_table
=
auto
&
rank_table
=
scope
.
FindVar
(
Input
(
"RankTable"
))
->
Get
<
framework
::
LoDRankTable
>
();
scope
.
FindVar
(
Input
(
"RankTable"
))
->
Get
<
framework
::
LoDRankTable
>
();
...
...
paddle/fluid/operators/assign_op.cc
浏览文件 @
5046869e
...
@@ -71,8 +71,10 @@ class AssignOp : public framework::OperatorBase {
...
@@ -71,8 +71,10 @@ class AssignOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
if
(
x
==
nullptr
)
{
if
(
x
==
nullptr
)
{
return
;
return
;
...
...
paddle/fluid/operators/beam_search_decode_op.cc
浏览文件 @
5046869e
...
@@ -55,8 +55,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
...
@@ -55,8 +55,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
paddle/fluid/operators/beam_search_op.h
浏览文件 @
5046869e
...
@@ -204,8 +204,9 @@ class BeamSearchOp : public framework::OperatorBase {
...
@@ -204,8 +204,9 @@ class BeamSearchOp : public framework::OperatorBase {
PADDLE_THROW
(
"Not Implemented"
);
PADDLE_THROW
(
"Not Implemented"
);
}
}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
dev_place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
ids_var
=
scope
.
FindVar
(
Input
(
"ids"
));
auto
ids_var
=
scope
.
FindVar
(
Input
(
"ids"
));
auto
scores_var
=
scope
.
FindVar
(
Input
(
"scores"
));
auto
scores_var
=
scope
.
FindVar
(
Input
(
"scores"
));
auto
pre_ids_var
=
scope
.
FindVar
(
Input
(
"pre_ids"
));
auto
pre_ids_var
=
scope
.
FindVar
(
Input
(
"pre_ids"
));
...
...
paddle/fluid/operators/cond_op.cc
浏览文件 @
5046869e
...
@@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
...
@@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
}
}
}
}
void
CondOp
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
void
CondOp
::
Run
Impl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
...
...
paddle/fluid/operators/cond_op.h
浏览文件 @
5046869e
...
@@ -77,8 +77,9 @@ class CondOp : public framework::OperatorBase {
...
@@ -77,8 +77,9 @@ class CondOp : public framework::OperatorBase {
sub_net_op_
[
FALSE_BRANCH
]
=
std
::
move
(
net
);
sub_net_op_
[
FALSE_BRANCH
]
=
std
::
move
(
net
);
}
}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
;
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
;
private:
private:
const
int
TRUE_BRANCH
=
0
;
const
int
TRUE_BRANCH
=
0
;
...
...
paddle/fluid/operators/conditional_block_op.cc
浏览文件 @
5046869e
...
@@ -65,8 +65,10 @@ class ConditionalBlockOp : public ConditionalOp {
...
@@ -65,8 +65,10 @@ class ConditionalBlockOp : public ConditionalOp {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
xs
=
InputTensors
(
scope
);
auto
xs
=
InputTensors
(
scope
);
bool
need_run
;
bool
need_run
;
...
@@ -128,8 +130,10 @@ class ConditionalBlockGradOp : public ConditionalOp {
...
@@ -128,8 +130,10 @@ class ConditionalBlockGradOp : public ConditionalOp {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
xs
=
this
->
InputTensors
(
scope
);
auto
xs
=
this
->
InputTensors
(
scope
);
bool
need_run
;
bool
need_run
;
...
...
paddle/fluid/operators/create_reader_op.cc
浏览文件 @
5046869e
...
@@ -106,8 +106,10 @@ template <typename T>
...
@@ -106,8 +106,10 @@ template <typename T>
class
CreateRandomDataGeneratorOp
:
public
framework
::
OperatorBase
{
class
CreateRandomDataGeneratorOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
...
@@ -155,8 +157,10 @@ class CreateRandomDataGeneratorOpMaker
...
@@ -155,8 +157,10 @@ class CreateRandomDataGeneratorOpMaker
class
CreateShuffleReaderOp
:
public
framework
::
OperatorBase
{
class
CreateShuffleReaderOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
...
@@ -187,8 +191,10 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -187,8 +191,10 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
class
CreateBatchReaderOp
:
public
framework
::
OperatorBase
{
class
CreateBatchReaderOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
...
...
paddle/fluid/operators/feed_op.cc
浏览文件 @
5046869e
...
@@ -24,8 +24,10 @@ class FeedOp : public framework::OperatorBase {
...
@@ -24,8 +24,10 @@ class FeedOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
feed_var_name
=
Input
(
"X"
);
auto
feed_var_name
=
Input
(
"X"
);
auto
*
feed_var
=
scope
.
FindVar
(
feed_var_name
);
auto
*
feed_var
=
scope
.
FindVar
(
feed_var_name
);
...
...
paddle/fluid/operators/fetch_op.cc
浏览文件 @
5046869e
...
@@ -26,8 +26,9 @@ class FetchOp : public framework::OperatorBase {
...
@@ -26,8 +26,9 @@ class FetchOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
fetch_var_name
=
Input
(
"X"
);
auto
fetch_var_name
=
Input
(
"X"
);
auto
*
fetch_var
=
scope
.
FindVar
(
fetch_var_name
);
auto
*
fetch_var
=
scope
.
FindVar
(
fetch_var_name
);
PADDLE_ENFORCE
(
fetch_var
!=
nullptr
,
PADDLE_ENFORCE
(
fetch_var
!=
nullptr
,
...
...
paddle/fluid/operators/fill_constant_op.cc
浏览文件 @
5046869e
...
@@ -33,8 +33,10 @@ class FillConstantInferShape : public framework::InferShapeBase {
...
@@ -33,8 +33,10 @@ class FillConstantInferShape : public framework::InferShapeBase {
class
FillConstantOp
:
public
framework
::
OperatorBase
{
class
FillConstantOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
data_type
=
auto
data_type
=
static_cast
<
framework
::
proto
::
DataType
>
(
Attr
<
int
>
(
"dtype"
));
static_cast
<
framework
::
proto
::
DataType
>
(
Attr
<
int
>
(
"dtype"
));
auto
value
=
Attr
<
float
>
(
"value"
);
auto
value
=
Attr
<
float
>
(
"value"
);
...
...
paddle/fluid/operators/fill_op.cc
浏览文件 @
5046869e
...
@@ -42,8 +42,10 @@ class FillOp : public framework::OperatorBase {
...
@@ -42,8 +42,10 @@ class FillOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
out
=
auto
&
out
=
detail
::
Ref
(
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)),
detail
::
Ref
(
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)),
"Cannot find variable %s"
,
Output
(
"Out"
))
"Cannot find variable %s"
,
Output
(
"Out"
))
...
...
paddle/fluid/operators/get_places_op.cc
浏览文件 @
5046869e
...
@@ -37,8 +37,10 @@ class GetPlacesOp : public framework::OperatorBase {
...
@@ -37,8 +37,10 @@ class GetPlacesOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
bool
is_gpu
;
bool
is_gpu
;
if
(
Attr
<
std
::
string
>
(
"device_type"
)
==
"AUTO"
)
{
if
(
Attr
<
std
::
string
>
(
"device_type"
)
==
"AUTO"
)
{
is_gpu
=
platform
::
is_gpu_place
(
place
);
is_gpu
=
platform
::
is_gpu_place
(
place
);
...
...
paddle/fluid/operators/increment_op.cc
浏览文件 @
5046869e
...
@@ -51,8 +51,9 @@ class IncrementOp : public framework::OperatorBase {
...
@@ -51,8 +51,9 @@ class IncrementOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
out
=
auto
&
out
=
*
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDTensor
>
();
*
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/operators/is_empty_op.cc
浏览文件 @
5046869e
...
@@ -28,8 +28,9 @@ class IsEmptyOp : public framework::OperatorBase {
...
@@ -28,8 +28,9 @@ class IsEmptyOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
// get input
// get input
auto
*
var
=
scope
.
FindVar
(
Input
(
kInput
));
auto
*
var
=
scope
.
FindVar
(
Input
(
kInput
));
PADDLE_ENFORCE_NOT_NULL
(
var
);
PADDLE_ENFORCE_NOT_NULL
(
var
);
...
...
paddle/fluid/operators/load_combine_op.cc
浏览文件 @
5046869e
...
@@ -26,8 +26,10 @@ class LoadCombineOp : public framework::OperatorBase {
...
@@ -26,8 +26,10 @@ class LoadCombineOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
std
::
ifstream
fin
(
filename
);
std
::
ifstream
fin
(
filename
);
...
...
paddle/fluid/operators/load_op.cc
浏览文件 @
5046869e
...
@@ -25,8 +25,10 @@ class LoadOp : public framework::OperatorBase {
...
@@ -25,8 +25,10 @@ class LoadOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
std
::
ifstream
fin
(
filename
);
std
::
ifstream
fin
(
filename
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load op"
,
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load op"
,
...
...
paddle/fluid/operators/lod_array_length_op.cc
浏览文件 @
5046869e
...
@@ -25,8 +25,10 @@ class LoDArrayLengthOp : public framework::OperatorBase {
...
@@ -25,8 +25,10 @@ class LoDArrayLengthOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
out
=
auto
&
out
=
*
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDTensor
>
();
*
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/operators/lod_rank_table_op.cc
浏览文件 @
5046869e
...
@@ -23,8 +23,10 @@ class LoDRankTableOp : public framework::OperatorBase {
...
@@ -23,8 +23,10 @@ class LoDRankTableOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
*
out
=
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDRankTable
>
();
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDRankTable
>
();
...
...
paddle/fluid/operators/lod_tensor_to_array_op.cc
浏览文件 @
5046869e
...
@@ -32,8 +32,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
...
@@ -32,8 +32,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
"Cannot find input %s"
,
auto
&
x
=
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
"Cannot find input %s"
,
Input
(
"X"
))
Input
(
"X"
))
.
Get
<
framework
::
LoDTensor
>
();
.
Get
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/operators/max_sequence_len_op.cc
浏览文件 @
5046869e
...
@@ -27,8 +27,9 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
...
@@ -27,8 +27,9 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
dev_place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
rank_table
=
auto
&
rank_table
=
scope
.
FindVar
(
Input
(
"RankTable"
))
->
Get
<
framework
::
LoDRankTable
>
();
scope
.
FindVar
(
Input
(
"RankTable"
))
->
Get
<
framework
::
LoDRankTable
>
();
auto
*
out
=
auto
*
out
=
...
...
paddle/fluid/operators/merge_lod_tensor_op.cc
浏览文件 @
5046869e
...
@@ -27,8 +27,10 @@ class MergeLoDTensorOp : public framework::OperatorBase {
...
@@ -27,8 +27,10 @@ class MergeLoDTensorOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
paddle/fluid/operators/nccl_op.cc
浏览文件 @
5046869e
...
@@ -26,8 +26,9 @@ class NCCLInitOp : public framework::OperatorBase {
...
@@ -26,8 +26,9 @@ class NCCLInitOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
auto
&
name
=
Output
(
"Communicator"
);
const
auto
&
name
=
Output
(
"Communicator"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
name
),
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
name
),
"Can not find variable '%s' in the scope."
,
name
);
"Can not find variable '%s' in the scope."
,
name
);
...
...
paddle/fluid/operators/net_op.h
浏览文件 @
5046869e
...
@@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase {
...
@@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase {
this
->
CompleteAddOp
();
this
->
CompleteAddOp
();
}
}
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
scope
,
place
);
}
}
bool
SupportGPU
()
const
override
{
bool
SupportGPU
()
const
override
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
if
(
!
op
->
SupportGPU
())
{
if
(
!
op
->
SupportGPU
())
{
...
@@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase {
...
@@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase {
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops_
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops_
;
private:
private:
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
scope
,
place
);
}
}
bool
add_op_done_
{
false
};
bool
add_op_done_
{
false
};
std
::
set
<
std
::
string
>
intermediate_outputs_
;
std
::
set
<
std
::
string
>
intermediate_outputs_
;
...
...
paddle/fluid/operators/net_op_test.cc
浏览文件 @
5046869e
...
@@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase {
...
@@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase {
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
DEFINE_OP_CLONE_METHOD
(
TestOp
);
DEFINE_OP_CLONE_METHOD
(
TestOp
);
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
++
run_cnt
;
++
run_cnt
;
}
}
};
};
...
...
paddle/fluid/operators/parallel_do_op.cc
浏览文件 @
5046869e
...
@@ -118,8 +118,9 @@ class ParallelDoOp : public framework::OperatorBase {
...
@@ -118,8 +118,9 @@ class ParallelDoOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
...
@@ -207,8 +208,9 @@ class ParallelDoGradOp : public framework::OperatorBase {
...
@@ -207,8 +208,9 @@ class ParallelDoGradOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kParallelBlock
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kParallelBlock
);
auto
*
program
=
block
->
Program
();
auto
*
program
=
block
->
Program
();
...
...
paddle/fluid/operators/print_op.cc
浏览文件 @
5046869e
...
@@ -130,8 +130,9 @@ class TensorPrintOp : public framework::OperatorBase {
...
@@ -130,8 +130,9 @@ class TensorPrintOp : public framework::OperatorBase {
PADDLE_THROW
(
"Not implemented."
);
PADDLE_THROW
(
"Not implemented."
);
}
}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
framework
::
Variable
*
in_var_ptr
=
nullptr
;
const
framework
::
Variable
*
in_var_ptr
=
nullptr
;
std
::
string
phase
=
kForward
;
std
::
string
phase
=
kForward
;
std
::
string
printed_var_name
=
""
;
std
::
string
printed_var_name
=
""
;
...
...
paddle/fluid/operators/read_op.cc
浏览文件 @
5046869e
...
@@ -54,8 +54,10 @@ class ReadInferVarType : public framework::VarTypeInference {
...
@@ -54,8 +54,10 @@ class ReadInferVarType : public framework::VarTypeInference {
class
ReadOp
:
public
framework
::
OperatorBase
{
class
ReadOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
framework
::
ReaderHolder
*
reader
=
framework
::
ReaderHolder
*
reader
=
scope
.
FindVar
(
Input
(
"Reader"
))
->
GetMutable
<
framework
::
ReaderHolder
>
();
scope
.
FindVar
(
Input
(
"Reader"
))
->
GetMutable
<
framework
::
ReaderHolder
>
();
if
(
!
reader
->
HasNext
())
{
if
(
!
reader
->
HasNext
())
{
...
...
paddle/fluid/operators/recurrent_op.cc
浏览文件 @
5046869e
...
@@ -226,8 +226,9 @@ class RecurrentOp : public RecurrentBase {
...
@@ -226,8 +226,9 @@ class RecurrentOp : public RecurrentBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
seq_len
=
static_cast
<
size_t
>
(
this
->
GetSequenceLength
(
scope
));
auto
seq_len
=
static_cast
<
size_t
>
(
this
->
GetSequenceLength
(
scope
));
VLOG
(
3
)
<<
"Static RNN input sequence length = "
<<
seq_len
;
VLOG
(
3
)
<<
"Static RNN input sequence length = "
<<
seq_len
;
StepScopes
scopes
=
CreateStepScopes
(
scope
,
seq_len
);
StepScopes
scopes
=
CreateStepScopes
(
scope
,
seq_len
);
...
@@ -315,8 +316,9 @@ class RecurrentGradOp : public RecurrentBase {
...
@@ -315,8 +316,9 @@ class RecurrentGradOp : public RecurrentBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
seq_len
=
static_cast
<
size_t
>
(
GetSequenceLength
(
scope
));
auto
seq_len
=
static_cast
<
size_t
>
(
GetSequenceLength
(
scope
));
StepScopes
scopes
=
CreateStepScopes
(
scope
,
seq_len
);
StepScopes
scopes
=
CreateStepScopes
(
scope
,
seq_len
);
auto
reverse
=
Attr
<
bool
>
(
kReverse
);
auto
reverse
=
Attr
<
bool
>
(
kReverse
);
...
...
paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc
浏览文件 @
5046869e
...
@@ -75,8 +75,10 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
...
@@ -75,8 +75,10 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
auto
&
x
=
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
"Cannot find input lod tensor variable %s"
,
Input
(
"X"
))
"Cannot find input lod tensor variable %s"
,
Input
(
"X"
))
...
...
paddle/fluid/operators/rnn_memory_helper_op.cc
浏览文件 @
5046869e
...
@@ -24,8 +24,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
...
@@ -24,8 +24,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
mem_var_name
=
Input
(
"X"
);
auto
mem_var_name
=
Input
(
"X"
);
auto
*
mem_var
=
scope
.
FindVar
(
mem_var_name
);
auto
*
mem_var
=
scope
.
FindVar
(
mem_var_name
);
PADDLE_ENFORCE
(
mem_var
!=
nullptr
,
PADDLE_ENFORCE
(
mem_var
!=
nullptr
,
...
@@ -76,8 +78,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
...
@@ -76,8 +78,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
out_grad_var_name
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
out_grad_var_name
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
*
out_grad_var
=
scope
.
FindVar
(
out_grad_var_name
);
auto
*
out_grad_var
=
scope
.
FindVar
(
out_grad_var_name
);
...
...
paddle/fluid/operators/save_combine_op.cc
浏览文件 @
5046869e
...
@@ -63,8 +63,10 @@ class SaveCombineOp : public framework::OperatorBase {
...
@@ -63,8 +63,10 @@ class SaveCombineOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
5046869e
...
@@ -62,8 +62,10 @@ class SaveOp : public framework::OperatorBase {
...
@@ -62,8 +62,10 @@ class SaveOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
...
...
paddle/fluid/operators/shrink_rnn_memory_op.cc
浏览文件 @
5046869e
...
@@ -27,8 +27,9 @@ class ShrinkRNNMemoryOp : public ArrayOp {
...
@@ -27,8 +27,9 @@ class ShrinkRNNMemoryOp : public ArrayOp {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
x_var
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x_var
=
scope
.
FindVar
(
Input
(
"X"
));
PADDLE_ENFORCE
(
x_var
!=
nullptr
,
"Input X must be set"
);
PADDLE_ENFORCE
(
x_var
!=
nullptr
,
"Input X must be set"
);
auto
&
x_tensor
=
x_var
->
Get
<
framework
::
LoDTensor
>
();
auto
&
x_tensor
=
x_var
->
Get
<
framework
::
LoDTensor
>
();
...
@@ -108,8 +109,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
...
@@ -108,8 +109,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
dout_var
=
scope
.
FindVar
(
Input
(
framework
::
GradVarName
(
"Out"
)));
auto
*
dout_var
=
scope
.
FindVar
(
Input
(
framework
::
GradVarName
(
"Out"
)));
auto
*
dx_var
=
scope
.
FindVar
(
Output
(
framework
::
GradVarName
(
"X"
)));
auto
*
dx_var
=
scope
.
FindVar
(
Output
(
framework
::
GradVarName
(
"X"
)));
PADDLE_ENFORCE
(
dx_var
!=
nullptr
,
"Input Gradient should not be nullptr"
);
PADDLE_ENFORCE
(
dx_var
!=
nullptr
,
"Input Gradient should not be nullptr"
);
...
...
paddle/fluid/operators/split_lod_tensor_op.cc
浏览文件 @
5046869e
...
@@ -33,8 +33,10 @@ class SplitLoDTensorOp : public framework::OperatorBase {
...
@@ -33,8 +33,10 @@ class SplitLoDTensorOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
mask
=
scope
.
FindVar
(
Input
(
"Mask"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
mask
=
scope
.
FindVar
(
Input
(
"Mask"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
*
out_true
=
auto
*
out_true
=
...
...
paddle/fluid/operators/tensor_array_read_write_op.cc
浏览文件 @
5046869e
...
@@ -24,8 +24,9 @@ class WriteToArrayOp : public ArrayOp {
...
@@ -24,8 +24,9 @@ class WriteToArrayOp : public ArrayOp {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
if
(
x
==
nullptr
)
return
;
if
(
x
==
nullptr
)
return
;
auto
&
x_tensor
=
x
->
Get
<
framework
::
LoDTensor
>
();
auto
&
x_tensor
=
x
->
Get
<
framework
::
LoDTensor
>
();
...
@@ -122,8 +123,10 @@ class ReadFromArrayOp : public ArrayOp {
...
@@ -122,8 +123,10 @@ class ReadFromArrayOp : public ArrayOp {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
PADDLE_ENFORCE
(
x
!=
nullptr
,
"X must be set"
);
PADDLE_ENFORCE
(
x
!=
nullptr
,
"X must be set"
);
auto
&
x_array
=
x
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
x_array
=
x
->
Get
<
framework
::
LoDTensorArray
>
();
...
...
paddle/fluid/operators/while_op.cc
浏览文件 @
5046869e
...
@@ -39,8 +39,9 @@ class WhileOp : public framework::OperatorBase {
...
@@ -39,8 +39,9 @@ class WhileOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
dev_place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
Input
(
kCondition
)));
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
Input
(
kCondition
)));
auto
&
cond
=
scope
.
FindVar
(
Input
(
kCondition
))
->
Get
<
LoDTensor
>
();
auto
&
cond
=
scope
.
FindVar
(
Input
(
kCondition
))
->
Get
<
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
cond
.
dims
(),
paddle
::
framework
::
make_ddim
({
1
}));
PADDLE_ENFORCE_EQ
(
cond
.
dims
(),
paddle
::
framework
::
make_ddim
({
1
}));
...
@@ -99,8 +100,9 @@ class WhileGradOp : public framework::OperatorBase {
...
@@ -99,8 +100,9 @@ class WhileGradOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
dev_place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录