Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
98c94373
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
98c94373
编写于
2月 09, 2018
作者:
Y
Yang Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use op run as wrapper of run_impl; make run_impl as private virtual function
上级
f605d00f
变更
41
隐藏空白更改
内联
并排
Showing
41 changed file
with
214 addition
and
114 deletion
+214
-114
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+8
-2
paddle/framework/operator.cc
paddle/framework/operator.cc
+14
-2
paddle/framework/operator.h
paddle/framework/operator.h
+10
-5
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+8
-3
paddle/operators/array_to_lod_tensor_op.cc
paddle/operators/array_to_lod_tensor_op.cc
+4
-2
paddle/operators/assign_op.cc
paddle/operators/assign_op.cc
+4
-2
paddle/operators/beam_search_decode_op.cc
paddle/operators/beam_search_decode_op.cc
+4
-2
paddle/operators/beam_search_op.h
paddle/operators/beam_search_op.h
+3
-2
paddle/operators/cond_op.cc
paddle/operators/cond_op.cc
+1
-1
paddle/operators/cond_op.h
paddle/operators/cond_op.h
+3
-2
paddle/operators/conditional_block_op.cc
paddle/operators/conditional_block_op.cc
+8
-4
paddle/operators/create_reader_op.cc
paddle/operators/create_reader_op.cc
+12
-6
paddle/operators/feed_op.cc
paddle/operators/feed_op.cc
+4
-2
paddle/operators/fetch_op.cc
paddle/operators/fetch_op.cc
+3
-2
paddle/operators/fill_constant_op.cc
paddle/operators/fill_constant_op.cc
+4
-2
paddle/operators/fill_op.cc
paddle/operators/fill_op.cc
+4
-2
paddle/operators/get_places_op.cc
paddle/operators/get_places_op.cc
+4
-2
paddle/operators/increment_op.cc
paddle/operators/increment_op.cc
+3
-2
paddle/operators/is_empty_op.cc
paddle/operators/is_empty_op.cc
+3
-2
paddle/operators/load_combine_op.cc
paddle/operators/load_combine_op.cc
+4
-2
paddle/operators/load_op.cc
paddle/operators/load_op.cc
+4
-2
paddle/operators/lod_array_length_op.cc
paddle/operators/lod_array_length_op.cc
+4
-2
paddle/operators/lod_rank_table_op.cc
paddle/operators/lod_rank_table_op.cc
+4
-2
paddle/operators/lod_tensor_to_array_op.cc
paddle/operators/lod_tensor_to_array_op.cc
+4
-2
paddle/operators/max_sequence_len_op.cc
paddle/operators/max_sequence_len_op.cc
+3
-2
paddle/operators/merge_lod_tensor_op.cc
paddle/operators/merge_lod_tensor_op.cc
+4
-2
paddle/operators/nccl_op.cc
paddle/operators/nccl_op.cc
+3
-2
paddle/operators/net_op.h
paddle/operators/net_op.h
+14
-14
paddle/operators/net_op_test.cc
paddle/operators/net_op_test.cc
+4
-1
paddle/operators/parallel_do_op.cc
paddle/operators/parallel_do_op.cc
+6
-4
paddle/operators/print_op.cc
paddle/operators/print_op.cc
+3
-2
paddle/operators/read_op.cc
paddle/operators/read_op.cc
+4
-2
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+6
-4
paddle/operators/reorder_lod_tensor_by_rank_op.cc
paddle/operators/reorder_lod_tensor_by_rank_op.cc
+4
-2
paddle/operators/rnn_memory_helper_op.cc
paddle/operators/rnn_memory_helper_op.cc
+8
-4
paddle/operators/save_combine_op.cc
paddle/operators/save_combine_op.cc
+4
-2
paddle/operators/save_op.cc
paddle/operators/save_op.cc
+4
-2
paddle/operators/shrink_rnn_memory_op.cc
paddle/operators/shrink_rnn_memory_op.cc
+6
-4
paddle/operators/split_lod_tensor_op.cc
paddle/operators/split_lod_tensor_op.cc
+4
-2
paddle/operators/tensor_array_read_write_op.cc
paddle/operators/tensor_array_read_write_op.cc
+7
-4
paddle/operators/while_op.cc
paddle/operators/while_op.cc
+6
-4
未找到文件。
paddle/framework/op_registry_test.cc
浏览文件 @
98c94373
...
@@ -25,7 +25,10 @@ namespace framework {
...
@@ -25,7 +25,10 @@ namespace framework {
class
CosineOp
:
public
OperatorBase
{
class
CosineOp
:
public
OperatorBase
{
public:
public:
using
OperatorBase
::
OperatorBase
;
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
};
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
@@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class
MyTestOp
:
public
OperatorBase
{
class
MyTestOp
:
public
OperatorBase
{
public:
public:
using
OperatorBase
::
OperatorBase
;
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
};
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
...
paddle/framework/operator.cc
浏览文件 @
98c94373
...
@@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
...
@@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}
}
}
}
void
OperatorBase
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
{
if
(
platform
::
is_gpu_place
(
place
))
{
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW
(
"Cannot run operator on place %s"
,
place
);
#else
auto
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
;
platform
::
SetDeviceId
(
dev_id
);
#endif
}
RunImpl
(
scope
,
place
);
}
std
::
string
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
std
::
string
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
&
ins
=
Inputs
(
name
);
auto
&
ins
=
Inputs
(
name
);
PADDLE_ENFORCE_LE
(
ins
.
size
(),
1UL
,
PADDLE_ENFORCE_LE
(
ins
.
size
(),
1UL
,
...
@@ -475,8 +487,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -475,8 +487,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
const
Scope
&
scope_
;
const
Scope
&
scope_
;
};
};
void
OperatorWithKernel
::
Run
(
const
Scope
&
scope
,
void
OperatorWithKernel
::
Run
Impl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
const
platform
::
Place
&
place
)
const
{
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
scope
);
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
scope
);
this
->
InferShape
(
&
infer_shape_ctx
);
this
->
InferShape
(
&
infer_shape_ctx
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
paddle/framework/operator.h
浏览文件 @
98c94373
...
@@ -89,8 +89,9 @@ class OperatorBase {
...
@@ -89,8 +89,9 @@ class OperatorBase {
std
::
string
DebugString
()
const
{
return
DebugStringEx
(
nullptr
);
}
std
::
string
DebugString
()
const
{
return
DebugStringEx
(
nullptr
);
}
/// Net will call this function to Run an op.
/// Net will call this interface function to Run an op.
virtual
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
=
0
;
// The implementation should be written at RunImpl
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
);
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual
void
Stop
()
{}
virtual
void
Stop
()
{}
...
@@ -144,6 +145,8 @@ class OperatorBase {
...
@@ -144,6 +145,8 @@ class OperatorBase {
private:
private:
void
GenerateTemporaryNames
();
void
GenerateTemporaryNames
();
void
CheckAllInputOutputSet
()
const
;
void
CheckAllInputOutputSet
()
const
;
virtual
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
=
0
;
};
};
// Macro for define a clone method.
// Macro for define a clone method.
...
@@ -168,10 +171,13 @@ class OperatorBase {
...
@@ -168,10 +171,13 @@ class OperatorBase {
class
NOP
:
public
OperatorBase
{
class
NOP
:
public
OperatorBase
{
public:
public:
using
OperatorBase
::
OperatorBase
;
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
std
::
unique_ptr
<
OperatorBase
>
Clone
()
const
override
{
std
::
unique_ptr
<
OperatorBase
>
Clone
()
const
override
{
return
std
::
unique_ptr
<
OperatorBase
>
(
new
NOP
(
*
this
));
return
std
::
unique_ptr
<
OperatorBase
>
(
new
NOP
(
*
this
));
}
}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
};
class
ExecutionContext
{
class
ExecutionContext
{
...
@@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase {
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
final
;
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
AllOpKernels
()
{
AllOpKernels
()
{
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
...
@@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase {
// indicate kernel DataType by input data. Defaultly all input data must be
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
// same.
proto
::
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
;
proto
::
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
;
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
final
;
};
};
extern
bool
OpSupportGPU
(
const
std
::
string
&
op_type
);
extern
bool
OpSupportGPU
(
const
std
::
string
&
op_type
);
...
...
paddle/framework/operator_test.cc
浏览文件 @
98c94373
...
@@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase {
...
@@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase {
OpWithoutKernelTest
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
OpWithoutKernelTest
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
),
x
(
1
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
),
x
(
1
)
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
++
op_run_num
;
++
op_run_num
;
ASSERT_EQ
(
static_cast
<
int
>
(
inputs_
.
size
()),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
inputs_
.
size
()),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
outputs_
.
size
()),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
outputs_
.
size
()),
1
);
...
@@ -259,8 +262,10 @@ class OperatorClone : public paddle::framework::OperatorBase {
...
@@ -259,8 +262,10 @@ class OperatorClone : public paddle::framework::OperatorBase {
const
paddle
::
framework
::
VariableNameMap
&
outputs
,
const
paddle
::
framework
::
VariableNameMap
&
outputs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
const
paddle
::
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
paddle
::
framework
::
Scope
&
scope
,
const
paddle
::
platform
::
Place
&
place
)
const
override
{}
private:
void
RunImpl
(
const
paddle
::
framework
::
Scope
&
scope
,
const
paddle
::
platform
::
Place
&
place
)
const
override
{}
};
};
TEST
(
Operator
,
Clone
)
{
TEST
(
Operator
,
Clone
)
{
...
...
paddle/operators/array_to_lod_tensor_op.cc
浏览文件 @
98c94373
...
@@ -31,8 +31,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
...
@@ -31,8 +31,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
rank_table
=
auto
&
rank_table
=
scope
.
FindVar
(
Input
(
"RankTable"
))
->
Get
<
framework
::
LoDRankTable
>
();
scope
.
FindVar
(
Input
(
"RankTable"
))
->
Get
<
framework
::
LoDRankTable
>
();
...
...
paddle/operators/assign_op.cc
浏览文件 @
98c94373
...
@@ -71,8 +71,10 @@ class AssignOp : public framework::OperatorBase {
...
@@ -71,8 +71,10 @@ class AssignOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
if
(
x
==
nullptr
)
{
if
(
x
==
nullptr
)
{
return
;
return
;
...
...
paddle/operators/beam_search_decode_op.cc
浏览文件 @
98c94373
...
@@ -55,8 +55,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
...
@@ -55,8 +55,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
paddle/operators/beam_search_op.h
浏览文件 @
98c94373
...
@@ -204,8 +204,9 @@ class BeamSearchOp : public framework::OperatorBase {
...
@@ -204,8 +204,9 @@ class BeamSearchOp : public framework::OperatorBase {
PADDLE_THROW
(
"Not Implemented"
);
PADDLE_THROW
(
"Not Implemented"
);
}
}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
dev_place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
ids_var
=
scope
.
FindVar
(
Input
(
"ids"
));
auto
ids_var
=
scope
.
FindVar
(
Input
(
"ids"
));
auto
scores_var
=
scope
.
FindVar
(
Input
(
"scores"
));
auto
scores_var
=
scope
.
FindVar
(
Input
(
"scores"
));
auto
pre_ids_var
=
scope
.
FindVar
(
Input
(
"pre_ids"
));
auto
pre_ids_var
=
scope
.
FindVar
(
Input
(
"pre_ids"
));
...
...
paddle/operators/cond_op.cc
浏览文件 @
98c94373
...
@@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
...
@@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
}
}
}
}
void
CondOp
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
void
CondOp
::
Run
Impl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
...
...
paddle/operators/cond_op.h
浏览文件 @
98c94373
...
@@ -77,8 +77,9 @@ class CondOp : public framework::OperatorBase {
...
@@ -77,8 +77,9 @@ class CondOp : public framework::OperatorBase {
sub_net_op_
[
FALSE_BRANCH
]
=
std
::
move
(
net
);
sub_net_op_
[
FALSE_BRANCH
]
=
std
::
move
(
net
);
}
}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
;
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
;
private:
private:
const
int
TRUE_BRANCH
=
0
;
const
int
TRUE_BRANCH
=
0
;
...
...
paddle/operators/conditional_block_op.cc
浏览文件 @
98c94373
...
@@ -65,8 +65,10 @@ class ConditionalBlockOp : public ConditionalOp {
...
@@ -65,8 +65,10 @@ class ConditionalBlockOp : public ConditionalOp {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
xs
=
InputTensors
(
scope
);
auto
xs
=
InputTensors
(
scope
);
bool
need_run
;
bool
need_run
;
...
@@ -128,8 +130,10 @@ class ConditionalBlockGradOp : public ConditionalOp {
...
@@ -128,8 +130,10 @@ class ConditionalBlockGradOp : public ConditionalOp {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
xs
=
this
->
InputTensors
(
scope
);
auto
xs
=
this
->
InputTensors
(
scope
);
bool
need_run
;
bool
need_run
;
...
...
paddle/operators/create_reader_op.cc
浏览文件 @
98c94373
...
@@ -72,8 +72,10 @@ template <typename T>
...
@@ -72,8 +72,10 @@ template <typename T>
class
CreateRandomDataGeneratorOp
:
public
framework
::
OperatorBase
{
class
CreateRandomDataGeneratorOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
...
@@ -120,8 +122,10 @@ class CreateRandomDataGeneratorOpMaker
...
@@ -120,8 +122,10 @@ class CreateRandomDataGeneratorOpMaker
class
CreateShuffleReaderOp
:
public
framework
::
OperatorBase
{
class
CreateShuffleReaderOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
...
@@ -152,8 +156,10 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -152,8 +156,10 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
class
CreateBatchReaderOp
:
public
framework
::
OperatorBase
{
class
CreateBatchReaderOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
...
...
paddle/operators/feed_op.cc
浏览文件 @
98c94373
...
@@ -24,8 +24,10 @@ class FeedOp : public framework::OperatorBase {
...
@@ -24,8 +24,10 @@ class FeedOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
feed_var_name
=
Input
(
"X"
);
auto
feed_var_name
=
Input
(
"X"
);
auto
*
feed_var
=
scope
.
FindVar
(
feed_var_name
);
auto
*
feed_var
=
scope
.
FindVar
(
feed_var_name
);
...
...
paddle/operators/fetch_op.cc
浏览文件 @
98c94373
...
@@ -26,8 +26,9 @@ class FetchOp : public framework::OperatorBase {
...
@@ -26,8 +26,9 @@ class FetchOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
fetch_var_name
=
Input
(
"X"
);
auto
fetch_var_name
=
Input
(
"X"
);
auto
*
fetch_var
=
scope
.
FindVar
(
fetch_var_name
);
auto
*
fetch_var
=
scope
.
FindVar
(
fetch_var_name
);
PADDLE_ENFORCE
(
fetch_var
!=
nullptr
,
PADDLE_ENFORCE
(
fetch_var
!=
nullptr
,
...
...
paddle/operators/fill_constant_op.cc
浏览文件 @
98c94373
...
@@ -33,8 +33,10 @@ class FillConstantInferShape : public framework::InferShapeBase {
...
@@ -33,8 +33,10 @@ class FillConstantInferShape : public framework::InferShapeBase {
class
FillConstantOp
:
public
framework
::
OperatorBase
{
class
FillConstantOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
data_type
=
auto
data_type
=
static_cast
<
framework
::
proto
::
DataType
>
(
Attr
<
int
>
(
"dtype"
));
static_cast
<
framework
::
proto
::
DataType
>
(
Attr
<
int
>
(
"dtype"
));
auto
value
=
Attr
<
float
>
(
"value"
);
auto
value
=
Attr
<
float
>
(
"value"
);
...
...
paddle/operators/fill_op.cc
浏览文件 @
98c94373
...
@@ -42,8 +42,10 @@ class FillOp : public framework::OperatorBase {
...
@@ -42,8 +42,10 @@ class FillOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
out
=
auto
&
out
=
detail
::
Ref
(
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)),
detail
::
Ref
(
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)),
"Cannot find variable %s"
,
Output
(
"Out"
))
"Cannot find variable %s"
,
Output
(
"Out"
))
...
...
paddle/operators/get_places_op.cc
浏览文件 @
98c94373
...
@@ -37,8 +37,10 @@ class GetPlacesOp : public framework::OperatorBase {
...
@@ -37,8 +37,10 @@ class GetPlacesOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
bool
is_gpu
;
bool
is_gpu
;
if
(
Attr
<
std
::
string
>
(
"device_type"
)
==
"AUTO"
)
{
if
(
Attr
<
std
::
string
>
(
"device_type"
)
==
"AUTO"
)
{
is_gpu
=
platform
::
is_gpu_place
(
place
);
is_gpu
=
platform
::
is_gpu_place
(
place
);
...
...
paddle/operators/increment_op.cc
浏览文件 @
98c94373
...
@@ -51,8 +51,9 @@ class IncrementOp : public framework::OperatorBase {
...
@@ -51,8 +51,9 @@ class IncrementOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
out
=
auto
&
out
=
*
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDTensor
>
();
*
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
paddle/operators/is_empty_op.cc
浏览文件 @
98c94373
...
@@ -28,8 +28,9 @@ class IsEmptyOp : public framework::OperatorBase {
...
@@ -28,8 +28,9 @@ class IsEmptyOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
// get input
// get input
auto
*
var
=
scope
.
FindVar
(
Input
(
kInput
));
auto
*
var
=
scope
.
FindVar
(
Input
(
kInput
));
PADDLE_ENFORCE_NOT_NULL
(
var
);
PADDLE_ENFORCE_NOT_NULL
(
var
);
...
...
paddle/operators/load_combine_op.cc
浏览文件 @
98c94373
...
@@ -26,8 +26,10 @@ class LoadCombineOp : public framework::OperatorBase {
...
@@ -26,8 +26,10 @@ class LoadCombineOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
std
::
ifstream
fin
(
filename
);
std
::
ifstream
fin
(
filename
);
...
...
paddle/operators/load_op.cc
浏览文件 @
98c94373
...
@@ -25,8 +25,10 @@ class LoadOp : public framework::OperatorBase {
...
@@ -25,8 +25,10 @@ class LoadOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
std
::
ifstream
fin
(
filename
);
std
::
ifstream
fin
(
filename
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load op"
,
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s for load op"
,
...
...
paddle/operators/lod_array_length_op.cc
浏览文件 @
98c94373
...
@@ -25,8 +25,10 @@ class LoDArrayLengthOp : public framework::OperatorBase {
...
@@ -25,8 +25,10 @@ class LoDArrayLengthOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
out
=
auto
&
out
=
*
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDTensor
>
();
*
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
paddle/operators/lod_rank_table_op.cc
浏览文件 @
98c94373
...
@@ -23,8 +23,10 @@ class LoDRankTableOp : public framework::OperatorBase {
...
@@ -23,8 +23,10 @@ class LoDRankTableOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
*
out
=
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDRankTable
>
();
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDRankTable
>
();
...
...
paddle/operators/lod_tensor_to_array_op.cc
浏览文件 @
98c94373
...
@@ -32,8 +32,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
...
@@ -32,8 +32,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
"Cannot find input %s"
,
auto
&
x
=
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
"Cannot find input %s"
,
Input
(
"X"
))
Input
(
"X"
))
.
Get
<
framework
::
LoDTensor
>
();
.
Get
<
framework
::
LoDTensor
>
();
...
...
paddle/operators/max_sequence_len_op.cc
浏览文件 @
98c94373
...
@@ -27,8 +27,9 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
...
@@ -27,8 +27,9 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
dev_place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
rank_table
=
auto
&
rank_table
=
scope
.
FindVar
(
Input
(
"RankTable"
))
->
Get
<
framework
::
LoDRankTable
>
();
scope
.
FindVar
(
Input
(
"RankTable"
))
->
Get
<
framework
::
LoDRankTable
>
();
auto
*
out
=
auto
*
out
=
...
...
paddle/operators/merge_lod_tensor_op.cc
浏览文件 @
98c94373
...
@@ -27,8 +27,10 @@ class MergeLoDTensorOp : public framework::OperatorBase {
...
@@ -27,8 +27,10 @@ class MergeLoDTensorOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
paddle/operators/nccl_op.cc
浏览文件 @
98c94373
...
@@ -26,8 +26,9 @@ class NCCLInitOp : public framework::OperatorBase {
...
@@ -26,8 +26,9 @@ class NCCLInitOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
auto
&
name
=
Output
(
"Communicator"
);
const
auto
&
name
=
Output
(
"Communicator"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
name
),
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
name
),
"Can not find variable '%s' in the scope."
,
name
);
"Can not find variable '%s' in the scope."
,
name
);
...
...
paddle/operators/net_op.h
浏览文件 @
98c94373
...
@@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase {
...
@@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase {
this
->
CompleteAddOp
();
this
->
CompleteAddOp
();
}
}
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
scope
,
place
);
}
}
bool
SupportGPU
()
const
override
{
bool
SupportGPU
()
const
override
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
if
(
!
op
->
SupportGPU
())
{
if
(
!
op
->
SupportGPU
())
{
...
@@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase {
...
@@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase {
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops_
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops_
;
private:
private:
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
scope
,
place
);
}
}
bool
add_op_done_
{
false
};
bool
add_op_done_
{
false
};
std
::
set
<
std
::
string
>
intermediate_outputs_
;
std
::
set
<
std
::
string
>
intermediate_outputs_
;
...
...
paddle/operators/net_op_test.cc
浏览文件 @
98c94373
...
@@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase {
...
@@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase {
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
DEFINE_OP_CLONE_METHOD
(
TestOp
);
DEFINE_OP_CLONE_METHOD
(
TestOp
);
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
++
run_cnt
;
++
run_cnt
;
}
}
};
};
...
...
paddle/operators/parallel_do_op.cc
浏览文件 @
98c94373
...
@@ -124,8 +124,9 @@ class ParallelDoOp : public framework::OperatorBase {
...
@@ -124,8 +124,9 @@ class ParallelDoOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
...
@@ -216,8 +217,9 @@ class ParallelDoGradOp : public framework::OperatorBase {
...
@@ -216,8 +217,9 @@ class ParallelDoGradOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kParallelBlock
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kParallelBlock
);
auto
*
program
=
block
->
Program
();
auto
*
program
=
block
->
Program
();
...
...
paddle/operators/print_op.cc
浏览文件 @
98c94373
...
@@ -130,8 +130,9 @@ class TensorPrintOp : public framework::OperatorBase {
...
@@ -130,8 +130,9 @@ class TensorPrintOp : public framework::OperatorBase {
PADDLE_THROW
(
"Not implemented."
);
PADDLE_THROW
(
"Not implemented."
);
}
}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
framework
::
Variable
*
in_var_ptr
=
nullptr
;
const
framework
::
Variable
*
in_var_ptr
=
nullptr
;
std
::
string
phase
=
kForward
;
std
::
string
phase
=
kForward
;
std
::
string
printed_var_name
=
""
;
std
::
string
printed_var_name
=
""
;
...
...
paddle/operators/read_op.cc
浏览文件 @
98c94373
...
@@ -54,8 +54,10 @@ class ReadInferVarType : public framework::VarTypeInference {
...
@@ -54,8 +54,10 @@ class ReadInferVarType : public framework::VarTypeInference {
class
ReadOp
:
public
framework
::
OperatorBase
{
class
ReadOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
framework
::
ReaderHolder
*
reader
=
framework
::
ReaderHolder
*
reader
=
scope
.
FindVar
(
Input
(
"Reader"
))
->
GetMutable
<
framework
::
ReaderHolder
>
();
scope
.
FindVar
(
Input
(
"Reader"
))
->
GetMutable
<
framework
::
ReaderHolder
>
();
if
(
!
reader
->
HasNext
())
{
if
(
!
reader
->
HasNext
())
{
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
98c94373
...
@@ -226,8 +226,9 @@ class RecurrentOp : public RecurrentBase {
...
@@ -226,8 +226,9 @@ class RecurrentOp : public RecurrentBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
seq_len
=
static_cast
<
size_t
>
(
this
->
GetSequenceLength
(
scope
));
auto
seq_len
=
static_cast
<
size_t
>
(
this
->
GetSequenceLength
(
scope
));
VLOG
(
3
)
<<
"Static RNN input sequence length = "
<<
seq_len
;
VLOG
(
3
)
<<
"Static RNN input sequence length = "
<<
seq_len
;
StepScopes
scopes
=
CreateStepScopes
(
scope
,
seq_len
);
StepScopes
scopes
=
CreateStepScopes
(
scope
,
seq_len
);
...
@@ -315,8 +316,9 @@ class RecurrentGradOp : public RecurrentBase {
...
@@ -315,8 +316,9 @@ class RecurrentGradOp : public RecurrentBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
seq_len
=
static_cast
<
size_t
>
(
GetSequenceLength
(
scope
));
auto
seq_len
=
static_cast
<
size_t
>
(
GetSequenceLength
(
scope
));
StepScopes
scopes
=
CreateStepScopes
(
scope
,
seq_len
);
StepScopes
scopes
=
CreateStepScopes
(
scope
,
seq_len
);
auto
reverse
=
Attr
<
bool
>
(
kReverse
);
auto
reverse
=
Attr
<
bool
>
(
kReverse
);
...
...
paddle/operators/reorder_lod_tensor_by_rank_op.cc
浏览文件 @
98c94373
...
@@ -75,8 +75,10 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
...
@@ -75,8 +75,10 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
auto
&
x
=
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
"Cannot find input lod tensor variable %s"
,
Input
(
"X"
))
"Cannot find input lod tensor variable %s"
,
Input
(
"X"
))
...
...
paddle/operators/rnn_memory_helper_op.cc
浏览文件 @
98c94373
...
@@ -24,8 +24,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
...
@@ -24,8 +24,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
mem_var_name
=
Input
(
"X"
);
auto
mem_var_name
=
Input
(
"X"
);
auto
*
mem_var
=
scope
.
FindVar
(
mem_var_name
);
auto
*
mem_var
=
scope
.
FindVar
(
mem_var_name
);
PADDLE_ENFORCE
(
mem_var
!=
nullptr
,
PADDLE_ENFORCE
(
mem_var
!=
nullptr
,
...
@@ -76,8 +78,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
...
@@ -76,8 +78,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
out_grad_var_name
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
out_grad_var_name
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
*
out_grad_var
=
scope
.
FindVar
(
out_grad_var_name
);
auto
*
out_grad_var
=
scope
.
FindVar
(
out_grad_var_name
);
...
...
paddle/operators/save_combine_op.cc
浏览文件 @
98c94373
...
@@ -63,8 +63,10 @@ class SaveCombineOp : public framework::OperatorBase {
...
@@ -63,8 +63,10 @@ class SaveCombineOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
...
...
paddle/operators/save_op.cc
浏览文件 @
98c94373
...
@@ -62,8 +62,10 @@ class SaveOp : public framework::OperatorBase {
...
@@ -62,8 +62,10 @@ class SaveOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
...
...
paddle/operators/shrink_rnn_memory_op.cc
浏览文件 @
98c94373
...
@@ -27,8 +27,9 @@ class ShrinkRNNMemoryOp : public ArrayOp {
...
@@ -27,8 +27,9 @@ class ShrinkRNNMemoryOp : public ArrayOp {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
x_var
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x_var
=
scope
.
FindVar
(
Input
(
"X"
));
PADDLE_ENFORCE
(
x_var
!=
nullptr
,
"Input X must be set"
);
PADDLE_ENFORCE
(
x_var
!=
nullptr
,
"Input X must be set"
);
auto
&
x_tensor
=
x_var
->
Get
<
framework
::
LoDTensor
>
();
auto
&
x_tensor
=
x_var
->
Get
<
framework
::
LoDTensor
>
();
...
@@ -108,8 +109,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
...
@@ -108,8 +109,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
dout_var
=
scope
.
FindVar
(
Input
(
framework
::
GradVarName
(
"Out"
)));
auto
*
dout_var
=
scope
.
FindVar
(
Input
(
framework
::
GradVarName
(
"Out"
)));
auto
*
dx_var
=
scope
.
FindVar
(
Output
(
framework
::
GradVarName
(
"X"
)));
auto
*
dx_var
=
scope
.
FindVar
(
Output
(
framework
::
GradVarName
(
"X"
)));
PADDLE_ENFORCE
(
dx_var
!=
nullptr
,
"Input Gradient should not be nullptr"
);
PADDLE_ENFORCE
(
dx_var
!=
nullptr
,
"Input Gradient should not be nullptr"
);
...
...
paddle/operators/split_lod_tensor_op.cc
浏览文件 @
98c94373
...
@@ -33,8 +33,10 @@ class SplitLoDTensorOp : public framework::OperatorBase {
...
@@ -33,8 +33,10 @@ class SplitLoDTensorOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
mask
=
scope
.
FindVar
(
Input
(
"Mask"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
mask
=
scope
.
FindVar
(
Input
(
"Mask"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
*
out_true
=
auto
*
out_true
=
...
...
paddle/operators/tensor_array_read_write_op.cc
浏览文件 @
98c94373
...
@@ -24,8 +24,9 @@ class WriteToArrayOp : public ArrayOp {
...
@@ -24,8 +24,9 @@ class WriteToArrayOp : public ArrayOp {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
if
(
x
==
nullptr
)
return
;
if
(
x
==
nullptr
)
return
;
auto
&
x_tensor
=
x
->
Get
<
framework
::
LoDTensor
>
();
auto
&
x_tensor
=
x
->
Get
<
framework
::
LoDTensor
>
();
...
@@ -122,8 +123,10 @@ class ReadFromArrayOp : public ArrayOp {
...
@@ -122,8 +123,10 @@ class ReadFromArrayOp : public ArrayOp {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
PADDLE_ENFORCE
(
x
!=
nullptr
,
"X must be set"
);
PADDLE_ENFORCE
(
x
!=
nullptr
,
"X must be set"
);
auto
&
x_array
=
x
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
x_array
=
x
->
Get
<
framework
::
LoDTensorArray
>
();
...
...
paddle/operators/while_op.cc
浏览文件 @
98c94373
...
@@ -39,8 +39,9 @@ class WhileOp : public framework::OperatorBase {
...
@@ -39,8 +39,9 @@ class WhileOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
dev_place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
Input
(
kCondition
)));
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
Input
(
kCondition
)));
auto
&
cond
=
scope
.
FindVar
(
Input
(
kCondition
))
->
Get
<
LoDTensor
>
();
auto
&
cond
=
scope
.
FindVar
(
Input
(
kCondition
))
->
Get
<
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
cond
.
dims
(),
paddle
::
framework
::
make_ddim
({
1
}));
PADDLE_ENFORCE_EQ
(
cond
.
dims
(),
paddle
::
framework
::
make_ddim
({
1
}));
...
@@ -99,8 +100,9 @@ class WhileGradOp : public framework::OperatorBase {
...
@@ -99,8 +100,9 @@ class WhileGradOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
const
platform
::
Place
&
dev_place
)
const
override
{
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录