Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
79149c8e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
79149c8e
编写于
9月 14, 2020
作者:
C
Chen Weihang
提交者:
GitHub
9月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish framework error message part 8 (#27269)
上级
8d531727
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
88 addition
and
62 deletion
+88
-62
paddle/fluid/framework/c/c_api.cc
paddle/fluid/framework/c/c_api.cc
+2
-1
paddle/fluid/framework/fleet/nccl_wrapper.cc
paddle/fluid/framework/fleet/nccl_wrapper.cc
+6
-5
paddle/fluid/framework/threadpool.cc
paddle/fluid/framework/threadpool.cc
+4
-2
paddle/fluid/framework/threadpool.h
paddle/fluid/framework/threadpool.h
+2
-1
paddle/fluid/framework/var_desc.cc
paddle/fluid/framework/var_desc.cc
+51
-37
paddle/fluid/framework/var_type.h
paddle/fluid/framework/var_type.h
+4
-2
paddle/fluid/framework/var_type_traits.cc
paddle/fluid/framework/var_type_traits.cc
+14
-10
paddle/fluid/framework/variable_helper.cc
paddle/fluid/framework/variable_helper.cc
+5
-4
未找到文件。
paddle/fluid/framework/c/c_api.cc
浏览文件 @
79149c8e
...
@@ -49,7 +49,8 @@ std::vector<std::string> PD_GetGradOpDescStrs(
...
@@ -49,7 +49,8 @@ std::vector<std::string> PD_GetGradOpDescStrs(
for
(
size_t
i
=
0
;
i
<
op_num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
op_num
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
grad_op_descs
[
i
]
->
Proto
()
->
SerializePartialToString
(
&
ret
[
i
]),
true
,
grad_op_descs
[
i
]
->
Proto
()
->
SerializePartialToString
(
&
ret
[
i
]),
true
,
"Cannot serialize message."
);
paddle
::
platform
::
errors
::
Unavailable
(
"Cannot serialize operator desc message."
));
}
}
}
}
return
ret
;
return
ret
;
...
...
paddle/fluid/framework/fleet/nccl_wrapper.cc
浏览文件 @
79149c8e
...
@@ -25,7 +25,7 @@ bool NCCLWrapper::is_initialized_ = false;
...
@@ -25,7 +25,7 @@ bool NCCLWrapper::is_initialized_ = false;
void
NCCLWrapper
::
InitNCCL
()
{
void
NCCLWrapper
::
InitNCCL
()
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclCommInitRank
(
PADDLE_ENFORCE
_CUDA_SUCCESS
(
platform
::
dynload
::
ncclCommInitRank
(
&
(
nccl_info_
.
comm_
),
nccl_info_
.
global_ranks_
,
nccl_info_
.
nccl_id_
,
&
(
nccl_info_
.
comm_
),
nccl_info_
.
global_ranks_
,
nccl_info_
.
nccl_id_
,
nccl_info_
.
my_global_rank_
));
nccl_info_
.
my_global_rank_
));
#endif
#endif
...
@@ -41,7 +41,8 @@ void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) {
...
@@ -41,7 +41,8 @@ void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) {
NCCLInfo
NCCLWrapper
::
GetNCCLId
()
{
NCCLInfo
NCCLWrapper
::
GetNCCLId
()
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclGetUniqueId
(
&
(
nccl_info_
.
nccl_id_
)));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGetUniqueId
(
&
(
nccl_info_
.
nccl_id_
)));
#endif
#endif
return
nccl_info_
;
return
nccl_info_
;
}
}
...
@@ -52,8 +53,8 @@ void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank,
...
@@ -52,8 +53,8 @@ void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank,
nccl_info_
.
local_rank_
=
local_rank
;
nccl_info_
.
local_rank_
=
local_rank
;
nccl_info_
.
my_global_rank_
=
global_rank
;
nccl_info_
.
my_global_rank_
=
global_rank
;
nccl_info_
.
global_ranks_
=
ranks
;
nccl_info_
.
global_ranks_
=
ranks
;
PADDLE_ENFORCE
(
cudaSetDevice
(
local_rank
));
PADDLE_ENFORCE
_CUDA_SUCCESS
(
cudaSetDevice
(
local_rank
));
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
(
nccl_info_
.
stream_
)));
PADDLE_ENFORCE
_CUDA_SUCCESS
(
cudaStreamCreate
(
&
(
nccl_info_
.
stream_
)));
#endif
#endif
return
;
return
;
}
}
...
@@ -65,7 +66,7 @@ void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope,
...
@@ -65,7 +66,7 @@ void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope,
auto
var
=
scope
.
FindVar
(
name
);
auto
var
=
scope
.
FindVar
(
name
);
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
int32_t
total_size
=
tensor
->
numel
();
int32_t
total_size
=
tensor
->
numel
();
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclBcast
(
PADDLE_ENFORCE
_CUDA_SUCCESS
(
platform
::
dynload
::
ncclBcast
(
reinterpret_cast
<
void
*>
(
tensor
->
data
<
float
>
()),
total_size
,
ncclFloat
,
reinterpret_cast
<
void
*>
(
tensor
->
data
<
float
>
()),
total_size
,
ncclFloat
,
root_rank
,
nccl_info_
.
comm_
,
nccl_info_
.
stream_
));
root_rank
,
nccl_info_
.
comm_
,
nccl_info_
.
stream_
));
cudaStreamSynchronize
(
nccl_info_
.
stream_
);
cudaStreamSynchronize
(
nccl_info_
.
stream_
);
...
...
paddle/fluid/framework/threadpool.cc
浏览文件 @
79149c8e
...
@@ -42,7 +42,8 @@ void ThreadPool::Init() {
...
@@ -42,7 +42,8 @@ void ThreadPool::Init() {
num_threads
=
FLAGS_dist_threadpool_size
;
num_threads
=
FLAGS_dist_threadpool_size
;
VLOG
(
1
)
<<
"set dist_threadpool_size to "
<<
num_threads
;
VLOG
(
1
)
<<
"set dist_threadpool_size to "
<<
num_threads
;
}
}
PADDLE_ENFORCE_GT
(
num_threads
,
0
);
PADDLE_ENFORCE_GT
(
num_threads
,
0
,
platform
::
errors
::
InvalidArgument
(
"The number of threads is 0."
));
threadpool_
.
reset
(
new
ThreadPool
(
num_threads
));
threadpool_
.
reset
(
new
ThreadPool
(
num_threads
));
}
}
}
}
...
@@ -83,7 +84,8 @@ void ThreadPool::TaskLoop() {
...
@@ -83,7 +84,8 @@ void ThreadPool::TaskLoop() {
}
}
if
(
tasks_
.
empty
())
{
if
(
tasks_
.
empty
())
{
PADDLE_THROW
(
"This thread has no task to Run"
);
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Current thread has no task to Run."
));
}
}
// pop a task from the task queue
// pop a task from the task queue
...
...
paddle/fluid/framework/threadpool.h
浏览文件 @
79149c8e
...
@@ -91,7 +91,8 @@ class ThreadPool {
...
@@ -91,7 +91,8 @@ class ThreadPool {
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
!
running_
)
{
if
(
!
running_
)
{
PADDLE_THROW
(
"enqueue on stopped ThreadPool"
);
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Task is enqueued into stopped ThreadPool."
));
}
}
tasks_
.
push
(
std
::
move
(
task
));
tasks_
.
push
(
std
::
move
(
task
));
}
}
...
...
paddle/fluid/framework/var_desc.cc
浏览文件 @
79149c8e
...
@@ -43,8 +43,9 @@ void VarDesc::SetTensorDescNum(size_t num) {
...
@@ -43,8 +43,9 @@ void VarDesc::SetTensorDescNum(size_t num) {
}
break
;
}
break
;
default:
default:
PADDLE_THROW
(
PADDLE_THROW
(
"Setting 'sub_tensor_number' is not supported by the type of var %s."
,
platform
::
errors
::
Unavailable
(
"Setting 'sub_tensor_number' is not "
this
->
Name
());
"supported by the %s type variable."
,
this
->
Name
()));
}
}
}
}
...
@@ -55,8 +56,9 @@ size_t VarDesc::GetTensorDescNum() const {
...
@@ -55,8 +56,9 @@ size_t VarDesc::GetTensorDescNum() const {
break
;
break
;
default:
default:
PADDLE_THROW
(
PADDLE_THROW
(
"Getting 'sub_tensor_number' is not supported by the type of var %s."
,
platform
::
errors
::
Unavailable
(
"Getting 'sub_tensor_number' is not "
this
->
Name
());
"supported by the %s type variable."
,
this
->
Name
()));
}
}
}
}
...
@@ -133,9 +135,9 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
...
@@ -133,9 +135,9 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
desc_
.
mutable_type
()
->
mutable_tensor_array
()
->
set_lod_level
(
lod_level
);
desc_
.
mutable_type
()
->
mutable_tensor_array
()
->
set_lod_level
(
lod_level
);
break
;
break
;
default:
default:
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Setting 'lod_level' is not supported by the
type of var %s
."
,
"Setting 'lod_level' is not supported by the
%s type variable
."
,
this
->
Name
());
this
->
Name
())
)
;
}
}
}
}
...
@@ -157,9 +159,9 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
...
@@ -157,9 +159,9 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
}
}
}
break
;
}
break
;
default:
default:
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Setting 'lod_levels' is not supported by the
type of var %s.
"
,
"Setting 'lod_levels' is not supported by the
%s type variable
"
,
this
->
Name
());
this
->
Name
())
)
;
}
}
}
}
...
@@ -170,9 +172,9 @@ int32_t VarDesc::GetLoDLevel() const {
...
@@ -170,9 +172,9 @@ int32_t VarDesc::GetLoDLevel() const {
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
return
desc_
.
type
().
tensor_array
().
lod_level
();
return
desc_
.
type
().
tensor_array
().
lod_level
();
default:
default:
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Getting 'lod_level' is not supported by the
type of var %s
."
,
"Getting 'lod_level' is not supported by the
%s type variable
."
,
this
->
Name
());
this
->
Name
())
)
;
}
}
}
}
...
@@ -187,15 +189,19 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
...
@@ -187,15 +189,19 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
return
res
;
return
res
;
break
;
break
;
default:
default:
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Getting 'lod_levels' is not supported by the
type of var %s
."
,
"Getting 'lod_levels' is not supported by the
%s type variable
."
,
this
->
Name
());
this
->
Name
())
)
;
}
}
}
}
const
proto
::
VarType
::
TensorDesc
&
VarDesc
::
tensor_desc
()
const
{
const
proto
::
VarType
::
TensorDesc
&
VarDesc
::
tensor_desc
()
const
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var's type hasn't been set."
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
desc_
.
has_type
(),
true
,
platform
::
errors
::
NotFound
(
"The variable's type was not be set."
));
PADDLE_ENFORCE_EQ
(
desc_
.
type
().
has_type
(),
true
,
platform
::
errors
::
NotFound
(
"The variable's type was not be set."
));
switch
(
desc_
.
type
().
type
())
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
SELECTED_ROWS
:
case
proto
::
VarType
::
SELECTED_ROWS
:
return
desc_
.
type
().
selected_rows
();
return
desc_
.
type
().
selected_rows
();
...
@@ -204,14 +210,16 @@ const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
...
@@ -204,14 +210,16 @@ const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
return
desc_
.
type
().
tensor_array
().
tensor
();
return
desc_
.
type
().
tensor_array
().
tensor
();
default:
default:
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Getting 'tensor_desc' is not supported by the
type of var %s
."
,
"Getting 'tensor_desc' is not supported by the
%s type variable
."
,
this
->
Name
());
this
->
Name
())
)
;
}
}
}
}
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
VarDesc
::
tensor_descs
()
const
{
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
VarDesc
::
tensor_descs
()
const
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE_EQ
(
desc_
.
has_type
(),
true
,
platform
::
errors
::
NotFound
(
"The variable's type was not be set."
));
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
res
;
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
res
;
res
.
reserve
(
GetTensorDescNum
());
res
.
reserve
(
GetTensorDescNum
());
switch
(
desc_
.
type
().
type
())
{
switch
(
desc_
.
type
().
type
())
{
...
@@ -221,16 +229,19 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
...
@@ -221,16 +229,19 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
}
}
return
res
;
return
res
;
default:
default:
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Getting 'tensor_descs' is not supported by the type of var "
"Getting 'tensor_descs' is not supported by the %s type variable."
,
"%s."
,
this
->
Name
()));
this
->
Name
());
}
}
}
}
proto
::
VarType
::
TensorDesc
*
VarDesc
::
mutable_tensor_desc
()
{
proto
::
VarType
::
TensorDesc
*
VarDesc
::
mutable_tensor_desc
()
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
desc_
.
has_type
(),
true
,
platform
::
errors
::
NotFound
(
"The variable's type was not be set."
));
PADDLE_ENFORCE_EQ
(
desc_
.
type
().
has_type
(),
true
,
platform
::
errors
::
NotFound
(
"The variable's type was not be set."
));
switch
(
desc_
.
type
().
type
())
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
SELECTED_ROWS
:
case
proto
::
VarType
::
SELECTED_ROWS
:
return
desc_
.
mutable_type
()
->
mutable_selected_rows
();
return
desc_
.
mutable_type
()
->
mutable_selected_rows
();
...
@@ -240,15 +251,19 @@ proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
...
@@ -240,15 +251,19 @@ proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
return
desc_
.
mutable_type
()
->
mutable_tensor_array
()
->
mutable_tensor
();
return
desc_
.
mutable_type
()
->
mutable_tensor_array
()
->
mutable_tensor
();
default:
default:
PADDLE_THROW
(
PADDLE_THROW
(
"Getting 'mutable_tensor_desc' is not supported by the type of var
"
platform
::
errors
::
Unavailable
(
"Getting 'mutable_tensor_desc' is not
"
"%s
."
,
"supported by the %s type variable
."
,
this
->
Name
(
));
this
->
Name
()
));
}
}
}
}
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
VarDesc
::
mutable_tensor_descs
()
{
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
VarDesc
::
mutable_tensor_descs
()
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
desc_
.
has_type
(),
true
,
platform
::
errors
::
NotFound
(
"The variable's type was not be set."
));
PADDLE_ENFORCE_EQ
(
desc_
.
type
().
has_type
(),
true
,
platform
::
errors
::
NotFound
(
"The variable's type was not be set."
));
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
res
;
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
res
;
res
.
reserve
(
GetTensorDescNum
());
res
.
reserve
(
GetTensorDescNum
());
switch
(
desc_
.
type
().
type
())
{
switch
(
desc_
.
type
().
type
())
{
...
@@ -259,10 +274,9 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
...
@@ -259,10 +274,9 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
}
}
return
res
;
return
res
;
default:
default:
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Getting 'tensor_descs' is not supported by the type of var "
"Getting 'tensor_descs' is not supported by the %s type variable."
,
"%s."
,
this
->
Name
()));
this
->
Name
());
}
}
}
}
...
...
paddle/fluid/framework/var_type.h
浏览文件 @
79149c8e
...
@@ -40,7 +40,8 @@ inline proto::VarType::Type ToVarType(int type) {
...
@@ -40,7 +40,8 @@ inline proto::VarType::Type ToVarType(int type) {
case
proto
::
VarType
::
READER
:
case
proto
::
VarType
::
READER
:
return
static_cast
<
proto
::
VarType
::
Type
>
(
type
);
return
static_cast
<
proto
::
VarType
::
Type
>
(
type
);
default:
default:
PADDLE_THROW
(
"ToVarType:Unsupported type %d"
,
type
);
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"ToVarType method Unsupported type %d."
,
type
));
}
}
}
}
...
@@ -66,7 +67,8 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
...
@@ -66,7 +67,8 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
visitor
(
var
.
Get
<
FetchList
>
());
visitor
(
var
.
Get
<
FetchList
>
());
return
;
return
;
default:
default:
PADDLE_THROW
(
"Not supported visit type, %s"
,
ToTypeName
(
var
.
Type
()));
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Not supported visit type %s."
,
ToTypeName
(
var
.
Type
())));
}
}
}
}
...
...
paddle/fluid/framework/var_type_traits.cc
浏览文件 @
79149c8e
...
@@ -46,12 +46,14 @@ struct VarIdToTypeIndexMapInitializerImpl {
...
@@ -46,12 +46,14 @@ struct VarIdToTypeIndexMapInitializerImpl {
static_assert
(
!
std
::
is_same
<
Type
,
void
>::
value
,
"Type cannot be void"
);
static_assert
(
!
std
::
is_same
<
Type
,
void
>::
value
,
"Type cannot be void"
);
constexpr
int
kId
=
VarTypeTrait
<
Type
>::
kId
;
constexpr
int
kId
=
VarTypeTrait
<
Type
>::
kId
;
auto
type
=
std
::
type_index
(
typeid
(
Type
));
auto
type
=
std
::
type_index
(
typeid
(
Type
));
PADDLE_ENFORCE
(
id_to_type
->
count
(
kId
)
==
0
,
PADDLE_ENFORCE_EQ
(
"Registered duplicate type id %d for type %s"
,
kId
,
id_to_type
->
count
(
kId
),
0
,
type
.
name
());
platform
::
errors
::
AlreadyExists
(
PADDLE_ENFORCE
(
type_to_id
->
count
(
type
)
==
0
,
"Registered duplicate type id %d for type %s."
,
kId
,
type
.
name
()));
"Registered duplicate type_index %s for id %d"
,
type
.
name
(),
PADDLE_ENFORCE_EQ
(
kId
);
type_to_id
->
count
(
type
),
0
,
platform
::
errors
::
AlreadyExists
(
"Registered duplicate type index %s for id %d."
,
type
.
name
(),
kId
));
id_to_type
->
emplace
(
kId
,
type
);
id_to_type
->
emplace
(
kId
,
type
);
type_to_id
->
emplace
(
type
,
kId
);
type_to_id
->
emplace
(
type
,
kId
);
VarIdToTypeIndexMapInitializerImpl
<
kStart
+
1
,
kEnd
,
VarIdToTypeIndexMapInitializerImpl
<
kStart
+
1
,
kEnd
,
...
@@ -79,15 +81,17 @@ struct VarIdToTypeIndexMapHolder {
...
@@ -79,15 +81,17 @@ struct VarIdToTypeIndexMapHolder {
public:
public:
static
const
std
::
type_index
&
ToTypeIndex
(
int
var_id
)
{
static
const
std
::
type_index
&
ToTypeIndex
(
int
var_id
)
{
auto
it
=
Instance
().
id_to_type_map_
.
find
(
var_id
);
auto
it
=
Instance
().
id_to_type_map_
.
find
(
var_id
);
PADDLE_ENFORCE
(
it
!=
Instance
().
id_to_type_map_
.
end
(),
PADDLE_ENFORCE_NE
(
it
,
Instance
().
id_to_type_map_
.
end
(),
"VarId %d is not registered."
,
var_id
);
platform
::
errors
::
NotFound
(
"Variable Id %d is not registered."
,
var_id
));
return
it
->
second
;
return
it
->
second
;
}
}
static
int
ToTypeId
(
const
std
::
type_index
&
type
)
{
static
int
ToTypeId
(
const
std
::
type_index
&
type
)
{
auto
it
=
Instance
().
type_to_id_map_
.
find
(
type
);
auto
it
=
Instance
().
type_to_id_map_
.
find
(
type
);
PADDLE_ENFORCE
(
it
!=
Instance
().
type_to_id_map_
.
end
(),
PADDLE_ENFORCE_NE
(
it
,
Instance
().
type_to_id_map_
.
end
(),
"VarType %s is not registered."
,
type
.
name
());
platform
::
errors
::
NotFound
(
"Variable Type %s is not registered."
,
type
.
name
()));
return
it
->
second
;
return
it
->
second
;
}
}
...
...
paddle/fluid/framework/variable_helper.cc
浏览文件 @
79149c8e
...
@@ -50,11 +50,11 @@ void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
...
@@ -50,11 +50,11 @@ void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
}
else
if
(
var_type
==
proto
::
VarType
::
RAW
)
{
}
else
if
(
var_type
==
proto
::
VarType
::
RAW
)
{
// GetMutable will be called in operator
// GetMutable will be called in operator
}
else
{
}
else
{
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Variable type %d is not in "
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, RAW]"
,
"LOD_RANK_TABLE, PLACE_LIST, READER, RAW]
.
"
,
var_type
);
var_type
)
)
;
}
}
}
}
...
@@ -76,7 +76,8 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) {
...
@@ -76,7 +76,8 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) {
auto
*
dst_t
=
tmp_grad_slr
->
mutable_value
();
auto
*
dst_t
=
tmp_grad_slr
->
mutable_value
();
framework
::
TensorCopy
(
src_t
,
cpu_place
,
dst_t
);
framework
::
TensorCopy
(
src_t
,
cpu_place
,
dst_t
);
}
else
{
}
else
{
PADDLE_THROW
(
"unknown var type to copy"
);
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Unknown variable type to copy."
));
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录