Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
98c94373
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
98c94373
编写于
2月 09, 2018
作者:
Y
Yang Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use op run as wrapper of run_impl; make run_impl as private virtual function
上级
f605d00f
变更
41
显示空白变更内容
内联
并排
Showing
41 changed file
with
214 addition
and
114 deletion
+214
-114
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+8
-2
paddle/framework/operator.cc
paddle/framework/operator.cc
+14
-2
paddle/framework/operator.h
paddle/framework/operator.h
+10
-5
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+8
-3
paddle/operators/array_to_lod_tensor_op.cc
paddle/operators/array_to_lod_tensor_op.cc
+4
-2
paddle/operators/assign_op.cc
paddle/operators/assign_op.cc
+4
-2
paddle/operators/beam_search_decode_op.cc
paddle/operators/beam_search_decode_op.cc
+4
-2
paddle/operators/beam_search_op.h
paddle/operators/beam_search_op.h
+3
-2
paddle/operators/cond_op.cc
paddle/operators/cond_op.cc
+1
-1
paddle/operators/cond_op.h
paddle/operators/cond_op.h
+3
-2
paddle/operators/conditional_block_op.cc
paddle/operators/conditional_block_op.cc
+8
-4
paddle/operators/create_reader_op.cc
paddle/operators/create_reader_op.cc
+12
-6
paddle/operators/feed_op.cc
paddle/operators/feed_op.cc
+4
-2
paddle/operators/fetch_op.cc
paddle/operators/fetch_op.cc
+3
-2
paddle/operators/fill_constant_op.cc
paddle/operators/fill_constant_op.cc
+4
-2
paddle/operators/fill_op.cc
paddle/operators/fill_op.cc
+4
-2
paddle/operators/get_places_op.cc
paddle/operators/get_places_op.cc
+4
-2
paddle/operators/increment_op.cc
paddle/operators/increment_op.cc
+3
-2
paddle/operators/is_empty_op.cc
paddle/operators/is_empty_op.cc
+3
-2
paddle/operators/load_combine_op.cc
paddle/operators/load_combine_op.cc
+4
-2
paddle/operators/load_op.cc
paddle/operators/load_op.cc
+4
-2
paddle/operators/lod_array_length_op.cc
paddle/operators/lod_array_length_op.cc
+4
-2
paddle/operators/lod_rank_table_op.cc
paddle/operators/lod_rank_table_op.cc
+4
-2
paddle/operators/lod_tensor_to_array_op.cc
paddle/operators/lod_tensor_to_array_op.cc
+4
-2
paddle/operators/max_sequence_len_op.cc
paddle/operators/max_sequence_len_op.cc
+3
-2
paddle/operators/merge_lod_tensor_op.cc
paddle/operators/merge_lod_tensor_op.cc
+4
-2
paddle/operators/nccl_op.cc
paddle/operators/nccl_op.cc
+3
-2
paddle/operators/net_op.h
paddle/operators/net_op.h
+14
-14
paddle/operators/net_op_test.cc
paddle/operators/net_op_test.cc
+4
-1
paddle/operators/parallel_do_op.cc
paddle/operators/parallel_do_op.cc
+6
-4
paddle/operators/print_op.cc
paddle/operators/print_op.cc
+3
-2
paddle/operators/read_op.cc
paddle/operators/read_op.cc
+4
-2
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+6
-4
paddle/operators/reorder_lod_tensor_by_rank_op.cc
paddle/operators/reorder_lod_tensor_by_rank_op.cc
+4
-2
paddle/operators/rnn_memory_helper_op.cc
paddle/operators/rnn_memory_helper_op.cc
+8
-4
paddle/operators/save_combine_op.cc
paddle/operators/save_combine_op.cc
+4
-2
paddle/operators/save_op.cc
paddle/operators/save_op.cc
+4
-2
paddle/operators/shrink_rnn_memory_op.cc
paddle/operators/shrink_rnn_memory_op.cc
+6
-4
paddle/operators/split_lod_tensor_op.cc
paddle/operators/split_lod_tensor_op.cc
+4
-2
paddle/operators/tensor_array_read_write_op.cc
paddle/operators/tensor_array_read_write_op.cc
+7
-4
paddle/operators/while_op.cc
paddle/operators/while_op.cc
+6
-4
未找到文件。
paddle/framework/op_registry_test.cc
浏览文件 @
98c94373
...
@@ -25,7 +25,10 @@ namespace framework {
...
@@ -25,7 +25,10 @@ namespace framework {
class
CosineOp
:
public
OperatorBase
{
class
CosineOp
:
public
OperatorBase
{
public:
public:
using
OperatorBase
::
OperatorBase
;
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
};
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
@@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class
MyTestOp
:
public
OperatorBase
{
class
MyTestOp
:
public
OperatorBase
{
public:
public:
using
OperatorBase
::
OperatorBase
;
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
};
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
...
paddle/framework/operator.cc
浏览文件 @
98c94373
...
@@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
...
@@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}
}
}
}
void
OperatorBase
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
{
if
(
platform
::
is_gpu_place
(
place
))
{
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW
(
"Cannot run operator on place %s"
,
place
);
#else
auto
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
;
platform
::
SetDeviceId
(
dev_id
);
#endif
}
RunImpl
(
scope
,
place
);
}
std
::
string
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
std
::
string
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
&
ins
=
Inputs
(
name
);
auto
&
ins
=
Inputs
(
name
);
PADDLE_ENFORCE_LE
(
ins
.
size
(),
1UL
,
PADDLE_ENFORCE_LE
(
ins
.
size
(),
1UL
,
...
@@ -475,7 +487,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -475,7 +487,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
const
Scope
&
scope_
;
const
Scope
&
scope_
;
};
};
void
OperatorWithKernel
::
Run
(
const
Scope
&
scope
,
void
OperatorWithKernel
::
Run
Impl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
const
platform
::
Place
&
place
)
const
{
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
scope
);
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
scope
);
this
->
InferShape
(
&
infer_shape_ctx
);
this
->
InferShape
(
&
infer_shape_ctx
);
...
...
paddle/framework/operator.h
浏览文件 @
98c94373
...
@@ -89,8 +89,9 @@ class OperatorBase {
...
@@ -89,8 +89,9 @@ class OperatorBase {
std
::
string
DebugString
()
const
{
return
DebugStringEx
(
nullptr
);
}
std
::
string
DebugString
()
const
{
return
DebugStringEx
(
nullptr
);
}
/// Net will call this function to Run an op.
/// Net will call this interface function to Run an op.
virtual
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
=
0
;
// The implementation should be written at RunImpl
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
);
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual
void
Stop
()
{}
virtual
void
Stop
()
{}
...
@@ -144,6 +145,8 @@ class OperatorBase {
...
@@ -144,6 +145,8 @@ class OperatorBase {
private:
private:
void
GenerateTemporaryNames
();
void
GenerateTemporaryNames
();
void
CheckAllInputOutputSet
()
const
;
void
CheckAllInputOutputSet
()
const
;
virtual
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
=
0
;
};
};
// Macro for define a clone method.
// Macro for define a clone method.
...
@@ -168,10 +171,13 @@ class OperatorBase {
...
@@ -168,10 +171,13 @@ class OperatorBase {
class
NOP
:
public
OperatorBase
{
class
NOP
:
public
OperatorBase
{
public:
public:
using
OperatorBase
::
OperatorBase
;
using
OperatorBase
::
OperatorBase
;
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
std
::
unique_ptr
<
OperatorBase
>
Clone
()
const
override
{
std
::
unique_ptr
<
OperatorBase
>
Clone
()
const
override
{
return
std
::
unique_ptr
<
OperatorBase
>
(
new
NOP
(
*
this
));
return
std
::
unique_ptr
<
OperatorBase
>
(
new
NOP
(
*
this
));
}
}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
};
class
ExecutionContext
{
class
ExecutionContext
{
...
@@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase {
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
final
;
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
AllOpKernels
()
{
AllOpKernels
()
{
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
...
@@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase {
// indicate kernel DataType by input data. Defaultly all input data must be
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
// same.
proto
::
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
;
proto
::
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
;
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
final
;
};
};
extern
bool
OpSupportGPU
(
const
std
::
string
&
op_type
);
extern
bool
OpSupportGPU
(
const
std
::
string
&
op_type
);
...
...
paddle/framework/operator_test.cc
浏览文件 @
98c94373
...
@@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase {
...
@@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase {
OpWithoutKernelTest
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
OpWithoutKernelTest
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
),
x
(
1
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
),
x
(
1
)
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
++
op_run_num
;
++
op_run_num
;
ASSERT_EQ
(
static_cast
<
int
>
(
inputs_
.
size
()),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
inputs_
.
size
()),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
outputs_
.
size
()),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
outputs_
.
size
()),
1
);
...
@@ -259,7 +262,9 @@ class OperatorClone : public paddle::framework::OperatorBase {
...
@@ -259,7 +262,9 @@ class OperatorClone : public paddle::framework::OperatorBase {
const
paddle
::
framework
::
VariableNameMap
&
outputs
,
const
paddle
::
framework
::
VariableNameMap
&
outputs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
const
paddle
::
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
paddle
::
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
paddle
::
framework
::
Scope
&
scope
,
const
paddle
::
platform
::
Place
&
place
)
const
override
{}
const
paddle
::
platform
::
Place
&
place
)
const
override
{}
};
};
...
...
paddle/operators/array_to_lod_tensor_op.cc
浏览文件 @
98c94373
...
@@ -31,7 +31,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
...
@@ -31,7 +31,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
rank_table
=
auto
&
rank_table
=
...
...
paddle/operators/assign_op.cc
浏览文件 @
98c94373
...
@@ -71,7 +71,9 @@ class AssignOp : public framework::OperatorBase {
...
@@ -71,7 +71,9 @@ class AssignOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
if
(
x
==
nullptr
)
{
if
(
x
==
nullptr
)
{
...
...
paddle/operators/beam_search_decode_op.cc
浏览文件 @
98c94373
...
@@ -55,7 +55,9 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
...
@@ -55,7 +55,9 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
paddle/operators/beam_search_op.h
浏览文件 @
98c94373
...
@@ -204,7 +204,8 @@ class BeamSearchOp : public framework::OperatorBase {
...
@@ -204,7 +204,8 @@ class BeamSearchOp : public framework::OperatorBase {
PADDLE_THROW
(
"Not Implemented"
);
PADDLE_THROW
(
"Not Implemented"
);
}
}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
ids_var
=
scope
.
FindVar
(
Input
(
"ids"
));
auto
ids_var
=
scope
.
FindVar
(
Input
(
"ids"
));
auto
scores_var
=
scope
.
FindVar
(
Input
(
"scores"
));
auto
scores_var
=
scope
.
FindVar
(
Input
(
"scores"
));
...
...
paddle/operators/cond_op.cc
浏览文件 @
98c94373
...
@@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
...
@@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
}
}
}
}
void
CondOp
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
void
CondOp
::
Run
Impl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
...
...
paddle/operators/cond_op.h
浏览文件 @
98c94373
...
@@ -77,7 +77,8 @@ class CondOp : public framework::OperatorBase {
...
@@ -77,7 +77,8 @@ class CondOp : public framework::OperatorBase {
sub_net_op_
[
FALSE_BRANCH
]
=
std
::
move
(
net
);
sub_net_op_
[
FALSE_BRANCH
]
=
std
::
move
(
net
);
}
}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
;
const
platform
::
Place
&
place
)
const
override
;
private:
private:
...
...
paddle/operators/conditional_block_op.cc
浏览文件 @
98c94373
...
@@ -65,7 +65,9 @@ class ConditionalBlockOp : public ConditionalOp {
...
@@ -65,7 +65,9 @@ class ConditionalBlockOp : public ConditionalOp {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
xs
=
InputTensors
(
scope
);
auto
xs
=
InputTensors
(
scope
);
...
@@ -128,7 +130,9 @@ class ConditionalBlockGradOp : public ConditionalOp {
...
@@ -128,7 +130,9 @@ class ConditionalBlockGradOp : public ConditionalOp {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ConditionalOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
xs
=
this
->
InputTensors
(
scope
);
auto
xs
=
this
->
InputTensors
(
scope
);
...
...
paddle/operators/create_reader_op.cc
浏览文件 @
98c94373
...
@@ -72,7 +72,9 @@ template <typename T>
...
@@ -72,7 +72,9 @@ template <typename T>
class
CreateRandomDataGeneratorOp
:
public
framework
::
OperatorBase
{
class
CreateRandomDataGeneratorOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
...
@@ -120,7 +122,9 @@ class CreateRandomDataGeneratorOpMaker
...
@@ -120,7 +122,9 @@ class CreateRandomDataGeneratorOpMaker
class
CreateShuffleReaderOp
:
public
framework
::
OperatorBase
{
class
CreateShuffleReaderOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
...
@@ -152,7 +156,9 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -152,7 +156,9 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
class
CreateBatchReaderOp
:
public
framework
::
OperatorBase
{
class
CreateBatchReaderOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
...
...
paddle/operators/feed_op.cc
浏览文件 @
98c94373
...
@@ -24,7 +24,9 @@ class FeedOp : public framework::OperatorBase {
...
@@ -24,7 +24,9 @@ class FeedOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
feed_var_name
=
Input
(
"X"
);
auto
feed_var_name
=
Input
(
"X"
);
auto
*
feed_var
=
scope
.
FindVar
(
feed_var_name
);
auto
*
feed_var
=
scope
.
FindVar
(
feed_var_name
);
...
...
paddle/operators/fetch_op.cc
浏览文件 @
98c94373
...
@@ -26,7 +26,8 @@ class FetchOp : public framework::OperatorBase {
...
@@ -26,7 +26,8 @@ class FetchOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
fetch_var_name
=
Input
(
"X"
);
auto
fetch_var_name
=
Input
(
"X"
);
auto
*
fetch_var
=
scope
.
FindVar
(
fetch_var_name
);
auto
*
fetch_var
=
scope
.
FindVar
(
fetch_var_name
);
...
...
paddle/operators/fill_constant_op.cc
浏览文件 @
98c94373
...
@@ -33,7 +33,9 @@ class FillConstantInferShape : public framework::InferShapeBase {
...
@@ -33,7 +33,9 @@ class FillConstantInferShape : public framework::InferShapeBase {
class
FillConstantOp
:
public
framework
::
OperatorBase
{
class
FillConstantOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
data_type
=
auto
data_type
=
static_cast
<
framework
::
proto
::
DataType
>
(
Attr
<
int
>
(
"dtype"
));
static_cast
<
framework
::
proto
::
DataType
>
(
Attr
<
int
>
(
"dtype"
));
...
...
paddle/operators/fill_op.cc
浏览文件 @
98c94373
...
@@ -42,7 +42,9 @@ class FillOp : public framework::OperatorBase {
...
@@ -42,7 +42,9 @@ class FillOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
&
out
=
auto
&
out
=
detail
::
Ref
(
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)),
detail
::
Ref
(
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)),
...
...
paddle/operators/get_places_op.cc
浏览文件 @
98c94373
...
@@ -37,7 +37,9 @@ class GetPlacesOp : public framework::OperatorBase {
...
@@ -37,7 +37,9 @@ class GetPlacesOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
bool
is_gpu
;
bool
is_gpu
;
if
(
Attr
<
std
::
string
>
(
"device_type"
)
==
"AUTO"
)
{
if
(
Attr
<
std
::
string
>
(
"device_type"
)
==
"AUTO"
)
{
...
...
paddle/operators/increment_op.cc
浏览文件 @
98c94373
...
@@ -51,7 +51,8 @@ class IncrementOp : public framework::OperatorBase {
...
@@ -51,7 +51,8 @@ class IncrementOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
out
=
auto
&
out
=
...
...
paddle/operators/is_empty_op.cc
浏览文件 @
98c94373
...
@@ -28,7 +28,8 @@ class IsEmptyOp : public framework::OperatorBase {
...
@@ -28,7 +28,8 @@ class IsEmptyOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
// get input
// get input
auto
*
var
=
scope
.
FindVar
(
Input
(
kInput
));
auto
*
var
=
scope
.
FindVar
(
Input
(
kInput
));
...
...
paddle/operators/load_combine_op.cc
浏览文件 @
98c94373
...
@@ -26,7 +26,9 @@ class LoadCombineOp : public framework::OperatorBase {
...
@@ -26,7 +26,9 @@ class LoadCombineOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
...
...
paddle/operators/load_op.cc
浏览文件 @
98c94373
...
@@ -25,7 +25,9 @@ class LoadOp : public framework::OperatorBase {
...
@@ -25,7 +25,9 @@ class LoadOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
std
::
ifstream
fin
(
filename
);
std
::
ifstream
fin
(
filename
);
...
...
paddle/operators/lod_array_length_op.cc
浏览文件 @
98c94373
...
@@ -25,7 +25,9 @@ class LoDArrayLengthOp : public framework::OperatorBase {
...
@@ -25,7 +25,9 @@ class LoDArrayLengthOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
out
=
auto
&
out
=
...
...
paddle/operators/lod_rank_table_op.cc
浏览文件 @
98c94373
...
@@ -23,7 +23,9 @@ class LoDRankTableOp : public framework::OperatorBase {
...
@@ -23,7 +23,9 @@ class LoDRankTableOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
*
out
=
auto
*
out
=
...
...
paddle/operators/lod_tensor_to_array_op.cc
浏览文件 @
98c94373
...
@@ -32,7 +32,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
...
@@ -32,7 +32,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
"Cannot find input %s"
,
auto
&
x
=
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
"Cannot find input %s"
,
Input
(
"X"
))
Input
(
"X"
))
...
...
paddle/operators/max_sequence_len_op.cc
浏览文件 @
98c94373
...
@@ -27,7 +27,8 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
...
@@ -27,7 +27,8 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
rank_table
=
auto
&
rank_table
=
scope
.
FindVar
(
Input
(
"RankTable"
))
->
Get
<
framework
::
LoDRankTable
>
();
scope
.
FindVar
(
Input
(
"RankTable"
))
->
Get
<
framework
::
LoDRankTable
>
();
...
...
paddle/operators/merge_lod_tensor_op.cc
浏览文件 @
98c94373
...
@@ -27,7 +27,9 @@ class MergeLoDTensorOp : public framework::OperatorBase {
...
@@ -27,7 +27,9 @@ class MergeLoDTensorOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
paddle/operators/nccl_op.cc
浏览文件 @
98c94373
...
@@ -26,7 +26,8 @@ class NCCLInitOp : public framework::OperatorBase {
...
@@ -26,7 +26,8 @@ class NCCLInitOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
const
auto
&
name
=
Output
(
"Communicator"
);
const
auto
&
name
=
Output
(
"Communicator"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
name
),
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
name
),
...
...
paddle/operators/net_op.h
浏览文件 @
98c94373
...
@@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase {
...
@@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase {
this
->
CompleteAddOp
();
this
->
CompleteAddOp
();
}
}
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
scope
,
place
);
}
}
bool
SupportGPU
()
const
override
{
bool
SupportGPU
()
const
override
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
if
(
!
op
->
SupportGPU
())
{
if
(
!
op
->
SupportGPU
())
{
...
@@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase {
...
@@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase {
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops_
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops_
;
private:
private:
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
scope
,
place
);
}
}
bool
add_op_done_
{
false
};
bool
add_op_done_
{
false
};
std
::
set
<
std
::
string
>
intermediate_outputs_
;
std
::
set
<
std
::
string
>
intermediate_outputs_
;
...
...
paddle/operators/net_op_test.cc
浏览文件 @
98c94373
...
@@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase {
...
@@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase {
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
DEFINE_OP_CLONE_METHOD
(
TestOp
);
DEFINE_OP_CLONE_METHOD
(
TestOp
);
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
++
run_cnt
;
++
run_cnt
;
}
}
};
};
...
...
paddle/operators/parallel_do_op.cc
浏览文件 @
98c94373
...
@@ -124,7 +124,8 @@ class ParallelDoOp : public framework::OperatorBase {
...
@@ -124,7 +124,8 @@ class ParallelDoOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
@@ -216,7 +217,8 @@ class ParallelDoGradOp : public framework::OperatorBase {
...
@@ -216,7 +217,8 @@ class ParallelDoGradOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kParallelBlock
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kParallelBlock
);
auto
*
program
=
block
->
Program
();
auto
*
program
=
block
->
Program
();
...
...
paddle/operators/print_op.cc
浏览文件 @
98c94373
...
@@ -130,7 +130,8 @@ class TensorPrintOp : public framework::OperatorBase {
...
@@ -130,7 +130,8 @@ class TensorPrintOp : public framework::OperatorBase {
PADDLE_THROW
(
"Not implemented."
);
PADDLE_THROW
(
"Not implemented."
);
}
}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
const
framework
::
Variable
*
in_var_ptr
=
nullptr
;
const
framework
::
Variable
*
in_var_ptr
=
nullptr
;
std
::
string
phase
=
kForward
;
std
::
string
phase
=
kForward
;
...
...
paddle/operators/read_op.cc
浏览文件 @
98c94373
...
@@ -54,7 +54,9 @@ class ReadInferVarType : public framework::VarTypeInference {
...
@@ -54,7 +54,9 @@ class ReadInferVarType : public framework::VarTypeInference {
class
ReadOp
:
public
framework
::
OperatorBase
{
class
ReadOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
framework
::
ReaderHolder
*
reader
=
framework
::
ReaderHolder
*
reader
=
scope
.
FindVar
(
Input
(
"Reader"
))
->
GetMutable
<
framework
::
ReaderHolder
>
();
scope
.
FindVar
(
Input
(
"Reader"
))
->
GetMutable
<
framework
::
ReaderHolder
>
();
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
98c94373
...
@@ -226,7 +226,8 @@ class RecurrentOp : public RecurrentBase {
...
@@ -226,7 +226,8 @@ class RecurrentOp : public RecurrentBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
seq_len
=
static_cast
<
size_t
>
(
this
->
GetSequenceLength
(
scope
));
auto
seq_len
=
static_cast
<
size_t
>
(
this
->
GetSequenceLength
(
scope
));
VLOG
(
3
)
<<
"Static RNN input sequence length = "
<<
seq_len
;
VLOG
(
3
)
<<
"Static RNN input sequence length = "
<<
seq_len
;
...
@@ -315,7 +316,8 @@ class RecurrentGradOp : public RecurrentBase {
...
@@ -315,7 +316,8 @@ class RecurrentGradOp : public RecurrentBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
RecurrentBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
seq_len
=
static_cast
<
size_t
>
(
GetSequenceLength
(
scope
));
auto
seq_len
=
static_cast
<
size_t
>
(
GetSequenceLength
(
scope
));
StepScopes
scopes
=
CreateStepScopes
(
scope
,
seq_len
);
StepScopes
scopes
=
CreateStepScopes
(
scope
,
seq_len
);
...
...
paddle/operators/reorder_lod_tensor_by_rank_op.cc
浏览文件 @
98c94373
...
@@ -75,7 +75,9 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
...
@@ -75,7 +75,9 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
auto
&
x
=
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"X"
)),
...
...
paddle/operators/rnn_memory_helper_op.cc
浏览文件 @
98c94373
...
@@ -24,7 +24,9 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
...
@@ -24,7 +24,9 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
mem_var_name
=
Input
(
"X"
);
auto
mem_var_name
=
Input
(
"X"
);
auto
*
mem_var
=
scope
.
FindVar
(
mem_var_name
);
auto
*
mem_var
=
scope
.
FindVar
(
mem_var_name
);
...
@@ -76,7 +78,9 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
...
@@ -76,7 +78,9 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
out_grad_var_name
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
out_grad_var_name
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
*
out_grad_var
=
scope
.
FindVar
(
out_grad_var_name
);
auto
*
out_grad_var
=
scope
.
FindVar
(
out_grad_var_name
);
...
...
paddle/operators/save_combine_op.cc
浏览文件 @
98c94373
...
@@ -63,7 +63,9 @@ class SaveCombineOp : public framework::OperatorBase {
...
@@ -63,7 +63,9 @@ class SaveCombineOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
...
...
paddle/operators/save_op.cc
浏览文件 @
98c94373
...
@@ -62,7 +62,9 @@ class SaveOp : public framework::OperatorBase {
...
@@ -62,7 +62,9 @@ class SaveOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
...
...
paddle/operators/shrink_rnn_memory_op.cc
浏览文件 @
98c94373
...
@@ -27,7 +27,8 @@ class ShrinkRNNMemoryOp : public ArrayOp {
...
@@ -27,7 +27,8 @@ class ShrinkRNNMemoryOp : public ArrayOp {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
*
x_var
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x_var
=
scope
.
FindVar
(
Input
(
"X"
));
PADDLE_ENFORCE
(
x_var
!=
nullptr
,
"Input X must be set"
);
PADDLE_ENFORCE
(
x_var
!=
nullptr
,
"Input X must be set"
);
...
@@ -108,7 +109,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
...
@@ -108,7 +109,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
*
dout_var
=
scope
.
FindVar
(
Input
(
framework
::
GradVarName
(
"Out"
)));
auto
*
dout_var
=
scope
.
FindVar
(
Input
(
framework
::
GradVarName
(
"Out"
)));
auto
*
dx_var
=
scope
.
FindVar
(
Output
(
framework
::
GradVarName
(
"X"
)));
auto
*
dx_var
=
scope
.
FindVar
(
Output
(
framework
::
GradVarName
(
"X"
)));
...
...
paddle/operators/split_lod_tensor_op.cc
浏览文件 @
98c94373
...
@@ -33,7 +33,9 @@ class SplitLoDTensorOp : public framework::OperatorBase {
...
@@ -33,7 +33,9 @@ class SplitLoDTensorOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
mask
=
scope
.
FindVar
(
Input
(
"Mask"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
mask
=
scope
.
FindVar
(
Input
(
"Mask"
))
->
Get
<
framework
::
LoDTensor
>
();
...
...
paddle/operators/tensor_array_read_write_op.cc
浏览文件 @
98c94373
...
@@ -24,7 +24,8 @@ class WriteToArrayOp : public ArrayOp {
...
@@ -24,7 +24,8 @@ class WriteToArrayOp : public ArrayOp {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
if
(
x
==
nullptr
)
return
;
if
(
x
==
nullptr
)
return
;
...
@@ -122,7 +123,9 @@ class ReadFromArrayOp : public ArrayOp {
...
@@ -122,7 +123,9 @@ class ReadFromArrayOp : public ArrayOp {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
ArrayOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x
=
scope
.
FindVar
(
Input
(
"X"
));
PADDLE_ENFORCE
(
x
!=
nullptr
,
"X must be set"
);
PADDLE_ENFORCE
(
x
!=
nullptr
,
"X must be set"
);
...
...
paddle/operators/while_op.cc
浏览文件 @
98c94373
...
@@ -39,7 +39,8 @@ class WhileOp : public framework::OperatorBase {
...
@@ -39,7 +39,8 @@ class WhileOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
Input
(
kCondition
)));
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
Input
(
kCondition
)));
auto
&
cond
=
scope
.
FindVar
(
Input
(
kCondition
))
->
Get
<
LoDTensor
>
();
auto
&
cond
=
scope
.
FindVar
(
Input
(
kCondition
))
->
Get
<
LoDTensor
>
();
...
@@ -99,7 +100,8 @@ class WhileGradOp : public framework::OperatorBase {
...
@@ -99,7 +100,8 @@ class WhileGradOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录