Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
7701ac0d
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7701ac0d
编写于
3月 29, 2019
作者:
H
hjchen2
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use enum type rather than string since comparing between enum is more efficent than string
上级
2852680a
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
175 addition
and
105 deletion
+175
-105
src/common/type_define.h
src/common/type_define.h
+79
-10
src/common/variant.h
src/common/variant.h
+11
-11
src/framework/attribute.h
src/framework/attribute.h
+12
-13
src/framework/data_type.cpp
src/framework/data_type.cpp
+10
-10
src/framework/data_type.h
src/framework/data_type.h
+4
-4
src/framework/ddim.h
src/framework/ddim.h
+10
-10
src/framework/load_ops.h
src/framework/load_ops.h
+3
-0
src/framework/mixed_vector.h
src/framework/mixed_vector.h
+1
-1
src/framework/tensor.h
src/framework/tensor.h
+19
-19
src/framework/tensor_base.h
src/framework/tensor_base.h
+10
-11
src/framework/variable.h
src/framework/variable.h
+8
-7
src/operators/kernel/arm/compare_kernel.cpp
src/operators/kernel/arm/compare_kernel.cpp
+2
-2
src/operators/kernel/arm/concat_kernel.cpp
src/operators/kernel/arm/concat_kernel.cpp
+1
-1
src/operators/kernel/arm/convolution/conv_common.cpp
src/operators/kernel/arm/convolution/conv_common.cpp
+1
-1
src/operators/kernel/arm/transpose2_kernel.cpp
src/operators/kernel/arm/transpose2_kernel.cpp
+2
-2
src/operators/kernel/central-arm-func/mul_arm_func.h
src/operators/kernel/central-arm-func/mul_arm_func.h
+1
-1
src/operators/kernel/central-arm-func/sum_arm_func.h
src/operators/kernel/central-arm-func/sum_arm_func.h
+1
-2
未找到文件。
src/common/type_define.h
浏览文件 @
7701ac0d
...
...
@@ -14,28 +14,85 @@ limitations under the License. */
#pragma once
#include <functional>
#include <string>
#include <vector>
namespace
paddle_mobile
{
typedef
enum
{
_void
=
0
,
_float
,
_int
,
_double
,
_int64_t
,
_size_t
,
_int16_t
,
_int8_t
,
_uint8_t
,
_bool
,
_string
,
_floats
=
100
,
_ints
,
_int64_ts
,
_size_ts
,
_bools
,
_strings
,
_const_float
=
200
,
_const_int
,
_block
=
300
,
_tensor
,
_lod_tensor
,
_blocks
,
_tensors
,
_lod_tensors
,
_p_block
=
400
,
_p_tensor
,
_p_lod_tensor
,
_p_blocks
,
_p_tensors
,
_p_lod_tensors
,
_scopes
=
500
,
_selected_rows
,
_dim0
=
600
,
_dim1
,
_dim2
,
_dim3
,
_dim4
,
_dim5
,
_dim6
,
_dim7
,
_dim8
,
_dim9
,
}
kTypeId_t
;
template
<
typename
T
>
struct
TypeIdWrapper
{
std
::
string
type
();
inline
std
::
string
name
();
inline
kTypeId_t
hash_code
();
};
template
<
typename
T
>
struct
type_id
{
const
std
::
string
type_
=
TypeIdWrapper
<
T
>
().
type
();
const
kTypeId_t
hash_code
()
const
{
return
TypeIdWrapper
<
T
>
().
hash_code
();
}
const
std
::
string
name
()
const
{
return
TypeIdWrapper
<
T
>
().
name
();
}
template
<
typename
OtherType
>
bool
operator
==
(
const
type_id
<
OtherType
>
&
operand
)
{
return
this
->
name
()
==
operand
.
nam
e
();
bool
operator
==
(
const
type_id
<
OtherType
>
&
operand
)
const
{
return
this
->
hash_code
()
==
operand
.
hash_cod
e
();
}
const
std
::
string
name
()
{
return
type_
;
}
};
template
<
typename
T
>
inline
bool
operator
==
(
const
kTypeId_t
&
t0
,
const
type_id
<
T
>
&
t1
)
{
return
t0
==
t1
.
hash_code
();
}
template
<
typename
T
>
inline
bool
operator
==
(
const
type_id
<
T
>
&
t0
,
const
kTypeId_t
&
t1
)
{
return
t1
==
t0
.
hash_code
();
}
namespace
framework
{
class
BlockDesc
;
class
Tensor
;
...
...
@@ -47,10 +104,11 @@ template <int>
struct
Dim
;
}
// namespace framework
#define REGISTER_TYPE_ID(Type, TypeName) \
template <> \
struct TypeIdWrapper<Type> { \
std::string type() { return std::string(#TypeName); } \
#define REGISTER_TYPE_ID(Type, TypeName) \
template <> \
struct TypeIdWrapper<Type> { \
inline std::string name() { return std::string(#TypeName); } \
inline kTypeId_t hash_code() { return kTypeId_t::TypeName; } \
};
REGISTER_TYPE_ID
(
void
,
_void
)
...
...
@@ -102,3 +160,14 @@ REGISTER_TYPE_ID(framework::Dim<8>, _dim8)
REGISTER_TYPE_ID
(
framework
::
Dim
<
9
>
,
_dim9
)
}
// namespace paddle_mobile
namespace
std
{
template
<
>
struct
hash
<
paddle_mobile
::
kTypeId_t
>
{
size_t
operator
()(
const
paddle_mobile
::
kTypeId_t
&
t
)
const
{
return
std
::
hash
<
int
>
{}(
static_cast
<
int
>
(
t
));
}
};
}
// namespace std
src/common/variant.h
浏览文件 @
7701ac0d
...
...
@@ -34,8 +34,8 @@ struct VariantHelper {
?
sizeof
(
F
)
:
VariantHelper
<
Ts
...
>::
size
;
inline
static
void
Destroy
(
std
::
string
type
,
void
*
data
)
{
if
(
type
==
type_id
<
F
>
()
.
name
()
)
{
inline
static
void
Destroy
(
kTypeId_t
type
,
void
*
data
)
{
if
(
type
==
type_id
<
F
>
())
{
reinterpret_cast
<
F
*>
(
data
)
->~
F
();
}
else
{
VariantHelper
<
Ts
...
>::
Destroy
(
type
,
data
);
...
...
@@ -46,8 +46,8 @@ struct VariantHelper {
template
<
typename
F
>
struct
VariantHelper
<
F
>
{
static
const
size_t
size
=
sizeof
(
F
);
inline
static
void
Destroy
(
std
::
string
type
,
void
*
data
)
{
if
(
type
==
type_id
<
F
>
()
.
name
()
)
{
inline
static
void
Destroy
(
kTypeId_t
type
,
void
*
data
)
{
if
(
type
==
type_id
<
F
>
())
{
// reinterpret_cast<F*>(data)->~F();
}
else
{
// std::cout << "未匹配到 " << std::endl;
...
...
@@ -85,17 +85,17 @@ struct Variant {
void
Set
(
Args
&&
...
args
)
{
helper
::
Destroy
(
type_
,
data_
.
data
);
new
(
data_
.
data
)
T
(
std
::
forward
<
Args
>
(
args
)...);
type_
=
type_id
<
T
>
().
nam
e
();
type_
=
type_id
<
T
>
().
hash_cod
e
();
}
void
SetString
(
const
std
::
string
&
string
)
{
helper
::
Destroy
(
type_
,
data_
.
data
);
type_
=
type_id
<
std
::
string
>
().
nam
e
();
type_
=
type_id
<
std
::
string
>
().
hash_cod
e
();
strcpy
(
data_
.
data
,
string
.
c_str
());
// NOLINT
}
std
::
string
GetString
()
const
{
if
(
type_
==
type_id
<
std
::
string
>
()
.
name
()
)
{
if
(
type_
==
type_id
<
std
::
string
>
())
{
return
std
::
string
(
data_
.
data
);
}
else
{
PADDLE_MOBILE_THROW_EXCEPTION
(
...
...
@@ -106,7 +106,7 @@ struct Variant {
template
<
typename
T
>
T
&
Get
()
const
{
if
(
type_
==
type_id
<
std
::
string
>
()
.
name
()
)
{
if
(
type_
==
type_id
<
std
::
string
>
())
{
PADDLE_MOBILE_THROW_EXCEPTION
(
"Please use getString to get an string (to avoid of an issue with "
"gcc "
...
...
@@ -117,12 +117,12 @@ struct Variant {
}
}
std
::
string
TypeId
()
const
{
return
type_
;
}
kTypeId_t
TypeId
()
const
{
return
type_
;
}
private:
static
inline
std
::
string
invalid_type
()
{
return
type_id
<
void
>
().
nam
e
();
}
static
inline
kTypeId_t
invalid_type
()
{
return
type_id
<
void
>
().
hash_cod
e
();
}
typedef
VariantHelper
<
Ts
...
>
helper
;
std
::
string
type_
=
type_id
<
void
>
().
nam
e
();
kTypeId_t
type_
=
type_id
<
void
>
().
hash_cod
e
();
// todo use an anto size to suite this.
RawData
<
64
>
data_
;
};
...
...
src/framework/attribute.h
浏览文件 @
7701ac0d
...
...
@@ -128,31 +128,30 @@ class Attribute {
template
<
typename
Vistor
>
static
typename
Vistor
::
type_t
ApplyVistor
(
Vistor
vistor
,
Attribute
attr
)
{
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
int
>
()
.
name
()
)
{
// NOLINT
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
int
>
())
{
// NOLINT
return
vistor
(
attr
.
variant_
.
Get
<
int
>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
float
>
()
.
name
()
)
{
// NOLINT
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
float
>
())
{
// NOLINT
return
vistor
(
attr
.
variant_
.
Get
<
float
>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
string
>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
string
>
())
{
return
vistor
(
attr
.
variant_
.
GetString
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
int
>>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
int
>>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
int
>>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
float
>>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
float
>>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
float
>>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
string
>>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
string
>>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
string
>>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
bool
>
()
.
name
()
)
{
// NOLINT
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
bool
>
())
{
// NOLINT
return
vistor
(
attr
.
variant_
.
Get
<
bool
>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
bool
>>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
bool
>>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
bool
>>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
int64_t
>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
int64_t
>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
int64_t
>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
framework
::
BlockDesc
*>
().
name
())
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
framework
::
BlockDesc
*>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
framework
::
BlockDesc
*>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
framework
::
BlockDesc
*>>
()
.
name
()
)
{
type_id
<
vector
<
framework
::
BlockDesc
*>>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
framework
::
BlockDesc
*>>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
int64_t
>>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
int64_t
>>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
int64_t
>>
());
}
else
{
PADDLE_MOBILE_THROW_EXCEPTION
(
"type not support"
);
...
...
src/framework/data_type.cpp
浏览文件 @
7701ac0d
...
...
@@ -22,12 +22,11 @@ namespace paddle_mobile {
namespace
framework
{
struct
DataTypeMap
{
std
::
unordered_map
<
std
::
string
,
_PaddleMobile__Framework__Proto__VarType__Type
>
std
::
unordered_map
<
kTypeId_t
,
_PaddleMobile__Framework__Proto__VarType__Type
>
cpp_to_proto_
;
std
::
unordered_map
<
int
,
std
::
string
>
proto_to_cpp_
;
std
::
unordered_map
<
int
,
kTypeId_t
>
proto_to_cpp_
;
std
::
unordered_map
<
int
,
std
::
string
>
proto_to_str_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
cpp_to_size_
;
std
::
unordered_map
<
kTypeId_t
,
size_t
>
cpp_to_size_
;
};
static
DataTypeMap
*
InitDataTypeMap
();
...
...
@@ -43,10 +42,11 @@ template <typename T>
static
inline
void
RegisterType
(
DataTypeMap
*
map
,
_PaddleMobile__Framework__Proto__VarType__Type
proto_type
,
const
std
::
string
&
name
)
{
map
->
proto_to_cpp_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
type_id
<
T
>
().
name
());
map
->
cpp_to_proto_
.
emplace
(
type_id
<
T
>
().
name
(),
proto_type
);
map
->
proto_to_cpp_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
type_id
<
T
>
().
hash_code
());
map
->
cpp_to_proto_
.
emplace
(
type_id
<
T
>
().
hash_code
(),
proto_type
);
map
->
proto_to_str_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
name
);
map
->
cpp_to_size_
.
emplace
(
type_id
<
T
>
().
nam
e
(),
sizeof
(
T
));
map
->
cpp_to_size_
.
emplace
(
type_id
<
T
>
().
hash_cod
e
(),
sizeof
(
T
));
}
static
DataTypeMap
*
InitDataTypeMap
()
{
...
...
@@ -71,15 +71,15 @@ static DataTypeMap* InitDataTypeMap() {
return
retv
;
}
_PaddleMobile__Framework__Proto__VarType__Type
ToDataType
(
std
::
string
type
)
{
_PaddleMobile__Framework__Proto__VarType__Type
ToDataType
(
kTypeId_t
type
)
{
auto
it
=
gDataTypeMap
().
cpp_to_proto_
.
find
(
type
);
if
(
it
!=
gDataTypeMap
().
cpp_to_proto_
.
end
())
{
return
it
->
second
;
}
PADDLE_MOBILE_THROW_EXCEPTION
(
"Not support %
s as tensor type"
,
type
.
c_str
()
);
PADDLE_MOBILE_THROW_EXCEPTION
(
"Not support %
d as tensor type"
,
type
);
}
std
::
string
ToTypeIndex
(
_PaddleMobile__Framework__Proto__VarType__Type
type
)
{
kTypeId_t
ToTypeIndex
(
_PaddleMobile__Framework__Proto__VarType__Type
type
)
{
auto
it
=
gDataTypeMap
().
proto_to_cpp_
.
find
(
static_cast
<
int
>
(
type
));
if
(
it
!=
gDataTypeMap
().
proto_to_cpp_
.
end
())
{
return
it
->
second
;
...
...
src/framework/data_type.h
浏览文件 @
7701ac0d
...
...
@@ -16,16 +16,16 @@ limitations under the License. */
#include <string>
#include "common/enforce.h"
#include "common/type_define.h"
#include "framework/framework.pb-c.h"
namespace
paddle_mobile
{
namespace
framework
{
extern
_PaddleMobile__Framework__Proto__VarType__Type
ToDataType
(
std
::
string
type
);
extern
std
::
string
ToTypeIndex
(
_PaddleMobile__Framework__Proto__VarType__Type
type
);
_PaddleMobile__Framework__Proto__VarType__Type
ToDataType
(
kTypeId_t
type
);
kTypeId_t
ToTypeIndex
(
_PaddleMobile__Framework__Proto__VarType__Type
type
);
inline
_PaddleMobile__Framework__Proto__VarType__Type
ToDataType
(
int
type
)
{
return
static_cast
<
_PaddleMobile__Framework__Proto__VarType__Type
>
(
type
);
...
...
src/framework/ddim.h
浏览文件 @
7701ac0d
...
...
@@ -40,25 +40,25 @@ struct DDim {
template
<
typename
Vistor
>
static
typename
Vistor
::
type_t
ApplyVistor
(
Vistor
vistor
,
const
DDim
&
d
)
{
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
0
>>
()
.
name
()
)
{
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
0
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
0
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
1
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
1
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
1
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
2
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
2
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
2
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
3
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
3
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
3
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
4
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
4
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
4
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
5
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
5
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
5
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
6
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
6
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
6
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
7
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
7
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
7
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
8
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
8
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
8
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
9
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
9
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
9
>>
());
}
else
{
PADDLE_MOBILE_ENFORCE
(
false
,
" dim not support"
);
...
...
src/framework/load_ops.h
浏览文件 @
7701ac0d
...
...
@@ -268,6 +268,9 @@ LOAD_OP1(sequence_expand, CPU);
#ifdef SEQUENCE_POOL_OP
LOAD_OP1
(
sequence_pool
,
CPU
);
#endif
#ifdef SEQUENCE_SOFTMAX_OP
LOAD_OP1
(
sequence_softmax
,
CPU
);
#endif
#ifdef LOG_OP
LOAD_OP1
(
log
,
CPU
);
#endif
...
...
src/framework/mixed_vector.h
浏览文件 @
7701ac0d
...
...
@@ -197,7 +197,7 @@ class Vector {
}
size_t
capacity
()
const
{
return
cpu_vec_
.
memory_size
()
/
SizeOfType
(
type_id
<
T
>
().
nam
e
());
return
cpu_vec_
.
memory_size
()
/
SizeOfType
(
type_id
<
T
>
().
hash_cod
e
());
}
// reserve data
...
...
src/framework/tensor.h
浏览文件 @
7701ac0d
...
...
@@ -81,7 +81,7 @@ class Tensor : public TensorBase {
return
*
this
;
}
inline
void
*
mutable_data
(
const
std
::
string
type
)
{
inline
void
*
mutable_data
(
const
kTypeId_t
type
)
{
if
(
holder_
!=
nullptr
)
{
holder_
->
set_type
(
type
);
}
...
...
@@ -106,7 +106,7 @@ class Tensor : public TensorBase {
template
<
typename
T
>
inline
T
*
mutable_data
()
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
return
reinterpret_cast
<
T
*>
(
mutable_data
(
type_id
<
T
>
().
nam
e
()));
return
reinterpret_cast
<
T
*>
(
mutable_data
(
type_id
<
T
>
().
hash_cod
e
()));
}
/**
...
...
@@ -163,9 +163,9 @@ class Tensor : public TensorBase {
check_memory_size
();
PADDLE_MOBILE_ENFORCE
(
(
std
::
is_same
<
T
,
void
>::
value
||
holder_
->
type
()
==
type_id
<
T
>
().
nam
e
()),
"Tensor holds the wrong type, it holds %
s, requested %s
"
,
this
->
holder_
->
type
()
.
c_str
(),
type_id
<
T
>
().
name
().
c_str
());
holder_
->
type
()
==
type_id
<
T
>
().
hash_cod
e
()),
"Tensor holds the wrong type, it holds %
d, requested %d
"
,
this
->
holder_
->
type
()
,
type_id
<
T
>
().
hash_code
());
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
...
...
@@ -177,9 +177,9 @@ class Tensor : public TensorBase {
check_memory_size
();
PADDLE_MOBILE_ENFORCE
(
(
std
::
is_same
<
T
,
void
>::
value
||
holder_
->
type
()
==
type_id
<
T
>
().
nam
e
()),
"Tensor holds the wrong type, it holds %
s, requested %s
"
,
this
->
holder_
->
type
()
.
c_str
(),
type_id
<
T
>
().
name
().
c_str
());
holder_
->
type
()
==
type_id
<
T
>
().
hash_cod
e
()),
"Tensor holds the wrong type, it holds %
d, requested %d
"
,
this
->
holder_
->
type
()
,
type_id
<
T
>
().
hash_code
());
return
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
...
...
@@ -187,7 +187,7 @@ class Tensor : public TensorBase {
private:
struct
PlaceholderImpl
:
public
Placeholder
{
PlaceholderImpl
(
size_t
size
,
const
std
::
string
type
)
PlaceholderImpl
(
size_t
size
,
const
kTypeId_t
type
)
:
ptr_
(
static_cast
<
uint8_t
*>
(
memory
::
Alloc
(
size
)),
memory
::
PODDeleter
<
uint8_t
>
()),
size_
(
size
),
...
...
@@ -201,9 +201,9 @@ class Tensor : public TensorBase {
virtual
void
*
ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
virtual
std
::
string
type
()
const
{
return
type_
;
}
virtual
kTypeId_t
type
()
const
{
return
type_
;
}
virtual
void
set_type
(
const
std
::
string
type
)
{
type_
=
type
;
}
virtual
void
set_type
(
const
kTypeId_t
type
)
{
type_
=
type
;
}
virtual
void
resize
(
size_t
size
)
{
if
(
size
>
capatity_
)
{
...
...
@@ -221,7 +221,7 @@ class Tensor : public TensorBase {
size_t
capatity_
;
/* the current type of memory */
std
::
string
type_
;
kTypeId_t
type_
;
};
#ifdef PADDLE_MOBILE_FPGA
...
...
@@ -229,13 +229,13 @@ class Tensor : public TensorBase {
inline
void
reset_data_ptr
(
void
*
p
)
{
((
PlaceholderImpl
*
)(
holder_
.
get
()))
->
ptr_
.
reset
((
uint8_t
*
)
p
);
// NOLINT
}
inline
void
set_type
(
const
std
::
string
type
)
{
holder_
->
set_type
(
type
);
}
inline
void
set_type
(
const
kTypeId_t
type
)
{
holder_
->
set_type
(
type
);
}
inline
void
*
get_data
()
{
return
(
void
*
)(((
PlaceholderImpl
*
)(
holder_
.
get
()))
->
ptr_
.
get
());
// NOLINT
}
inline
void
*
init
(
const
std
::
string
type
)
{
inline
void
*
init
(
const
kTypeId_t
type
)
{
if
(
holder_
!=
nullptr
)
{
holder_
->
set_type
(
type
);
}
...
...
@@ -263,15 +263,15 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) {
stride
=
stride
>
0
?
stride
:
1
;
#ifndef PADDLE_MOBILE_FPGA
for
(
int
i
=
0
;
i
<
tensor
.
numel
();
i
+=
stride
)
{
if
(
tensor
.
type
()
==
type_id
<
float
>
()
.
name
()
)
{
if
(
tensor
.
type
()
==
type_id
<
float
>
())
{
printer
<<
tensor
.
data
<
float
>
()[
i
]
<<
" "
;
}
else
if
(
tensor
.
type
()
==
type_id
<
int32_t
>
()
.
name
()
)
{
}
else
if
(
tensor
.
type
()
==
type_id
<
int32_t
>
())
{
printer
<<
tensor
.
data
<
int32_t
>
()[
i
]
<<
" "
;
}
else
if
(
tensor
.
type
()
==
type_id
<
int64_t
>
()
.
name
()
)
{
}
else
if
(
tensor
.
type
()
==
type_id
<
int64_t
>
())
{
printer
<<
tensor
.
data
<
int64_t
>
()[
i
]
<<
" "
;
}
else
if
(
tensor
.
type
()
==
type_id
<
int8_t
>
()
.
name
()
)
{
}
else
if
(
tensor
.
type
()
==
type_id
<
int8_t
>
())
{
printer
<<
static_cast
<
int
>
(
tensor
.
data
<
int8_t
>
()[
i
])
<<
" "
;
}
else
if
(
tensor
.
type
()
==
type_id
<
int32_t
>
()
.
name
()
)
{
}
else
if
(
tensor
.
type
()
==
type_id
<
int32_t
>
())
{
printer
<<
tensor
.
data
<
int32_t
>
()[
i
]
<<
" "
;
}
}
...
...
src/framework/tensor_base.h
浏览文件 @
7701ac0d
...
...
@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include "common/enforce.h"
#include "common/type_define.h"
#include "common/types.h"
#include "framework/ddim.h"
...
...
@@ -27,8 +27,8 @@ struct SizeOfTypeFunctor;
template
<
typename
T
>
struct
SizeOfTypeFunctor
<
T
>
{
size_t
operator
()(
const
std
::
string
type
)
const
{
if
(
type_id
<
T
>
().
nam
e
()
==
type
)
{
size_t
operator
()(
const
kTypeId_t
type
)
const
{
if
(
type_id
<
T
>
().
hash_cod
e
()
==
type
)
{
return
sizeof
(
T
);
}
else
{
return
0UL
;
...
...
@@ -38,12 +38,12 @@ struct SizeOfTypeFunctor<T> {
template
<
>
struct
SizeOfTypeFunctor
<>
{
size_t
operator
()(
const
std
::
string
type
)
const
{
return
0UL
;
}
size_t
operator
()(
const
kTypeId_t
type
)
const
{
return
0UL
;
}
};
template
<
typename
HEAD
,
typename
...
TAIL
>
struct
SizeOfTypeFunctor
<
HEAD
,
TAIL
...
>
{
size_t
operator
()(
const
std
::
string
type
)
const
{
size_t
operator
()(
const
kTypeId_t
type
)
const
{
SizeOfTypeFunctor
<
HEAD
>
head
;
size_t
head_size
=
head
(
type
);
if
(
head_size
!=
0
)
{
...
...
@@ -54,14 +54,13 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
}
};
static
inline
size_t
SizeOfType
(
std
::
string
type
)
{
static
inline
size_t
SizeOfType
(
const
kTypeId_t
type
)
{
SizeOfTypeFunctor
<
int8_t
,
int
,
half
,
float
,
double
,
int16_t
,
int64_t
,
bool
,
size_t
>
functor
;
size_t
size
=
functor
(
type
);
PADDLE_MOBILE_ENFORCE
(
size
!=
0UL
,
"Cannot get size of type %s"
,
type
.
c_str
());
PADDLE_MOBILE_ENFORCE
(
size
!=
0UL
,
"Cannot get size of type %d"
,
type
);
return
size
;
}
...
...
@@ -77,7 +76,7 @@ class TensorBase {
/*! Return the numel of the memory block. */
inline
int64_t
numel
()
const
{
return
product
(
dims_
);
}
std
::
string
type
()
const
{
kTypeId_t
type
()
const
{
PADDLE_MOBILE_ENFORCE
(
holder_
!=
nullptr
,
"Tensor not initialized yet when Tensor::type() is called."
)
...
...
@@ -113,9 +112,9 @@ class TensorBase {
virtual
size_t
size
()
const
=
0
;
virtual
std
::
string
type
()
const
=
0
;
virtual
kTypeId_t
type
()
const
=
0
;
virtual
void
set_type
(
std
::
string
type
)
=
0
;
virtual
void
set_type
(
kTypeId_t
type
)
=
0
;
virtual
void
resize
(
size_t
size
)
=
0
;
};
...
...
src/framework/variable.h
浏览文件 @
7701ac0d
...
...
@@ -30,7 +30,7 @@ class Variable {
template
<
typename
T
>
const
T
GetValue
()
const
{
if
(
type_id
<
T
>
().
name
()
==
type_id
<
std
::
string
>
().
nam
e
())
{
if
(
type_id
<
T
>
().
hash_code
()
==
type_id
<
std
::
string
>
().
hash_cod
e
())
{
PADDLE_MOBILE_THROW_EXCEPTION
(
"Please use getString to get an string (to avoid of an issue with "
"gcc "
...
...
@@ -57,31 +57,32 @@ class Variable {
template
<
typename
T
>
bool
IsType
()
const
{
return
holder_
!=
nullptr
&&
holder_
->
Type
()
==
type_id
<
T
>
().
nam
e
();
return
holder_
!=
nullptr
&&
holder_
->
Type
()
==
type_id
<
T
>
().
hash_cod
e
();
}
void
Clear
()
{
holder_
.
reset
();
}
std
::
string
Type
()
const
{
return
holder_
->
Type
();
}
kTypeId_t
Type
()
const
{
return
holder_
->
Type
();
}
private:
struct
Placeholder
{
Placeholder
()
=
default
;
virtual
~
Placeholder
()
=
default
;
virtual
std
::
string
Type
()
const
=
0
;
virtual
kTypeId_t
Type
()
const
=
0
;
virtual
void
*
Ptr
()
const
=
0
;
};
template
<
typename
T
>
struct
PlaceholderImp
:
public
Placeholder
{
explicit
PlaceholderImp
(
T
*
ptr
)
:
ptr_
(
ptr
),
type_
(
type_id
<
T
>
().
name
())
{}
explicit
PlaceholderImp
(
T
*
ptr
)
:
ptr_
(
ptr
),
type_
(
type_id
<
T
>
().
hash_code
())
{}
std
::
string
Type
()
const
override
{
return
type_
;
}
kTypeId_t
Type
()
const
override
{
return
type_
;
}
void
*
Ptr
()
const
override
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
std
::
unique_ptr
<
T
>
ptr_
;
std
::
string
type_
;
kTypeId_t
type_
;
};
friend
class
Scope
;
...
...
src/operators/kernel/arm/compare_kernel.cpp
浏览文件 @
7701ac0d
...
...
@@ -192,10 +192,10 @@ bool LessThanKernel<CPU, float>::Init(CompareParam<CPU> *param) {
template
<
>
void
LessThanKernel
<
CPU
,
float
>::
Compute
(
const
CompareParam
<
CPU
>
&
param
)
{
if
(
param
.
input_x_
->
type
()
==
type_id
<
int64_t
>
().
nam
e
())
{
if
(
param
.
input_x_
->
type
()
==
type_id
<
int64_t
>
().
hash_cod
e
())
{
CompareCompute
<
int64_t
,
LESS_THAN
>
()(
param
.
input_x_
,
param
.
input_y_
,
param
.
axis_
,
param
.
output_
);
}
else
if
(
param
.
input_x_
->
type
()
==
type_id
<
float
>
().
nam
e
())
{
}
else
if
(
param
.
input_x_
->
type
()
==
type_id
<
float
>
().
hash_cod
e
())
{
CompareCompute
<
float
,
LESS_THAN
>
()(
param
.
input_x_
,
param
.
input_y_
,
param
.
axis_
,
param
.
output_
);
}
else
{
...
...
src/operators/kernel/arm/concat_kernel.cpp
浏览文件 @
7701ac0d
...
...
@@ -27,7 +27,7 @@ bool ConcatKernel<CPU, float>::Init(ConcatParam<CPU> *param) {
template
<
>
void
ConcatKernel
<
CPU
,
float
>::
Compute
(
const
ConcatParam
<
CPU
>
&
param
)
{
if
(
param
.
Inputs
()[
0
]
->
type
()
==
type_id
<
int8_t
>
().
nam
e
())
{
if
(
param
.
Inputs
()[
0
]
->
type
()
==
type_id
<
int8_t
>
().
hash_cod
e
())
{
ConcatCompute
<
int8_t
>
(
param
);
}
else
{
ConcatCompute
<
float
>
(
param
);
...
...
src/operators/kernel/arm/convolution/conv_common.cpp
浏览文件 @
7701ac0d
...
...
@@ -28,7 +28,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
bool
depth5x5
=
conv5x5
&&
param
->
Groups
()
==
param
->
Input
()
->
dims
()[
1
]
&&
param
->
Input
()
->
dims
()[
1
]
==
param
->
Output
()
->
dims
()[
1
];
if
(
param
->
Filter
()
->
type
()
==
type_id
<
int8_t
>
().
nam
e
())
{
if
(
param
->
Filter
()
->
type
()
==
type_id
<
int8_t
>
().
hash_cod
e
())
{
#ifndef __aarch64__
if
(
depth3x3
&&
param
->
Strides
()[
0
]
<
3
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
])
{
...
...
src/operators/kernel/arm/transpose2_kernel.cpp
浏览文件 @
7701ac0d
...
...
@@ -126,13 +126,13 @@ void Transpose2Kernel<CPU, float>::Compute(const Transpose2Param<CPU> ¶m) {
const
std
::
vector
<
int
>
&
axis
=
param
.
Axis
();
bool
shuffle_channel
=
IsShuffleChannel
(
axis
);
if
(
shuffle_channel
)
{
if
(
param
.
InputX
()
->
type
()
==
type_id
<
int8_t
>
().
nam
e
())
{
if
(
param
.
InputX
()
->
type
()
==
type_id
<
int8_t
>
().
hash_cod
e
())
{
ShuffleChannelCompute
<
int8_t
>
(
param
);
}
else
{
ShuffleChannelCompute
<
float
>
(
param
);
}
}
else
{
if
(
param
.
InputX
()
->
type
()
==
type_id
<
int8_t
>
().
nam
e
())
{
if
(
param
.
InputX
()
->
type
()
==
type_id
<
int8_t
>
().
hash_cod
e
())
{
Transpose2Compute
<
int8_t
>
(
param
);
}
else
{
Transpose2Compute
<
float
>
(
param
);
...
...
src/operators/kernel/central-arm-func/mul_arm_func.h
浏览文件 @
7701ac0d
...
...
@@ -37,7 +37,7 @@ void MulCompute(const MulParam<CPU> ¶m) {
if
(
out_dim
.
size
()
!=
2
)
{
out
->
Resize
({
x_matrix
.
dims
()[
0
],
y_matrix
.
dims
()[
1
]});
}
if
(
param
.
InputX
()
->
type
()
==
type_id
<
int8_t
>
().
nam
e
())
{
if
(
param
.
InputX
()
->
type
()
==
type_id
<
int8_t
>
().
hash_cod
e
())
{
out
->
mutable_data
<
int32_t
>
();
math
::
MatMul
<
int8_t
,
int32_t
>
(
x_matrix
,
false
,
y_matrix
,
false
,
static_cast
<
float
>
(
1
),
out
,
...
...
src/operators/kernel/central-arm-func/sum_arm_func.h
浏览文件 @
7701ac0d
...
...
@@ -144,8 +144,7 @@ void SumCompute(const SumParam<CPU> ¶m) {
}
}
else
{
PADDLE_MOBILE_THROW_EXCEPTION
(
"Unexpected branch, output variable type is %s"
,
outvar
->
Type
().
c_str
());
"Unexpected branch, output variable type is %d"
,
outvar
->
Type
());
}
}
}
// namespace operators
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录