Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5f6fd26f
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5f6fd26f
编写于
5月 15, 2018
作者:
Y
Yu Yang
提交者:
GitHub
5月 15, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10572 from reyoung/feature/polish_visit_data_type
Polish data_type.h
上级
30c350b9
c70ddb0a
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
114 addition
and
108 deletion
+114
-108
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+3
-3
paddle/fluid/framework/data_type.cc
paddle/fluid/framework/data_type.cc
+101
-0
paddle/fluid/framework/data_type.h
paddle/fluid/framework/data_type.h
+5
-62
paddle/fluid/framework/framework.proto
paddle/fluid/framework/framework.proto
+2
-0
paddle/fluid/framework/op_kernel_type_test.cc
paddle/fluid/framework/op_kernel_type_test.cc
+1
-1
paddle/fluid/framework/tensor_impl.h
paddle/fluid/framework/tensor_impl.h
+2
-42
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
5f6fd26f
...
...
@@ -5,11 +5,11 @@ proto_library(framework_proto SRCS framework.proto)
cc_library
(
ddim SRCS ddim.cc DEPS eigen3 boost
)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
nv_test
(
dim_test SRCS dim_test.cu DEPS ddim
)
cc_library
(
data_type SRCS data_type.cc DEPS framework_proto ddim device_context
)
if
(
WITH_GPU
)
nv_library
(
tensor SRCS tensor.cc tensor_util.cu DEPS
ddim place memory device_context framework_proto
)
nv_library
(
tensor SRCS tensor.cc tensor_util.cu DEPS
place memory data_type
)
else
()
cc_library
(
tensor SRCS tensor.cc tensor_util.cc DEPS
ddim place memory device_context framework_proto
)
cc_library
(
tensor SRCS tensor.cc tensor_util.cc DEPS
place memory data_type
)
endif
()
cc_test
(
tensor_test SRCS tensor_test.cc DEPS tensor
)
...
...
paddle/fluid/framework/data_type.cc
0 → 100644
浏览文件 @
5f6fd26f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/data_type.h"
#include <stdint.h>
#include <string>
#include <unordered_map>
namespace
paddle
{
namespace
framework
{
struct
DataTypeMap
{
std
::
unordered_map
<
std
::
type_index
,
proto
::
VarType
::
Type
>
cpp_to_proto_
;
std
::
unordered_map
<
int
,
std
::
type_index
>
proto_to_cpp_
;
std
::
unordered_map
<
int
,
std
::
string
>
proto_to_str_
;
std
::
unordered_map
<
std
::
type_index
,
size_t
>
cpp_to_size_
;
};
static
DataTypeMap
*
InitDataTypeMap
();
static
DataTypeMap
&
gDataTypeMap
()
{
static
DataTypeMap
*
g_data_type_map_
=
InitDataTypeMap
();
return
*
g_data_type_map_
;
}
template
<
typename
T
>
static
inline
void
RegisterType
(
DataTypeMap
*
map
,
proto
::
VarType
::
Type
proto_type
,
const
std
::
string
&
name
)
{
map
->
proto_to_cpp_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
typeid
(
T
));
map
->
cpp_to_proto_
.
emplace
(
typeid
(
T
),
proto_type
);
map
->
proto_to_str_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
name
);
map
->
cpp_to_size_
.
emplace
(
typeid
(
T
),
sizeof
(
T
));
}
static
DataTypeMap
*
InitDataTypeMap
()
{
auto
retv
=
new
DataTypeMap
();
#define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here.
RegType
(
platform
::
float16
,
proto
::
VarType
::
FP16
);
RegType
(
float
,
proto
::
VarType
::
FP32
);
RegType
(
double
,
proto
::
VarType
::
FP64
);
RegType
(
int
,
proto
::
VarType
::
INT32
);
RegType
(
int64_t
,
proto
::
VarType
::
INT64
);
RegType
(
bool
,
proto
::
VarType
::
BOOL
);
RegType
(
size_t
,
proto
::
VarType
::
SIZE_T
);
RegType
(
int16_t
,
proto
::
VarType
::
INT16
);
#undef RegType
return
retv
;
}
proto
::
VarType
::
Type
ToDataType
(
std
::
type_index
type
)
{
auto
it
=
gDataTypeMap
().
cpp_to_proto_
.
find
(
type
);
if
(
it
!=
gDataTypeMap
().
cpp_to_proto_
.
end
())
{
return
it
->
second
;
}
PADDLE_THROW
(
"Not support %s as tensor type"
,
type
.
name
());
}
std
::
type_index
ToTypeIndex
(
proto
::
VarType
::
Type
type
)
{
auto
it
=
gDataTypeMap
().
proto_to_cpp_
.
find
(
static_cast
<
int
>
(
type
));
if
(
it
!=
gDataTypeMap
().
proto_to_cpp_
.
end
())
{
return
it
->
second
;
}
PADDLE_THROW
(
"Not support proto::VarType::Type(%d) as tensor type"
,
static_cast
<
int
>
(
type
));
}
std
::
string
DataTypeToString
(
const
proto
::
VarType
::
Type
type
)
{
auto
it
=
gDataTypeMap
().
proto_to_str_
.
find
(
static_cast
<
int
>
(
type
));
if
(
it
!=
gDataTypeMap
().
proto_to_str_
.
end
())
{
return
it
->
second
;
}
PADDLE_THROW
(
"Not support proto::VarType::Type(%d) as tensor type"
,
static_cast
<
int
>
(
type
));
}
size_t
SizeOfType
(
std
::
type_index
type
)
{
auto
it
=
gDataTypeMap
().
cpp_to_size_
.
find
(
type
);
if
(
it
!=
gDataTypeMap
().
cpp_to_size_
.
end
())
{
return
it
->
second
;
}
PADDLE_THROW
(
"Not support %s as tensor type"
,
type
.
name
());
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/data_type.h
浏览文件 @
5f6fd26f
...
...
@@ -17,51 +17,14 @@ limitations under the License. */
#include <typeindex>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
framework
{
inline
proto
::
VarType
::
Type
ToDataType
(
std
::
type_index
type
)
{
if
(
typeid
(
platform
::
float16
).
hash_code
()
==
type
.
hash_code
())
{
return
proto
::
VarType
::
FP16
;
}
else
if
(
typeid
(
const
float
).
hash_code
()
==
type
.
hash_code
())
{
// CPPLint complains Using C-style cast. Use static_cast<float>() instead
// One fix to this is to replace float with const float because
// typeid(T) == typeid(const T)
// http://en.cppreference.com/w/cpp/language/typeid
return
proto
::
VarType
::
FP32
;
}
else
if
(
typeid
(
const
double
).
hash_code
()
==
type
.
hash_code
())
{
return
proto
::
VarType
::
FP64
;
}
else
if
(
typeid
(
const
int
).
hash_code
()
==
type
.
hash_code
())
{
return
proto
::
VarType
::
INT32
;
}
else
if
(
typeid
(
const
int64_t
).
hash_code
()
==
type
.
hash_code
())
{
return
proto
::
VarType
::
INT64
;
}
else
if
(
typeid
(
const
bool
).
hash_code
()
==
type
.
hash_code
())
{
return
proto
::
VarType
::
BOOL
;
}
else
{
PADDLE_THROW
(
"Not supported"
);
}
}
inline
std
::
type_index
ToTypeIndex
(
proto
::
VarType
::
Type
type
)
{
switch
(
type
)
{
case
proto
::
VarType
::
FP16
:
return
typeid
(
platform
::
float16
);
case
proto
::
VarType
::
FP32
:
return
typeid
(
float
);
case
proto
::
VarType
::
FP64
:
return
typeid
(
double
);
case
proto
::
VarType
::
INT32
:
return
typeid
(
int
);
case
proto
::
VarType
::
INT64
:
return
typeid
(
int64_t
);
case
proto
::
VarType
::
BOOL
:
return
typeid
(
bool
);
default:
PADDLE_THROW
(
"Not support type %d"
,
type
);
}
}
extern
proto
::
VarType
::
Type
ToDataType
(
std
::
type_index
type
);
extern
std
::
type_index
ToTypeIndex
(
proto
::
VarType
::
Type
type
);
template
<
typename
Visitor
>
inline
void
VisitDataType
(
proto
::
VarType
::
Type
type
,
Visitor
visitor
)
{
...
...
@@ -89,32 +52,12 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
}
}
inline
std
::
string
DataTypeToString
(
const
proto
::
VarType
::
Type
type
)
{
switch
(
type
)
{
case
proto
::
VarType
::
FP16
:
return
"float16"
;
case
proto
::
VarType
::
FP32
:
return
"float32"
;
case
proto
::
VarType
::
FP64
:
return
"float64"
;
case
proto
::
VarType
::
INT16
:
return
"int16"
;
case
proto
::
VarType
::
INT32
:
return
"int32"
;
case
proto
::
VarType
::
INT64
:
return
"int64"
;
case
proto
::
VarType
::
BOOL
:
return
"bool"
;
default:
PADDLE_THROW
(
"Not support type %d"
,
type
);
}
}
extern
std
::
string
DataTypeToString
(
const
proto
::
VarType
::
Type
type
);
extern
size_t
SizeOfType
(
std
::
type_index
type
);
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
proto
::
VarType
::
Type
&
type
)
{
out
<<
DataTypeToString
(
type
);
return
out
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/framework.proto
浏览文件 @
5f6fd26f
...
...
@@ -101,6 +101,8 @@ message VarType {
FP16
=
4
;
FP32
=
5
;
FP64
=
6
;
// Tensor<size_t> is used in C++.
SIZE_T
=
19
;
// Other types that may need additional descriptions
LOD_TENSOR
=
7
;
...
...
paddle/fluid/framework/op_kernel_type_test.cc
浏览文件 @
5f6fd26f
...
...
@@ -27,7 +27,7 @@ TEST(OpKernelType, ToString) {
LibraryType
::
kCUDNN
);
ASSERT_EQ
(
paddle
::
framework
::
KernelTypeToString
(
op_kernel_type
),
"data_type[float
32
]:data_layout[NCHW]:place[CPUPlace]:library_type["
"data_type[float]:data_layout[NCHW]:place[CPUPlace]:library_type["
"CUDNN]"
);
}
...
...
paddle/fluid/framework/tensor_impl.h
浏览文件 @
5f6fd26f
...
...
@@ -13,54 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
framework
{
template
<
typename
...
T
>
struct
SizeOfTypeFunctor
;
template
<
typename
T
>
struct
SizeOfTypeFunctor
<
T
>
{
size_t
operator
()(
std
::
type_index
type
)
const
{
if
(
typeid
(
T
).
hash_code
()
==
type
.
hash_code
())
{
return
sizeof
(
T
);
}
else
{
return
0UL
;
}
}
};
template
<
>
struct
SizeOfTypeFunctor
<>
{
size_t
operator
()(
std
::
type_index
type
)
const
{
return
0UL
;
}
};
template
<
typename
HEAD
,
typename
...
TAIL
>
struct
SizeOfTypeFunctor
<
HEAD
,
TAIL
...
>
{
size_t
operator
()(
std
::
type_index
type
)
const
{
SizeOfTypeFunctor
<
HEAD
>
head
;
size_t
head_size
=
head
(
type
);
if
(
head_size
!=
0
)
{
return
head_size
;
}
SizeOfTypeFunctor
<
TAIL
...
>
tail
;
return
tail
(
type
);
}
};
static
inline
size_t
SizeOfType
(
std
::
type_index
type
)
{
SizeOfTypeFunctor
<
int
,
float
,
double
,
int16_t
,
int64_t
,
bool
,
size_t
,
platform
::
float16
>
functor
;
size_t
size
=
functor
(
type
);
PADDLE_ENFORCE
(
size
!=
0UL
,
"Cannot get size of type %s"
,
type
.
name
());
return
size
;
}
extern
size_t
SizeOfType
(
std
::
type_index
type
);
inline
void
Tensor
::
check_memory_size
()
const
{
PADDLE_ENFORCE_NOT_NULL
(
holder_
,
"Tensor holds no memory. Call Tensor::mutable_data first."
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录