Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
1ad8f821
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
337
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1ad8f821
编写于
5月 20, 2018
作者:
朔-望
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify to 2 spaces indent & format code & rm build folder
上级
e35ef6fe
变更
102
展开全部
隐藏空白更改
内联
并排
Showing
102 changed file
with
13543 addition
and
14086 deletion
+13543
-14086
.clang-format
.clang-format
+0
-1
cmake-build-release/compile_commands.json
cmake-build-release/compile_commands.json
+0
-312
src/common/log.h
src/common/log.h
+86
-88
src/common/types.h
src/common/types.h
+24
-24
src/common/variant.h
src/common/variant.h
+50
-50
src/framework/attribute.h
src/framework/attribute.h
+85
-85
src/framework/block_desc.cpp
src/framework/block_desc.cpp
+16
-16
src/framework/block_desc.h
src/framework/block_desc.h
+23
-24
src/framework/data_layout.h
src/framework/data_layout.h
+29
-29
src/framework/data_transform.cpp
src/framework/data_transform.cpp
+52
-52
src/framework/ddim.cc
src/framework/ddim.cc
+174
-176
src/framework/ddim.h
src/framework/ddim.h
+57
-57
src/framework/dim.h
src/framework/dim.h
+139
-141
src/framework/executor.cpp
src/framework/executor.cpp
+48
-48
src/framework/executor.h
src/framework/executor.h
+12
-12
src/framework/framework.pb.cpp
src/framework/framework.pb.cpp
+5392
-5493
src/framework/framework.pb.h
src/framework/framework.pb.h
+3436
-3483
src/framework/lod_tensor.cc
src/framework/lod_tensor.cc
+230
-232
src/framework/lod_tensor.h
src/framework/lod_tensor.h
+58
-58
src/framework/op_desc.cpp
src/framework/op_desc.cpp
+30
-30
src/framework/op_desc.h
src/framework/op_desc.h
+14
-14
src/framework/op_info.h
src/framework/op_info.h
+50
-50
src/framework/op_kernel_type.h
src/framework/op_kernel_type.h
+22
-23
src/framework/operator.cpp
src/framework/operator.cpp
+1
-1
src/framework/operator.h
src/framework/operator.h
+32
-32
src/framework/paddle_mobile_object.h
src/framework/paddle_mobile_object.h
+7
-7
src/framework/program-optimize/node.cpp
src/framework/program-optimize/node.cpp
+42
-42
src/framework/program-optimize/node.h
src/framework/program-optimize/node.h
+15
-15
src/framework/program-optimize/program_optimize.cpp
src/framework/program-optimize/program_optimize.cpp
+37
-38
src/framework/program-optimize/program_optimize.h
src/framework/program-optimize/program_optimize.h
+9
-9
src/framework/program.h
src/framework/program.h
+5
-5
src/framework/program_desc.cpp
src/framework/program_desc.cpp
+5
-5
src/framework/program_desc.h
src/framework/program_desc.h
+7
-7
src/framework/scope.cc
src/framework/scope.cc
+58
-58
src/framework/scope.h
src/framework/scope.h
+34
-34
src/framework/selected_rows.h
src/framework/selected_rows.h
+48
-48
src/framework/tensor.h
src/framework/tensor.h
+263
-264
src/framework/tensor_util.cc
src/framework/tensor_util.cc
+131
-132
src/framework/tensor_util.h
src/framework/tensor_util.h
+10
-10
src/framework/var_desc.h
src/framework/var_desc.h
+45
-45
src/framework/var_type.h
src/framework/var_type.h
+8
-8
src/framework/variable.h
src/framework/variable.h
+51
-51
src/io.cpp
src/io.cpp
+329
-333
src/io.h
src/io.h
+4
-4
src/memory/t_malloc.cc
src/memory/t_malloc.cc
+13
-13
src/memory/t_malloc.h
src/memory/t_malloc.h
+7
-7
src/operators/batchnorm_op.cpp
src/operators/batchnorm_op.cpp
+2
-2
src/operators/batchnorm_op.h
src/operators/batchnorm_op.h
+19
-19
src/operators/concat_op.cpp
src/operators/concat_op.cpp
+27
-27
src/operators/concat_op.h
src/operators/concat_op.h
+15
-16
src/operators/conv_op.cpp
src/operators/conv_op.cpp
+24
-24
src/operators/conv_op.h
src/operators/conv_op.h
+19
-19
src/operators/elementwise_add_op.cpp
src/operators/elementwise_add_op.cpp
+2
-2
src/operators/elementwise_add_op.h
src/operators/elementwise_add_op.h
+16
-16
src/operators/kernel/arm/batchnorm_kernel.cpp
src/operators/kernel/arm/batchnorm_kernel.cpp
+60
-61
src/operators/kernel/arm/concat_kernel.cpp
src/operators/kernel/arm/concat_kernel.cpp
+70
-71
src/operators/kernel/arm/conv_kernel.cpp
src/operators/kernel/arm/conv_kernel.cpp
+114
-116
src/operators/kernel/arm/elementwise_add_kernel.cpp
src/operators/kernel/arm/elementwise_add_kernel.cpp
+8
-8
src/operators/kernel/arm/lrn_kernel.cpp
src/operators/kernel/arm/lrn_kernel.cpp
+15
-15
src/operators/kernel/arm/mul_kernel.cpp
src/operators/kernel/arm/mul_kernel.cpp
+21
-21
src/operators/kernel/arm/pool_kernel.cpp
src/operators/kernel/arm/pool_kernel.cpp
+37
-37
src/operators/kernel/batchnorm_kernel.h
src/operators/kernel/batchnorm_kernel.h
+2
-2
src/operators/kernel/concat_kernel.h
src/operators/kernel/concat_kernel.h
+2
-2
src/operators/kernel/conv_kernel.h
src/operators/kernel/conv_kernel.h
+2
-2
src/operators/kernel/elementwise_add_kernel.h
src/operators/kernel/elementwise_add_kernel.h
+2
-2
src/operators/kernel/lrn_kernel.h
src/operators/kernel/lrn_kernel.h
+32
-35
src/operators/kernel/mul_kernel.h
src/operators/kernel/mul_kernel.h
+2
-2
src/operators/kernel/pool_kernel.h
src/operators/kernel/pool_kernel.h
+2
-2
src/operators/lrn_op.cpp
src/operators/lrn_op.cpp
+2
-2
src/operators/lrn_op.h
src/operators/lrn_op.h
+18
-18
src/operators/math/elementwise_op_function.h
src/operators/math/elementwise_op_function.h
+137
-138
src/operators/math/im2col.cc
src/operators/math/im2col.cc
+215
-232
src/operators/math/im2col.h
src/operators/math/im2col.h
+9
-10
src/operators/math/math_function.cc
src/operators/math/math_function.cc
+62
-62
src/operators/math/pool3x3.h
src/operators/math/pool3x3.h
+2
-2
src/operators/math/pool_2x2.h
src/operators/math/pool_2x2.h
+2
-2
src/operators/math/pooling.cpp
src/operators/math/pooling.cpp
+47
-49
src/operators/math/pooling.h
src/operators/math/pooling.h
+13
-14
src/operators/math/transform.h
src/operators/math/transform.h
+12
-12
src/operators/math/vol2col.cc
src/operators/math/vol2col.cc
+149
-162
src/operators/math/vol2col.h
src/operators/math/vol2col.h
+8
-8
src/operators/mul_op.cpp
src/operators/mul_op.cpp
+21
-21
src/operators/mul_op.h
src/operators/mul_op.h
+18
-18
src/operators/op_param.cpp
src/operators/op_param.cpp
+18
-19
src/operators/op_param.h
src/operators/op_param.h
+277
-279
src/operators/pool_op.cpp
src/operators/pool_op.cpp
+25
-25
src/operators/pool_op.h
src/operators/pool_op.h
+17
-17
src/platform/data_type.h
src/platform/data_type.h
+83
-83
src/platform/macros.h
src/platform/macros.h
+5
-5
test/common/test_log.cpp
test/common/test_log.cpp
+11
-11
test/framework/executor_for_test.cpp
test/framework/executor_for_test.cpp
+30
-30
test/framework/executor_for_test.h
test/framework/executor_for_test.h
+4
-4
test/framework/test_load.cpp
test/framework/test_load.cpp
+5
-5
test/framework/test_optimize.cpp
test/framework/test_optimize.cpp
+9
-9
test/operators/test_batchnorm_op.cpp
test/operators/test_batchnorm_op.cpp
+134
-138
test/operators/test_concat_op.cpp
test/operators/test_concat_op.cpp
+137
-141
test/operators/test_cov_op.cpp
test/operators/test_cov_op.cpp
+18
-18
test/operators/test_elementwise_add_op.cpp
test/operators/test_elementwise_add_op.cpp
+110
-113
test/operators/test_lrn_op.cpp
test/operators/test_lrn_op.cpp
+110
-115
test/operators/test_mul_op.cpp
test/operators/test_mul_op.cpp
+129
-131
test/operators/test_pool_op.cpp
test/operators/test_pool_op.cpp
+18
-18
test/test_helper.h
test/test_helper.h
+7
-8
未找到文件。
.clang-format
浏览文件 @
1ad8f821
...
@@ -2,5 +2,4 @@
...
@@ -2,5 +2,4 @@
Language: Cpp
Language: Cpp
BasedOnStyle: LLVM
BasedOnStyle: LLVM
Standard: Cpp11
Standard: Cpp11
IndentWidth: 4
...
...
cmake-build-release/compile_commands.json
已删除
100644 → 0
浏览文件 @
e35ef6fe
此差异已折叠。
点击以展开。
src/common/log.h
浏览文件 @
1ad8f821
...
@@ -28,15 +28,15 @@ SOFTWARE.
...
@@ -28,15 +28,15 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
enum
LogLevel
{
enum
LogLevel
{
kNO_LOG
,
kNO_LOG
,
kLOG_ERROR
,
kLOG_ERROR
,
kLOG_WARNING
,
kLOG_WARNING
,
kLOG_INFO
,
kLOG_INFO
,
kLOG_DEBUG
,
kLOG_DEBUG
,
kLOG_DEBUG1
,
kLOG_DEBUG1
,
kLOG_DEBUG2
,
kLOG_DEBUG2
,
kLOG_DEBUG3
,
kLOG_DEBUG3
,
kLOG_DEBUG4
kLOG_DEBUG4
};
};
// log level
// log level
...
@@ -49,119 +49,117 @@ struct ToLog;
...
@@ -49,119 +49,117 @@ struct ToLog;
struct
Print
;
struct
Print
;
struct
Print
{
struct
Print
{
friend
struct
ToLog
;
friend
struct
ToLog
;
template
<
typename
T
>
Print
&
operator
<<
(
T
const
&
value
)
{
template
<
typename
T
>
Print
&
operator
<<
(
T
const
&
value
)
{
buffer_
<<
value
;
buffer_
<<
value
;
return
*
this
;
return
*
this
;
}
}
private:
private:
void
print
(
LogLevel
level
)
{
void
print
(
LogLevel
level
)
{
buffer_
<<
std
::
endl
;
buffer_
<<
std
::
endl
;
if
(
level
==
kLOG_ERROR
)
{
if
(
level
==
kLOG_ERROR
)
{
std
::
cerr
<<
buffer_
.
str
();
std
::
cerr
<<
buffer_
.
str
();
}
else
{
}
else
{
std
::
cout
<<
buffer_
.
str
();
std
::
cout
<<
buffer_
.
str
();
}
}
}
std
::
ostringstream
buffer_
;
}
std
::
ostringstream
buffer_
;
};
};
struct
ToLog
{
struct
ToLog
{
ToLog
(
LogLevel
level
=
kLOG_DEBUG
,
const
std
::
string
&
info
=
""
)
ToLog
(
LogLevel
level
=
kLOG_DEBUG
,
const
std
::
string
&
info
=
""
)
:
level_
(
level
)
{
:
level_
(
level
)
{
unsigned
blanks
=
unsigned
blanks
=
(
unsigned
)(
level
>
kLOG_DEBUG
?
(
level
-
kLOG_DEBUG
)
*
4
:
1
);
(
unsigned
)(
level
>
kLOG_DEBUG
?
(
level
-
kLOG_DEBUG
)
*
4
:
1
);
printer_
<<
logs
[
level
]
<<
" "
<<
info
<<
":"
printer_
<<
logs
[
level
]
<<
" "
<<
info
<<
":"
<<
std
::
string
(
blanks
,
' '
);
<<
std
::
string
(
blanks
,
' '
);
}
}
template
<
typename
T
>
ToLog
&
operator
<<
(
T
const
&
value
)
{
template
<
typename
T
>
ToLog
&
operator
<<
(
T
const
&
value
)
{
printer_
<<
value
;
printer_
<<
value
;
return
*
this
;
return
*
this
;
}
}
~
ToLog
()
{
printer_
.
print
(
level_
);
}
~
ToLog
()
{
printer_
.
print
(
level_
);
}
private:
private:
LogLevel
level_
;
LogLevel
level_
;
Print
printer_
;
Print
printer_
;
};
};
#define LOG(level) \
#define LOG(level) \
if (level > paddle_mobile::log_level) {
\
if (level > paddle_mobile::log_level) {
\
} else
\
} else
\
paddle_mobile::ToLog(
\
paddle_mobile::ToLog(
\
level, (std::stringstream()
\
level,
\
<< "[file: "
\
(std::stringstream()
\
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1)
\
<< "[file: "
\
: __FILE__)
\
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) : __FILE__)
\
<< "] [line: " << __LINE__ << "] ")
\
<< "] [line: " << __LINE__ << "] ")
\
.str())
.str())
#define DLOG \
#define DLOG \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \
} else \
} else \
paddle_mobile::ToLog( \
paddle_mobile::ToLog( \
paddle_mobile::kLOG_DEBUG, \
paddle_mobile::kLOG_DEBUG, \
(std::stringstream() \
(std::stringstream() \
<< "[file: " \
<< "[file: " \
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) \
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) : __FILE__) \
: __FILE__) \
<< "] [line: " << __LINE__ << "] ") \
<< "] [line: " << __LINE__ << "] ") \
.str())
.str())
}
// namespace paddle_mobile
}
// namespace paddle_mobile
#define LOGF(level, format, ...) \
#define LOGF(level, format, ...) \
if (level > paddle_mobile::log_level) {
\
if (level > paddle_mobile::log_level) {
\
} else
\
} else
\
printf(format, ##__VA_ARGS__)
printf(format, ##__VA_ARGS__)
#define DLOGF(format, ...) \
#define DLOGF(format, ...) \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) {
\
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) {
\
} else
\
} else
\
printf(format, ##__VA_ARGS__)
printf(format, ##__VA_ARGS__)
#else
#else
namespace
paddle_mobile
{
namespace
paddle_mobile
{
enum
LogLevel
{
enum
LogLevel
{
kNO_LOG
,
kNO_LOG
,
kLOG_ERROR
,
kLOG_ERROR
,
kLOG_WARNING
,
kLOG_WARNING
,
kLOG_INFO
,
kLOG_INFO
,
kLOG_DEBUG
,
kLOG_DEBUG
,
kLOG_DEBUG1
,
kLOG_DEBUG1
,
kLOG_DEBUG2
,
kLOG_DEBUG2
,
kLOG_DEBUG3
,
kLOG_DEBUG3
,
kLOG_DEBUG4
kLOG_DEBUG4
};
};
struct
ToLog
;
struct
ToLog
;
struct
Print
{
struct
Print
{
friend
struct
ToLog
;
friend
struct
ToLog
;
template
<
typename
T
>
Print
&
operator
<<
(
T
const
&
value
)
{}
template
<
typename
T
>
Print
&
operator
<<
(
T
const
&
value
)
{}
private:
private:
};
};
struct
ToLog
{
struct
ToLog
{
ToLog
(
LogLevel
level
)
{}
ToLog
(
LogLevel
level
)
{}
template
<
typename
T
>
ToLog
&
operator
<<
(
T
const
&
value
)
{
return
*
this
;
}
template
<
typename
T
>
ToLog
&
operator
<<
(
T
const
&
value
)
{
return
*
this
;
}
};
};
#define LOG(level) \
#define LOG(level) \
if (true) {
\
if (true) {
\
} else
\
} else
\
paddle_mobile::ToLog(level)
paddle_mobile::ToLog(level)
#define DLOG \
#define DLOG \
if (true) {
\
if (true) {
\
} else
\
} else
\
paddle_mobile::ToLog(paddle_mobile::kLOG_DEBUG)
paddle_mobile::ToLog(paddle_mobile::kLOG_DEBUG)
#define LOGF(level, format, ...)
#define LOGF(level, format, ...)
...
...
src/common/types.h
浏览文件 @
1ad8f821
...
@@ -32,32 +32,32 @@ typedef DeviceType<kGPU_MALI> GPU_MALI;
...
@@ -32,32 +32,32 @@ typedef DeviceType<kGPU_MALI> GPU_MALI;
//! data type
//! data type
enum
DataType
{
enum
DataType
{
PM_INVALID
=
-
1
,
PM_INVALID
=
-
1
,
PM_HALF
=
0
,
PM_HALF
=
0
,
PM_FLOAT
=
1
,
PM_FLOAT
=
1
,
PM_DOUBLE
=
2
,
PM_DOUBLE
=
2
,
PM_INT8
=
3
,
PM_INT8
=
3
,
PM_INT16
=
4
,
PM_INT16
=
4
,
PM_INT32
=
5
,
PM_INT32
=
5
,
PM_INT64
=
6
,
PM_INT64
=
6
,
PM_UINT8
=
7
,
PM_UINT8
=
7
,
PM_UINT16
=
8
,
PM_UINT16
=
8
,
PM_UINT32
=
9
,
PM_UINT32
=
9
,
PM_STRING
=
10
,
PM_STRING
=
10
,
PM_BOOL
=
11
,
PM_BOOL
=
11
,
PM_SHAPE
=
12
,
PM_SHAPE
=
12
,
PM_TENSOR
=
13
PM_TENSOR
=
13
};
};
//!
//!
enum
PMStatus
{
enum
PMStatus
{
PMSuccess
=
0xFF
,
/*!< No errors */
PMSuccess
=
0xFF
,
/*!< No errors */
PMNotInitialized
=
0x01
,
/*!< Data not initialized. */
PMNotInitialized
=
0x01
,
/*!< Data not initialized. */
PMInvalidValue
=
0x02
,
/*!< Incorrect variable value. */
PMInvalidValue
=
0x02
,
/*!< Incorrect variable value. */
PMMemAllocFailed
=
0x03
,
/*!< Memory allocation error. */
PMMemAllocFailed
=
0x03
,
/*!< Memory allocation error. */
PMUnKownError
=
0x04
,
/*!< Unknown error. */
PMUnKownError
=
0x04
,
/*!< Unknown error. */
PMOutOfAuthority
=
0x05
,
/*!< Try to modified data not your own*/
PMOutOfAuthority
=
0x05
,
/*!< Try to modified data not your own*/
PMOutOfMem
=
0x06
,
/*!< OOM error*/
PMOutOfMem
=
0x06
,
/*!< OOM error*/
PMUnImplError
=
0x07
,
/*!< Unimplement error. */
PMUnImplError
=
0x07
,
/*!< Unimplement error. */
PMWrongDevice
=
0x08
/*!< un-correct device. */
PMWrongDevice
=
0x08
/*!< un-correct device. */
};
};
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/common/variant.h
浏览文件 @
1ad8f821
...
@@ -24,74 +24,74 @@ namespace paddle_mobile {
...
@@ -24,74 +24,74 @@ namespace paddle_mobile {
template
<
int
ID
,
typename
Type
>
struct
IDToType
{
typedef
Type
type_t
;
};
template
<
int
ID
,
typename
Type
>
struct
IDToType
{
typedef
Type
type_t
;
};
template
<
typename
F
,
typename
...
Ts
>
struct
VariantHelper
{
template
<
typename
F
,
typename
...
Ts
>
struct
VariantHelper
{
static
const
size_t
size
=
sizeof
(
F
)
>
VariantHelper
<
Ts
...
>::
size
static
const
size_t
size
=
sizeof
(
F
)
>
VariantHelper
<
Ts
...
>::
size
?
sizeof
(
F
)
?
sizeof
(
F
)
:
VariantHelper
<
Ts
...
>::
size
;
:
VariantHelper
<
Ts
...
>::
size
;
inline
static
void
Destroy
(
size_t
id
,
void
*
data
)
{
inline
static
void
Destroy
(
size_t
id
,
void
*
data
)
{
if
(
id
==
typeid
(
F
).
hash_code
())
{
if
(
id
==
typeid
(
F
).
hash_code
())
{
reinterpret_cast
<
F
*>
(
data
)
->~
F
();
reinterpret_cast
<
F
*>
(
data
)
->~
F
();
}
else
{
}
else
{
VariantHelper
<
Ts
...
>::
Destroy
(
id
,
data
);
VariantHelper
<
Ts
...
>::
Destroy
(
id
,
data
);
}
}
}
}
};
};
template
<
typename
F
>
struct
VariantHelper
<
F
>
{
template
<
typename
F
>
struct
VariantHelper
<
F
>
{
static
const
size_t
size
=
sizeof
(
F
);
static
const
size_t
size
=
sizeof
(
F
);
inline
static
void
Destroy
(
size_t
id
,
void
*
data
)
{
inline
static
void
Destroy
(
size_t
id
,
void
*
data
)
{
if
(
id
==
typeid
(
F
).
hash_code
())
{
if
(
id
==
typeid
(
F
).
hash_code
())
{
// reinterpret_cast<F*>(data)->~F();
// reinterpret_cast<F*>(data)->~F();
}
else
{
}
else
{
// std::cout << "未匹配到 " << std::endl;
// std::cout << "未匹配到 " << std::endl;
}
}
}
}
};
};
template
<
size_t
size
>
class
RawData
{
template
<
size_t
size
>
class
RawData
{
public:
public:
char
data
[
size
];
char
data
[
size
];
RawData
()
{}
RawData
()
{}
RawData
(
const
RawData
&
raw_data
)
{
strcpy
(
data
,
raw_data
.
data
);
}
RawData
(
const
RawData
&
raw_data
)
{
strcpy
(
data
,
raw_data
.
data
);
}
// void operator=(const RawData &raw_data){
// void operator=(const RawData &raw_data){
// strcpy(data, raw_data.data);
// strcpy(data, raw_data.data);
// }
// }
};
};
template
<
typename
...
Ts
>
struct
Variant
{
template
<
typename
...
Ts
>
struct
Variant
{
Variant
(
const
Variant
&
variant
)
{
Variant
(
const
Variant
&
variant
)
{
// std::cout << " 赋值构造函数 " << std::endl;
// std::cout << " 赋值构造函数 " << std::endl;
type_id
=
variant
.
type_id
;
type_id
=
variant
.
type_id
;
data
=
variant
.
data
;
data
=
variant
.
data
;
}
}
Variant
()
:
type_id
(
invalid_type
())
{}
Variant
()
:
type_id
(
invalid_type
())
{}
~
Variant
()
{
~
Variant
()
{
// helper::Destroy(type_id, &data);
// helper::Destroy(type_id, &data);
}
}
template
<
typename
T
,
typename
...
Args
>
void
Set
(
Args
&&
...
args
)
{
template
<
typename
T
,
typename
...
Args
>
void
Set
(
Args
&&
...
args
)
{
helper
::
Destroy
(
type_id
,
&
data
);
helper
::
Destroy
(
type_id
,
&
data
);
new
(
&
data
)
T
(
std
::
forward
<
Args
>
(
args
)...);
new
(
&
data
)
T
(
std
::
forward
<
Args
>
(
args
)...);
type_id
=
typeid
(
T
).
hash_code
();
type_id
=
typeid
(
T
).
hash_code
();
}
}
template
<
typename
T
>
T
&
Get
()
const
{
template
<
typename
T
>
T
&
Get
()
const
{
if
(
type_id
==
typeid
(
T
).
hash_code
())
{
if
(
type_id
==
typeid
(
T
).
hash_code
())
{
return
*
const_cast
<
T
*>
(
reinterpret_cast
<
const
T
*>
(
&
data
));
return
*
const_cast
<
T
*>
(
reinterpret_cast
<
const
T
*>
(
&
data
));
}
else
{
}
else
{
// std::cout << " bad cast in variant " << std::endl;
// std::cout << " bad cast in variant " << std::endl;
throw
std
::
bad_cast
();
throw
std
::
bad_cast
();
}
}
}
}
size_t
TypeId
()
const
{
return
type_id
;
}
size_t
TypeId
()
const
{
return
type_id
;
}
private:
private:
static
inline
size_t
invalid_type
()
{
return
typeid
(
void
).
hash_code
();
}
static
inline
size_t
invalid_type
()
{
return
typeid
(
void
).
hash_code
();
}
typedef
VariantHelper
<
Ts
...
>
helper
;
typedef
VariantHelper
<
Ts
...
>
helper
;
size_t
type_id
;
size_t
type_id
;
RawData
<
helper
::
size
>
data
;
RawData
<
helper
::
size
>
data
;
};
};
template
<
typename
T
>
struct
Vistor
{
typedef
T
type_t
;
};
template
<
typename
T
>
struct
Vistor
{
typedef
T
type_t
;
};
...
...
src/framework/attribute.h
浏览文件 @
1ad8f821
...
@@ -27,102 +27,102 @@ namespace framework {
...
@@ -27,102 +27,102 @@ namespace framework {
class
BlockDesc
;
class
BlockDesc
;
class
Attribute
{
class
Attribute
{
public:
public:
static
Attribute
GetAttrValue
(
const
proto
::
OpDesc
::
Attr
&
attr_desc
)
{
static
Attribute
GetAttrValue
(
const
proto
::
OpDesc
::
Attr
&
attr_desc
)
{
// std::cout << "begin get attr value" << std::endl;
// std::cout << "begin get attr value" << std::endl;
Attribute
attr
;
Attribute
attr
;
switch
(
attr_desc
.
type
())
{
switch
(
attr_desc
.
type
())
{
case
proto
::
AttrType
::
BOOLEAN
:
{
case
proto
::
AttrType
::
BOOLEAN
:
{
attr
.
Set
<
bool
>
(
attr_desc
.
b
());
attr
.
Set
<
bool
>
(
attr_desc
.
b
());
break
;
break
;
}
case
proto
::
AttrType
::
INT
:
{
attr
.
Set
<
int
>
(
attr_desc
.
i
());
break
;
}
case
proto
::
AttrType
::
FLOAT
:
{
attr
.
Set
<
float
>
(
attr_desc
.
f
());
break
;
}
case
proto
::
AttrType
::
STRING
:
{
attr
.
Set
<
std
::
string
>
(
attr_desc
.
s
());
break
;
}
case
proto
::
AttrType
::
BOOLEANS
:
{
std
::
vector
<
bool
>
val
(
attr_desc
.
bools_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
bools_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
bools
(
i
);
}
attr
.
Set
<
std
::
vector
<
bool
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
INTS
:
{
std
::
vector
<
int
>
val
(
attr_desc
.
ints_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
ints_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
ints
(
i
);
}
attr
.
Set
<
std
::
vector
<
int
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
FLOATS
:
{
std
::
vector
<
float
>
val
(
attr_desc
.
floats_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
floats_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
floats
(
i
);
}
attr
.
Set
<
std
::
vector
<
float
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
STRINGS
:
{
std
::
vector
<
std
::
string
>
val
(
attr_desc
.
strings_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
strings_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
strings
(
i
);
}
attr
.
Set
<
std
::
vector
<
std
::
string
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
LONG
:
{
attr
.
Set
<
int64_t
>
(
attr_desc
.
l
());
break
;
}
default:
// std::cout << " not support " << std::endl;
break
;
}
// std::cout << "end get attr value" << std::endl;
return
attr
;
}
}
case
proto
::
AttrType
::
INT
:
{
Attribute
()
{}
attr
.
Set
<
int
>
(
attr_desc
.
i
());
template
<
typename
T
,
typename
...
Args
>
Attribute
&
Set
(
Args
&&
...
args
)
{
break
;
variant_
.
Set
<
T
>
(
args
...);
}
return
*
this
;
case
proto
::
AttrType
::
FLOAT
:
{
attr
.
Set
<
float
>
(
attr_desc
.
f
());
break
;
}
case
proto
::
AttrType
::
STRING
:
{
attr
.
Set
<
std
::
string
>
(
attr_desc
.
s
());
break
;
}
case
proto
::
AttrType
::
BOOLEANS
:
{
std
::
vector
<
bool
>
val
(
attr_desc
.
bools_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
bools_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
bools
(
i
);
}
attr
.
Set
<
std
::
vector
<
bool
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
INTS
:
{
std
::
vector
<
int
>
val
(
attr_desc
.
ints_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
ints_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
ints
(
i
);
}
attr
.
Set
<
std
::
vector
<
int
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
FLOATS
:
{
std
::
vector
<
float
>
val
(
attr_desc
.
floats_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
floats_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
floats
(
i
);
}
attr
.
Set
<
std
::
vector
<
float
>>
(
val
);
break
;
}
}
case
proto
::
AttrType
::
STRINGS
:
{
std
::
vector
<
std
::
string
>
val
(
attr_desc
.
strings_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
strings_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
strings
(
i
);
}
attr
.
Set
<
std
::
vector
<
std
::
string
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
LONG
:
{
attr
.
Set
<
int64_t
>
(
attr_desc
.
l
());
break
;
}
default:
// std::cout << " not support " << std::endl;
break
;
}
// std::cout << "end get attr value" << std::endl;
return
attr
;
}
template
<
typename
T
>
T
&
Get
()
const
{
return
variant_
.
Get
<
T
>
();
}
Attribute
()
{}
template
<
typename
T
,
typename
...
Args
>
Attribute
&
Set
(
Args
&&
...
args
)
{
variant_
.
Set
<
T
>
(
args
...);
return
*
this
;
}
private:
template
<
typename
T
>
T
&
Get
()
const
{
return
variant_
.
Get
<
T
>
();
}
Variant
<
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
bool
>
,
BlockDesc
*
,
private:
int64_t
>
Variant
<
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
std
::
vector
<
float
>
,
variant_
;
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
bool
>
,
BlockDesc
*
,
int64_t
>
variant_
;
};
};
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
class
AttrReader
{
class
AttrReader
{
public:
public:
explicit
AttrReader
(
const
AttributeMap
&
attrs
)
:
attrs_
(
attrs
)
{}
explicit
AttrReader
(
const
AttributeMap
&
attrs
)
:
attrs_
(
attrs
)
{}
template
<
typename
T
>
inline
T
Get
(
const
std
::
string
&
name
)
const
{
template
<
typename
T
>
inline
T
Get
(
const
std
::
string
&
name
)
const
{
// PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should
// PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should
// be in
// be in
// AttributeMap",
// AttributeMap",
// name);
// name);
return
((
Attribute
)
attrs_
.
at
(
name
)).
Get
<
T
>
();
return
((
Attribute
)
attrs_
.
at
(
name
)).
Get
<
T
>
();
}
}
private:
private:
const
AttributeMap
&
attrs_
;
const
AttributeMap
&
attrs_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
src/framework/block_desc.cpp
浏览文件 @
1ad8f821
...
@@ -22,28 +22,28 @@ namespace paddle_mobile {
...
@@ -22,28 +22,28 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
BlockDesc
::
Vars
()
const
{
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
BlockDesc
::
Vars
()
const
{
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
res
;
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
res
;
for
(
const
auto
&
p
:
vars_
)
{
for
(
const
auto
&
p
:
vars_
)
{
res
.
push_back
(
p
.
second
);
res
.
push_back
(
p
.
second
);
}
}
return
res
;
return
res
;
}
}
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
BlockDesc
::
Ops
()
const
{
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
BlockDesc
::
Ops
()
const
{
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
res
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
res
;
for
(
const
auto
&
op
:
ops_
)
{
for
(
const
auto
&
op
:
ops_
)
{
res
.
push_back
(
op
);
res
.
push_back
(
op
);
}
}
return
res
;
return
res
;
}
}
BlockDesc
::
BlockDesc
(
const
proto
::
BlockDesc
&
desc
)
:
desc_
(
desc
)
{
BlockDesc
::
BlockDesc
(
const
proto
::
BlockDesc
&
desc
)
:
desc_
(
desc
)
{
for
(
const
proto
::
VarDesc
&
var_desc
:
desc_
.
vars
())
{
for
(
const
proto
::
VarDesc
&
var_desc
:
desc_
.
vars
())
{
vars_
[
var_desc
.
name
()].
reset
(
new
VarDesc
(
var_desc
));
vars_
[
var_desc
.
name
()].
reset
(
new
VarDesc
(
var_desc
));
}
}
for
(
const
proto
::
OpDesc
&
op_desc
:
desc_
.
ops
())
{
for
(
const
proto
::
OpDesc
&
op_desc
:
desc_
.
ops
())
{
ops_
.
emplace_back
(
new
framework
::
OpDesc
(
op_desc
));
ops_
.
emplace_back
(
new
framework
::
OpDesc
(
op_desc
));
}
}
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/block_desc.h
浏览文件 @
1ad8f821
...
@@ -27,29 +27,28 @@ namespace paddle_mobile {
...
@@ -27,29 +27,28 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
BlockDesc
:
PaddleMobileObject
{
class
BlockDesc
:
PaddleMobileObject
{
public:
public:
BlockDesc
(
const
proto
::
BlockDesc
&
desc
);
BlockDesc
(
const
proto
::
BlockDesc
&
desc
);
const
int
&
ID
()
const
{
return
desc_
.
idx
();
}
const
int
&
ID
()
const
{
return
desc_
.
idx
();
}
const
int
&
Parent
()
const
{
return
desc_
.
parent_idx
();
}
const
int
&
Parent
()
const
{
return
desc_
.
parent_idx
();
}
bool
operator
==
(
const
paddle_mobile
::
framework
::
BlockDesc
&
in_block
)
const
{
bool
operator
==
(
const
paddle_mobile
::
framework
::
BlockDesc
&
in_block
)
const
{
return
this
->
ID
()
==
in_block
.
ID
()
&&
return
this
->
ID
()
==
in_block
.
ID
()
&&
this
->
Parent
()
==
in_block
.
Parent
();
this
->
Parent
()
==
in_block
.
Parent
();
}
}
bool
operator
<
(
const
paddle_mobile
::
framework
::
BlockDesc
&
in_block
)
const
{
bool
operator
<
(
const
paddle_mobile
::
framework
::
BlockDesc
&
in_block
)
const
{
return
this
->
ID
()
<
in_block
.
ID
()
&&
this
->
Parent
()
<
in_block
.
Parent
();
return
this
->
ID
()
<
in_block
.
ID
()
&&
this
->
Parent
()
<
in_block
.
Parent
();
}
}
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
Vars
()
const
;
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
Vars
()
const
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
Ops
()
const
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
Ops
()
const
;
private:
private:
proto
::
BlockDesc
desc_
;
proto
::
BlockDesc
desc_
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
ops_
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
ops_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarDesc
>>
vars_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarDesc
>>
vars_
;
};
};
}
// namespace framework
}
// namespace framework
...
@@ -58,13 +57,13 @@ class BlockDesc : PaddleMobileObject {
...
@@ -58,13 +57,13 @@ class BlockDesc : PaddleMobileObject {
namespace
std
{
namespace
std
{
template
<
>
struct
hash
<
paddle_mobile
::
framework
::
BlockDesc
>
{
template
<
>
struct
hash
<
paddle_mobile
::
framework
::
BlockDesc
>
{
typedef
paddle_mobile
::
framework
::
BlockDesc
argument_type
;
typedef
paddle_mobile
::
framework
::
BlockDesc
argument_type
;
typedef
std
::
size_t
result_type
;
typedef
std
::
size_t
result_type
;
result_type
operator
()(
argument_type
const
&
s
)
const
noexcept
{
result_type
operator
()(
argument_type
const
&
s
)
const
noexcept
{
result_type
const
h1
(
std
::
hash
<
int
>
{}(
s
.
ID
()));
result_type
const
h1
(
std
::
hash
<
int
>
{}(
s
.
ID
()));
result_type
const
h2
(
std
::
hash
<
int
>
{}(
s
.
ID
()));
result_type
const
h2
(
std
::
hash
<
int
>
{}(
s
.
ID
()));
return
h1
^
(
h2
<<
1
);
return
h1
^
(
h2
<<
1
);
}
}
};
};
}
// namespace std
}
// namespace std
src/framework/data_layout.h
浏览文件 @
1ad8f821
...
@@ -22,45 +22,45 @@ namespace paddle_mobile {
...
@@ -22,45 +22,45 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
enum
class
DataLayout
{
enum
class
DataLayout
{
kNHWC
=
0
,
kNHWC
=
0
,
kNCHW
=
1
,
kNCHW
=
1
,
kAnyLayout
=
2
,
kAnyLayout
=
2
,
};
};
inline
DataLayout
StringToDataLayout
(
const
std
::
string
&
str
)
{
inline
DataLayout
StringToDataLayout
(
const
std
::
string
&
str
)
{
std
::
string
s
(
str
);
std
::
string
s
(
str
);
for
(
size_t
i
=
0
;
i
<
s
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
s
.
size
();
++
i
)
{
s
[
i
]
=
toupper
(
s
[
i
]);
s
[
i
]
=
toupper
(
s
[
i
]);
}
}
if
(
s
==
"NHWC"
)
{
if
(
s
==
"NHWC"
)
{
return
DataLayout
::
kNHWC
;
return
DataLayout
::
kNHWC
;
}
else
if
(
s
==
"NCHW"
)
{
}
else
if
(
s
==
"NCHW"
)
{
return
DataLayout
::
kNCHW
;
return
DataLayout
::
kNCHW
;
}
else
if
(
s
==
"ANYLAYOUT"
)
{
}
else
if
(
s
==
"ANYLAYOUT"
)
{
return
DataLayout
::
kAnyLayout
;
return
DataLayout
::
kAnyLayout
;
}
else
{
}
else
{
// std::cout << "Unknown storage order string: %s", s;
// std::cout << "Unknown storage order string: %s", s;
}
}
}
}
inline
std
::
string
DataLayoutToString
(
const
DataLayout
&
data_layout
)
{
inline
std
::
string
DataLayoutToString
(
const
DataLayout
&
data_layout
)
{
switch
(
data_layout
)
{
switch
(
data_layout
)
{
case
DataLayout
::
kNHWC
:
case
DataLayout
::
kNHWC
:
return
"NHWC"
;
return
"NHWC"
;
case
DataLayout
::
kNCHW
:
case
DataLayout
::
kNCHW
:
return
"NCHW"
;
return
"NCHW"
;
case
DataLayout
::
kAnyLayout
:
case
DataLayout
::
kAnyLayout
:
return
"ANY_LAYOUT"
;
return
"ANY_LAYOUT"
;
default:
default:
break
;
break
;
// std::cout << "unknown DataLayou %d", data_layout;
// std::cout << "unknown DataLayou %d", data_layout;
}
}
}
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
DataLayout
&
l
)
{
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
DataLayout
&
l
)
{
out
<<
DataLayoutToString
(
l
);
out
<<
DataLayoutToString
(
l
);
return
out
;
return
out
;
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/data_transform.cpp
浏览文件 @
1ad8f821
...
@@ -24,68 +24,68 @@ namespace paddle_mobile {
...
@@ -24,68 +24,68 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
static
void
PassTensorData
(
Tensor
*
from
,
Tensor
*
to
)
{
static
void
PassTensorData
(
Tensor
*
from
,
Tensor
*
to
)
{
to
->
ShareDataWith
(
*
from
);
to
->
ShareDataWith
(
*
from
);
*
from
=
Tensor
();
*
from
=
Tensor
();
}
}
void
DataTransform
(
const
OpKernelType
&
expected_kernel_type
,
void
DataTransform
(
const
OpKernelType
&
expected_kernel_type
,
const
OpKernelType
&
kernel_type_for_var
,
const
OpKernelType
&
kernel_type_for_var
,
const
Tensor
&
input_tensor
,
Tensor
*
output_tensor
)
{
const
Tensor
&
input_tensor
,
Tensor
*
output_tensor
)
{
bool
transformed
=
false
;
bool
transformed
=
false
;
Tensor
in
;
Tensor
in
;
in
.
ShareDataWith
(
input_tensor
);
in
.
ShareDataWith
(
input_tensor
);
Tensor
out
;
Tensor
out
;
// // do layout transform
// // do layout transform
// if (NeedTransformLayout(expected_kernel_type.data_layout_,
// if (NeedTransformLayout(expected_kernel_type.data_layout_,
// kernel_type_for_var.data_layout_)) {
// kernel_type_for_var.data_layout_)) {
// TransDataLayout(kernel_type_for_var, expected_kernel_type, in,
// TransDataLayout(kernel_type_for_var, expected_kernel_type, in,
// &out);
// &out);
// transformed = true;
// transformed = true;
// PassTensorData(&out, &in);
// PassTensorData(&out, &in);
// }
// }
//
//
// // do data type transform
// // do data type transform
// if (expected_kernel_type.data_type_ !=
// if (expected_kernel_type.data_type_ !=
// kernel_type_for_var.data_type_) {
// kernel_type_for_var.data_type_) {
// TransDataType(kernel_type_for_var, expected_kernel_type, in,
// TransDataType(kernel_type_for_var, expected_kernel_type, in,
// &out);
// &out);
// transformed = true;
// transformed = true;
// PassTensorData(&out, &in);
// PassTensorData(&out, &in);
// }
// }
//
//
// // do device transform
// // do device transform
// if (!platform::is_same_place(kernel_type_for_var.place_,
// if (!platform::is_same_place(kernel_type_for_var.place_,
// expected_kernel_type.place_)) {
// expected_kernel_type.place_)) {
// TransDataDevice(in, expected_kernel_type.place_, &out);
// TransDataDevice(in, expected_kernel_type.place_, &out);
// transformed = true;
// transformed = true;
// PassTensorData(&out, &in);
// PassTensorData(&out, &in);
// }
// }
//
//
// PADDLE_ENFORCE(transformed, "No transform is applied, please
// PADDLE_ENFORCE(transformed, "No transform is applied, please
// check!");
// check!");
// get output data
// get output data
output_tensor
->
ShareDataWith
(
in
);
output_tensor
->
ShareDataWith
(
in
);
}
}
void
CopyVariableWithTensor
(
const
Variable
&
in_var
,
const
Tensor
&
tensor
,
void
CopyVariableWithTensor
(
const
Variable
&
in_var
,
const
Tensor
&
tensor
,
Variable
&
out_var
)
{
Variable
&
out_var
)
{
// if (in_var.IsType<LoDTensor>()) {
// if (in_var.IsType<LoDTensor>()) {
// auto& in_lod_tensor = in_var.Get<LoDTensor>();
// auto& in_lod_tensor = in_var.Get<LoDTensor>();
// auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
// auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
// tran_lod_tensor->set_lod(in_lod_tensor.lod());
// tran_lod_tensor->set_lod(in_lod_tensor.lod());
// tran_lod_tensor->set_layout(in_lod_tensor.layout());
// tran_lod_tensor->set_layout(in_lod_tensor.layout());
// tran_lod_tensor->ShareDataWith(tensor);
// tran_lod_tensor->ShareDataWith(tensor);
// } else if (in_var.IsType<SelectedRows>()) {
// } else if (in_var.IsType<SelectedRows>()) {
// auto& in_selected_rows = in_var.Get<SelectedRows>();
// auto& in_selected_rows = in_var.Get<SelectedRows>();
// auto* trans_selected_rows =
// auto* trans_selected_rows =
// out_var.GetMutable<SelectedRows>();
// out_var.GetMutable<SelectedRows>();
// trans_selected_rows->set_height(in_selected_rows.height());
// trans_selected_rows->set_height(in_selected_rows.height());
// trans_selected_rows->set_rows(in_selected_rows.rows());
// trans_selected_rows->set_rows(in_selected_rows.rows());
// trans_selected_rows->mutable_value()->ShareDataWith(tensor);
// trans_selected_rows->mutable_value()->ShareDataWith(tensor);
// } else {
// } else {
// PADDLE_THROW("unknown var type");
// PADDLE_THROW("unknown var type");
// }
// }
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/ddim.cc
浏览文件 @
1ad8f821
...
@@ -20,158 +20,156 @@ namespace framework {
...
@@ -20,158 +20,156 @@ namespace framework {
/// @cond HIDDEN
/// @cond HIDDEN
template
<
int
i
>
Dim
<
i
>
make_dim
(
const
int64_t
*
d
)
{
template
<
int
i
>
Dim
<
i
>
make_dim
(
const
int64_t
*
d
)
{
return
Dim
<
i
>
(
*
d
,
make_dim
<
i
-
1
>
(
d
+
1
));
return
Dim
<
i
>
(
*
d
,
make_dim
<
i
-
1
>
(
d
+
1
));
}
}
template
<
>
Dim
<
0
>
make_dim
<
0
>
(
const
int64_t
*
d
)
{
return
Dim
<
0
>
(
*
d
);
}
template
<
>
Dim
<
0
>
make_dim
<
0
>
(
const
int64_t
*
d
)
{
return
Dim
<
0
>
(
*
d
);
}
void
make_ddim
(
DDim
&
ddim
,
const
int64_t
*
dims
,
int
n
)
{
void
make_ddim
(
DDim
&
ddim
,
const
int64_t
*
dims
,
int
n
)
{
switch
(
n
)
{
switch
(
n
)
{
case
0
:
case
0
:
ddim
=
make_dim
<
0
>
(
dims
);
ddim
=
make_dim
<
0
>
(
dims
);
break
;
break
;
case
1
:
case
1
:
ddim
=
make_dim
<
1
>
(
dims
);
ddim
=
make_dim
<
1
>
(
dims
);
break
;
break
;
case
2
:
case
2
:
ddim
=
make_dim
<
2
>
(
dims
);
ddim
=
make_dim
<
2
>
(
dims
);
break
;
break
;
case
3
:
case
3
:
ddim
=
make_dim
<
3
>
(
dims
);
ddim
=
make_dim
<
3
>
(
dims
);
break
;
break
;
case
4
:
case
4
:
ddim
=
make_dim
<
4
>
(
dims
);
ddim
=
make_dim
<
4
>
(
dims
);
break
;
break
;
case
5
:
case
5
:
ddim
=
make_dim
<
5
>
(
dims
);
ddim
=
make_dim
<
5
>
(
dims
);
break
;
break
;
case
6
:
case
6
:
ddim
=
make_dim
<
6
>
(
dims
);
ddim
=
make_dim
<
6
>
(
dims
);
break
;
break
;
case
7
:
case
7
:
ddim
=
make_dim
<
7
>
(
dims
);
ddim
=
make_dim
<
7
>
(
dims
);
break
;
break
;
case
8
:
case
8
:
ddim
=
make_dim
<
8
>
(
dims
);
ddim
=
make_dim
<
8
>
(
dims
);
break
;
break
;
case
9
:
case
9
:
ddim
=
make_dim
<
9
>
(
dims
);
ddim
=
make_dim
<
9
>
(
dims
);
break
;
break
;
default:
default:
// std::cout << "Dynamic dimensions must have between [1,
// std::cout << "Dynamic dimensions must have between [1,
// 9]
// 9]
// dimensions.";
// dimensions.";
break
;
break
;
}
}
}
}
/// @endcond
/// @endcond
DDim
make_ddim
(
std
::
initializer_list
<
int64_t
>
dims
)
{
DDim
make_ddim
(
std
::
initializer_list
<
int64_t
>
dims
)
{
DDim
result
(
make_dim
(
0
));
DDim
result
(
make_dim
(
0
));
make_ddim
(
result
,
dims
.
begin
(),
dims
.
size
());
make_ddim
(
result
,
dims
.
begin
(),
dims
.
size
());
return
result
;
return
result
;
}
}
DDim
make_ddim
(
const
std
::
vector
<
int64_t
>
&
dims
)
{
DDim
make_ddim
(
const
std
::
vector
<
int64_t
>
&
dims
)
{
DDim
result
(
make_dim
(
0
));
DDim
result
(
make_dim
(
0
));
make_ddim
(
result
,
&
dims
[
0
],
dims
.
size
());
make_ddim
(
result
,
&
dims
[
0
],
dims
.
size
());
return
result
;
return
result
;
}
}
DDim
make_ddim
(
const
std
::
vector
<
int
>
&
dims
)
{
DDim
make_ddim
(
const
std
::
vector
<
int
>
&
dims
)
{
std
::
vector
<
int64_t
>
res
(
dims
.
size
());
std
::
vector
<
int64_t
>
res
(
dims
.
size
());
std
::
transform
(
dims
.
begin
(),
dims
.
end
(),
res
.
begin
(),
std
::
transform
(
dims
.
begin
(),
dims
.
end
(),
res
.
begin
(),
[](
int
d
)
{
return
static_cast
<
int64_t
>
(
d
);
});
[](
int
d
)
{
return
static_cast
<
int64_t
>
(
d
);
});
return
make_ddim
(
res
);
return
make_ddim
(
res
);
}
}
/// @cond HIDDEN
/// @cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes
// XXX For some reason, putting this in an anonymous namespace causes
// errors
// errors
struct
DynamicMutableIndexer
:
Vistor
<
int64_t
&>
{
struct
DynamicMutableIndexer
:
Vistor
<
int64_t
&>
{
public:
public:
explicit
DynamicMutableIndexer
(
int
idx
)
:
idx_
(
idx
)
{}
explicit
DynamicMutableIndexer
(
int
idx
)
:
idx_
(
idx
)
{}
template
<
int
D
>
int64_t
&
operator
()(
Dim
<
D
>
&
dim
)
const
{
template
<
int
D
>
int64_t
&
operator
()(
Dim
<
D
>
&
dim
)
const
{
return
dim
[
idx_
];
}
return
dim
[
idx_
];
}
private:
private:
int
idx_
;
int
idx_
;
};
};
struct
DynamicConstIndexer
:
public
Vistor
<
int64_t
>
{
struct
DynamicConstIndexer
:
public
Vistor
<
int64_t
>
{
public:
public:
explicit
DynamicConstIndexer
(
int
idx
)
:
idx_
(
idx
)
{}
explicit
DynamicConstIndexer
(
int
idx
)
:
idx_
(
idx
)
{}
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>
&
dim
)
const
{
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>
&
dim
)
const
{
return
dim
[
idx_
];
return
dim
[
idx_
];
}
}
private:
private:
int
idx_
;
int
idx_
;
};
};
/// @endcond
/// @endcond
int64_t
&
DDim
::
operator
[](
int
idx
)
{
int64_t
&
DDim
::
operator
[](
int
idx
)
{
return
DDim
::
ApplyVistor
(
DynamicMutableIndexer
(
idx
),
*
this
);
return
DDim
::
ApplyVistor
(
DynamicMutableIndexer
(
idx
),
*
this
);
}
}
int64_t
DDim
::
operator
[](
int
idx
)
const
{
int64_t
DDim
::
operator
[](
int
idx
)
const
{
return
DDim
::
ApplyVistor
(
DynamicConstIndexer
(
idx
),
*
this
);
return
DDim
::
ApplyVistor
(
DynamicConstIndexer
(
idx
),
*
this
);
}
}
int
DDim
::
size
()
const
{
return
arity
(
*
this
);
}
int
DDim
::
size
()
const
{
return
arity
(
*
this
);
}
bool
DDim
::
operator
==
(
DDim
d
)
const
{
bool
DDim
::
operator
==
(
DDim
d
)
const
{
// if (var.which() != d.getVar().which()) {
// if (var.which() != d.getVar().which()) {
// return false;
// return false;
// } else {
// } else {
std
::
vector
<
int64_t
>
v1
=
vectorize
(
*
this
);
std
::
vector
<
int64_t
>
v1
=
vectorize
(
*
this
);
std
::
vector
<
int64_t
>
v2
=
vectorize
(
d
);
std
::
vector
<
int64_t
>
v2
=
vectorize
(
d
);
for
(
unsigned
int
i
=
0
;
i
<
v1
.
size
();
i
++
)
{
for
(
unsigned
int
i
=
0
;
i
<
v1
.
size
();
i
++
)
{
if
(
v1
[
i
]
!=
v2
[
i
])
{
if
(
v1
[
i
]
!=
v2
[
i
])
{
return
false
;
return
false
;
}
}
}
}
return
true
;
return
true
;
// }
// }
}
}
bool
DDim
::
operator
!=
(
DDim
d
)
const
{
return
!
(
*
this
==
d
);
}
bool
DDim
::
operator
!=
(
DDim
d
)
const
{
return
!
(
*
this
==
d
);
}
DDim
DDim
::
operator
+
(
DDim
d
)
const
{
DDim
DDim
::
operator
+
(
DDim
d
)
const
{
std
::
vector
<
int64_t
>
v1
=
vectorize
(
*
this
);
std
::
vector
<
int64_t
>
v1
=
vectorize
(
*
this
);
std
::
vector
<
int64_t
>
v2
=
vectorize
(
d
);
std
::
vector
<
int64_t
>
v2
=
vectorize
(
d
);
std
::
vector
<
int64_t
>
v3
;
std
::
vector
<
int64_t
>
v3
;
assert
(
v1
.
size
()
==
v2
.
size
());
assert
(
v1
.
size
()
==
v2
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
v1
.
size
();
i
++
)
{
for
(
unsigned
int
i
=
0
;
i
<
v1
.
size
();
i
++
)
{
v3
.
push_back
(
v1
[
i
]
+
v2
[
i
]);
v3
.
push_back
(
v1
[
i
]
+
v2
[
i
]);
}
}
return
make_ddim
(
v3
);
return
make_ddim
(
v3
);
}
}
DDim
DDim
::
operator
*
(
DDim
d
)
const
{
DDim
DDim
::
operator
*
(
DDim
d
)
const
{
std
::
vector
<
int64_t
>
v1
=
vectorize
(
*
this
);
std
::
vector
<
int64_t
>
v1
=
vectorize
(
*
this
);
std
::
vector
<
int64_t
>
v2
=
vectorize
(
d
);
std
::
vector
<
int64_t
>
v2
=
vectorize
(
d
);
std
::
vector
<
int64_t
>
v3
;
std
::
vector
<
int64_t
>
v3
;
assert
(
v1
.
size
()
==
v2
.
size
());
assert
(
v1
.
size
()
==
v2
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
v1
.
size
();
i
++
)
{
for
(
unsigned
int
i
=
0
;
i
<
v1
.
size
();
i
++
)
{
v3
.
push_back
(
v1
[
i
]
*
v2
[
i
]);
v3
.
push_back
(
v1
[
i
]
*
v2
[
i
]);
}
}
return
make_ddim
(
v3
);
return
make_ddim
(
v3
);
}
}
int64_t
get
(
const
DDim
&
ddim
,
int
idx
)
{
return
ddim
[
idx
];
}
int64_t
get
(
const
DDim
&
ddim
,
int
idx
)
{
return
ddim
[
idx
];
}
...
@@ -180,152 +178,152 @@ void set(DDim &ddim, int idx, int value) { ddim[idx] = value; }
...
@@ -180,152 +178,152 @@ void set(DDim &ddim, int idx, int value) { ddim[idx] = value; }
/// @cond HIDDEN
/// @cond HIDDEN
struct
VectorizeVisitor
:
Vistor
<
void
>
{
struct
VectorizeVisitor
:
Vistor
<
void
>
{
std
::
vector
<
int64_t
>
&
vector
;
std
::
vector
<
int64_t
>
&
vector
;
explicit
VectorizeVisitor
(
std
::
vector
<
int64_t
>
&
v
)
:
vector
(
v
)
{}
explicit
VectorizeVisitor
(
std
::
vector
<
int64_t
>
&
v
)
:
vector
(
v
)
{}
template
<
typename
T
>
void
operator
()(
const
T
&
t
)
{
template
<
typename
T
>
void
operator
()(
const
T
&
t
)
{
vector
.
push_back
(
t
.
head
);
vector
.
push_back
(
t
.
head
);
this
->
operator
()(
t
.
tail
);
this
->
operator
()(
t
.
tail
);
}
}
void
operator
()(
const
Dim
<
0
>
&
t
)
{}
void
operator
()(
const
Dim
<
0
>
&
t
)
{}
};
};
/// @endcond
/// @endcond
std
::
vector
<
int64_t
>
vectorize
(
const
DDim
&
ddim
)
{
std
::
vector
<
int64_t
>
vectorize
(
const
DDim
&
ddim
)
{
std
::
vector
<
int64_t
>
result
;
std
::
vector
<
int64_t
>
result
;
VectorizeVisitor
visitor
(
result
);
VectorizeVisitor
visitor
(
result
);
DDim
::
ApplyVistor
(
visitor
,
ddim
);
DDim
::
ApplyVistor
(
visitor
,
ddim
);
return
result
;
return
result
;
}
}
// NOTE: framework::vectorize converts to type int64_t
// NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs.
// which does not fit cudnn inputs.
std
::
vector
<
int
>
vectorize2int
(
const
DDim
&
ddim
)
{
std
::
vector
<
int
>
vectorize2int
(
const
DDim
&
ddim
)
{
std
::
vector
<
int64_t
>
temp
=
vectorize
(
ddim
);
std
::
vector
<
int64_t
>
temp
=
vectorize
(
ddim
);
std
::
vector
<
int
>
result
(
temp
.
begin
(),
temp
.
end
());
std
::
vector
<
int
>
result
(
temp
.
begin
(),
temp
.
end
());
return
result
;
return
result
;
}
}
struct
ProductVisitor
:
Vistor
<
int64_t
>
{
struct
ProductVisitor
:
Vistor
<
int64_t
>
{
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>
&
dim
)
{
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>
&
dim
)
{
return
product
(
dim
);
return
product
(
dim
);
}
}
};
};
int64_t
product
(
const
DDim
&
ddim
)
{
int64_t
product
(
const
DDim
&
ddim
)
{
ProductVisitor
visitor
;
ProductVisitor
visitor
;
return
DDim
::
ApplyVistor
(
visitor
,
ddim
);
return
DDim
::
ApplyVistor
(
visitor
,
ddim
);
}
}
struct
SliceVectorizeVisitor
:
Vistor
<
void
>
{
struct
SliceVectorizeVisitor
:
Vistor
<
void
>
{
std
::
vector
<
int64_t
>
&
vector
;
std
::
vector
<
int64_t
>
&
vector
;
int
begin
;
int
begin
;
int
end
;
int
end
;
SliceVectorizeVisitor
(
std
::
vector
<
int64_t
>
&
v
,
int
b
,
int
e
)
SliceVectorizeVisitor
(
std
::
vector
<
int64_t
>
&
v
,
int
b
,
int
e
)
:
vector
(
v
),
begin
(
b
),
end
(
e
)
{
:
vector
(
v
),
begin
(
b
),
end
(
e
)
{
// PADDLE_ENFORCE(begin < end,
// PADDLE_ENFORCE(begin < end,
// "Begin index must be less than end index in
// "Begin index must be less than end index in
// ddim
// ddim
// slice.");
// slice.");
// PADDLE_ENFORCE(begin >= 0,
// PADDLE_ENFORCE(begin >= 0,
// "Begin index can't be less than zero in
// "Begin index can't be less than zero in
// ddim slice.");
// ddim slice.");
}
template
<
int
S
>
void
operator
()(
const
Dim
<
S
>
&
dim
)
{
if
(
begin
==
0
)
{
vector
.
push_back
(
dim
.
head
);
}
else
{
--
begin
;
}
}
--
end
;
template
<
int
S
>
void
operator
()(
const
Dim
<
S
>
&
dim
)
{
if
(
end
>
0
)
{
if
(
begin
==
0
)
{
this
->
operator
()(
dim
.
tail
);
vector
.
push_back
(
dim
.
head
);
}
else
{
--
begin
;
}
--
end
;
if
(
end
>
0
)
{
this
->
operator
()(
dim
.
tail
);
}
}
}
}
void
operator
()(
const
Dim
<
0
>
&
dim
)
{
void
operator
()(
const
Dim
<
0
>
&
dim
)
{
// PADDLE_ENFORCE(end == 0, "End index in ddim slice is out
// PADDLE_ENFORCE(end == 0, "End index in ddim slice is out
// of bound.");
// of bound.");
}
}
};
};
DDim
slice_ddim
(
const
DDim
&
ddim
,
int
begin
,
int
end
)
{
DDim
slice_ddim
(
const
DDim
&
ddim
,
int
begin
,
int
end
)
{
std
::
vector
<
int64_t
>
vec
;
std
::
vector
<
int64_t
>
vec
;
vec
.
reserve
(
end
-
begin
);
vec
.
reserve
(
end
-
begin
);
SliceVectorizeVisitor
visitor
(
vec
,
begin
,
end
);
SliceVectorizeVisitor
visitor
(
vec
,
begin
,
end
);
// boost::apply_visitor(visitor, dim);
// boost::apply_visitor(visitor, dim);
DDim
::
ApplyVistor
(
visitor
,
ddim
);
DDim
::
ApplyVistor
(
visitor
,
ddim
);
// visitor(ddim.var.Get<Dim<4>>());
// visitor(ddim.var.Get<Dim<4>>());
return
make_ddim
(
vec
);
return
make_ddim
(
vec
);
}
}
/// \cond HIDDEN
/// \cond HIDDEN
struct
ArityVisitor
:
Vistor
<
int
>
{
struct
ArityVisitor
:
Vistor
<
int
>
{
template
<
int
D
>
int
operator
()(
Dim
<
D
>
)
const
{
return
D
;
}
template
<
int
D
>
int
operator
()(
Dim
<
D
>
)
const
{
return
D
;
}
};
};
/// \endcond
/// \endcond
int
arity
(
const
DDim
&
d
)
{
int
arity
(
const
DDim
&
d
)
{
ArityVisitor
arityVisitor
=
ArityVisitor
();
ArityVisitor
arityVisitor
=
ArityVisitor
();
return
DDim
::
ApplyVistor
(
arityVisitor
,
d
);
return
DDim
::
ApplyVistor
(
arityVisitor
,
d
);
// return arityVisitor(d.var.Get<Dim<4>>());
// return arityVisitor(d.var.Get<Dim<4>>());
// return boost::apply_visitor(ArityVisitor(), d); }
// return boost::apply_visitor(ArityVisitor(), d); }
}
}
/// \cond HIDDEN
/// \cond HIDDEN
/// \endcond
/// \endcond
struct
OSVistor
:
Vistor
<
std
::
ostream
&>
{
struct
OSVistor
:
Vistor
<
std
::
ostream
&>
{
OSVistor
(
std
::
ostream
&
os
)
:
os_
(
os
)
{}
OSVistor
(
std
::
ostream
&
os
)
:
os_
(
os
)
{}
template
<
int
D
>
std
::
ostream
&
operator
()(
Dim
<
D
>
dim
)
const
{
template
<
int
D
>
std
::
ostream
&
operator
()(
Dim
<
D
>
dim
)
const
{
return
os_
<<
dim
;
return
os_
<<
dim
;
}
}
private:
private:
std
::
ostream
&
os_
;
std
::
ostream
&
os_
;
};
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
DDim
&
ddim
)
{
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
DDim
&
ddim
)
{
auto
vistor
=
OSVistor
(
os
);
auto
vistor
=
OSVistor
(
os
);
DDim
::
ApplyVistor
(
vistor
,
ddim
);
DDim
::
ApplyVistor
(
vistor
,
ddim
);
return
os
;
return
os
;
}
}
DDim
::
DDim
(
std
::
initializer_list
<
int64_t
>
init_list
)
{
DDim
::
DDim
(
std
::
initializer_list
<
int64_t
>
init_list
)
{
*
this
=
make_ddim
(
init_list
);
*
this
=
make_ddim
(
init_list
);
}
}
DDim
flatten_to_2d
(
const
DDim
&
src
,
int
num_col_dims
)
{
DDim
flatten_to_2d
(
const
DDim
&
src
,
int
num_col_dims
)
{
int
rank
=
src
.
size
();
int
rank
=
src
.
size
();
return
make_ddim
({
product
(
slice_ddim
(
src
,
0
,
num_col_dims
)),
return
make_ddim
({
product
(
slice_ddim
(
src
,
0
,
num_col_dims
)),
product
(
slice_ddim
(
src
,
num_col_dims
,
rank
))});
product
(
slice_ddim
(
src
,
num_col_dims
,
rank
))});
}
}
DDim
flatten_to_1d
(
const
DDim
&
src
)
{
return
make_ddim
({
product
(
src
)});
}
DDim
flatten_to_1d
(
const
DDim
&
src
)
{
return
make_ddim
({
product
(
src
)});
}
DDim
stride
(
const
DDim
&
ddim
)
{
DDim
stride
(
const
DDim
&
ddim
)
{
std
::
vector
<
int64_t
>
strides
(
ddim
.
size
());
std
::
vector
<
int64_t
>
strides
(
ddim
.
size
());
strides
[
ddim
.
size
()
-
1
]
=
1
;
strides
[
ddim
.
size
()
-
1
]
=
1
;
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
+
1
];
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
+
1
];
}
}
return
framework
::
make_ddim
(
strides
);
return
framework
::
make_ddim
(
strides
);
}
}
DDim
stride_numel
(
const
framework
::
DDim
&
ddim
)
{
DDim
stride_numel
(
const
framework
::
DDim
&
ddim
)
{
std
::
vector
<
int64_t
>
strides
(
ddim
.
size
());
std
::
vector
<
int64_t
>
strides
(
ddim
.
size
());
strides
[
ddim
.
size
()
-
1
]
=
ddim
[
ddim
.
size
()
-
1
];
strides
[
ddim
.
size
()
-
1
]
=
ddim
[
ddim
.
size
()
-
1
];
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
];
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
];
}
}
return
framework
::
make_ddim
(
strides
);
return
framework
::
make_ddim
(
strides
);
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/ddim.h
浏览文件 @
1ad8f821
...
@@ -30,77 +30,77 @@ namespace framework {
...
@@ -30,77 +30,77 @@ namespace framework {
* The number of dimensions must be between [1, 9].
* The number of dimensions must be between [1, 9].
*/
*/
struct
DDim
{
struct
DDim
{
typedef
Variant
<
Dim
<
0
>
,
Dim
<
1
>
,
Dim
<
2
>
,
Dim
<
3
>
,
Dim
<
4
>
,
Dim
<
5
>
,
Dim
<
6
>
,
typedef
Variant
<
Dim
<
0
>
,
Dim
<
1
>
,
Dim
<
2
>
,
Dim
<
3
>
,
Dim
<
4
>
,
Dim
<
5
>
,
Dim
<
6
>
,
Dim
<
7
>
,
Dim
<
8
>
,
Dim
<
9
>>
Dim
<
7
>
,
Dim
<
8
>
,
Dim
<
9
>>
DDimVar
;
DDimVar
;
DDimVar
var
;
DDimVar
var
;
template
<
typename
Vistor
>
template
<
typename
Vistor
>
static
typename
Vistor
::
type_t
ApplyVistor
(
Vistor
vistor
,
const
DDim
&
d
)
{
static
typename
Vistor
::
type_t
ApplyVistor
(
Vistor
vistor
,
const
DDim
&
d
)
{
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
0
>
).
hash_code
())
{
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
0
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
0
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
0
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
1
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
1
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
1
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
1
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
2
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
2
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
2
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
2
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
3
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
3
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
3
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
3
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
4
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
4
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
4
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
4
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
5
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
5
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
5
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
5
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
6
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
6
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
6
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
6
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
7
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
7
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
7
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
7
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
8
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
8
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
8
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
8
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
9
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
9
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
9
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
9
>>
());
}
else
{
}
else
{
printf
(
" dim not support
\n
"
);
printf
(
" dim not support
\n
"
);
throw
std
::
bad_exception
();
throw
std
::
bad_exception
();
// return typename Vistor::type_t();
// return typename Vistor::type_t();
}
}
}
}
DDim
()
{
var
.
Set
<
Dim
<
1
>>
(
Dim
<
1
>
());
}
DDim
()
{
var
.
Set
<
Dim
<
1
>>
(
Dim
<
1
>
());
}
template
<
int
D
>
explicit
DDim
(
const
Dim
<
D
>
&
in
)
{
var
.
Set
<
Dim
<
D
>>
(
in
);
}
template
<
int
D
>
explicit
DDim
(
const
Dim
<
D
>
&
in
)
{
var
.
Set
<
Dim
<
D
>>
(
in
);
}
/*implicit*/
DDim
(
std
::
initializer_list
<
int64_t
>
init_list
);
/*implicit*/
DDim
(
std
::
initializer_list
<
int64_t
>
init_list
);
template
<
int
D
>
DDim
&
operator
=
(
const
Dim
<
D
>
&
in
)
{
template
<
int
D
>
DDim
&
operator
=
(
const
Dim
<
D
>
&
in
)
{
var
.
Set
<
Dim
<
D
>>
(
in
);
var
.
Set
<
Dim
<
D
>>
(
in
);
return
*
this
;
return
*
this
;
}
}
int64_t
&
operator
[](
int
idx
);
int64_t
&
operator
[](
int
idx
);
int64_t
operator
[](
int
idx
)
const
;
int64_t
operator
[](
int
idx
)
const
;
// template <typename Visitor>
// template <typename Visitor>
// typename Visitor::result_type apply_visitor(Visitor& visitor) {
// typename Visitor::result_type apply_visitor(Visitor& visitor) {
// return var.apply_visitor(visitor);
// return var.apply_visitor(visitor);
// }
// }
//
//
// template <typename Visitor>
// template <typename Visitor>
// typename Visitor::result_type apply_visitor(Visitor& visitor)
// typename Visitor::result_type apply_visitor(Visitor& visitor)
// const {
// const {
// return var.apply_visitor(visitor);
// return var.apply_visitor(visitor);
// }
// }
DDimVar
getVar
()
{
return
var
;
}
DDimVar
getVar
()
{
return
var
;
}
bool
operator
==
(
DDim
d
)
const
;
bool
operator
==
(
DDim
d
)
const
;
bool
operator
!=
(
DDim
d
)
const
;
bool
operator
!=
(
DDim
d
)
const
;
DDim
operator
+
(
DDim
d
)
const
;
DDim
operator
+
(
DDim
d
)
const
;
DDim
operator
*
(
DDim
d
)
const
;
DDim
operator
*
(
DDim
d
)
const
;
int
size
()
const
;
int
size
()
const
;
};
};
/**
/**
...
...
src/framework/dim.h
浏览文件 @
1ad8f821
...
@@ -25,199 +25,197 @@ namespace framework {
...
@@ -25,199 +25,197 @@ namespace framework {
// Statically sized, statically indexed dimension
// Statically sized, statically indexed dimension
template
<
int
i
>
struct
Dim
{
template
<
int
i
>
struct
Dim
{
static
constexpr
int
dimensions
=
i
;
static
constexpr
int
dimensions
=
i
;
template
<
typename
...
Args
>
template
<
typename
...
Args
>
HOSTDEVICE
Dim
(
int64_t
_head
,
Args
...
_tail
)
:
head
(
_head
),
tail
(
_tail
...)
{
HOSTDEVICE
Dim
(
int64_t
_head
,
Args
...
_tail
)
:
head
(
_head
),
tail
(
_tail
...)
{
static_assert
(
sizeof
...(
_tail
)
==
i
-
1
,
static_assert
(
sizeof
...(
_tail
)
==
i
-
1
,
"Dim initialized with the wrong number of parameters"
);
"Dim initialized with the wrong number of parameters"
);
}
}
HOSTDEVICE
HOSTDEVICE
Dim
(
int64_t
_head
,
const
Dim
<
i
-
1
>
&
_tail
)
:
head
(
_head
),
tail
(
_tail
)
{}
Dim
(
int64_t
_head
,
const
Dim
<
i
-
1
>
&
_tail
)
:
head
(
_head
),
tail
(
_tail
)
{}
HOSTDEVICE
HOSTDEVICE
Dim
()
:
head
(
0
),
tail
()
{}
Dim
()
:
head
(
0
),
tail
()
{}
/** Construct a Dim from a linear index and size. Uses Fortran
/** Construct a Dim from a linear index and size. Uses Fortran
* order
* order
* indexing. */
* indexing. */
HOSTDEVICE
HOSTDEVICE
Dim
(
int64_t
idx
,
const
Dim
<
i
>
&
size
)
Dim
(
int64_t
idx
,
const
Dim
<
i
>
&
size
)
:
head
(
idx
%
size
.
head
),
tail
(
idx
/
size
.
head
,
size
.
tail
)
{}
:
head
(
idx
%
size
.
head
),
tail
(
idx
/
size
.
head
,
size
.
tail
)
{}
/** Construct a Dim with each dimension set to the given index */
/** Construct a Dim with each dimension set to the given index */
HOSTDEVICE
HOSTDEVICE
Dim
(
int64_t
idx
)
:
head
(
idx
),
tail
(
idx
)
{}
Dim
(
int64_t
idx
)
:
head
(
idx
),
tail
(
idx
)
{}
HOSTDEVICE
HOSTDEVICE
bool
operator
==
(
const
Dim
<
i
>
&
o
)
const
{
bool
operator
==
(
const
Dim
<
i
>
&
o
)
const
{
return
(
head
==
o
.
head
)
&&
(
tail
==
o
.
tail
);
return
(
head
==
o
.
head
)
&&
(
tail
==
o
.
tail
);
}
}
HOSTDEVICE
HOSTDEVICE
bool
operator
!=
(
const
Dim
<
i
>
&
o
)
const
{
return
!
(
*
this
==
o
);
}
bool
operator
!=
(
const
Dim
<
i
>
&
o
)
const
{
return
!
(
*
this
==
o
);
}
HOSTDEVICE
HOSTDEVICE
int64_t
&
operator
[](
int
idx
);
int64_t
&
operator
[](
int
idx
);
HOSTDEVICE
HOSTDEVICE
int64_t
operator
[](
int
idx
)
const
;
int64_t
operator
[](
int
idx
)
const
;
HOST
std
::
string
to_string
()
const
;
HOST
std
::
string
to_string
()
const
;
int64_t
head
;
int64_t
head
;
Dim
<
i
-
1
>
tail
;
Dim
<
i
-
1
>
tail
;
};
};
// Base case specialization
// Base case specialization
template
<
>
struct
Dim
<
0
>
{
template
<
>
struct
Dim
<
0
>
{
static
constexpr
int
dimensions
=
0
;
static
constexpr
int
dimensions
=
0
;
HOSTDEVICE
HOSTDEVICE
Dim
(
int64_t
_head
)
{}
Dim
(
int64_t
_head
)
{}
HOSTDEVICE
HOSTDEVICE
Dim
()
{}
Dim
()
{}
HOSTDEVICE
HOSTDEVICE
Dim
(
int
idx
,
const
Dim
<
0
>
&
size
)
{
Dim
(
int
idx
,
const
Dim
<
0
>
&
size
)
{
#ifndef __CUDA_ARCH__
#ifndef __CUDA_ARCH__
if
(
idx
>
0
)
{
if
(
idx
>
0
)
{
throw
std
::
invalid_argument
(
"Index out of range."
);
throw
std
::
invalid_argument
(
"Index out of range."
);
}
}
#else
#else
PADDLE_ASSERT
(
idx
==
0
);
PADDLE_ASSERT
(
idx
==
0
);
#endif
#endif
}
}
HOSTDEVICE
HOSTDEVICE
bool
operator
==
(
const
Dim
<
0
>
&
o
)
const
{
return
true
;
}
bool
operator
==
(
const
Dim
<
0
>
&
o
)
const
{
return
true
;
}
HOSTDEVICE
HOSTDEVICE
bool
operator
!=
(
const
Dim
<
0
>
&
o
)
const
{
return
false
;
}
bool
operator
!=
(
const
Dim
<
0
>
&
o
)
const
{
return
false
;
}
HOSTDEVICE
HOSTDEVICE
int64_t
&
operator
[](
int
idx
);
int64_t
&
operator
[](
int
idx
);
HOSTDEVICE
HOSTDEVICE
int64_t
operator
[](
int
idx
)
const
;
int64_t
operator
[](
int
idx
)
const
;
};
};
namespace
{
namespace
{
// Helper for accessing Dim classes
// Helper for accessing Dim classes
template
<
int
i
>
struct
DimGetter
{
template
<
int
i
>
struct
DimGetter
{
// Return a copy if Dim is const
// Return a copy if Dim is const
template
<
typename
D
>
HOSTDEVICE
static
int64_t
impl
(
const
D
&
d
)
{
template
<
typename
D
>
HOSTDEVICE
static
int64_t
impl
(
const
D
&
d
)
{
return
DimGetter
<
i
-
1
>::
impl
(
d
.
tail
);
return
DimGetter
<
i
-
1
>::
impl
(
d
.
tail
);
}
}
// Return a reference if Dim is mutable
// Return a reference if Dim is mutable
template
<
typename
D
>
HOSTDEVICE
static
int64_t
&
impl
(
D
&
d
)
{
template
<
typename
D
>
HOSTDEVICE
static
int64_t
&
impl
(
D
&
d
)
{
return
DimGetter
<
i
-
1
>::
impl
(
d
.
tail
);
return
DimGetter
<
i
-
1
>::
impl
(
d
.
tail
);
}
}
};
};
// Eureka! We found the element!
// Eureka! We found the element!
template
<
>
struct
DimGetter
<
0
>
{
template
<
>
struct
DimGetter
<
0
>
{
// Return a copy if Dim is const
// Return a copy if Dim is const
template
<
typename
D
>
HOSTDEVICE
static
int64_t
impl
(
const
D
&
d
)
{
template
<
typename
D
>
HOSTDEVICE
static
int64_t
impl
(
const
D
&
d
)
{
return
d
.
head
;
return
d
.
head
;
}
}
// Return a reference if Dim is mutable
// Return a reference if Dim is mutable
template
<
typename
D
>
HOSTDEVICE
static
int64_t
&
impl
(
D
&
d
)
{
template
<
typename
D
>
HOSTDEVICE
static
int64_t
&
impl
(
D
&
d
)
{
return
d
.
head
;
}
return
d
.
head
;
}
};
};
template
<
int
D
>
HOSTDEVICE
int64_t
&
indexer
(
Dim
<
D
>
&
dim
,
int
idx
)
{
template
<
int
D
>
HOSTDEVICE
int64_t
&
indexer
(
Dim
<
D
>
&
dim
,
int
idx
)
{
#ifndef __CUDA_ARCH__
#ifndef __CUDA_ARCH__
if
(
idx
<
0
)
{
if
(
idx
<
0
)
{
throw
std
::
invalid_argument
(
"Tried to access a negative dimension"
);
throw
std
::
invalid_argument
(
"Tried to access a negative dimension"
);
}
}
#else
#else
PADDLE_ASSERT
(
idx
>=
0
);
PADDLE_ASSERT
(
idx
>=
0
);
#endif
#endif
if
(
idx
==
0
)
{
if
(
idx
==
0
)
{
return
dim
.
head
;
return
dim
.
head
;
}
}
return
indexer
(
dim
.
tail
,
idx
-
1
);
return
indexer
(
dim
.
tail
,
idx
-
1
);
}
}
template
<
>
HOSTDEVICE
int64_t
&
indexer
<
0
>
(
Dim
<
0
>
&
dim
,
int
idx
)
{
template
<
>
HOSTDEVICE
int64_t
&
indexer
<
0
>
(
Dim
<
0
>
&
dim
,
int
idx
)
{
#ifndef __CUDA_ARCH__
#ifndef __CUDA_ARCH__
throw
std
::
invalid_argument
(
"Invalid index"
);
throw
std
::
invalid_argument
(
"Invalid index"
);
#else
#else
PADDLE_ASSERT
(
false
);
PADDLE_ASSERT
(
false
);
#if CUDA_VERSION < 8000
#if CUDA_VERSION < 8000
// On CUDA versions previous to 8.0, only __shared__ variables
// On CUDA versions previous to 8.0, only __shared__ variables
// could be declared as static in the device code.
// could be declared as static in the device code.
int64_t
head
=
0
;
int64_t
head
=
0
;
#else
#else
static
int64_t
head
=
0
;
static
int64_t
head
=
0
;
#endif
#endif
return
head
;
return
head
;
#endif
#endif
}
}
template
<
int
D
>
HOSTDEVICE
int64_t
indexer
(
const
Dim
<
D
>
&
dim
,
int
idx
)
{
template
<
int
D
>
HOSTDEVICE
int64_t
indexer
(
const
Dim
<
D
>
&
dim
,
int
idx
)
{
#ifndef __CUDA_ARCH__
#ifndef __CUDA_ARCH__
if
(
idx
<
0
)
{
if
(
idx
<
0
)
{
throw
std
::
invalid_argument
(
"Tried to access a negative dimension"
);
throw
std
::
invalid_argument
(
"Tried to access a negative dimension"
);
}
}
#else
#else
PADDLE_ASSERT
(
idx
>=
0
);
PADDLE_ASSERT
(
idx
>=
0
);
#endif
#endif
if
(
idx
==
0
)
{
if
(
idx
==
0
)
{
return
dim
.
head
;
return
dim
.
head
;
}
}
return
indexer
(
dim
.
tail
,
idx
-
1
);
return
indexer
(
dim
.
tail
,
idx
-
1
);
}
}
template
<
>
HOSTDEVICE
int64_t
indexer
<
0
>
(
const
Dim
<
0
>
&
dim
,
int
idx
)
{
template
<
>
HOSTDEVICE
int64_t
indexer
<
0
>
(
const
Dim
<
0
>
&
dim
,
int
idx
)
{
#ifndef __CUDA_ARCH__
#ifndef __CUDA_ARCH__
throw
std
::
invalid_argument
(
"Invalid index"
);
throw
std
::
invalid_argument
(
"Invalid index"
);
#else
#else
PADDLE_ASSERT
(
false
);
PADDLE_ASSERT
(
false
);
#if CUDA_VERSION < 8000
#if CUDA_VERSION < 8000
// On CUDA versions previous to 8.0, only __shared__ variables
// On CUDA versions previous to 8.0, only __shared__ variables
// could be declared as static in the device code.
// could be declared as static in the device code.
int64_t
head
=
0
;
int64_t
head
=
0
;
#else
#else
static
int64_t
head
=
0
;
static
int64_t
head
=
0
;
#endif
#endif
return
head
;
return
head
;
#endif
#endif
}
}
}
// namespace
}
// namespace
// Static access to constant Dim
// Static access to constant Dim
template
<
int
i
,
int
l
>
HOSTDEVICE
int64_t
get
(
const
Dim
<
l
>
&
d
)
{
template
<
int
i
,
int
l
>
HOSTDEVICE
int64_t
get
(
const
Dim
<
l
>
&
d
)
{
return
DimGetter
<
i
>::
impl
(
d
);
return
DimGetter
<
i
>::
impl
(
d
);
}
}
// Static access to mutable Dim
// Static access to mutable Dim
template
<
int
i
,
int
l
>
HOSTDEVICE
int64_t
&
get
(
Dim
<
l
>
&
d
)
{
template
<
int
i
,
int
l
>
HOSTDEVICE
int64_t
&
get
(
Dim
<
l
>
&
d
)
{
return
DimGetter
<
i
>::
impl
(
d
);
return
DimGetter
<
i
>::
impl
(
d
);
}
}
// Dynamic access to constant Dim
// Dynamic access to constant Dim
template
<
int
l
>
HOSTDEVICE
int64_t
Dim
<
l
>::
operator
[](
int
i
)
const
{
template
<
int
l
>
HOSTDEVICE
int64_t
Dim
<
l
>::
operator
[](
int
i
)
const
{
// std::cout << "l: " << l << std::endl;
// std::cout << "l: " << l << std::endl;
return
indexer
(
*
this
,
i
);
return
indexer
(
*
this
,
i
);
}
}
// Dynamic access to mutable Dim
// Dynamic access to mutable Dim
template
<
int
l
>
HOSTDEVICE
int64_t
&
Dim
<
l
>::
operator
[](
int
i
)
{
template
<
int
l
>
HOSTDEVICE
int64_t
&
Dim
<
l
>::
operator
[](
int
i
)
{
return
indexer
(
*
this
,
i
);
return
indexer
(
*
this
,
i
);
}
}
// Dynamic access to constant Dim
// Dynamic access to constant Dim
inline
HOSTDEVICE
int64_t
Dim
<
0
>::
operator
[](
int
i
)
const
{
inline
HOSTDEVICE
int64_t
Dim
<
0
>::
operator
[](
int
i
)
const
{
return
indexer
(
*
this
,
i
);
return
indexer
(
*
this
,
i
);
}
}
// Dynamic access to mutable Dim
// Dynamic access to mutable Dim
inline
HOSTDEVICE
int64_t
&
Dim
<
0
>::
operator
[](
int
i
)
{
inline
HOSTDEVICE
int64_t
&
Dim
<
0
>::
operator
[](
int
i
)
{
return
indexer
(
*
this
,
i
);
return
indexer
(
*
this
,
i
);
}
}
// Dynamic access to constant Dim
// Dynamic access to constant Dim
...
@@ -225,52 +223,52 @@ inline HOSTDEVICE int64_t &Dim<0>::operator[](int i) {
...
@@ -225,52 +223,52 @@ inline HOSTDEVICE int64_t &Dim<0>::operator[](int i) {
template
<
int
l
>
template
<
int
l
>
HOSTDEVICE
typename
std
::
enable_if
<
(
l
>
0
),
int64_t
>::
type
get
(
const
Dim
<
l
>
&
d
,
HOSTDEVICE
typename
std
::
enable_if
<
(
l
>
0
),
int64_t
>::
type
get
(
const
Dim
<
l
>
&
d
,
int
i
)
{
int
i
)
{
return
d
[
i
];
return
d
[
i
];
}
}
// Dynamic access to mutable Dim
// Dynamic access to mutable Dim
template
<
int
l
>
template
<
int
l
>
HOSTDEVICE
typename
std
::
enable_if
<
(
l
>
0
),
int64_t
&>::
type
get
(
Dim
<
l
>
&
d
,
HOSTDEVICE
typename
std
::
enable_if
<
(
l
>
0
),
int64_t
&>::
type
get
(
Dim
<
l
>
&
d
,
int
i
)
{
int
i
)
{
return
d
[
i
];
return
d
[
i
];
}
}
// Dot product of two dims
// Dot product of two dims
template
<
int
i
>
template
<
int
i
>
HOSTDEVICE
int64_t
linearize
(
const
Dim
<
i
>
&
a
,
const
Dim
<
i
>
&
b
)
{
HOSTDEVICE
int64_t
linearize
(
const
Dim
<
i
>
&
a
,
const
Dim
<
i
>
&
b
)
{
return
a
.
head
*
b
.
head
+
linearize
(
a
.
tail
,
b
.
tail
);
return
a
.
head
*
b
.
head
+
linearize
(
a
.
tail
,
b
.
tail
);
}
}
// Base case dot product of two Dims
// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
// Notice it is inline because it is no longer a template
template
<
>
template
<
>
HOSTDEVICE
inline
int64_t
linearize
(
const
Dim
<
0
>
&
a
,
const
Dim
<
0
>
&
b
)
{
HOSTDEVICE
inline
int64_t
linearize
(
const
Dim
<
0
>
&
a
,
const
Dim
<
0
>
&
b
)
{
return
0
;
return
0
;
}
}
// Product of a Dim
// Product of a Dim
template
<
int
i
>
HOSTDEVICE
int64_t
product
(
const
Dim
<
i
>
&
a
,
int
prod
=
1
)
{
template
<
int
i
>
HOSTDEVICE
int64_t
product
(
const
Dim
<
i
>
&
a
,
int
prod
=
1
)
{
return
prod
*
a
.
head
*
product
(
a
.
tail
);
return
prod
*
a
.
head
*
product
(
a
.
tail
);
}
}
// Base case product of a Dim
// Base case product of a Dim
// Notice it is inline because it is no longer a template
// Notice it is inline because it is no longer a template
template
<
>
HOSTDEVICE
inline
int64_t
product
(
const
Dim
<
0
>
&
a
,
int
prod
)
{
template
<
>
HOSTDEVICE
inline
int64_t
product
(
const
Dim
<
0
>
&
a
,
int
prod
)
{
return
prod
;
return
prod
;
}
}
// Is 0 <= idx_i < size_i for all i?
// Is 0 <= idx_i < size_i for all i?
template
<
int
i
>
template
<
int
i
>
HOSTDEVICE
bool
contained
(
const
Dim
<
i
>
&
idx
,
const
Dim
<
i
>
&
size
)
{
HOSTDEVICE
bool
contained
(
const
Dim
<
i
>
&
idx
,
const
Dim
<
i
>
&
size
)
{
return
((
0
<=
idx
.
head
)
&&
(
idx
.
head
<
size
.
head
)
&&
return
((
0
<=
idx
.
head
)
&&
(
idx
.
head
<
size
.
head
)
&&
contained
(
idx
.
tail
,
size
.
tail
));
contained
(
idx
.
tail
,
size
.
tail
));
}
}
// Base case of is 0 <= idx_i < size_i ?
// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
// Notice it is inline because it is no longer a template
template
<
>
template
<
>
HOSTDEVICE
inline
bool
contained
(
const
Dim
<
0
>
&
idx
,
const
Dim
<
0
>
&
size
)
{
HOSTDEVICE
inline
bool
contained
(
const
Dim
<
0
>
&
idx
,
const
Dim
<
0
>
&
size
)
{
return
true
;
return
true
;
}
}
/**
/**
...
@@ -278,14 +276,14 @@ HOSTDEVICE inline bool contained(const Dim<0> &idx, const Dim<0> &size) {
...
@@ -278,14 +276,14 @@ HOSTDEVICE inline bool contained(const Dim<0> &idx, const Dim<0> &size) {
*/
*/
template
<
int
i
>
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
ex_prefix_mul
(
const
Dim
<
i
>
&
src
,
int
mul
=
1
)
{
HOSTDEVICE
Dim
<
i
>
ex_prefix_mul
(
const
Dim
<
i
>
&
src
,
int
mul
=
1
)
{
return
Dim
<
i
>
(
mul
,
ex_prefix_mul
(
src
.
tail
,
mul
*
src
.
head
));
return
Dim
<
i
>
(
mul
,
ex_prefix_mul
(
src
.
tail
,
mul
*
src
.
head
));
}
}
///\cond HIDDEN
///\cond HIDDEN
// Base case of ex_prefix_mul
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
// Notice it is inline because it is no longer a template
template
<
>
HOSTDEVICE
inline
Dim
<
0
>
ex_prefix_mul
(
const
Dim
<
0
>
&
src
,
int
mul
)
{
template
<
>
HOSTDEVICE
inline
Dim
<
0
>
ex_prefix_mul
(
const
Dim
<
0
>
&
src
,
int
mul
)
{
return
Dim
<
0
>
();
return
Dim
<
0
>
();
}
}
///\endcond
///\endcond
...
@@ -293,36 +291,36 @@ template <> HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
...
@@ -293,36 +291,36 @@ template <> HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
* Add two dimensions together
* Add two dimensions together
*/
*/
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
dim_plus
(
const
Dim
<
i
>
&
a
,
const
Dim
<
i
>
&
b
)
{
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
dim_plus
(
const
Dim
<
i
>
&
a
,
const
Dim
<
i
>
&
b
)
{
return
Dim
<
i
>
(
a
.
head
+
b
.
head
,
dim_plus
(
a
.
tail
,
b
.
tail
));
return
Dim
<
i
>
(
a
.
head
+
b
.
head
,
dim_plus
(
a
.
tail
,
b
.
tail
));
}
}
// Base case
// Base case
template
<
>
template
<
>
HOSTDEVICE
inline
Dim
<
0
>
dim_plus
(
const
Dim
<
0
>
&
a
,
const
Dim
<
0
>
&
b
)
{
HOSTDEVICE
inline
Dim
<
0
>
dim_plus
(
const
Dim
<
0
>
&
a
,
const
Dim
<
0
>
&
b
)
{
return
Dim
<
0
>
();
return
Dim
<
0
>
();
}
}
template
<
int
i
>
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
operator
+
(
const
Dim
<
i
>
&
lhs
,
const
Dim
<
i
>
&
rhs
)
{
HOSTDEVICE
Dim
<
i
>
operator
+
(
const
Dim
<
i
>
&
lhs
,
const
Dim
<
i
>
&
rhs
)
{
return
dim_plus
(
lhs
,
rhs
);
return
dim_plus
(
lhs
,
rhs
);
}
}
/**
/**
* Multiply two dimensions together
* Multiply two dimensions together
*/
*/
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
dim_mult
(
const
Dim
<
i
>
&
a
,
const
Dim
<
i
>
&
b
)
{
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
dim_mult
(
const
Dim
<
i
>
&
a
,
const
Dim
<
i
>
&
b
)
{
return
Dim
<
i
>
(
a
.
head
*
b
.
head
,
dim_mult
(
a
.
tail
,
b
.
tail
));
return
Dim
<
i
>
(
a
.
head
*
b
.
head
,
dim_mult
(
a
.
tail
,
b
.
tail
));
}
}
// Base case
// Base case
template
<
>
template
<
>
HOSTDEVICE
inline
Dim
<
0
>
dim_mult
(
const
Dim
<
0
>
&
a
,
const
Dim
<
0
>
&
b
)
{
HOSTDEVICE
inline
Dim
<
0
>
dim_mult
(
const
Dim
<
0
>
&
a
,
const
Dim
<
0
>
&
b
)
{
return
Dim
<
0
>
();
return
Dim
<
0
>
();
}
}
template
<
int
i
>
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
operator
*
(
const
Dim
<
i
>
&
lhs
,
const
Dim
<
i
>
&
rhs
)
{
HOSTDEVICE
Dim
<
i
>
operator
*
(
const
Dim
<
i
>
&
lhs
,
const
Dim
<
i
>
&
rhs
)
{
return
dim_mult
(
lhs
,
rhs
);
return
dim_mult
(
lhs
,
rhs
);
}
}
/**
/**
...
@@ -337,8 +335,8 @@ HOSTDEVICE Dim<i> operator*(const Dim<i> &lhs, const Dim<i> &rhs) {
...
@@ -337,8 +335,8 @@ HOSTDEVICE Dim<i> operator*(const Dim<i> &lhs, const Dim<i> &rhs) {
template
<
int
i
>
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
normalize_strides
(
const
Dim
<
i
>
&
size
,
const
Dim
<
i
>
&
stride
)
{
HOSTDEVICE
Dim
<
i
>
normalize_strides
(
const
Dim
<
i
>
&
size
,
const
Dim
<
i
>
&
stride
)
{
int
norm_stride
=
size
.
head
==
1
?
0
:
stride
.
head
;
int
norm_stride
=
size
.
head
==
1
?
0
:
stride
.
head
;
return
Dim
<
i
>
(
norm_stride
,
normalize_strides
(
size
.
tail
,
stride
.
tail
));
return
Dim
<
i
>
(
norm_stride
,
normalize_strides
(
size
.
tail
,
stride
.
tail
));
}
}
///\cond HIDDEN
///\cond HIDDEN
...
@@ -346,7 +344,7 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i> &size, const Dim<i> &stride) {
...
@@ -346,7 +344,7 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i> &size, const Dim<i> &stride) {
template
<
>
template
<
>
HOSTDEVICE
inline
Dim
<
0
>
normalize_strides
(
const
Dim
<
0
>
&
size
,
HOSTDEVICE
inline
Dim
<
0
>
normalize_strides
(
const
Dim
<
0
>
&
size
,
const
Dim
<
0
>
&
stride
)
{
const
Dim
<
0
>
&
stride
)
{
return
Dim
<
0
>
();
return
Dim
<
0
>
();
}
}
///\endcond
///\endcond
...
@@ -361,7 +359,7 @@ HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0> &size,
...
@@ -361,7 +359,7 @@ HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0> &size,
template
<
typename
...
Args
>
template
<
typename
...
Args
>
HOSTDEVICE
Dim
<
sizeof
...(
Args
)
>
make_dim
(
Args
...
idxes
)
{
HOSTDEVICE
Dim
<
sizeof
...(
Args
)
>
make_dim
(
Args
...
idxes
)
{
return
Dim
<
sizeof
...(
Args
)
>
(
idxes
...);
return
Dim
<
sizeof
...(
Args
)
>
(
idxes
...);
}
}
// Allows us to output a Dim
// Allows us to output a Dim
...
@@ -369,8 +367,8 @@ HOSTDEVICE Dim<sizeof...(Args)> make_dim(Args... idxes) {
...
@@ -369,8 +367,8 @@ HOSTDEVICE Dim<sizeof...(Args)> make_dim(Args... idxes) {
template
<
int
i
>
template
<
int
i
>
typename
std
::
enable_if
<
(
i
>
1
),
std
::
ostream
&>::
type
typename
std
::
enable_if
<
(
i
>
1
),
std
::
ostream
&>::
type
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
i
>
&
d
)
{
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
i
>
&
d
)
{
os
<<
d
.
head
<<
", "
<<
d
.
tail
;
os
<<
d
.
head
<<
", "
<<
d
.
tail
;
return
os
;
return
os
;
}
}
// Base case that allows us to output a Dim
// Base case that allows us to output a Dim
...
@@ -378,34 +376,34 @@ operator<<(std::ostream &os, const Dim<i> &d) {
...
@@ -378,34 +376,34 @@ operator<<(std::ostream &os, const Dim<i> &d) {
template
<
int
i
>
template
<
int
i
>
typename
std
::
enable_if
<
(
i
==
1
),
std
::
ostream
&>::
type
typename
std
::
enable_if
<
(
i
==
1
),
std
::
ostream
&>::
type
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
i
>
&
d
)
{
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
i
>
&
d
)
{
os
<<
d
.
head
;
os
<<
d
.
head
;
return
os
;
return
os
;
}
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
0
>
&
d
)
{
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
0
>
&
d
)
{
return
os
;
return
os
;
}
}
template
<
int
i
>
HOST
std
::
string
Dim
<
i
>::
to_string
()
const
{
template
<
int
i
>
HOST
std
::
string
Dim
<
i
>::
to_string
()
const
{
std
::
stringstream
stream
;
std
::
stringstream
stream
;
stream
<<
*
this
;
stream
<<
*
this
;
return
stream
.
str
();
return
stream
.
str
();
}
}
template
<
int
D
>
template
<
int
D
>
HOSTDEVICE
Dim
<
D
>
linear_to_dimension
(
int
linear_index
,
Dim
<
D
>
extents
)
{
HOSTDEVICE
Dim
<
D
>
linear_to_dimension
(
int
linear_index
,
Dim
<
D
>
extents
)
{
Dim
<
D
>
result
;
Dim
<
D
>
result
;
for
(
int
i
=
0
;
i
<
D
-
1
;
++
i
)
{
for
(
int
i
=
0
;
i
<
D
-
1
;
++
i
)
{
result
[
i
]
=
linear_index
%
extents
[
i
];
result
[
i
]
=
linear_index
%
extents
[
i
];
linear_index
/=
extents
[
i
];
linear_index
/=
extents
[
i
];
}
}
result
[
D
-
1
]
=
linear_index
;
result
[
D
-
1
]
=
linear_index
;
return
result
;
return
result
;
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/executor.cpp
浏览文件 @
1ad8f821
...
@@ -26,66 +26,66 @@ namespace framework {
...
@@ -26,66 +26,66 @@ namespace framework {
template
<
typename
Dtype
>
template
<
typename
Dtype
>
Executor
<
Dtype
>::
Executor
(
const
Program
<
Dtype
>
p
)
:
program_
(
p
)
{
Executor
<
Dtype
>::
Executor
(
const
Program
<
Dtype
>
p
)
:
program_
(
p
)
{
if
(
use_optimize_
)
{
if
(
use_optimize_
)
{
to_predict_program_
=
program_
.
optimizeProgram
;
to_predict_program_
=
program_
.
optimizeProgram
;
}
else
{
}
else
{
to_predict_program_
=
program_
.
originProgram
;
to_predict_program_
=
program_
.
originProgram
;
}
}
// const std::vector<std::shared_ptr<BlockDesc>> blocks =
// const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_
->
Blocks
();
to_predict_program_
->
Blocks
();
// for (int i = 0; i < blocks.size(); ++i) {
// for (int i = 0; i < blocks.size(); ++i) {
// std::shared_ptr<BlockDesc> block_desc = blocks[i];
// std::shared_ptr<BlockDesc> block_desc = blocks[i];
// std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// for (int j = 0; j < ops.size(); ++j) {
// for (int j = 0; j < ops.size(); ++j) {
// std::shared_ptr<OpDesc> op = ops[j];
// std::shared_ptr<OpDesc> op = ops[j];
// if (op->Type() == "conv2d" && op->Input("Input")[0] ==
// if (op->Type() == "conv2d" && op->Input("Input")[0] ==
// "pixel") {
// "pixel") {
// Attribute strides_attr = op->GetAttrMap().at("strides");
// Attribute strides_attr = op->GetAttrMap().at("strides");
// std::vector<int> stride =
// std::vector<int> stride =
// strides_attr.Get<std::vector<int>>(); for (int k = 0; k <
// strides_attr.Get<std::vector<int>>(); for (int k = 0; k <
// stride.size(); ++k) {
// stride.size(); ++k) {
// }
// }
// std::shared_ptr<operators::ConvOp<Dtype, float>> conv =
// std::shared_ptr<operators::ConvOp<Dtype, float>> conv =
// std::make_shared<operators::ConvOp<Dtype, float>>(
// std::make_shared<operators::ConvOp<Dtype, float>>(
// op->Type(), op->GetInputs(), op->GetOutputs(),
// op->Type(), op->GetInputs(), op->GetOutputs(),
// op->GetAttrMap(), program_.scope);
// op->GetAttrMap(), program_.scope);
// ops_of_block_[*block_desc.get()].push_back(conv);
// ops_of_block_[*block_desc.get()].push_back(conv);
// }
// }
// }
// }
// }
// }
}
}
template
<
typename
Dtype
>
template
<
typename
Dtype
>
std
::
shared_ptr
<
Tensor
>
Executor
<
Dtype
>::
predict
(
Tensor
&
t
)
{
std
::
shared_ptr
<
Tensor
>
Executor
<
Dtype
>::
predict
(
Tensor
&
t
)
{
// feed
// feed
auto
scope
=
program_
.
scope
;
auto
scope
=
program_
.
scope
;
Variable
*
g_feed_value
=
scope
->
Var
(
"pixel"
);
Variable
*
g_feed_value
=
scope
->
Var
(
"pixel"
);
auto
tensor
=
g_feed_value
->
GetMutable
<
Tensor
>
();
auto
tensor
=
g_feed_value
->
GetMutable
<
Tensor
>
();
tensor
->
ShareDataWith
(
t
);
tensor
->
ShareDataWith
(
t
);
Variable
*
con_output
=
scope
->
Var
(
"conv2d_0.tmp_0"
);
Variable
*
con_output
=
scope
->
Var
(
"conv2d_0.tmp_0"
);
Tensor
*
output_tensor
=
con_output
->
GetMutable
<
Tensor
>
();
Tensor
*
output_tensor
=
con_output
->
GetMutable
<
Tensor
>
();
output_tensor
->
mutable_data
<
float
>
({
1
,
16
,
32
,
32
});
output_tensor
->
mutable_data
<
float
>
({
1
,
16
,
32
,
32
});
// std::cout << typeid(output_tensor).name() << std::endl;
// std::cout << typeid(output_tensor).name() << std::endl;
// std::cout << "output_tensor dims: " << output_tensor->dims() <<
// std::cout << "output_tensor dims: " << output_tensor->dims() <<
// std::endl;
// std::endl;
std
::
shared_ptr
<
Tensor
>
out_tensor
=
std
::
make_shared
<
LoDTensor
>
();
std
::
shared_ptr
<
Tensor
>
out_tensor
=
std
::
make_shared
<
LoDTensor
>
();
out_tensor
.
reset
(
output_tensor
);
out_tensor
.
reset
(
output_tensor
);
predict
(
t
,
0
);
predict
(
t
,
0
);
return
out_tensor
;
return
out_tensor
;
}
}
template
<
typename
Dtype
>
template
<
typename
Dtype
>
void
Executor
<
Dtype
>::
predict
(
const
Tensor
&
t
,
int
block_id
)
{
void
Executor
<
Dtype
>::
predict
(
const
Tensor
&
t
,
int
block_id
)
{
std
::
shared_ptr
<
BlockDesc
>
to_predict_block
=
std
::
shared_ptr
<
BlockDesc
>
to_predict_block
=
to_predict_program_
->
Block
(
block_id
);
to_predict_program_
->
Block
(
block_id
);
for
(
int
j
=
0
;
j
<
ops_of_block_
[
*
to_predict_block
.
get
()].
size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
ops_of_block_
[
*
to_predict_block
.
get
()].
size
();
++
j
)
{
auto
op
=
ops_of_block_
[
*
to_predict_block
.
get
()][
j
];
auto
op
=
ops_of_block_
[
*
to_predict_block
.
get
()][
j
];
op
->
Run
();
op
->
Run
();
}
}
}
}
template
class
Executor
<
CPU
>;
template
class
Executor
<
CPU
>;
...
...
src/framework/executor.h
浏览文件 @
1ad8f821
...
@@ -35,23 +35,23 @@ namespace paddle_mobile {
...
@@ -35,23 +35,23 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
template
<
typename
Dtype
>
class
Executor
{
template
<
typename
Dtype
>
class
Executor
{
public:
public:
Executor
();
Executor
();
Executor
(
const
Program
<
Dtype
>
p
);
Executor
(
const
Program
<
Dtype
>
p
);
std
::
shared_ptr
<
Tensor
>
predict
(
Tensor
&
t
);
std
::
shared_ptr
<
Tensor
>
predict
(
Tensor
&
t
);
public:
public:
const
framework
::
Program
<
Dtype
>
program_
;
const
framework
::
Program
<
Dtype
>
program_
;
std
::
shared_ptr
<
ProgramDesc
>
to_predict_program_
;
std
::
shared_ptr
<
ProgramDesc
>
to_predict_program_
;
void
predict
(
const
Tensor
&
t
,
int
block_id
);
void
predict
(
const
Tensor
&
t
,
int
block_id
);
std
::
map
<
framework
::
BlockDesc
,
std
::
map
<
framework
::
BlockDesc
,
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
<
Dtype
>>>>
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
<
Dtype
>>>>
ops_of_block_
;
ops_of_block_
;
bool
use_optimize_
=
false
;
bool
use_optimize_
=
false
;
};
};
}
// namespace framework
}
// namespace framework
...
...
src/framework/framework.pb.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/framework/framework.pb.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/framework/lod_tensor.cc
浏览文件 @
1ad8f821
...
@@ -22,291 +22,289 @@ namespace paddle_mobile {
...
@@ -22,291 +22,289 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
LoD
&
lod
)
{
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
LoD
&
lod
)
{
os
<<
"{"
;
for
(
auto
&
v
:
lod
)
{
os
<<
"{"
;
os
<<
"{"
;
for
(
auto
&
v
:
lod
)
{
bool
is_first
=
true
;
os
<<
"{"
;
for
(
auto
&
i
:
v
)
{
bool
is_first
=
true
;
if
(
is_first
)
{
for
(
auto
&
i
:
v
)
{
os
<<
i
;
if
(
is_first
)
{
is_first
=
false
;
os
<<
i
;
}
else
{
is_first
=
false
;
os
<<
", "
<<
i
;
}
else
{
}
os
<<
", "
<<
i
;
}
}
os
<<
"}"
;
}
}
os
<<
"}"
;
os
<<
"}"
;
}
os
<<
"}"
;
return
os
;
return
os
;
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
LoDTensor
&
t
)
{
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
LoDTensor
&
t
)
{
// PADDLE_ENFORCE(t.type().hash_code() ==
// PADDLE_ENFORCE(t.type().hash_code() ==
// typeid(float).hash_code());
// typeid(float).hash_code());
// if (!platform::is_cpu_place(t.place())) {
// if (!platform::is_cpu_place(t.place())) {
// LoDTensor tt;
// LoDTensor tt;
// framework::TensorCopy(t, platform::CPUPlace(), &tt);
// framework::TensorCopy(t, platform::CPUPlace(), &tt);
// platform::DeviceContextPool &pool =
// platform::DeviceContextPool &pool =
// platform::DeviceContextPool::Instance(); auto &dev_ctx =
// platform::DeviceContextPool::Instance(); auto &dev_ctx =
// *pool.Get(t.place()); dev_ctx.Wait();
// *pool.Get(t.place()); dev_ctx.Wait();
//
//
// os << tt;
// os << tt;
// return os;
// return os;
// }
// }
os
<<
"dim: "
<<
t
.
dims
()
<<
"
\n
"
;
os
<<
"dim: "
<<
t
.
dims
()
<<
"
\n
"
;
os
<<
"lod: "
<<
t
.
lod
()
<<
"
\n
"
;
os
<<
"lod: "
<<
t
.
lod
()
<<
"
\n
"
;
// only print first ten elements
// only print first ten elements
int64_t
size
=
t
.
numel
()
<
10
?
t
.
numel
()
:
10
;
int64_t
size
=
t
.
numel
()
<
10
?
t
.
numel
()
:
10
;
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
os
<<
t
.
data
<
float
>
()[
i
]
<<
" "
;
os
<<
t
.
data
<
float
>
()[
i
]
<<
" "
;
}
}
return
os
;
return
os
;
}
}
std
::
string
LoDToString
(
const
LoD
&
lod
)
{
std
::
string
LoDToString
(
const
LoD
&
lod
)
{
std
::
ostringstream
stream
;
std
::
ostringstream
stream
;
stream
<<
lod
;
stream
<<
lod
;
return
stream
.
str
();
return
stream
.
str
();
}
}
LoD
SliceInLevel
(
const
LoD
&
in
,
size_t
level
,
size_t
elem_begin
,
LoD
SliceInLevel
(
const
LoD
&
in
,
size_t
level
,
size_t
elem_begin
,
size_t
elem_end
)
{
size_t
elem_end
)
{
// PADDLE_ENFORCE_LT(level, in.size());
// PADDLE_ENFORCE_LT(level, in.size());
// PADDLE_ENFORCE_LT(elem_end, in[level].size());
// PADDLE_ENFORCE_LT(elem_end, in[level].size());
LoD
res
;
LoD
res
;
res
.
resize
(
in
.
size
()
-
level
);
res
.
resize
(
in
.
size
()
-
level
);
// copy the first level
// copy the first level
res
[
0
].
assign
(
in
[
level
].
begin
()
+
elem_begin
,
res
[
0
].
assign
(
in
[
level
].
begin
()
+
elem_begin
,
in
[
level
].
begin
()
+
elem_end
+
1
);
in
[
level
].
begin
()
+
elem_end
+
1
);
for
(
size_t
lvl
=
1
;
lvl
<
res
.
size
();
lvl
++
)
{
for
(
size_t
lvl
=
1
;
lvl
<
res
.
size
();
lvl
++
)
{
const
auto
&
in_level
=
in
[
level
+
lvl
];
const
auto
&
in_level
=
in
[
level
+
lvl
];
const
auto
&
above_level
=
res
[
lvl
-
1
];
const
auto
&
above_level
=
res
[
lvl
-
1
];
auto
&
out_level
=
res
[
lvl
];
auto
&
out_level
=
res
[
lvl
];
out_level
.
assign
(
in_level
.
begin
()
+
above_level
.
front
(),
out_level
.
assign
(
in_level
.
begin
()
+
above_level
.
front
(),
in_level
.
begin
()
+
above_level
.
back
()
+
1
);
in_level
.
begin
()
+
above_level
.
back
()
+
1
);
}
}
for
(
size_t
lvl
=
0
;
lvl
<
res
.
size
();
lvl
++
)
{
for
(
size_t
lvl
=
0
;
lvl
<
res
.
size
();
lvl
++
)
{
// to make the first offset equals 0, all the elements minus the
// to make the first offset equals 0, all the elements minus the
// first
// first
// element
// element
size_t
front
=
res
[
lvl
].
front
();
size_t
front
=
res
[
lvl
].
front
();
for
(
auto
&
ele
:
res
[
lvl
])
{
for
(
auto
&
ele
:
res
[
lvl
])
{
ele
-=
front
;
ele
-=
front
;
}
}
}
return
res
;
}
return
res
;
}
}
LoD
ToAbsOffset
(
const
LoD
&
in
)
{
LoD
ToAbsOffset
(
const
LoD
&
in
)
{
// the lowest level stores relative offsets
// the lowest level stores relative offsets
if
(
in
.
empty
()
||
in
.
size
()
==
1
)
if
(
in
.
empty
()
||
in
.
size
()
==
1
)
return
in
;
return
in
;
LoD
result
=
in
;
LoD
result
=
in
;
for
(
auto
level
=
static_cast
<
int
>
(
in
.
size
()
-
2
);
level
>=
0
;
level
--
)
{
for
(
auto
level
=
static_cast
<
int
>
(
in
.
size
()
-
2
);
level
>=
0
;
level
--
)
{
for
(
size_t
i
=
0
;
i
<
in
[
level
].
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
in
[
level
].
size
();
++
i
)
{
size_t
index
=
in
[
level
][
i
];
size_t
index
=
in
[
level
][
i
];
result
[
level
][
i
]
=
result
[
level
+
1
][
index
];
result
[
level
][
i
]
=
result
[
level
+
1
][
index
];
}
}
}
return
result
;
}
return
result
;
}
}
bool
operator
==
(
const
LoD
&
a
,
const
LoD
&
b
)
{
bool
operator
==
(
const
LoD
&
a
,
const
LoD
&
b
)
{
if
(
a
.
size
()
!=
b
.
size
())
{
if
(
a
.
size
()
!=
b
.
size
())
{
return
false
;
return
false
;
}
for
(
size_t
i
=
0
;
i
<
a
.
size
();
i
++
)
{
const
auto
&
a_level
=
a
[
i
];
const
auto
&
b_level
=
b
[
i
];
if
(
a_level
.
size
()
!=
b_level
.
size
())
{
return
false
;
}
}
for
(
size_t
j
=
0
;
j
<
a_level
.
size
();
j
++
)
{
for
(
size_t
i
=
0
;
i
<
a
.
size
();
i
++
)
{
if
(
a_level
[
j
]
!=
b_level
[
j
])
{
const
auto
&
a_level
=
a
[
i
];
return
false
;
const
auto
&
b_level
=
b
[
i
];
}
if
(
a_level
.
size
()
!=
b_level
.
size
())
{
return
false
;
}
for
(
size_t
j
=
0
;
j
<
a_level
.
size
();
j
++
)
{
if
(
a_level
[
j
]
!=
b_level
[
j
])
{
return
false
;
}
}
}
}
return
true
;
}
return
true
;
}
}
bool
CheckLoD
(
const
LoD
&
in
,
int
tensor_height
)
{
bool
CheckLoD
(
const
LoD
&
in
,
int
tensor_height
)
{
if
(
in
.
empty
())
if
(
in
.
empty
())
return
true
;
for
(
const
auto
&
level
:
in
)
{
// check: there should be more than 2 offsets existing in each
// level.
if
(
level
.
size
()
<
2
)
return
false
;
// check: the first offset(the begin offset) of each level
// should be 0.
if
(
level
.
front
()
!=
0
)
return
false
;
// check: all the offsets in a level should be ascending(no same
// items
// allows).
if
(
!
std
::
is_sorted
(
level
.
begin
(),
level
.
begin
(),
[](
size_t
a
,
size_t
b
)
{
if
(
a
<
b
)
return
true
;
return
false
;
}))
{
std
::
cout
<<
"ascending error"
;
return
false
;
}
}
// check: the lowest level's last offset should equals
// `tensor_height` if
// tensor_height>0.
if
(
tensor_height
>
0
&&
(
size_t
)
tensor_height
!=
in
.
back
().
back
())
return
false
;
// check: the higher level's last offset should equals the lower
// level's
// size-1.
// NOTE LoD store the levels from top to bottom, so the higher level
// goes
// first.
for
(
size_t
level
=
0
;
level
<
in
.
size
()
-
1
;
level
++
)
{
if
(
in
[
level
].
back
()
!=
in
[
level
+
1
].
size
()
-
1
)
return
false
;
}
return
true
;
return
true
;
for
(
const
auto
&
level
:
in
)
{
// check: there should be more than 2 offsets existing in each
// level.
if
(
level
.
size
()
<
2
)
return
false
;
// check: the first offset(the begin offset) of each level
// should be 0.
if
(
level
.
front
()
!=
0
)
return
false
;
// check: all the offsets in a level should be ascending(no same
// items
// allows).
if
(
!
std
::
is_sorted
(
level
.
begin
(),
level
.
begin
(),
[](
size_t
a
,
size_t
b
)
{
if
(
a
<
b
)
return
true
;
return
false
;
}))
{
std
::
cout
<<
"ascending error"
;
return
false
;
}
}
// check: the lowest level's last offset should equals
// `tensor_height` if
// tensor_height>0.
if
(
tensor_height
>
0
&&
(
size_t
)
tensor_height
!=
in
.
back
().
back
())
return
false
;
// check: the higher level's last offset should equals the lower
// level's
// size-1.
// NOTE LoD store the levels from top to bottom, so the higher level
// goes
// first.
for
(
size_t
level
=
0
;
level
<
in
.
size
()
-
1
;
level
++
)
{
if
(
in
[
level
].
back
()
!=
in
[
level
+
1
].
size
()
-
1
)
return
false
;
}
return
true
;
}
}
bool
CheckAbsLoD
(
const
LoD
&
in
,
int
tensor_height
)
{
bool
CheckAbsLoD
(
const
LoD
&
in
,
int
tensor_height
)
{
if
(
in
.
empty
())
if
(
in
.
empty
())
return
true
;
for
(
const
auto
&
level
:
in
)
{
// check: all the offsets in a level should be ascending(no same
// items
// allows).
if
(
!
std
::
is_sorted
(
level
.
begin
(),
level
.
begin
(),
[](
size_t
a
,
size_t
b
)
{
if
(
a
<
b
)
return
true
;
return
false
;
}))
{
return
false
;
}
// check: there should be more than 2 offsets existing in each
// level.
if
(
level
.
size
()
<
2
)
return
false
;
// check: the first offset of each level should be 0, and the
// last should be
// the same(the height of underlying tensor).
if
(
level
.
front
()
!=
0
)
return
false
;
if
(
tensor_height
<
0
)
{
tensor_height
=
level
.
back
();
}
else
if
((
size_t
)
tensor_height
!=
level
.
back
())
{
return
false
;
}
}
return
true
;
return
true
;
for
(
const
auto
&
level
:
in
)
{
// check: all the offsets in a level should be ascending(no same
// items
// allows).
if
(
!
std
::
is_sorted
(
level
.
begin
(),
level
.
begin
(),
[](
size_t
a
,
size_t
b
)
{
if
(
a
<
b
)
return
true
;
return
false
;
}))
{
return
false
;
}
// check: there should be more than 2 offsets existing in each
// level.
if
(
level
.
size
()
<
2
)
return
false
;
// check: the first offset of each level should be 0, and the
// last should be
// the same(the height of underlying tensor).
if
(
level
.
front
()
!=
0
)
return
false
;
if
(
tensor_height
<
0
)
{
tensor_height
=
level
.
back
();
}
else
if
((
size_t
)
tensor_height
!=
level
.
back
())
{
return
false
;
}
}
return
true
;
}
}
using
LoDAndOffset
=
std
::
pair
<
LoD
,
std
::
pair
<
size_t
,
size_t
>>
;
using
LoDAndOffset
=
std
::
pair
<
LoD
,
std
::
pair
<
size_t
,
size_t
>>
;
LoDAndOffset
GetSubLoDAndAbsoluteOffset
(
const
LoD
&
lod
,
size_t
start_idx
,
LoDAndOffset
GetSubLoDAndAbsoluteOffset
(
const
LoD
&
lod
,
size_t
start_idx
,
size_t
end_idx
,
size_t
start_level
)
{
size_t
end_idx
,
size_t
start_level
)
{
LoD
sub_lod
;
LoD
sub_lod
;
for
(
size_t
level_idx
=
start_level
;
level_idx
<
lod
.
size
();
++
level_idx
)
{
for
(
size_t
level_idx
=
start_level
;
level_idx
<
lod
.
size
();
++
level_idx
)
{
// PADDLE_ENFORCE_LE(start_idx, end_idx);
// PADDLE_ENFORCE_LE(start_idx, end_idx);
// PADDLE_ENFORCE_LT(end_idx, lod[level_idx].size());
// PADDLE_ENFORCE_LT(end_idx, lod[level_idx].size());
std
::
vector
<
size_t
>
level_lens
;
std
::
vector
<
size_t
>
level_lens
;
for
(
size_t
i
=
start_idx
;
i
<
end_idx
;
++
i
)
{
for
(
size_t
i
=
start_idx
;
i
<
end_idx
;
++
i
)
{
level_lens
.
push_back
(
lod
[
level_idx
][
i
+
1
]
-
lod
[
level_idx
][
i
]);
level_lens
.
push_back
(
lod
[
level_idx
][
i
+
1
]
-
lod
[
level_idx
][
i
]);
}
sub_lod
.
emplace_back
(
level_lens
);
start_idx
=
lod
[
level_idx
][
start_idx
];
end_idx
=
lod
[
level_idx
][
end_idx
];
}
}
sub_lod
.
emplace_back
(
level_lens
);
start_idx
=
lod
[
level_idx
][
start_idx
];
end_idx
=
lod
[
level_idx
][
end_idx
];
}
return
LoDAndOffset
{
sub_lod
,
{
start_idx
,
end_idx
}};
return
LoDAndOffset
{
sub_lod
,
{
start_idx
,
end_idx
}};
}
}
void
AppendLoD
(
LoD
*
lod
,
const
LoD
&
lod_length
)
{
void
AppendLoD
(
LoD
*
lod
,
const
LoD
&
lod_length
)
{
// PADDLE_ENFORCE(
// PADDLE_ENFORCE(
// lod->empty() || lod->size() == lod_length.size(),
// lod->empty() || lod->size() == lod_length.size(),
// "The lod_length should has the same size with the appended
// "The lod_length should has the same size with the appended
// lod.");
// lod.");
if
(
lod
->
empty
())
{
if
(
lod
->
empty
())
{
for
(
size_t
i
=
0
;
i
<
lod_length
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
lod_length
.
size
();
++
i
)
{
lod
->
emplace_back
(
1
,
0
);
// size = 1, value = 0;
lod
->
emplace_back
(
1
,
0
);
// size = 1, value = 0;
}
*
lod
=
LoD
(
lod_length
.
size
(),
std
::
vector
<
size_t
>
({
0
}));
}
}
for
(
size_t
i
=
0
;
i
<
lod
->
size
();
++
i
)
{
*
lod
=
LoD
(
lod_length
.
size
(),
std
::
vector
<
size_t
>
({
0
}));
auto
&
level
=
(
*
lod
)[
i
];
}
for
(
size_t
len
:
lod_length
[
i
])
{
for
(
size_t
i
=
0
;
i
<
lod
->
size
();
++
i
)
{
level
.
push_back
(
level
.
back
()
+
len
);
auto
&
level
=
(
*
lod
)[
i
];
}
for
(
size_t
len
:
lod_length
[
i
])
{
level
.
push_back
(
level
.
back
()
+
len
);
}
}
}
}
}
void
SerializeToStream
(
std
::
ostream
&
os
,
const
LoDTensor
&
tensor
)
{
void
SerializeToStream
(
std
::
ostream
&
os
,
const
LoDTensor
&
tensor
)
{
{
// the 1st field, uint32_t version for LoDTensor
{
// the 1st field, uint32_t version for LoDTensor
constexpr
uint32_t
version
=
0
;
constexpr
uint32_t
version
=
0
;
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
version
),
sizeof
(
version
));
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
version
),
sizeof
(
version
));
}
{
// the 2st field, LoD information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
auto
lod
=
tensor
.
lod
();
uint64_t
size
=
lod
.
size
();
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
size
),
sizeof
(
size
));
for
(
auto
&
each
:
lod
)
{
size
=
each
.
size
()
*
sizeof
(
framework
::
LoD
::
value_type
::
value_type
);
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
size
),
sizeof
(
size
));
os
.
write
(
reinterpret_cast
<
const
char
*>
(
each
.
data
()),
static_cast
<
std
::
streamsize
>
(
size
));
}
}
{
}
// the 2st field, LoD information
// the 3st field, Tensor
// uint64_t lod_level
TensorToStream
(
os
,
static_cast
<
Tensor
>
(
tensor
));
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
auto
lod
=
tensor
.
lod
();
uint64_t
size
=
lod
.
size
();
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
size
),
sizeof
(
size
));
for
(
auto
&
each
:
lod
)
{
size
=
each
.
size
()
*
sizeof
(
framework
::
LoD
::
value_type
::
value_type
);
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
size
),
sizeof
(
size
));
os
.
write
(
reinterpret_cast
<
const
char
*>
(
each
.
data
()),
static_cast
<
std
::
streamsize
>
(
size
));
}
}
// the 3st field, Tensor
TensorToStream
(
os
,
static_cast
<
Tensor
>
(
tensor
));
}
}
void
DeserializeFromStream
(
std
::
istream
&
is
,
LoDTensor
*
tensor
)
{
void
DeserializeFromStream
(
std
::
istream
&
is
,
LoDTensor
*
tensor
)
{
{
{
// the 1st field, unit32_t version for LoDTensor
// the 1st field, unit32_t version for LoDTensor
uint32_t
version
;
uint32_t
version
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
version
),
sizeof
(
version
));
is
.
read
(
reinterpret_cast
<
char
*>
(
&
version
),
sizeof
(
version
));
// PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is
// PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is
// supported");
// supported");
}
}
{
{
// the 2st field, LoD information
// the 2st field, LoD information
uint64_t
lod_level
;
uint64_t
lod_level
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
lod_level
),
sizeof
(
lod_level
));
is
.
read
(
reinterpret_cast
<
char
*>
(
&
lod_level
),
sizeof
(
lod_level
));
auto
&
lod
=
*
tensor
->
mutable_lod
();
auto
&
lod
=
*
tensor
->
mutable_lod
();
lod
.
resize
(
lod_level
);
lod
.
resize
(
lod_level
);
for
(
uint64_t
i
=
0
;
i
<
lod_level
;
++
i
)
{
for
(
uint64_t
i
=
0
;
i
<
lod_level
;
++
i
)
{
uint64_t
size
;
uint64_t
size
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
size
),
sizeof
(
size
));
is
.
read
(
reinterpret_cast
<
char
*>
(
&
size
),
sizeof
(
size
));
std
::
vector
<
size_t
>
tmp
(
size
/
sizeof
(
size_t
));
std
::
vector
<
size_t
>
tmp
(
size
/
sizeof
(
size_t
));
is
.
read
(
reinterpret_cast
<
char
*>
(
tmp
.
data
()),
is
.
read
(
reinterpret_cast
<
char
*>
(
tmp
.
data
()),
static_cast
<
std
::
streamsize
>
(
size
));
static_cast
<
std
::
streamsize
>
(
size
));
lod
[
i
]
=
tmp
;
lod
[
i
]
=
tmp
;
}
}
}
// the 3st filed, Tensor
}
TensorFromStream
(
is
,
static_cast
<
Tensor
*>
(
tensor
));
// the 3st filed, Tensor
TensorFromStream
(
is
,
static_cast
<
Tensor
*>
(
tensor
));
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/lod_tensor.h
浏览文件 @
1ad8f821
...
@@ -102,45 +102,45 @@ bool CheckAbsLoD(const LoD &in, int tensor_height = -1);
...
@@ -102,45 +102,45 @@ bool CheckAbsLoD(const LoD &in, int tensor_height = -1);
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/
*/
class
LoDTensor
:
public
Tensor
{
class
LoDTensor
:
public
Tensor
{
public:
public:
LoDTensor
()
:
Tensor
()
{}
LoDTensor
()
:
Tensor
()
{}
explicit
LoDTensor
(
const
LoD
&
lod
)
:
lod_
(
lod
)
{}
explicit
LoDTensor
(
const
LoD
&
lod
)
:
lod_
(
lod
)
{}
void
set_lod
(
const
LoD
&
lod
)
{
lod_
=
lod
;
}
void
set_lod
(
const
LoD
&
lod
)
{
lod_
=
lod
;
}
const
LoD
&
lod
()
const
{
return
lod_
;
}
const
LoD
&
lod
()
const
{
return
lod_
;
}
LoD
*
mutable_lod
()
{
return
&
lod_
;
}
LoD
*
mutable_lod
()
{
return
&
lod_
;
}
/*
/*
* Get the start offset and end offset of an element from LoD.
* Get the start offset and end offset of an element from LoD.
*/
*/
std
::
pair
<
size_t
,
size_t
>
lod_element
(
size_t
level
,
size_t
elem
)
const
{
std
::
pair
<
size_t
,
size_t
>
lod_element
(
size_t
level
,
size_t
elem
)
const
{
// PADDLE_ENFORCE_LT(level, NumLevels());
// PADDLE_ENFORCE_LT(level, NumLevels());
// PADDLE_ENFORCE_LT(elem, NumElements(level));
// PADDLE_ENFORCE_LT(elem, NumElements(level));
return
std
::
make_pair
((
lod_
)[
level
][
elem
],
(
lod_
)[
level
][
elem
+
1
]);
return
std
::
make_pair
((
lod_
)[
level
][
elem
],
(
lod_
)[
level
][
elem
+
1
]);
}
}
/*
/*
* Number of LoDTensor's levels, each level has units of data, for
* Number of LoDTensor's levels, each level has units of data, for
* example,
* example,
* in the sentence's view, article, paragraph, sentence are 3
* in the sentence's view, article, paragraph, sentence are 3
* levels.
* levels.
*/
*/
size_t
NumLevels
()
const
{
return
lod_
.
size
();
}
size_t
NumLevels
()
const
{
return
lod_
.
size
();
}
/*
/*
* Number of elements in a level.
* Number of elements in a level.
*/
*/
size_t
NumElements
(
size_t
level
=
0
)
const
{
size_t
NumElements
(
size_t
level
=
0
)
const
{
// PADDLE_ENFORCE_LT(level, NumLevels());
// PADDLE_ENFORCE_LT(level, NumLevels());
// the last offset is the end of last element
// the last offset is the end of last element
return
(
lod_
)[
level
].
size
()
-
1
;
return
(
lod_
)[
level
].
size
()
-
1
;
}
}
private:
private:
LoD
lod_
;
LoD
lod_
;
};
};
/*
/*
...
@@ -155,26 +155,26 @@ class LoDTensor : public Tensor {
...
@@ -155,26 +155,26 @@ class LoDTensor : public Tensor {
*/
*/
template
<
typename
T
>
template
<
typename
T
>
LoDTensor
LodExpand
(
const
LoDTensor
&
source
,
const
LoD
&
lod
,
size_t
level
)
{
LoDTensor
LodExpand
(
const
LoDTensor
&
source
,
const
LoD
&
lod
,
size_t
level
)
{
LoD
abs_lod
=
ToAbsOffset
(
lod
);
LoD
abs_lod
=
ToAbsOffset
(
lod
);
const
auto
&
lod_level
=
lod
[
level
];
const
auto
&
lod_level
=
lod
[
level
];
size_t
num_instances
=
source
.
dims
()[
0
];
size_t
num_instances
=
source
.
dims
()[
0
];
// new tensor
// new tensor
LoDTensor
tensor
;
LoDTensor
tensor
;
tensor
.
set_lod
(
lod
);
tensor
.
set_lod
(
lod
);
auto
dims
=
source
.
dims
();
auto
dims
=
source
.
dims
();
dims
[
0
]
=
lod_level
.
back
();
dims
[
0
]
=
lod_level
.
back
();
tensor
.
Resize
(
dims
);
tensor
.
Resize
(
dims
);
tensor
.
mutable_data
<
T
>
();
tensor
.
mutable_data
<
T
>
();
// PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1);
// PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1);
for
(
size_t
ins
=
0
;
ins
<
num_instances
;
ins
++
)
{
for
(
size_t
ins
=
0
;
ins
<
num_instances
;
ins
++
)
{
for
(
size_t
elem
=
lod_level
[
ins
];
elem
<
lod_level
[
ins
+
1
];
elem
++
)
{
for
(
size_t
elem
=
lod_level
[
ins
];
elem
<
lod_level
[
ins
+
1
];
elem
++
)
{
auto
slice
=
tensor
.
Slice
(
elem
,
elem
+
1
);
auto
slice
=
tensor
.
Slice
(
elem
,
elem
+
1
);
TensorCopy
(
source
.
Slice
(
ins
,
ins
+
1
),
&
slice
);
TensorCopy
(
source
.
Slice
(
ins
,
ins
+
1
),
&
slice
);
}
}
}
return
tensor
;
}
return
tensor
;
}
}
// Get the absolute offset of a lod[start_level][start_idx:end_idx] and
// Get the absolute offset of a lod[start_level][start_idx:end_idx] and
...
...
src/framework/op_desc.cpp
浏览文件 @
1ad8f821
...
@@ -8,51 +8,51 @@ namespace paddle_mobile {
...
@@ -8,51 +8,51 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
OpDesc
::
OpDesc
(
const
proto
::
OpDesc
&
desc
)
:
desc_
(
desc
)
{
OpDesc
::
OpDesc
(
const
proto
::
OpDesc
&
desc
)
:
desc_
(
desc
)
{
for
(
int
i
=
0
;
i
<
desc_
.
inputs_size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
desc_
.
inputs_size
();
++
i
)
{
const
proto
::
OpDesc
::
Var
&
var
=
desc_
.
inputs
(
i
);
const
proto
::
OpDesc
::
Var
&
var
=
desc_
.
inputs
(
i
);
std
::
vector
<
std
::
string
>
&
args
=
inputs_
[
var
.
parameter
()];
std
::
vector
<
std
::
string
>
&
args
=
inputs_
[
var
.
parameter
()];
int
arg_size
=
var
.
arguments_size
();
int
arg_size
=
var
.
arguments_size
();
for
(
int
j
=
0
;
j
<
arg_size
;
++
j
)
{
for
(
int
j
=
0
;
j
<
arg_size
;
++
j
)
{
args
.
push_back
(
var
.
arguments
(
j
));
args
.
push_back
(
var
.
arguments
(
j
));
}
}
}
}
for
(
int
i
=
0
;
i
<
desc_
.
outputs_size
();
++
i
)
{
const
proto
::
OpDesc
::
Var
&
var
=
desc_
.
outputs
(
i
);
for
(
int
i
=
0
;
i
<
desc_
.
outputs_size
();
++
i
)
{
std
::
vector
<
std
::
string
>
&
args
=
outputs_
[
var
.
parameter
()]
;
const
proto
::
OpDesc
::
Var
&
var
=
desc_
.
outputs
(
i
)
;
int
arg_size
=
var
.
arguments_size
()
;
std
::
vector
<
std
::
string
>
&
args
=
outputs_
[
var
.
parameter
()]
;
for
(
int
j
=
0
;
j
<
arg_size
;
++
j
)
{
int
arg_size
=
var
.
arguments_size
();
args
.
push_back
(
var
.
arguments
(
j
));
for
(
int
j
=
0
;
j
<
arg_size
;
++
j
)
{
}
args
.
push_back
(
var
.
arguments
(
j
));
}
}
}
for
(
const
proto
::
OpDesc
::
Attr
&
attr
:
desc_
.
attrs
())
{
std
::
string
attr_name
=
attr
.
name
();
for
(
const
proto
::
OpDesc
::
Attr
&
attr
:
desc_
.
attrs
())
{
if
(
attr
.
type
()
!=
proto
::
AttrType
::
BLOCK
)
{
std
::
string
attr_name
=
attr
.
name
();
attrs_
[
attr_name
]
=
Attribute
::
GetAttrValue
(
attr
);
if
(
attr
.
type
()
!=
proto
::
AttrType
::
BLOCK
)
{
// if (attr.type() == proto::AttrType::INT){
attrs_
[
attr_name
]
=
Attribute
::
GetAttrValue
(
attr
);
// std::cout << " attrName " << attr_name << " " <<
// if (attr.type() == proto::AttrType::INT){
// attrs_[attr_name].Get<int>() << std::endl;
// std::cout << " attrName " << attr_name << " " <<
// }
// attrs_[attr_name].Get<int>() << std::endl;
}
//
}
}
}
}
}
}
const
std
::
vector
<
std
::
string
>
&
OpDesc
::
Input
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
OpDesc
::
Input
(
const
std
::
string
&
name
)
const
{
return
inputs_
.
find
(
name
)
->
second
;
return
inputs_
.
find
(
name
)
->
second
;
}
}
const
std
::
vector
<
std
::
string
>
&
OpDesc
::
Output
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
OpDesc
::
Output
(
const
std
::
string
&
name
)
const
{
return
outputs_
.
find
(
name
)
->
second
;
return
outputs_
.
find
(
name
)
->
second
;
}
}
Attribute
OpDesc
::
GetAttr
(
const
std
::
string
&
name
)
const
{
Attribute
OpDesc
::
GetAttr
(
const
std
::
string
&
name
)
const
{
auto
it
=
attrs_
.
find
(
name
);
auto
it
=
attrs_
.
find
(
name
);
return
it
->
second
;
return
it
->
second
;
}
}
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
OpDesc
::
GetAttrMap
()
const
{
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
OpDesc
::
GetAttrMap
()
const
{
return
attrs_
;
return
attrs_
;
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/op_desc.h
浏览文件 @
1ad8f821
...
@@ -26,25 +26,25 @@ namespace paddle_mobile {
...
@@ -26,25 +26,25 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
OpDesc
:
PaddleMobileObject
{
class
OpDesc
:
PaddleMobileObject
{
public:
public:
OpDesc
(
const
proto
::
OpDesc
&
desc
);
OpDesc
(
const
proto
::
OpDesc
&
desc
);
const
std
::
vector
<
std
::
string
>
&
Input
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>
&
Input
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>
&
Output
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>
&
Output
(
const
std
::
string
&
name
)
const
;
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
;
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
;
const
VariableNameMap
&
GetInputs
()
{
return
inputs_
;
}
const
VariableNameMap
&
GetInputs
()
{
return
inputs_
;
}
const
VariableNameMap
&
GetOutputs
()
{
return
outputs_
;
}
const
VariableNameMap
&
GetOutputs
()
{
return
outputs_
;
}
const
AttributeMap
&
GetAttrMap
()
const
;
const
AttributeMap
&
GetAttrMap
()
const
;
const
std
::
string
&
Type
()
{
return
desc_
.
type
();
};
const
std
::
string
&
Type
()
{
return
desc_
.
type
();
};
private:
private:
proto
::
OpDesc
desc_
;
proto
::
OpDesc
desc_
;
VariableNameMap
inputs_
;
VariableNameMap
inputs_
;
VariableNameMap
outputs_
;
VariableNameMap
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
src/framework/op_info.h
浏览文件 @
1ad8f821
...
@@ -25,13 +25,13 @@ namespace paddle_mobile {
...
@@ -25,13 +25,13 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
template
<
typename
Dtype
>
struct
OpInfo
{
template
<
typename
Dtype
>
struct
OpInfo
{
OpCreator
<
Dtype
>
creator_
;
OpCreator
<
Dtype
>
creator_
;
const
OpCreator
<
Dtype
>
&
Creator
()
const
{
const
OpCreator
<
Dtype
>
&
Creator
()
const
{
// PADDLE_ENFORCE_NOT_NULL(creator_,
// PADDLE_ENFORCE_NOT_NULL(creator_,
// "Operator Creator has not been
// "Operator Creator has not been
// registered");
// registered");
return
creator_
;
return
creator_
;
}
}
};
};
template
<
typename
Dtype
>
class
OpInfoMap
;
template
<
typename
Dtype
>
class
OpInfoMap
;
...
@@ -39,55 +39,55 @@ template <typename Dtype> class OpInfoMap;
...
@@ -39,55 +39,55 @@ template <typename Dtype> class OpInfoMap;
template
<
typename
Dtype
>
static
OpInfoMap
<
Dtype
>
*
g_op_info_map
=
nullptr
;
template
<
typename
Dtype
>
static
OpInfoMap
<
Dtype
>
*
g_op_info_map
=
nullptr
;
template
<
typename
Dtype
>
class
OpInfoMap
{
template
<
typename
Dtype
>
class
OpInfoMap
{
public:
public:
static
OpInfoMap
&
Instance
()
{
static
OpInfoMap
&
Instance
()
{
if
(
g_op_info_map
<
Dtype
>
==
nullptr
)
{
if
(
g_op_info_map
<
Dtype
>
==
nullptr
)
{
g_op_info_map
<
Dtype
>
=
new
OpInfoMap
();
g_op_info_map
<
Dtype
>
=
new
OpInfoMap
();
}
return
*
g_op_info_map
<
Dtype
>
;
};
bool
Has
(
const
std
::
string
&
op_type
)
const
{
return
map_
.
find
(
op_type
)
!=
map_
.
end
();
}
void
Insert
(
const
std
::
string
&
type
,
const
OpInfo
<
Dtype
>
&
info
)
{
// PADDLE_ENFORCE(!Has(type), "Operator %s has been
// registered", type);
map_
.
insert
({
type
,
info
});
}
const
OpInfo
<
Dtype
>
&
Get
(
const
std
::
string
&
type
)
const
{
auto
op_info_ptr
=
GetNullable
(
type
);
// PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not
// been
// registered",
// type);
return
*
op_info_ptr
;
}
}
return
*
g_op_info_map
<
Dtype
>
;
const
OpInfo
<
Dtype
>
*
GetNullable
(
const
std
::
string
&
type
)
const
{
};
auto
it
=
map_
.
find
(
type
);
if
(
it
==
map_
.
end
())
{
bool
Has
(
const
std
::
string
&
op_type
)
const
{
return
nullptr
;
return
map_
.
find
(
op_type
)
!=
map_
.
end
();
}
else
{
}
return
&
it
->
second
;
}
void
Insert
(
const
std
::
string
&
type
,
const
OpInfo
<
Dtype
>
&
info
)
{
// PADDLE_ENFORCE(!Has(type), "Operator %s has been
// registered", type);
map_
.
insert
({
type
,
info
});
}
const
OpInfo
<
Dtype
>
&
Get
(
const
std
::
string
&
type
)
const
{
auto
op_info_ptr
=
GetNullable
(
type
);
// PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not
// been
// registered",
// type);
return
*
op_info_ptr
;
}
const
OpInfo
<
Dtype
>
*
GetNullable
(
const
std
::
string
&
type
)
const
{
auto
it
=
map_
.
find
(
type
);
if
(
it
==
map_
.
end
())
{
return
nullptr
;
}
else
{
return
&
it
->
second
;
}
}
}
const
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
&
map
()
const
{
const
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
&
map
()
const
{
return
map_
;
return
map_
;
}
}
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
*
mutable_map
()
{
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
*
mutable_map
()
{
return
&
map_
;
return
&
map_
;
}
}
private:
private:
OpInfoMap
()
=
default
;
OpInfoMap
()
=
default
;
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
map_
;
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
map_
;
// DISABLE_COPY_AND_ASSIGN(OpInfoMap);
// DISABLE_COPY_AND_ASSIGN(OpInfoMap);
};
};
}
// namespace framework
}
// namespace framework
...
...
src/framework/op_kernel_type.h
浏览文件 @
1ad8f821
...
@@ -24,41 +24,40 @@ SOFTWARE.
...
@@ -24,41 +24,40 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
struct
OpKernelType
{
struct
OpKernelType
{
struct
Hash
{
struct
Hash
{
size_t
operator
()(
const
OpKernelType
&
key
)
const
{
size_t
operator
()(
const
OpKernelType
&
key
)
const
{
int
data_type
=
static_cast
<
int
>
(
key
.
data_type_
)
<<
LEFT_SHIFT
;
int
data_type
=
static_cast
<
int
>
(
key
.
data_type_
)
<<
LEFT_SHIFT
;
int
data_layout
=
static_cast
<
int
>
(
key
.
data_layout_
)
int
data_layout
=
static_cast
<
int
>
(
key
.
data_layout_
)
<<
(
LEFT_SHIFT
*
2
);
<<
(
LEFT_SHIFT
*
2
);
std
::
hash
<
int
>
hasher
;
std
::
hash
<
int
>
hasher
;
return
hasher
(
data_type
+
data_layout
);
return
hasher
(
data_type
+
data_layout
);
}
}
};
};
// place, data_type, library_type kinds less than 2^8
// place, data_type, library_type kinds less than 2^8
constexpr
static
int
LEFT_SHIFT
=
8
;
constexpr
static
int
LEFT_SHIFT
=
8
;
proto
::
VarType
::
Type
data_type_
;
proto
::
VarType
::
Type
data_type_
;
DataLayout
data_layout_
;
DataLayout
data_layout_
;
OpKernelType
(
proto
::
VarType
::
Type
data_type
,
OpKernelType
(
proto
::
VarType
::
Type
data_type
,
DataLayout
data_layout
=
DataLayout
::
kAnyLayout
)
DataLayout
data_layout
=
DataLayout
::
kAnyLayout
)
:
data_type_
(
data_type
),
data_layout_
(
data_layout
)
{}
:
data_type_
(
data_type
),
data_layout_
(
data_layout
)
{}
bool
operator
==
(
const
OpKernelType
&
o
)
const
{
bool
operator
==
(
const
OpKernelType
&
o
)
const
{
return
data_type_
==
o
.
data_type_
&&
data_layout_
==
o
.
data_layout_
;
return
data_type_
==
o
.
data_type_
&&
data_layout_
==
o
.
data_layout_
;
}
}
bool
operator
!=
(
const
OpKernelType
&
o
)
const
{
return
!
(
*
this
==
o
);
}
bool
operator
!=
(
const
OpKernelType
&
o
)
const
{
return
!
(
*
this
==
o
);
}
};
};
inline
bool
NeedTransformLayout
(
const
DataLayout
&
l
,
const
DataLayout
&
r
)
{
inline
bool
NeedTransformLayout
(
const
DataLayout
&
l
,
const
DataLayout
&
r
)
{
return
l
!=
DataLayout
::
kAnyLayout
&&
r
!=
DataLayout
::
kAnyLayout
&&
l
!=
r
;
return
l
!=
DataLayout
::
kAnyLayout
&&
r
!=
DataLayout
::
kAnyLayout
&&
l
!=
r
;
}
}
inline
bool
TransFromNeeded
(
const
OpKernelType
&
l
,
const
OpKernelType
&
r
)
{
inline
bool
TransFromNeeded
(
const
OpKernelType
&
l
,
const
OpKernelType
&
r
)
{
return
(
l
.
data_type_
!=
r
.
data_type_
)
||
return
(
l
.
data_type_
!=
r
.
data_type_
)
||
NeedTransformLayout
(
l
.
data_layout_
,
r
.
data_layout_
);
NeedTransformLayout
(
l
.
data_layout_
,
r
.
data_layout_
);
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/operator.cpp
浏览文件 @
1ad8f821
...
@@ -30,7 +30,7 @@ OperatorBase<Dtype>::OperatorBase(const std::string &type,
...
@@ -30,7 +30,7 @@ OperatorBase<Dtype>::OperatorBase(const std::string &type,
std
::
shared_ptr
<
Scope
>
scope
)
std
::
shared_ptr
<
Scope
>
scope
)
:
type_
(
type
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs
),
:
type_
(
type
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs
),
scope_
(
scope
)
{
scope_
(
scope
)
{
CheckAllInputOutputSet
();
CheckAllInputOutputSet
();
}
}
template
<
typename
Dtype
>
template
<
typename
Dtype
>
void
OperatorBase
<
Dtype
>::
CheckAllInputOutputSet
()
const
{}
void
OperatorBase
<
Dtype
>::
CheckAllInputOutputSet
()
const
{}
...
...
src/framework/operator.h
浏览文件 @
1ad8f821
...
@@ -49,50 +49,50 @@ static std::unordered_map<
...
@@ -49,50 +49,50 @@ static std::unordered_map<
{
"fetch"
,
{{
"X"
},
{
"Out"
}}}};
{
"fetch"
,
{{
"X"
},
{
"Out"
}}}};
template
<
typename
Dtype
>
class
OperatorBase
:
PaddleMobileObject
{
template
<
typename
Dtype
>
class
OperatorBase
:
PaddleMobileObject
{
public:
public:
OperatorBase
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
OperatorBase
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
std
::
shared_ptr
<
Scope
>
scope
);
std
::
shared_ptr
<
Scope
>
scope
);
virtual
~
OperatorBase
()
{}
virtual
~
OperatorBase
()
{}
virtual
void
Run
()
const
=
0
;
virtual
void
Run
()
const
=
0
;
const
VariableNameMap
&
Inputs
()
const
{
return
inputs_
;
}
const
VariableNameMap
&
Inputs
()
const
{
return
inputs_
;
}
const
VariableNameMap
&
Outputs
()
const
{
return
outputs_
;
}
const
VariableNameMap
&
Outputs
()
const
{
return
outputs_
;
}
const
std
::
string
&
Type
()
const
{
return
type_
;
}
const
std
::
string
&
Type
()
const
{
return
type_
;
}
const
AttributeMap
&
Attrs
()
const
{
return
attrs_
;
}
const
AttributeMap
&
Attrs
()
const
{
return
attrs_
;
}
void
ClearVariables
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
{
void
ClearVariables
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
{
if
(
this
->
scope_
)
{
if
(
this
->
scope_
)
{
this
->
scope_
->
EraseVars
(
var_names
);
this
->
scope_
->
EraseVars
(
var_names
);
}
}
}
}
protected:
protected:
std
::
shared_ptr
<
Scope
>
scope_
;
std
::
shared_ptr
<
Scope
>
scope_
;
std
::
string
type_
;
std
::
string
type_
;
VariableNameMap
inputs_
;
VariableNameMap
inputs_
;
VariableNameMap
outputs_
;
VariableNameMap
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
private:
private:
void
CheckAllInputOutputSet
()
const
;
void
CheckAllInputOutputSet
()
const
;
};
};
template
<
typename
Dtype
>
template
<
typename
Dtype
>
class
OperatorWithKernel
:
public
OperatorBase
<
Dtype
>
{
class
OperatorWithKernel
:
public
OperatorBase
<
Dtype
>
{
public:
public:
OperatorWithKernel
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
OperatorWithKernel
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
output
s
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attr
s
,
const
AttributeMap
&
attrs
,
std
::
shared_ptr
<
Scope
>
scope
)
std
::
shared_ptr
<
Scope
>
scope
)
:
OperatorBase
<
Dtype
>
(
type
,
inputs
,
outputs
,
attrs
,
scope
)
{}
:
OperatorBase
<
Dtype
>
(
type
,
inputs
,
outputs
,
attrs
,
scope
)
{}
virtual
void
InferShape
()
const
=
0
;
virtual
void
InferShape
()
const
=
0
;
virtual
void
Run
()
const
=
0
;
virtual
void
Run
()
const
=
0
;
};
};
template
<
typename
Dtype
,
typename
P
>
class
OpKernelBase
:
PaddleMobileObject
{
template
<
typename
Dtype
,
typename
P
>
class
OpKernelBase
:
PaddleMobileObject
{
public:
public:
virtual
void
Compute
(
const
P
&
para
)
const
=
0
;
virtual
void
Compute
(
const
P
&
para
)
const
=
0
;
virtual
~
OpKernelBase
()
=
default
;
virtual
~
OpKernelBase
()
=
default
;
};
};
}
// namespace framework
}
// namespace framework
...
...
src/framework/paddle_mobile_object.h
浏览文件 @
1ad8f821
...
@@ -24,13 +24,13 @@ SOFTWARE.
...
@@ -24,13 +24,13 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
class
PaddleMobileObject
{
class
PaddleMobileObject
{
public:
public:
virtual
std
::
string
ToString
()
{
virtual
std
::
string
ToString
()
{
char
address
[
128
]
=
{
0
};
char
address
[
128
]
=
{
0
};
sprintf
(
address
,
"%p"
,
this
);
sprintf
(
address
,
"%p"
,
this
);
return
std
::
string
(
address
);
return
std
::
string
(
address
);
}
}
private:
private:
};
};
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/program-optimize/node.cpp
浏览文件 @
1ad8f821
...
@@ -25,71 +25,71 @@ namespace paddle_mobile {
...
@@ -25,71 +25,71 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
Node
&
Node
::
operator
>
(
std
::
shared_ptr
<
Node
>
node
)
{
Node
&
Node
::
operator
>
(
std
::
shared_ptr
<
Node
>
node
)
{
outputs_
.
push_back
(
node
);
outputs_
.
push_back
(
node
);
std
::
shared_ptr
<
Node
>
this_node
;
std
::
shared_ptr
<
Node
>
this_node
;
node
->
inputs_
.
push_back
(
this
);
node
->
inputs_
.
push_back
(
this
);
return
*
node
;
return
*
node
;
}
}
bool
Node
::
operator
==
(
const
Node
&
in
)
{
bool
Node
::
operator
==
(
const
Node
&
in
)
{
if
(
in
.
type_
==
this
->
type_
)
{
if
(
in
.
type_
==
this
->
type_
)
{
if
(
this
->
outputs_
.
size
()
==
in
.
outputs_
.
size
())
{
if
(
this
->
outputs_
.
size
()
==
in
.
outputs_
.
size
())
{
for
(
int
i
=
0
;
i
<
outputs_
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
outputs_
.
size
();
++
i
)
{
if
(
!
(
*
outputs_
[
i
]
==
*
in
.
outputs_
[
i
]))
{
if
(
!
(
*
outputs_
[
i
]
==
*
in
.
outputs_
[
i
]))
{
return
false
;
return
false
;
}
}
}
else
{
return
false
;
}
}
}
}
else
{
}
else
{
return
false
;
return
false
;
}
}
return
true
;
}
else
{
return
false
;
}
return
true
;
}
}
std
::
string
Node
::
ToString
(
std
::
string
blank
,
const
Node
*
node
)
const
{
std
::
string
Node
::
ToString
(
std
::
string
blank
,
const
Node
*
node
)
const
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
type_
<<
"->
\n
"
;
ss
<<
type_
<<
"->
\n
"
;
if
(
inputs_
.
size
()
>
1
&&
node
!=
inputs_
.
back
())
{
if
(
inputs_
.
size
()
>
1
&&
node
!=
inputs_
.
back
())
{
return
ss
.
str
();
}
else
if
(
inputs_
.
size
()
>
1
&&
node
==
inputs_
.
back
())
{
ss
<<
"
\n
"
<<
blank
<<
type_
<<
"
\n
"
;
}
for
(
int
i
=
0
;
i
<
outputs_
.
size
();
++
i
)
{
ss
<<
blank
<<
outputs_
[
i
]
->
ToString
(
blank
+
" "
,
this
)
<<
""
;
}
return
ss
.
str
();
return
ss
.
str
();
}
else
if
(
inputs_
.
size
()
>
1
&&
node
==
inputs_
.
back
())
{
ss
<<
"
\n
"
<<
blank
<<
type_
<<
"
\n
"
;
}
for
(
int
i
=
0
;
i
<
outputs_
.
size
();
++
i
)
{
ss
<<
blank
<<
outputs_
[
i
]
->
ToString
(
blank
+
" "
,
this
)
<<
""
;
}
return
ss
.
str
();
}
}
std
::
string
Node
::
ToString
()
const
{
return
this
->
ToString
(
" "
,
this
);
}
std
::
string
Node
::
ToString
()
const
{
return
this
->
ToString
(
" "
,
this
);
}
Node
&
Node
::
To
(
int
index
)
{
Node
&
Node
::
To
(
int
index
)
{
if
(
index
==
0
)
{
if
(
index
==
0
)
{
this
->
outputs_
.
clear
();
this
->
outputs_
.
clear
();
}
}
for
(
int
j
=
0
;
j
<
this
->
outputs_
.
size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
this
->
outputs_
.
size
();
++
j
)
{
outputs_
[
j
]
->
To
(
index
-
1
);
outputs_
[
j
]
->
To
(
index
-
1
);
}
}
return
*
this
;
return
*
this
;
}
}
uint
Node
::
depth
(
uint
begin
)
{
uint
Node
::
depth
(
uint
begin
)
{
uint
depth
=
0
;
uint
depth
=
0
;
begin
++
;
begin
++
;
for
(
int
i
=
0
;
i
<
outputs_
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
outputs_
.
size
();
++
i
)
{
uint
output_depth
=
outputs_
[
i
]
->
depth
(
begin
);
uint
output_depth
=
outputs_
[
i
]
->
depth
(
begin
);
depth
=
output_depth
>
depth
?
output_depth
:
depth
;
depth
=
output_depth
>
depth
?
output_depth
:
depth
;
}
}
return
begin
>
depth
?
begin
:
depth
;
return
begin
>
depth
?
begin
:
depth
;
}
}
Print
&
operator
<<
(
Print
&
printer
,
const
Node
&
node
)
{
Print
&
operator
<<
(
Print
&
printer
,
const
Node
&
node
)
{
printer
<<
node
.
ToString
();
printer
<<
node
.
ToString
();
return
printer
;
return
printer
;
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/program-optimize/node.h
浏览文件 @
1ad8f821
...
@@ -29,22 +29,22 @@ namespace paddle_mobile {
...
@@ -29,22 +29,22 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
Node
:
PaddleMobileObject
{
class
Node
:
PaddleMobileObject
{
public:
public:
Node
(
const
std
::
string
&
type
)
:
type_
(
type
)
{}
Node
(
const
std
::
string
&
type
)
:
type_
(
type
)
{}
Node
(
std
::
shared_ptr
<
OpDesc
>
op_desc
)
Node
(
std
::
shared_ptr
<
OpDesc
>
op_desc
)
:
op_desc_
(
op_desc
),
type_
(
op_desc
->
Type
()){};
:
op_desc_
(
op_desc
),
type_
(
op_desc
->
Type
()){};
Node
&
operator
>
(
std
::
shared_ptr
<
Node
>
node
);
Node
&
operator
>
(
std
::
shared_ptr
<
Node
>
node
);
bool
operator
==
(
const
Node
&
in
);
bool
operator
==
(
const
Node
&
in
);
std
::
string
ToString
()
const
;
std
::
string
ToString
()
const
;
Node
&
To
(
int
index
);
Node
&
To
(
int
index
);
uint
depth
(
uint
begin
=
0
);
uint
depth
(
uint
begin
=
0
);
private:
private:
std
::
shared_ptr
<
OpDesc
>
op_desc_
;
std
::
shared_ptr
<
OpDesc
>
op_desc_
;
std
::
string
ToString
(
std
::
string
blank
,
const
Node
*
node
)
const
;
std
::
string
ToString
(
std
::
string
blank
,
const
Node
*
node
)
const
;
std
::
vector
<
std
::
shared_ptr
<
Node
>>
outputs_
;
std
::
vector
<
std
::
shared_ptr
<
Node
>>
outputs_
;
std
::
vector
<
Node
*>
inputs_
;
std
::
vector
<
Node
*>
inputs_
;
std
::
string
type_
;
std
::
string
type_
;
};
};
Print
&
operator
<<
(
Print
&
printer
,
const
Node
&
node
);
Print
&
operator
<<
(
Print
&
printer
,
const
Node
&
node
);
...
...
src/framework/program-optimize/program_optimize.cpp
浏览文件 @
1ad8f821
...
@@ -26,49 +26,48 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {}
...
@@ -26,49 +26,48 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::Optimize() {}
std
::
shared_ptr
<
ProgramDesc
>
std
::
shared_ptr
<
ProgramDesc
>
ProgramOptimize
::
FushionOptimize
(
std
::
shared_ptr
<
ProgramDesc
>
ori_des
)
{
ProgramOptimize
::
FushionOptimize
(
std
::
shared_ptr
<
ProgramDesc
>
ori_des
)
{
for
(
int
i
=
0
;
i
<
ori_des
->
Blocks
().
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
ori_des
->
Blocks
().
size
();
++
i
)
{
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Node
>>
output_nodes
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Node
>>
output_nodes
;
std
::
shared_ptr
<
Node
>
begin_node
;
std
::
shared_ptr
<
Node
>
begin_node
;
auto
block
=
ori_des
->
Block
(
i
);
auto
block
=
ori_des
->
Block
(
i
);
// DLOG << " ops size: " << block->Ops().size();
// DLOG << " ops size: " << block->Ops().size();
for
(
int
j
=
0
;
j
<
block
->
Ops
().
size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
block
->
Ops
().
size
();
++
j
)
{
auto
op
=
block
->
Ops
()[
j
];
auto
op
=
block
->
Ops
()[
j
];
auto
op_type
=
op
->
Type
();
auto
op_type
=
op
->
Type
();
// DLOG << "op type: " << op_type << " index: " << j;
// DLOG << "op type: " << op_type << " index: " << j;
if
(
op_input_output_key
.
find
(
op
->
Type
())
==
if
(
op_input_output_key
.
find
(
op
->
Type
())
==
op_input_output_key
.
end
())
{
op_input_output_key
.
end
())
{
return
NULL
;
return
NULL
;
}
}
std
::
shared_ptr
<
Node
>
node
=
std
::
make_shared
<
Node
>
(
op
);
std
::
shared_ptr
<
Node
>
node
=
std
::
make_shared
<
Node
>
(
op
);
if
(
j
==
0
)
{
if
(
j
==
0
)
{
begin_node
=
node
;
begin_node
=
node
;
}
}
auto
input_keys
=
op_input_output_key
.
at
(
op
->
Type
()).
first
;
auto
input_keys
=
op_input_output_key
.
at
(
op
->
Type
()).
first
;
for
(
auto
input_key
:
input_keys
)
{
for
(
auto
input_key
:
input_keys
)
{
auto
op_inputs
=
op
->
Input
(
input_key
);
auto
op_inputs
=
op
->
Input
(
input_key
);
for
(
int
l
=
0
;
l
<
op_inputs
.
size
();
++
l
)
{
for
(
int
l
=
0
;
l
<
op_inputs
.
size
();
++
l
)
{
std
::
string
input_key
=
op_inputs
[
l
];
std
::
string
input_key
=
op_inputs
[
l
];
if
(
output_nodes
.
find
(
input_key
)
!=
output_nodes
.
end
())
{
if
(
output_nodes
.
find
(
input_key
)
!=
output_nodes
.
end
())
{
auto
input_node
=
output_nodes
[
input_key
];
auto
input_node
=
output_nodes
[
input_key
];
*
input_node
>
node
;
*
input_node
>
node
;
}
}
}
}
auto
output_keys
=
op_input_output_key
.
at
(
op_type
).
second
;
for
(
auto
output_key
:
output_keys
)
{
auto
op_outputs
=
op
->
Output
(
output_key
);
for
(
int
k
=
0
;
k
<
op_outputs
.
size
();
++
k
)
{
output_nodes
[
op_outputs
[
k
]]
=
node
;
}
}
}
}
}
DLOG
<<
"node:
\n
"
<<
*
begin_node
;
auto
output_keys
=
op_input_output_key
.
at
(
op_type
).
second
;
for
(
auto
output_key
:
output_keys
)
{
auto
op_outputs
=
op
->
Output
(
output_key
);
for
(
int
k
=
0
;
k
<
op_outputs
.
size
();
++
k
)
{
output_nodes
[
op_outputs
[
k
]]
=
node
;
}
}
}
}
return
ori_des
;
DLOG
<<
"node:
\n
"
<<
*
begin_node
;
}
return
ori_des
;
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/program-optimize/program_optimize.h
浏览文件 @
1ad8f821
...
@@ -26,16 +26,16 @@ namespace paddle_mobile {
...
@@ -26,16 +26,16 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
ProgramOptimize
{
class
ProgramOptimize
{
public:
public:
ProgramOptimize
()
{}
ProgramOptimize
()
{}
std
::
shared_ptr
<
ProgramDesc
>
Optimize
();
std
::
shared_ptr
<
ProgramDesc
>
Optimize
();
std
::
shared_ptr
<
ProgramDesc
>
std
::
shared_ptr
<
ProgramDesc
>
FushionOptimize
(
std
::
shared_ptr
<
ProgramDesc
>
ori_des
);
FushionOptimize
(
std
::
shared_ptr
<
ProgramDesc
>
ori_des
);
private:
private:
// std::shared_ptr<ProgramDesc> ori_desc_;
// std::shared_ptr<ProgramDesc> ori_desc_;
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Node
>>>
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Node
>>>
outputs_nodes_
;
outputs_nodes_
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/program.h
浏览文件 @
1ad8f821
...
@@ -28,12 +28,12 @@ namespace framework {
...
@@ -28,12 +28,12 @@ namespace framework {
template
<
typename
Dtype
,
Precision
P
=
Precision
::
FP32
>
template
<
typename
Dtype
,
Precision
P
=
Precision
::
FP32
>
class
Program
:
PaddleMobileObject
{
class
Program
:
PaddleMobileObject
{
public:
public:
std
::
shared_ptr
<
ProgramDesc
>
originProgram
;
std
::
shared_ptr
<
ProgramDesc
>
originProgram
;
std
::
shared_ptr
<
ProgramDesc
>
optimizeProgram
;
std
::
shared_ptr
<
ProgramDesc
>
optimizeProgram
;
std
::
shared_ptr
<
Scope
>
scope
;
std
::
shared_ptr
<
Scope
>
scope
;
private:
private:
};
};
}
// namespace framework
}
// namespace framework
...
...
src/framework/program_desc.cpp
浏览文件 @
1ad8f821
...
@@ -8,14 +8,14 @@ namespace paddle_mobile {
...
@@ -8,14 +8,14 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
ProgramDesc
::
ProgramDesc
(
const
proto
::
ProgramDesc
&
desc
)
:
desc_
(
desc
)
{
ProgramDesc
::
ProgramDesc
(
const
proto
::
ProgramDesc
&
desc
)
:
desc_
(
desc
)
{
for
(
auto
&
block_desc
:
*
desc_
.
mutable_blocks
())
{
for
(
auto
&
block_desc
:
*
desc_
.
mutable_blocks
())
{
// new framework::BlockDesc(block_desc)
// new framework::BlockDesc(block_desc)
blocks_
.
emplace_back
(
std
::
make_shared
<
BlockDesc
>
(
block_desc
));
blocks_
.
emplace_back
(
std
::
make_shared
<
BlockDesc
>
(
block_desc
));
}
}
}
}
std
::
shared_ptr
<
BlockDesc
>
ProgramDesc
::
Block
(
size_t
idx
)
{
std
::
shared_ptr
<
BlockDesc
>
ProgramDesc
::
Block
(
size_t
idx
)
{
return
blocks_
[
idx
];
return
blocks_
[
idx
];
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/program_desc.h
浏览文件 @
1ad8f821
...
@@ -28,14 +28,14 @@ namespace paddle_mobile {
...
@@ -28,14 +28,14 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
ProgramDesc
:
PaddleMobileObject
{
class
ProgramDesc
:
PaddleMobileObject
{
public:
public:
ProgramDesc
(
const
proto
::
ProgramDesc
&
desc
);
ProgramDesc
(
const
proto
::
ProgramDesc
&
desc
);
std
::
shared_ptr
<
BlockDesc
>
Block
(
size_t
idx
);
std
::
shared_ptr
<
BlockDesc
>
Block
(
size_t
idx
);
const
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
&
Blocks
()
{
return
blocks_
;
};
const
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
&
Blocks
()
{
return
blocks_
;
};
private:
private:
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
blocks_
;
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
blocks_
;
proto
::
ProgramDesc
desc_
;
proto
::
ProgramDesc
desc_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
src/framework/scope.cc
浏览文件 @
1ad8f821
...
@@ -7,20 +7,20 @@ namespace paddle_mobile {
...
@@ -7,20 +7,20 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
Scope
&
Scope
::
NewScope
()
const
{
Scope
&
Scope
::
NewScope
()
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
kids_
.
push_back
(
new
Scope
(
this
));
kids_
.
push_back
(
new
Scope
(
this
));
return
*
kids_
.
back
();
return
*
kids_
.
back
();
}
}
Variable
*
Scope
::
Var
(
const
std
::
string
&
name
)
{
Variable
*
Scope
::
Var
(
const
std
::
string
&
name
)
{
auto
*
pvar
=
FindVarLocally
(
name
);
auto
*
pvar
=
FindVarLocally
(
name
);
if
(
pvar
!=
nullptr
)
{
if
(
pvar
!=
nullptr
)
{
return
pvar
;
};
pvar
=
new
Variable
;
vars_
[
name
]
=
pvar
;
pvar
->
name_
=
&
(
vars_
.
find
(
name
)
->
first
);
return
pvar
;
return
pvar
;
};
pvar
=
new
Variable
;
vars_
[
name
]
=
pvar
;
pvar
->
name_
=
&
(
vars_
.
find
(
name
)
->
first
);
return
pvar
;
}
}
// Variable* Scope::Var(std::string* name) {
// Variable* Scope::Var(std::string* name) {
...
@@ -33,70 +33,70 @@ Variable *Scope::Var(const std::string &name) {
...
@@ -33,70 +33,70 @@ Variable *Scope::Var(const std::string &name) {
// }
// }
Variable
*
Scope
::
FindVar
(
const
std
::
string
&
name
)
const
{
Variable
*
Scope
::
FindVar
(
const
std
::
string
&
name
)
const
{
auto
*
pvar
=
FindVarLocally
(
name
);
auto
*
pvar
=
FindVarLocally
(
name
);
if
(
pvar
!=
nullptr
)
{
if
(
pvar
!=
nullptr
)
{
return
pvar
;
return
pvar
;
}
}
return
(
parent_
==
nullptr
)
?
nullptr
:
parent_
->
FindVar
(
name
);
return
(
parent_
==
nullptr
)
?
nullptr
:
parent_
->
FindVar
(
name
);
}
}
const
Scope
*
Scope
::
FindScope
(
const
Variable
*
var
)
const
{
const
Scope
*
Scope
::
FindScope
(
const
Variable
*
var
)
const
{
for
(
auto
&
name_var
:
vars_
)
{
for
(
auto
&
name_var
:
vars_
)
{
if
(
name_var
.
second
==
var
)
{
if
(
name_var
.
second
==
var
)
{
return
this
;
return
this
;
}
}
}
return
(
parent_
==
nullptr
)
?
nullptr
:
parent_
->
FindScope
(
var
);
}
return
(
parent_
==
nullptr
)
?
nullptr
:
parent_
->
FindScope
(
var
);
}
}
void
Scope
::
DropKids
()
{
void
Scope
::
DropKids
()
{
for
(
Scope
*
s
:
kids_
)
{
for
(
Scope
*
s
:
kids_
)
{
delete
s
;
delete
s
;
}
}
kids_
.
clear
();
kids_
.
clear
();
}
}
std
::
vector
<
std
::
string
>
Scope
::
LocalVarNames
()
const
{
std
::
vector
<
std
::
string
>
Scope
::
LocalVarNames
()
const
{
std
::
vector
<
std
::
string
>
known_vars
;
std
::
vector
<
std
::
string
>
known_vars
;
known_vars
.
reserve
(
vars_
.
size
());
known_vars
.
reserve
(
vars_
.
size
());
for
(
auto
&
name_var
:
vars_
)
{
for
(
auto
&
name_var
:
vars_
)
{
known_vars
.
emplace_back
(
name_var
.
first
);
known_vars
.
emplace_back
(
name_var
.
first
);
}
}
return
known_vars
;
return
known_vars
;
}
}
void
Scope
::
DeleteScope
(
Scope
*
scope
)
const
{
void
Scope
::
DeleteScope
(
Scope
*
scope
)
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
auto
it
=
std
::
find
(
kids_
.
begin
(),
kids_
.
end
(),
scope
);
auto
it
=
std
::
find
(
kids_
.
begin
(),
kids_
.
end
(),
scope
);
kids_
.
erase
(
it
);
kids_
.
erase
(
it
);
delete
scope
;
delete
scope
;
// deferent
// deferent
}
}
void
Scope
::
EraseVars
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
{
void
Scope
::
EraseVars
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
{
std
::
set
<
std
::
string
>
var_set
(
var_names
.
begin
(),
var_names
.
end
());
std
::
set
<
std
::
string
>
var_set
(
var_names
.
begin
(),
var_names
.
end
());
for
(
auto
it
=
vars_
.
begin
();
it
!=
vars_
.
end
();)
{
for
(
auto
it
=
vars_
.
begin
();
it
!=
vars_
.
end
();)
{
if
(
var_set
.
find
(
it
->
first
)
!=
var_set
.
end
())
{
if
(
var_set
.
find
(
it
->
first
)
!=
var_set
.
end
())
{
delete
it
->
second
;
delete
it
->
second
;
it
=
vars_
.
erase
(
it
);
it
=
vars_
.
erase
(
it
);
}
else
{
}
else
{
++
it
;
++
it
;
}
}
}
}
}
}
void
Scope
::
Rename
(
const
std
::
string
&
origin_name
,
void
Scope
::
Rename
(
const
std
::
string
&
origin_name
,
const
std
::
string
&
new_name
)
const
{
const
std
::
string
&
new_name
)
const
{
auto
origin_it
=
vars_
.
find
(
origin_name
);
auto
origin_it
=
vars_
.
find
(
origin_name
);
if
(
origin_it
==
vars_
.
end
())
{
if
(
origin_it
==
vars_
.
end
())
{
return
;
return
;
}
}
auto
new_it
=
vars_
.
find
(
new_name
);
auto
new_it
=
vars_
.
find
(
new_name
);
if
(
new_it
!=
vars_
.
end
())
{
if
(
new_it
!=
vars_
.
end
())
{
return
;
return
;
}
}
vars_
[
new_name
]
=
origin_it
->
second
;
vars_
[
new_name
]
=
origin_it
->
second
;
vars_
.
erase
(
origin_it
);
vars_
.
erase
(
origin_it
);
}
}
//
//
// std::string Scope::Rename(const std::string& origin_name)
// std::string Scope::Rename(const std::string& origin_name)
...
@@ -108,11 +108,11 @@ void Scope::Rename(const std::string &origin_name,
...
@@ -108,11 +108,11 @@ void Scope::Rename(const std::string &origin_name,
// }
// }
Variable
*
Scope
::
FindVarLocally
(
const
std
::
string
&
name
)
const
{
Variable
*
Scope
::
FindVarLocally
(
const
std
::
string
&
name
)
const
{
auto
it
=
vars_
.
find
(
name
);
auto
it
=
vars_
.
find
(
name
);
if
(
it
!=
vars_
.
end
())
{
if
(
it
!=
vars_
.
end
())
{
return
it
->
second
;
return
it
->
second
;
}
}
return
nullptr
;
return
nullptr
;
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/scope.h
浏览文件 @
1ad8f821
...
@@ -26,56 +26,56 @@ SOFTWARE.
...
@@ -26,56 +26,56 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
class
Scope
{
class
Scope
{
public:
public:
Scope
()
{}
Scope
()
{}
~
Scope
()
{}
~
Scope
()
{}
Scope
&
NewScope
()
const
;
Scope
&
NewScope
()
const
;
/// Create a variable with given name if it doesn't exist.
/// Create a variable with given name if it doesn't exist.
Variable
*
Var
(
const
std
::
string
&
name
);
Variable
*
Var
(
const
std
::
string
&
name
);
/// Create a variable with a scope-unique name.
/// Create a variable with a scope-unique name.
Variable
*
Var
(
std
::
string
*
name
=
nullptr
);
Variable
*
Var
(
std
::
string
*
name
=
nullptr
);
void
EraseVars
(
const
std
::
vector
<
std
::
string
>
&
var_names
);
void
EraseVars
(
const
std
::
vector
<
std
::
string
>
&
var_names
);
/// Find a variable in the scope or any of its ancestors. Returns
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
/// nullptr if cannot find.
Variable
*
FindVar
(
const
std
::
string
&
name
)
const
;
Variable
*
FindVar
(
const
std
::
string
&
name
)
const
;
const
Scope
*
parent
()
const
{
return
parent_
;
}
const
Scope
*
parent
()
const
{
return
parent_
;
}
/// Find the scope or an ancestor scope that contains the given
/// Find the scope or an ancestor scope that contains the given
/// variable.
/// variable.
const
Scope
*
FindScope
(
const
Variable
*
var
)
const
;
const
Scope
*
FindScope
(
const
Variable
*
var
)
const
;
void
DeleteScope
(
Scope
*
scope
)
const
;
void
DeleteScope
(
Scope
*
scope
)
const
;
/// Drop all kids scopes belonged to this scope.
/// Drop all kids scopes belonged to this scope.
void
DropKids
();
void
DropKids
();
// enumerate all the variables current contains.
// enumerate all the variables current contains.
std
::
vector
<
std
::
string
>
LocalVarNames
()
const
;
std
::
vector
<
std
::
string
>
LocalVarNames
()
const
;
// Rename variable to a new name
// Rename variable to a new name
void
Rename
(
const
std
::
string
&
origin_name
,
void
Rename
(
const
std
::
string
&
origin_name
,
const
std
::
string
&
new_name
)
const
;
const
std
::
string
&
new_name
)
const
;
// Rename variable to a new name and return the new name
// Rename variable to a new name and return the new name
std
::
string
Rename
(
const
std
::
string
&
origin_name
)
const
;
std
::
string
Rename
(
const
std
::
string
&
origin_name
)
const
;
Variable
*
FindVarLocally
(
const
std
::
string
&
name
)
const
;
Variable
*
FindVarLocally
(
const
std
::
string
&
name
)
const
;
private:
private:
// Call Scope::NewScope for a sub-scope.
// Call Scope::NewScope for a sub-scope.
explicit
Scope
(
Scope
const
*
parent
)
:
parent_
(
parent
)
{}
explicit
Scope
(
Scope
const
*
parent
)
:
parent_
(
parent
)
{}
mutable
std
::
unordered_map
<
std
::
string
,
Variable
*>
vars_
;
mutable
std
::
unordered_map
<
std
::
string
,
Variable
*>
vars_
;
mutable
std
::
list
<
Scope
*>
kids_
;
mutable
std
::
list
<
Scope
*>
kids_
;
Scope
const
*
parent_
{
nullptr
};
Scope
const
*
parent_
{
nullptr
};
mutable
std
::
mutex
mutex_
;
mutable
std
::
mutex
mutex_
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/selected_rows.h
浏览文件 @
1ad8f821
...
@@ -27,54 +27,54 @@ namespace paddle_mobile {
...
@@ -27,54 +27,54 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
SelectedRows
{
class
SelectedRows
{
public:
public:
SelectedRows
(
const
std
::
vector
<
int64_t
>
&
rows
,
const
int64_t
&
height
)
SelectedRows
(
const
std
::
vector
<
int64_t
>
&
rows
,
const
int64_t
&
height
)
:
rows_
(
rows
),
height_
(
height
)
{
:
rows_
(
rows
),
height_
(
height
)
{
value_
.
reset
(
new
Tensor
());
value_
.
reset
(
new
Tensor
());
}
}
SelectedRows
()
{
SelectedRows
()
{
height_
=
0
;
height_
=
0
;
value_
.
reset
(
new
Tensor
());
value_
.
reset
(
new
Tensor
());
}
}
const
Tensor
&
value
()
const
{
return
*
value_
;
}
const
Tensor
&
value
()
const
{
return
*
value_
;
}
Tensor
*
mutable_value
()
{
return
value_
.
get
();
}
Tensor
*
mutable_value
()
{
return
value_
.
get
();
}
int64_t
height
()
const
{
return
height_
;
}
int64_t
height
()
const
{
return
height_
;
}
void
set_height
(
int64_t
height
)
{
height_
=
height
;
}
void
set_height
(
int64_t
height
)
{
height_
=
height
;
}
const
std
::
vector
<
int64_t
>
&
rows
()
const
{
return
rows_
;
}
const
std
::
vector
<
int64_t
>
&
rows
()
const
{
return
rows_
;
}
std
::
vector
<
int64_t
>
*
mutable_rows
()
{
return
&
rows_
;
}
std
::
vector
<
int64_t
>
*
mutable_rows
()
{
return
&
rows_
;
}
void
set_rows
(
const
std
::
vector
<
int64_t
>
&
rows
)
{
rows_
=
rows
;
}
void
set_rows
(
const
std
::
vector
<
int64_t
>
&
rows
)
{
rows_
=
rows
;
}
/**
/**
* get the index of id in rows
* get the index of id in rows
*/
*/
int64_t
index
(
int64_t
id
)
const
{
int64_t
index
(
int64_t
id
)
const
{
auto
it
=
std
::
find
(
rows_
.
begin
(),
rows_
.
end
(),
id
);
auto
it
=
std
::
find
(
rows_
.
begin
(),
rows_
.
end
(),
id
);
// PADDLE_ENFORCE(it != rows_.end(), "id should be in rows");
// PADDLE_ENFORCE(it != rows_.end(), "id should be in rows");
return
static_cast
<
int64_t
>
(
std
::
distance
(
rows_
.
begin
(),
it
));
return
static_cast
<
int64_t
>
(
std
::
distance
(
rows_
.
begin
(),
it
));
}
}
DDim
GetCompleteDims
()
const
{
DDim
GetCompleteDims
()
const
{
std
::
vector
<
int64_t
>
dims
=
vectorize
(
value_
->
dims
());
std
::
vector
<
int64_t
>
dims
=
vectorize
(
value_
->
dims
());
dims
[
0
]
=
height_
;
dims
[
0
]
=
height_
;
return
make_ddim
(
dims
);
return
make_ddim
(
dims
);
}
}
private:
private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9}
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9}
// here.
// here.
// SelectedRows are simply concated when adding together. Until a
// SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
// SelectedRows add a Tensor, will the duplicate rows be handled.
std
::
vector
<
int64_t
>
rows_
;
std
::
vector
<
int64_t
>
rows_
;
std
::
unique_ptr
<
Tensor
>
value_
{
nullptr
};
std
::
unique_ptr
<
Tensor
>
value_
{
nullptr
};
int64_t
height_
;
int64_t
height_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
src/framework/tensor.h
浏览文件 @
1ad8f821
...
@@ -29,305 +29,304 @@ namespace framework {
...
@@ -29,305 +29,304 @@ namespace framework {
template
<
typename
...
T
>
struct
SizeOfTypeFunctor
;
template
<
typename
...
T
>
struct
SizeOfTypeFunctor
;
template
<
typename
T
>
struct
SizeOfTypeFunctor
<
T
>
{
template
<
typename
T
>
struct
SizeOfTypeFunctor
<
T
>
{
size_t
operator
()(
std
::
type_index
type
)
const
{
size_t
operator
()(
std
::
type_index
type
)
const
{
if
(
typeid
(
T
).
hash_code
()
==
type
.
hash_code
())
{
if
(
typeid
(
T
).
hash_code
()
==
type
.
hash_code
())
{
return
sizeof
(
T
);
return
sizeof
(
T
);
}
else
{
}
else
{
return
0UL
;
return
0UL
;
}
}
}
}
};
};
template
<
>
struct
SizeOfTypeFunctor
<>
{
template
<
>
struct
SizeOfTypeFunctor
<>
{
size_t
operator
()(
std
::
type_index
type
)
const
{
return
0UL
;
}
size_t
operator
()(
std
::
type_index
type
)
const
{
return
0UL
;
}
};
};
template
<
typename
HEAD
,
typename
...
TAIL
>
template
<
typename
HEAD
,
typename
...
TAIL
>
struct
SizeOfTypeFunctor
<
HEAD
,
TAIL
...
>
{
struct
SizeOfTypeFunctor
<
HEAD
,
TAIL
...
>
{
size_t
operator
()(
std
::
type_index
type
)
const
{
size_t
operator
()(
std
::
type_index
type
)
const
{
SizeOfTypeFunctor
<
HEAD
>
head
;
SizeOfTypeFunctor
<
HEAD
>
head
;
size_t
head_size
=
head
(
type
);
size_t
head_size
=
head
(
type
);
if
(
head_size
!=
0
)
{
if
(
head_size
!=
0
)
{
return
head_size
;
return
head_size
;
}
SizeOfTypeFunctor
<
TAIL
...
>
tail
;
return
tail
(
type
);
}
}
SizeOfTypeFunctor
<
TAIL
...
>
tail
;
return
tail
(
type
);
}
};
};
static
inline
size_t
SizeOfType
(
std
::
type_index
type
)
{
static
inline
size_t
SizeOfType
(
std
::
type_index
type
)
{
SizeOfTypeFunctor
<
int
,
float
,
double
,
int16_t
,
int64_t
,
bool
,
size_t
>
SizeOfTypeFunctor
<
int
,
float
,
double
,
int16_t
,
int64_t
,
bool
,
size_t
>
functor
;
functor
;
size_t
size
=
functor
(
type
);
size_t
size
=
functor
(
type
);
// PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s",
// PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s",
// type.name());
// type.name());
return
size
;
return
size
;
}
}
class
LoDTensor
;
class
LoDTensor
;
class
Tensor
{
class
Tensor
{
public:
public:
Tensor
()
:
offset_
(
0
)
{}
Tensor
()
:
offset_
(
0
)
{}
/*! Return a pointer to mutable memory block. */
/*! Return a pointer to mutable memory block. */
template
<
typename
T
>
inline
T
*
data
()
{
template
<
typename
T
>
inline
T
*
data
()
{
check_memory_size
();
check_memory_size
();
// PADDLE_ENFORCE(std::is_same<T, void>::value ||
// PADDLE_ENFORCE(std::is_same<T, void>::value ||
// holder_->type().hash_code() ==
// holder_->type().hash_code() ==
// typeid(T).hash_code(),
// typeid(T).hash_code(),
// "Tensor holds the wrong type, it holds %s",
// "Tensor holds the wrong type, it holds %s",
// this->holder_->type().name());
// this->holder_->type().name());
return
reinterpret_cast
<
T
*>
(
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
offset_
);
}
/*! Return a pointer to constant memory block. */
template
<
typename
T
>
inline
const
T
*
data
()
const
{
check_memory_size
();
// PADDLE_ENFORCE(std::is_same<T, void>::value ||
// holder_->type().hash_code() ==
// typeid(T).hash_code(),
// "Tensor holds the wrong type, it holds %s",
// this->holder_->type().name());
return
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
inline
bool
IsInitialized
()
const
{
return
holder_
!=
nullptr
;
}
/**
* @brief Return a pointer to mutable memory block.
* @note If not exist, then allocation.
*/
template
<
typename
T
>
inline
T
*
mutable_data
()
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
return
reinterpret_cast
<
T
*>
(
mutable_data
(
typeid
(
T
)));
}
inline
void
*
mutable_data
(
std
::
type_index
type
)
{
if
(
holder_
!=
nullptr
)
{
holder_
->
set_type
(
type
);
}
}
// PADDLE_ENFORCE_GE(numel(), 0,
/
*! Return a pointer to constant memory block. */
/
/ "When calling this method, the Tensor's
template
<
typename
T
>
inline
const
T
*
data
()
const
{
// numel must be
check_memory_size
();
// " "equal or larger than zero. " "Please
// PADDLE_ENFORCE(std::is_same<T, void>::value ||
// check
// holder_->type().hash_code() ==
// Tensor::Resize has been called first.");
// typeid(T).hash_code(),
int64_t
size
=
numel
()
*
SizeOfType
(
type
);
// "Tensor holds the wrong type, it holds %s",
/* some versions of boost::variant don't have operator!= */
// this->holder_->type().name());
if
(
holder_
==
nullptr
||
holder_
->
size
()
<
size
+
offset_
)
{
holder_
.
reset
(
new
PlaceholderImpl
(
size
,
type
));
return
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
)
;
offset_
=
0
;
}
}
return
reinterpret_cast
<
void
*>
(
inline
bool
IsInitialized
()
const
{
return
holder_
!=
nullptr
;
}
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
/**
* @brief Return a pointer to mutable memory block.
inline
void
*
mutable_data
()
{
* @note If not exist, then allocation.
// PADDLE_ENFORCE(this->holder_ != nullptr,
*/
// "Cannot invoke mutable data if current hold
template
<
typename
T
>
inline
T
*
mutable_data
()
{
// nothing.");
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
return
mutable_data
(
holder_
->
type
());
return
reinterpret_cast
<
T
*>
(
mutable_data
(
typeid
(
T
)));
}
/**
* @brief Return a pointer to mutable memory block.
*
* @param[in] dims The dimensions of the memory block.
* @param[in] place The place of the memory block.
*
* @note If not exist, then allocation.
*/
template
<
typename
T
>
inline
T
*
mutable_data
(
DDim
dims
)
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
Resize
(
dims
);
return
mutable_data
<
T
>
();
}
/*! Return the dimensions of the memory block. */
inline
const
DDim
&
dims
()
const
{
return
dims_
;
}
/*! Return the numel of the memory block. */
inline
int64_t
numel
()
const
{
return
product
(
dims_
);
}
/*! Resize the dimensions of the memory block. */
inline
Tensor
&
Resize
(
const
DDim
&
dims
)
{
dims_
=
dims
;
return
*
this
;
}
/*! The internal of two tensors share the same memory block. */
inline
Tensor
&
ShareDataWith
(
const
Tensor
&
src
)
{
src
.
check_memory_size
();
*
this
=
src
;
return
*
this
;
}
/**
* @brief Return a sub-tensor of the given tensor.
*
* @param[in] begin_idx The index of the start row(inclusive) to
* slice.
* The index number begins from 0.
* @param[in] end_idx The index of the end row(exclusive) to
* slice.
* The index number begins from 0.
*/
inline
Tensor
Slice
(
int
begin_idx
,
int
end_idx
)
const
{
check_memory_size
();
// PADDLE_ENFORCE_GE(begin_idx, 0,
// "The start row index must be greater than
// 0.");
// PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is
// out of
// bound."); PADDLE_ENFORCE_LT(
// begin_idx, end_idx,
// "The start row index must be lesser than the end row
// index.");
if
(
dims_
[
0
]
==
1
)
{
return
*
this
;
}
else
{
size_t
base
=
numel
()
/
dims_
[
0
];
Tensor
dst
;
dst
.
holder_
=
holder_
;
dst
.
set_layout
(
layout_
);
DDim
dst_dims
=
dims_
;
dst_dims
[
0
]
=
end_idx
-
begin_idx
;
dst
.
Resize
(
dst_dims
);
dst
.
offset_
=
offset_
+
begin_idx
*
base
*
SizeOfType
(
type
());
return
dst
;
}
}
}
inline
void
*
mutable_data
(
std
::
type_index
type
)
{
if
(
holder_
!=
nullptr
)
{
std
::
type_index
type
()
const
{
holder_
->
set_type
(
type
);
// PADDLE_ENFORCE_NOT_NULL(
}
// holder_, "Tensor not initialized yet
// PADDLE_ENFORCE_GE(numel(), 0,
// when
// "When calling this method, the Tensor's
// Tensor::type() is called.");
// numel must be
return
holder_
->
type
();
// " "equal or larger than zero. " "Please
}
// check
// Tensor::Resize has been called first.");
// memory size returns the holding memory size in byte.
int64_t
size
=
numel
()
*
SizeOfType
(
type
);
size_t
memory_size
()
const
{
/* some versions of boost::variant don't have operator!= */
return
holder_
==
nullptr
?
0UL
:
holder_
->
size
()
-
offset_
;
if
(
holder_
==
nullptr
||
holder_
->
size
()
<
size
+
offset_
)
{
}
holder_
.
reset
(
new
PlaceholderImpl
(
size
,
type
));
inline
void
check_memory_size
()
const
{
offset_
=
0
;
// PADDLE_ENFORCE_NOT_NULL(
}
// holder_, "Tensor holds no memory. Call
return
reinterpret_cast
<
void
*>
(
// Tensor::mutable_data
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
// first.");
// PADDLE_ENFORCE_LE(
// numel() * SizeOfType(type()), memory_size(),
// "Tensor's dims_ is out of bound. Call
// Tensor::mutable_data "
// "first to re-allocate memory.\n"
// "or maybe the required data-type mismatches the data
// already
// stored.");
}
inline
DataLayout
layout
()
const
{
return
layout_
;
}
inline
void
set_layout
(
const
DataLayout
layout
)
{
layout_
=
layout
;
}
private:
/**
* @note Placeholder hides type T, so it doesn't appear as a
* template
* parameter of Variable.
*/
struct
Placeholder
{
virtual
~
Placeholder
()
=
default
;
virtual
void
*
ptr
()
const
=
0
;
virtual
size_t
size
()
const
=
0
;
virtual
std
::
type_index
type
()
const
=
0
;
virtual
void
set_type
(
std
::
type_index
type
)
=
0
;
};
struct
PlaceholderImpl
:
public
Placeholder
{
PlaceholderImpl
(
size_t
size
,
std
::
type_index
type
)
:
ptr_
(
static_cast
<
uint8_t
*>
(
memory
::
Alloc
(
size
)),
memory
::
PODDeleter
<
uint8_t
>
()),
size_
(
size
),
type_
(
type
)
{
// PADDLE_ENFORCE_NOT_NULL(ptr_,
// "Insufficient %s
// memory to allocation.",
// (is_cpu_place(place_)
// ?
// "CPU" :
// "GPU"));
}
}
inline
void
*
mutable_data
()
{
virtual
size_t
size
()
const
{
return
size_
;
}
// PADDLE_ENFORCE(this->holder_ != nullptr,
// "Cannot invoke mutable data if current hold
// nothing.");
return
mutable_data
(
holder_
->
type
());
}
/**
virtual
void
*
ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
* @brief Return a pointer to mutable memory block.
*
* @param[in] dims The dimensions of the memory block.
* @param[in] place The place of the memory block.
*
* @note If not exist, then allocation.
*/
template
<
typename
T
>
inline
T
*
mutable_data
(
DDim
dims
)
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
Resize
(
dims
);
return
mutable_data
<
T
>
();
}
/*! Return the dimensions of the memory block. */
virtual
std
::
type_index
type
()
const
{
return
type_
;
}
inline
const
DDim
&
dims
()
const
{
return
dims_
;
}
/*! Return the numel of the memory block. */
virtual
void
set_type
(
std
::
type_index
type
)
{
type_
=
type
;
}
inline
int64_t
numel
()
const
{
return
product
(
dims_
);
}
/*! Resize the dimensions of the memory block. */
/*! the pointer of memory block. */
inline
Tensor
&
Resize
(
const
DDim
&
dims
)
{
std
::
unique_ptr
<
uint8_t
,
memory
::
PODDeleter
<
uint8_t
>>
ptr_
;
dims_
=
dims
;
return
*
this
;
}
/*! The internal of two tensors share the same memory block. */
/*! the size of memory block. */
inline
Tensor
&
ShareDataWith
(
const
Tensor
&
src
)
{
size_t
size_
;
src
.
check_memory_size
();
*
this
=
src
;
return
*
this
;
}
/**
/* the current type of memory */
* @brief Return a sub-tensor of the given tensor.
std
::
type_index
type_
;
*
};
* @param[in] begin_idx The index of the start row(inclusive) to
* slice.
* The index number begins from 0.
* @param[in] end_idx The index of the end row(exclusive) to
* slice.
* The index number begins from 0.
*/
inline
Tensor
Slice
(
int
begin_idx
,
int
end_idx
)
const
{
check_memory_size
();
// PADDLE_ENFORCE_GE(begin_idx, 0,
// "The start row index must be greater than
// 0.");
// PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is
// out of
// bound."); PADDLE_ENFORCE_LT(
// begin_idx, end_idx,
// "The start row index must be lesser than the end row
// index.");
if
(
dims_
[
0
]
==
1
)
{
return
*
this
;
}
else
{
size_t
base
=
numel
()
/
dims_
[
0
];
Tensor
dst
;
dst
.
holder_
=
holder_
;
dst
.
set_layout
(
layout_
);
DDim
dst_dims
=
dims_
;
dst_dims
[
0
]
=
end_idx
-
begin_idx
;
dst
.
Resize
(
dst_dims
);
dst
.
offset_
=
offset_
+
begin_idx
*
base
*
SizeOfType
(
type
());
return
dst
;
}
}
std
::
type_index
type
()
const
{
/*! holds the memory block if allocated. */
// PADDLE_ENFORCE_NOT_NULL(
std
::
shared_ptr
<
Placeholder
>
holder_
;
// holder_, "Tensor not initialized yet
// when
// Tensor::type() is called.");
return
holder_
->
type
();
}
// memory size returns the holding memory size in byte.
/**
size_t
memory_size
()
const
{
* @brief points to elements dimensions.
return
holder_
==
nullptr
?
0UL
:
holder_
->
size
()
-
offset_
;
*
}
* @note dims_ do not indicate the memory block size.
*/
inline
void
check_memory_size
()
const
{
DDim
dims_
;
// PADDLE_ENFORCE_NOT_NULL(
// holder_, "Tensor holds no memory. Call
/**
// Tensor::mutable_data
* @brief the layout of memory block, default is NHWC.
// first.");
*
// PADDLE_ENFORCE_LE(
* @note the memory allocation order, describe how weight/data is
// numel() * SizeOfType(type()), memory_size(),
* stored
// "Tensor's dims_ is out of bound. Call
* For example, in 4-D Tensor(rank=4), there are three
// Tensor::mutable_data "
* commonly
// "first to re-allocate memory.\n"
* used layout. They are
// "or maybe the required data-type mismatches the data
* NCHW, NHWC, CHWN.
// already
* N,C,H,W for respectively the batch size, the number of
// stored.");
* feature maps, the height, the width.
}
*/
DataLayout
layout_
=
DataLayout
::
kNHWC
;
inline
DataLayout
layout
()
const
{
return
layout_
;
}
/**
* @brief A PlaceHolder may be shared by more than one tensor.
inline
void
set_layout
(
const
DataLayout
layout
)
{
layout_
=
layout
;
}
*
* @note Some of them may be slices of the others. So the offset_
private:
* is introduced here to indicate the byte offset between
/**
* PlaceHolder::ptr_ and where the tensor data really
* @note Placeholder hides type T, so it doesn't appear as a
* begins.
* template
*/
* parameter of Variable.
size_t
offset_
;
*/
struct
Placeholder
{
virtual
~
Placeholder
()
=
default
;
virtual
void
*
ptr
()
const
=
0
;
virtual
size_t
size
()
const
=
0
;
virtual
std
::
type_index
type
()
const
=
0
;
virtual
void
set_type
(
std
::
type_index
type
)
=
0
;
};
struct
PlaceholderImpl
:
public
Placeholder
{
PlaceholderImpl
(
size_t
size
,
std
::
type_index
type
)
:
ptr_
(
static_cast
<
uint8_t
*>
(
memory
::
Alloc
(
size
)),
memory
::
PODDeleter
<
uint8_t
>
()),
size_
(
size
),
type_
(
type
)
{
// PADDLE_ENFORCE_NOT_NULL(ptr_,
// "Insufficient %s
// memory to allocation.",
// (is_cpu_place(place_)
// ?
// "CPU" :
// "GPU"));
}
virtual
size_t
size
()
const
{
return
size_
;
}
virtual
void
*
ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
virtual
std
::
type_index
type
()
const
{
return
type_
;
}
virtual
void
set_type
(
std
::
type_index
type
)
{
type_
=
type
;
}
/*! the pointer of memory block. */
std
::
unique_ptr
<
uint8_t
,
memory
::
PODDeleter
<
uint8_t
>>
ptr_
;
/*! the size of memory block. */
size_t
size_
;
/* the current type of memory */
std
::
type_index
type_
;
};
/*! holds the memory block if allocated. */
std
::
shared_ptr
<
Placeholder
>
holder_
;
/**
* @brief points to elements dimensions.
*
* @note dims_ do not indicate the memory block size.
*/
DDim
dims_
;
/**
* @brief the layout of memory block, default is NHWC.
*
* @note the memory allocation order, describe how weight/data is
* stored
* For example, in 4-D Tensor(rank=4), there are three
* commonly
* used layout. They are
* NCHW, NHWC, CHWN.
* N,C,H,W for respectively the batch size, the number of
* feature maps, the height, the width.
*/
DataLayout
layout_
=
DataLayout
::
kNHWC
;
/**
* @brief A PlaceHolder may be shared by more than one tensor.
*
* @note Some of them may be slices of the others. So the offset_
* is introduced here to indicate the byte offset between
* PlaceHolder::ptr_ and where the tensor data really
* begins.
*/
size_t
offset_
;
};
};
inline
Tensor
ReshapeToMatrix
(
const
Tensor
&
src
,
int
num_col_dims
)
{
inline
Tensor
ReshapeToMatrix
(
const
Tensor
&
src
,
int
num_col_dims
)
{
Tensor
res
;
Tensor
res
;
res
.
ShareDataWith
(
src
);
res
.
ShareDataWith
(
src
);
res
.
Resize
(
flatten_to_2d
(
src
.
dims
(),
num_col_dims
));
res
.
Resize
(
flatten_to_2d
(
src
.
dims
(),
num_col_dims
));
return
res
;
return
res
;
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/tensor_util.cc
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/framework/tensor_util.h
浏览文件 @
1ad8f821
...
@@ -43,23 +43,23 @@ void TensorFromStream(std::istream &is, Tensor *tensor);
...
@@ -43,23 +43,23 @@ void TensorFromStream(std::istream &is, Tensor *tensor);
template
<
typename
T
>
template
<
typename
T
>
void
TensorFromVector
(
const
std
::
vector
<
T
>
&
src
,
Tensor
*
dst
)
{
void
TensorFromVector
(
const
std
::
vector
<
T
>
&
src
,
Tensor
*
dst
)
{
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
());
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
());
dst
->
Resize
({
static_cast
<
int64_t
>
(
src
.
size
())});
dst
->
Resize
({
static_cast
<
int64_t
>
(
src
.
size
())});
auto
dst_ptr
=
static_cast
<
void
*>
(
dst
->
mutable_data
<
T
>
());
auto
dst_ptr
=
static_cast
<
void
*>
(
dst
->
mutable_data
<
T
>
());
auto
size
=
src
.
size
()
*
sizeof
(
T
);
auto
size
=
src
.
size
()
*
sizeof
(
T
);
memory
::
Copy
(
dst_ptr
,
src_ptr
,
size
);
memory
::
Copy
(
dst_ptr
,
src_ptr
,
size
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
TensorToVector
(
const
Tensor
&
src
,
std
::
vector
<
T
>
*
dst
)
{
void
TensorToVector
(
const
Tensor
&
src
,
std
::
vector
<
T
>
*
dst
)
{
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
auto
size
=
src
.
numel
()
*
sizeof
(
T
);
auto
size
=
src
.
numel
()
*
sizeof
(
T
);
dst
->
resize
(
src
.
numel
());
dst
->
resize
(
src
.
numel
());
auto
dst_ptr
=
static_cast
<
void
*>
(
dst
->
data
());
auto
dst_ptr
=
static_cast
<
void
*>
(
dst
->
data
());
memory
::
Copy
(
dst_ptr
,
src_ptr
,
size
);
memory
::
Copy
(
dst_ptr
,
src_ptr
,
size
);
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/var_desc.h
浏览文件 @
1ad8f821
...
@@ -25,63 +25,63 @@ namespace paddle_mobile {
...
@@ -25,63 +25,63 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
VarDesc
{
class
VarDesc
{
public:
public:
VarDesc
(
const
proto
::
VarDesc
&
desc
);
VarDesc
(
const
proto
::
VarDesc
&
desc
);
std
::
string
Name
()
const
{
return
desc_
.
name
();
}
std
::
string
Name
()
const
{
return
desc_
.
name
();
}
proto
::
VarType
::
Type
GetType
()
const
{
return
desc_
.
type
().
type
();
}
proto
::
VarType
::
Type
GetType
()
const
{
return
desc_
.
type
().
type
();
}
bool
Persistable
()
const
{
return
desc_
.
persistable
();
}
bool
Persistable
()
const
{
return
desc_
.
persistable
();
}
const
proto
::
VarType
::
ChannelDesc
&
channel_desc
()
const
{
const
proto
::
VarType
::
ChannelDesc
&
channel_desc
()
const
{
switch
(
desc_
.
type
().
type
())
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
case
proto
::
VarType
::
CHANNEL
:
return
desc_
.
type
().
channel
();
return
desc_
.
type
().
channel
();
default:
default:
break
;
break
;
}
}
}
}
const
proto
::
VarType
::
TensorDesc
&
tensor_desc
()
const
{
const
proto
::
VarType
::
TensorDesc
&
tensor_desc
()
const
{
switch
(
desc_
.
type
().
type
())
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
SELECTED_ROWS
:
case
proto
::
VarType
::
SELECTED_ROWS
:
return
desc_
.
type
().
selected_rows
();
return
desc_
.
type
().
selected_rows
();
case
proto
::
VarType
::
LOD_TENSOR
:
case
proto
::
VarType
::
LOD_TENSOR
:
return
desc_
.
type
().
lod_tensor
().
tensor
();
return
desc_
.
type
().
lod_tensor
().
tensor
();
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
return
desc_
.
type
().
tensor_array
().
tensor
();
return
desc_
.
type
().
tensor_array
().
tensor
();
default:
default:
break
;
break
;
}
}
}
}
proto
::
VarType
::
Type
GetDataType
()
const
{
proto
::
VarType
::
Type
GetDataType
()
const
{
switch
(
desc_
.
type
().
type
())
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
case
proto
::
VarType
::
CHANNEL
:
return
channel_desc
().
data_type
();
return
channel_desc
().
data_type
();
break
;
break
;
default:
default:
return
tensor_desc
().
data_type
();
return
tensor_desc
().
data_type
();
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
std
::
vector
<
T
>
RepeatedToVector
(
std
::
vector
<
T
>
RepeatedToVector
(
const
google
::
protobuf
::
RepeatedField
<
T
>
&
repeated_field
)
const
{
const
google
::
protobuf
::
RepeatedField
<
T
>
&
repeated_field
)
const
{
std
::
vector
<
T
>
ret
;
std
::
vector
<
T
>
ret
;
ret
.
reserve
(
repeated_field
.
size
());
ret
.
reserve
(
repeated_field
.
size
());
std
::
copy
(
repeated_field
.
begin
(),
repeated_field
.
end
(),
std
::
copy
(
repeated_field
.
begin
(),
repeated_field
.
end
(),
std
::
back_inserter
(
ret
));
std
::
back_inserter
(
ret
));
return
ret
;
return
ret
;
}
}
std
::
vector
<
int64_t
>
GetShape
()
const
{
std
::
vector
<
int64_t
>
GetShape
()
const
{
return
this
->
RepeatedToVector
(
tensor_desc
().
dims
());
return
this
->
RepeatedToVector
(
tensor_desc
().
dims
());
}
}
private:
private:
proto
::
VarDesc
desc_
;
proto
::
VarDesc
desc_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
src/framework/var_type.h
浏览文件 @
1ad8f821
...
@@ -25,14 +25,14 @@ SOFTWARE.
...
@@ -25,14 +25,14 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
inline
proto
::
VarType
::
Type
ToVarType
(
std
::
type_index
type
)
{
inline
proto
::
VarType
::
Type
ToVarType
(
std
::
type_index
type
)
{
if
(
type
.
hash_code
()
==
typeid
(
LoDTensor
).
hash_code
())
{
if
(
type
.
hash_code
()
==
typeid
(
LoDTensor
).
hash_code
())
{
return
proto
::
VarType_Type_LOD_TENSOR
;
return
proto
::
VarType_Type_LOD_TENSOR
;
}
else
if
(
type
.
hash_code
()
==
typeid
(
SelectedRows
).
hash_code
())
{
}
else
if
(
type
.
hash_code
()
==
typeid
(
SelectedRows
).
hash_code
())
{
return
proto
::
VarType_Type_SELECTED_ROWS
;
return
proto
::
VarType_Type_SELECTED_ROWS
;
}
else
{
}
else
{
// PADDLE_THROW("ToVarType:Unsupported type %s",
// PADDLE_THROW("ToVarType:Unsupported type %s",
// type.name());
// type.name());
}
}
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/variable.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/io.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/io.h
浏览文件 @
1ad8f821
...
@@ -29,11 +29,11 @@ namespace paddle_mobile {
...
@@ -29,11 +29,11 @@ namespace paddle_mobile {
template
<
typename
Dtype
,
Precision
P
=
Precision
::
FP32
>
template
<
typename
Dtype
,
Precision
P
=
Precision
::
FP32
>
class
Loader
:
PaddleMobileObject
{
class
Loader
:
PaddleMobileObject
{
public:
public:
const
framework
::
Program
<
Dtype
,
P
>
Load
(
const
std
::
string
&
dirname
);
const
framework
::
Program
<
Dtype
,
P
>
Load
(
const
std
::
string
&
dirname
);
private:
private:
void
LoadVar
(
framework
::
LoDTensor
*
tensor
,
const
std
::
string
&
file_path
);
void
LoadVar
(
framework
::
LoDTensor
*
tensor
,
const
std
::
string
&
file_path
);
};
};
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/memory/t_malloc.cc
浏览文件 @
1ad8f821
...
@@ -26,25 +26,25 @@ namespace memory {
...
@@ -26,25 +26,25 @@ namespace memory {
const
int
MALLOC_ALIGN
=
16
;
const
int
MALLOC_ALIGN
=
16
;
void
Copy
(
void
*
dst
,
const
void
*
src
,
size_t
num
)
{
void
Copy
(
void
*
dst
,
const
void
*
src
,
size_t
num
)
{
std
::
memcpy
(
dst
,
src
,
num
);
std
::
memcpy
(
dst
,
src
,
num
);
};
};
void
*
Alloc
(
size_t
size
)
{
void
*
Alloc
(
size_t
size
)
{
size_t
offset
=
sizeof
(
void
*
)
+
MALLOC_ALIGN
-
1
;
size_t
offset
=
sizeof
(
void
*
)
+
MALLOC_ALIGN
-
1
;
char
*
p
=
static_cast
<
char
*>
(
malloc
(
offset
+
size
));
char
*
p
=
static_cast
<
char
*>
(
malloc
(
offset
+
size
));
if
(
!
p
)
{
if
(
!
p
)
{
return
nullptr
;
return
nullptr
;
}
}
void
*
r
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
size_t
>
(
p
+
offset
)
&
void
*
r
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
size_t
>
(
p
+
offset
)
&
(
~
(
MALLOC_ALIGN
-
1
)));
(
~
(
MALLOC_ALIGN
-
1
)));
static_cast
<
void
**>
(
r
)[
-
1
]
=
p
;
static_cast
<
void
**>
(
r
)[
-
1
]
=
p
;
return
r
;
return
r
;
}
}
void
Free
(
void
*
ptr
)
{
void
Free
(
void
*
ptr
)
{
if
(
ptr
)
{
if
(
ptr
)
{
free
(
static_cast
<
void
**>
(
ptr
)[
-
1
]);
free
(
static_cast
<
void
**>
(
ptr
)[
-
1
]);
}
}
}
}
}
// namespace memory
}
// namespace memory
...
...
src/memory/t_malloc.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/batchnorm_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/batchnorm_op.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/concat_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/concat_op.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/conv_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/conv_op.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/elementwise_add_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/elementwise_add_op.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/arm/batchnorm_kernel.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/arm/concat_kernel.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/arm/conv_kernel.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/arm/elementwise_add_kernel.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/arm/lrn_kernel.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/arm/mul_kernel.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/arm/pool_kernel.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/batchnorm_kernel.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/concat_kernel.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/conv_kernel.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/elementwise_add_kernel.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/lrn_kernel.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/mul_kernel.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/kernel/pool_kernel.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/lrn_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/lrn_op.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/math/elementwise_op_function.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/math/im2col.cc
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/math/im2col.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/math/math_function.cc
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/math/pool3x3.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/math/pool_2x2.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/math/pooling.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/math/pooling.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/math/transform.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/math/vol2col.cc
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/math/vol2col.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/mul_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/mul_op.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/op_param.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/op_param.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/pool_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/operators/pool_op.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/platform/data_type.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
src/platform/macros.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/common/test_log.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/framework/executor_for_test.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/framework/executor_for_test.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/framework/test_load.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/framework/test_optimize.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/operators/test_batchnorm_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/operators/test_concat_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/operators/test_cov_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/operators/test_elementwise_add_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/operators/test_lrn_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/operators/test_mul_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/operators/test_pool_op.cpp
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
test/test_helper.h
浏览文件 @
1ad8f821
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录