Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
d6825bf3
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
338
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d6825bf3
编写于
5月 20, 2018
作者:
朔-望
提交者:
GitHub
5月 20, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #218 from allonli/develop
update code style to Google
上级
7b384fb2
7cc864ea
变更
98
展开全部
显示空白变更内容
内联
并排
Showing
98 changed file
with
2436 addition
and
2451 deletion
+2436
-2451
.clang-format
.clang-format
+1
-1
src/common/log.h
src/common/log.h
+27
-21
src/common/type_define.h
src/common/type_define.h
+5
-4
src/common/types.h
src/common/types.h
+3
-2
src/common/variant.h
src/common/variant.h
+23
-11
src/framework/attribute.cpp
src/framework/attribute.cpp
+1
-1
src/framework/attribute.h
src/framework/attribute.h
+65
-60
src/framework/block_desc.cpp
src/framework/block_desc.cpp
+2
-2
src/framework/block_desc.h
src/framework/block_desc.h
+7
-6
src/framework/data_layout.h
src/framework/data_layout.h
+11
-11
src/framework/data_transform.cpp
src/framework/data_transform.cpp
+2
-2
src/framework/data_transform.h
src/framework/data_transform.h
+2
-2
src/framework/data_type.h
src/framework/data_type.h
+1
-1
src/framework/ddim.cc
src/framework/ddim.cc
+66
-51
src/framework/ddim.h
src/framework/ddim.h
+10
-6
src/framework/dim.h
src/framework/dim.h
+53
-29
src/framework/executor.cpp
src/framework/executor.cpp
+2
-2
src/framework/executor.h
src/framework/executor.h
+6
-5
src/framework/framework.pb.cpp
src/framework/framework.pb.cpp
+1292
-1332
src/framework/framework.pb.h
src/framework/framework.pb.h
+298
-315
src/framework/lod_tensor.cc
src/framework/lod_tensor.cc
+16
-26
src/framework/lod_tensor.h
src/framework/lod_tensor.h
+8
-9
src/framework/op_desc.cpp
src/framework/op_desc.cpp
+2
-2
src/framework/op_desc.h
src/framework/op_desc.h
+4
-4
src/framework/op_info.h
src/framework/op_info.h
+12
-8
src/framework/op_kernel_type.h
src/framework/op_kernel_type.h
+2
-2
src/framework/op_proto_maker.h
src/framework/op_proto_maker.h
+2
-2
src/framework/operator.cpp
src/framework/operator.cpp
+6
-3
src/framework/operator.h
src/framework/operator.h
+11
-9
src/framework/paddle_mobile_object.h
src/framework/paddle_mobile_object.h
+4
-4
src/framework/program-optimize/node.cpp
src/framework/program-optimize/node.cpp
+2
-2
src/framework/program-optimize/node.h
src/framework/program-optimize/node.h
+4
-4
src/framework/program-optimize/program_optimize.cpp
src/framework/program-optimize/program_optimize.cpp
+4
-4
src/framework/program-optimize/program_optimize.h
src/framework/program-optimize/program_optimize.h
+6
-6
src/framework/program.cpp
src/framework/program.cpp
+1
-1
src/framework/program.h
src/framework/program.h
+4
-4
src/framework/program_desc.cpp
src/framework/program_desc.cpp
+2
-2
src/framework/program_desc.h
src/framework/program_desc.h
+4
-4
src/framework/scope.cc
src/framework/scope.cc
+2
-2
src/framework/scope.h
src/framework/scope.h
+7
-7
src/framework/selected_rows.h
src/framework/selected_rows.h
+4
-4
src/framework/tensor.h
src/framework/tensor.h
+20
-12
src/framework/tensor_util.cc
src/framework/tensor_util.cc
+15
-11
src/framework/tensor_util.h
src/framework/tensor_util.h
+3
-3
src/framework/var_desc.cpp
src/framework/var_desc.cpp
+2
-2
src/framework/var_desc.h
src/framework/var_desc.h
+21
-21
src/framework/var_type.h
src/framework/var_type.h
+2
-2
src/framework/variable.h
src/framework/variable.h
+13
-9
src/io.cpp
src/io.cpp
+101
-167
src/io.h
src/io.h
+3
-3
src/memory/t_malloc.cc
src/memory/t_malloc.cc
+2
-2
src/memory/t_malloc.h
src/memory/t_malloc.h
+8
-6
src/operators/batchnorm_op.cpp
src/operators/batchnorm_op.cpp
+2
-2
src/operators/batchnorm_op.h
src/operators/batchnorm_op.h
+4
-4
src/operators/concat_op.cpp
src/operators/concat_op.cpp
+2
-2
src/operators/concat_op.h
src/operators/concat_op.h
+4
-4
src/operators/conv_op.cpp
src/operators/conv_op.cpp
+2
-2
src/operators/conv_op.h
src/operators/conv_op.h
+4
-4
src/operators/elementwise_add_op.cpp
src/operators/elementwise_add_op.cpp
+2
-2
src/operators/elementwise_add_op.h
src/operators/elementwise_add_op.h
+4
-4
src/operators/kernel/arm/batchnorm_kernel.cpp
src/operators/kernel/arm/batchnorm_kernel.cpp
+2
-2
src/operators/kernel/arm/concat_kernel.cpp
src/operators/kernel/arm/concat_kernel.cpp
+5
-4
src/operators/kernel/arm/conv_kernel.cpp
src/operators/kernel/arm/conv_kernel.cpp
+4
-3
src/operators/kernel/arm/elementwise_add_kernel.cpp
src/operators/kernel/arm/elementwise_add_kernel.cpp
+4
-3
src/operators/kernel/arm/lrn_kernel.cpp
src/operators/kernel/arm/lrn_kernel.cpp
+4
-3
src/operators/kernel/arm/mul_kernel.cpp
src/operators/kernel/arm/mul_kernel.cpp
+4
-3
src/operators/kernel/arm/pool_kernel.cpp
src/operators/kernel/arm/pool_kernel.cpp
+5
-4
src/operators/kernel/batchnorm_kernel.h
src/operators/kernel/batchnorm_kernel.h
+3
-3
src/operators/kernel/concat_kernel.h
src/operators/kernel/concat_kernel.h
+3
-3
src/operators/kernel/conv_kernel.h
src/operators/kernel/conv_kernel.h
+3
-3
src/operators/kernel/elementwise_add_kernel.h
src/operators/kernel/elementwise_add_kernel.h
+3
-3
src/operators/kernel/fpga/conv_kernel.cpp
src/operators/kernel/fpga/conv_kernel.cpp
+1
-1
src/operators/kernel/lrn_kernel.h
src/operators/kernel/lrn_kernel.h
+5
-4
src/operators/kernel/mul_kernel.h
src/operators/kernel/mul_kernel.h
+3
-3
src/operators/kernel/pool_kernel.h
src/operators/kernel/pool_kernel.h
+3
-3
src/operators/lrn_op.cpp
src/operators/lrn_op.cpp
+4
-3
src/operators/lrn_op.h
src/operators/lrn_op.h
+4
-4
src/operators/math/elementwise_op_function.h
src/operators/math/elementwise_op_function.h
+18
-14
src/operators/math/im2col.cc
src/operators/math/im2col.cc
+15
-11
src/operators/math/im2col.h
src/operators/math/im2col.h
+5
-5
src/operators/math/math_function.cc
src/operators/math/math_function.cc
+3
-3
src/operators/math/math_function.h
src/operators/math/math_function.h
+4
-4
src/operators/math/pool3x3.h
src/operators/math/pool3x3.h
+1
-1
src/operators/math/pool_2x2.h
src/operators/math/pool_2x2.h
+1
-1
src/operators/math/pooling.cpp
src/operators/math/pooling.cpp
+4
-5
src/operators/math/pooling.h
src/operators/math/pooling.h
+9
-7
src/operators/math/transform.h
src/operators/math/transform.h
+3
-3
src/operators/math/vol2col.cc
src/operators/math/vol2col.cc
+9
-7
src/operators/math/vol2col.h
src/operators/math/vol2col.h
+9
-7
src/operators/mul_op.cpp
src/operators/mul_op.cpp
+4
-3
src/operators/mul_op.h
src/operators/mul_op.h
+4
-4
src/operators/op_param.cpp
src/operators/op_param.cpp
+2
-2
src/operators/op_param.h
src/operators/op_param.h
+18
-18
src/operators/pool_op.cpp
src/operators/pool_op.cpp
+2
-2
src/operators/pool_op.h
src/operators/pool_op.h
+4
-4
src/platform/data_type.h
src/platform/data_type.h
+56
-56
src/platform/macros.h
src/platform/macros.h
+5
-5
test/operators/test_cov_op.cpp
test/operators/test_cov_op.cpp
+2
-2
未找到文件。
.clang-format
浏览文件 @
d6825bf3
---
---
Language: Cpp
Language: Cpp
BasedOnStyle:
LLVM
BasedOnStyle:
Google
Standard: Cpp11
Standard: Cpp11
...
...
src/common/log.h
浏览文件 @
d6825bf3
...
@@ -51,12 +51,13 @@ struct Print;
...
@@ -51,12 +51,13 @@ struct Print;
struct
Print
{
struct
Print
{
friend
struct
ToLog
;
friend
struct
ToLog
;
template
<
typename
T
>
Print
&
operator
<<
(
T
const
&
value
)
{
template
<
typename
T
>
Print
&
operator
<<
(
T
const
&
value
)
{
buffer_
<<
value
;
buffer_
<<
value
;
return
*
this
;
return
*
this
;
}
}
private:
private:
void
print
(
LogLevel
level
)
{
void
print
(
LogLevel
level
)
{
buffer_
<<
std
::
endl
;
buffer_
<<
std
::
endl
;
if
(
level
==
kLOG_ERROR
)
{
if
(
level
==
kLOG_ERROR
)
{
...
@@ -76,14 +77,15 @@ struct ToLog {
...
@@ -76,14 +77,15 @@ struct ToLog {
printer_
<<
logs
[
level
]
<<
" "
<<
info
<<
":"
<<
std
::
string
(
blanks
,
' '
);
printer_
<<
logs
[
level
]
<<
" "
<<
info
<<
":"
<<
std
::
string
(
blanks
,
' '
);
}
}
template
<
typename
T
>
ToLog
&
operator
<<
(
T
const
&
value
)
{
template
<
typename
T
>
ToLog
&
operator
<<
(
T
const
&
value
)
{
printer_
<<
value
;
printer_
<<
value
;
return
*
this
;
return
*
this
;
}
}
~
ToLog
()
{
printer_
.
print
(
level_
);
}
~
ToLog
()
{
printer_
.
print
(
level_
);
}
private:
private:
LogLevel
level_
;
LogLevel
level_
;
Print
printer_
;
Print
printer_
;
};
};
...
@@ -140,15 +142,19 @@ enum LogLevel {
...
@@ -140,15 +142,19 @@ enum LogLevel {
struct
ToLog
;
struct
ToLog
;
struct
Print
{
struct
Print
{
friend
struct
ToLog
;
friend
struct
ToLog
;
template
<
typename
T
>
Print
&
operator
<<
(
T
const
&
value
)
{}
template
<
typename
T
>
Print
&
operator
<<
(
T
const
&
value
)
{}
private:
private:
};
};
struct
ToLog
{
struct
ToLog
{
ToLog
(
LogLevel
level
)
{}
ToLog
(
LogLevel
level
)
{}
template
<
typename
T
>
ToLog
&
operator
<<
(
T
const
&
value
)
{
return
*
this
;
}
template
<
typename
T
>
ToLog
&
operator
<<
(
T
const
&
value
)
{
return
*
this
;
}
};
};
#define LOG(level) \
#define LOG(level) \
...
...
src/common/type_define.h
浏览文件 @
d6825bf3
...
@@ -17,14 +17,15 @@ SOFTWARE.
...
@@ -17,14 +17,15 @@ SOFTWARE.
==============================================================================*/
==============================================================================*/
#pragma once;
#pragma once;
#include "framework/attribute.h"
#include <map>
#include <map>
#include <string>
#include <string>
#include "framework/attribute.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
template
<
typename
Dtype
>
class
OperatorBase
;
template
<
typename
Dtype
>
class
OperatorBase
;
class
OpDesc
;
class
OpDesc
;
class
BlockDesc
;
class
BlockDesc
;
class
InferShapeContext
;
class
InferShapeContext
;
...
...
src/common/types.h
浏览文件 @
d6825bf3
...
@@ -24,7 +24,8 @@ enum class Precision : int { FP32 = 0 };
...
@@ -24,7 +24,8 @@ enum class Precision : int { FP32 = 0 };
//! device type
//! device type
enum
DeviceTypeEnum
{
kINVALID
=
-
1
,
kCPU
=
0
,
kFPGA
=
1
,
kGPU_MALI
=
2
};
enum
DeviceTypeEnum
{
kINVALID
=
-
1
,
kCPU
=
0
,
kFPGA
=
1
,
kGPU_MALI
=
2
};
template
<
DeviceTypeEnum
T
>
struct
DeviceType
{};
template
<
DeviceTypeEnum
T
>
struct
DeviceType
{};
typedef
DeviceType
<
kCPU
>
CPU
;
typedef
DeviceType
<
kCPU
>
CPU
;
typedef
DeviceType
<
kFPGA
>
FPGA
;
typedef
DeviceType
<
kFPGA
>
FPGA
;
...
...
src/common/variant.h
浏览文件 @
d6825bf3
...
@@ -21,9 +21,13 @@ SOFTWARE.
...
@@ -21,9 +21,13 @@ SOFTWARE.
#pragma once
#pragma once
namespace
paddle_mobile
{
namespace
paddle_mobile
{
template
<
int
ID
,
typename
Type
>
struct
IDToType
{
typedef
Type
type_t
;
};
template
<
int
ID
,
typename
Type
>
struct
IDToType
{
typedef
Type
type_t
;
};
template
<
typename
F
,
typename
...
Ts
>
struct
VariantHelper
{
template
<
typename
F
,
typename
...
Ts
>
struct
VariantHelper
{
static
const
size_t
size
=
sizeof
(
F
)
>
VariantHelper
<
Ts
...
>::
size
static
const
size_t
size
=
sizeof
(
F
)
>
VariantHelper
<
Ts
...
>::
size
?
sizeof
(
F
)
?
sizeof
(
F
)
:
VariantHelper
<
Ts
...
>::
size
;
:
VariantHelper
<
Ts
...
>::
size
;
...
@@ -37,7 +41,8 @@ template <typename F, typename... Ts> struct VariantHelper {
...
@@ -37,7 +41,8 @@ template <typename F, typename... Ts> struct VariantHelper {
}
}
};
};
template
<
typename
F
>
struct
VariantHelper
<
F
>
{
template
<
typename
F
>
struct
VariantHelper
<
F
>
{
static
const
size_t
size
=
sizeof
(
F
);
static
const
size_t
size
=
sizeof
(
F
);
inline
static
void
Destroy
(
size_t
id
,
void
*
data
)
{
inline
static
void
Destroy
(
size_t
id
,
void
*
data
)
{
if
(
id
==
typeid
(
F
).
hash_code
())
{
if
(
id
==
typeid
(
F
).
hash_code
())
{
...
@@ -48,8 +53,9 @@ template <typename F> struct VariantHelper<F> {
...
@@ -48,8 +53,9 @@ template <typename F> struct VariantHelper<F> {
}
}
};
};
template
<
size_t
size
>
class
RawData
{
template
<
size_t
size
>
public:
class
RawData
{
public:
char
data
[
size
];
char
data
[
size
];
RawData
()
{}
RawData
()
{}
RawData
(
const
RawData
&
raw_data
)
{
strcpy
(
data
,
raw_data
.
data
);
}
RawData
(
const
RawData
&
raw_data
)
{
strcpy
(
data
,
raw_data
.
data
);
}
...
@@ -58,7 +64,8 @@ public:
...
@@ -58,7 +64,8 @@ public:
// }
// }
};
};
template
<
typename
...
Ts
>
struct
Variant
{
template
<
typename
...
Ts
>
struct
Variant
{
Variant
(
const
Variant
&
variant
)
{
Variant
(
const
Variant
&
variant
)
{
// std::cout << " 赋值构造函数 " << std::endl;
// std::cout << " 赋值构造函数 " << std::endl;
type_id
=
variant
.
type_id
;
type_id
=
variant
.
type_id
;
...
@@ -70,13 +77,15 @@ template <typename... Ts> struct Variant {
...
@@ -70,13 +77,15 @@ template <typename... Ts> struct Variant {
// helper::Destroy(type_id, &data);
// helper::Destroy(type_id, &data);
}
}
template
<
typename
T
,
typename
...
Args
>
void
Set
(
Args
&&
...
args
)
{
template
<
typename
T
,
typename
...
Args
>
void
Set
(
Args
&&
...
args
)
{
helper
::
Destroy
(
type_id
,
&
data
);
helper
::
Destroy
(
type_id
,
&
data
);
new
(
&
data
)
T
(
std
::
forward
<
Args
>
(
args
)...);
new
(
&
data
)
T
(
std
::
forward
<
Args
>
(
args
)...);
type_id
=
typeid
(
T
).
hash_code
();
type_id
=
typeid
(
T
).
hash_code
();
}
}
template
<
typename
T
>
T
&
Get
()
const
{
template
<
typename
T
>
T
&
Get
()
const
{
if
(
type_id
==
typeid
(
T
).
hash_code
())
{
if
(
type_id
==
typeid
(
T
).
hash_code
())
{
return
*
const_cast
<
T
*>
(
reinterpret_cast
<
const
T
*>
(
&
data
));
return
*
const_cast
<
T
*>
(
reinterpret_cast
<
const
T
*>
(
&
data
));
}
else
{
}
else
{
...
@@ -87,13 +96,16 @@ template <typename... Ts> struct Variant {
...
@@ -87,13 +96,16 @@ template <typename... Ts> struct Variant {
size_t
TypeId
()
const
{
return
type_id
;
}
size_t
TypeId
()
const
{
return
type_id
;
}
private:
private:
static
inline
size_t
invalid_type
()
{
return
typeid
(
void
).
hash_code
();
}
static
inline
size_t
invalid_type
()
{
return
typeid
(
void
).
hash_code
();
}
typedef
VariantHelper
<
Ts
...
>
helper
;
typedef
VariantHelper
<
Ts
...
>
helper
;
size_t
type_id
;
size_t
type_id
;
RawData
<
helper
::
size
>
data
;
RawData
<
helper
::
size
>
data
;
};
};
template
<
typename
T
>
struct
Vistor
{
typedef
T
type_t
;
};
template
<
typename
T
>
struct
Vistor
{
typedef
T
type_t
;
};
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/attribute.cpp
浏览文件 @
d6825bf3
src/framework/attribute.h
浏览文件 @
d6825bf3
...
@@ -27,7 +27,7 @@ namespace framework {
...
@@ -27,7 +27,7 @@ namespace framework {
class
BlockDesc
;
class
BlockDesc
;
class
Attribute
{
class
Attribute
{
public:
public:
static
Attribute
GetAttrValue
(
const
proto
::
OpDesc
::
Attr
&
attr_desc
)
{
static
Attribute
GetAttrValue
(
const
proto
::
OpDesc
::
Attr
&
attr_desc
)
{
// std::cout << "begin get attr value" << std::endl;
// std::cout << "begin get attr value" << std::endl;
Attribute
attr
;
Attribute
attr
;
...
@@ -93,14 +93,18 @@ public:
...
@@ -93,14 +93,18 @@ public:
}
}
Attribute
()
{}
Attribute
()
{}
template
<
typename
T
,
typename
...
Args
>
Attribute
&
Set
(
Args
&&
...
args
)
{
template
<
typename
T
,
typename
...
Args
>
Attribute
&
Set
(
Args
&&
...
args
)
{
variant_
.
Set
<
T
>
(
args
...);
variant_
.
Set
<
T
>
(
args
...);
return
*
this
;
return
*
this
;
}
}
template
<
typename
T
>
T
&
Get
()
const
{
return
variant_
.
Get
<
T
>
();
}
template
<
typename
T
>
T
&
Get
()
const
{
return
variant_
.
Get
<
T
>
();
}
private:
private:
Variant
<
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
std
::
vector
<
float
>
,
Variant
<
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
bool
>
,
BlockDesc
*
,
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
bool
>
,
BlockDesc
*
,
int64_t
>
int64_t
>
...
@@ -110,10 +114,11 @@ private:
...
@@ -110,10 +114,11 @@ private:
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
class
AttrReader
{
class
AttrReader
{
public:
public:
explicit
AttrReader
(
const
AttributeMap
&
attrs
)
:
attrs_
(
attrs
)
{}
explicit
AttrReader
(
const
AttributeMap
&
attrs
)
:
attrs_
(
attrs
)
{}
template
<
typename
T
>
inline
T
Get
(
const
std
::
string
&
name
)
const
{
template
<
typename
T
>
inline
T
Get
(
const
std
::
string
&
name
)
const
{
// PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should
// PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should
// be in
// be in
// AttributeMap",
// AttributeMap",
...
@@ -121,7 +126,7 @@ public:
...
@@ -121,7 +126,7 @@ public:
return
((
Attribute
)
attrs_
.
at
(
name
)).
Get
<
T
>
();
return
((
Attribute
)
attrs_
.
at
(
name
)).
Get
<
T
>
();
}
}
private:
private:
const
AttributeMap
&
attrs_
;
const
AttributeMap
&
attrs_
;
};
};
...
...
src/framework/block_desc.cpp
浏览文件 @
d6825bf3
src/framework/block_desc.h
浏览文件 @
d6825bf3
...
@@ -27,7 +27,7 @@ namespace paddle_mobile {
...
@@ -27,7 +27,7 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
BlockDesc
:
PaddleMobileObject
{
class
BlockDesc
:
PaddleMobileObject
{
public:
public:
BlockDesc
(
const
proto
::
BlockDesc
&
desc
);
BlockDesc
(
const
proto
::
BlockDesc
&
desc
);
const
int
&
ID
()
const
{
return
desc_
.
idx
();
}
const
int
&
ID
()
const
{
return
desc_
.
idx
();
}
...
@@ -45,7 +45,7 @@ public:
...
@@ -45,7 +45,7 @@ public:
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
Vars
()
const
;
std
::
vector
<
std
::
shared_ptr
<
VarDesc
>>
Vars
()
const
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
Ops
()
const
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
Ops
()
const
;
private:
private:
proto
::
BlockDesc
desc_
;
proto
::
BlockDesc
desc_
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
ops_
;
std
::
vector
<
std
::
shared_ptr
<
OpDesc
>>
ops_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarDesc
>>
vars_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarDesc
>>
vars_
;
...
@@ -56,7 +56,8 @@ private:
...
@@ -56,7 +56,8 @@ private:
namespace
std
{
namespace
std
{
template
<
>
struct
hash
<
paddle_mobile
::
framework
::
BlockDesc
>
{
template
<
>
struct
hash
<
paddle_mobile
::
framework
::
BlockDesc
>
{
typedef
paddle_mobile
::
framework
::
BlockDesc
argument_type
;
typedef
paddle_mobile
::
framework
::
BlockDesc
argument_type
;
typedef
std
::
size_t
result_type
;
typedef
std
::
size_t
result_type
;
result_type
operator
()(
argument_type
const
&
s
)
const
noexcept
{
result_type
operator
()(
argument_type
const
&
s
)
const
noexcept
{
...
...
src/framework/data_layout.h
浏览文件 @
d6825bf3
src/framework/data_transform.cpp
浏览文件 @
d6825bf3
src/framework/data_transform.h
浏览文件 @
d6825bf3
src/framework/data_type.h
浏览文件 @
d6825bf3
src/framework/ddim.cc
浏览文件 @
d6825bf3
...
@@ -19,11 +19,15 @@ namespace framework {
...
@@ -19,11 +19,15 @@ namespace framework {
/// @cond HIDDEN
/// @cond HIDDEN
template
<
int
i
>
Dim
<
i
>
make_dim
(
const
int64_t
*
d
)
{
template
<
int
i
>
Dim
<
i
>
make_dim
(
const
int64_t
*
d
)
{
return
Dim
<
i
>
(
*
d
,
make_dim
<
i
-
1
>
(
d
+
1
));
return
Dim
<
i
>
(
*
d
,
make_dim
<
i
-
1
>
(
d
+
1
));
}
}
template
<
>
Dim
<
0
>
make_dim
<
0
>
(
const
int64_t
*
d
)
{
return
Dim
<
0
>
(
*
d
);
}
template
<
>
Dim
<
0
>
make_dim
<
0
>
(
const
int64_t
*
d
)
{
return
Dim
<
0
>
(
*
d
);
}
void
make_ddim
(
DDim
&
ddim
,
const
int64_t
*
dims
,
int
n
)
{
void
make_ddim
(
DDim
&
ddim
,
const
int64_t
*
dims
,
int
n
)
{
switch
(
n
)
{
switch
(
n
)
{
...
@@ -90,24 +94,28 @@ DDim make_ddim(const std::vector<int> &dims) {
...
@@ -90,24 +94,28 @@ DDim make_ddim(const std::vector<int> &dims) {
// XXX For some reason, putting this in an anonymous namespace causes
// XXX For some reason, putting this in an anonymous namespace causes
// errors
// errors
struct
DynamicMutableIndexer
:
Vistor
<
int64_t
&>
{
struct
DynamicMutableIndexer
:
Vistor
<
int64_t
&>
{
public:
public:
explicit
DynamicMutableIndexer
(
int
idx
)
:
idx_
(
idx
)
{}
explicit
DynamicMutableIndexer
(
int
idx
)
:
idx_
(
idx
)
{}
template
<
int
D
>
int64_t
&
operator
()(
Dim
<
D
>
&
dim
)
const
{
return
dim
[
idx_
];
}
template
<
int
D
>
int64_t
&
operator
()(
Dim
<
D
>
&
dim
)
const
{
return
dim
[
idx_
];
}
private:
private:
int
idx_
;
int
idx_
;
};
};
struct
DynamicConstIndexer
:
public
Vistor
<
int64_t
>
{
struct
DynamicConstIndexer
:
public
Vistor
<
int64_t
>
{
public:
public:
explicit
DynamicConstIndexer
(
int
idx
)
:
idx_
(
idx
)
{}
explicit
DynamicConstIndexer
(
int
idx
)
:
idx_
(
idx
)
{}
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>
&
dim
)
const
{
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>
&
dim
)
const
{
return
dim
[
idx_
];
return
dim
[
idx_
];
}
}
private:
private:
int
idx_
;
int
idx_
;
};
};
...
@@ -182,7 +190,8 @@ struct VectorizeVisitor : Vistor<void> {
...
@@ -182,7 +190,8 @@ struct VectorizeVisitor : Vistor<void> {
explicit
VectorizeVisitor
(
std
::
vector
<
int64_t
>
&
v
)
:
vector
(
v
)
{}
explicit
VectorizeVisitor
(
std
::
vector
<
int64_t
>
&
v
)
:
vector
(
v
)
{}
template
<
typename
T
>
void
operator
()(
const
T
&
t
)
{
template
<
typename
T
>
void
operator
()(
const
T
&
t
)
{
vector
.
push_back
(
t
.
head
);
vector
.
push_back
(
t
.
head
);
this
->
operator
()(
t
.
tail
);
this
->
operator
()(
t
.
tail
);
}
}
...
@@ -207,7 +216,8 @@ std::vector<int> vectorize2int(const DDim &ddim) {
...
@@ -207,7 +216,8 @@ std::vector<int> vectorize2int(const DDim &ddim) {
}
}
struct
ProductVisitor
:
Vistor
<
int64_t
>
{
struct
ProductVisitor
:
Vistor
<
int64_t
>
{
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>
&
dim
)
{
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>
&
dim
)
{
return
product
(
dim
);
return
product
(
dim
);
}
}
};
};
...
@@ -233,7 +243,8 @@ struct SliceVectorizeVisitor : Vistor<void> {
...
@@ -233,7 +243,8 @@ struct SliceVectorizeVisitor : Vistor<void> {
// ddim slice.");
// ddim slice.");
}
}
template
<
int
S
>
void
operator
()(
const
Dim
<
S
>
&
dim
)
{
template
<
int
S
>
void
operator
()(
const
Dim
<
S
>
&
dim
)
{
if
(
begin
==
0
)
{
if
(
begin
==
0
)
{
vector
.
push_back
(
dim
.
head
);
vector
.
push_back
(
dim
.
head
);
}
else
{
}
else
{
...
@@ -264,7 +275,10 @@ DDim slice_ddim(const DDim &ddim, int begin, int end) {
...
@@ -264,7 +275,10 @@ DDim slice_ddim(const DDim &ddim, int begin, int end) {
/// \cond HIDDEN
/// \cond HIDDEN
struct
ArityVisitor
:
Vistor
<
int
>
{
struct
ArityVisitor
:
Vistor
<
int
>
{
template
<
int
D
>
int
operator
()(
Dim
<
D
>
)
const
{
return
D
;
}
template
<
int
D
>
int
operator
()(
Dim
<
D
>
)
const
{
return
D
;
}
};
};
/// \endcond
/// \endcond
...
@@ -282,11 +296,12 @@ int arity(const DDim &d) {
...
@@ -282,11 +296,12 @@ int arity(const DDim &d) {
struct
OSVistor
:
Vistor
<
std
::
ostream
&>
{
struct
OSVistor
:
Vistor
<
std
::
ostream
&>
{
OSVistor
(
std
::
ostream
&
os
)
:
os_
(
os
)
{}
OSVistor
(
std
::
ostream
&
os
)
:
os_
(
os
)
{}
template
<
int
D
>
std
::
ostream
&
operator
()(
Dim
<
D
>
dim
)
const
{
template
<
int
D
>
std
::
ostream
&
operator
()(
Dim
<
D
>
dim
)
const
{
return
os_
<<
dim
;
return
os_
<<
dim
;
}
}
private:
private:
std
::
ostream
&
os_
;
std
::
ostream
&
os_
;
};
};
...
...
src/framework/ddim.h
浏览文件 @
d6825bf3
...
@@ -14,12 +14,12 @@ limitations under the License. */
...
@@ -14,12 +14,12 @@ limitations under the License. */
#pragma once
#pragma once
#include "common/variant.h"
#include "dim.h"
#include <assert.h>
#include <assert.h>
#include <initializer_list>
#include <initializer_list>
#include <stdexcept>
#include <stdexcept>
#include <vector>
#include <vector>
#include "common/variant.h"
#include "dim.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
...
@@ -66,11 +66,15 @@ struct DDim {
...
@@ -66,11 +66,15 @@ struct DDim {
DDim
()
{
var
.
Set
<
Dim
<
1
>>
(
Dim
<
1
>
());
}
DDim
()
{
var
.
Set
<
Dim
<
1
>>
(
Dim
<
1
>
());
}
template
<
int
D
>
explicit
DDim
(
const
Dim
<
D
>
&
in
)
{
var
.
Set
<
Dim
<
D
>>
(
in
);
}
template
<
int
D
>
explicit
DDim
(
const
Dim
<
D
>
&
in
)
{
var
.
Set
<
Dim
<
D
>>
(
in
);
}
/*implicit*/
DDim
(
std
::
initializer_list
<
int64_t
>
init_list
);
/*implicit*/
DDim
(
std
::
initializer_list
<
int64_t
>
init_list
);
template
<
int
D
>
DDim
&
operator
=
(
const
Dim
<
D
>
&
in
)
{
template
<
int
D
>
DDim
&
operator
=
(
const
Dim
<
D
>
&
in
)
{
var
.
Set
<
Dim
<
D
>>
(
in
);
var
.
Set
<
Dim
<
D
>>
(
in
);
return
*
this
;
return
*
this
;
}
}
...
...
src/framework/dim.h
浏览文件 @
d6825bf3
...
@@ -24,7 +24,8 @@ namespace paddle_mobile {
...
@@ -24,7 +24,8 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
// Statically sized, statically indexed dimension
// Statically sized, statically indexed dimension
template
<
int
i
>
struct
Dim
{
template
<
int
i
>
struct
Dim
{
static
constexpr
int
dimensions
=
i
;
static
constexpr
int
dimensions
=
i
;
template
<
typename
...
Args
>
template
<
typename
...
Args
>
...
@@ -70,7 +71,8 @@ template <int i> struct Dim {
...
@@ -70,7 +71,8 @@ template <int i> struct Dim {
};
};
// Base case specialization
// Base case specialization
template
<
>
struct
Dim
<
0
>
{
template
<
>
struct
Dim
<
0
>
{
static
constexpr
int
dimensions
=
0
;
static
constexpr
int
dimensions
=
0
;
HOSTDEVICE
HOSTDEVICE
...
@@ -105,28 +107,37 @@ template <> struct Dim<0> {
...
@@ -105,28 +107,37 @@ template <> struct Dim<0> {
namespace
{
namespace
{
// Helper for accessing Dim classes
// Helper for accessing Dim classes
template
<
int
i
>
struct
DimGetter
{
template
<
int
i
>
struct
DimGetter
{
// Return a copy if Dim is const
// Return a copy if Dim is const
template
<
typename
D
>
HOSTDEVICE
static
int64_t
impl
(
const
D
&
d
)
{
template
<
typename
D
>
HOSTDEVICE
static
int64_t
impl
(
const
D
&
d
)
{
return
DimGetter
<
i
-
1
>::
impl
(
d
.
tail
);
return
DimGetter
<
i
-
1
>::
impl
(
d
.
tail
);
}
}
// Return a reference if Dim is mutable
// Return a reference if Dim is mutable
template
<
typename
D
>
HOSTDEVICE
static
int64_t
&
impl
(
D
&
d
)
{
template
<
typename
D
>
HOSTDEVICE
static
int64_t
&
impl
(
D
&
d
)
{
return
DimGetter
<
i
-
1
>::
impl
(
d
.
tail
);
return
DimGetter
<
i
-
1
>::
impl
(
d
.
tail
);
}
}
};
};
// Eureka! We found the element!
// Eureka! We found the element!
template
<
>
struct
DimGetter
<
0
>
{
template
<
>
struct
DimGetter
<
0
>
{
// Return a copy if Dim is const
// Return a copy if Dim is const
template
<
typename
D
>
HOSTDEVICE
static
int64_t
impl
(
const
D
&
d
)
{
template
<
typename
D
>
HOSTDEVICE
static
int64_t
impl
(
const
D
&
d
)
{
return
d
.
head
;
return
d
.
head
;
}
}
// Return a reference if Dim is mutable
// Return a reference if Dim is mutable
template
<
typename
D
>
HOSTDEVICE
static
int64_t
&
impl
(
D
&
d
)
{
return
d
.
head
;
}
template
<
typename
D
>
HOSTDEVICE
static
int64_t
&
impl
(
D
&
d
)
{
return
d
.
head
;
}
};
};
template
<
int
D
>
HOSTDEVICE
int64_t
&
indexer
(
Dim
<
D
>
&
dim
,
int
idx
)
{
template
<
int
D
>
HOSTDEVICE
int64_t
&
indexer
(
Dim
<
D
>
&
dim
,
int
idx
)
{
#ifndef __CUDA_ARCH__
#ifndef __CUDA_ARCH__
if
(
idx
<
0
)
{
if
(
idx
<
0
)
{
throw
std
::
invalid_argument
(
"Tried to access a negative dimension"
);
throw
std
::
invalid_argument
(
"Tried to access a negative dimension"
);
...
@@ -140,7 +151,8 @@ template <int D> HOSTDEVICE int64_t &indexer(Dim<D> &dim, int idx) {
...
@@ -140,7 +151,8 @@ template <int D> HOSTDEVICE int64_t &indexer(Dim<D> &dim, int idx) {
return
indexer
(
dim
.
tail
,
idx
-
1
);
return
indexer
(
dim
.
tail
,
idx
-
1
);
}
}
template
<
>
HOSTDEVICE
int64_t
&
indexer
<
0
>
(
Dim
<
0
>
&
dim
,
int
idx
)
{
template
<
>
HOSTDEVICE
int64_t
&
indexer
<
0
>
(
Dim
<
0
>
&
dim
,
int
idx
)
{
#ifndef __CUDA_ARCH__
#ifndef __CUDA_ARCH__
throw
std
::
invalid_argument
(
"Invalid index"
);
throw
std
::
invalid_argument
(
"Invalid index"
);
#else
#else
...
@@ -156,7 +168,8 @@ template <> HOSTDEVICE int64_t &indexer<0>(Dim<0> &dim, int idx) {
...
@@ -156,7 +168,8 @@ template <> HOSTDEVICE int64_t &indexer<0>(Dim<0> &dim, int idx) {
#endif
#endif
}
}
template
<
int
D
>
HOSTDEVICE
int64_t
indexer
(
const
Dim
<
D
>
&
dim
,
int
idx
)
{
template
<
int
D
>
HOSTDEVICE
int64_t
indexer
(
const
Dim
<
D
>
&
dim
,
int
idx
)
{
#ifndef __CUDA_ARCH__
#ifndef __CUDA_ARCH__
if
(
idx
<
0
)
{
if
(
idx
<
0
)
{
throw
std
::
invalid_argument
(
"Tried to access a negative dimension"
);
throw
std
::
invalid_argument
(
"Tried to access a negative dimension"
);
...
@@ -170,7 +183,8 @@ template <int D> HOSTDEVICE int64_t indexer(const Dim<D> &dim, int idx) {
...
@@ -170,7 +183,8 @@ template <int D> HOSTDEVICE int64_t indexer(const Dim<D> &dim, int idx) {
return
indexer
(
dim
.
tail
,
idx
-
1
);
return
indexer
(
dim
.
tail
,
idx
-
1
);
}
}
template
<
>
HOSTDEVICE
int64_t
indexer
<
0
>
(
const
Dim
<
0
>
&
dim
,
int
idx
)
{
template
<
>
HOSTDEVICE
int64_t
indexer
<
0
>
(
const
Dim
<
0
>
&
dim
,
int
idx
)
{
#ifndef __CUDA_ARCH__
#ifndef __CUDA_ARCH__
throw
std
::
invalid_argument
(
"Invalid index"
);
throw
std
::
invalid_argument
(
"Invalid index"
);
#else
#else
...
@@ -188,23 +202,27 @@ template <> HOSTDEVICE int64_t indexer<0>(const Dim<0> &dim, int idx) {
...
@@ -188,23 +202,27 @@ template <> HOSTDEVICE int64_t indexer<0>(const Dim<0> &dim, int idx) {
}
// namespace
}
// namespace
// Static access to constant Dim
// Static access to constant Dim
template
<
int
i
,
int
l
>
HOSTDEVICE
int64_t
get
(
const
Dim
<
l
>
&
d
)
{
template
<
int
i
,
int
l
>
HOSTDEVICE
int64_t
get
(
const
Dim
<
l
>
&
d
)
{
return
DimGetter
<
i
>::
impl
(
d
);
return
DimGetter
<
i
>::
impl
(
d
);
}
}
// Static access to mutable Dim
// Static access to mutable Dim
template
<
int
i
,
int
l
>
HOSTDEVICE
int64_t
&
get
(
Dim
<
l
>
&
d
)
{
template
<
int
i
,
int
l
>
HOSTDEVICE
int64_t
&
get
(
Dim
<
l
>
&
d
)
{
return
DimGetter
<
i
>::
impl
(
d
);
return
DimGetter
<
i
>::
impl
(
d
);
}
}
// Dynamic access to constant Dim
// Dynamic access to constant Dim
template
<
int
l
>
HOSTDEVICE
int64_t
Dim
<
l
>::
operator
[](
int
i
)
const
{
template
<
int
l
>
HOSTDEVICE
int64_t
Dim
<
l
>::
operator
[](
int
i
)
const
{
// std::cout << "l: " << l << std::endl;
// std::cout << "l: " << l << std::endl;
return
indexer
(
*
this
,
i
);
return
indexer
(
*
this
,
i
);
}
}
// Dynamic access to mutable Dim
// Dynamic access to mutable Dim
template
<
int
l
>
HOSTDEVICE
int64_t
&
Dim
<
l
>::
operator
[](
int
i
)
{
template
<
int
l
>
HOSTDEVICE
int64_t
&
Dim
<
l
>::
operator
[](
int
i
)
{
return
indexer
(
*
this
,
i
);
return
indexer
(
*
this
,
i
);
}
}
...
@@ -247,13 +265,15 @@ HOSTDEVICE inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) {
...
@@ -247,13 +265,15 @@ HOSTDEVICE inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) {
}
}
// Product of a Dim
// Product of a Dim
template
<
int
i
>
HOSTDEVICE
int64_t
product
(
const
Dim
<
i
>
&
a
,
int
prod
=
1
)
{
template
<
int
i
>
HOSTDEVICE
int64_t
product
(
const
Dim
<
i
>
&
a
,
int
prod
=
1
)
{
return
prod
*
a
.
head
*
product
(
a
.
tail
);
return
prod
*
a
.
head
*
product
(
a
.
tail
);
}
}
// Base case product of a Dim
// Base case product of a Dim
// Notice it is inline because it is no longer a template
// Notice it is inline because it is no longer a template
template
<
>
HOSTDEVICE
inline
int64_t
product
(
const
Dim
<
0
>
&
a
,
int
prod
)
{
template
<
>
HOSTDEVICE
inline
int64_t
product
(
const
Dim
<
0
>
&
a
,
int
prod
)
{
return
prod
;
return
prod
;
}
}
...
@@ -282,7 +302,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i> &src, int mul = 1) {
...
@@ -282,7 +302,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i> &src, int mul = 1) {
///\cond HIDDEN
///\cond HIDDEN
// Base case of ex_prefix_mul
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
// Notice it is inline because it is no longer a template
template
<
>
HOSTDEVICE
inline
Dim
<
0
>
ex_prefix_mul
(
const
Dim
<
0
>
&
src
,
int
mul
)
{
template
<
>
HOSTDEVICE
inline
Dim
<
0
>
ex_prefix_mul
(
const
Dim
<
0
>
&
src
,
int
mul
)
{
return
Dim
<
0
>
();
return
Dim
<
0
>
();
}
}
///\endcond
///\endcond
...
@@ -290,7 +311,8 @@ template <> HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
...
@@ -290,7 +311,8 @@ template <> HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
/**
/**
* Add two dimensions together
* Add two dimensions together
*/
*/
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
dim_plus
(
const
Dim
<
i
>
&
a
,
const
Dim
<
i
>
&
b
)
{
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
dim_plus
(
const
Dim
<
i
>
&
a
,
const
Dim
<
i
>
&
b
)
{
return
Dim
<
i
>
(
a
.
head
+
b
.
head
,
dim_plus
(
a
.
tail
,
b
.
tail
));
return
Dim
<
i
>
(
a
.
head
+
b
.
head
,
dim_plus
(
a
.
tail
,
b
.
tail
));
}
}
...
@@ -308,7 +330,8 @@ HOSTDEVICE Dim<i> operator+(const Dim<i> &lhs, const Dim<i> &rhs) {
...
@@ -308,7 +330,8 @@ HOSTDEVICE Dim<i> operator+(const Dim<i> &lhs, const Dim<i> &rhs) {
/**
/**
* Multiply two dimensions together
* Multiply two dimensions together
*/
*/
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
dim_mult
(
const
Dim
<
i
>
&
a
,
const
Dim
<
i
>
&
b
)
{
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
dim_mult
(
const
Dim
<
i
>
&
a
,
const
Dim
<
i
>
&
b
)
{
return
Dim
<
i
>
(
a
.
head
*
b
.
head
,
dim_mult
(
a
.
tail
,
b
.
tail
));
return
Dim
<
i
>
(
a
.
head
*
b
.
head
,
dim_mult
(
a
.
tail
,
b
.
tail
));
}
}
...
@@ -365,8 +388,8 @@ HOSTDEVICE Dim<sizeof...(Args)> make_dim(Args... idxes) {
...
@@ -365,8 +388,8 @@ HOSTDEVICE Dim<sizeof...(Args)> make_dim(Args... idxes) {
// Allows us to output a Dim
// Allows us to output a Dim
// XXX For some reason, overloading fails to resolve this correctly
// XXX For some reason, overloading fails to resolve this correctly
template
<
int
i
>
template
<
int
i
>
typename
std
::
enable_if
<
(
i
>
1
),
std
::
ostream
&>::
type
typename
std
::
enable_if
<
(
i
>
1
),
std
::
ostream
&>::
type
operator
<<
(
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
i
>
&
d
)
{
std
::
ostream
&
os
,
const
Dim
<
i
>
&
d
)
{
os
<<
d
.
head
<<
", "
<<
d
.
tail
;
os
<<
d
.
head
<<
", "
<<
d
.
tail
;
return
os
;
return
os
;
}
}
...
@@ -374,8 +397,8 @@ operator<<(std::ostream &os, const Dim<i> &d) {
...
@@ -374,8 +397,8 @@ operator<<(std::ostream &os, const Dim<i> &d) {
// Base case that allows us to output a Dim
// Base case that allows us to output a Dim
// XXX I wish this could be an overload instead of a template
// XXX I wish this could be an overload instead of a template
template
<
int
i
>
template
<
int
i
>
typename
std
::
enable_if
<
(
i
==
1
),
std
::
ostream
&>::
type
typename
std
::
enable_if
<
(
i
==
1
),
std
::
ostream
&>::
type
operator
<<
(
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
i
>
&
d
)
{
std
::
ostream
&
os
,
const
Dim
<
i
>
&
d
)
{
os
<<
d
.
head
;
os
<<
d
.
head
;
return
os
;
return
os
;
}
}
...
@@ -384,7 +407,8 @@ inline std::ostream &operator<<(std::ostream &os, const Dim<0> &d) {
...
@@ -384,7 +407,8 @@ inline std::ostream &operator<<(std::ostream &os, const Dim<0> &d) {
return
os
;
return
os
;
}
}
template
<
int
i
>
HOST
std
::
string
Dim
<
i
>::
to_string
()
const
{
template
<
int
i
>
HOST
std
::
string
Dim
<
i
>::
to_string
()
const
{
std
::
stringstream
stream
;
std
::
stringstream
stream
;
stream
<<
*
this
;
stream
<<
*
this
;
...
...
src/framework/executor.cpp
浏览文件 @
d6825bf3
src/framework/executor.h
浏览文件 @
d6825bf3
...
@@ -34,15 +34,16 @@ SOFTWARE.
...
@@ -34,15 +34,16 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
template
<
typename
Dtype
>
class
Executor
{
template
<
typename
Dtype
>
public:
class
Executor
{
public:
Executor
();
Executor
();
Executor
(
const
Program
<
Dtype
>
p
);
Executor
(
const
Program
<
Dtype
>
p
);
std
::
shared_ptr
<
Tensor
>
predict
(
Tensor
&
t
);
std
::
shared_ptr
<
Tensor
>
predict
(
Tensor
&
t
);
public:
public:
const
framework
::
Program
<
Dtype
>
program_
;
const
framework
::
Program
<
Dtype
>
program_
;
std
::
shared_ptr
<
ProgramDesc
>
to_predict_program_
;
std
::
shared_ptr
<
ProgramDesc
>
to_predict_program_
;
...
...
src/framework/framework.pb.cpp
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/framework/framework.pb.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/framework/lod_tensor.cc
浏览文件 @
d6825bf3
...
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
...
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "lod_tensor.h"
#include "lod_tensor.h"
#include <algorithm>
#include <iterator>
#include <stdint.h>
#include <stdint.h>
#include <string.h>
#include <string.h>
#include <algorithm>
#include <iterator>
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
...
@@ -105,8 +105,7 @@ LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin,
...
@@ -105,8 +105,7 @@ LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin,
LoD
ToAbsOffset
(
const
LoD
&
in
)
{
LoD
ToAbsOffset
(
const
LoD
&
in
)
{
// the lowest level stores relative offsets
// the lowest level stores relative offsets
if
(
in
.
empty
()
||
in
.
size
()
==
1
)
if
(
in
.
empty
()
||
in
.
size
()
==
1
)
return
in
;
return
in
;
LoD
result
=
in
;
LoD
result
=
in
;
for
(
auto
level
=
static_cast
<
int
>
(
in
.
size
()
-
2
);
level
>=
0
;
level
--
)
{
for
(
auto
level
=
static_cast
<
int
>
(
in
.
size
()
-
2
);
level
>=
0
;
level
--
)
{
for
(
size_t
i
=
0
;
i
<
in
[
level
].
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
in
[
level
].
size
();
++
i
)
{
...
@@ -138,23 +137,19 @@ bool operator==(const LoD &a, const LoD &b) {
...
@@ -138,23 +137,19 @@ bool operator==(const LoD &a, const LoD &b) {
}
}
bool
CheckLoD
(
const
LoD
&
in
,
int
tensor_height
)
{
bool
CheckLoD
(
const
LoD
&
in
,
int
tensor_height
)
{
if
(
in
.
empty
())
if
(
in
.
empty
())
return
true
;
return
true
;
for
(
const
auto
&
level
:
in
)
{
for
(
const
auto
&
level
:
in
)
{
// check: there should be more than 2 offsets existing in each
// check: there should be more than 2 offsets existing in each
// level.
// level.
if
(
level
.
size
()
<
2
)
if
(
level
.
size
()
<
2
)
return
false
;
return
false
;
// check: the first offset(the begin offset) of each level
// check: the first offset(the begin offset) of each level
// should be 0.
// should be 0.
if
(
level
.
front
()
!=
0
)
if
(
level
.
front
()
!=
0
)
return
false
;
return
false
;
// check: all the offsets in a level should be ascending(no same
// check: all the offsets in a level should be ascending(no same
// items
// items
// allows).
// allows).
if
(
!
std
::
is_sorted
(
level
.
begin
(),
level
.
begin
(),
[](
size_t
a
,
size_t
b
)
{
if
(
!
std
::
is_sorted
(
level
.
begin
(),
level
.
begin
(),
[](
size_t
a
,
size_t
b
)
{
if
(
a
<
b
)
if
(
a
<
b
)
return
true
;
return
true
;
return
false
;
return
false
;
}))
{
}))
{
std
::
cout
<<
"ascending error"
;
std
::
cout
<<
"ascending error"
;
...
@@ -174,22 +169,19 @@ bool CheckLoD(const LoD &in, int tensor_height) {
...
@@ -174,22 +169,19 @@ bool CheckLoD(const LoD &in, int tensor_height) {
// goes
// goes
// first.
// first.
for
(
size_t
level
=
0
;
level
<
in
.
size
()
-
1
;
level
++
)
{
for
(
size_t
level
=
0
;
level
<
in
.
size
()
-
1
;
level
++
)
{
if
(
in
[
level
].
back
()
!=
in
[
level
+
1
].
size
()
-
1
)
if
(
in
[
level
].
back
()
!=
in
[
level
+
1
].
size
()
-
1
)
return
false
;
return
false
;
}
}
return
true
;
return
true
;
}
}
bool
CheckAbsLoD
(
const
LoD
&
in
,
int
tensor_height
)
{
bool
CheckAbsLoD
(
const
LoD
&
in
,
int
tensor_height
)
{
if
(
in
.
empty
())
if
(
in
.
empty
())
return
true
;
return
true
;
for
(
const
auto
&
level
:
in
)
{
for
(
const
auto
&
level
:
in
)
{
// check: all the offsets in a level should be ascending(no same
// check: all the offsets in a level should be ascending(no same
// items
// items
// allows).
// allows).
if
(
!
std
::
is_sorted
(
level
.
begin
(),
level
.
begin
(),
[](
size_t
a
,
size_t
b
)
{
if
(
!
std
::
is_sorted
(
level
.
begin
(),
level
.
begin
(),
[](
size_t
a
,
size_t
b
)
{
if
(
a
<
b
)
if
(
a
<
b
)
return
true
;
return
true
;
return
false
;
return
false
;
}))
{
}))
{
return
false
;
return
false
;
...
@@ -197,14 +189,12 @@ bool CheckAbsLoD(const LoD &in, int tensor_height) {
...
@@ -197,14 +189,12 @@ bool CheckAbsLoD(const LoD &in, int tensor_height) {
// check: there should be more than 2 offsets existing in each
// check: there should be more than 2 offsets existing in each
// level.
// level.
if
(
level
.
size
()
<
2
)
if
(
level
.
size
()
<
2
)
return
false
;
return
false
;
// check: the first offset of each level should be 0, and the
// check: the first offset of each level should be 0, and the
// last should be
// last should be
// the same(the height of underlying tensor).
// the same(the height of underlying tensor).
if
(
level
.
front
()
!=
0
)
if
(
level
.
front
()
!=
0
)
return
false
;
return
false
;
if
(
tensor_height
<
0
)
{
if
(
tensor_height
<
0
)
{
tensor_height
=
level
.
back
();
tensor_height
=
level
.
back
();
}
else
if
((
size_t
)
tensor_height
!=
level
.
back
())
{
}
else
if
((
size_t
)
tensor_height
!=
level
.
back
())
{
...
...
src/framework/lod_tensor.h
浏览文件 @
d6825bf3
...
@@ -14,12 +14,12 @@ limitations under the License. */
...
@@ -14,12 +14,12 @@ limitations under the License. */
#pragma once
#pragma once
#include "tensor.h"
#include "tensor_util.h"
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "tensor.h"
#include "tensor_util.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
...
@@ -102,7 +102,7 @@ bool CheckAbsLoD(const LoD &in, int tensor_height = -1);
...
@@ -102,7 +102,7 @@ bool CheckAbsLoD(const LoD &in, int tensor_height = -1);
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/
*/
class
LoDTensor
:
public
Tensor
{
class
LoDTensor
:
public
Tensor
{
public:
public:
LoDTensor
()
:
Tensor
()
{}
LoDTensor
()
:
Tensor
()
{}
explicit
LoDTensor
(
const
LoD
&
lod
)
:
lod_
(
lod
)
{}
explicit
LoDTensor
(
const
LoD
&
lod
)
:
lod_
(
lod
)
{}
...
@@ -139,7 +139,7 @@ public:
...
@@ -139,7 +139,7 @@ public:
return
(
lod_
)[
level
].
size
()
-
1
;
return
(
lod_
)[
level
].
size
()
-
1
;
}
}
private:
private:
LoD
lod_
;
LoD
lod_
;
};
};
...
@@ -189,9 +189,8 @@ LoDTensor LodExpand(const LoDTensor &source, const LoD &lod, size_t level) {
...
@@ -189,9 +189,8 @@ LoDTensor LodExpand(const LoDTensor &source, const LoD &lod, size_t level) {
// Returns:
// Returns:
// LoD = [[1, 4], [2, 4, 2, 3, 2]]
// LoD = [[1, 4], [2, 4, 2, 3, 2]]
// pair<size_t, size_t> = {11, 24}
// pair<size_t, size_t> = {11, 24}
std
::
pair
<
LoD
,
std
::
pair
<
size_t
,
size_t
>>
std
::
pair
<
LoD
,
std
::
pair
<
size_t
,
size_t
>>
GetSubLoDAndAbsoluteOffset
(
GetSubLoDAndAbsoluteOffset
(
const
LoD
&
lod
,
size_t
start_idx
,
size_t
end_idx
,
const
LoD
&
lod
,
size_t
start_idx
,
size_t
end_idx
,
size_t
start_level
);
size_t
start_level
);
void
AppendLoD
(
LoD
*
lod
,
const
LoD
&
lod_length
);
void
AppendLoD
(
LoD
*
lod
,
const
LoD
&
lod_length
);
...
...
src/framework/op_desc.cpp
浏览文件 @
d6825bf3
src/framework/op_desc.h
浏览文件 @
d6825bf3
...
@@ -26,7 +26,7 @@ namespace paddle_mobile {
...
@@ -26,7 +26,7 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
OpDesc
:
PaddleMobileObject
{
class
OpDesc
:
PaddleMobileObject
{
public:
public:
OpDesc
(
const
proto
::
OpDesc
&
desc
);
OpDesc
(
const
proto
::
OpDesc
&
desc
);
const
std
::
vector
<
std
::
string
>
&
Input
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>
&
Input
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>
&
Output
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>
&
Output
(
const
std
::
string
&
name
)
const
;
...
@@ -40,7 +40,7 @@ public:
...
@@ -40,7 +40,7 @@ public:
const
std
::
string
&
Type
()
{
return
desc_
.
type
();
};
const
std
::
string
&
Type
()
{
return
desc_
.
type
();
};
private:
private:
proto
::
OpDesc
desc_
;
proto
::
OpDesc
desc_
;
VariableNameMap
inputs_
;
VariableNameMap
inputs_
;
VariableNameMap
outputs_
;
VariableNameMap
outputs_
;
...
...
src/framework/op_info.h
浏览文件 @
d6825bf3
...
@@ -24,7 +24,8 @@ SOFTWARE.
...
@@ -24,7 +24,8 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
template
<
typename
Dtype
>
struct
OpInfo
{
template
<
typename
Dtype
>
struct
OpInfo
{
OpCreator
<
Dtype
>
creator_
;
OpCreator
<
Dtype
>
creator_
;
const
OpCreator
<
Dtype
>
&
Creator
()
const
{
const
OpCreator
<
Dtype
>
&
Creator
()
const
{
// PADDLE_ENFORCE_NOT_NULL(creator_,
// PADDLE_ENFORCE_NOT_NULL(creator_,
...
@@ -34,12 +35,15 @@ template <typename Dtype> struct OpInfo {
...
@@ -34,12 +35,15 @@ template <typename Dtype> struct OpInfo {
}
}
};
};
template
<
typename
Dtype
>
class
OpInfoMap
;
template
<
typename
Dtype
>
class
OpInfoMap
;
template
<
typename
Dtype
>
static
OpInfoMap
<
Dtype
>
*
g_op_info_map
=
nullptr
;
template
<
typename
Dtype
>
static
OpInfoMap
<
Dtype
>
*
g_op_info_map
=
nullptr
;
template
<
typename
Dtype
>
class
OpInfoMap
{
template
<
typename
Dtype
>
public:
class
OpInfoMap
{
public:
static
OpInfoMap
&
Instance
()
{
static
OpInfoMap
&
Instance
()
{
if
(
g_op_info_map
<
Dtype
>
==
nullptr
)
{
if
(
g_op_info_map
<
Dtype
>
==
nullptr
)
{
g_op_info_map
<
Dtype
>
=
new
OpInfoMap
();
g_op_info_map
<
Dtype
>
=
new
OpInfoMap
();
...
@@ -83,7 +87,7 @@ public:
...
@@ -83,7 +87,7 @@ public:
return
&
map_
;
return
&
map_
;
}
}
private:
private:
OpInfoMap
()
=
default
;
OpInfoMap
()
=
default
;
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
map_
;
std
::
unordered_map
<
std
::
string
,
OpInfo
<
Dtype
>>
map_
;
...
...
src/framework/op_kernel_type.h
浏览文件 @
d6825bf3
src/framework/op_proto_maker.h
浏览文件 @
d6825bf3
src/framework/operator.cpp
浏览文件 @
d6825bf3
...
@@ -28,7 +28,10 @@ OperatorBase<Dtype>::OperatorBase(const std::string &type,
...
@@ -28,7 +28,10 @@ OperatorBase<Dtype>::OperatorBase(const std::string &type,
const
VariableNameMap
&
outputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
const
AttributeMap
&
attrs
,
std
::
shared_ptr
<
Scope
>
scope
)
std
::
shared_ptr
<
Scope
>
scope
)
:
type_
(
type
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs
),
:
type_
(
type
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs
),
scope_
(
scope
)
{
scope_
(
scope
)
{
CheckAllInputOutputSet
();
CheckAllInputOutputSet
();
}
}
...
...
src/framework/operator.h
浏览文件 @
d6825bf3
...
@@ -48,8 +48,9 @@ static std::unordered_map<
...
@@ -48,8 +48,9 @@ static std::unordered_map<
{
"feed"
,
{{
"X"
},
{
"Out"
}}},
{
"feed"
,
{{
"X"
},
{
"Out"
}}},
{
"fetch"
,
{{
"X"
},
{
"Out"
}}}};
{
"fetch"
,
{{
"X"
},
{
"Out"
}}}};
template
<
typename
Dtype
>
class
OperatorBase
:
PaddleMobileObject
{
template
<
typename
Dtype
>
public:
class
OperatorBase
:
PaddleMobileObject
{
public:
OperatorBase
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
OperatorBase
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
std
::
shared_ptr
<
Scope
>
scope
);
std
::
shared_ptr
<
Scope
>
scope
);
...
@@ -66,20 +67,20 @@ public:
...
@@ -66,20 +67,20 @@ public:
}
}
}
}
protected:
protected:
std
::
shared_ptr
<
Scope
>
scope_
;
std
::
shared_ptr
<
Scope
>
scope_
;
std
::
string
type_
;
std
::
string
type_
;
VariableNameMap
inputs_
;
VariableNameMap
inputs_
;
VariableNameMap
outputs_
;
VariableNameMap
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
private:
private:
void
CheckAllInputOutputSet
()
const
;
void
CheckAllInputOutputSet
()
const
;
};
};
template
<
typename
Dtype
>
template
<
typename
Dtype
>
class
OperatorWithKernel
:
public
OperatorBase
<
Dtype
>
{
class
OperatorWithKernel
:
public
OperatorBase
<
Dtype
>
{
public:
public:
OperatorWithKernel
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
OperatorWithKernel
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
std
::
shared_ptr
<
Scope
>
scope
)
std
::
shared_ptr
<
Scope
>
scope
)
...
@@ -88,8 +89,9 @@ public:
...
@@ -88,8 +89,9 @@ public:
virtual
void
Run
()
const
=
0
;
virtual
void
Run
()
const
=
0
;
};
};
template
<
typename
Dtype
,
typename
P
>
class
OpKernelBase
:
PaddleMobileObject
{
template
<
typename
Dtype
,
typename
P
>
public:
class
OpKernelBase
:
PaddleMobileObject
{
public:
virtual
void
Compute
(
const
P
&
para
)
const
=
0
;
virtual
void
Compute
(
const
P
&
para
)
const
=
0
;
virtual
~
OpKernelBase
()
=
default
;
virtual
~
OpKernelBase
()
=
default
;
...
...
src/framework/paddle_mobile_object.h
浏览文件 @
d6825bf3
...
@@ -18,19 +18,19 @@ SOFTWARE.
...
@@ -18,19 +18,19 @@ SOFTWARE.
#pragma once
#pragma once
#include "stdio.h"
#include <string>
#include <string>
#include "stdio.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
class
PaddleMobileObject
{
class
PaddleMobileObject
{
public:
public:
virtual
std
::
string
ToString
()
{
virtual
std
::
string
ToString
()
{
char
address
[
128
]
=
{
0
};
char
address
[
128
]
=
{
0
};
sprintf
(
address
,
"%p"
,
this
);
sprintf
(
address
,
"%p"
,
this
);
return
std
::
string
(
address
);
return
std
::
string
(
address
);
}
}
private:
private:
};
};
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/framework/program-optimize/node.cpp
浏览文件 @
d6825bf3
src/framework/program-optimize/node.h
浏览文件 @
d6825bf3
...
@@ -29,7 +29,7 @@ namespace paddle_mobile {
...
@@ -29,7 +29,7 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
Node
:
PaddleMobileObject
{
class
Node
:
PaddleMobileObject
{
public:
public:
Node
(
const
std
::
string
&
type
)
:
type_
(
type
)
{}
Node
(
const
std
::
string
&
type
)
:
type_
(
type
)
{}
Node
(
std
::
shared_ptr
<
OpDesc
>
op_desc
)
Node
(
std
::
shared_ptr
<
OpDesc
>
op_desc
)
:
op_desc_
(
op_desc
),
type_
(
op_desc
->
Type
()){};
:
op_desc_
(
op_desc
),
type_
(
op_desc
->
Type
()){};
...
@@ -39,7 +39,7 @@ public:
...
@@ -39,7 +39,7 @@ public:
Node
&
To
(
int
index
);
Node
&
To
(
int
index
);
uint
depth
(
uint
begin
=
0
);
uint
depth
(
uint
begin
=
0
);
private:
private:
std
::
shared_ptr
<
OpDesc
>
op_desc_
;
std
::
shared_ptr
<
OpDesc
>
op_desc_
;
std
::
string
ToString
(
std
::
string
blank
,
const
Node
*
node
)
const
;
std
::
string
ToString
(
std
::
string
blank
,
const
Node
*
node
)
const
;
std
::
vector
<
std
::
shared_ptr
<
Node
>>
outputs_
;
std
::
vector
<
std
::
shared_ptr
<
Node
>>
outputs_
;
...
...
src/framework/program-optimize/program_optimize.cpp
浏览文件 @
d6825bf3
...
@@ -24,8 +24,8 @@ namespace framework {
...
@@ -24,8 +24,8 @@ namespace framework {
std
::
shared_ptr
<
ProgramDesc
>
ProgramOptimize
::
Optimize
()
{}
std
::
shared_ptr
<
ProgramDesc
>
ProgramOptimize
::
Optimize
()
{}
std
::
shared_ptr
<
ProgramDesc
>
std
::
shared_ptr
<
ProgramDesc
>
ProgramOptimize
::
FushionOptimize
(
ProgramOptimize
::
FushionOptimize
(
std
::
shared_ptr
<
ProgramDesc
>
ori_des
)
{
std
::
shared_ptr
<
ProgramDesc
>
ori_des
)
{
for
(
int
i
=
0
;
i
<
ori_des
->
Blocks
().
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
ori_des
->
Blocks
().
size
();
++
i
)
{
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Node
>>
output_nodes
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Node
>>
output_nodes
;
std
::
shared_ptr
<
Node
>
begin_node
;
std
::
shared_ptr
<
Node
>
begin_node
;
...
...
src/framework/program-optimize/program_optimize.h
浏览文件 @
d6825bf3
...
@@ -26,13 +26,13 @@ namespace paddle_mobile {
...
@@ -26,13 +26,13 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
ProgramOptimize
{
class
ProgramOptimize
{
public:
public:
ProgramOptimize
()
{}
ProgramOptimize
()
{}
std
::
shared_ptr
<
ProgramDesc
>
Optimize
();
std
::
shared_ptr
<
ProgramDesc
>
Optimize
();
std
::
shared_ptr
<
ProgramDesc
>
std
::
shared_ptr
<
ProgramDesc
>
FushionOptimize
(
FushionOptimize
(
std
::
shared_ptr
<
ProgramDesc
>
ori_des
);
std
::
shared_ptr
<
ProgramDesc
>
ori_des
);
private:
private:
// std::shared_ptr<ProgramDesc> ori_desc_;
// std::shared_ptr<ProgramDesc> ori_desc_;
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Node
>>>
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Node
>>>
outputs_nodes_
;
outputs_nodes_
;
...
...
src/framework/program.cpp
浏览文件 @
d6825bf3
src/framework/program.h
浏览文件 @
d6825bf3
...
@@ -28,12 +28,12 @@ namespace framework {
...
@@ -28,12 +28,12 @@ namespace framework {
template
<
typename
Dtype
,
Precision
P
=
Precision
::
FP32
>
template
<
typename
Dtype
,
Precision
P
=
Precision
::
FP32
>
class
Program
:
PaddleMobileObject
{
class
Program
:
PaddleMobileObject
{
public:
public:
std
::
shared_ptr
<
ProgramDesc
>
originProgram
;
std
::
shared_ptr
<
ProgramDesc
>
originProgram
;
std
::
shared_ptr
<
ProgramDesc
>
optimizeProgram
;
std
::
shared_ptr
<
ProgramDesc
>
optimizeProgram
;
std
::
shared_ptr
<
Scope
>
scope
;
std
::
shared_ptr
<
Scope
>
scope
;
private:
private:
};
};
}
// namespace framework
}
// namespace framework
...
...
src/framework/program_desc.cpp
浏览文件 @
d6825bf3
src/framework/program_desc.h
浏览文件 @
d6825bf3
...
@@ -28,12 +28,12 @@ namespace paddle_mobile {
...
@@ -28,12 +28,12 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
ProgramDesc
:
PaddleMobileObject
{
class
ProgramDesc
:
PaddleMobileObject
{
public:
public:
ProgramDesc
(
const
proto
::
ProgramDesc
&
desc
);
ProgramDesc
(
const
proto
::
ProgramDesc
&
desc
);
std
::
shared_ptr
<
BlockDesc
>
Block
(
size_t
idx
);
std
::
shared_ptr
<
BlockDesc
>
Block
(
size_t
idx
);
const
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
&
Blocks
()
{
return
blocks_
;
};
const
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
&
Blocks
()
{
return
blocks_
;
};
private:
private:
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
blocks_
;
std
::
vector
<
std
::
shared_ptr
<
BlockDesc
>>
blocks_
;
proto
::
ProgramDesc
desc_
;
proto
::
ProgramDesc
desc_
;
};
};
...
...
src/framework/scope.cc
浏览文件 @
d6825bf3
src/framework/scope.h
浏览文件 @
d6825bf3
...
@@ -18,15 +18,15 @@ SOFTWARE.
...
@@ -18,15 +18,15 @@ SOFTWARE.
==============================================================================*/
==============================================================================*/
#pragma once
#pragma once
#include "variable.h"
#include <list> //std::list
#include <list> //std::list
#include <mutex> //std::mutex
#include <mutex> //std::mutex
#include <unordered_map> //std::unordered_map
#include <unordered_map> //std::unordered_map
#include "variable.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
class
Scope
{
class
Scope
{
public:
public:
Scope
()
{}
Scope
()
{}
~
Scope
()
{}
~
Scope
()
{}
...
@@ -67,7 +67,7 @@ public:
...
@@ -67,7 +67,7 @@ public:
Variable
*
FindVarLocally
(
const
std
::
string
&
name
)
const
;
Variable
*
FindVarLocally
(
const
std
::
string
&
name
)
const
;
private:
private:
// Call Scope::NewScope for a sub-scope.
// Call Scope::NewScope for a sub-scope.
explicit
Scope
(
Scope
const
*
parent
)
:
parent_
(
parent
)
{}
explicit
Scope
(
Scope
const
*
parent
)
:
parent_
(
parent
)
{}
...
...
src/framework/selected_rows.h
浏览文件 @
d6825bf3
...
@@ -27,7 +27,7 @@ namespace paddle_mobile {
...
@@ -27,7 +27,7 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
SelectedRows
{
class
SelectedRows
{
public:
public:
SelectedRows
(
const
std
::
vector
<
int64_t
>
&
rows
,
const
int64_t
&
height
)
SelectedRows
(
const
std
::
vector
<
int64_t
>
&
rows
,
const
int64_t
&
height
)
:
rows_
(
rows
),
height_
(
height
)
{
:
rows_
(
rows
),
height_
(
height
)
{
value_
.
reset
(
new
Tensor
());
value_
.
reset
(
new
Tensor
());
...
@@ -67,7 +67,7 @@ public:
...
@@ -67,7 +67,7 @@ public:
return
make_ddim
(
dims
);
return
make_ddim
(
dims
);
}
}
private:
private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9}
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9}
// here.
// here.
// SelectedRows are simply concated when adding together. Until a
// SelectedRows are simply concated when adding together. Until a
...
...
src/framework/tensor.h
浏览文件 @
d6825bf3
...
@@ -26,9 +26,11 @@ limitations under the License. */
...
@@ -26,9 +26,11 @@ limitations under the License. */
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
template
<
typename
...
T
>
struct
SizeOfTypeFunctor
;
template
<
typename
...
T
>
struct
SizeOfTypeFunctor
;
template
<
typename
T
>
struct
SizeOfTypeFunctor
<
T
>
{
template
<
typename
T
>
struct
SizeOfTypeFunctor
<
T
>
{
size_t
operator
()(
std
::
type_index
type
)
const
{
size_t
operator
()(
std
::
type_index
type
)
const
{
if
(
typeid
(
T
).
hash_code
()
==
type
.
hash_code
())
{
if
(
typeid
(
T
).
hash_code
()
==
type
.
hash_code
())
{
return
sizeof
(
T
);
return
sizeof
(
T
);
...
@@ -38,7 +40,8 @@ template <typename T> struct SizeOfTypeFunctor<T> {
...
@@ -38,7 +40,8 @@ template <typename T> struct SizeOfTypeFunctor<T> {
}
}
};
};
template
<
>
struct
SizeOfTypeFunctor
<>
{
template
<
>
struct
SizeOfTypeFunctor
<>
{
size_t
operator
()(
std
::
type_index
type
)
const
{
return
0UL
;
}
size_t
operator
()(
std
::
type_index
type
)
const
{
return
0UL
;
}
};
};
...
@@ -66,11 +69,12 @@ static inline size_t SizeOfType(std::type_index type) {
...
@@ -66,11 +69,12 @@ static inline size_t SizeOfType(std::type_index type) {
class
LoDTensor
;
class
LoDTensor
;
class
Tensor
{
class
Tensor
{
public:
public:
Tensor
()
:
offset_
(
0
)
{}
Tensor
()
:
offset_
(
0
)
{}
/*! Return a pointer to mutable memory block. */
/*! Return a pointer to mutable memory block. */
template
<
typename
T
>
inline
T
*
data
()
{
template
<
typename
T
>
inline
T
*
data
()
{
check_memory_size
();
check_memory_size
();
// PADDLE_ENFORCE(std::is_same<T, void>::value ||
// PADDLE_ENFORCE(std::is_same<T, void>::value ||
// holder_->type().hash_code() ==
// holder_->type().hash_code() ==
...
@@ -82,7 +86,8 @@ public:
...
@@ -82,7 +86,8 @@ public:
}
}
/*! Return a pointer to constant memory block. */
/*! Return a pointer to constant memory block. */
template
<
typename
T
>
inline
const
T
*
data
()
const
{
template
<
typename
T
>
inline
const
T
*
data
()
const
{
check_memory_size
();
check_memory_size
();
// PADDLE_ENFORCE(std::is_same<T, void>::value ||
// PADDLE_ENFORCE(std::is_same<T, void>::value ||
// holder_->type().hash_code() ==
// holder_->type().hash_code() ==
...
@@ -100,7 +105,8 @@ public:
...
@@ -100,7 +105,8 @@ public:
* @brief Return a pointer to mutable memory block.
* @brief Return a pointer to mutable memory block.
* @note If not exist, then allocation.
* @note If not exist, then allocation.
*/
*/
template
<
typename
T
>
inline
T
*
mutable_data
()
{
template
<
typename
T
>
inline
T
*
mutable_data
()
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
return
reinterpret_cast
<
T
*>
(
mutable_data
(
typeid
(
T
)));
return
reinterpret_cast
<
T
*>
(
mutable_data
(
typeid
(
T
)));
}
}
...
@@ -141,7 +147,8 @@ public:
...
@@ -141,7 +147,8 @@ public:
*
*
* @note If not exist, then allocation.
* @note If not exist, then allocation.
*/
*/
template
<
typename
T
>
inline
T
*
mutable_data
(
DDim
dims
)
{
template
<
typename
T
>
inline
T
*
mutable_data
(
DDim
dims
)
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
Resize
(
dims
);
Resize
(
dims
);
return
mutable_data
<
T
>
();
return
mutable_data
<
T
>
();
...
@@ -235,7 +242,7 @@ public:
...
@@ -235,7 +242,7 @@ public:
inline
void
set_layout
(
const
DataLayout
layout
)
{
layout_
=
layout
;
}
inline
void
set_layout
(
const
DataLayout
layout
)
{
layout_
=
layout
;
}
private:
private:
/**
/**
* @note Placeholder hides type T, so it doesn't appear as a
* @note Placeholder hides type T, so it doesn't appear as a
* template
* template
...
@@ -257,7 +264,8 @@ private:
...
@@ -257,7 +264,8 @@ private:
PlaceholderImpl
(
size_t
size
,
std
::
type_index
type
)
PlaceholderImpl
(
size_t
size
,
std
::
type_index
type
)
:
ptr_
(
static_cast
<
uint8_t
*>
(
memory
::
Alloc
(
size
)),
:
ptr_
(
static_cast
<
uint8_t
*>
(
memory
::
Alloc
(
size
)),
memory
::
PODDeleter
<
uint8_t
>
()),
memory
::
PODDeleter
<
uint8_t
>
()),
size_
(
size
),
type_
(
type
)
{
size_
(
size
),
type_
(
type
)
{
// PADDLE_ENFORCE_NOT_NULL(ptr_,
// PADDLE_ENFORCE_NOT_NULL(ptr_,
// "Insufficient %s
// "Insufficient %s
// memory to allocation.",
// memory to allocation.",
...
...
src/framework/tensor_util.cc
浏览文件 @
d6825bf3
...
@@ -51,7 +51,8 @@ void TensorCopySync(const Tensor &src, Tensor *dst) {
...
@@ -51,7 +51,8 @@ void TensorCopySync(const Tensor &src, Tensor *dst) {
memory
::
Copy
(
dst_ptr
,
src_ptr
,
size
);
memory
::
Copy
(
dst_ptr
,
src_ptr
,
size
);
}
}
template
<
typename
Predicate
>
struct
AnyDTypeVisitor
{
template
<
typename
Predicate
>
struct
AnyDTypeVisitor
{
Predicate
predicate_
;
Predicate
predicate_
;
const
Tensor
&
tensor_
;
const
Tensor
&
tensor_
;
Tensor
*
out_
;
Tensor
*
out_
;
...
@@ -59,7 +60,8 @@ template <typename Predicate> struct AnyDTypeVisitor {
...
@@ -59,7 +60,8 @@ template <typename Predicate> struct AnyDTypeVisitor {
AnyDTypeVisitor
(
Predicate
predicate
,
const
Tensor
&
tensor
,
Tensor
*
out
)
AnyDTypeVisitor
(
Predicate
predicate
,
const
Tensor
&
tensor
,
Tensor
*
out
)
:
predicate_
(
predicate
),
tensor_
(
tensor
),
out_
(
out
)
{}
:
predicate_
(
predicate
),
tensor_
(
tensor
),
out_
(
out
)
{}
template
<
typename
T
>
void
operator
()()
const
{
template
<
typename
T
>
void
operator
()()
const
{
// auto t = EigenVector<T>::Flatten(tensor_);
// auto t = EigenVector<T>::Flatten(tensor_);
// auto o = EigenScalar<bool>::From(*out_);
// auto o = EigenScalar<bool>::From(*out_);
// return any of predicate_(t) is true.
// return any of predicate_(t) is true.
...
@@ -74,7 +76,8 @@ inline void AnyImpl(Predicate predicate, const Tensor &tensor,
...
@@ -74,7 +76,8 @@ inline void AnyImpl(Predicate predicate, const Tensor &tensor,
AnyDTypeVisitor
<
Predicate
>
(
predicate
,
tensor
,
out
));
AnyDTypeVisitor
<
Predicate
>
(
predicate
,
tensor
,
out
));
}
}
template
<
typename
Predicate
>
struct
AnyVisitor
{
template
<
typename
Predicate
>
struct
AnyVisitor
{
const
framework
::
Tensor
&
tensor_
;
const
framework
::
Tensor
&
tensor_
;
Predicate
predicate_
;
Predicate
predicate_
;
...
@@ -164,7 +167,8 @@ struct DeserializedDataFunctor {
...
@@ -164,7 +167,8 @@ struct DeserializedDataFunctor {
DeserializedDataFunctor
(
void
**
buf
,
Tensor
*
tensor
)
DeserializedDataFunctor
(
void
**
buf
,
Tensor
*
tensor
)
:
buf_
(
buf
),
tensor_
(
tensor
)
{}
:
buf_
(
buf
),
tensor_
(
tensor
)
{}
template
<
typename
T
>
void
operator
()()
{
template
<
typename
T
>
void
operator
()()
{
*
buf_
=
tensor_
->
mutable_data
<
T
>
();
*
buf_
=
tensor_
->
mutable_data
<
T
>
();
}
}
...
...
src/framework/tensor_util.h
浏览文件 @
d6825bf3
...
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
...
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <vector>
#include "framework.pb.h"
#include "framework.pb.h"
#include "memory/t_malloc.h"
#include "memory/t_malloc.h"
#include "platform/data_type.h"
#include "platform/data_type.h"
#include "tensor.h"
#include "tensor.h"
#include <vector>
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
...
...
src/framework/var_desc.cpp
浏览文件 @
d6825bf3
src/framework/var_desc.h
浏览文件 @
d6825bf3
...
@@ -25,7 +25,7 @@ namespace paddle_mobile {
...
@@ -25,7 +25,7 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
class
VarDesc
{
class
VarDesc
{
public:
public:
VarDesc
(
const
proto
::
VarDesc
&
desc
);
VarDesc
(
const
proto
::
VarDesc
&
desc
);
std
::
string
Name
()
const
{
return
desc_
.
name
();
}
std
::
string
Name
()
const
{
return
desc_
.
name
();
}
...
@@ -80,7 +80,7 @@ public:
...
@@ -80,7 +80,7 @@ public:
return
this
->
RepeatedToVector
(
tensor_desc
().
dims
());
return
this
->
RepeatedToVector
(
tensor_desc
().
dims
());
}
}
private:
private:
proto
::
VarDesc
desc_
;
proto
::
VarDesc
desc_
;
};
};
...
...
src/framework/var_type.h
浏览文件 @
d6825bf3
src/framework/variable.h
浏览文件 @
d6825bf3
...
@@ -18,18 +18,19 @@ SOFTWARE.
...
@@ -18,18 +18,19 @@ SOFTWARE.
==============================================================================*/
==============================================================================*/
#pragma once
#pragma once
#include "paddle_mobile_object.h"
#include <iostream>
#include <iostream>
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <typeindex>
#include <typeindex>
#include <typeinfo>
#include <typeinfo>
#include "paddle_mobile_object.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
class
Variable
:
public
PaddleMobileObject
{
class
Variable
:
public
PaddleMobileObject
{
public:
public:
template
<
typename
T
>
const
T
*
Get
()
const
{
template
<
typename
T
>
const
T
*
Get
()
const
{
return
static_cast
<
const
T
*>
(
holder_
->
Ptr
());
return
static_cast
<
const
T
*>
(
holder_
->
Ptr
());
}
}
...
@@ -37,7 +38,8 @@ public:
...
@@ -37,7 +38,8 @@ public:
const
std
::
string
*
Name
()
{
return
name_
;
}
const
std
::
string
*
Name
()
{
return
name_
;
}
template
<
typename
T
>
T
*
GetMutable
()
{
template
<
typename
T
>
T
*
GetMutable
()
{
if
(
!
IsType
<
T
>
())
{
if
(
!
IsType
<
T
>
())
{
if
(
*
Name
()
==
"pixel"
)
{
if
(
*
Name
()
==
"pixel"
)
{
// std::cout << " reset " << *Name() <<
// std::cout << " reset " << *Name() <<
...
@@ -48,7 +50,8 @@ public:
...
@@ -48,7 +50,8 @@ public:
return
static_cast
<
T
*>
(
holder_
->
Ptr
());
return
static_cast
<
T
*>
(
holder_
->
Ptr
());
}
}
template
<
typename
T
>
bool
IsType
()
const
{
template
<
typename
T
>
bool
IsType
()
const
{
if
(
holder_
)
{
if
(
holder_
)
{
// printf("not null \n");
// printf("not null \n");
printf
(
" holder type : %s, this type %s
\n
"
,
holder_
->
Type
().
name
(),
printf
(
" holder type : %s, this type %s
\n
"
,
holder_
->
Type
().
name
(),
...
@@ -67,7 +70,7 @@ public:
...
@@ -67,7 +70,7 @@ public:
void
SetName
(
const
std
::
string
*
name
)
{
name_
=
name
;
}
void
SetName
(
const
std
::
string
*
name
)
{
name_
=
name
;
}
private:
private:
struct
Placeholder
{
struct
Placeholder
{
Placeholder
()
=
default
;
Placeholder
()
=
default
;
virtual
~
Placeholder
()
=
default
;
virtual
~
Placeholder
()
=
default
;
...
@@ -76,7 +79,8 @@ private:
...
@@ -76,7 +79,8 @@ private:
virtual
void
*
Ptr
()
const
=
0
;
virtual
void
*
Ptr
()
const
=
0
;
};
};
template
<
typename
T
>
struct
PlaceholderImp
:
public
Placeholder
{
template
<
typename
T
>
struct
PlaceholderImp
:
public
Placeholder
{
explicit
PlaceholderImp
(
T
*
ptr
)
:
ptr_
(
ptr
),
type_
(
typeid
(
T
))
{}
explicit
PlaceholderImp
(
T
*
ptr
)
:
ptr_
(
ptr
),
type_
(
typeid
(
T
))
{}
virtual
const
std
::
type_info
&
Type
()
const
{
return
type_
;
}
virtual
const
std
::
type_info
&
Type
()
const
{
return
type_
;
}
...
...
src/io.cpp
浏览文件 @
d6825bf3
...
@@ -18,13 +18,13 @@ SOFTWARE.
...
@@ -18,13 +18,13 @@ SOFTWARE.
#include <fstream>
#include <fstream>
#include "../src/io.h"
#include "common/log.h"
#include "common/log.h"
#include "framework/framework.pb.h"
#include "framework/framework.pb.h"
#include "framework/lod_tensor.h"
#include "framework/lod_tensor.h"
#include "framework/program_desc.h"
#include "framework/program_desc.h"
#include "framework/scope.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/tensor.h"
#include "io.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
...
@@ -41,26 +41,20 @@ void ReadBinaryFile(const std::string &filename, std::string *contents) {
...
@@ -41,26 +41,20 @@ void ReadBinaryFile(const std::string &filename, std::string *contents) {
template
<
typename
Dtype
,
Precision
P
>
template
<
typename
Dtype
,
Precision
P
>
void
Loader
<
Dtype
,
P
>::
LoadVar
(
framework
::
LoDTensor
*
tensor
,
void
Loader
<
Dtype
,
P
>::
LoadVar
(
framework
::
LoDTensor
*
tensor
,
const
std
::
string
&
file_path
)
{
const
std
::
string
&
file_path
)
{
// LOG(kLOG_DEBUG) << " to load " << file_path;
// Log(kLOG_DEBUG) << "123";
std
::
ifstream
is
(
file_path
);
std
::
ifstream
is
(
file_path
);
std
::
streampos
pos
=
is
.
tellg
();
// save current position
std
::
fpos
<
mbstate_t
>
pos
;
pos
=
is
.
tellg
();
// save current position
is
.
seekg
(
0
,
std
::
ios
::
end
);
is
.
seekg
(
0
,
std
::
ios
::
end
);
// LOG(kLOG_DEBUG) << " file length = " << is.tellg();
is
.
seekg
(
pos
);
// restore saved position
is
.
seekg
(
pos
);
// restore saved position
// 1. version
// 1. version
uint32_t
version
;
uint32_t
version
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
version
),
sizeof
(
version
));
is
.
read
(
reinterpret_cast
<
char
*>
(
&
version
),
sizeof
(
version
));
// LOG(kLOG_INFO) << " version: " << version;
// 2 Lod information
// 2 Lod information
uint64_t
lod_level
;
uint64_t
lod_level
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
lod_level
),
sizeof
(
lod_level
));
is
.
read
(
reinterpret_cast
<
char
*>
(
&
lod_level
),
sizeof
(
lod_level
));
// LOG(kLOG_DEBUG) << " load level: " << lod_level;
// LOG(kLOG_DEBUG) << " lod info: ";
auto
&
lod
=
*
tensor
->
mutable_lod
();
auto
&
lod
=
*
tensor
->
mutable_lod
();
lod
.
resize
(
lod_level
);
lod
.
resize
(
lod_level
);
for
(
uint64_t
i
=
0
;
i
<
lod_level
;
++
i
)
{
for
(
uint64_t
i
=
0
;
i
<
lod_level
;
++
i
)
{
...
@@ -69,8 +63,8 @@ void Loader<Dtype, P>::LoadVar(framework::LoDTensor *tensor,
...
@@ -69,8 +63,8 @@ void Loader<Dtype, P>::LoadVar(framework::LoDTensor *tensor,
std
::
vector
<
size_t
>
tmp
(
size
/
sizeof
(
size_t
));
std
::
vector
<
size_t
>
tmp
(
size
/
sizeof
(
size_t
));
is
.
read
(
reinterpret_cast
<
char
*>
(
tmp
.
data
()),
is
.
read
(
reinterpret_cast
<
char
*>
(
tmp
.
data
()),
static_cast
<
std
::
streamsize
>
(
size
));
static_cast
<
std
::
streamsize
>
(
size
));
for
(
int
j
=
0
;
j
<
tmp
.
size
();
++
j
)
{
for
(
auto
j
:
tmp
)
{
LOG
(
kLOG_DEBUG1
)
<<
" lod - "
<<
tmp
[
j
]
;
LOG
(
kLOG_DEBUG1
)
<<
" lod - "
<<
j
;
}
}
lod
[
i
]
=
tmp
;
lod
[
i
]
=
tmp
;
}
}
...
@@ -78,26 +72,19 @@ void Loader<Dtype, P>::LoadVar(framework::LoDTensor *tensor,
...
@@ -78,26 +72,19 @@ void Loader<Dtype, P>::LoadVar(framework::LoDTensor *tensor,
// 3. tensor version
// 3. tensor version
uint32_t
tensor_version
;
uint32_t
tensor_version
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
tensor_version
),
sizeof
(
tensor_version
));
is
.
read
(
reinterpret_cast
<
char
*>
(
&
tensor_version
),
sizeof
(
tensor_version
));
// std::cout << " tensor_version: " << tensor_version << std::endl;
// 4. tensor desc
// 4. tensor desc
int32_t
size
;
int32_t
size
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
size
),
sizeof
(
size
));
is
.
read
(
reinterpret_cast
<
char
*>
(
&
size
),
sizeof
(
size
));
// std::cout << " tensor desc size: " << size << std::endl;
std
::
unique_ptr
<
char
[]
>
buf
(
new
char
[
size
]);
std
::
unique_ptr
<
char
[]
>
buf
(
new
char
[
size
]);
is
.
read
(
reinterpret_cast
<
char
*>
(
buf
.
get
()),
size
);
is
.
read
(
reinterpret_cast
<
char
*>
(
buf
.
get
()),
size
);
framework
::
proto
::
VarType
::
TensorDesc
desc
;
framework
::
proto
::
VarType
::
TensorDesc
desc
;
desc
.
ParseFromArray
(
buf
.
get
(),
size
);
desc
.
ParseFromArray
(
buf
.
get
(),
size
);
// std::cout << " desc dims size " << desc.dims().size() <<
// std::endl;
int
memory_size
=
1
;
int
memory_size
=
1
;
for
(
int
l
=
0
;
l
<
desc
.
dims
().
size
();
++
l
)
{
for
(
auto
l
:
desc
.
dims
())
{
// std::cout << " dim " << l << " value: " << desc.dims()[l]
memory_size
*=
l
;
// <<
// std::endl;
memory_size
*=
desc
.
dims
()[
l
];
}
}
std
::
vector
<
int64_t
>
dims
;
std
::
vector
<
int64_t
>
dims
;
...
@@ -105,50 +92,39 @@ void Loader<Dtype, P>::LoadVar(framework::LoDTensor *tensor,
...
@@ -105,50 +92,39 @@ void Loader<Dtype, P>::LoadVar(framework::LoDTensor *tensor,
std
::
copy
(
desc
.
dims
().
begin
(),
desc
.
dims
().
end
(),
std
::
back_inserter
(
dims
));
std
::
copy
(
desc
.
dims
().
begin
(),
desc
.
dims
().
end
(),
std
::
back_inserter
(
dims
));
tensor
->
Resize
(
framework
::
make_ddim
(
dims
));
tensor
->
Resize
(
framework
::
make_ddim
(
dims
));
void
*
memory
;
void
*
memory
=
tensor
;
int
type_size
=
0
;
int
type_size
=
0
;
// std::cout << " desc pre type: ";
switch
(
desc
.
data_type
())
{
switch
(
desc
.
data_type
())
{
case
framework
::
proto
::
VarType
::
FP16
:
case
framework
::
proto
::
VarType
::
FP16
:
// std::cout << "FP16" << std::endl;
type_size
=
2
;
type_size
=
2
;
break
;
break
;
case
framework
::
proto
::
VarType
::
FP32
:
case
framework
::
proto
::
VarType
::
FP32
:
type_size
=
4
;
type_size
=
4
;
memory
=
tensor
->
mutable_data
<
float
>
();
memory
=
tensor
->
mutable_data
<
float
>
();
// std::cout << "FP32" << std::endl;
break
;
break
;
case
framework
::
proto
::
VarType
::
FP64
:
case
framework
::
proto
::
VarType
::
FP64
:
type_size
=
8
;
type_size
=
8
;
// std::cout << "FP64" << std::endl;
break
;
break
;
case
framework
::
proto
::
VarType
::
INT32
:
case
framework
::
proto
::
VarType
::
INT32
:
type_size
=
4
;
type_size
=
4
;
// std::cout << "INT32" << std::endl;
break
;
break
;
case
framework
::
proto
::
VarType
::
INT64
:
case
framework
::
proto
::
VarType
::
INT64
:
type_size
=
8
;
type_size
=
8
;
// std::cout << "INT64" << std::endl;
break
;
break
;
case
framework
::
proto
::
VarType
::
BOOL
:
case
framework
::
proto
::
VarType
::
BOOL
:
type_size
=
1
;
type_size
=
1
;
// std::cout << "BOOL" << std::endl;
break
;
break
;
default:
default:
break
;
break
;
// std::cout << " not support" << std::endl;
}
}
// std::cout << " malloc size: " << memory_size * type_size <<
// std::endl;
is
.
read
(
static_cast
<
char
*>
(
memory
),
memory_size
*
type_size
);
is
.
read
(
static_cast
<
char
*>
(
memory
),
memory_size
*
type_size
);
// std::cout << " memory: " << memory << std::endl;
is
.
close
();
is
.
close
();
}
;
}
template
<
typename
Dtype
,
Precision
P
>
template
<
typename
Dtype
,
Precision
P
>
const
framework
::
Program
<
Dtype
,
P
>
const
framework
::
Program
<
Dtype
,
P
>
Loader
<
Dtype
,
P
>::
Load
(
Loader
<
Dtype
,
P
>::
Load
(
const
std
::
string
&
dirname
)
{
const
std
::
string
&
dirname
)
{
std
::
string
model_filename
=
dirname
+
"/__model__"
;
std
::
string
model_filename
=
dirname
+
"/__model__"
;
std
::
string
program_desc_str
;
std
::
string
program_desc_str
;
ReadBinaryFile
(
model_filename
,
&
program_desc_str
);
ReadBinaryFile
(
model_filename
,
&
program_desc_str
);
...
@@ -165,10 +141,9 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
...
@@ -165,10 +141,9 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
std
::
make_shared
<
framework
::
Scope
>
();
std
::
make_shared
<
framework
::
Scope
>
();
program
.
scope
=
scope
;
program
.
scope
=
scope
;
auto
block
=
originProgramDesc
->
Block
(
0
);
originProgramDesc
->
Block
(
0
);
for
(
auto
block
:
originProgramDesc
->
Blocks
())
{
for
(
const
auto
&
block
:
originProgramDesc
->
Blocks
())
{
// std::cout << "for block" << std::endl;
for
(
int
i
=
0
;
i
<
block
->
Vars
().
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
block
->
Vars
().
size
();
++
i
)
{
std
::
shared_ptr
<
framework
::
VarDesc
>
var_desc
=
block
->
Vars
()[
i
];
std
::
shared_ptr
<
framework
::
VarDesc
>
var_desc
=
block
->
Vars
()[
i
];
auto
var
=
scope
->
Var
(
var_desc
->
Name
());
auto
var
=
scope
->
Var
(
var_desc
->
Name
());
...
@@ -176,20 +151,18 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
...
@@ -176,20 +151,18 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
if
(
var_desc
->
Persistable
()
&&
if
(
var_desc
->
Persistable
()
&&
var_desc
->
GetType
()
!=
framework
::
proto
::
VarType
::
FEED_MINIBATCH
&&
var_desc
->
GetType
()
!=
framework
::
proto
::
VarType
::
FEED_MINIBATCH
&&
var_desc
->
GetType
()
!=
framework
::
proto
::
VarType
::
FETCH_LIST
)
{
var_desc
->
GetType
()
!=
framework
::
proto
::
VarType
::
FETCH_LIST
)
{
framework
::
LoDTensor
*
tensor
=
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
var
->
GetMutable
<
framework
::
LoDTensor
>
();
// to load
// to load
LoadVar
(
tensor
,
dirname
+
"/"
+
var_desc
->
Name
());
LoadVar
(
tensor
,
dirname
+
"/"
+
var_desc
->
Name
());
}
}
}
else
{
}
else
{
//
std::cout << "非 lod" << std::endl;
//
TODO by someone
}
}
}
}
}
}
#ifdef PADDLE_MOBILE_DEBUG
#ifdef PADDLE_MOBILE_DEBUG
for
(
int
i
=
0
;
i
<
program_desc_proto
.
blocks
().
size
();
++
i
)
{
for
(
const
auto
&
block
:
program_desc_proto
.
blocks
())
{
framework
::
proto
::
BlockDesc
block
=
program_desc_proto
.
blocks
()[
i
];
LOG
(
kLOG_DEBUG
)
<<
"block: "
<<
block
.
idx
();
LOG
(
kLOG_DEBUG
)
<<
"block: "
<<
block
.
idx
();
for
(
int
j
=
0
;
j
<
block
.
ops
().
size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
block
.
ops
().
size
();
++
j
)
{
if
(
j
==
2
)
{
if
(
j
==
2
)
{
...
@@ -200,21 +173,20 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
...
@@ -200,21 +173,20 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
for
(
int
m
=
0
;
m
<
op
.
inputs_size
();
++
m
)
{
for
(
int
m
=
0
;
m
<
op
.
inputs_size
();
++
m
)
{
const
framework
::
proto
::
OpDesc
::
Var
&
var
=
op
.
inputs
(
m
);
const
framework
::
proto
::
OpDesc
::
Var
&
var
=
op
.
inputs
(
m
);
LOG
(
kLOG_DEBUG2
)
<<
"input parameter: "
<<
var
.
parameter
();
LOG
(
kLOG_DEBUG2
)
<<
"input parameter: "
<<
var
.
parameter
();
for
(
int
n
=
0
;
n
<
var
.
arguments
().
size
();
++
n
)
{
for
(
const
auto
&
n
:
var
.
arguments
()
)
{
LOG
(
kLOG_DEBUG3
)
<<
"argument - "
<<
var
.
arguments
()[
n
]
;
LOG
(
kLOG_DEBUG3
)
<<
"argument - "
<<
n
;
}
}
}
}
for
(
int
y
=
0
;
y
<
op
.
outputs_size
();
++
y
)
{
for
(
int
y
=
0
;
y
<
op
.
outputs_size
();
++
y
)
{
const
framework
::
proto
::
OpDesc
::
Var
&
var
=
op
.
outputs
(
y
);
const
framework
::
proto
::
OpDesc
::
Var
&
var
=
op
.
outputs
(
y
);
LOG
(
kLOG_DEBUG2
)
<<
"out parameter: "
<<
var
.
parameter
();
LOG
(
kLOG_DEBUG2
)
<<
"out parameter: "
<<
var
.
parameter
();
for
(
int
z
=
0
;
z
<
var
.
arguments
().
size
();
++
z
)
{
for
(
const
auto
&
z
:
var
.
arguments
()
)
{
LOG
(
kLOG_DEBUG3
)
<<
"argument - "
<<
var
.
arguments
()[
z
]
;
LOG
(
kLOG_DEBUG3
)
<<
"argument - "
<<
z
;
}
}
}
}
for
(
int
x
=
0
;
x
<
op
.
attrs
().
size
();
++
x
)
{
for
(
const
auto
&
attr
:
op
.
attrs
())
{
const
framework
::
proto
::
OpDesc_Attr
attr
=
op
.
attrs
()[
x
];
LOG
(
kLOG_DEBUG2
)
<<
"attr name: "
<<
attr
.
name
();
LOG
(
kLOG_DEBUG2
)
<<
"attr name: "
<<
attr
.
name
();
switch
(
attr
.
type
())
{
switch
(
attr
.
type
())
{
...
@@ -246,19 +218,19 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
...
@@ -246,19 +218,19 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
for
(
int
y
=
0
;
y
<
attr
.
strings_size
();
++
y
)
{
for
(
int
y
=
0
;
y
<
attr
.
strings_size
();
++
y
)
{
LOG
(
kLOG_DEBUG3
)
<<
"strings: "
<<
attr
.
strings
(
y
);
LOG
(
kLOG_DEBUG3
)
<<
"strings: "
<<
attr
.
strings
(
y
);
}
}
case
framework
::
proto
::
BLOCK
:
break
;
}
}
}
}
}
}
for
(
int
k
=
0
;
k
<
block
.
vars
().
size
();
++
k
)
{
for
(
const
auto
&
var
:
block
.
vars
())
{
framework
::
proto
::
VarDesc
var
=
block
.
vars
()[
k
];
if
(
var
.
type
().
type
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
if
(
var
.
type
().
type
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
LOG
(
kLOG_DEBUG1
)
<<
"var name: "
<<
var
.
name
();
LOG
(
kLOG_DEBUG1
)
<<
"var name: "
<<
var
.
name
();
const
framework
::
proto
::
VarType
::
TensorDesc
&
tensor_desc
=
const
framework
::
proto
::
VarType
::
TensorDesc
&
tensor_desc
=
var
.
type
().
lod_tensor
().
tensor
();
var
.
type
().
lod_tensor
().
tensor
();
LOG
(
kLOG_DEBUG2
)
<<
"in var tensor desc dims size: "
LOG
(
kLOG_DEBUG2
)
<<
"in var tensor desc dims size: "
<<
tensor_desc
.
dims
().
size
();
<<
tensor_desc
.
dims
().
size
();
int
memory_size
=
1
;
for
(
int
l
=
0
;
l
<
tensor_desc
.
dims
().
size
();
++
l
)
{
for
(
int
l
=
0
;
l
<
tensor_desc
.
dims
().
size
();
++
l
)
{
LOG
(
kLOG_DEBUG3
)
<<
"var tensor desc dim "
<<
l
LOG
(
kLOG_DEBUG3
)
<<
"var tensor desc dim "
<<
l
<<
" value: "
<<
tensor_desc
.
dims
()[
l
];
<<
" value: "
<<
tensor_desc
.
dims
()[
l
];
...
@@ -268,28 +240,20 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
...
@@ -268,28 +240,20 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
if
(
var
.
persistable
()
&&
if
(
var
.
persistable
()
&&
var
.
type
().
type
()
!=
framework
::
proto
::
VarType
::
FEED_MINIBATCH
&&
var
.
type
().
type
()
!=
framework
::
proto
::
VarType
::
FEED_MINIBATCH
&&
var
.
type
().
type
()
!=
framework
::
proto
::
VarType
::
FETCH_LIST
)
{
var
.
type
().
type
()
!=
framework
::
proto
::
VarType
::
FETCH_LIST
)
{
// std::cout << " to load " << var.name() <<
// std::endl;
std
::
string
file_path
=
dirname
+
"/"
+
var
.
name
();
std
::
string
file_path
=
dirname
+
"/"
+
var
.
name
();
std
::
ifstream
is
(
file_path
);
std
::
ifstream
is
(
file_path
);
std
::
streampos
pos
=
is
.
tellg
();
// save current position
std
::
fpos
<
mbstate_t
>
pos
;
pos
=
is
.
tellg
();
// save current position
is
.
seekg
(
0
,
std
::
ios
::
end
);
is
.
seekg
(
0
,
std
::
ios
::
end
);
// std::cout << " file length = " << is.tellg() <<
// std::endl;
is
.
seekg
(
pos
);
// restore saved position
is
.
seekg
(
pos
);
// restore saved position
// 1. version
// 1. version
uint32_t
version
;
uint32_t
version
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
version
),
sizeof
(
version
));
is
.
read
(
reinterpret_cast
<
char
*>
(
&
version
),
sizeof
(
version
));
// std::cout << " version: " << version <<
// std::endl;
// 2 Lod information
// 2 Lod information
uint64_t
lod_level
;
uint64_t
lod_level
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
lod_level
),
sizeof
(
lod_level
));
is
.
read
(
reinterpret_cast
<
char
*>
(
&
lod_level
),
sizeof
(
lod_level
));
// std::cout << " load level: " << lod_level <<
// std::endl;
// std::cout << " lod info: " << std::endl;
for
(
uint64_t
i
=
0
;
i
<
lod_level
;
++
i
)
{
for
(
uint64_t
i
=
0
;
i
<
lod_level
;
++
i
)
{
uint64_t
size
;
uint64_t
size
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
size
),
sizeof
(
size
));
is
.
read
(
reinterpret_cast
<
char
*>
(
&
size
),
sizeof
(
size
));
...
@@ -297,83 +261,53 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
...
@@ -297,83 +261,53 @@ Loader<Dtype, P>::Load(const std::string &dirname) {
is
.
read
(
reinterpret_cast
<
char
*>
(
tmp
.
data
()),
is
.
read
(
reinterpret_cast
<
char
*>
(
tmp
.
data
()),
static_cast
<
std
::
streamsize
>
(
size
));
static_cast
<
std
::
streamsize
>
(
size
));
for
(
int
j
=
0
;
j
<
tmp
.
size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
tmp
.
size
();
++
j
)
{
// std::cout << " lod - " << tmp[j] <<
// std::endl;
}
}
}
}
uint32_t
tensor_version
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
version
),
sizeof
(
version
));
is
.
read
(
reinterpret_cast
<
char
*>
(
&
version
),
sizeof
(
version
));
// std::cout << " tensor_version: " <<
// tensor_version <<
// std::endl;
int32_t
size
;
int32_t
size
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
size
),
sizeof
(
size
));
is
.
read
(
reinterpret_cast
<
char
*>
(
&
size
),
sizeof
(
size
));
// std::cout << " tensor desc size: " << size <<
// std::endl;
std
::
unique_ptr
<
char
[]
>
buf
(
new
char
[
size
]);
std
::
unique_ptr
<
char
[]
>
buf
(
new
char
[
size
]);
is
.
read
(
reinterpret_cast
<
char
*>
(
buf
.
get
()),
size
);
is
.
read
(
reinterpret_cast
<
char
*>
(
buf
.
get
()),
size
);
framework
::
proto
::
VarType
::
TensorDesc
desc
;
framework
::
proto
::
VarType
::
TensorDesc
desc
;
desc
.
ParseFromArray
(
buf
.
get
(),
size
);
desc
.
ParseFromArray
(
buf
.
get
(),
size
);
// std::cout << " desc dims size " <<
// desc.dims().size() <<
// std::endl;
int
memory_size
=
1
;
int
memory_size
=
1
;
for
(
int
l
=
0
;
l
<
desc
.
dims
().
size
();
++
l
)
{
for
(
long
long
l
:
desc
.
dims
())
{
// std::cout << " dim " << l << " value: "
memory_size
*=
l
;
// <<
// desc.dims()[l]
// << std::endl;
memory_size
*=
desc
.
dims
()[
l
];
}
}
int
type_size
=
0
;
int
type_size
=
0
;
// std::cout << " desc pre type: ";
switch
(
desc
.
data_type
())
{
switch
(
desc
.
data_type
())
{
case
framework
::
proto
::
VarType
::
FP16
:
case
framework
::
proto
::
VarType
::
FP16
:
// std::cout << "FP16" << std::endl;
type_size
=
2
;
type_size
=
2
;
break
;
break
;
case
framework
::
proto
::
VarType
::
FP32
:
case
framework
::
proto
::
VarType
::
FP32
:
type_size
=
4
;
type_size
=
4
;
// std::cout << "FP32" << std::endl;
break
;
break
;
case
framework
::
proto
::
VarType
::
FP64
:
case
framework
::
proto
::
VarType
::
FP64
:
type_size
=
8
;
type_size
=
8
;
// std::cout << "FP64" << std::endl;
break
;
break
;
case
framework
::
proto
::
VarType
::
INT32
:
case
framework
::
proto
::
VarType
::
INT32
:
type_size
=
4
;
type_size
=
4
;
// std::cout << "INT32" << std::endl;
break
;
break
;
case
framework
::
proto
::
VarType
::
INT64
:
case
framework
::
proto
::
VarType
::
INT64
:
type_size
=
8
;
type_size
=
8
;
// std::cout << "INT64" << std::endl;
break
;
break
;
case
framework
::
proto
::
VarType
::
BOOL
:
case
framework
::
proto
::
VarType
::
BOOL
:
type_size
=
1
;
type_size
=
1
;
// std::cout << "BOOL" << std::endl;
break
;
break
;
default:
default:
break
;
break
;
// std::cout << " not support" <<
// std::endl;
}
}
// std::cout << " malloc size: " << memory_size *
// type_size
// << std::endl;
void
*
memory
=
malloc
(
memory_size
*
type_size
);
void
*
memory
=
malloc
(
memory_size
*
type_size
);
is
.
read
(
static_cast
<
char
*>
(
memory
),
memory_size
*
type_size
);
is
.
read
(
static_cast
<
char
*>
(
memory
),
memory_size
*
type_size
);
// std::cout << " memory: " << memory <<
// std::endl;
is
.
close
();
is
.
close
();
}
else
{
}
else
{
// std::cout << " *not load "
// TODO
// << " var : " << var.name() << std::endl;
}
}
}
}
}
}
...
...
src/io.h
浏览文件 @
d6825bf3
...
@@ -29,10 +29,10 @@ namespace paddle_mobile {
...
@@ -29,10 +29,10 @@ namespace paddle_mobile {
template
<
typename
Dtype
,
Precision
P
=
Precision
::
FP32
>
template
<
typename
Dtype
,
Precision
P
=
Precision
::
FP32
>
class
Loader
:
PaddleMobileObject
{
class
Loader
:
PaddleMobileObject
{
public:
public:
const
framework
::
Program
<
Dtype
,
P
>
Load
(
const
std
::
string
&
dirname
);
const
framework
::
Program
<
Dtype
,
P
>
Load
(
const
std
::
string
&
dirname
);
private:
private:
void
LoadVar
(
framework
::
LoDTensor
*
tensor
,
const
std
::
string
&
file_path
);
void
LoadVar
(
framework
::
LoDTensor
*
tensor
,
const
std
::
string
&
file_path
);
};
};
...
...
src/memory/t_malloc.cc
浏览文件 @
d6825bf3
src/memory/t_malloc.h
浏览文件 @
d6825bf3
...
@@ -37,10 +37,11 @@ void Free(void *ptr);
...
@@ -37,10 +37,11 @@ void Free(void *ptr);
* std::unique_ptr<T> in tensor.h.
* std::unique_ptr<T> in tensor.h.
* static_cast
* static_cast
*/
*/
template
<
typename
T
>
class
PODDeleter
{
template
<
typename
T
>
class
PODDeleter
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
public:
public:
explicit
PODDeleter
(){};
explicit
PODDeleter
(){};
void
operator
()(
T
*
ptr
)
{
Free
(
static_cast
<
void
*>
(
ptr
));
}
void
operator
()(
T
*
ptr
)
{
Free
(
static_cast
<
void
*>
(
ptr
));
}
...
@@ -54,8 +55,9 @@ public:
...
@@ -54,8 +55,9 @@ public:
* std::unique_ptr<T> in tensor.h.
* std::unique_ptr<T> in tensor.h.
* reinterpret_cast
* reinterpret_cast
*/
*/
template
<
typename
T
>
class
PlainDeleter
{
template
<
typename
T
>
public:
class
PlainDeleter
{
public:
explicit
PlainDeleter
(){};
explicit
PlainDeleter
(){};
void
operator
()(
T
*
ptr
)
{
Free
(
reinterpret_cast
<
void
*>
(
ptr
));
}
void
operator
()(
T
*
ptr
)
{
Free
(
reinterpret_cast
<
void
*>
(
ptr
));
}
...
...
src/operators/batchnorm_op.cpp
浏览文件 @
d6825bf3
src/operators/batchnorm_op.h
浏览文件 @
d6825bf3
...
@@ -27,7 +27,7 @@ using namespace framework;
...
@@ -27,7 +27,7 @@ using namespace framework;
template
<
typename
DeviceType
,
typename
T
>
template
<
typename
DeviceType
,
typename
T
>
class
BatchNormOp
:
public
framework
::
OperatorWithKernel
<
DeviceType
>
{
class
BatchNormOp
:
public
framework
::
OperatorWithKernel
<
DeviceType
>
{
public:
public:
BatchNormOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
BatchNormOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
attrs
,
const
framework
::
AttributeMap
attrs
,
...
@@ -44,7 +44,7 @@ public:
...
@@ -44,7 +44,7 @@ public:
using
framework
::
OperatorWithKernel
<
DeviceType
>::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
<
DeviceType
>::
OperatorWithKernel
;
void
InferShape
()
const
override
;
void
InferShape
()
const
override
;
protected:
protected:
BatchNormParam
param_
;
BatchNormParam
param_
;
};
};
...
...
src/operators/concat_op.cpp
浏览文件 @
d6825bf3
src/operators/concat_op.h
浏览文件 @
d6825bf3
...
@@ -26,7 +26,7 @@ using namespace framework;
...
@@ -26,7 +26,7 @@ using namespace framework;
template
<
typename
DeviceType
,
typename
T
>
template
<
typename
DeviceType
,
typename
T
>
class
ConcatOp
:
public
framework
::
OperatorWithKernel
<
DeviceType
>
{
class
ConcatOp
:
public
framework
::
OperatorWithKernel
<
DeviceType
>
{
public:
public:
ConcatOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
ConcatOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
attrs
,
const
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
attrs
,
std
::
shared_ptr
<
framework
::
Scope
>
scope
)
std
::
shared_ptr
<
framework
::
Scope
>
scope
)
...
@@ -42,7 +42,7 @@ public:
...
@@ -42,7 +42,7 @@ public:
using
framework
::
OperatorWithKernel
<
DeviceType
>::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
<
DeviceType
>::
OperatorWithKernel
;
void
InferShape
()
const
override
;
void
InferShape
()
const
override
;
protected:
protected:
ConcatParam
param_
;
ConcatParam
param_
;
};
};
...
...
src/operators/conv_op.cpp
浏览文件 @
d6825bf3
src/operators/conv_op.h
浏览文件 @
d6825bf3
...
@@ -28,7 +28,7 @@ using namespace framework;
...
@@ -28,7 +28,7 @@ using namespace framework;
template
<
typename
DeviceType
,
typename
T
>
template
<
typename
DeviceType
,
typename
T
>
class
ConvOp
:
public
framework
::
OperatorWithKernel
<
DeviceType
>
{
class
ConvOp
:
public
framework
::
OperatorWithKernel
<
DeviceType
>
{
public:
public:
ConvOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
ConvOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
,
const
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
,
std
::
shared_ptr
<
framework
::
Scope
>
scope
)
std
::
shared_ptr
<
framework
::
Scope
>
scope
)
...
@@ -45,7 +45,7 @@ public:
...
@@ -45,7 +45,7 @@ public:
this
->
ClearVariables
({
"Filter"
,
"Input"
});
this
->
ClearVariables
({
"Filter"
,
"Input"
});
}
}
private:
private:
ConvParam
param_
;
ConvParam
param_
;
};
};
...
...
src/operators/elementwise_add_op.cpp
浏览文件 @
d6825bf3
src/operators/elementwise_add_op.h
浏览文件 @
d6825bf3
...
@@ -27,7 +27,7 @@ using namespace framework;
...
@@ -27,7 +27,7 @@ using namespace framework;
template
<
typename
DeviceType
,
typename
T
>
template
<
typename
DeviceType
,
typename
T
>
class
ElementwiseAddOp
:
public
framework
::
OperatorWithKernel
<
DeviceType
>
{
class
ElementwiseAddOp
:
public
framework
::
OperatorWithKernel
<
DeviceType
>
{
public:
public:
ElementwiseAddOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
ElementwiseAddOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
attrs
,
const
framework
::
AttributeMap
attrs
,
...
@@ -44,7 +44,7 @@ public:
...
@@ -44,7 +44,7 @@ public:
using
framework
::
OperatorWithKernel
<
DeviceType
>::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
<
DeviceType
>::
OperatorWithKernel
;
void
InferShape
()
const
override
;
void
InferShape
()
const
override
;
protected:
protected:
ElementwiseAddParam
param_
;
ElementwiseAddParam
param_
;
};
};
}
// namespace operators
}
// namespace operators
...
...
src/operators/kernel/arm/batchnorm_kernel.cpp
浏览文件 @
d6825bf3
src/operators/kernel/arm/concat_kernel.cpp
浏览文件 @
d6825bf3
...
@@ -18,8 +18,9 @@ limitations under the License. */
...
@@ -18,8 +18,9 @@ limitations under the License. */
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
class
ConcatFunctor
{
template
<
typename
T
>
public:
class
ConcatFunctor
{
public:
void
operator
()(
const
std
::
vector
<
framework
::
Tensor
>
&
input
,
const
int
axis
,
void
operator
()(
const
std
::
vector
<
framework
::
Tensor
>
&
input
,
const
int
axis
,
framework
::
Tensor
*
output
)
{
framework
::
Tensor
*
output
)
{
size_t
num
=
input
.
size
();
size_t
num
=
input
.
size
();
...
...
src/operators/kernel/arm/conv_kernel.cpp
浏览文件 @
d6825bf3
...
@@ -34,7 +34,8 @@ bool IsExpand(const std::vector<int64_t> &filter_dim,
...
@@ -34,7 +34,8 @@ bool IsExpand(const std::vector<int64_t> &filter_dim,
return
!
(
filter_1
&&
strides_1
&&
padding_0
&&
dilation_1
);
return
!
(
filter_1
&&
strides_1
&&
padding_0
&&
dilation_1
);
}
}
template
<
>
void
ConvKernel
<
CPU
,
float
>::
Compute
(
const
ConvParam
&
param
)
const
{
template
<
>
void
ConvKernel
<
CPU
,
float
>::
Compute
(
const
ConvParam
&
param
)
const
{
LOG
(
kLOG_DEBUG
)
<<
param
;
LOG
(
kLOG_DEBUG
)
<<
param
;
const
Tensor
*
input
=
param
.
Input
();
const
Tensor
*
input
=
param
.
Input
();
...
...
src/operators/kernel/arm/elementwise_add_kernel.cpp
浏览文件 @
d6825bf3
...
@@ -19,7 +19,8 @@ limitations under the License. */
...
@@ -19,7 +19,8 @@ limitations under the License. */
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
struct
AddFunctor
{
template
<
typename
T
>
struct
AddFunctor
{
inline
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
inline
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
};
...
...
src/operators/kernel/arm/lrn_kernel.cpp
浏览文件 @
d6825bf3
...
@@ -23,7 +23,8 @@ SOFTWARE.
...
@@ -23,7 +23,8 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
template
<
>
void
LrnKernel
<
CPU
,
float
>::
Compute
(
const
LrnParam
&
param
)
const
{
template
<
>
void
LrnKernel
<
CPU
,
float
>::
Compute
(
const
LrnParam
&
param
)
const
{
const
Tensor
*
input_x
=
param
.
InputX
();
const
Tensor
*
input_x
=
param
.
InputX
();
auto
x_dims
=
input_x
->
dims
();
auto
x_dims
=
input_x
->
dims
();
/// data_format = NCHW
/// data_format = NCHW
...
...
src/operators/kernel/arm/mul_kernel.cpp
浏览文件 @
d6825bf3
...
@@ -23,7 +23,8 @@ SOFTWARE.
...
@@ -23,7 +23,8 @@ SOFTWARE.
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
template
<
>
void
MulKernel
<
CPU
,
float
>::
Compute
(
const
MulParam
&
param
)
const
{
template
<
>
void
MulKernel
<
CPU
,
float
>::
Compute
(
const
MulParam
&
param
)
const
{
const
Tensor
*
input_x
=
param
.
InputX
();
const
Tensor
*
input_x
=
param
.
InputX
();
const
Tensor
*
input_y
=
param
.
InputY
();
const
Tensor
*
input_y
=
param
.
InputY
();
Tensor
*
out
=
param
.
Out
();
Tensor
*
out
=
param
.
Out
();
...
...
src/operators/kernel/arm/pool_kernel.cpp
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/kernel/batchnorm_kernel.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/kernel/concat_kernel.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/kernel/conv_kernel.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/kernel/elementwise_add_kernel.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/kernel/fpga/conv_kernel.cpp
浏览文件 @
d6825bf3
src/operators/kernel/lrn_kernel.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/kernel/mul_kernel.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/kernel/pool_kernel.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/lrn_op.cpp
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/lrn_op.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/math/elementwise_op_function.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/math/im2col.cc
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/math/im2col.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/math/math_function.cc
浏览文件 @
d6825bf3
src/operators/math/math_function.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/math/pool3x3.h
浏览文件 @
d6825bf3
src/operators/math/pool_2x2.h
浏览文件 @
d6825bf3
src/operators/math/pooling.cpp
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/math/pooling.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/math/transform.h
浏览文件 @
d6825bf3
src/operators/math/vol2col.cc
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/math/vol2col.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/mul_op.cpp
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/mul_op.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/op_param.cpp
浏览文件 @
d6825bf3
src/operators/op_param.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/operators/pool_op.cpp
浏览文件 @
d6825bf3
src/operators/pool_op.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/platform/data_type.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
src/platform/macros.h
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
test/operators/test_cov_op.cpp
浏览文件 @
d6825bf3
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录