Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
7701ac0d
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7701ac0d
编写于
3月 29, 2019
作者:
H
hjchen2
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use enum type rather than string since comparing between enum is more efficent than string
上级
2852680a
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
175 addition
and
105 deletion
+175
-105
src/common/type_define.h
src/common/type_define.h
+79
-10
src/common/variant.h
src/common/variant.h
+11
-11
src/framework/attribute.h
src/framework/attribute.h
+12
-13
src/framework/data_type.cpp
src/framework/data_type.cpp
+10
-10
src/framework/data_type.h
src/framework/data_type.h
+4
-4
src/framework/ddim.h
src/framework/ddim.h
+10
-10
src/framework/load_ops.h
src/framework/load_ops.h
+3
-0
src/framework/mixed_vector.h
src/framework/mixed_vector.h
+1
-1
src/framework/tensor.h
src/framework/tensor.h
+19
-19
src/framework/tensor_base.h
src/framework/tensor_base.h
+10
-11
src/framework/variable.h
src/framework/variable.h
+8
-7
src/operators/kernel/arm/compare_kernel.cpp
src/operators/kernel/arm/compare_kernel.cpp
+2
-2
src/operators/kernel/arm/concat_kernel.cpp
src/operators/kernel/arm/concat_kernel.cpp
+1
-1
src/operators/kernel/arm/convolution/conv_common.cpp
src/operators/kernel/arm/convolution/conv_common.cpp
+1
-1
src/operators/kernel/arm/transpose2_kernel.cpp
src/operators/kernel/arm/transpose2_kernel.cpp
+2
-2
src/operators/kernel/central-arm-func/mul_arm_func.h
src/operators/kernel/central-arm-func/mul_arm_func.h
+1
-1
src/operators/kernel/central-arm-func/sum_arm_func.h
src/operators/kernel/central-arm-func/sum_arm_func.h
+1
-2
未找到文件。
src/common/type_define.h
浏览文件 @
7701ac0d
...
@@ -14,28 +14,85 @@ limitations under the License. */
...
@@ -14,28 +14,85 @@ limitations under the License. */
#pragma once
#pragma once
#include <functional>
#include <string>
#include <string>
#include <vector>
#include <vector>
namespace
paddle_mobile
{
namespace
paddle_mobile
{
typedef
enum
{
_void
=
0
,
_float
,
_int
,
_double
,
_int64_t
,
_size_t
,
_int16_t
,
_int8_t
,
_uint8_t
,
_bool
,
_string
,
_floats
=
100
,
_ints
,
_int64_ts
,
_size_ts
,
_bools
,
_strings
,
_const_float
=
200
,
_const_int
,
_block
=
300
,
_tensor
,
_lod_tensor
,
_blocks
,
_tensors
,
_lod_tensors
,
_p_block
=
400
,
_p_tensor
,
_p_lod_tensor
,
_p_blocks
,
_p_tensors
,
_p_lod_tensors
,
_scopes
=
500
,
_selected_rows
,
_dim0
=
600
,
_dim1
,
_dim2
,
_dim3
,
_dim4
,
_dim5
,
_dim6
,
_dim7
,
_dim8
,
_dim9
,
}
kTypeId_t
;
template
<
typename
T
>
template
<
typename
T
>
struct
TypeIdWrapper
{
struct
TypeIdWrapper
{
std
::
string
type
();
inline
std
::
string
name
();
inline
kTypeId_t
hash_code
();
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
type_id
{
struct
type_id
{
const
std
::
string
type_
=
TypeIdWrapper
<
T
>
().
type
();
const
kTypeId_t
hash_code
()
const
{
return
TypeIdWrapper
<
T
>
().
hash_code
();
}
const
std
::
string
name
()
const
{
return
TypeIdWrapper
<
T
>
().
name
();
}
template
<
typename
OtherType
>
template
<
typename
OtherType
>
bool
operator
==
(
const
type_id
<
OtherType
>
&
operand
)
{
bool
operator
==
(
const
type_id
<
OtherType
>
&
operand
)
const
{
return
this
->
name
()
==
operand
.
nam
e
();
return
this
->
hash_code
()
==
operand
.
hash_cod
e
();
}
}
const
std
::
string
name
()
{
return
type_
;
}
};
};
template
<
typename
T
>
inline
bool
operator
==
(
const
kTypeId_t
&
t0
,
const
type_id
<
T
>
&
t1
)
{
return
t0
==
t1
.
hash_code
();
}
template
<
typename
T
>
inline
bool
operator
==
(
const
type_id
<
T
>
&
t0
,
const
kTypeId_t
&
t1
)
{
return
t1
==
t0
.
hash_code
();
}
namespace
framework
{
namespace
framework
{
class
BlockDesc
;
class
BlockDesc
;
class
Tensor
;
class
Tensor
;
...
@@ -50,7 +107,8 @@ struct Dim;
...
@@ -50,7 +107,8 @@ struct Dim;
#define REGISTER_TYPE_ID(Type, TypeName) \
#define REGISTER_TYPE_ID(Type, TypeName) \
template <> \
template <> \
struct TypeIdWrapper<Type> { \
struct TypeIdWrapper<Type> { \
std::string type() { return std::string(#TypeName); } \
inline std::string name() { return std::string(#TypeName); } \
inline kTypeId_t hash_code() { return kTypeId_t::TypeName; } \
};
};
REGISTER_TYPE_ID
(
void
,
_void
)
REGISTER_TYPE_ID
(
void
,
_void
)
...
@@ -102,3 +160,14 @@ REGISTER_TYPE_ID(framework::Dim<8>, _dim8)
...
@@ -102,3 +160,14 @@ REGISTER_TYPE_ID(framework::Dim<8>, _dim8)
REGISTER_TYPE_ID
(
framework
::
Dim
<
9
>
,
_dim9
)
REGISTER_TYPE_ID
(
framework
::
Dim
<
9
>
,
_dim9
)
}
// namespace paddle_mobile
}
// namespace paddle_mobile
namespace
std
{
template
<
>
struct
hash
<
paddle_mobile
::
kTypeId_t
>
{
size_t
operator
()(
const
paddle_mobile
::
kTypeId_t
&
t
)
const
{
return
std
::
hash
<
int
>
{}(
static_cast
<
int
>
(
t
));
}
};
}
// namespace std
src/common/variant.h
浏览文件 @
7701ac0d
...
@@ -34,8 +34,8 @@ struct VariantHelper {
...
@@ -34,8 +34,8 @@ struct VariantHelper {
?
sizeof
(
F
)
?
sizeof
(
F
)
:
VariantHelper
<
Ts
...
>::
size
;
:
VariantHelper
<
Ts
...
>::
size
;
inline
static
void
Destroy
(
std
::
string
type
,
void
*
data
)
{
inline
static
void
Destroy
(
kTypeId_t
type
,
void
*
data
)
{
if
(
type
==
type_id
<
F
>
()
.
name
()
)
{
if
(
type
==
type_id
<
F
>
())
{
reinterpret_cast
<
F
*>
(
data
)
->~
F
();
reinterpret_cast
<
F
*>
(
data
)
->~
F
();
}
else
{
}
else
{
VariantHelper
<
Ts
...
>::
Destroy
(
type
,
data
);
VariantHelper
<
Ts
...
>::
Destroy
(
type
,
data
);
...
@@ -46,8 +46,8 @@ struct VariantHelper {
...
@@ -46,8 +46,8 @@ struct VariantHelper {
template
<
typename
F
>
template
<
typename
F
>
struct
VariantHelper
<
F
>
{
struct
VariantHelper
<
F
>
{
static
const
size_t
size
=
sizeof
(
F
);
static
const
size_t
size
=
sizeof
(
F
);
inline
static
void
Destroy
(
std
::
string
type
,
void
*
data
)
{
inline
static
void
Destroy
(
kTypeId_t
type
,
void
*
data
)
{
if
(
type
==
type_id
<
F
>
()
.
name
()
)
{
if
(
type
==
type_id
<
F
>
())
{
// reinterpret_cast<F*>(data)->~F();
// reinterpret_cast<F*>(data)->~F();
}
else
{
}
else
{
// std::cout << "未匹配到 " << std::endl;
// std::cout << "未匹配到 " << std::endl;
...
@@ -85,17 +85,17 @@ struct Variant {
...
@@ -85,17 +85,17 @@ struct Variant {
void
Set
(
Args
&&
...
args
)
{
void
Set
(
Args
&&
...
args
)
{
helper
::
Destroy
(
type_
,
data_
.
data
);
helper
::
Destroy
(
type_
,
data_
.
data
);
new
(
data_
.
data
)
T
(
std
::
forward
<
Args
>
(
args
)...);
new
(
data_
.
data
)
T
(
std
::
forward
<
Args
>
(
args
)...);
type_
=
type_id
<
T
>
().
nam
e
();
type_
=
type_id
<
T
>
().
hash_cod
e
();
}
}
void
SetString
(
const
std
::
string
&
string
)
{
void
SetString
(
const
std
::
string
&
string
)
{
helper
::
Destroy
(
type_
,
data_
.
data
);
helper
::
Destroy
(
type_
,
data_
.
data
);
type_
=
type_id
<
std
::
string
>
().
nam
e
();
type_
=
type_id
<
std
::
string
>
().
hash_cod
e
();
strcpy
(
data_
.
data
,
string
.
c_str
());
// NOLINT
strcpy
(
data_
.
data
,
string
.
c_str
());
// NOLINT
}
}
std
::
string
GetString
()
const
{
std
::
string
GetString
()
const
{
if
(
type_
==
type_id
<
std
::
string
>
()
.
name
()
)
{
if
(
type_
==
type_id
<
std
::
string
>
())
{
return
std
::
string
(
data_
.
data
);
return
std
::
string
(
data_
.
data
);
}
else
{
}
else
{
PADDLE_MOBILE_THROW_EXCEPTION
(
PADDLE_MOBILE_THROW_EXCEPTION
(
...
@@ -106,7 +106,7 @@ struct Variant {
...
@@ -106,7 +106,7 @@ struct Variant {
template
<
typename
T
>
template
<
typename
T
>
T
&
Get
()
const
{
T
&
Get
()
const
{
if
(
type_
==
type_id
<
std
::
string
>
()
.
name
()
)
{
if
(
type_
==
type_id
<
std
::
string
>
())
{
PADDLE_MOBILE_THROW_EXCEPTION
(
PADDLE_MOBILE_THROW_EXCEPTION
(
"Please use getString to get an string (to avoid of an issue with "
"Please use getString to get an string (to avoid of an issue with "
"gcc "
"gcc "
...
@@ -117,12 +117,12 @@ struct Variant {
...
@@ -117,12 +117,12 @@ struct Variant {
}
}
}
}
std
::
string
TypeId
()
const
{
return
type_
;
}
kTypeId_t
TypeId
()
const
{
return
type_
;
}
private:
private:
static
inline
std
::
string
invalid_type
()
{
return
type_id
<
void
>
().
nam
e
();
}
static
inline
kTypeId_t
invalid_type
()
{
return
type_id
<
void
>
().
hash_cod
e
();
}
typedef
VariantHelper
<
Ts
...
>
helper
;
typedef
VariantHelper
<
Ts
...
>
helper
;
std
::
string
type_
=
type_id
<
void
>
().
nam
e
();
kTypeId_t
type_
=
type_id
<
void
>
().
hash_cod
e
();
// todo use an anto size to suite this.
// todo use an anto size to suite this.
RawData
<
64
>
data_
;
RawData
<
64
>
data_
;
};
};
...
...
src/framework/attribute.h
浏览文件 @
7701ac0d
...
@@ -128,31 +128,30 @@ class Attribute {
...
@@ -128,31 +128,30 @@ class Attribute {
template
<
typename
Vistor
>
template
<
typename
Vistor
>
static
typename
Vistor
::
type_t
ApplyVistor
(
Vistor
vistor
,
Attribute
attr
)
{
static
typename
Vistor
::
type_t
ApplyVistor
(
Vistor
vistor
,
Attribute
attr
)
{
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
int
>
()
.
name
()
)
{
// NOLINT
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
int
>
())
{
// NOLINT
return
vistor
(
attr
.
variant_
.
Get
<
int
>
());
return
vistor
(
attr
.
variant_
.
Get
<
int
>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
float
>
()
.
name
()
)
{
// NOLINT
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
float
>
())
{
// NOLINT
return
vistor
(
attr
.
variant_
.
Get
<
float
>
());
return
vistor
(
attr
.
variant_
.
Get
<
float
>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
string
>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
string
>
())
{
return
vistor
(
attr
.
variant_
.
GetString
());
return
vistor
(
attr
.
variant_
.
GetString
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
int
>>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
int
>>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
int
>>
());
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
int
>>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
float
>>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
float
>>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
float
>>
());
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
float
>>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
string
>>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
string
>>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
string
>>
());
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
string
>>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
bool
>
()
.
name
()
)
{
// NOLINT
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
bool
>
())
{
// NOLINT
return
vistor
(
attr
.
variant_
.
Get
<
bool
>
());
return
vistor
(
attr
.
variant_
.
Get
<
bool
>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
bool
>>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
bool
>>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
bool
>>
());
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
bool
>>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
int64_t
>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
int64_t
>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
int64_t
>
());
return
vistor
(
attr
.
variant_
.
Get
<
int64_t
>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
framework
::
BlockDesc
*>
())
{
type_id
<
framework
::
BlockDesc
*>
().
name
())
{
return
vistor
(
attr
.
variant_
.
Get
<
framework
::
BlockDesc
*>
());
return
vistor
(
attr
.
variant_
.
Get
<
framework
::
BlockDesc
*>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
framework
::
BlockDesc
*>>
()
.
name
()
)
{
type_id
<
vector
<
framework
::
BlockDesc
*>>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
framework
::
BlockDesc
*>>
());
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
framework
::
BlockDesc
*>>
());
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
int64_t
>>
()
.
name
()
)
{
}
else
if
(
attr
.
variant_
.
TypeId
()
==
type_id
<
vector
<
int64_t
>>
())
{
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
int64_t
>>
());
return
vistor
(
attr
.
variant_
.
Get
<
vector
<
int64_t
>>
());
}
else
{
}
else
{
PADDLE_MOBILE_THROW_EXCEPTION
(
"type not support"
);
PADDLE_MOBILE_THROW_EXCEPTION
(
"type not support"
);
...
...
src/framework/data_type.cpp
浏览文件 @
7701ac0d
...
@@ -22,12 +22,11 @@ namespace paddle_mobile {
...
@@ -22,12 +22,11 @@ namespace paddle_mobile {
namespace
framework
{
namespace
framework
{
struct
DataTypeMap
{
struct
DataTypeMap
{
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
kTypeId_t
,
_PaddleMobile__Framework__Proto__VarType__Type
>
_PaddleMobile__Framework__Proto__VarType__Type
>
cpp_to_proto_
;
cpp_to_proto_
;
std
::
unordered_map
<
int
,
std
::
string
>
proto_to_cpp_
;
std
::
unordered_map
<
int
,
kTypeId_t
>
proto_to_cpp_
;
std
::
unordered_map
<
int
,
std
::
string
>
proto_to_str_
;
std
::
unordered_map
<
int
,
std
::
string
>
proto_to_str_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
cpp_to_size_
;
std
::
unordered_map
<
kTypeId_t
,
size_t
>
cpp_to_size_
;
};
};
static
DataTypeMap
*
InitDataTypeMap
();
static
DataTypeMap
*
InitDataTypeMap
();
...
@@ -43,10 +42,11 @@ template <typename T>
...
@@ -43,10 +42,11 @@ template <typename T>
static
inline
void
RegisterType
(
static
inline
void
RegisterType
(
DataTypeMap
*
map
,
_PaddleMobile__Framework__Proto__VarType__Type
proto_type
,
DataTypeMap
*
map
,
_PaddleMobile__Framework__Proto__VarType__Type
proto_type
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
)
{
map
->
proto_to_cpp_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
type_id
<
T
>
().
name
());
map
->
proto_to_cpp_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
map
->
cpp_to_proto_
.
emplace
(
type_id
<
T
>
().
name
(),
proto_type
);
type_id
<
T
>
().
hash_code
());
map
->
cpp_to_proto_
.
emplace
(
type_id
<
T
>
().
hash_code
(),
proto_type
);
map
->
proto_to_str_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
name
);
map
->
proto_to_str_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
name
);
map
->
cpp_to_size_
.
emplace
(
type_id
<
T
>
().
nam
e
(),
sizeof
(
T
));
map
->
cpp_to_size_
.
emplace
(
type_id
<
T
>
().
hash_cod
e
(),
sizeof
(
T
));
}
}
static
DataTypeMap
*
InitDataTypeMap
()
{
static
DataTypeMap
*
InitDataTypeMap
()
{
...
@@ -71,15 +71,15 @@ static DataTypeMap* InitDataTypeMap() {
...
@@ -71,15 +71,15 @@ static DataTypeMap* InitDataTypeMap() {
return
retv
;
return
retv
;
}
}
_PaddleMobile__Framework__Proto__VarType__Type
ToDataType
(
std
::
string
type
)
{
_PaddleMobile__Framework__Proto__VarType__Type
ToDataType
(
kTypeId_t
type
)
{
auto
it
=
gDataTypeMap
().
cpp_to_proto_
.
find
(
type
);
auto
it
=
gDataTypeMap
().
cpp_to_proto_
.
find
(
type
);
if
(
it
!=
gDataTypeMap
().
cpp_to_proto_
.
end
())
{
if
(
it
!=
gDataTypeMap
().
cpp_to_proto_
.
end
())
{
return
it
->
second
;
return
it
->
second
;
}
}
PADDLE_MOBILE_THROW_EXCEPTION
(
"Not support %
s as tensor type"
,
type
.
c_str
()
);
PADDLE_MOBILE_THROW_EXCEPTION
(
"Not support %
d as tensor type"
,
type
);
}
}
std
::
string
ToTypeIndex
(
_PaddleMobile__Framework__Proto__VarType__Type
type
)
{
kTypeId_t
ToTypeIndex
(
_PaddleMobile__Framework__Proto__VarType__Type
type
)
{
auto
it
=
gDataTypeMap
().
proto_to_cpp_
.
find
(
static_cast
<
int
>
(
type
));
auto
it
=
gDataTypeMap
().
proto_to_cpp_
.
find
(
static_cast
<
int
>
(
type
));
if
(
it
!=
gDataTypeMap
().
proto_to_cpp_
.
end
())
{
if
(
it
!=
gDataTypeMap
().
proto_to_cpp_
.
end
())
{
return
it
->
second
;
return
it
->
second
;
...
...
src/framework/data_type.h
浏览文件 @
7701ac0d
...
@@ -16,16 +16,16 @@ limitations under the License. */
...
@@ -16,16 +16,16 @@ limitations under the License. */
#include <string>
#include <string>
#include "common/enforce.h"
#include "common/enforce.h"
#include "common/type_define.h"
#include "framework/framework.pb-c.h"
#include "framework/framework.pb-c.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
framework
{
namespace
framework
{
extern
_PaddleMobile__Framework__Proto__VarType__Type
ToDataType
(
_PaddleMobile__Framework__Proto__VarType__Type
ToDataType
(
kTypeId_t
type
);
std
::
string
type
);
extern
std
::
string
ToTypeIndex
(
kTypeId_t
ToTypeIndex
(
_PaddleMobile__Framework__Proto__VarType__Type
type
);
_PaddleMobile__Framework__Proto__VarType__Type
type
);
inline
_PaddleMobile__Framework__Proto__VarType__Type
ToDataType
(
int
type
)
{
inline
_PaddleMobile__Framework__Proto__VarType__Type
ToDataType
(
int
type
)
{
return
static_cast
<
_PaddleMobile__Framework__Proto__VarType__Type
>
(
type
);
return
static_cast
<
_PaddleMobile__Framework__Proto__VarType__Type
>
(
type
);
...
...
src/framework/ddim.h
浏览文件 @
7701ac0d
...
@@ -40,25 +40,25 @@ struct DDim {
...
@@ -40,25 +40,25 @@ struct DDim {
template
<
typename
Vistor
>
template
<
typename
Vistor
>
static
typename
Vistor
::
type_t
ApplyVistor
(
Vistor
vistor
,
const
DDim
&
d
)
{
static
typename
Vistor
::
type_t
ApplyVistor
(
Vistor
vistor
,
const
DDim
&
d
)
{
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
0
>>
()
.
name
()
)
{
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
0
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
0
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
0
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
1
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
1
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
1
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
1
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
2
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
2
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
2
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
2
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
3
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
3
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
3
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
3
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
4
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
4
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
4
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
4
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
5
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
5
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
5
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
5
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
6
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
6
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
6
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
6
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
7
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
7
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
7
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
7
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
8
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
8
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
8
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
8
>>
());
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
9
>>
()
.
name
()
)
{
}
else
if
(
d
.
var
.
TypeId
()
==
type_id
<
Dim
<
9
>>
())
{
return
vistor
(
d
.
var
.
Get
<
Dim
<
9
>>
());
return
vistor
(
d
.
var
.
Get
<
Dim
<
9
>>
());
}
else
{
}
else
{
PADDLE_MOBILE_ENFORCE
(
false
,
" dim not support"
);
PADDLE_MOBILE_ENFORCE
(
false
,
" dim not support"
);
...
...
src/framework/load_ops.h
浏览文件 @
7701ac0d
...
@@ -268,6 +268,9 @@ LOAD_OP1(sequence_expand, CPU);
...
@@ -268,6 +268,9 @@ LOAD_OP1(sequence_expand, CPU);
#ifdef SEQUENCE_POOL_OP
#ifdef SEQUENCE_POOL_OP
LOAD_OP1
(
sequence_pool
,
CPU
);
LOAD_OP1
(
sequence_pool
,
CPU
);
#endif
#endif
#ifdef SEQUENCE_SOFTMAX_OP
LOAD_OP1
(
sequence_softmax
,
CPU
);
#endif
#ifdef LOG_OP
#ifdef LOG_OP
LOAD_OP1
(
log
,
CPU
);
LOAD_OP1
(
log
,
CPU
);
#endif
#endif
...
...
src/framework/mixed_vector.h
浏览文件 @
7701ac0d
...
@@ -197,7 +197,7 @@ class Vector {
...
@@ -197,7 +197,7 @@ class Vector {
}
}
size_t
capacity
()
const
{
size_t
capacity
()
const
{
return
cpu_vec_
.
memory_size
()
/
SizeOfType
(
type_id
<
T
>
().
nam
e
());
return
cpu_vec_
.
memory_size
()
/
SizeOfType
(
type_id
<
T
>
().
hash_cod
e
());
}
}
// reserve data
// reserve data
...
...
src/framework/tensor.h
浏览文件 @
7701ac0d
...
@@ -81,7 +81,7 @@ class Tensor : public TensorBase {
...
@@ -81,7 +81,7 @@ class Tensor : public TensorBase {
return
*
this
;
return
*
this
;
}
}
inline
void
*
mutable_data
(
const
std
::
string
type
)
{
inline
void
*
mutable_data
(
const
kTypeId_t
type
)
{
if
(
holder_
!=
nullptr
)
{
if
(
holder_
!=
nullptr
)
{
holder_
->
set_type
(
type
);
holder_
->
set_type
(
type
);
}
}
...
@@ -106,7 +106,7 @@ class Tensor : public TensorBase {
...
@@ -106,7 +106,7 @@ class Tensor : public TensorBase {
template
<
typename
T
>
template
<
typename
T
>
inline
T
*
mutable_data
()
{
inline
T
*
mutable_data
()
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
return
reinterpret_cast
<
T
*>
(
mutable_data
(
type_id
<
T
>
().
nam
e
()));
return
reinterpret_cast
<
T
*>
(
mutable_data
(
type_id
<
T
>
().
hash_cod
e
()));
}
}
/**
/**
...
@@ -163,9 +163,9 @@ class Tensor : public TensorBase {
...
@@ -163,9 +163,9 @@ class Tensor : public TensorBase {
check_memory_size
();
check_memory_size
();
PADDLE_MOBILE_ENFORCE
(
PADDLE_MOBILE_ENFORCE
(
(
std
::
is_same
<
T
,
void
>::
value
||
(
std
::
is_same
<
T
,
void
>::
value
||
holder_
->
type
()
==
type_id
<
T
>
().
nam
e
()),
holder_
->
type
()
==
type_id
<
T
>
().
hash_cod
e
()),
"Tensor holds the wrong type, it holds %
s, requested %s
"
,
"Tensor holds the wrong type, it holds %
d, requested %d
"
,
this
->
holder_
->
type
()
.
c_str
(),
type_id
<
T
>
().
name
().
c_str
());
this
->
holder_
->
type
()
,
type_id
<
T
>
().
hash_code
());
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
offset_
);
...
@@ -177,9 +177,9 @@ class Tensor : public TensorBase {
...
@@ -177,9 +177,9 @@ class Tensor : public TensorBase {
check_memory_size
();
check_memory_size
();
PADDLE_MOBILE_ENFORCE
(
PADDLE_MOBILE_ENFORCE
(
(
std
::
is_same
<
T
,
void
>::
value
||
(
std
::
is_same
<
T
,
void
>::
value
||
holder_
->
type
()
==
type_id
<
T
>
().
nam
e
()),
holder_
->
type
()
==
type_id
<
T
>
().
hash_cod
e
()),
"Tensor holds the wrong type, it holds %
s, requested %s
"
,
"Tensor holds the wrong type, it holds %
d, requested %d
"
,
this
->
holder_
->
type
()
.
c_str
(),
type_id
<
T
>
().
name
().
c_str
());
this
->
holder_
->
type
()
,
type_id
<
T
>
().
hash_code
());
return
reinterpret_cast
<
const
T
*>
(
return
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
...
@@ -187,7 +187,7 @@ class Tensor : public TensorBase {
...
@@ -187,7 +187,7 @@ class Tensor : public TensorBase {
private:
private:
struct
PlaceholderImpl
:
public
Placeholder
{
struct
PlaceholderImpl
:
public
Placeholder
{
PlaceholderImpl
(
size_t
size
,
const
std
::
string
type
)
PlaceholderImpl
(
size_t
size
,
const
kTypeId_t
type
)
:
ptr_
(
static_cast
<
uint8_t
*>
(
memory
::
Alloc
(
size
)),
:
ptr_
(
static_cast
<
uint8_t
*>
(
memory
::
Alloc
(
size
)),
memory
::
PODDeleter
<
uint8_t
>
()),
memory
::
PODDeleter
<
uint8_t
>
()),
size_
(
size
),
size_
(
size
),
...
@@ -201,9 +201,9 @@ class Tensor : public TensorBase {
...
@@ -201,9 +201,9 @@ class Tensor : public TensorBase {
virtual
void
*
ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
virtual
void
*
ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
virtual
std
::
string
type
()
const
{
return
type_
;
}
virtual
kTypeId_t
type
()
const
{
return
type_
;
}
virtual
void
set_type
(
const
std
::
string
type
)
{
type_
=
type
;
}
virtual
void
set_type
(
const
kTypeId_t
type
)
{
type_
=
type
;
}
virtual
void
resize
(
size_t
size
)
{
virtual
void
resize
(
size_t
size
)
{
if
(
size
>
capatity_
)
{
if
(
size
>
capatity_
)
{
...
@@ -221,7 +221,7 @@ class Tensor : public TensorBase {
...
@@ -221,7 +221,7 @@ class Tensor : public TensorBase {
size_t
capatity_
;
size_t
capatity_
;
/* the current type of memory */
/* the current type of memory */
std
::
string
type_
;
kTypeId_t
type_
;
};
};
#ifdef PADDLE_MOBILE_FPGA
#ifdef PADDLE_MOBILE_FPGA
...
@@ -229,13 +229,13 @@ class Tensor : public TensorBase {
...
@@ -229,13 +229,13 @@ class Tensor : public TensorBase {
inline
void
reset_data_ptr
(
void
*
p
)
{
inline
void
reset_data_ptr
(
void
*
p
)
{
((
PlaceholderImpl
*
)(
holder_
.
get
()))
->
ptr_
.
reset
((
uint8_t
*
)
p
);
// NOLINT
((
PlaceholderImpl
*
)(
holder_
.
get
()))
->
ptr_
.
reset
((
uint8_t
*
)
p
);
// NOLINT
}
}
inline
void
set_type
(
const
std
::
string
type
)
{
holder_
->
set_type
(
type
);
}
inline
void
set_type
(
const
kTypeId_t
type
)
{
holder_
->
set_type
(
type
);
}
inline
void
*
get_data
()
{
inline
void
*
get_data
()
{
return
(
return
(
void
*
)(((
PlaceholderImpl
*
)(
holder_
.
get
()))
->
ptr_
.
get
());
// NOLINT
void
*
)(((
PlaceholderImpl
*
)(
holder_
.
get
()))
->
ptr_
.
get
());
// NOLINT
}
}
inline
void
*
init
(
const
std
::
string
type
)
{
inline
void
*
init
(
const
kTypeId_t
type
)
{
if
(
holder_
!=
nullptr
)
{
if
(
holder_
!=
nullptr
)
{
holder_
->
set_type
(
type
);
holder_
->
set_type
(
type
);
}
}
...
@@ -263,15 +263,15 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) {
...
@@ -263,15 +263,15 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) {
stride
=
stride
>
0
?
stride
:
1
;
stride
=
stride
>
0
?
stride
:
1
;
#ifndef PADDLE_MOBILE_FPGA
#ifndef PADDLE_MOBILE_FPGA
for
(
int
i
=
0
;
i
<
tensor
.
numel
();
i
+=
stride
)
{
for
(
int
i
=
0
;
i
<
tensor
.
numel
();
i
+=
stride
)
{
if
(
tensor
.
type
()
==
type_id
<
float
>
()
.
name
()
)
{
if
(
tensor
.
type
()
==
type_id
<
float
>
())
{
printer
<<
tensor
.
data
<
float
>
()[
i
]
<<
" "
;
printer
<<
tensor
.
data
<
float
>
()[
i
]
<<
" "
;
}
else
if
(
tensor
.
type
()
==
type_id
<
int32_t
>
()
.
name
()
)
{
}
else
if
(
tensor
.
type
()
==
type_id
<
int32_t
>
())
{
printer
<<
tensor
.
data
<
int32_t
>
()[
i
]
<<
" "
;
printer
<<
tensor
.
data
<
int32_t
>
()[
i
]
<<
" "
;
}
else
if
(
tensor
.
type
()
==
type_id
<
int64_t
>
()
.
name
()
)
{
}
else
if
(
tensor
.
type
()
==
type_id
<
int64_t
>
())
{
printer
<<
tensor
.
data
<
int64_t
>
()[
i
]
<<
" "
;
printer
<<
tensor
.
data
<
int64_t
>
()[
i
]
<<
" "
;
}
else
if
(
tensor
.
type
()
==
type_id
<
int8_t
>
()
.
name
()
)
{
}
else
if
(
tensor
.
type
()
==
type_id
<
int8_t
>
())
{
printer
<<
static_cast
<
int
>
(
tensor
.
data
<
int8_t
>
()[
i
])
<<
" "
;
printer
<<
static_cast
<
int
>
(
tensor
.
data
<
int8_t
>
()[
i
])
<<
" "
;
}
else
if
(
tensor
.
type
()
==
type_id
<
int32_t
>
()
.
name
()
)
{
}
else
if
(
tensor
.
type
()
==
type_id
<
int32_t
>
())
{
printer
<<
tensor
.
data
<
int32_t
>
()[
i
]
<<
" "
;
printer
<<
tensor
.
data
<
int32_t
>
()[
i
]
<<
" "
;
}
}
}
}
...
...
src/framework/tensor_base.h
浏览文件 @
7701ac0d
...
@@ -14,8 +14,8 @@ limitations under the License. */
...
@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#pragma once
#include <string>
#include "common/enforce.h"
#include "common/enforce.h"
#include "common/type_define.h"
#include "common/types.h"
#include "common/types.h"
#include "framework/ddim.h"
#include "framework/ddim.h"
...
@@ -27,8 +27,8 @@ struct SizeOfTypeFunctor;
...
@@ -27,8 +27,8 @@ struct SizeOfTypeFunctor;
template
<
typename
T
>
template
<
typename
T
>
struct
SizeOfTypeFunctor
<
T
>
{
struct
SizeOfTypeFunctor
<
T
>
{
size_t
operator
()(
const
std
::
string
type
)
const
{
size_t
operator
()(
const
kTypeId_t
type
)
const
{
if
(
type_id
<
T
>
().
nam
e
()
==
type
)
{
if
(
type_id
<
T
>
().
hash_cod
e
()
==
type
)
{
return
sizeof
(
T
);
return
sizeof
(
T
);
}
else
{
}
else
{
return
0UL
;
return
0UL
;
...
@@ -38,12 +38,12 @@ struct SizeOfTypeFunctor<T> {
...
@@ -38,12 +38,12 @@ struct SizeOfTypeFunctor<T> {
template
<
>
template
<
>
struct
SizeOfTypeFunctor
<>
{
struct
SizeOfTypeFunctor
<>
{
size_t
operator
()(
const
std
::
string
type
)
const
{
return
0UL
;
}
size_t
operator
()(
const
kTypeId_t
type
)
const
{
return
0UL
;
}
};
};
template
<
typename
HEAD
,
typename
...
TAIL
>
template
<
typename
HEAD
,
typename
...
TAIL
>
struct
SizeOfTypeFunctor
<
HEAD
,
TAIL
...
>
{
struct
SizeOfTypeFunctor
<
HEAD
,
TAIL
...
>
{
size_t
operator
()(
const
std
::
string
type
)
const
{
size_t
operator
()(
const
kTypeId_t
type
)
const
{
SizeOfTypeFunctor
<
HEAD
>
head
;
SizeOfTypeFunctor
<
HEAD
>
head
;
size_t
head_size
=
head
(
type
);
size_t
head_size
=
head
(
type
);
if
(
head_size
!=
0
)
{
if
(
head_size
!=
0
)
{
...
@@ -54,14 +54,13 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
...
@@ -54,14 +54,13 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
}
}
};
};
static
inline
size_t
SizeOfType
(
std
::
string
type
)
{
static
inline
size_t
SizeOfType
(
const
kTypeId_t
type
)
{
SizeOfTypeFunctor
<
int8_t
,
int
,
half
,
float
,
double
,
int16_t
,
int64_t
,
bool
,
SizeOfTypeFunctor
<
int8_t
,
int
,
half
,
float
,
double
,
int16_t
,
int64_t
,
bool
,
size_t
>
size_t
>
functor
;
functor
;
size_t
size
=
functor
(
type
);
size_t
size
=
functor
(
type
);
PADDLE_MOBILE_ENFORCE
(
size
!=
0UL
,
"Cannot get size of type %s"
,
PADDLE_MOBILE_ENFORCE
(
size
!=
0UL
,
"Cannot get size of type %d"
,
type
);
type
.
c_str
());
return
size
;
return
size
;
}
}
...
@@ -77,7 +76,7 @@ class TensorBase {
...
@@ -77,7 +76,7 @@ class TensorBase {
/*! Return the numel of the memory block. */
/*! Return the numel of the memory block. */
inline
int64_t
numel
()
const
{
return
product
(
dims_
);
}
inline
int64_t
numel
()
const
{
return
product
(
dims_
);
}
std
::
string
type
()
const
{
kTypeId_t
type
()
const
{
PADDLE_MOBILE_ENFORCE
(
PADDLE_MOBILE_ENFORCE
(
holder_
!=
nullptr
,
holder_
!=
nullptr
,
"Tensor not initialized yet when Tensor::type() is called."
)
"Tensor not initialized yet when Tensor::type() is called."
)
...
@@ -113,9 +112,9 @@ class TensorBase {
...
@@ -113,9 +112,9 @@ class TensorBase {
virtual
size_t
size
()
const
=
0
;
virtual
size_t
size
()
const
=
0
;
virtual
std
::
string
type
()
const
=
0
;
virtual
kTypeId_t
type
()
const
=
0
;
virtual
void
set_type
(
std
::
string
type
)
=
0
;
virtual
void
set_type
(
kTypeId_t
type
)
=
0
;
virtual
void
resize
(
size_t
size
)
=
0
;
virtual
void
resize
(
size_t
size
)
=
0
;
};
};
...
...
src/framework/variable.h
浏览文件 @
7701ac0d
...
@@ -30,7 +30,7 @@ class Variable {
...
@@ -30,7 +30,7 @@ class Variable {
template
<
typename
T
>
template
<
typename
T
>
const
T
GetValue
()
const
{
const
T
GetValue
()
const
{
if
(
type_id
<
T
>
().
name
()
==
type_id
<
std
::
string
>
().
nam
e
())
{
if
(
type_id
<
T
>
().
hash_code
()
==
type_id
<
std
::
string
>
().
hash_cod
e
())
{
PADDLE_MOBILE_THROW_EXCEPTION
(
PADDLE_MOBILE_THROW_EXCEPTION
(
"Please use getString to get an string (to avoid of an issue with "
"Please use getString to get an string (to avoid of an issue with "
"gcc "
"gcc "
...
@@ -57,31 +57,32 @@ class Variable {
...
@@ -57,31 +57,32 @@ class Variable {
template
<
typename
T
>
template
<
typename
T
>
bool
IsType
()
const
{
bool
IsType
()
const
{
return
holder_
!=
nullptr
&&
holder_
->
Type
()
==
type_id
<
T
>
().
nam
e
();
return
holder_
!=
nullptr
&&
holder_
->
Type
()
==
type_id
<
T
>
().
hash_cod
e
();
}
}
void
Clear
()
{
holder_
.
reset
();
}
void
Clear
()
{
holder_
.
reset
();
}
std
::
string
Type
()
const
{
return
holder_
->
Type
();
}
kTypeId_t
Type
()
const
{
return
holder_
->
Type
();
}
private:
private:
struct
Placeholder
{
struct
Placeholder
{
Placeholder
()
=
default
;
Placeholder
()
=
default
;
virtual
~
Placeholder
()
=
default
;
virtual
~
Placeholder
()
=
default
;
virtual
std
::
string
Type
()
const
=
0
;
virtual
kTypeId_t
Type
()
const
=
0
;
virtual
void
*
Ptr
()
const
=
0
;
virtual
void
*
Ptr
()
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
PlaceholderImp
:
public
Placeholder
{
struct
PlaceholderImp
:
public
Placeholder
{
explicit
PlaceholderImp
(
T
*
ptr
)
:
ptr_
(
ptr
),
type_
(
type_id
<
T
>
().
name
())
{}
explicit
PlaceholderImp
(
T
*
ptr
)
:
ptr_
(
ptr
),
type_
(
type_id
<
T
>
().
hash_code
())
{}
std
::
string
Type
()
const
override
{
return
type_
;
}
kTypeId_t
Type
()
const
override
{
return
type_
;
}
void
*
Ptr
()
const
override
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
void
*
Ptr
()
const
override
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
std
::
unique_ptr
<
T
>
ptr_
;
std
::
unique_ptr
<
T
>
ptr_
;
std
::
string
type_
;
kTypeId_t
type_
;
};
};
friend
class
Scope
;
friend
class
Scope
;
...
...
src/operators/kernel/arm/compare_kernel.cpp
浏览文件 @
7701ac0d
...
@@ -192,10 +192,10 @@ bool LessThanKernel<CPU, float>::Init(CompareParam<CPU> *param) {
...
@@ -192,10 +192,10 @@ bool LessThanKernel<CPU, float>::Init(CompareParam<CPU> *param) {
template
<
>
template
<
>
void
LessThanKernel
<
CPU
,
float
>::
Compute
(
const
CompareParam
<
CPU
>
&
param
)
{
void
LessThanKernel
<
CPU
,
float
>::
Compute
(
const
CompareParam
<
CPU
>
&
param
)
{
if
(
param
.
input_x_
->
type
()
==
type_id
<
int64_t
>
().
nam
e
())
{
if
(
param
.
input_x_
->
type
()
==
type_id
<
int64_t
>
().
hash_cod
e
())
{
CompareCompute
<
int64_t
,
LESS_THAN
>
()(
param
.
input_x_
,
param
.
input_y_
,
CompareCompute
<
int64_t
,
LESS_THAN
>
()(
param
.
input_x_
,
param
.
input_y_
,
param
.
axis_
,
param
.
output_
);
param
.
axis_
,
param
.
output_
);
}
else
if
(
param
.
input_x_
->
type
()
==
type_id
<
float
>
().
nam
e
())
{
}
else
if
(
param
.
input_x_
->
type
()
==
type_id
<
float
>
().
hash_cod
e
())
{
CompareCompute
<
float
,
LESS_THAN
>
()(
param
.
input_x_
,
param
.
input_y_
,
CompareCompute
<
float
,
LESS_THAN
>
()(
param
.
input_x_
,
param
.
input_y_
,
param
.
axis_
,
param
.
output_
);
param
.
axis_
,
param
.
output_
);
}
else
{
}
else
{
...
...
src/operators/kernel/arm/concat_kernel.cpp
浏览文件 @
7701ac0d
...
@@ -27,7 +27,7 @@ bool ConcatKernel<CPU, float>::Init(ConcatParam<CPU> *param) {
...
@@ -27,7 +27,7 @@ bool ConcatKernel<CPU, float>::Init(ConcatParam<CPU> *param) {
template
<
>
template
<
>
void
ConcatKernel
<
CPU
,
float
>::
Compute
(
const
ConcatParam
<
CPU
>
&
param
)
{
void
ConcatKernel
<
CPU
,
float
>::
Compute
(
const
ConcatParam
<
CPU
>
&
param
)
{
if
(
param
.
Inputs
()[
0
]
->
type
()
==
type_id
<
int8_t
>
().
nam
e
())
{
if
(
param
.
Inputs
()[
0
]
->
type
()
==
type_id
<
int8_t
>
().
hash_cod
e
())
{
ConcatCompute
<
int8_t
>
(
param
);
ConcatCompute
<
int8_t
>
(
param
);
}
else
{
}
else
{
ConcatCompute
<
float
>
(
param
);
ConcatCompute
<
float
>
(
param
);
...
...
src/operators/kernel/arm/convolution/conv_common.cpp
浏览文件 @
7701ac0d
...
@@ -28,7 +28,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
...
@@ -28,7 +28,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
bool
depth5x5
=
conv5x5
&&
param
->
Groups
()
==
param
->
Input
()
->
dims
()[
1
]
&&
bool
depth5x5
=
conv5x5
&&
param
->
Groups
()
==
param
->
Input
()
->
dims
()[
1
]
&&
param
->
Input
()
->
dims
()[
1
]
==
param
->
Output
()
->
dims
()[
1
];
param
->
Input
()
->
dims
()[
1
]
==
param
->
Output
()
->
dims
()[
1
];
if
(
param
->
Filter
()
->
type
()
==
type_id
<
int8_t
>
().
nam
e
())
{
if
(
param
->
Filter
()
->
type
()
==
type_id
<
int8_t
>
().
hash_cod
e
())
{
#ifndef __aarch64__
#ifndef __aarch64__
if
(
depth3x3
&&
param
->
Strides
()[
0
]
<
3
&&
if
(
depth3x3
&&
param
->
Strides
()[
0
]
<
3
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
])
{
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
])
{
...
...
src/operators/kernel/arm/transpose2_kernel.cpp
浏览文件 @
7701ac0d
...
@@ -126,13 +126,13 @@ void Transpose2Kernel<CPU, float>::Compute(const Transpose2Param<CPU> ¶m) {
...
@@ -126,13 +126,13 @@ void Transpose2Kernel<CPU, float>::Compute(const Transpose2Param<CPU> ¶m) {
const
std
::
vector
<
int
>
&
axis
=
param
.
Axis
();
const
std
::
vector
<
int
>
&
axis
=
param
.
Axis
();
bool
shuffle_channel
=
IsShuffleChannel
(
axis
);
bool
shuffle_channel
=
IsShuffleChannel
(
axis
);
if
(
shuffle_channel
)
{
if
(
shuffle_channel
)
{
if
(
param
.
InputX
()
->
type
()
==
type_id
<
int8_t
>
().
nam
e
())
{
if
(
param
.
InputX
()
->
type
()
==
type_id
<
int8_t
>
().
hash_cod
e
())
{
ShuffleChannelCompute
<
int8_t
>
(
param
);
ShuffleChannelCompute
<
int8_t
>
(
param
);
}
else
{
}
else
{
ShuffleChannelCompute
<
float
>
(
param
);
ShuffleChannelCompute
<
float
>
(
param
);
}
}
}
else
{
}
else
{
if
(
param
.
InputX
()
->
type
()
==
type_id
<
int8_t
>
().
nam
e
())
{
if
(
param
.
InputX
()
->
type
()
==
type_id
<
int8_t
>
().
hash_cod
e
())
{
Transpose2Compute
<
int8_t
>
(
param
);
Transpose2Compute
<
int8_t
>
(
param
);
}
else
{
}
else
{
Transpose2Compute
<
float
>
(
param
);
Transpose2Compute
<
float
>
(
param
);
...
...
src/operators/kernel/central-arm-func/mul_arm_func.h
浏览文件 @
7701ac0d
...
@@ -37,7 +37,7 @@ void MulCompute(const MulParam<CPU> ¶m) {
...
@@ -37,7 +37,7 @@ void MulCompute(const MulParam<CPU> ¶m) {
if
(
out_dim
.
size
()
!=
2
)
{
if
(
out_dim
.
size
()
!=
2
)
{
out
->
Resize
({
x_matrix
.
dims
()[
0
],
y_matrix
.
dims
()[
1
]});
out
->
Resize
({
x_matrix
.
dims
()[
0
],
y_matrix
.
dims
()[
1
]});
}
}
if
(
param
.
InputX
()
->
type
()
==
type_id
<
int8_t
>
().
nam
e
())
{
if
(
param
.
InputX
()
->
type
()
==
type_id
<
int8_t
>
().
hash_cod
e
())
{
out
->
mutable_data
<
int32_t
>
();
out
->
mutable_data
<
int32_t
>
();
math
::
MatMul
<
int8_t
,
int32_t
>
(
x_matrix
,
false
,
y_matrix
,
false
,
math
::
MatMul
<
int8_t
,
int32_t
>
(
x_matrix
,
false
,
y_matrix
,
false
,
static_cast
<
float
>
(
1
),
out
,
static_cast
<
float
>
(
1
),
out
,
...
...
src/operators/kernel/central-arm-func/sum_arm_func.h
浏览文件 @
7701ac0d
...
@@ -144,8 +144,7 @@ void SumCompute(const SumParam<CPU> ¶m) {
...
@@ -144,8 +144,7 @@ void SumCompute(const SumParam<CPU> ¶m) {
}
}
}
else
{
}
else
{
PADDLE_MOBILE_THROW_EXCEPTION
(
PADDLE_MOBILE_THROW_EXCEPTION
(
"Unexpected branch, output variable type is %s"
,
"Unexpected branch, output variable type is %d"
,
outvar
->
Type
());
outvar
->
Type
().
c_str
());
}
}
}
}
}
// namespace operators
}
// namespace operators
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录