Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
96bcd27a
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
96bcd27a
编写于
8月 14, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 14, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4394 Fix fc op's bug
Merge pull request !4394 from zhanyuan/dev
上级
eea10fac
ba4dec43
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
67 addition
and
20 deletion
+67
-20
mindspore/lite/schema/ops.fbs
mindspore/lite/schema/ops.fbs
+1
-0
mindspore/lite/src/ops/fullconnection.cc
mindspore/lite/src/ops/fullconnection.cc
+29
-11
mindspore/lite/src/populate_parameter.cc
mindspore/lite/src/populate_parameter.cc
+8
-1
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc
...e/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc
+22
-8
mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc
...s/converter/parser/tflite/tflite_fullyconnected_parser.cc
+7
-0
未找到文件。
mindspore/lite/schema/ops.fbs
浏览文件 @
96bcd27a
...
...
@@ -352,6 +352,7 @@ table FullConnection {
hasBias: bool;
axis: int;
useAxis: bool;
activationType: ActivationType = 0;
}
// Mean(input_tensor, axis, keep_dims)
...
...
mindspore/lite/src/ops/fullconnection.cc
浏览文件 @
96bcd27a
...
...
@@ -24,7 +24,7 @@ int FullConnection::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT
(
this
->
primitive
!=
nullptr
);
auto
input0
=
inputs_
.
front
();
MS_ASSERT
(
input0
!=
nullptr
);
auto
input1
=
inputs_
.
at
(
1
)
;
auto
input1
=
inputs_
[
1
]
;
MS_ASSERT
(
input1
!=
nullptr
);
auto
output
=
outputs_
.
front
();
MS_ASSERT
(
output
!=
nullptr
);
...
...
@@ -33,27 +33,45 @@ int FullConnection::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_LOG
(
ERROR
)
<<
"Input tensors num error"
;
return
RET_INPUT_TENSOR_ERROR
;
}
if
(
fc_prim
->
axis
()
<
1
||
fc_prim
->
axis
()
>
input0
->
shape
().
size
())
{
auto
axis
=
fc_prim
->
axis
();
auto
use_axis
=
fc_prim
->
useAxis
();
if
(
use_axis
&&
(
axis
<
1
||
axis
>=
input0
->
shape
().
size
()))
{
MS_LOG
(
ERROR
)
<<
"FullConnection axis invalid"
;
return
RET_INPUT_TENSOR_ERROR
;
}
int
new_k
=
1
;
for
(
size_t
i
=
fc_prim
->
axis
();
i
<
input0
->
shape
().
size
();
++
i
)
{
new_k
*=
input0
->
shape
().
at
(
i
);
}
if
(
new_k
!=
input1
->
shape
().
at
(
1
))
{
MS_LOG
(
ERROR
)
<<
"Input1 size invalid"
;
return
RET_PARAM_INVALID
;
if
(
use_axis
)
{
for
(
int
i
=
axis
;
i
<
input0
->
shape
().
size
();
++
i
)
{
new_k
*=
input0
->
shape
()[
i
];
}
if
(
new_k
!=
input1
->
shape
()[
1
])
{
MS_LOG
(
ERROR
)
<<
"Input1 size invalid"
;
return
RET_PARAM_INVALID
;
}
}
else
{
new_k
=
input1
->
shape
()[
1
];
}
if
(
fc_prim
->
hasBias
())
{
if
(
inputs_
.
at
(
2
)
->
shape
()[
0
]
!=
input1
->
shape
()[
0
])
{
if
(
inputs_
[
2
]
->
shape
()[
0
]
!=
input1
->
shape
()[
0
])
{
MS_LOG
(
ERROR
)
<<
"bias size invalid"
;
return
RET_PARAM_INVALID
;
}
}
std
::
vector
<
int
>
out_shape
{
inputs_
[
0
]
->
shape
()};
out_shape
.
resize
(
fc_prim
->
axis
()
+
1
);
out_shape
[
fc_prim
->
axis
()]
=
input1
->
shape
()[
0
];
if
(
use_axis
)
{
out_shape
.
resize
(
fc_prim
->
axis
()
+
1
);
out_shape
[
fc_prim
->
axis
()]
=
input1
->
shape
()[
0
];
}
else
{
int
total
=
1
;
for
(
int
i
=
0
;
i
<
input0
->
shape
().
size
();
++
i
)
{
total
*=
input0
->
shape
()[
i
];
}
out_shape
.
resize
(
2
);
auto
batch_size
=
total
/
new_k
;
out_shape
[
0
]
=
batch_size
;
out_shape
[
1
]
=
input1
->
shape
()[
0
];
}
output
->
set_shape
(
out_shape
);
output
->
set_data_type
(
input0
->
data_type
());
output
->
SetFormat
(
input0
->
GetFormat
());
...
...
mindspore/lite/src/populate_parameter.cc
浏览文件 @
96bcd27a
...
...
@@ -226,7 +226,14 @@ OpParameter *PopulateFullconnectionParameter(const lite::Primitive *primitive) {
matmul_param
->
b_transpose_
=
true
;
matmul_param
->
a_transpose_
=
false
;
matmul_param
->
has_bias_
=
param
->
hasBias
();
matmul_param
->
act_type_
=
ActType_No
;
if
(
param
->
activationType
()
==
schema
::
ActivationType_RELU
)
{
matmul_param
->
act_type_
=
ActType_Relu
;
}
else
if
(
param
->
activationType
()
==
schema
::
ActivationType_RELU6
)
{
matmul_param
->
act_type_
=
ActType_Relu6
;
}
else
{
matmul_param
->
act_type_
=
ActType_No
;
}
return
reinterpret_cast
<
OpParameter
*>
(
matmul_param
);
}
...
...
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc
浏览文件 @
96bcd27a
...
...
@@ -48,6 +48,22 @@ int PowerTestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite
return
out_t
->
ElementsNum
();
}
int
PowerTestInit2
(
std
::
vector
<
lite
::
tensor
::
Tensor
*>
*
inputs_
,
std
::
vector
<
lite
::
tensor
::
Tensor
*>
*
outputs_
,
float
*
a_ptr
,
std
::
vector
<
int
>
a_shape
,
std
::
vector
<
int
>
c_shape
)
{
auto
in_t
=
new
lite
::
tensor
::
Tensor
(
kNumberTypeFloat
,
a_shape
,
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
in_t
->
MallocData
();
memcpy
(
in_t
->
Data
(),
a_ptr
,
sizeof
(
float
)
*
in_t
->
ElementsNum
());
inputs_
->
push_back
(
in_t
);
auto
out_t
=
new
lite
::
tensor
::
Tensor
(
kNumberTypeFloat
,
c_shape
,
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
out_t
->
MallocData
();
outputs_
->
push_back
(
out_t
);
return
out_t
->
ElementsNum
();
}
TEST_F
(
TestPowerFp32
,
Simple
)
{
std
::
vector
<
lite
::
tensor
::
Tensor
*>
inputs_
;
std
::
vector
<
lite
::
tensor
::
Tensor
*>
outputs_
;
...
...
@@ -62,13 +78,12 @@ TEST_F(TestPowerFp32, Simple) {
int
total_size
=
PowerTestInit
(
&
inputs_
,
&
outputs_
,
a
,
b
,
a_shape
,
b_shape
,
c_shape
);
auto
ctx
=
new
lite
::
Context
;
ctx
->
thread_num_
=
1
;
kernel
::
PowerCPUKernel
*
op
=
new
kernel
::
PowerCPUKernel
(
reinterpret_cast
<
OpParameter
*>
(
param
),
inputs_
,
outputs_
,
ctx
,
nullptr
);
kernel
::
PowerCPUKernel
*
op
=
new
kernel
::
PowerCPUKernel
(
reinterpret_cast
<
OpParameter
*>
(
param
),
inputs_
,
outputs_
,
ctx
,
nullptr
);
op
->
Init
();
op
->
Run
();
float
correct
[]
=
{
1
,
64
,
2187
,
65536
};
float
*
output
=
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
());
for
(
int
i
=
0
;
i
<
4
;
++
i
)
printf
(
"%f "
,
output
[
i
]);
CompareOutputData
(
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
()),
correct
,
total_size
,
0.0001
);
delete
op
;
for
(
auto
t
:
inputs_
)
delete
t
;
...
...
@@ -79,18 +94,17 @@ TEST_F(TestPowerFp32, Broadcast) {
std
::
vector
<
lite
::
tensor
::
Tensor
*>
inputs_
;
std
::
vector
<
lite
::
tensor
::
Tensor
*>
outputs_
;
auto
param
=
new
PowerParameter
();
param
->
power_
=
2
;
param
->
scale_
=
1
;
param
->
shift_
=
0
;
float
a
[]
=
{
1
,
2
,
3
,
4
};
float
b
[]
=
{
2
};
std
::
vector
<
int
>
a_shape
=
{
2
,
2
};
std
::
vector
<
int
>
b_shape
=
{
1
};
std
::
vector
<
int
>
c_shape
=
{
2
,
2
};
int
total_size
=
PowerTestInit
(
&
inputs_
,
&
outputs_
,
a
,
b
,
a_shape
,
b
_shape
,
c_shape
);
int
total_size
=
PowerTestInit
2
(
&
inputs_
,
&
outputs_
,
a
,
a
_shape
,
c_shape
);
auto
ctx
=
new
lite
::
Context
;
ctx
->
thread_num_
=
2
;
kernel
::
PowerCPUKernel
*
op
=
new
kernel
::
PowerCPUKernel
(
reinterpret_cast
<
OpParameter
*>
(
param
),
inputs_
,
outputs_
,
ctx
,
nullptr
);
kernel
::
PowerCPUKernel
*
op
=
new
kernel
::
PowerCPUKernel
(
reinterpret_cast
<
OpParameter
*>
(
param
),
inputs_
,
outputs_
,
ctx
,
nullptr
);
op
->
Init
();
op
->
Run
();
float
correct
[]
=
{
1
,
4
,
9
,
16
};
...
...
mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc
浏览文件 @
96bcd27a
...
...
@@ -38,6 +38,13 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT
MS_LOG
(
DEBUG
)
<<
"parse TfliteFullyConnectedParser"
;
std
::
unique_ptr
<
schema
::
FullConnectionT
>
attr
(
new
schema
::
FullConnectionT
());
const
auto
&
tflite_attr
=
tfliteOp
->
builtin_options
.
AsFullyConnectedOptions
();
if
(
tflite_attr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"get op: "
<<
op
->
name
<<
" attr failed"
;
return
RET_NULL_PTR
;
}
attr
->
activationType
=
GetActivationFunctionType
(
tflite_attr
->
fused_activation_function
);
auto
weight_index
=
tfliteOp
->
inputs
[
1
];
const
auto
&
weight_tensor
=
tfliteTensors
[
weight_index
];
if
(
weight_tensor
==
nullptr
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录