Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
47e5978f
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
337
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
47e5978f
编写于
5月 18, 2018
作者:
朔-望
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add clang-format clang-tidy hook
上级
ab5a35c2
变更
76
展开全部
隐藏空白更改
内联
并排
Showing
76 changed file
with
17667 addition
and
18947 deletion
+17667
-18947
.pre-commit-config.yaml
.pre-commit-config.yaml
+10
-1
src/common/log.h
src/common/log.h
+106
-107
src/common/type_define.h
src/common/type_define.h
+11
-12
src/common/types.h
src/common/types.h
+26
-26
src/common/variant.h
src/common/variant.h
+56
-56
src/framework/attribute.h
src/framework/attribute.h
+85
-87
src/framework/block_desc.cpp
src/framework/block_desc.cpp
+16
-16
src/framework/block_desc.h
src/framework/block_desc.h
+25
-28
src/framework/data_layout.h
src/framework/data_layout.h
+30
-27
src/framework/data_transform.cpp
src/framework/data_transform.cpp
+54
-54
src/framework/data_transform.h
src/framework/data_transform.h
+7
-7
src/framework/data_type.h
src/framework/data_type.h
+18
-18
src/framework/ddim.cc
src/framework/ddim.cc
+310
-312
src/framework/ddim.h
src/framework/ddim.h
+132
-135
src/framework/dim.h
src/framework/dim.h
+350
-368
src/framework/executor.cpp
src/framework/executor.cpp
+56
-59
src/framework/executor.h
src/framework/executor.h
+15
-15
src/framework/framework.pb.cpp
src/framework/framework.pb.cpp
+7986
-8584
src/framework/framework.pb.h
src/framework/framework.pb.h
+4928
-5426
src/framework/lod_tensor.cc
src/framework/lod_tensor.cc
+275
-284
src/framework/lod_tensor.h
src/framework/lod_tensor.h
+181
-185
src/framework/op_desc.cpp
src/framework/op_desc.cpp
+45
-48
src/framework/op_desc.h
src/framework/op_desc.h
+18
-20
src/framework/op_info.h
src/framework/op_info.h
+69
-70
src/framework/op_kernel_type.h
src/framework/op_kernel_type.h
+31
-38
src/framework/op_proto_maker.h
src/framework/op_proto_maker.h
+4
-4
src/framework/operator.cpp
src/framework/operator.cpp
+16
-16
src/framework/operator.h
src/framework/operator.h
+51
-55
src/framework/paddle_mobile_object.h
src/framework/paddle_mobile_object.h
+9
-9
src/framework/program.cpp
src/framework/program.cpp
+1
-1
src/framework/program.h
src/framework/program.h
+10
-10
src/framework/program_desc.cpp
src/framework/program_desc.cpp
+11
-11
src/framework/program_desc.h
src/framework/program_desc.h
+11
-13
src/framework/scope.cc
src/framework/scope.cc
+98
-98
src/framework/scope.h
src/framework/scope.h
+38
-38
src/framework/selected_rows.h
src/framework/selected_rows.h
+41
-42
src/framework/tensor.h
src/framework/tensor.h
+303
-309
src/framework/tensor_util.cc
src/framework/tensor_util.cc
+183
-185
src/framework/tensor_util.h
src/framework/tensor_util.h
+31
-31
src/framework/var_desc.cpp
src/framework/var_desc.cpp
+3
-3
src/framework/var_desc.h
src/framework/var_desc.h
+52
-53
src/framework/var_type.h
src/framework/var_type.h
+12
-12
src/framework/variable.h
src/framework/variable.h
+66
-66
src/io.cpp
src/io.cpp
+349
-361
src/io.h
src/io.h
+7
-8
src/memory/t_malloc.cc
src/memory/t_malloc.cc
+22
-22
src/memory/t_malloc.h
src/memory/t_malloc.h
+32
-32
src/operators/conv_op.cpp
src/operators/conv_op.cpp
+27
-27
src/operators/conv_op.h
src/operators/conv_op.h
+22
-23
src/operators/elementwise_add_op.cpp
src/operators/elementwise_add_op.cpp
+5
-5
src/operators/elementwise_add_op.h
src/operators/elementwise_add_op.h
+21
-24
src/operators/kernel/arm/conv_kernel.cpp
src/operators/kernel/arm/conv_kernel.cpp
+119
-128
src/operators/kernel/arm/elementwise_add_kernel.cpp
src/operators/kernel/arm/elementwise_add_kernel.cpp
+11
-11
src/operators/kernel/arm/mul_kernel.cpp
src/operators/kernel/arm/mul_kernel.cpp
+24
-26
src/operators/kernel/conv_kernel.h
src/operators/kernel/conv_kernel.h
+6
-7
src/operators/kernel/elementwise_add_kernel.h
src/operators/kernel/elementwise_add_kernel.h
+5
-5
src/operators/kernel/fpga/conv_kernel.cpp
src/operators/kernel/fpga/conv_kernel.cpp
+1
-1
src/operators/kernel/mul_kernel.h
src/operators/kernel/mul_kernel.h
+5
-5
src/operators/math/elementwise_op_function.h
src/operators/math/elementwise_op_function.h
+144
-148
src/operators/math/im2col.cc
src/operators/math/im2col.cc
+237
-270
src/operators/math/im2col.h
src/operators/math/im2col.h
+13
-15
src/operators/math/math_function.cc
src/operators/math/math_function.cc
+85
-100
src/operators/math/math_function.h
src/operators/math/math_function.h
+13
-15
src/operators/math/transform.h
src/operators/math/transform.h
+15
-17
src/operators/math/vol2col.cc
src/operators/math/vol2col.cc
+165
-177
src/operators/math/vol2col.h
src/operators/math/vol2col.h
+11
-15
src/operators/mul_op.cpp
src/operators/mul_op.cpp
+24
-25
src/operators/mul_op.h
src/operators/mul_op.h
+20
-21
src/operators/op_param.cpp
src/operators/op_param.cpp
+19
-19
src/operators/op_param.h
src/operators/op_param.h
+169
-179
test/common/test_log.cpp
test/common/test_log.cpp
+3
-2
test/elementwise_add_op_test.h
test/elementwise_add_op_test.h
+136
-141
test/mul_op_test.h
test/mul_op_test.h
+161
-173
test/test_include.h
test/test_include.h
+0
-1
tools/pre-commit.hooks/.clang-format.hook
tools/pre-commit.hooks/.clang-format.hook
+15
-0
tools/pre-commit.hooks/.clang-tidy.hook
tools/pre-commit.hooks/.clang-tidy.hook
+1
-5
未找到文件。
.pre-commit-config.yaml
浏览文件 @
47e5978f
...
@@ -20,11 +20,20 @@ repos:
...
@@ -20,11 +20,20 @@ repos:
-
id
:
trailing-whitespace
-
id
:
trailing-whitespace
files
:
(src).*\.(md|py|mm|swift|java|c|cc|cxx|cpp|cu|h|hpp|hxx)$
files
:
(src).*\.(md|py|mm|swift|java|c|cc|cxx|cpp|cu|h|hpp|hxx)$
-
repo
:
local
hooks
:
-
id
:
clang-format
name
:
clang-format
description
:
Format files with ClangFormat.
entry
:
bash ./tools/pre-commit.hooks/.clang-format.hook -i
language
:
system
files
:
\.(c|cc|cxx|cpp|h|hpp|hxx)$
-
repo
:
local
-
repo
:
local
hooks
:
hooks
:
-
id
:
clang-tidy
-
id
:
clang-tidy
name
:
clang-tidy
name
:
clang-tidy
description
:
Format files with
tidy.
description
:
Check C++ code style using clang-
tidy.
entry
:
bash ./tools/pre-commit.hooks/.clang-tidy.hook -i
entry
:
bash ./tools/pre-commit.hooks/.clang-tidy.hook -i
language
:
system
language
:
system
files
:
(src).*\.(c|cc|cxx|cpp|h|hpp|hxx)$
files
:
(src).*\.(c|cc|cxx|cpp|h|hpp|hxx)$
...
...
src/common/log.h
浏览文件 @
47e5978f
...
@@ -27,146 +27,145 @@ SOFTWARE.
...
@@ -27,146 +27,145 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
enum
LogLevel
{
enum
LogLevel
{
kNO_LOG
,
kNO_LOG
,
kLOG_ERROR
,
kLOG_ERROR
,
kLOG_WARNING
,
kLOG_WARNING
,
kLOG_INFO
,
kLOG_INFO
,
kLOG_DEBUG
,
kLOG_DEBUG
,
kLOG_DEBUG1
,
kLOG_DEBUG1
,
kLOG_DEBUG2
,
kLOG_DEBUG2
,
kLOG_DEBUG3
,
kLOG_DEBUG3
,
kLOG_DEBUG4
kLOG_DEBUG4
};
};
// log level
// log level
static
LogLevel
log_level
=
kLOG_DEBUG4
;
static
LogLevel
log_level
=
kLOG_DEBUG4
;
static
std
::
vector
<
std
::
string
>
logs
{
"NO"
,
"ERROR "
,
"WARNING"
,
static
std
::
vector
<
std
::
string
>
logs
{
"NO"
,
"ERROR "
,
"WARNING"
,
"INFO "
,
"DEBUG "
,
"DEBUG1 "
,
"INFO "
,
"DEBUG "
,
"DEBUG1 "
,
"DEBUG2 "
,
"DEBUG3 "
,
"DEBUG4 "
};
"DEBUG2 "
,
"DEBUG3 "
,
"DEBUG4 "
};
struct
ToLog
;
struct
ToLog
;
struct
Print
;
struct
Print
;
struct
Print
{
struct
Print
{
friend
struct
ToLog
;
friend
struct
ToLog
;
template
<
typename
T
>
Print
&
operator
<<
(
T
const
&
value
)
{
template
<
typename
T
>
Print
&
operator
<<
(
T
const
&
value
)
{
buffer_
<<
value
;
buffer_
<<
value
;
return
*
this
;
return
*
this
;
}
private:
void
print
(
LogLevel
level
)
{
buffer_
<<
std
::
endl
;
if
(
level
==
kLOG_ERROR
)
{
std
::
cerr
<<
buffer_
.
str
();
}
else
{
std
::
cout
<<
buffer_
.
str
();
}
}
}
private:
std
::
ostringstream
buffer_
;
void
print
(
LogLevel
level
)
{
};
buffer_
<<
std
::
endl
;
if
(
level
==
kLOG_ERROR
)
{
struct
ToLog
{
std
::
cerr
<<
buffer_
.
str
();
ToLog
(
LogLevel
level
=
kLOG_DEBUG
,
const
std
::
string
&
info
=
""
)
}
else
{
:
level_
(
level
)
{
std
::
cout
<<
buffer_
.
str
();
unsigned
blanks
=
}
(
unsigned
)(
level
>
kLOG_DEBUG
?
(
level
-
kLOG_DEBUG
)
*
4
:
1
);
}
printer_
<<
logs
[
level
]
<<
" "
<<
info
<<
":"
std
::
ostringstream
buffer_
;
<<
std
::
string
(
blanks
,
' '
);
};
}
struct
ToLog
{
template
<
typename
T
>
ToLog
&
operator
<<
(
T
const
&
value
)
{
ToLog
(
LogLevel
level
=
kLOG_DEBUG
,
const
std
::
string
&
info
=
""
)
printer_
<<
value
;
:
level_
(
level
)
{
return
*
this
;
unsigned
blanks
=
}
(
unsigned
)(
level
>
kLOG_DEBUG
?
(
level
-
kLOG_DEBUG
)
*
4
:
1
);
printer_
<<
logs
[
level
]
<<
" "
<<
info
<<
":"
~
ToLog
()
{
printer_
.
print
(
level_
);
}
<<
std
::
string
(
blanks
,
' '
);
}
private:
LogLevel
level_
;
template
<
typename
T
>
ToLog
&
operator
<<
(
T
const
&
value
)
{
Print
printer_
;
printer_
<<
value
;
};
return
*
this
;
}
~
ToLog
()
{
printer_
.
print
(
level_
);
}
private:
LogLevel
level_
;
Print
printer_
;
};
#define LOG(level) \
#define LOG(level) \
if (level > paddle_mobile::log_level) { \
if (level > paddle_mobile::log_level) { \
} else \
} else \
paddle_mobile::ToLog(
\
paddle_mobile::ToLog(
\
level,
\
level, (std::stringstream()
\
(std::stringstream()
\
<< "[file: "
\
<< "[file: "
\
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1)
\
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) : __FILE__)
\
: __FILE__)
\
<< "] [line: " << __LINE__ << "] ")
\
<< "] [line: " << __LINE__ << "] ")
\
.str())
.str())
#define DLOG \
#define DLOG \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \
} else \
} else \
paddle_mobile::ToLog( \
paddle_mobile::ToLog( \
paddle_mobile::kLOG_DEBUG, \
paddle_mobile::kLOG_DEBUG, \
(std::stringstream() \
(std::stringstream() \
<< "[file: " \
<< "[file: " \
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) : __FILE__) \
<< (strrchr(__FILE__, '/') ? (strrchr(__FILE__, '/') + 1) \
<< "] [line: " << __LINE__ << "] ") \
: __FILE__) \
.str())
<< "] [line: " << __LINE__ << "] ") \
}
.str())
}
// namespace paddle_mobile
#define LOGF(level, format, ...) \
#define LOGF(level, format, ...) \
if (level > paddle_mobile::log_level) { \
if (level > paddle_mobile::log_level) { \
} else \
} else \
printf(format, ##__VA_ARGS__)
printf(format, ##__VA_ARGS__)
#define DLOGF(format, ...) \
#define DLOGF(format, ...) \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \
if (paddle_mobile::kLOG_DEBUG > paddle_mobile::log_level) { \
} else \
} else \
printf(format, ##__VA_ARGS__)
printf(format, ##__VA_ARGS__)
#else
#else
namespace
paddle_mobile
{
namespace
paddle_mobile
{
enum
LogLevel
{
enum
LogLevel
{
kNO_LOG
,
kNO_LOG
,
kLOG_ERROR
,
kLOG_ERROR
,
kLOG_WARNING
,
kLOG_WARNING
,
kLOG_INFO
,
kLOG_INFO
,
kLOG_DEBUG
,
kLOG_DEBUG
,
kLOG_DEBUG1
,
kLOG_DEBUG1
,
kLOG_DEBUG2
,
kLOG_DEBUG2
,
kLOG_DEBUG3
,
kLOG_DEBUG3
,
kLOG_DEBUG4
kLOG_DEBUG4
};
};
struct
ToLog
;
struct
ToLog
;
struct
Print
{
struct
Print
{
friend
struct
ToLog
;
friend
struct
ToLog
;
template
<
typename
T
>
Print
&
operator
<<
(
T
const
&
value
)
{}
template
<
typename
T
>
Print
&
operator
<<
(
T
const
&
value
)
{}
private:
private:
};
};
struct
ToLog
{
struct
ToLog
{
ToLog
(
LogLevel
level
)
{}
ToLog
(
LogLevel
level
)
{}
template
<
typename
T
>
ToLog
&
operator
<<
(
T
const
&
value
)
{
template
<
typename
T
>
ToLog
&
operator
<<
(
T
const
&
value
)
{
return
*
this
;
}
return
*
this
;
};
}
};
#define LOG(level) \
#define LOG(level) \
if (true) { \
if (true) { \
} else \
} else \
paddle_mobile::ToLog(level)
paddle_mobile::ToLog(level)
#define DLOG \
#define DLOG \
if (true) { \
if (true) { \
} else \
} else \
paddle_mobile::ToLog(paddle_mobile::kLOG_DEBUG)
paddle_mobile::ToLog(paddle_mobile::kLOG_DEBUG)
#define LOGF(level, format, ...)
#define LOGF(level, format, ...)
#define DLOGF(format, ...)
#define DLOGF(format, ...)
}
}
// namespace paddle_mobile
#endif
#endif
src/common/type_define.h
浏览文件 @
47e5978f
...
@@ -24,30 +24,29 @@ SOFTWARE.
...
@@ -24,30 +24,29 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
template
<
typename
Dtype
>
class
OperatorBase
;
template
<
typename
Dtype
>
class
OperatorBase
;
class
OpDesc
;
class
OpDesc
;
class
BlockDesc
;
class
BlockDesc
;
class
InferShapeContext
;
class
InferShapeContext
;
}
}
// namespace framework
using
VariableNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
using
VariableNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
template
<
typename
Dtype
>
template
<
typename
Dtype
>
using
OpCreator
=
std
::
function
<
framework
::
OperatorBase
<
Dtype
>
*
(
using
OpCreator
=
std
::
function
<
framework
::
OperatorBase
<
Dtype
>
*
(
const
std
::
string
&
/*type*/
,
const
VariableNameMap
&
/*inputs*/
,
const
std
::
string
&
/*type*/
,
const
VariableNameMap
&
/*inputs*/
,
const
VariableNameMap
&
/*outputs*/
,
const
VariableNameMap
&
/*outputs*/
,
const
framework
::
AttributeMap
&
/*attrs*/
)
>
;
const
framework
::
AttributeMap
&
/*attrs*/
)
>
;
using
GradOpMakerFN
=
using
GradOpMakerFN
=
std
::
function
<
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDesc
>>
(
std
::
function
<
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDesc
>>
(
const
framework
::
OpDesc
&
,
const
framework
::
OpDesc
&
,
const
std
::
unordered_set
<
std
::
string
>
&
/*no_grad_set*/
,
const
std
::
unordered_set
<
std
::
string
>
&
/*no_grad_set*/
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
*
/*grad_to_var*/
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
*
/*grad_to_var*/
,
const
std
::
vector
<
framework
::
BlockDesc
*>
&
grad_block
)
>
;
const
std
::
vector
<
framework
::
BlockDesc
*>
&
grad_block
)
>
;
using
InferVarTypeFN
=
using
InferVarTypeFN
=
std
::
function
<
void
(
const
framework
::
OpDesc
&
/*op_desc*/
,
std
::
function
<
void
(
const
framework
::
OpDesc
&
/*op_desc*/
,
framework
::
BlockDesc
*
/*block*/
)
>
;
framework
::
BlockDesc
*
/*block*/
)
>
;
using
InferShapeFN
=
std
::
function
<
void
(
framework
::
InferShapeContext
*
)
>
;
using
InferShapeFN
=
std
::
function
<
void
(
framework
::
InferShapeContext
*
)
>
;
};
};
// namespace paddle_mobile
src/common/types.h
浏览文件 @
47e5978f
...
@@ -24,7 +24,7 @@ enum class Precision : int { FP32 = 0 };
...
@@ -24,7 +24,7 @@ enum class Precision : int { FP32 = 0 };
//! device type
//! device type
enum
DeviceTypeEnum
{
kINVALID
=
-
1
,
kCPU
=
0
,
kFPGA
=
1
,
kGPU_MALI
=
2
};
enum
DeviceTypeEnum
{
kINVALID
=
-
1
,
kCPU
=
0
,
kFPGA
=
1
,
kGPU_MALI
=
2
};
template
<
DeviceTypeEnum
T
>
struct
DeviceType
{};
template
<
DeviceTypeEnum
T
>
struct
DeviceType
{};
typedef
DeviceType
<
kCPU
>
CPU
;
typedef
DeviceType
<
kCPU
>
CPU
;
typedef
DeviceType
<
kFPGA
>
FPGA
;
typedef
DeviceType
<
kFPGA
>
FPGA
;
...
@@ -32,32 +32,32 @@ typedef DeviceType<kGPU_MALI> GPU_MALI;
...
@@ -32,32 +32,32 @@ typedef DeviceType<kGPU_MALI> GPU_MALI;
//! data type
//! data type
enum
DataType
{
enum
DataType
{
PM_INVALID
=
-
1
,
PM_INVALID
=
-
1
,
PM_HALF
=
0
,
PM_HALF
=
0
,
PM_FLOAT
=
1
,
PM_FLOAT
=
1
,
PM_DOUBLE
=
2
,
PM_DOUBLE
=
2
,
PM_INT8
=
3
,
PM_INT8
=
3
,
PM_INT16
=
4
,
PM_INT16
=
4
,
PM_INT32
=
5
,
PM_INT32
=
5
,
PM_INT64
=
6
,
PM_INT64
=
6
,
PM_UINT8
=
7
,
PM_UINT8
=
7
,
PM_UINT16
=
8
,
PM_UINT16
=
8
,
PM_UINT32
=
9
,
PM_UINT32
=
9
,
PM_STRING
=
10
,
PM_STRING
=
10
,
PM_BOOL
=
11
,
PM_BOOL
=
11
,
PM_SHAPE
=
12
,
PM_SHAPE
=
12
,
PM_TENSOR
=
13
PM_TENSOR
=
13
};
};
//!
//!
enum
PMStatus
{
enum
PMStatus
{
PMSuccess
=
0xFF
,
/*!< No errors */
PMSuccess
=
0xFF
,
/*!< No errors */
PMNotInitialized
=
0x01
,
/*!< Data not initialized. */
PMNotInitialized
=
0x01
,
/*!< Data not initialized. */
PMInvalidValue
=
0x02
,
/*!< Incorrect variable value. */
PMInvalidValue
=
0x02
,
/*!< Incorrect variable value. */
PMMemAllocFailed
=
0x03
,
/*!< Memory allocation error. */
PMMemAllocFailed
=
0x03
,
/*!< Memory allocation error. */
PMUnKownError
=
0x04
,
/*!< Unknown error. */
PMUnKownError
=
0x04
,
/*!< Unknown error. */
PMOutOfAuthority
=
0x05
,
/*!< Try to modified data not your own*/
PMOutOfAuthority
=
0x05
,
/*!< Try to modified data not your own*/
PMOutOfMem
=
0x06
,
/*!< OOM error*/
PMOutOfMem
=
0x06
,
/*!< OOM error*/
PMUnImplError
=
0x07
,
/*!< Unimplement error. */
PMUnImplError
=
0x07
,
/*!< Unimplement error. */
PMWrongDevice
=
0x08
/*!< un-correct device. */
PMWrongDevice
=
0x08
/*!< un-correct device. */
};
};
}
}
// namespace paddle_mobile
src/common/variant.h
浏览文件 @
47e5978f
...
@@ -21,79 +21,79 @@ SOFTWARE.
...
@@ -21,79 +21,79 @@ SOFTWARE.
#pragma once
#pragma once
namespace
paddle_mobile
{
namespace
paddle_mobile
{
template
<
int
ID
,
typename
Type
>
struct
IDToType
{
typedef
Type
type_t
;
};
template
<
int
ID
,
typename
Type
>
struct
IDToType
{
typedef
Type
type_t
;
};
template
<
typename
F
,
typename
...
Ts
>
struct
VariantHelper
{
template
<
typename
F
,
typename
...
Ts
>
struct
VariantHelper
{
static
const
size_t
size
=
sizeof
(
F
)
>
VariantHelper
<
Ts
...
>::
size
static
const
size_t
size
=
sizeof
(
F
)
>
VariantHelper
<
Ts
...
>::
size
?
sizeof
(
F
)
?
sizeof
(
F
)
:
VariantHelper
<
Ts
...
>::
size
;
:
VariantHelper
<
Ts
...
>::
size
;
inline
static
void
Destroy
(
size_t
id
,
void
*
data
)
{
inline
static
void
Destroy
(
size_t
id
,
void
*
data
)
{
if
(
id
==
typeid
(
F
).
hash_code
())
{
if
(
id
==
typeid
(
F
).
hash_code
())
{
reinterpret_cast
<
F
*>
(
data
)
->~
F
();
reinterpret_cast
<
F
*>
(
data
)
->~
F
();
}
else
{
}
else
{
VariantHelper
<
Ts
...
>::
Destroy
(
id
,
data
);
VariantHelper
<
Ts
...
>::
Destroy
(
id
,
data
);
}
}
}
}
};
};
template
<
typename
F
>
struct
VariantHelper
<
F
>
{
template
<
typename
F
>
struct
VariantHelper
<
F
>
{
static
const
size_t
size
=
sizeof
(
F
);
static
const
size_t
size
=
sizeof
(
F
);
inline
static
void
Destroy
(
size_t
id
,
void
*
data
)
{
inline
static
void
Destroy
(
size_t
id
,
void
*
data
)
{
if
(
id
==
typeid
(
F
).
hash_code
())
{
if
(
id
==
typeid
(
F
).
hash_code
())
{
// reinterpret_cast<F*>(data)->~F();
// reinterpret_cast<F*>(data)->~F();
}
else
{
}
else
{
// std::cout << "未匹配到 " << std::endl;
// std::cout << "未匹配到 " << std::endl;
}
}
}
}
};
};
template
<
size_t
size
>
class
RawData
{
template
<
size_t
size
>
class
RawData
{
public:
public:
char
data
[
size
];
char
data
[
size
];
RawData
()
{}
RawData
()
{}
RawData
(
const
RawData
&
raw_data
)
{
strcpy
(
data
,
raw_data
.
data
);
}
RawData
(
const
RawData
&
raw_data
)
{
strcpy
(
data
,
raw_data
.
data
);
}
// void operator=(const RawData &raw_data){
// void operator=(const RawData &raw_data){
// strcpy(data, raw_data.data);
// strcpy(data, raw_data.data);
// }
// }
};
};
template
<
typename
...
Ts
>
struct
Variant
{
template
<
typename
...
Ts
>
struct
Variant
{
Variant
(
const
Variant
&
variant
)
{
Variant
(
const
Variant
&
variant
)
{
// std::cout << " 赋值构造函数 " << std::endl;
// std::cout << " 赋值构造函数 " << std::endl;
type_id
=
variant
.
type_id
;
type_id
=
variant
.
type_id
;
data
=
variant
.
data
;
data
=
variant
.
data
;
}
}
Variant
()
:
type_id
(
invalid_type
())
{}
Variant
()
:
type_id
(
invalid_type
())
{}
~
Variant
()
{
~
Variant
()
{
// helper::Destroy(type_id, &data);
// helper::Destroy(type_id, &data);
}
}
template
<
typename
T
,
typename
...
Args
>
void
Set
(
Args
&&
...
args
)
{
template
<
typename
T
,
typename
...
Args
>
void
Set
(
Args
&&
...
args
)
{
helper
::
Destroy
(
type_id
,
&
data
);
helper
::
Destroy
(
type_id
,
&
data
);
new
(
&
data
)
T
(
std
::
forward
<
Args
>
(
args
)...);
new
(
&
data
)
T
(
std
::
forward
<
Args
>
(
args
)...);
type_id
=
typeid
(
T
).
hash_code
();
type_id
=
typeid
(
T
).
hash_code
();
}
}
template
<
typename
T
>
T
&
Get
()
const
{
template
<
typename
T
>
T
&
Get
()
const
{
if
(
type_id
==
typeid
(
T
).
hash_code
())
{
if
(
type_id
==
typeid
(
T
).
hash_code
())
{
return
*
const_cast
<
T
*>
(
reinterpret_cast
<
const
T
*>
(
&
data
));
return
*
const_cast
<
T
*>
(
reinterpret_cast
<
const
T
*>
(
&
data
));
}
else
{
}
else
{
// std::cout << " bad cast in variant " << std::endl;
// std::cout << " bad cast in variant " << std::endl;
throw
std
::
bad_cast
();
throw
std
::
bad_cast
();
}
}
}
}
size_t
TypeId
()
const
{
return
type_id
;
}
size_t
TypeId
()
const
{
return
type_id
;
}
private:
private:
static
inline
size_t
invalid_type
()
{
return
typeid
(
void
).
hash_code
();
}
static
inline
size_t
invalid_type
()
{
return
typeid
(
void
).
hash_code
();
}
typedef
VariantHelper
<
Ts
...
>
helper
;
typedef
VariantHelper
<
Ts
...
>
helper
;
size_t
type_id
;
size_t
type_id
;
RawData
<
helper
::
size
>
data
;
RawData
<
helper
::
size
>
data
;
};
};
template
<
typename
T
>
struct
Vistor
{
typedef
T
type_t
;
};
template
<
typename
T
>
struct
Vistor
{
typedef
T
type_t
;
};
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/attribute.h
浏览文件 @
47e5978f
...
@@ -27,104 +27,102 @@ namespace framework {
...
@@ -27,104 +27,102 @@ namespace framework {
class
BlockDesc
;
class
BlockDesc
;
class
Attribute
{
class
Attribute
{
public:
public:
static
Attribute
static
Attribute
GetAttrValue
(
const
proto
::
OpDesc
::
Attr
&
attr_desc
)
{
GetAttrValue
(
const
proto
::
OpDesc
::
Attr
&
attr_desc
)
{
// std::cout << "begin get attr value" << std::endl;
// std::cout << "begin get attr value" << std::endl;
Attribute
attr
;
Attribute
attr
;
switch
(
attr_desc
.
type
())
{
switch
(
attr_desc
.
type
())
{
case
proto
::
AttrType
::
BOOLEAN
:
{
case
proto
::
AttrType
::
BOOLEAN
:
{
attr
.
Set
<
bool
>
(
attr_desc
.
b
());
attr
.
Set
<
bool
>
(
attr_desc
.
b
());
break
;
break
;
}
case
proto
::
AttrType
::
INT
:
{
attr
.
Set
<
int
>
(
attr_desc
.
i
());
break
;
}
case
proto
::
AttrType
::
FLOAT
:
{
attr
.
Set
<
float
>
(
attr_desc
.
f
());
break
;
}
case
proto
::
AttrType
::
STRING
:
{
attr
.
Set
<
std
::
string
>
(
attr_desc
.
s
());
break
;
}
case
proto
::
AttrType
::
BOOLEANS
:
{
std
::
vector
<
bool
>
val
(
attr_desc
.
bools_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
bools_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
bools
(
i
);
}
attr
.
Set
<
std
::
vector
<
bool
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
INTS
:
{
std
::
vector
<
int
>
val
(
attr_desc
.
ints_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
ints_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
ints
(
i
);
}
attr
.
Set
<
std
::
vector
<
int
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
FLOATS
:
{
std
::
vector
<
float
>
val
(
attr_desc
.
floats_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
floats_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
floats
(
i
);
}
attr
.
Set
<
std
::
vector
<
float
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
STRINGS
:
{
std
::
vector
<
std
::
string
>
val
(
attr_desc
.
strings_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
strings_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
strings
(
i
);
}
attr
.
Set
<
std
::
vector
<
std
::
string
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
LONG
:
{
attr
.
Set
<
int64_t
>
(
attr_desc
.
l
());
break
;
}
default:
// std::cout << " not support " << std::endl;
break
;
}
// std::cout << "end get attr value" << std::endl;
return
attr
;
}
}
case
proto
::
AttrType
::
INT
:
{
attr
.
Set
<
int
>
(
attr_desc
.
i
());
break
;
}
case
proto
::
AttrType
::
FLOAT
:
{
attr
.
Set
<
float
>
(
attr_desc
.
f
());
break
;
}
case
proto
::
AttrType
::
STRING
:
{
attr
.
Set
<
std
::
string
>
(
attr_desc
.
s
());
break
;
}
case
proto
::
AttrType
::
BOOLEANS
:
{
std
::
vector
<
bool
>
val
(
attr_desc
.
bools_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
bools_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
bools
(
i
);
}
attr
.
Set
<
std
::
vector
<
bool
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
INTS
:
{
std
::
vector
<
int
>
val
(
attr_desc
.
ints_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
ints_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
ints
(
i
);
}
attr
.
Set
<
std
::
vector
<
int
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
FLOATS
:
{
std
::
vector
<
float
>
val
(
attr_desc
.
floats_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
floats_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
floats
(
i
);
}
attr
.
Set
<
std
::
vector
<
float
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
STRINGS
:
{
std
::
vector
<
std
::
string
>
val
(
attr_desc
.
strings_size
());
for
(
int
i
=
0
;
i
<
attr_desc
.
strings_size
();
++
i
)
{
val
[
i
]
=
attr_desc
.
strings
(
i
);
}
attr
.
Set
<
std
::
vector
<
std
::
string
>>
(
val
);
break
;
}
case
proto
::
AttrType
::
LONG
:
{
attr
.
Set
<
int64_t
>
(
attr_desc
.
l
());
break
;
}
default:
// std::cout << " not support " << std::endl;
break
;
}
// std::cout << "end get attr value" << std::endl;
return
attr
;
}
Attribute
()
{}
Attribute
()
{}
template
<
typename
T
,
typename
...
Args
>
template
<
typename
T
,
typename
...
Args
>
Attribute
&
Set
(
Args
&&
...
args
)
{
Attribute
&
Set
(
Args
&&
...
args
)
{
variant_
.
Set
<
T
>
(
args
...);
variant_
.
Set
<
T
>
(
args
...);
return
*
this
;
return
*
this
;
}
}
template
<
typename
T
>
T
&
Get
()
const
{
return
variant_
.
Get
<
T
>
();
}
template
<
typename
T
>
T
&
Get
()
const
{
return
variant_
.
Get
<
T
>
();
}
private:
private:
Variant
<
int
,
float
,
std
::
string
,
std
::
vector
<
in
t
>
,
Variant
<
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
std
::
vector
<
floa
t
>
,
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
bool
>
,
BlockDesc
*
,
std
::
vector
<
bool
>
,
BlockDesc
*
,
int64_t
>
int64_t
>
variant_
;
variant_
;
};
};
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
class
AttrReader
{
class
AttrReader
{
public:
public:
explicit
AttrReader
(
const
AttributeMap
&
attrs
)
:
attrs_
(
attrs
)
{}
explicit
AttrReader
(
const
AttributeMap
&
attrs
)
:
attrs_
(
attrs
)
{}
template
<
typename
T
>
inline
T
Get
(
const
std
::
string
&
name
)
const
{
template
<
typename
T
>
inline
T
Get
(
const
std
::
string
&
name
)
const
{
// PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should
// PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should
// be in
// be in
// AttributeMap",
// AttributeMap",
// name);
// name);
return
((
Attribute
)
attrs_
.
at
(
name
)).
Get
<
T
>
();
return
((
Attribute
)
attrs_
.
at
(
name
)).
Get
<
T
>
();
}
}
private:
private:
const
AttributeMap
&
attrs_
;
const
AttributeMap
&
attrs_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
src/framework/block_desc.cpp
浏览文件 @
47e5978f
...
@@ -22,28 +22,28 @@ namespace paddle_mobile {
...
@@ -22,28 +22,28 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
BlockDesc
::
Vars
()
const
{
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
BlockDesc
::
Vars
()
const
{
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
res
;
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
res
;
for
(
const
auto
&
p
:
vars_
)
{
for
(
const
auto
&
p
:
vars_
)
{
res
.
push_back
(
p
.
second
);
res
.
push_back
(
p
.
second
);
}
}
return
res
;
return
res
;
}
}
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
BlockDesc
::
Ops
()
const
{
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
BlockDesc
::
Ops
()
const
{
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
res
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
res
;
for
(
const
auto
&
op
:
ops_
)
{
for
(
const
auto
&
op
:
ops_
)
{
res
.
push_back
(
op
);
res
.
push_back
(
op
);
}
}
return
res
;
return
res
;
}
}
BlockDesc
::
BlockDesc
(
const
proto
::
BlockDesc
&
desc
)
:
desc_
(
desc
)
{
BlockDesc
::
BlockDesc
(
const
proto
::
BlockDesc
&
desc
)
:
desc_
(
desc
)
{
for
(
const
proto
::
VarDesc
&
var_desc
:
desc_
.
vars
())
{
for
(
const
proto
::
VarDesc
&
var_desc
:
desc_
.
vars
())
{
vars_
[
var_desc
.
name
()].
reset
(
new
VarDesc
(
var_desc
));
vars_
[
var_desc
.
name
()].
reset
(
new
VarDesc
(
var_desc
));
}
}
for
(
const
proto
::
OpDesc
&
op_desc
:
desc_
.
ops
())
{
for
(
const
proto
::
OpDesc
&
op_desc
:
desc_
.
ops
())
{
ops_
.
emplace_back
(
new
framework
::
OpDesc
(
op_desc
));
ops_
.
emplace_back
(
new
framework
::
OpDesc
(
op_desc
));
}
}
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/block_desc.h
浏览文件 @
47e5978f
...
@@ -27,32 +27,29 @@ namespace paddle_mobile {
...
@@ -27,32 +27,29 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
BlockDesc
:
PaddleMobileObject
{
class
BlockDesc
:
PaddleMobileObject
{
public:
public:
BlockDesc
(
const
proto
::
BlockDesc
&
desc
);
BlockDesc
(
const
proto
::
BlockDesc
&
desc
);
const
int
&
ID
()
const
{
return
desc_
.
idx
();
}
const
int
&
ID
()
const
{
return
desc_
.
idx
();
}
const
int
&
Parent
()
const
{
return
desc_
.
parent_idx
();
}
const
int
&
Parent
()
const
{
return
desc_
.
parent_idx
();
}
bool
operator
==
(
bool
operator
==
(
const
paddle_mobile
::
framework
::
BlockDesc
&
in_block
)
const
{
const
paddle_mobile
::
framework
::
BlockDesc
&
in_block
)
const
{
return
this
->
ID
()
==
in_block
.
ID
()
&&
return
this
->
ID
()
==
in_block
.
ID
()
&&
this
->
Parent
()
==
in_block
.
Parent
();
this
->
Parent
()
==
in_block
.
Parent
();
}
}
bool
operator
<
(
bool
operator
<
(
const
paddle_mobile
::
framework
::
BlockDesc
&
in_block
)
const
{
const
paddle_mobile
::
framework
::
BlockDesc
&
in_block
)
const
{
return
this
->
ID
()
<
in_block
.
ID
()
&&
this
->
Parent
()
<
in_block
.
Parent
();
return
this
->
ID
()
<
in_block
.
ID
()
&&
}
this
->
Parent
()
<
in_block
.
Parent
();
}
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
Vars
()
const
;
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
Vars
()
const
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
Ops
()
const
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
Ops
()
const
;
private:
private:
proto
::
BlockDesc
desc_
;
proto
::
BlockDesc
desc_
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
ops_
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
ops_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarDesc
>>
vars_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarDesc
>>
vars_
;
};
};
}
// namespace framework
}
// namespace framework
...
@@ -60,14 +57,14 @@ private:
...
@@ -60,14 +57,14 @@ private:
namespace
std
{
namespace
std
{
template
<
>
struct
hash
<
paddle_mobile
::
framework
::
BlockDesc
>
{
template
<
>
struct
hash
<
paddle_mobile
::
framework
::
BlockDesc
>
{
typedef
paddle_mobile
::
framework
::
BlockDesc
argument_type
;
typedef
paddle_mobile
::
framework
::
BlockDesc
argument_type
;
typedef
std
::
size_t
result_type
;
typedef
std
::
size_t
result_type
;
result_type
operator
()(
argument_type
const
&
s
)
const
noexcept
{
result_type
operator
()(
argument_type
const
&
s
)
const
noexcept
{
result_type
const
h1
(
std
::
hash
<
int
>
{}(
s
.
ID
()));
result_type
const
h1
(
std
::
hash
<
int
>
{}(
s
.
ID
()));
result_type
const
h2
(
std
::
hash
<
int
>
{}(
s
.
ID
()));
result_type
const
h2
(
std
::
hash
<
int
>
{}(
s
.
ID
()));
return
h1
^
(
h2
<<
1
);
return
h1
^
(
h2
<<
1
);
}
}
};
};
}
// namespace std
}
// namespace std
src/framework/data_layout.h
浏览文件 @
47e5978f
...
@@ -22,42 +22,45 @@ namespace paddle_mobile {
...
@@ -22,42 +22,45 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
enum
class
DataLayout
{
enum
class
DataLayout
{
kNHWC
=
0
,
kNHWC
=
0
,
kNCHW
=
1
,
kNCHW
=
1
,
kAnyLayout
=
2
,
kAnyLayout
=
2
,
};
};
inline
DataLayout
StringToDataLayout
(
const
std
::
string
&
str
)
{
inline
DataLayout
StringToDataLayout
(
const
std
::
string
&
str
)
{
std
::
string
s
(
str
);
std
::
string
s
(
str
);
for
(
size_t
i
=
0
;
i
<
s
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
s
.
size
();
++
i
)
{
s
[
i
]
=
toupper
(
s
[
i
]);
s
[
i
]
=
toupper
(
s
[
i
]);
}
}
if
(
s
==
"NHWC"
)
{
if
(
s
==
"NHWC"
)
{
return
DataLayout
::
kNHWC
;
return
DataLayout
::
kNHWC
;
}
else
if
(
s
==
"NCHW"
)
{
}
else
if
(
s
==
"NCHW"
)
{
return
DataLayout
::
kNCHW
;
return
DataLayout
::
kNCHW
;
}
else
if
(
s
==
"ANYLAYOUT"
)
{
}
else
if
(
s
==
"ANYLAYOUT"
)
{
return
DataLayout
::
kAnyLayout
;
return
DataLayout
::
kAnyLayout
;
}
else
{
}
else
{
// std::cout << "Unknown storage order string: %s", s;
// std::cout << "Unknown storage order string: %s", s;
}
}
}
}
inline
std
::
string
DataLayoutToString
(
const
DataLayout
&
data_layout
)
{
inline
std
::
string
DataLayoutToString
(
const
DataLayout
&
data_layout
)
{
switch
(
data_layout
)
{
switch
(
data_layout
)
{
case
DataLayout
::
kNHWC
:
return
"NHWC"
;
case
DataLayout
::
kNHWC
:
case
DataLayout
::
kNCHW
:
return
"NCHW"
;
return
"NHWC"
;
case
DataLayout
::
kAnyLayout
:
return
"ANY_LAYOUT"
;
case
DataLayout
::
kNCHW
:
default:
break
;
return
"NCHW"
;
// std::cout << "unknown DataLayou %d", data_layout;
case
DataLayout
::
kAnyLayout
:
}
return
"ANY_LAYOUT"
;
default:
break
;
// std::cout << "unknown DataLayou %d", data_layout;
}
}
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
DataLayout
&
l
)
{
const
DataLayout
&
l
)
{
out
<<
DataLayoutToString
(
l
);
out
<<
DataLayoutToString
(
l
);
return
out
;
return
out
;
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/data_transform.cpp
浏览文件 @
47e5978f
...
@@ -24,68 +24,68 @@ namespace paddle_mobile {
...
@@ -24,68 +24,68 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
static
void
PassTensorData
(
Tensor
*
from
,
Tensor
*
to
)
{
static
void
PassTensorData
(
Tensor
*
from
,
Tensor
*
to
)
{
to
->
ShareDataWith
(
*
from
);
to
->
ShareDataWith
(
*
from
);
*
from
=
Tensor
();
*
from
=
Tensor
();
}
}
void
DataTransform
(
const
OpKernelType
&
expected_kernel_type
,
void
DataTransform
(
const
OpKernelType
&
expected_kernel_type
,
const
OpKernelType
&
kernel_type_for_var
,
const
OpKernelType
&
kernel_type_for_var
,
const
Tensor
&
input_tensor
,
Tensor
*
output_tensor
)
{
const
Tensor
&
input_tensor
,
Tensor
*
output_tensor
)
{
bool
transformed
=
false
;
bool
transformed
=
false
;
Tensor
in
;
Tensor
in
;
in
.
ShareDataWith
(
input_tensor
);
in
.
ShareDataWith
(
input_tensor
);
Tensor
out
;
Tensor
out
;
// // do layout transform
// // do layout transform
// if (NeedTransformLayout(expected_kernel_type.data_layout_,
// if (NeedTransformLayout(expected_kernel_type.data_layout_,
// kernel_type_for_var.data_layout_)) {
// kernel_type_for_var.data_layout_)) {
// TransDataLayout(kernel_type_for_var, expected_kernel_type, in,
// TransDataLayout(kernel_type_for_var, expected_kernel_type, in,
// &out);
// &out);
// transformed = true;
// transformed = true;
// PassTensorData(&out, &in);
// PassTensorData(&out, &in);
// }
// }
//
//
// // do data type transform
// // do data type transform
// if (expected_kernel_type.data_type_ !=
// if (expected_kernel_type.data_type_ !=
// kernel_type_for_var.data_type_) {
// kernel_type_for_var.data_type_) {
// TransDataType(kernel_type_for_var, expected_kernel_type, in,
// TransDataType(kernel_type_for_var, expected_kernel_type, in,
// &out);
// &out);
// transformed = true;
// transformed = true;
// PassTensorData(&out, &in);
// PassTensorData(&out, &in);
// }
// }
//
//
// // do device transform
// // do device transform
// if (!platform::is_same_place(kernel_type_for_var.place_,
// if (!platform::is_same_place(kernel_type_for_var.place_,
// expected_kernel_type.place_)) {
// expected_kernel_type.place_)) {
// TransDataDevice(in, expected_kernel_type.place_, &out);
// TransDataDevice(in, expected_kernel_type.place_, &out);
// transformed = true;
// transformed = true;
// PassTensorData(&out, &in);
// PassTensorData(&out, &in);
// }
// }
//
//
// PADDLE_ENFORCE(transformed, "No transform is applied, please
// PADDLE_ENFORCE(transformed, "No transform is applied, please
// check!");
// check!");
// get output data
// get output data
output_tensor
->
ShareDataWith
(
in
);
output_tensor
->
ShareDataWith
(
in
);
}
}
void
CopyVariableWithTensor
(
const
Variable
&
in_var
,
void
CopyVariableWithTensor
(
const
Variable
&
in_var
,
const
Tensor
&
tensor
,
const
Tensor
&
tensor
,
Variable
&
out_var
)
{
Variable
&
out_var
)
{
// if (in_var.IsType<LoDTensor>()) {
// if (in_var.IsType<LoDTensor>()) {
// auto& in_lod_tensor = in_var.Get<LoDTensor>();
// auto& in_lod_tensor = in_var.Get<LoDTensor>();
// auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
// auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
// tran_lod_tensor->set_lod(in_lod_tensor.lod());
// tran_lod_tensor->set_lod(in_lod_tensor.lod());
// tran_lod_tensor->set_layout(in_lod_tensor.layout());
// tran_lod_tensor->set_layout(in_lod_tensor.layout());
// tran_lod_tensor->ShareDataWith(tensor);
// tran_lod_tensor->ShareDataWith(tensor);
// } else if (in_var.IsType<SelectedRows>()) {
// } else if (in_var.IsType<SelectedRows>()) {
// auto& in_selected_rows = in_var.Get<SelectedRows>();
// auto& in_selected_rows = in_var.Get<SelectedRows>();
// auto* trans_selected_rows =
// auto* trans_selected_rows =
// out_var.GetMutable<SelectedRows>();
// out_var.GetMutable<SelectedRows>();
// trans_selected_rows->set_height(in_selected_rows.height());
// trans_selected_rows->set_height(in_selected_rows.height());
// trans_selected_rows->set_rows(in_selected_rows.rows());
// trans_selected_rows->set_rows(in_selected_rows.rows());
// trans_selected_rows->mutable_value()->ShareDataWith(tensor);
// trans_selected_rows->mutable_value()->ShareDataWith(tensor);
// } else {
// } else {
// PADDLE_THROW("unknown var type");
// PADDLE_THROW("unknown var type");
// }
// }
}
}
}
// namespace framework
}
// namespace framework
...
...
src/framework/data_transform.h
浏览文件 @
47e5978f
...
@@ -28,14 +28,14 @@ SOFTWARE.
...
@@ -28,14 +28,14 @@ SOFTWARE.
#include "variable.h"
#include "variable.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
void
DataTransform
(
const
OpKernelType
&
expected_kernel_type
,
void
DataTransform
(
const
OpKernelType
&
expected_kernel_type
,
const
OpKernelType
&
kernel_type_for_var
,
const
OpKernelType
&
kernel_type_for_var
,
const
Tensor
&
input_tensor
,
Tensor
*
out
);
const
Tensor
&
input_tensor
,
Tensor
*
out
);
void
CopyVariableWithTensor
(
const
Variable
&
in_va
r
,
void
CopyVariableWithTensor
(
const
Variable
&
in_var
,
const
Tensor
&
tenso
r
,
const
Tensor
&
tensor
,
Variable
&
out_var
);
Variable
&
out_var
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/data_type.h
浏览文件 @
47e5978f
...
@@ -21,23 +21,23 @@ SOFTWARE.
...
@@ -21,23 +21,23 @@ SOFTWARE.
#include "framework.pb.h"
#include "framework.pb.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
// inline proto::VarType::Type ToDataType(std::type_index type) {
// inline proto::VarType::Type ToDataType(std::type_index type) {
// using namespace paddle_mobile::framework::proto;
// using namespace paddle_mobile::framework::proto;
// if (typeid(float).hash_code() == type.hash_code()) {
// if (typeid(float).hash_code() == type.hash_code()) {
// return proto::VarType::FP32;
// return proto::VarType::FP32;
// } else if (typeid(double).hash_code() == type.hash_code()) {
// } else if (typeid(double).hash_code() == type.hash_code()) {
// return proto::VarType::FP64;
// return proto::VarType::FP64;
// } else if (typeid(int).hash_code() == type.hash_code()) {
// } else if (typeid(int).hash_code() == type.hash_code()) {
// return proto::VarType::INT32;
// return proto::VarType::INT32;
// } else if (typeid(int64_t).hash_code() == type.hash_code()) {
// } else if (typeid(int64_t).hash_code() == type.hash_code()) {
// return proto::VarType::INT64;
// return proto::VarType::INT64;
// } else if (typeid(bool).hash_code() == type.hash_code()) {
// } else if (typeid(bool).hash_code() == type.hash_code()) {
// return proto::VarType::BOOL;
// return proto::VarType::BOOL;
// } else {
// } else {
//// PADDLE_THROW("Not supported");
//// PADDLE_THROW("Not supported");
// }
// }
// }
// }
}
}
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/ddim.cc
浏览文件 @
47e5978f
...
@@ -15,320 +15,318 @@ limitations under the License. */
...
@@ -15,320 +15,318 @@ limitations under the License. */
#include "ddim.h"
#include "ddim.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
/// @cond HIDDEN
/// @cond HIDDEN
template
<
int
i
>
Dim
<
i
>
make_dim
(
const
int64_t
*
d
)
{
template
<
int
i
>
Dim
<
i
>
make_dim
(
const
int64_t
*
d
)
{
return
Dim
<
i
>
(
*
d
,
make_dim
<
i
-
1
>
(
d
+
1
));
return
Dim
<
i
>
(
*
d
,
make_dim
<
i
-
1
>
(
d
+
1
));
}
}
template
<
>
Dim
<
0
>
make_dim
<
0
>
(
const
int64_t
*
d
)
{
return
Dim
<
0
>
(
*
d
);
}
template
<
>
Dim
<
0
>
make_dim
<
0
>
(
const
int64_t
*
d
)
{
return
Dim
<
0
>
(
*
d
);
}
void
make_ddim
(
DDim
&
ddim
,
const
int64_t
*
dims
,
int
n
)
{
void
make_ddim
(
DDim
&
ddim
,
const
int64_t
*
dims
,
int
n
)
{
switch
(
n
)
{
switch
(
n
)
{
case
0
:
case
0
:
ddim
=
make_dim
<
0
>
(
dims
);
ddim
=
make_dim
<
0
>
(
dims
);
break
;
break
;
case
1
:
case
1
:
ddim
=
make_dim
<
1
>
(
dims
);
ddim
=
make_dim
<
1
>
(
dims
);
break
;
break
;
case
2
:
case
2
:
ddim
=
make_dim
<
2
>
(
dims
);
ddim
=
make_dim
<
2
>
(
dims
);
break
;
break
;
case
3
:
case
3
:
ddim
=
make_dim
<
3
>
(
dims
);
ddim
=
make_dim
<
3
>
(
dims
);
break
;
break
;
case
4
:
case
4
:
ddim
=
make_dim
<
4
>
(
dims
);
ddim
=
make_dim
<
4
>
(
dims
);
break
;
break
;
case
5
:
case
5
:
ddim
=
make_dim
<
5
>
(
dims
);
ddim
=
make_dim
<
5
>
(
dims
);
break
;
break
;
case
6
:
case
6
:
ddim
=
make_dim
<
6
>
(
dims
);
ddim
=
make_dim
<
6
>
(
dims
);
break
;
break
;
case
7
:
case
7
:
ddim
=
make_dim
<
7
>
(
dims
);
ddim
=
make_dim
<
7
>
(
dims
);
break
;
break
;
case
8
:
case
8
:
ddim
=
make_dim
<
8
>
(
dims
);
ddim
=
make_dim
<
8
>
(
dims
);
break
;
break
;
case
9
:
case
9
:
ddim
=
make_dim
<
9
>
(
dims
);
ddim
=
make_dim
<
9
>
(
dims
);
break
;
break
;
default:
default:
// std::cout << "Dynamic dimensions must have between [1,
// std::cout << "Dynamic dimensions must have between [1,
// 9]
// 9]
// dimensions.";
// dimensions.";
break
;
break
;
}
}
}
}
/// @endcond
/// @endcond
DDim
make_ddim
(
std
::
initializer_list
<
int64_t
>
dims
)
{
DDim
make_ddim
(
std
::
initializer_list
<
int64_t
>
dims
)
{
DDim
result
(
make_dim
(
0
));
DDim
result
(
make_dim
(
0
));
make_ddim
(
result
,
dims
.
begin
(),
dims
.
size
());
make_ddim
(
result
,
dims
.
begin
(),
dims
.
size
());
return
result
;
return
result
;
}
}
DDim
make_ddim
(
const
std
::
vector
<
int64_t
>
&
dims
)
{
DDim
make_ddim
(
const
std
::
vector
<
int64_t
>
&
dims
)
{
DDim
result
(
make_dim
(
0
));
DDim
result
(
make_dim
(
0
));
make_ddim
(
result
,
&
dims
[
0
],
dims
.
size
());
make_ddim
(
result
,
&
dims
[
0
],
dims
.
size
());
return
result
;
return
result
;
}
}
DDim
make_ddim
(
const
std
::
vector
<
int
>
&
dims
)
{
DDim
make_ddim
(
const
std
::
vector
<
int
>
&
dims
)
{
std
::
vector
<
int64_t
>
res
(
dims
.
size
());
std
::
vector
<
int64_t
>
res
(
dims
.
size
());
std
::
transform
(
dims
.
begin
(),
dims
.
end
(),
res
.
begin
(),
std
::
transform
(
dims
.
begin
(),
dims
.
end
(),
res
.
begin
(),
[](
int
d
)
{
return
static_cast
<
int64_t
>
(
d
);
});
[](
int
d
)
{
return
static_cast
<
int64_t
>
(
d
);
});
return
make_ddim
(
res
);
return
make_ddim
(
res
);
}
}
/// @cond HIDDEN
/// @cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes
// XXX For some reason, putting this in an anonymous namespace causes
// errors
// errors
struct
DynamicMutableIndexer
:
Vistor
<
int64_t
&>
{
struct
DynamicMutableIndexer
:
Vistor
<
int64_t
&>
{
public:
public:
explicit
DynamicMutableIndexer
(
int
idx
)
:
idx_
(
idx
)
{}
explicit
DynamicMutableIndexer
(
int
idx
)
:
idx_
(
idx
)
{}
template
<
int
D
>
int64_t
&
operator
()(
Dim
<
D
>
&
dim
)
const
{
template
<
int
D
>
int64_t
&
operator
()(
Dim
<
D
>
&
dim
)
const
{
return
dim
[
idx_
];
return
dim
[
idx_
];
}
}
private:
private:
int
idx_
;
int
idx_
;
};
};
struct
DynamicConstIndexer
:
public
Vistor
<
int64_t
>
{
struct
DynamicConstIndexer
:
public
Vistor
<
int64_t
>
{
public:
public:
explicit
DynamicConstIndexer
(
int
idx
)
:
idx_
(
idx
)
{}
explicit
DynamicConstIndexer
(
int
idx
)
:
idx_
(
idx
)
{}
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>
&
dim
)
const
{
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>
&
dim
)
const
{
return
dim
[
idx_
];
return
dim
[
idx_
];
}
}
private:
private:
int
idx_
;
int
idx_
;
};
};
/// @endcond
/// @endcond
int64_t
&
DDim
::
operator
[](
int
idx
)
{
int64_t
&
DDim
::
operator
[](
int
idx
)
{
return
DDim
::
ApplyVistor
(
DynamicMutableIndexer
(
idx
),
*
this
);
return
DDim
::
ApplyVistor
(
DynamicMutableIndexer
(
idx
),
*
this
);
}
}
int64_t
DDim
::
operator
[](
int
idx
)
const
{
int64_t
DDim
::
operator
[](
int
idx
)
const
{
return
DDim
::
ApplyVistor
(
DynamicConstIndexer
(
idx
),
*
this
);
return
DDim
::
ApplyVistor
(
DynamicConstIndexer
(
idx
),
*
this
);
}
}
int
DDim
::
size
()
const
{
return
arity
(
*
this
);
}
int
DDim
::
size
()
const
{
return
arity
(
*
this
);
}
bool
DDim
::
operator
==
(
DDim
d
)
const
{
bool
DDim
::
operator
==
(
DDim
d
)
const
{
// if (var.which() != d.getVar().which()) {
// if (var.which() != d.getVar().which()) {
// return false;
// return false;
// } else {
// } else {
std
::
vector
<
int64_t
>
v1
=
vectorize
(
*
this
);
std
::
vector
<
int64_t
>
v1
=
vectorize
(
*
this
);
std
::
vector
<
int64_t
>
v2
=
vectorize
(
d
);
std
::
vector
<
int64_t
>
v2
=
vectorize
(
d
);
for
(
unsigned
int
i
=
0
;
i
<
v1
.
size
();
i
++
)
{
for
(
unsigned
int
i
=
0
;
i
<
v1
.
size
();
i
++
)
{
if
(
v1
[
i
]
!=
v2
[
i
])
{
if
(
v1
[
i
]
!=
v2
[
i
])
{
return
false
;
return
false
;
}
}
return
true
;
// }
}
bool
DDim
::
operator
!=
(
DDim
d
)
const
{
return
!
(
*
this
==
d
);
}
DDim
DDim
::
operator
+
(
DDim
d
)
const
{
std
::
vector
<
int64_t
>
v1
=
vectorize
(
*
this
);
std
::
vector
<
int64_t
>
v2
=
vectorize
(
d
);
std
::
vector
<
int64_t
>
v3
;
assert
(
v1
.
size
()
==
v2
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
v1
.
size
();
i
++
)
{
v3
.
push_back
(
v1
[
i
]
+
v2
[
i
]);
}
return
make_ddim
(
v3
);
}
DDim
DDim
::
operator
*
(
DDim
d
)
const
{
std
::
vector
<
int64_t
>
v1
=
vectorize
(
*
this
);
std
::
vector
<
int64_t
>
v2
=
vectorize
(
d
);
std
::
vector
<
int64_t
>
v3
;
assert
(
v1
.
size
()
==
v2
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
v1
.
size
();
i
++
)
{
v3
.
push_back
(
v1
[
i
]
*
v2
[
i
]);
}
return
make_ddim
(
v3
);
}
}
}
int64_t
get
(
const
DDim
&
ddim
,
int
idx
)
{
return
ddim
[
idx
];
}
return
true
;
// }
void
set
(
DDim
&
ddim
,
int
idx
,
int
value
)
{
ddim
[
idx
]
=
value
;
}
}
/// @cond HIDDEN
bool
DDim
::
operator
!=
(
DDim
d
)
const
{
return
!
(
*
this
==
d
);
}
struct
VectorizeVisitor
:
Vistor
<
void
>
{
std
::
vector
<
int64_t
>
&
vector
;
DDim
DDim
::
operator
+
(
DDim
d
)
const
{
std
::
vector
<
int64_t
>
v1
=
vectorize
(
*
this
);
explicit
VectorizeVisitor
(
std
::
vector
<
int64_t
>
&
v
)
:
vector
(
v
)
{}
std
::
vector
<
int64_t
>
v2
=
vectorize
(
d
);
template
<
typename
T
>
void
operator
()(
const
T
&
t
)
{
std
::
vector
<
int64_t
>
v3
;
vector
.
push_back
(
t
.
head
);
this
->
operator
()(
t
.
tail
);
assert
(
v1
.
size
()
==
v2
.
size
());
}
for
(
unsigned
int
i
=
0
;
i
<
v1
.
size
();
i
++
)
{
void
operator
()(
const
Dim
<
0
>
&
t
)
{}
v3
.
push_back
(
v1
[
i
]
+
v2
[
i
]);
};
}
/// @endcond
return
make_ddim
(
v3
);
std
::
vector
<
int64_t
>
vectorize
(
const
DDim
&
ddim
)
{
}
std
::
vector
<
int64_t
>
result
;
VectorizeVisitor
visitor
(
result
);
DDim
DDim
::
operator
*
(
DDim
d
)
const
{
DDim
::
ApplyVistor
(
visitor
,
ddim
);
std
::
vector
<
int64_t
>
v1
=
vectorize
(
*
this
);
return
result
;
std
::
vector
<
int64_t
>
v2
=
vectorize
(
d
);
std
::
vector
<
int64_t
>
v3
;
assert
(
v1
.
size
()
==
v2
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
v1
.
size
();
i
++
)
{
v3
.
push_back
(
v1
[
i
]
*
v2
[
i
]);
}
return
make_ddim
(
v3
);
}
int64_t
get
(
const
DDim
&
ddim
,
int
idx
)
{
return
ddim
[
idx
];
}
void
set
(
DDim
&
ddim
,
int
idx
,
int
value
)
{
ddim
[
idx
]
=
value
;
}
/// @cond HIDDEN
struct
VectorizeVisitor
:
Vistor
<
void
>
{
std
::
vector
<
int64_t
>
&
vector
;
explicit
VectorizeVisitor
(
std
::
vector
<
int64_t
>
&
v
)
:
vector
(
v
)
{}
template
<
typename
T
>
void
operator
()(
const
T
&
t
)
{
vector
.
push_back
(
t
.
head
);
this
->
operator
()(
t
.
tail
);
}
void
operator
()(
const
Dim
<
0
>
&
t
)
{}
};
/// @endcond
std
::
vector
<
int64_t
>
vectorize
(
const
DDim
&
ddim
)
{
std
::
vector
<
int64_t
>
result
;
VectorizeVisitor
visitor
(
result
);
DDim
::
ApplyVistor
(
visitor
,
ddim
);
return
result
;
}
// NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs.
std
::
vector
<
int
>
vectorize2int
(
const
DDim
&
ddim
)
{
std
::
vector
<
int64_t
>
temp
=
vectorize
(
ddim
);
std
::
vector
<
int
>
result
(
temp
.
begin
(),
temp
.
end
());
return
result
;
}
struct
ProductVisitor
:
Vistor
<
int64_t
>
{
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>
&
dim
)
{
return
product
(
dim
);
}
};
int64_t
product
(
const
DDim
&
ddim
)
{
ProductVisitor
visitor
;
return
DDim
::
ApplyVistor
(
visitor
,
ddim
);
}
struct
SliceVectorizeVisitor
:
Vistor
<
void
>
{
std
::
vector
<
int64_t
>
&
vector
;
int
begin
;
int
end
;
SliceVectorizeVisitor
(
std
::
vector
<
int64_t
>
&
v
,
int
b
,
int
e
)
:
vector
(
v
),
begin
(
b
),
end
(
e
)
{
// PADDLE_ENFORCE(begin < end,
// "Begin index must be less than end index in
// ddim
// slice.");
// PADDLE_ENFORCE(begin >= 0,
// "Begin index can't be less than zero in
// ddim slice.");
}
template
<
int
S
>
void
operator
()(
const
Dim
<
S
>
&
dim
)
{
if
(
begin
==
0
)
{
vector
.
push_back
(
dim
.
head
);
}
else
{
--
begin
;
}
}
--
end
;
// NOTE: framework::vectorize converts to type int64_t
if
(
end
>
0
)
{
// which does not fit cudnn inputs.
this
->
operator
()(
dim
.
tail
);
std
::
vector
<
int
>
vectorize2int
(
const
DDim
&
ddim
)
{
std
::
vector
<
int64_t
>
temp
=
vectorize
(
ddim
);
std
::
vector
<
int
>
result
(
temp
.
begin
(),
temp
.
end
());
return
result
;
}
}
}
struct
ProductVisitor
:
Vistor
<
int64_t
>
{
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>
&
dim
)
{
void
operator
()(
const
Dim
<
0
>
&
dim
)
{
return
product
(
dim
);
// PADDLE_ENFORCE(end == 0, "End index in ddim slice is out
}
// of bound.");
};
}
};
int64_t
product
(
const
DDim
&
ddim
)
{
ProductVisitor
visitor
;
DDim
slice_ddim
(
const
DDim
&
ddim
,
int
begin
,
int
end
)
{
return
DDim
::
ApplyVistor
(
visitor
,
ddim
);
std
::
vector
<
int64_t
>
vec
;
}
vec
.
reserve
(
end
-
begin
);
SliceVectorizeVisitor
visitor
(
vec
,
begin
,
end
);
struct
SliceVectorizeVisitor
:
Vistor
<
void
>
{
// boost::apply_visitor(visitor, dim);
std
::
vector
<
int64_t
>
&
vector
;
DDim
::
ApplyVistor
(
visitor
,
ddim
);
int
begin
;
// visitor(ddim.var.Get<Dim<4>>());
int
end
;
return
make_ddim
(
vec
);
}
SliceVectorizeVisitor
(
std
::
vector
<
int64_t
>
&
v
,
int
b
,
int
e
)
:
vector
(
v
),
begin
(
b
),
end
(
e
)
{
/// \cond HIDDEN
// PADDLE_ENFORCE(begin < end,
// "Begin index must be less than end index in
struct
ArityVisitor
:
Vistor
<
int
>
{
// ddim
template
<
int
D
>
int
operator
()(
Dim
<
D
>
)
const
{
return
D
;
}
// slice.");
};
// PADDLE_ENFORCE(begin >= 0,
// "Begin index can't be less than zero in
/// \endcond
// ddim slice.");
}
int
arity
(
const
DDim
&
d
)
{
ArityVisitor
arityVisitor
=
ArityVisitor
();
template
<
int
S
>
void
operator
()(
const
Dim
<
S
>
&
dim
)
{
return
DDim
::
ApplyVistor
(
arityVisitor
,
d
);
if
(
begin
==
0
)
{
// return arityVisitor(d.var.Get<Dim<4>>());
vector
.
push_back
(
dim
.
head
);
// return boost::apply_visitor(ArityVisitor(), d); }
}
else
{
}
--
begin
;
/// \cond HIDDEN
}
--
end
;
/// \endcond
if
(
end
>
0
)
{
this
->
operator
()(
dim
.
tail
);
struct
OSVistor
:
Vistor
<
std
::
ostream
&>
{
}
OSVistor
(
std
::
ostream
&
os
)
:
os_
(
os
)
{}
}
template
<
int
D
>
std
::
ostream
&
operator
()(
Dim
<
D
>
dim
)
const
{
void
operator
()(
const
Dim
<
0
>
&
dim
)
{
return
os_
<<
dim
;
// PADDLE_ENFORCE(end == 0, "End index in ddim slice is out
}
// of bound.");
}
private:
};
std
::
ostream
&
os_
;
};
DDim
slice_ddim
(
const
DDim
&
ddim
,
int
begin
,
int
end
)
{
std
::
vector
<
int64_t
>
vec
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
DDim
&
ddim
)
{
vec
.
reserve
(
end
-
begin
);
auto
vistor
=
OSVistor
(
os
);
SliceVectorizeVisitor
visitor
(
vec
,
begin
,
end
);
DDim
::
ApplyVistor
(
vistor
,
ddim
);
// boost::apply_visitor(visitor, dim);
return
os
;
DDim
::
ApplyVistor
(
visitor
,
ddim
);
}
// visitor(ddim.var.Get<Dim<4>>());
return
make_ddim
(
vec
);
DDim
::
DDim
(
std
::
initializer_list
<
int64_t
>
init_list
)
{
}
*
this
=
make_ddim
(
init_list
);
}
/// \cond HIDDEN
DDim
flatten_to_2d
(
const
DDim
&
src
,
int
num_col_dims
)
{
struct
ArityVisitor
:
Vistor
<
int
>
{
int
rank
=
src
.
size
();
template
<
int
D
>
int
operator
()(
Dim
<
D
>
)
const
{
return
D
;
}
return
make_ddim
({
product
(
slice_ddim
(
src
,
0
,
num_col_dims
)),
};
product
(
slice_ddim
(
src
,
num_col_dims
,
rank
))});
}
/// \endcond
DDim
flatten_to_1d
(
const
DDim
&
src
)
{
return
make_ddim
({
product
(
src
)});
}
int
arity
(
const
DDim
&
d
)
{
ArityVisitor
arityVisitor
=
ArityVisitor
();
DDim
stride
(
const
DDim
&
ddim
)
{
return
DDim
::
ApplyVistor
(
arityVisitor
,
d
);
std
::
vector
<
int64_t
>
strides
(
ddim
.
size
());
// return arityVisitor(d.var.Get<Dim<4>>());
strides
[
ddim
.
size
()
-
1
]
=
1
;
// return boost::apply_visitor(ArityVisitor(), d); }
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
}
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
+
1
];
/// \cond HIDDEN
}
return
framework
::
make_ddim
(
strides
);
/// \endcond
}
struct
OSVistor
:
Vistor
<
std
::
ostream
&>
{
DDim
stride_numel
(
const
framework
::
DDim
&
ddim
)
{
OSVistor
(
std
::
ostream
&
os
)
:
os_
(
os
)
{}
std
::
vector
<
int64_t
>
strides
(
ddim
.
size
());
strides
[
ddim
.
size
()
-
1
]
=
ddim
[
ddim
.
size
()
-
1
];
template
<
int
D
>
std
::
ostream
&
operator
()(
Dim
<
D
>
dim
)
const
{
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
return
os_
<<
dim
;
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
];
}
}
return
framework
::
make_ddim
(
strides
);
private:
}
std
::
ostream
&
os_
;
};
}
// namespace framework
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
DDim
&
ddim
)
{
auto
vistor
=
OSVistor
(
os
);
DDim
::
ApplyVistor
(
vistor
,
ddim
);
return
os
;
}
DDim
::
DDim
(
std
::
initializer_list
<
int64_t
>
init_list
)
{
*
this
=
make_ddim
(
init_list
);
}
DDim
flatten_to_2d
(
const
DDim
&
src
,
int
num_col_dims
)
{
int
rank
=
src
.
size
();
return
make_ddim
({
product
(
slice_ddim
(
src
,
0
,
num_col_dims
)),
product
(
slice_ddim
(
src
,
num_col_dims
,
rank
))});
}
DDim
flatten_to_1d
(
const
DDim
&
src
)
{
return
make_ddim
({
product
(
src
)});
}
DDim
stride
(
const
DDim
&
ddim
)
{
std
::
vector
<
int64_t
>
strides
(
ddim
.
size
());
strides
[
ddim
.
size
()
-
1
]
=
1
;
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
+
1
];
}
return
framework
::
make_ddim
(
strides
);
}
DDim
stride_numel
(
const
framework
::
DDim
&
ddim
)
{
std
::
vector
<
int64_t
>
strides
(
ddim
.
size
());
strides
[
ddim
.
size
()
-
1
]
=
ddim
[
ddim
.
size
()
-
1
];
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
];
}
return
framework
::
make_ddim
(
strides
);
}
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/ddim.h
浏览文件 @
47e5978f
...
@@ -22,145 +22,142 @@ limitations under the License. */
...
@@ -22,145 +22,142 @@ limitations under the License. */
#include <vector>
#include <vector>
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
/**
/**
* \brief A dynamically sized dimension.
* \brief A dynamically sized dimension.
*
*
* The number of dimensions must be between [1, 9].
* The number of dimensions must be between [1, 9].
*/
*/
struct
DDim
{
struct
DDim
{
typedef
Variant
<
Dim
<
0
>
,
Dim
<
1
>
,
Dim
<
2
>
,
Dim
<
3
>
,
Dim
<
4
>
,
Dim
<
5
>
,
typedef
Variant
<
Dim
<
0
>
,
Dim
<
1
>
,
Dim
<
2
>
,
Dim
<
3
>
,
Dim
<
4
>
,
Dim
<
5
>
,
Dim
<
6
>
,
Dim
<
6
>
,
Dim
<
7
>
,
Dim
<
8
>
,
Dim
<
9
>>
Dim
<
7
>
,
Dim
<
8
>
,
Dim
<
9
>>
DDimVar
;
DDimVar
;
DDimVar
var
;
DDimVar
var
;
template
<
typename
Vistor
>
template
<
typename
Vistor
>
static
typename
Vistor
::
type_t
ApplyVistor
(
Vistor
vistor
,
static
typename
Vistor
::
type_t
ApplyVistor
(
Vistor
vistor
,
const
DDim
&
d
)
{
const
DDim
&
d
)
{
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
0
>
).
hash_code
())
{
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
0
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
0
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
0
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
1
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
1
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
1
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
1
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
2
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
2
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
2
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
2
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
3
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
3
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
3
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
3
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
4
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
4
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
4
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
4
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
5
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
5
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
5
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
5
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
6
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
6
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
6
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
6
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
7
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
7
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
7
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
7
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
8
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
8
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
8
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
8
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
9
>
).
hash_code
())
{
}
else
if
(
d
.
var
.
TypeId
()
==
typeid
(
Dim
<
9
>
).
hash_code
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
9
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
9
>>
());
}
else
{
}
else
{
printf
(
" dim not support
\n
"
);
printf
(
" dim not support
\n
"
);
throw
std
::
bad_exception
();
throw
std
::
bad_exception
();
// return typename Vistor::type_t();
// return typename Vistor::type_t();
}
}
}
}
DDim
()
{
var
.
Set
<
Dim
<
1
>>
(
Dim
<
1
>
());
}
DDim
()
{
var
.
Set
<
Dim
<
1
>>
(
Dim
<
1
>
());
}
template
<
int
D
>
explicit
DDim
(
const
Dim
<
D
>
&
in
)
{
var
.
Set
<
Dim
<
D
>>
(
in
);
}
template
<
int
D
>
explicit
DDim
(
const
Dim
<
D
>
&
in
)
{
var
.
Set
<
Dim
<
D
>>
(
in
);
/*implicit*/
DDim
(
std
::
initializer_list
<
int64_t
>
init_list
);
}
template
<
int
D
>
DDim
&
operator
=
(
const
Dim
<
D
>
&
in
)
{
/*implicit*/
DDim
(
std
::
initializer_list
<
int64_t
>
init_list
);
var
.
Set
<
Dim
<
D
>>
(
in
);
return
*
this
;
template
<
int
D
>
DDim
&
operator
=
(
const
Dim
<
D
>
&
in
)
{
}
var
.
Set
<
Dim
<
D
>>
(
in
);
return
*
this
;
int64_t
&
operator
[](
int
idx
);
}
int64_t
operator
[](
int
idx
)
const
;
int64_t
&
operator
[](
int
idx
);
// template <typename Visitor>
int64_t
operator
[](
int
idx
)
const
;
// typename Visitor::result_type apply_visitor(Visitor& visitor) {
// return var.apply_visitor(visitor);
// template <typename Visitor>
// }
// typename Visitor::result_type apply_visitor(Visitor& visitor) {
//
// return var.apply_visitor(visitor);
// template <typename Visitor>
// }
// typename Visitor::result_type apply_visitor(Visitor& visitor)
//
// const {
// template <typename Visitor>
// return var.apply_visitor(visitor);
// typename Visitor::result_type apply_visitor(Visitor& visitor)
// }
// const {
// return var.apply_visitor(visitor);
DDimVar
getVar
()
{
return
var
;
}
// }
bool
operator
==
(
DDim
d
)
const
;
DDimVar
getVar
()
{
return
var
;
}
bool
operator
!=
(
DDim
d
)
const
;
bool
operator
==
(
DDim
d
)
const
;
DDim
operator
+
(
DDim
d
)
const
;
bool
operator
!=
(
DDim
d
)
const
;
DDim
operator
*
(
DDim
d
)
const
;
DDim
operator
+
(
DDim
d
)
const
;
int
size
()
const
;
DDim
operator
*
(
DDim
d
)
const
;
};
int
size
()
const
;
/**
};
* \brief Make a DDim from std::vector<int64_t>
*
/**
* \param dims An vector of ints. Must be sized between [1, 9]
* \brief Make a DDim from std::vector<int64_t>
*/
*
DDim
make_ddim
(
const
std
::
vector
<
int64_t
>
&
dims
);
* \param dims An vector of ints. Must be sized between [1, 9]
*/
DDim
make_ddim
(
const
std
::
vector
<
int
>
&
dims
);
DDim
make_ddim
(
const
std
::
vector
<
int64_t
>
&
dims
);
/**
DDim
make_ddim
(
const
std
::
vector
<
int
>
&
dims
);
* \brief Make a DDim from an initializer list
*
/**
* \param dims An initializer list of ints. Must be sized between [1, 9]
* \brief Make a DDim from an initializer list
*
*
*/
* \param dims An initializer list of ints. Must be sized between [1, 9]
DDim
make_ddim
(
std
::
initializer_list
<
int64_t
>
dims
);
*
*/
int64_t
get
(
const
DDim
&
dim
,
int
idx
);
DDim
make_ddim
(
std
::
initializer_list
<
int64_t
>
dims
);
void
set
(
DDim
&
dim
,
int
idx
,
int
val
);
int64_t
get
(
const
DDim
&
dim
,
int
idx
);
std
::
vector
<
int64_t
>
vectorize
(
const
DDim
&
ddim
);
void
set
(
DDim
&
dim
,
int
idx
,
int
val
);
std
::
vector
<
int
>
vectorize2int
(
const
DDim
&
ddim
);
std
::
vector
<
int64_t
>
vectorize
(
const
DDim
&
ddim
);
int64_t
product
(
const
DDim
&
ddim
);
std
::
vector
<
int
>
vectorize2int
(
const
DDim
&
ddim
);
/**
int64_t
product
(
const
DDim
&
ddim
);
* \brief Slice a ddim
*
/**
* Slice dim with [begin, end).
* \brief Slice a ddim
* e.g. DDim d = make_ddim({1,2,3,4,5});
*
* slice_ddim(d, 1, 3); ====> {2,3}
* Slice dim with [begin, end).
*/
* e.g. DDim d = make_ddim({1,2,3,4,5});
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
);
* slice_ddim(d, 1, 3); ====> {2,3}
*/
/**
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
);
* \brief What is the length of this dimension?
*
/**
* \param Dynamic dimension to inspect
* \brief What is the length of this dimension?
*/
*
* \param Dynamic dimension to inspect
*/
int
arity
(
const
DDim
&
ddim
);
int
arity
(
const
DDim
&
ddim
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
DDim
&
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
DDim
&
);
// Reshape a tensor to a matrix. The matrix's first dimension(column
// Reshape a tensor to a matrix. The matrix's first dimension(column
// length)
// length)
// will be the product of tensor's first `num_col_dims` dimensions.
// will be the product of tensor's first `num_col_dims` dimensions.
DDim
flatten_to_2d
(
const
DDim
&
src
,
int
num_col_dims
);
DDim
flatten_to_2d
(
const
DDim
&
src
,
int
num_col_dims
);
DDim
flatten_to_1d
(
const
DDim
&
src
);
DDim
flatten_to_1d
(
const
DDim
&
src
);
DDim
stride
(
const
DDim
&
ddim
);
DDim
stride
(
const
DDim
&
ddim
);
DDim
stride_numel
(
const
DDim
&
ddim
);
DDim
stride_numel
(
const
DDim
&
ddim
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/dim.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/framework/executor.cpp
浏览文件 @
47e5978f
...
@@ -23,75 +23,72 @@ SOFTWARE.
...
@@ -23,75 +23,72 @@ SOFTWARE.
#include "variable.h"
#include "variable.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
template
<
typename
Dtype
>
template
<
typename
Dtype
>
Executor
<
Dtype
>::
Executor
(
const
Program
<
Dtype
>
p
)
:
program_
(
p
)
{
Executor
<
Dtype
>::
Executor
(
const
Program
<
Dtype
>
p
)
:
program_
(
p
)
{
if
(
use_optimize_
)
{
if
(
use_optimize_
)
{
to_predict_program_
=
program_
.
optimizeProgram
;
to_predict_program_
=
program_
.
optimizeProgram
;
}
else
{
}
else
{
to_predict_program_
=
program_
.
originProgram
;
to_predict_program_
=
program_
.
originProgram
;
}
}
const
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
blocks
=
const
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
blocks
=
to_predict_program_
->
Blocks
();
to_predict_program_
->
Blocks
();
for
(
int
i
=
0
;
i
<
blocks
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
blocks
.
size
();
++
i
)
{
std
::
shared_ptr
<
BlockDesc
>
block_desc
=
blocks
[
i
];
std
::
shared_ptr
<
BlockDesc
>
block_desc
=
blocks
[
i
];
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
ops
=
block_desc
->
Ops
();
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
ops
=
block_desc
->
Ops
();
for
(
int
j
=
0
;
j
<
ops
.
size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
ops
.
size
();
++
j
)
{
std
::
shared_ptr
<
OpDesc
>
op
=
ops
[
j
];
std
::
shared_ptr
<
OpDesc
>
op
=
ops
[
j
];
if
(
op
->
Type
()
==
"conv2d"
&&
if
(
op
->
Type
()
==
"conv2d"
&&
op
->
Input
(
"Input"
)[
0
]
==
"pixel"
)
{
op
->
Input
(
"Input"
)[
0
]
==
"pixel"
)
{
Attribute
strides_attr
=
op
->
GetAttrMap
().
at
(
"strides"
);
Attribute
strides_attr
=
op
->
GetAttrMap
().
at
(
"strides"
);
std
::
vector
<
int
>
stride
=
strides_attr
.
Get
<
std
::
vector
<
int
>>
();
std
::
vector
<
int
>
stride
=
for
(
int
k
=
0
;
k
<
stride
.
size
();
++
k
)
{
strides_attr
.
Get
<
std
::
vector
<
int
>>
();
for
(
int
k
=
0
;
k
<
stride
.
size
();
++
k
)
{
}
std
::
shared_ptr
<
operators
::
ConvOp
<
Dtype
,
float
>>
conv
=
std
::
make_shared
<
operators
::
ConvOp
<
Dtype
,
float
>>
(
op
->
Type
(),
op
->
GetInputs
(),
op
->
GetOutputs
(),
op
->
GetAttrMap
(),
program_
.
scope
);
ops_of_block_
[
*
block_desc
.
get
()].
push_back
(
conv
);
}
}
}
std
::
shared_ptr
<
operators
::
ConvOp
<
Dtype
,
float
>>
conv
=
std
::
make_shared
<
operators
::
ConvOp
<
Dtype
,
float
>>
(
op
->
Type
(),
op
->
GetInputs
(),
op
->
GetOutputs
(),
op
->
GetAttrMap
(),
program_
.
scope
);
ops_of_block_
[
*
block_desc
.
get
()].
push_back
(
conv
);
}
}
}
}
}
}
template
<
typename
Dtype
>
template
<
typename
Dtype
>
std
::
shared_ptr
<
Tensor
>
Executor
<
Dtype
>::
predict
(
Tensor
&
t
)
{
std
::
shared_ptr
<
Tensor
>
Executor
<
Dtype
>::
predict
(
Tensor
&
t
)
{
// feed
// feed
auto
scope
=
program_
.
scope
;
auto
scope
=
program_
.
scope
;
Variable
*
g_feed_value
=
scope
->
Var
(
"pixel"
);
Variable
*
g_feed_value
=
scope
->
Var
(
"pixel"
);
auto
tensor
=
g_feed_value
->
GetMutable
<
Tensor
>
();
auto
tensor
=
g_feed_value
->
GetMutable
<
Tensor
>
();
tensor
->
ShareDataWith
(
t
);
tensor
->
ShareDataWith
(
t
);
Variable
*
con_output
=
scope
->
Var
(
"conv2d_0.tmp_0"
);
Variable
*
con_output
=
scope
->
Var
(
"conv2d_0.tmp_0"
);
Tensor
*
output_tensor
=
con_output
->
GetMutable
<
Tensor
>
();
Tensor
*
output_tensor
=
con_output
->
GetMutable
<
Tensor
>
();
output_tensor
->
mutable_data
<
float
>
({
1
,
16
,
32
,
32
});
output_tensor
->
mutable_data
<
float
>
({
1
,
16
,
32
,
32
});
// std::cout << typeid(output_tensor).name() << std::endl;
// std::cout << typeid(output_tensor).name() << std::endl;
// std::cout << "output_tensor dims: " << output_tensor->dims() <<
// std::cout << "output_tensor dims: " << output_tensor->dims() <<
// std::endl;
// std::endl;
std
::
shared_ptr
<
Tensor
>
out_tensor
=
std
::
make_shared
<
LoDTensor
>
();
std
::
shared_ptr
<
Tensor
>
out_tensor
=
std
::
make_shared
<
LoDTensor
>
();
out_tensor
.
reset
(
output_tensor
);
out_tensor
.
reset
(
output_tensor
);
predict
(
t
,
0
);
predict
(
t
,
0
);
return
out_tensor
;
return
out_tensor
;
}
}
template
<
typename
Dtype
>
template
<
typename
Dtype
>
void
Executor
<
Dtype
>::
predict
(
const
Tensor
&
t
,
int
block_id
)
{
void
Executor
<
Dtype
>::
predict
(
const
Tensor
&
t
,
int
block_id
)
{
std
::
shared_ptr
<
BlockDesc
>
to_predict_block
=
std
::
shared_ptr
<
BlockDesc
>
to_predict_block
=
to_predict_program_
->
Block
(
block_id
);
to_predict_program_
->
Block
(
block_id
);
for
(
int
j
=
0
;
j
<
ops_of_block_
[
*
to_predict_block
.
get
()].
size
();
for
(
int
j
=
0
;
j
<
ops_of_block_
[
*
to_predict_block
.
get
()].
size
();
++
j
)
{
++
j
)
{
auto
op
=
ops_of_block_
[
*
to_predict_block
.
get
()][
j
];
auto
op
=
ops_of_block_
[
*
to_predict_block
.
get
()][
j
];
// std::cout << "开始run" << std::endl;
// std::cout << "开始run" << std::endl;
op
->
Run
();
op
->
Run
();
}
}
}
}
template
class
Executor
<
CPU
>;
template
class
Executor
<
CPU
>;
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/executor.h
浏览文件 @
47e5978f
...
@@ -32,22 +32,22 @@ SOFTWARE.
...
@@ -32,22 +32,22 @@ SOFTWARE.
#include "variable.h"
#include "variable.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
template
<
typename
Dtype
>
class
Executor
{
template
<
typename
Dtype
>
class
Executor
{
public:
public:
Executor
(
const
Program
<
Dtype
>
p
);
Executor
(
const
Program
<
Dtype
>
p
);
std
::
shared_ptr
<
Tensor
>
predict
(
Tensor
&
t
);
std
::
shared_ptr
<
Tensor
>
predict
(
Tensor
&
t
);
private:
private:
const
framework
::
Program
<
Dtype
>
program_
;
const
framework
::
Program
<
Dtype
>
program_
;
std
::
shared_ptr
<
ProgramDesc
>
to_predict_program_
;
std
::
shared_ptr
<
ProgramDesc
>
to_predict_program_
;
void
predict
(
const
Tensor
&
t
,
int
block_id
);
void
predict
(
const
Tensor
&
t
,
int
block_id
);
std
::
map
<
framework
::
BlockDesc
,
std
::
map
<
framework
::
BlockDesc
,
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
<
Dtype
>>>>
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
<
Dtype
>>>>
ops_of_block_
;
ops_of_block_
;
bool
use_optimize_
=
false
;
bool
use_optimize_
=
false
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/framework.pb.cpp
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/framework/framework.pb.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/framework/lod_tensor.cc
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/framework/lod_tensor.h
浏览文件 @
47e5978f
...
@@ -23,190 +23,186 @@ limitations under the License. */
...
@@ -23,190 +23,186 @@ limitations under the License. */
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
/*
/*
* LoD is short for Level of Details.
* LoD is short for Level of Details.
*
*
* - in a level, each element indicates relative offset of the lower
* - in a level, each element indicates relative offset of the lower
* level
* level
* - the first element should be 0 and that indicates that this sequence
* - the first element should be 0 and that indicates that this sequence
* start
* start
* from 0
* from 0
* - each sequence's begin and end(no-inclusive) is level[id, id+1]
* - each sequence's begin and end(no-inclusive) is level[id, id+1]
*
*
* For example:
* For example:
* 3-level LoD stores
* 3-level LoD stores
*
*
* 0 2 3
* 0 2 3
* 0 2 4 7
* 0 2 4 7
* 0 2 5 7 10 12 15 20
* 0 2 5 7 10 12 15 20
*/
*/
using
LoD
=
std
::
vector
<
std
::
vector
<
size_t
>>
;
using
LoD
=
std
::
vector
<
std
::
vector
<
size_t
>>
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
LoD
&
lod
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
LoD
&
lod
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
LoDTensor
&
t
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
LoDTensor
&
t
);
std
::
string
LoDToString
(
const
LoD
&
lod
);
std
::
string
LoDToString
(
const
LoD
&
lod
);
LoD
SliceInLevel
(
const
LoD
&
in
,
size_t
level
,
size_t
elem_begin
,
LoD
SliceInLevel
(
const
LoD
&
in
,
size_t
level
,
size_t
elem_begin
,
size_t
elem_end
);
size_t
elem_end
);
/*
/*
* Transform an LoD from relative offsets to absolute offsets.
* Transform an LoD from relative offsets to absolute offsets.
*/
*/
LoD
ToAbsOffset
(
const
LoD
&
in
);
LoD
ToAbsOffset
(
const
LoD
&
in
);
bool
operator
==
(
const
LoD
&
a
,
const
LoD
&
b
);
bool
operator
==
(
const
LoD
&
a
,
const
LoD
&
b
);
/*
/*
* Check whether this lod's format is valid.
* Check whether this lod's format is valid.
*
*
* ATTENTION:
* ATTENTION:
* - Empty lod is treated as valid.
* - Empty lod is treated as valid.
*
*
* It will check two things:
* It will check two things:
*
*
* 1. all the offsets in a level should be ascending(no same items
* 1. all the offsets in a level should be ascending(no same items
* allows).
* allows).
* 2. there should be more than 2 offsets existing in each level.
* 2. there should be more than 2 offsets existing in each level.
* 3. the higher level's last offset should equals the lower level's
* 3. the higher level's last offset should equals the lower level's
* size-1.
* size-1.
* 4. the first offset(the begin offset) of each level should be 0.
* 4. the first offset(the begin offset) of each level should be 0.
* 5. the lowest level's last offset should equals `tensor_height` if
* 5. the lowest level's last offset should equals `tensor_height` if
* tensor_height>0.
* tensor_height>0.
*/
*/
bool
CheckLoD
(
const
LoD
&
in
,
int
tensor_height
=
-
1
);
bool
CheckLoD
(
const
LoD
&
in
,
int
tensor_height
=
-
1
);
/*
/*
* Check whether this absolute lod's format is valid.
* Check whether this absolute lod's format is valid.
*
*
* ATTENTION:
* ATTENTION:
* - Empty lod is treated as valid.
* - Empty lod is treated as valid.
*
*
* It will check two things:
* It will check two things:
* 1. all the offsets in a level should be ascending(no same items
* 1. all the offsets in a level should be ascending(no same items
* allows)
* allows)
* 2. there should be more than 2 offsets existing in each level.
* 2. there should be more than 2 offsets existing in each level.
* 3. the first offset of each level should be 0, and the last should
* 3. the first offset of each level should be 0, and the last should
* be the
* be the
* same(the height of underlying tensor) or `tensor_height` if
* same(the height of underlying tensor) or `tensor_height` if
* tensor_height>0.
* tensor_height>0.
*/
*/
bool
CheckAbsLoD
(
const
LoD
&
in
,
int
tensor_height
=
-
1
);
bool
CheckAbsLoD
(
const
LoD
&
in
,
int
tensor_height
=
-
1
);
/*
/*
* LoDTensor (Level of details Tensor)
* LoDTensor (Level of details Tensor)
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/
*/
class
LoDTensor
:
public
Tensor
{
class
LoDTensor
:
public
Tensor
{
public:
public:
LoDTensor
()
:
Tensor
()
{}
LoDTensor
()
:
Tensor
()
{}
explicit
LoDTensor
(
const
LoD
&
lod
)
:
lod_
(
lod
)
{}
explicit
LoDTensor
(
const
LoD
&
lod
)
:
lod_
(
lod
)
{}
void
set_lod
(
const
LoD
&
lod
)
{
lod_
=
lod
;
}
void
set_lod
(
const
LoD
&
lod
)
{
lod_
=
lod
;
}
const
LoD
&
lod
()
const
{
return
lod_
;
}
const
LoD
&
lod
()
const
{
return
lod_
;
}
LoD
*
mutable_lod
()
{
return
&
lod_
;
}
LoD
*
mutable_lod
()
{
return
&
lod_
;
}
/*
/*
* Get the start offset and end offset of an element from LoD.
* Get the start offset and end offset of an element from LoD.
*/
*/
std
::
pair
<
size_t
,
size_t
>
lod_element
(
size_t
level
,
std
::
pair
<
size_t
,
size_t
>
lod_element
(
size_t
level
,
size_t
elem
)
const
{
size_t
elem
)
const
{
// PADDLE_ENFORCE_LT(level, NumLevels());
// PADDLE_ENFORCE_LT(level, NumLevels());
// PADDLE_ENFORCE_LT(elem, NumElements(level));
// PADDLE_ENFORCE_LT(elem, NumElements(level));
return
std
::
make_pair
((
lod_
)[
level
][
elem
],
(
lod_
)[
level
][
elem
+
1
]);
return
std
::
make_pair
((
lod_
)[
level
][
elem
],
}
(
lod_
)[
level
][
elem
+
1
]);
}
/*
* Number of LoDTensor's levels, each level has units of data, for
/*
* example,
* Number of LoDTensor's levels, each level has units of data, for
* in the sentence's view, article, paragraph, sentence are 3
* example,
* levels.
* in the sentence's view, article, paragraph, sentence are 3
*/
* levels.
size_t
NumLevels
()
const
{
return
lod_
.
size
();
}
*/
size_t
NumLevels
()
const
{
return
lod_
.
size
();
}
/*
* Number of elements in a level.
/*
*/
* Number of elements in a level.
size_t
NumElements
(
size_t
level
=
0
)
const
{
*/
// PADDLE_ENFORCE_LT(level, NumLevels());
size_t
NumElements
(
size_t
level
=
0
)
const
{
// the last offset is the end of last element
// PADDLE_ENFORCE_LT(level, NumLevels());
return
(
lod_
)[
level
].
size
()
-
1
;
// the last offset is the end of last element
}
return
(
lod_
)[
level
].
size
()
-
1
;
}
private:
LoD
lod_
;
private:
};
LoD
lod_
;
};
/*
* Expand the `source` to fit the LoD of `lod`. For example, a `source`
/*
* LoDTensor is
* Expand the `source` to fit the LoD of `lod`. For example, a `source`
* - LoD: [0, 2]
* LoDTensor is
* - tensor: [a0, a1]
* - LoD: [0, 2]
* a `lod` is
* - tensor: [a0, a1]
* - LoD: [0 3 5]
* a `lod` is
* returns a new LoDTensor
* - LoD: [0 3 5]
* - [a0 a0 a0 a1 a1]
* returns a new LoDTensor
*/
* - [a0 a0 a0 a1 a1]
template
<
typename
T
>
*/
LoDTensor
LodExpand
(
const
LoDTensor
&
source
,
const
LoD
&
lod
,
size_t
level
)
{
template
<
typename
T
>
LoD
abs_lod
=
ToAbsOffset
(
lod
);
LoDTensor
LodExpand
(
const
LoDTensor
&
source
,
const
LoD
&
lod
,
const
auto
&
lod_level
=
lod
[
level
];
size_t
level
)
{
size_t
num_instances
=
source
.
dims
()[
0
];
LoD
abs_lod
=
ToAbsOffset
(
lod
);
const
auto
&
lod_level
=
lod
[
level
];
// new tensor
size_t
num_instances
=
source
.
dims
()[
0
];
LoDTensor
tensor
;
tensor
.
set_lod
(
lod
);
// new tensor
auto
dims
=
source
.
dims
();
LoDTensor
tensor
;
dims
[
0
]
=
lod_level
.
back
();
tensor
.
set_lod
(
lod
);
tensor
.
Resize
(
dims
);
auto
dims
=
source
.
dims
();
tensor
.
mutable_data
<
T
>
();
dims
[
0
]
=
lod_level
.
back
();
tensor
.
Resize
(
dims
);
// PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1);
tensor
.
mutable_data
<
T
>
();
for
(
size_t
ins
=
0
;
ins
<
num_instances
;
ins
++
)
{
for
(
size_t
elem
=
lod_level
[
ins
];
elem
<
lod_level
[
ins
+
1
];
elem
++
)
{
// PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1);
auto
slice
=
tensor
.
Slice
(
elem
,
elem
+
1
);
for
(
size_t
ins
=
0
;
ins
<
num_instances
;
ins
++
)
{
TensorCopy
(
source
.
Slice
(
ins
,
ins
+
1
),
&
slice
);
for
(
size_t
elem
=
lod_level
[
ins
];
elem
<
lod_level
[
ins
+
1
];
elem
++
)
{
auto
slice
=
tensor
.
Slice
(
elem
,
elem
+
1
);
TensorCopy
(
source
.
Slice
(
ins
,
ins
+
1
),
&
slice
);
}
}
return
tensor
;
}
}
}
// Get the absolute offset of a lod[start_level][start_idx:end_idx] and
return
tensor
;
// relative length of details for every levels(i.e., [start_level: ]).
}
//
// For example,
// Get the absolute offset of a lod[start_level][start_idx:end_idx] and
// lod = [[0, 3, 4, 8], [0, 9, 10, 11, 13, 17, 19, 22, 24]]
// relative length of details for every levels(i.e., [start_level: ]).
// start_level = 0
//
// start_idx = 1
// For example,
// end_idx = 3
// lod = [[0, 3, 4, 8], [0, 9, 10, 11, 13, 17, 19, 22, 24]]
//
// start_level = 0
// Returns:
// start_idx = 1
// LoD = [[1, 4], [2, 4, 2, 3, 2]]
// end_idx = 3
// pair<size_t, size_t> = {11, 24}
//
std
::
pair
<
LoD
,
std
::
pair
<
size_t
,
size_t
>>
// Returns:
GetSubLoDAndAbsoluteOffset
(
const
LoD
&
lod
,
size_t
start_idx
,
// LoD = [[1, 4], [2, 4, 2, 3, 2]]
size_t
end_idx
,
size_t
start_level
);
// pair<size_t, size_t> = {11, 24}
std
::
pair
<
LoD
,
std
::
pair
<
size_t
,
size_t
>>
void
AppendLoD
(
LoD
*
lod
,
const
LoD
&
lod_length
);
GetSubLoDAndAbsoluteOffset
(
const
LoD
&
lod
,
size_t
start_idx
,
size_t
end_idx
,
size_t
start_level
);
/*
* Serialize/Desiralize LoDTensor to std::ostream
void
AppendLoD
(
LoD
*
lod
,
const
LoD
&
lod_length
);
* You can pass ofstream or ostringstream to serilize to file
* or to a in memory string. GPU tensor will be copied to CPU.
/*
*/
* Serialize/Desiralize LoDTensor to std::ostream
void
SerializeToStream
(
std
::
ostream
&
os
,
const
LoDTensor
&
tensor
);
* You can pass ofstream or ostringstream to serilize to file
* or to a in memory string. GPU tensor will be copied to CPU.
void
DeserializeFromStream
(
std
::
istream
&
is
,
LoDTensor
*
tensor
);
*/
void
SerializeToStream
(
std
::
ostream
&
os
,
const
LoDTensor
&
tensor
);
}
// namespace framework
void
DeserializeFromStream
(
std
::
istream
&
is
,
LoDTensor
*
tensor
);
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/op_desc.cpp
浏览文件 @
47e5978f
...
@@ -5,58 +5,55 @@
...
@@ -5,58 +5,55 @@
#include "op_desc.h"
#include "op_desc.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
OpDesc
::
OpDesc
(
const
proto
::
OpDesc
&
desc
)
:
desc_
(
desc
)
{
OpDesc
::
OpDesc
(
const
proto
::
OpDesc
&
desc
)
:
desc_
(
desc
)
{
for
(
int
i
=
0
;
i
<
desc_
.
inputs_size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
desc_
.
inputs_size
();
++
i
)
{
const
proto
::
OpDesc
::
Var
&
var
=
desc_
.
inputs
(
i
);
const
proto
::
OpDesc
::
Var
&
var
=
desc_
.
inputs
(
i
);
std
::
vector
<
std
::
string
>
&
args
=
inputs_
[
var
.
parameter
()];
std
::
vector
<
std
::
string
>
&
args
=
inputs_
[
var
.
parameter
()];
int
arg_size
=
var
.
arguments_size
();
int
arg_size
=
var
.
arguments_size
();
for
(
int
j
=
0
;
j
<
arg_size
;
++
j
)
{
for
(
int
j
=
0
;
j
<
arg_size
;
++
j
)
{
args
.
push_back
(
var
.
arguments
(
j
));
args
.
push_back
(
var
.
arguments
(
j
));
}
}
for
(
int
i
=
0
;
i
<
desc_
.
outputs_size
();
++
i
)
{
const
proto
::
OpDesc
::
Var
&
var
=
desc_
.
outputs
(
i
);
std
::
vector
<
std
::
string
>
&
args
=
outputs_
[
var
.
parameter
()];
int
arg_size
=
var
.
arguments_size
();
for
(
int
j
=
0
;
j
<
arg_size
;
++
j
)
{
args
.
push_back
(
var
.
arguments
(
j
));
}
}
for
(
const
proto
::
OpDesc
::
Attr
&
attr
:
desc_
.
attrs
())
{
std
::
string
attr_name
=
attr
.
name
();
if
(
attr
.
type
()
!=
proto
::
AttrType
::
BLOCK
)
{
attrs_
[
attr_name
]
=
Attribute
::
GetAttrValue
(
attr
);
// if (attr.type() == proto::AttrType::INT){
// std::cout << " attrName " << attr_name << " " <<
// attrs_[attr_name].Get<int>() << std::endl;
// }
}
}
}
}
}
const
std
::
vector
<
std
::
string
>
&
OpDesc
::
Input
(
const
std
::
string
&
name
)
const
{
for
(
int
i
=
0
;
i
<
desc_
.
outputs_size
();
++
i
)
{
return
inputs_
.
find
(
name
)
->
second
;
const
proto
::
OpDesc
::
Var
&
var
=
desc_
.
outputs
(
i
);
std
::
vector
<
std
::
string
>
&
args
=
outputs_
[
var
.
parameter
()];
int
arg_size
=
var
.
arguments_size
();
for
(
int
j
=
0
;
j
<
arg_size
;
++
j
)
{
args
.
push_back
(
var
.
arguments
(
j
));
}
}
}
const
std
::
vector
<
std
::
string
>
&
OpDesc
::
Output
(
const
std
::
string
&
name
)
const
{
for
(
const
proto
::
OpDesc
::
Attr
&
attr
:
desc_
.
attrs
())
{
return
outputs_
.
find
(
name
)
->
second
;
std
::
string
attr_name
=
attr
.
name
();
if
(
attr
.
type
()
!=
proto
::
AttrType
::
BLOCK
)
{
attrs_
[
attr_name
]
=
Attribute
::
GetAttrValue
(
attr
);
// if (attr.type() == proto::AttrType::INT){
// std::cout << " attrName " << attr_name << " " <<
// attrs_[attr_name].Get<int>() << std::endl;
// }
}
}
}
}
Attribute
OpDesc
::
GetAttr
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
OpDesc
::
Input
(
const
std
::
string
&
name
)
const
{
auto
it
=
attrs_
.
find
(
name
);
return
inputs_
.
find
(
name
)
->
second
;
return
it
->
second
;
}
}
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
const
std
::
vector
<
std
::
string
>
&
OpDesc
::
Output
(
const
std
::
string
&
name
)
const
{
OpDesc
::
GetAttrMap
()
const
{
return
outputs_
.
find
(
name
)
->
second
;
return
attrs_
;
}
}
Attribute
OpDesc
::
GetAttr
(
const
std
::
string
&
name
)
const
{
auto
it
=
attrs_
.
find
(
name
);
return
it
->
second
;
}
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
OpDesc
::
GetAttrMap
()
const
{
return
attrs_
;
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/op_desc.h
浏览文件 @
47e5978f
...
@@ -23,31 +23,29 @@ SOFTWARE.
...
@@ -23,31 +23,29 @@ SOFTWARE.
#include "paddle_mobile_object.h"
#include "paddle_mobile_object.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
class
OpDesc
:
PaddleMobileObject
{
class
OpDesc
:
PaddleMobileObject
{
public:
public:
OpDesc
(
const
proto
::
OpDesc
&
desc
);
OpDesc
(
const
proto
::
OpDesc
&
desc
);
const
std
::
vector
<
std
::
string
>
&
const
std
::
vector
<
std
::
string
>
&
Input
(
const
std
::
string
&
name
)
const
;
Input
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>
&
Output
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>
&
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
;
Output
(
const
std
::
string
&
name
)
const
;
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
;
const
VariableNameMap
&
GetInputs
()
{
return
inputs_
;
}
const
VariableNameMap
&
GetInputs
()
{
return
inputs_
;
}
const
VariableNameMap
&
GetOutputs
()
{
return
outputs_
;
}
const
VariableNameMap
&
GetOutputs
()
{
return
outputs_
;
}
const
AttributeMap
&
GetAttrMap
()
const
;
const
AttributeMap
&
GetAttrMap
()
const
;
const
std
::
string
&
Type
()
{
return
desc_
.
type
();
};
const
std
::
string
&
Type
()
{
return
desc_
.
type
();
};
private:
private:
proto
::
OpDesc
desc_
;
proto
::
OpDesc
desc_
;
VariableNameMap
inputs_
;
VariableNameMap
inputs_
;
VariableNameMap
outputs_
;
VariableNameMap
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/op_info.h
浏览文件 @
47e5978f
...
@@ -22,74 +22,73 @@ SOFTWARE.
...
@@ -22,74 +22,73 @@ SOFTWARE.
#include "framework.pb.h"
#include "framework.pb.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
template
<
typename
Dtype
>
struct
OpInfo
{
template
<
typename
Dtype
>
struct
OpInfo
{
OpCreator
<
Dtype
>
creator_
;
OpCreator
<
Dtype
>
creator_
;
const
OpCreator
<
Dtype
>
&
Creator
()
const
{
const
OpCreator
<
Dtype
>
&
Creator
()
const
{
// PADDLE_ENFORCE_NOT_NULL(creator_,
// PADDLE_ENFORCE_NOT_NULL(creator_,
// "Operator Creator has not been
// "Operator Creator has not been
// registered");
// registered");
return
creator_
;
return
creator_
;
}
}
};
};
template
<
typename
Dtype
>
class
OpInfoMap
;
template
<
typename
Dtype
>
class
OpInfoMap
;
template
<
typename
Dtype
>
template
<
typename
Dtype
>
static
OpInfoMap
<
Dtype
>
*
g_op_info_map
=
nullptr
;
static
OpInfoMap
<
Dtype
>
*
g_op_info_map
=
nullptr
;
template
<
typename
Dtype
>
class
OpInfoMap
{
template
<
typename
Dtype
>
class
OpInfoMap
{
public:
public:
static
OpInfoMap
&
Instance
()
{
static
OpInfoMap
&
Instance
()
{
if
(
g_op_info_map
<
Dtype
>
==
nullptr
)
{
if
(
g_op_info_map
<
Dtype
>
==
nullptr
)
{
g_op_info_map
<
Dtype
>
=
new
OpInfoMap
();
g_op_info_map
<
Dtype
>
=
new
OpInfoMap
();
}
}
return
*
g_op_info_map
<
Dtype
>
;
return
*
g_op_info_map
<
Dtype
>
;
};
};
bool
Has
(
const
std
::
string
&
op_type
)
const
{
bool
Has
(
const
std
::
string
&
op_type
)
const
{
return
map_
.
find
(
op_type
)
!=
map_
.
end
();
return
map_
.
find
(
op_type
)
!=
map_
.
end
();
}
}
void
Insert
(
const
std
::
string
&
type
,
const
OpInfo
<
Dtype
>
&
info
)
{
void
Insert
(
const
std
::
string
&
type
,
const
OpInfo
<
Dtype
>
&
info
)
{
// PADDLE_ENFORCE(!Has(type), "Operator %s has been
// PADDLE_ENFORCE(!Has(type), "Operator %s has been
// registered", type);
// registered", type);
map_
.
insert
({
type
,
info
});
map_
.
insert
({
type
,
info
});
}
}
const
OpInfo
<
Dtype
>
&
Get
(
const
std
::
string
&
type
)
const
{
const
OpInfo
<
Dtype
>
&
Get
(
const
std
::
string
&
type
)
const
{
auto
op_info_ptr
=
GetNullable
(
type
);
auto
op_info_ptr
=
GetNullable
(
type
);
// PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not
// PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not
// been
// been
// registered",
// registered",
// type);
// type);
return
*
op_info_ptr
;
return
*
op_info_ptr
;
}
}
const
OpInfo
<
Dtype
>
*
GetNullable
(
const
std
::
string
&
type
)
const
{
const
OpInfo
<
Dtype
>
*
GetNullable
(
const
std
::
string
&
type
)
const
{
auto
it
=
map_
.
find
(
type
);
auto
it
=
map_
.
find
(
type
);
if
(
it
==
map_
.
end
())
{
if
(
it
==
map_
.
end
())
{
return
nullptr
;
return
nullptr
;
}
else
{
}
else
{
return
&
it
->
second
;
return
&
it
->
second
;
}
}
}
}
const
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
&
map
()
const
{
const
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
&
map
()
const
{
return
map_
;
return
map_
;
}
}
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
*
mutable_map
()
{
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
*
mutable_map
()
{
return
&
map_
;
return
&
map_
;
}
}
private:
private:
OpInfoMap
()
=
default
;
OpInfoMap
()
=
default
;
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
map_
;
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
map_
;
// DISABLE_COPY_AND_ASSIGN(OpInfoMap);
// DISABLE_COPY_AND_ASSIGN(OpInfoMap);
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/op_kernel_type.h
浏览文件 @
47e5978f
...
@@ -22,51 +22,44 @@ SOFTWARE.
...
@@ -22,51 +22,44 @@ SOFTWARE.
#include "framework.pb.h"
#include "framework.pb.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
struct
OpKernelType
{
struct
OpKernelType
{
struct
Hash
{
struct
Hash
{
size_t
operator
()(
const
OpKernelType
&
key
)
const
{
size_t
operator
()(
const
OpKernelType
&
key
)
const
{
int
data_type
=
static_cast
<
int
>
(
key
.
data_type_
)
int
data_type
=
static_cast
<
int
>
(
key
.
data_type_
)
<<
LEFT_SHIFT
;
<<
LEFT_SHIFT
;
int
data_layout
=
static_cast
<
int
>
(
key
.
data_layout_
)
int
data_layout
=
static_cast
<
int
>
(
key
.
data_layout_
)
<<
(
LEFT_SHIFT
*
2
);
<<
(
LEFT_SHIFT
*
2
);
std
::
hash
<
int
>
hasher
;
std
::
hash
<
int
>
hasher
;
return
hasher
(
data_type
+
data_layout
);
return
hasher
(
data_type
+
data_layout
);
}
}
};
};
// place, data_type, library_type kinds less than 2^8
// place, data_type, library_type kinds less than 2^8
constexpr
static
int
LEFT_SHIFT
=
8
;
constexpr
static
int
LEFT_SHIFT
=
8
;
proto
::
VarType
::
Type
data_type_
;
proto
::
VarType
::
Type
data_type_
;
DataLayout
data_layout_
;
DataLayout
data_layout_
;
OpKernelType
(
proto
::
VarType
::
Type
data_type
,
OpKernelType
(
proto
::
VarType
::
Type
data_type
,
DataLayout
data_layout
=
DataLayout
::
kAnyLayout
)
DataLayout
data_layout
=
DataLayout
::
kAnyLayout
)
:
data_type_
(
data_type
),
data_layout_
(
data_layout
)
{}
:
data_type_
(
data_type
),
data_layout_
(
data_layout
)
{}
bool
operator
==
(
const
OpKernelType
&
o
)
const
{
bool
operator
==
(
const
OpKernelType
&
o
)
const
{
return
data_type_
==
o
.
data_type_
&&
return
data_type_
==
o
.
data_type_
&&
data_layout_
==
o
.
data_layout_
;
data_layout_
==
o
.
data_layout_
;
}
}
bool
operator
!=
(
const
OpKernelType
&
o
)
const
{
bool
operator
!=
(
const
OpKernelType
&
o
)
const
{
return
!
(
*
this
==
o
);
}
return
!
(
*
this
==
o
);
};
}
};
inline
bool
NeedTransformLayout
(
const
DataLayout
&
l
,
inline
bool
NeedTransformLayout
(
const
DataLayout
&
l
,
const
DataLayout
&
r
)
{
const
DataLayout
&
r
)
{
return
l
!=
DataLayout
::
kAnyLayout
&&
r
!=
DataLayout
::
kAnyLayout
&&
l
!=
r
;
return
l
!=
DataLayout
::
kAnyLayout
&&
r
!=
DataLayout
::
kAnyLayout
&&
}
l
!=
r
;
}
inline
bool
TransFromNeeded
(
const
OpKernelType
&
l
,
inline
bool
TransFromNeeded
(
const
OpKernelType
&
l
,
const
OpKernelType
&
r
)
{
const
OpKernelType
&
r
)
{
return
(
l
.
data_type_
!=
r
.
data_type_
)
||
return
(
l
.
data_type_
!=
r
.
data_type_
)
||
NeedTransformLayout
(
l
.
data_layout_
,
r
.
data_layout_
);
NeedTransformLayout
(
l
.
data_layout_
,
r
.
data_layout_
);
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/op_proto_maker.h
浏览文件 @
47e5978f
...
@@ -19,8 +19,8 @@ SOFTWARE.
...
@@ -19,8 +19,8 @@ SOFTWARE.
#pragma once
#pragma once
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
// this class not only make proto but also init attribute checkers.
// this class not only make proto but also init attribute checkers.
class
OpProtoAndCheckerMaker
{};
class
OpProtoAndCheckerMaker
{};
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/operator.cpp
浏览文件 @
47e5978f
...
@@ -20,23 +20,23 @@ SOFTWARE.
...
@@ -20,23 +20,23 @@ SOFTWARE.
#include "op_info.h"
#include "op_info.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
template
<
typename
Dtype
>
template
<
typename
Dtype
>
OperatorBase
<
Dtype
>::
OperatorBase
(
const
std
::
string
&
type
,
OperatorBase
<
Dtype
>::
OperatorBase
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
const
AttributeMap
&
attrs
,
std
::
shared_ptr
<
Scope
>
scope
)
std
::
shared_ptr
<
Scope
>
scope
)
:
type_
(
type
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs
),
:
type_
(
type
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs
),
scope_
(
scope
)
{
scope_
(
scope
)
{
CheckAllInputOutputSet
();
CheckAllInputOutputSet
();
}
}
template
<
typename
Dtype
>
template
<
typename
Dtype
>
void
OperatorBase
<
Dtype
>::
CheckAllInputOutputSet
()
const
{}
void
OperatorBase
<
Dtype
>::
CheckAllInputOutputSet
()
const
{}
template
class
OperatorBase
<
CPU
>;
template
class
OperatorBase
<
CPU
>;
template
class
OperatorWithKernel
<
CPU
>;
template
class
OperatorWithKernel
<
CPU
>;
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/operator.h
浏览文件 @
47e5978f
...
@@ -33,68 +33,64 @@ SOFTWARE.
...
@@ -33,68 +33,64 @@ SOFTWARE.
#include "variable.h"
#include "variable.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
static
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
static
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
op_input_output_key
=
{
op_input_output_key
=
{
{
"conv2d"
,
{
"Input"
,
"Output"
}},
{
"relu"
,
{
"X"
,
"Out"
}},
{
"conv2d"
,
{
"Input"
,
"Output"
}},
{
"relu"
,
{
"X"
,
"Out"
}},
{
"softmax"
,
{
"X"
,
"Out"
}},
{
"mul"
,
{
"X"
,
"Out"
}},
{
"softmax"
,
{
"X"
,
"Out"
}},
{
"mul"
,
{
"X"
,
"Out"
}},
{
"elementwise_add"
,
{
"X"
,
"Out"
}},
{
"pool2d"
,
{
"X"
,
"Out"
}},
{
"elementwise_add"
,
{
"X"
,
"Out"
}},
{
"pool2d"
,
{
"X"
,
"Out"
}},
{
"batch_norm"
,
{
"X"
,
"Y"
}},
{
"lrn"
,
{
"X"
,
"Out"
}},
{
"batch_norm"
,
{
"X"
,
"Y"
}},
{
"lrn"
,
{
"X"
,
"Out"
}},
{
"concat"
,
{
"X"
,
"Out"
}},
{
"concat"
,
{
"X"
,
"Out"
}},
};
};
template
<
typename
Dtype
>
class
OperatorBase
:
PaddleMobileObject
{
template
<
typename
Dtype
>
class
OperatorBase
:
PaddleMobileObject
{
public:
public:
OperatorBase
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
OperatorBase
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
const
AttributeMap
&
attrs
,
std
::
shared_ptr
<
Scope
>
scope
);
std
::
shared_ptr
<
Scope
>
scope
);
virtual
~
OperatorBase
()
{}
virtual
~
OperatorBase
()
{}
virtual
void
Run
()
const
=
0
;
virtual
void
Run
()
const
=
0
;
const
VariableNameMap
&
Inputs
()
const
{
return
inputs_
;
}
const
VariableNameMap
&
Inputs
()
const
{
return
inputs_
;
}
const
VariableNameMap
&
Outputs
()
const
{
return
outputs_
;
}
const
VariableNameMap
&
Outputs
()
const
{
return
outputs_
;
}
const
std
::
string
&
Type
()
const
{
return
type_
;
}
const
std
::
string
&
Type
()
const
{
return
type_
;
}
const
AttributeMap
&
Attrs
()
const
{
return
attrs_
;
}
const
AttributeMap
&
Attrs
()
const
{
return
attrs_
;
}
void
ClearVariables
()
const
{
void
ClearVariables
()
const
{
if
(
this
->
scope_
)
{
if
(
this
->
scope_
)
{
this
->
scope_
->
EraseVars
(
this
->
inputs_
.
at
(
"Filter"
));
this
->
scope_
->
EraseVars
(
this
->
inputs_
.
at
(
"Filter"
));
this
->
scope_
->
EraseVars
(
this
->
inputs_
.
at
(
"Input"
));
this
->
scope_
->
EraseVars
(
this
->
inputs_
.
at
(
"Input"
));
}
}
}
}
protected:
protected:
std
::
shared_ptr
<
Scope
>
scope_
;
std
::
shared_ptr
<
Scope
>
scope_
;
std
::
string
type_
;
std
::
string
type_
;
VariableNameMap
inputs_
;
VariableNameMap
inputs_
;
VariableNameMap
outputs_
;
VariableNameMap
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
private:
private:
void
CheckAllInputOutputSet
()
const
;
void
CheckAllInputOutputSet
()
const
;
};
};
template
<
typename
Dtype
>
template
<
typename
Dtype
>
class
OperatorWithKernel
:
public
OperatorBase
<
Dtype
>
{
class
OperatorWithKernel
:
public
OperatorBase
<
Dtype
>
{
public:
public:
OperatorWithKernel
(
const
std
::
string
&
type
,
OperatorWithKernel
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
std
::
shared_ptr
<
Scope
>
scope
)
const
AttributeMap
&
attrs
,
:
OperatorBase
<
Dtype
>
(
type
,
inputs
,
outputs
,
attrs
,
scope
)
{}
std
::
shared_ptr
<
Scope
>
scope
)
virtual
void
InferShape
()
const
=
0
;
:
OperatorBase
<
Dtype
>
(
type
,
inputs
,
outputs
,
attrs
,
scope
)
{}
virtual
void
Run
()
const
=
0
;
virtual
void
InferShape
()
const
=
0
;
};
virtual
void
Run
()
const
=
0
;
};
template
<
typename
Dtype
,
typename
P
>
template
<
typename
Dtype
,
typename
P
>
class
OpKernelBase
:
PaddleMobileObject
{
class
OpKernelBase
:
PaddleMobileObject
{
public:
public:
virtual
void
Compute
(
const
P
&
para
)
const
=
0
;
virtual
void
Compute
(
const
P
&
para
)
const
=
0
;
virtual
~
OpKernelBase
()
=
default
;
virtual
~
OpKernelBase
()
=
default
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/paddle_mobile_object.h
浏览文件 @
47e5978f
...
@@ -23,14 +23,14 @@ SOFTWARE.
...
@@ -23,14 +23,14 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
class
PaddleMobileObject
{
class
PaddleMobileObject
{
public:
public:
virtual
inline
const
std
::
string
&
ToString
()
{
virtual
inline
const
std
::
string
&
ToString
()
{
char
address
[
128
]
=
{
0
};
char
address
[
128
]
=
{
0
};
sprintf
(
address
,
"%p"
,
this
);
sprintf
(
address
,
"%p"
,
this
);
return
std
::
string
(
address
);
return
std
::
string
(
address
);
}
}
private:
private:
};
};
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/program.cpp
浏览文件 @
47e5978f
...
@@ -17,5 +17,5 @@ SOFTWARE.
...
@@ -17,5 +17,5 @@ SOFTWARE.
==============================================================================*/
==============================================================================*/
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{}
namespace
framework
{}
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/program.h
浏览文件 @
47e5978f
...
@@ -24,17 +24,17 @@ SOFTWARE.
...
@@ -24,17 +24,17 @@ SOFTWARE.
#include "scope.h"
#include "scope.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
template
<
typename
Dtype
,
Precision
P
=
Precision
::
FP32
>
template
<
typename
Dtype
,
Precision
P
=
Precision
::
FP32
>
class
Program
:
PaddleMobileObject
{
class
Program
:
PaddleMobileObject
{
public:
public:
std
::
shared_ptr
<
ProgramDesc
>
originProgram
;
std
::
shared_ptr
<
ProgramDesc
>
originProgram
;
std
::
shared_ptr
<
ProgramDesc
>
optimizeProgram
;
std
::
shared_ptr
<
ProgramDesc
>
optimizeProgram
;
std
::
shared_ptr
<
Scope
>
scope
;
std
::
shared_ptr
<
Scope
>
scope
;
private:
private:
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/program_desc.cpp
浏览文件 @
47e5978f
...
@@ -5,18 +5,18 @@
...
@@ -5,18 +5,18 @@
#include "program_desc.h"
#include "program_desc.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
ProgramDesc
::
ProgramDesc
(
const
proto
::
ProgramDesc
&
desc
)
:
desc_
(
desc
)
{
ProgramDesc
::
ProgramDesc
(
const
proto
::
ProgramDesc
&
desc
)
:
desc_
(
desc
)
{
for
(
auto
&
block_desc
:
*
desc_
.
mutable_blocks
())
{
for
(
auto
&
block_desc
:
*
desc_
.
mutable_blocks
())
{
// new framework::BlockDesc(block_desc)
// new framework::BlockDesc(block_desc)
blocks_
.
emplace_back
(
std
::
make_shared
<
BlockDesc
>
(
block_desc
));
blocks_
.
emplace_back
(
std
::
make_shared
<
BlockDesc
>
(
block_desc
));
}
}
}
}
std
::
shared_ptr
<
BlockDesc
>
ProgramDesc
::
Block
(
size_t
idx
)
{
std
::
shared_ptr
<
BlockDesc
>
ProgramDesc
::
Block
(
size_t
idx
)
{
return
blocks_
[
idx
];
return
blocks_
[
idx
];
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/program_desc.h
浏览文件 @
47e5978f
...
@@ -25,20 +25,18 @@ SOFTWARE.
...
@@ -25,20 +25,18 @@ SOFTWARE.
#include "paddle_mobile_object.h"
#include "paddle_mobile_object.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
class
ProgramDesc
:
PaddleMobileObject
{
class
ProgramDesc
:
PaddleMobileObject
{
public:
public:
ProgramDesc
(
const
proto
::
ProgramDesc
&
desc
);
ProgramDesc
(
const
proto
::
ProgramDesc
&
desc
);
std
::
shared_ptr
<
BlockDesc
>
Block
(
size_t
idx
);
std
::
shared_ptr
<
BlockDesc
>
Block
(
size_t
idx
);
const
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
&
Blocks
()
{
const
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
&
Blocks
()
{
return
blocks_
;
};
return
blocks_
;
};
private:
private:
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
blocks_
;
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
blocks_
;
proto
::
ProgramDesc
desc_
;
proto
::
ProgramDesc
desc_
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/scope.cc
浏览文件 @
47e5978f
...
@@ -4,116 +4,116 @@
...
@@ -4,116 +4,116 @@
#include <vector>
#include <vector>
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
Scope
&
Scope
::
NewScope
()
const
{
Scope
&
Scope
::
NewScope
()
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
kids_
.
push_back
(
new
Scope
(
this
));
kids_
.
push_back
(
new
Scope
(
this
));
return
*
kids_
.
back
();
return
*
kids_
.
back
();
}
}
Variable
*
Scope
::
Var
(
const
std
::
string
&
name
)
{
Variable
*
Scope
::
Var
(
const
std
::
string
&
name
)
{
auto
*
pvar
=
FindVarLocally
(
name
);
auto
*
pvar
=
FindVarLocally
(
name
);
if
(
pvar
!=
nullptr
)
{
if
(
pvar
!=
nullptr
)
{
return
pvar
;
return
pvar
;
};
};
pvar
=
new
Variable
;
pvar
=
new
Variable
;
vars_
[
name
]
=
pvar
;
vars_
[
name
]
=
pvar
;
pvar
->
name_
=
&
(
vars_
.
find
(
name
)
->
first
);
pvar
->
name_
=
&
(
vars_
.
find
(
name
)
->
first
);
return
pvar
;
return
pvar
;
}
}
// Variable* Scope::Var(std::string* name) {
// Variable* Scope::Var(std::string* name) {
// auto var_name = string::Sprintf("%p.%d", this,
// auto var_name = string::Sprintf("%p.%d", this,
// vars_.size());
// vars_.size());
// if (name != nullptr) {
// if (name != nullptr) {
// *name = var_name;
// *name = var_name;
// }
// }
// return Var(var_name);
// return Var(var_name);
// }
// }
Variable
*
Scope
::
FindVar
(
const
std
::
string
&
name
)
const
{
Variable
*
Scope
::
FindVar
(
const
std
::
string
&
name
)
const
{
auto
*
pvar
=
FindVarLocally
(
name
);
auto
*
pvar
=
FindVarLocally
(
name
);
if
(
pvar
!=
nullptr
)
{
if
(
pvar
!=
nullptr
)
{
return
pvar
;
return
pvar
;
}
}
return
(
parent_
==
nullptr
)
?
nullptr
:
parent_
->
FindVar
(
name
);
return
(
parent_
==
nullptr
)
?
nullptr
:
parent_
->
FindVar
(
name
);
}
}
const
Scope
*
Scope
::
FindScope
(
const
Variable
*
var
)
const
{
const
Scope
*
Scope
::
FindScope
(
const
Variable
*
var
)
const
{
for
(
auto
&
name_var
:
vars_
)
{
for
(
auto
&
name_var
:
vars_
)
{
if
(
name_var
.
second
==
var
)
{
if
(
name_var
.
second
==
var
)
{
return
this
;
return
this
;
}
}
return
(
parent_
==
nullptr
)
?
nullptr
:
parent_
->
FindScope
(
var
);
}
}
}
return
(
parent_
==
nullptr
)
?
nullptr
:
parent_
->
FindScope
(
var
);
}
void
Scope
::
DropKids
()
{
void
Scope
::
DropKids
()
{
for
(
Scope
*
s
:
kids_
)
{
for
(
Scope
*
s
:
kids_
)
{
delete
s
;
delete
s
;
}
}
kids_
.
clear
();
kids_
.
clear
();
}
}
std
::
vector
<
std
::
string
>
Scope
::
LocalVarNames
()
const
{
std
::
vector
<
std
::
string
>
Scope
::
LocalVarNames
()
const
{
std
::
vector
<
std
::
string
>
known_vars
;
std
::
vector
<
std
::
string
>
known_vars
;
known_vars
.
reserve
(
vars_
.
size
());
known_vars
.
reserve
(
vars_
.
size
());
for
(
auto
&
name_var
:
vars_
)
{
for
(
auto
&
name_var
:
vars_
)
{
known_vars
.
emplace_back
(
name_var
.
first
);
known_vars
.
emplace_back
(
name_var
.
first
);
}
}
return
known_vars
;
return
known_vars
;
}
}
void
Scope
::
DeleteScope
(
Scope
*
scope
)
const
{
void
Scope
::
DeleteScope
(
Scope
*
scope
)
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
auto
it
=
std
::
find
(
kids_
.
begin
(),
kids_
.
end
(),
scope
);
auto
it
=
std
::
find
(
kids_
.
begin
(),
kids_
.
end
(),
scope
);
kids_
.
erase
(
it
);
kids_
.
erase
(
it
);
delete
scope
;
delete
scope
;
// deferent
// deferent
}
}
void
Scope
::
EraseVars
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
{
void
Scope
::
EraseVars
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
{
std
::
set
<
std
::
string
>
var_set
(
var_names
.
begin
(),
var_names
.
end
());
std
::
set
<
std
::
string
>
var_set
(
var_names
.
begin
(),
var_names
.
end
());
for
(
auto
it
=
vars_
.
begin
();
it
!=
vars_
.
end
();)
{
for
(
auto
it
=
vars_
.
begin
();
it
!=
vars_
.
end
();)
{
if
(
var_set
.
find
(
it
->
first
)
!=
var_set
.
end
())
{
if
(
var_set
.
find
(
it
->
first
)
!=
var_set
.
end
())
{
delete
it
->
second
;
delete
it
->
second
;
it
=
vars_
.
erase
(
it
);
it
=
vars_
.
erase
(
it
);
}
else
{
}
else
{
++
it
;
++
it
;
}
}
}
}
}
}
void
Scope
::
Rename
(
const
std
::
string
&
origin_name
,
void
Scope
::
Rename
(
const
std
::
string
&
origin_name
,
const
std
::
string
&
new_name
)
const
{
const
std
::
string
&
new_name
)
const
{
auto
origin_it
=
vars_
.
find
(
origin_name
);
auto
origin_it
=
vars_
.
find
(
origin_name
);
if
(
origin_it
==
vars_
.
end
())
{
if
(
origin_it
==
vars_
.
end
())
{
return
;
return
;
}
}
auto
new_it
=
vars_
.
find
(
new_name
);
auto
new_it
=
vars_
.
find
(
new_name
);
if
(
new_it
!=
vars_
.
end
())
{
if
(
new_it
!=
vars_
.
end
())
{
return
;
return
;
}
}
vars_
[
new_name
]
=
origin_it
->
second
;
vars_
[
new_name
]
=
origin_it
->
second
;
vars_
.
erase
(
origin_it
);
vars_
.
erase
(
origin_it
);
}
}
//
//
// std::string Scope::Rename(const std::string& origin_name)
// std::string Scope::Rename(const std::string& origin_name)
// const {
// const {
// auto var_name = string::Sprintf("%p.%d", this,
// auto var_name = string::Sprintf("%p.%d", this,
// vars_.size());
// vars_.size());
// Rename(origin_name, var_name);
// Rename(origin_name, var_name);
// return var_name;
// return var_name;
// }
// }
Variable
*
Scope
::
FindVarLocally
(
const
std
::
string
&
name
)
const
{
Variable
*
Scope
::
FindVarLocally
(
const
std
::
string
&
name
)
const
{
auto
it
=
vars_
.
find
(
name
);
auto
it
=
vars_
.
find
(
name
);
if
(
it
!=
vars_
.
end
())
{
if
(
it
!=
vars_
.
end
())
{
return
it
->
second
;
return
it
->
second
;
}
}
return
nullptr
;
return
nullptr
;
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/scope.h
浏览文件 @
47e5978f
...
@@ -24,58 +24,58 @@ SOFTWARE.
...
@@ -24,58 +24,58 @@ SOFTWARE.
#include <unordered_map> //std::unordered_map
#include <unordered_map> //std::unordered_map
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
class
Scope
{
class
Scope
{
public:
public:
Scope
()
{}
Scope
()
{}
~
Scope
()
{}
~
Scope
()
{}
Scope
&
NewScope
()
const
;
Scope
&
NewScope
()
const
;
/// Create a variable with given name if it doesn't exist.
/// Create a variable with given name if it doesn't exist.
Variable
*
Var
(
const
std
::
string
&
name
);
Variable
*
Var
(
const
std
::
string
&
name
);
/// Create a variable with a scope-unique name.
/// Create a variable with a scope-unique name.
Variable
*
Var
(
std
::
string
*
name
=
nullptr
);
Variable
*
Var
(
std
::
string
*
name
=
nullptr
);
void
EraseVars
(
const
std
::
vector
<
std
::
string
>
&
var_names
);
void
EraseVars
(
const
std
::
vector
<
std
::
string
>
&
var_names
);
/// Find a variable in the scope or any of its ancestors. Returns
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
/// nullptr if cannot find.
Variable
*
FindVar
(
const
std
::
string
&
name
)
const
;
Variable
*
FindVar
(
const
std
::
string
&
name
)
const
;
const
Scope
*
parent
()
const
{
return
parent_
;
}
const
Scope
*
parent
()
const
{
return
parent_
;
}
/// Find the scope or an ancestor scope that contains the given
/// Find the scope or an ancestor scope that contains the given
/// variable.
/// variable.
const
Scope
*
FindScope
(
const
Variable
*
var
)
const
;
const
Scope
*
FindScope
(
const
Variable
*
var
)
const
;
void
DeleteScope
(
Scope
*
scope
)
const
;
void
DeleteScope
(
Scope
*
scope
)
const
;
/// Drop all kids scopes belonged to this scope.
/// Drop all kids scopes belonged to this scope.
void
DropKids
();
void
DropKids
();
// enumerate all the variables current contains.
// enumerate all the variables current contains.
std
::
vector
<
std
::
string
>
LocalVarNames
()
const
;
std
::
vector
<
std
::
string
>
LocalVarNames
()
const
;
// Rename variable to a new name
// Rename variable to a new name
void
Rename
(
const
std
::
string
&
origin_name
,
void
Rename
(
const
std
::
string
&
origin_name
,
const
std
::
string
&
new_name
)
const
;
const
std
::
string
&
new_name
)
const
;
// Rename variable to a new name and return the new name
// Rename variable to a new name and return the new name
std
::
string
Rename
(
const
std
::
string
&
origin_name
)
const
;
std
::
string
Rename
(
const
std
::
string
&
origin_name
)
const
;
Variable
*
FindVarLocally
(
const
std
::
string
&
name
)
const
;
Variable
*
FindVarLocally
(
const
std
::
string
&
name
)
const
;
private:
private:
// Call Scope::NewScope for a sub-scope.
// Call Scope::NewScope for a sub-scope.
explicit
Scope
(
Scope
const
*
parent
)
:
parent_
(
parent
)
{}
explicit
Scope
(
Scope
const
*
parent
)
:
parent_
(
parent
)
{}
mutable
std
::
unordered_map
<
std
::
string
,
Variable
*>
vars_
;
mutable
std
::
unordered_map
<
std
::
string
,
Variable
*>
vars_
;
mutable
std
::
list
<
Scope
*>
kids_
;
mutable
std
::
list
<
Scope
*>
kids_
;
Scope
const
*
parent_
{
nullptr
};
Scope
const
*
parent_
{
nullptr
};
mutable
std
::
mutex
mutex_
;
mutable
std
::
mutex
mutex_
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/selected_rows.h
浏览文件 @
47e5978f
...
@@ -24,59 +24,58 @@ SOFTWARE.
...
@@ -24,59 +24,58 @@ SOFTWARE.
#include "tensor.h"
#include "tensor.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
class
SelectedRows
{
class
SelectedRows
{
public:
public:
SelectedRows
(
const
std
::
vector
<
int64_t
>
&
rows
,
SelectedRows
(
const
std
::
vector
<
int64_t
>
&
rows
,
const
int64_t
&
height
)
const
int64_t
&
height
)
:
rows_
(
rows
),
height_
(
height
)
{
:
rows_
(
rows
),
height_
(
height
)
{
value_
.
reset
(
new
Tensor
());
value_
.
reset
(
new
Tensor
());
}
}
SelectedRows
()
{
SelectedRows
()
{
height_
=
0
;
height_
=
0
;
value_
.
reset
(
new
Tensor
());
value_
.
reset
(
new
Tensor
());
}
}
const
Tensor
&
value
()
const
{
return
*
value_
;
}
const
Tensor
&
value
()
const
{
return
*
value_
;
}
Tensor
*
mutable_value
()
{
return
value_
.
get
();
}
Tensor
*
mutable_value
()
{
return
value_
.
get
();
}
int64_t
height
()
const
{
return
height_
;
}
int64_t
height
()
const
{
return
height_
;
}
void
set_height
(
int64_t
height
)
{
height_
=
height
;
}
void
set_height
(
int64_t
height
)
{
height_
=
height
;
}
const
std
::
vector
<
int64_t
>
&
rows
()
const
{
return
rows_
;
}
const
std
::
vector
<
int64_t
>
&
rows
()
const
{
return
rows_
;
}
std
::
vector
<
int64_t
>
*
mutable_rows
()
{
return
&
rows_
;
}
std
::
vector
<
int64_t
>
*
mutable_rows
()
{
return
&
rows_
;
}
void
set_rows
(
const
std
::
vector
<
int64_t
>
&
rows
)
{
rows_
=
rows
;
}
void
set_rows
(
const
std
::
vector
<
int64_t
>
&
rows
)
{
rows_
=
rows
;
}
/**
/**
* get the index of id in rows
* get the index of id in rows
*/
*/
int64_t
index
(
int64_t
id
)
const
{
int64_t
index
(
int64_t
id
)
const
{
auto
it
=
std
::
find
(
rows_
.
begin
(),
rows_
.
end
(),
id
);
auto
it
=
std
::
find
(
rows_
.
begin
(),
rows_
.
end
(),
id
);
// PADDLE_ENFORCE(it != rows_.end(), "id should be in rows");
// PADDLE_ENFORCE(it != rows_.end(), "id should be in rows");
return
static_cast
<
int64_t
>
(
std
::
distance
(
rows_
.
begin
(),
it
));
return
static_cast
<
int64_t
>
(
std
::
distance
(
rows_
.
begin
(),
it
));
}
}
DDim
GetCompleteDims
()
const
{
DDim
GetCompleteDims
()
const
{
std
::
vector
<
int64_t
>
dims
=
vectorize
(
value_
->
dims
());
std
::
vector
<
int64_t
>
dims
=
vectorize
(
value_
->
dims
());
dims
[
0
]
=
height_
;
dims
[
0
]
=
height_
;
return
make_ddim
(
dims
);
return
make_ddim
(
dims
);
}
}
private:
private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9}
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9}
// here.
// here.
// SelectedRows are simply concated when adding together. Until a
// SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
// SelectedRows add a Tensor, will the duplicate rows be handled.
std
::
vector
<
int64_t
>
rows_
;
std
::
vector
<
int64_t
>
rows_
;
std
::
unique_ptr
<
Tensor
>
value_
{
nullptr
};
std
::
unique_ptr
<
Tensor
>
value_
{
nullptr
};
int64_t
height_
;
int64_t
height_
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/tensor.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/framework/tensor_util.cc
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/framework/tensor_util.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/framework/var_desc.cpp
浏览文件 @
47e5978f
...
@@ -20,9 +20,9 @@ SOFTWARE.
...
@@ -20,9 +20,9 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
VarDesc
::
VarDesc
(
const
proto
::
VarDesc
&
desc
)
:
desc_
(
desc
)
{}
VarDesc
::
VarDesc
(
const
proto
::
VarDesc
&
desc
)
:
desc_
(
desc
)
{}
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/var_desc.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/framework/var_type.h
浏览文件 @
47e5978f
...
@@ -23,17 +23,17 @@ SOFTWARE.
...
@@ -23,17 +23,17 @@ SOFTWARE.
#include "variable.h"
#include "variable.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
inline
proto
::
VarType
::
Type
ToVarType
(
std
::
type_index
type
)
{
inline
proto
::
VarType
::
Type
ToVarType
(
std
::
type_index
type
)
{
if
(
type
.
hash_code
()
==
typeid
(
LoDTensor
).
hash_code
())
{
if
(
type
.
hash_code
()
==
typeid
(
LoDTensor
).
hash_code
())
{
return
proto
::
VarType_Type_LOD_TENSOR
;
return
proto
::
VarType_Type_LOD_TENSOR
;
}
else
if
(
type
.
hash_code
()
==
typeid
(
SelectedRows
).
hash_code
())
{
}
else
if
(
type
.
hash_code
()
==
typeid
(
SelectedRows
).
hash_code
())
{
return
proto
::
VarType_Type_SELECTED_ROWS
;
return
proto
::
VarType_Type_SELECTED_ROWS
;
}
else
{
}
else
{
// PADDLE_THROW("ToVarType:Unsupported type %s",
// PADDLE_THROW("ToVarType:Unsupported type %s",
// type.name());
// type.name());
}
}
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/variable.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/io.cpp
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/io.h
浏览文件 @
47e5978f
...
@@ -27,14 +27,13 @@ SOFTWARE.
...
@@ -27,14 +27,13 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
template
<
typename
Dtype
,
Precision
P
=
Precision
::
FP32
>
template
<
typename
Dtype
,
Precision
P
=
Precision
::
FP32
>
class
Loader
:
PaddleMobileObject
{
class
Loader
:
PaddleMobileObject
{
public:
public:
const
framework
::
Program
<
Dtype
,
P
>
Load
(
const
std
::
string
&
dirname
);
const
framework
::
Program
<
Dtype
,
P
>
Load
(
const
std
::
string
&
dirname
);
private:
private:
void
LoadVar
(
framework
::
LoDTensor
*
tensor
,
void
LoadVar
(
framework
::
LoDTensor
*
tensor
,
const
std
::
string
&
file_path
);
const
std
::
string
&
file_path
);
};
};
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/memory/t_malloc.cc
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/memory/t_malloc.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/conv_op.cpp
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/conv_op.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/elementwise_add_op.cpp
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/elementwise_add_op.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/kernel/arm/conv_kernel.cpp
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/kernel/arm/elementwise_add_kernel.cpp
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/kernel/arm/mul_kernel.cpp
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/kernel/conv_kernel.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/kernel/elementwise_add_kernel.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/kernel/fpga/conv_kernel.cpp
浏览文件 @
47e5978f
...
@@ -25,4 +25,4 @@ namespace operators {
...
@@ -25,4 +25,4 @@ namespace operators {
//
//
// template class ConvKernel<FPGA, float>;
// template class ConvKernel<FPGA, float>;
}
}
}
}
// namespace paddle_mobile
src/operators/kernel/mul_kernel.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/math/elementwise_op_function.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/math/im2col.cc
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/math/im2col.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/math/math_function.cc
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/math/math_function.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/math/transform.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/math/vol2col.cc
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/math/vol2col.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/mul_op.cpp
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/mul_op.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/op_param.cpp
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
src/operators/op_param.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
test/common/test_log.cpp
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
test/elementwise_add_op_test.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
test/mul_op_test.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
test/test_include.h
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
tools/pre-commit.hooks/.clang-format.hook
0 → 100644
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
tools/pre-commit.hooks/.clang-tidy.hook
浏览文件 @
47e5978f
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录