Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8b007f24
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8b007f24
编写于
9月 08, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 08, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5862 [MSLITE] grad ops added
Merge pull request !5862 from wangchangkai/master
上级
e4c992aa
fc1d615f
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
562 addition
and
6 deletion
+562
-6
mindspore/lite/schema/model.fbs
mindspore/lite/schema/model.fbs
+1
-0
mindspore/lite/schema/ops.fbs
mindspore/lite/schema/ops.fbs
+4
-1
mindspore/lite/src/ops/activation_grad.cc
mindspore/lite/src/ops/activation_grad.cc
+30
-0
mindspore/lite/src/ops/activation_grad.h
mindspore/lite/src/ops/activation_grad.h
+2
-0
mindspore/lite/src/ops/bias_grad.cc
mindspore/lite/src/ops/bias_grad.cc
+28
-1
mindspore/lite/src/ops/bias_grad.h
mindspore/lite/src/ops/bias_grad.h
+1
-1
mindspore/lite/src/ops/bn_grad.h
mindspore/lite/src/ops/bn_grad.h
+2
-2
mindspore/lite/src/ops/bn_grad_input.cc
mindspore/lite/src/ops/bn_grad_input.cc
+75
-0
mindspore/lite/src/ops/bn_grad_input.h
mindspore/lite/src/ops/bn_grad_input.h
+47
-0
mindspore/lite/src/ops/conv2d_grad_filter.cc
mindspore/lite/src/ops/conv2d_grad_filter.cc
+126
-0
mindspore/lite/src/ops/conv2d_grad_filter.h
mindspore/lite/src/ops/conv2d_grad_filter.h
+6
-0
mindspore/lite/src/ops/conv2d_grad_input.cc
mindspore/lite/src/ops/conv2d_grad_input.cc
+126
-0
mindspore/lite/src/ops/conv2d_grad_input.h
mindspore/lite/src/ops/conv2d_grad_input.h
+6
-0
mindspore/lite/src/ops/pooling_grad.cc
mindspore/lite/src/ops/pooling_grad.cc
+57
-0
mindspore/lite/src/ops/pooling_grad.h
mindspore/lite/src/ops/pooling_grad.h
+2
-0
mindspore/lite/src/ops/power_grad.cc
mindspore/lite/src/ops/power_grad.cc
+30
-1
mindspore/lite/src/ops/power_grad.h
mindspore/lite/src/ops/power_grad.h
+1
-0
mindspore/lite/src/ops/primitive_c.cc
mindspore/lite/src/ops/primitive_c.cc
+18
-0
未找到文件。
mindspore/lite/schema/model.fbs
浏览文件 @
8b007f24
...
@@ -179,6 +179,7 @@ union PrimitiveType {
...
@@ -179,6 +179,7 @@ union PrimitiveType {
Conv2DGradInput,
Conv2DGradInput,
PoolingGrad,
PoolingGrad,
BNGrad,
BNGrad,
BNGradInput,
ApplyMomentum,
ApplyMomentum,
BiasGrad,
BiasGrad,
SoftmaxCrossEntropy,
SoftmaxCrossEntropy,
...
...
mindspore/lite/schema/ops.fbs
浏览文件 @
8b007f24
...
@@ -398,7 +398,10 @@ table BNGrad {
...
@@ -398,7 +398,10 @@ table BNGrad {
eps : float;
eps : float;
momentum: float;
momentum: float;
}
}
table BNGradInput {
eps : float;
momentum: float;
}
table Scale {
table Scale {
axis: int;
axis: int;
}
}
...
...
mindspore/lite/src/ops/activation_grad.cc
浏览文件 @
8b007f24
...
@@ -25,6 +25,36 @@ void ActivationGrad::SetType(int type) {
...
@@ -25,6 +25,36 @@ void ActivationGrad::SetType(int type) {
this
->
primitive_
->
value
.
AsActivationGrad
()
->
type
=
(
schema
::
ActivationType
)
type
;
this
->
primitive_
->
value
.
AsActivationGrad
()
->
type
=
(
schema
::
ActivationType
)
type
;
}
}
void
ActivationGrad
::
SetAlpha
(
float
alpha
)
{
this
->
primitive_
->
value
.
AsActivationGrad
()
->
alpha
=
alpha
;
}
void
ActivationGrad
::
SetAlpha
(
float
alpha
)
{
this
->
primitive_
->
value
.
AsActivationGrad
()
->
alpha
=
alpha
;
}
int
ActivationGrad
::
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
if
(
this
->
primitive_
==
nullptr
)
{
this
->
primitive_
=
new
(
std
::
nothrow
)
schema
::
PrimitiveT
;
if
(
this
->
primitive_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new primitiveT failed"
;
return
RET_ERROR
;
}
this
->
primitive_
->
value
.
type
=
schema
::
PrimitiveType_ActivationGrad
;
}
if
(
this
->
primitive_
->
value
.
type
!=
schema
::
PrimitiveType_ActivationGrad
)
{
MS_LOG
(
ERROR
)
<<
"Primitive type is error :"
<<
this
->
primitive_
->
value
.
type
;
return
RET_ERROR
;
}
auto
attr
=
std
::
make_unique
<
schema
::
ActivationGradT
>
();
if
(
prim
.
name
()
==
"ReLU"
)
{
attr
->
type
=
schema
::
ActivationType_RELU
;
}
else
if
(
prim
.
name
()
==
"Sigmoid"
)
{
attr
->
type
=
schema
::
ActivationType_SIGMOID
;
}
else
if
(
prim
.
name
()
==
"ReLU6"
)
{
attr
->
type
=
schema
::
ActivationType_RELU6
;
}
auto
alpha
=
GetValue
<
float
>
(
prim
.
GetAttr
(
"alpha"
));
attr
->
alpha
=
alpha
;
this
->
primitive_
->
value
.
value
=
attr
.
release
();
if
(
this
->
primitive_
->
value
.
value
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new primitiveT value failed"
;
return
RET_ERROR
;
}
return
RET_OK
;
}
#else
#else
int
ActivationGrad
::
UnPackToFlatBuilder
(
const
schema
::
Primitive
*
primitive
,
flatbuffers
::
FlatBufferBuilder
*
fbb
)
{
int
ActivationGrad
::
UnPackToFlatBuilder
(
const
schema
::
Primitive
*
primitive
,
flatbuffers
::
FlatBufferBuilder
*
fbb
)
{
MS_ASSERT
(
nullptr
!=
primitive
);
MS_ASSERT
(
nullptr
!=
primitive
);
...
...
mindspore/lite/src/ops/activation_grad.h
浏览文件 @
8b007f24
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include <vector>
#include <vector>
#include <set>
#include <set>
#include <cmath>
#include <cmath>
#include <memory>
#include "ir/dtype/type_id.h"
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#include "src/ops/primitive_c.h"
...
@@ -33,6 +34,7 @@ class ActivationGrad : public PrimitiveC {
...
@@ -33,6 +34,7 @@ class ActivationGrad : public PrimitiveC {
explicit
ActivationGrad
(
schema
::
PrimitiveT
*
primitive
)
:
PrimitiveC
(
primitive
)
{}
explicit
ActivationGrad
(
schema
::
PrimitiveT
*
primitive
)
:
PrimitiveC
(
primitive
)
{}
void
SetType
(
int
type
);
void
SetType
(
int
type
);
void
SetAlpha
(
float
alpha
);
void
SetAlpha
(
float
alpha
);
int
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
override
;
#else
#else
ActivationGrad
()
=
default
;
ActivationGrad
()
=
default
;
...
...
mindspore/lite/src/ops/bias_grad.cc
浏览文件 @
8b007f24
...
@@ -22,7 +22,34 @@ namespace lite {
...
@@ -22,7 +22,34 @@ namespace lite {
std
::
vector
<
int
>
BiasGrad
::
GetAxis
()
const
{
return
this
->
primitive_
->
value
.
AsBiasGrad
()
->
axis
;
}
std
::
vector
<
int
>
BiasGrad
::
GetAxis
()
const
{
return
this
->
primitive_
->
value
.
AsBiasGrad
()
->
axis
;
}
void
BiasGrad
::
SetAxis
(
const
std
::
vector
<
int
>
&
axis
)
{
this
->
primitive_
->
value
.
AsBiasGrad
()
->
axis
=
axis
;
}
void
BiasGrad
::
SetAxis
(
const
std
::
vector
<
int
>
&
axis
)
{
this
->
primitive_
->
value
.
AsBiasGrad
()
->
axis
=
axis
;
}
int
BiasGrad
::
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
if
(
this
->
primitive_
==
nullptr
)
{
this
->
primitive_
=
new
(
std
::
nothrow
)
schema
::
PrimitiveT
;
if
(
this
->
primitive_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new primitiveT failed"
;
return
RET_ERROR
;
}
this
->
primitive_
->
value
.
type
=
schema
::
PrimitiveType_BiasGrad
;
}
if
(
this
->
primitive_
->
value
.
type
!=
schema
::
PrimitiveType_BiasGrad
)
{
MS_LOG
(
ERROR
)
<<
"Primitive type is error :"
<<
this
->
primitive_
->
value
.
type
;
return
RET_ERROR
;
}
if
(
this
->
primitive_
->
value
.
value
==
nullptr
)
{
auto
attr
=
new
(
std
::
nothrow
)
schema
::
BiasGradT
();
if
(
attr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new primitiveT value failed"
;
return
RET_ERROR
;
}
attr
->
axis
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"axis"
));
this
->
primitive_
->
value
.
value
=
attr
;
if
(
this
->
primitive_
->
value
.
value
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"primitive value is nullptr"
;
return
RET_ERROR
;
}
}
return
RET_OK
;
}
#else
#else
int
BiasGrad
::
UnPackToFlatBuilder
(
const
schema
::
Primitive
*
primitive
,
flatbuffers
::
FlatBufferBuilder
*
fbb
)
{
int
BiasGrad
::
UnPackToFlatBuilder
(
const
schema
::
Primitive
*
primitive
,
flatbuffers
::
FlatBufferBuilder
*
fbb
)
{
MS_ASSERT
(
nullptr
!=
primitive
);
MS_ASSERT
(
nullptr
!=
primitive
);
...
...
mindspore/lite/src/ops/bias_grad.h
浏览文件 @
8b007f24
...
@@ -33,7 +33,7 @@ class BiasGrad : public PrimitiveC {
...
@@ -33,7 +33,7 @@ class BiasGrad : public PrimitiveC {
BiasGrad
()
=
default
;
BiasGrad
()
=
default
;
explicit
BiasGrad
(
schema
::
PrimitiveT
*
primitive
)
:
PrimitiveC
(
primitive
)
{}
explicit
BiasGrad
(
schema
::
PrimitiveT
*
primitive
)
:
PrimitiveC
(
primitive
)
{}
void
SetAxis
(
const
std
::
vector
<
int
>
&
axis
);
void
SetAxis
(
const
std
::
vector
<
int
>
&
axis
);
int
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
override
;
#else
#else
BiasGrad
()
=
default
;
BiasGrad
()
=
default
;
...
...
mindspore/lite/src/ops/bn_grad.h
浏览文件 @
8b007f24
...
@@ -14,8 +14,8 @@
...
@@ -14,8 +14,8 @@
* limitations under the License.
* limitations under the License.
*/
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_
INPUT_
H_
#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_
INPUT_
H_
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_
#include <vector>
#include <vector>
#include <set>
#include <set>
...
...
mindspore/lite/src/ops/bn_grad_input.cc
0 → 100644
浏览文件 @
8b007f24
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/bn_grad_input.h"
namespace
mindspore
{
namespace
lite
{
#ifdef PRIMITIVE_WRITEABLE
float
BNGradInput
::
GetEps
()
const
{
return
this
->
primitive_
->
value
.
AsBNGradInput
()
->
eps
;
}
float
BNGradInput
::
GetMomentum
()
const
{
return
this
->
primitive_
->
value
.
AsBNGradInput
()
->
momentum
;
}
void
BNGradInput
::
SetEps
(
float
eps
)
{
this
->
primitive_
->
value
.
AsBNGradInput
()
->
eps
=
eps
;
}
void
BNGradInput
::
SetMomentum
(
float
momentum
)
{
this
->
primitive_
->
value
.
AsBNGradInput
()
->
momentum
=
momentum
;
}
int
BNGradInput
::
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
if
(
this
->
primitive_
==
nullptr
)
{
this
->
primitive_
=
new
(
std
::
nothrow
)
schema
::
PrimitiveT
;
if
(
this
->
primitive_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new primitiveT failed"
;
return
RET_ERROR
;
}
this
->
primitive_
->
value
.
type
=
schema
::
PrimitiveType_BNGradInput
;
}
if
(
this
->
primitive_
->
value
.
type
!=
schema
::
PrimitiveType_BNGradInput
)
{
MS_LOG
(
ERROR
)
<<
"Primitive type is error :"
<<
this
->
primitive_
->
value
.
type
;
return
RET_ERROR
;
}
if
(
this
->
primitive_
->
value
.
value
==
nullptr
)
{
auto
attr
=
new
(
std
::
nothrow
)
schema
::
BNGradInputT
();
if
(
attr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new primitiveT value failed"
;
return
RET_ERROR
;
}
attr
->
eps
=
GetValue
<
float
>
(
prim
.
GetAttr
(
"eps"
));
attr
->
momentum
=
GetValue
<
float
>
(
prim
.
GetAttr
(
"momentum"
));
this
->
primitive_
->
value
.
value
=
attr
;
if
(
this
->
primitive_
->
value
.
value
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"primitive value is nullptr"
;
return
RET_ERROR
;
}
}
return
RET_OK
;
}
#else
int
BNGradInput
::
UnPackToFlatBuilder
(
const
schema
::
Primitive
*
primitive
,
flatbuffers
::
FlatBufferBuilder
*
fbb
)
{
MS_ASSERT
(
nullptr
!=
primitive
);
MS_ASSERT
(
nullptr
!=
fbb
);
auto
attr
=
primitive
->
value_as_BNGradInput
();
if
(
attr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"value_as_BNGradInputInput return nullptr"
;
return
RET_ERROR
;
}
auto
val_offset
=
schema
::
CreateBNGradInput
(
*
fbb
,
attr
->
eps
(),
attr
->
momentum
());
auto
prim_offset
=
schema
::
CreatePrimitive
(
*
fbb
,
schema
::
PrimitiveType_BNGradInput
,
val_offset
.
o
);
fbb
->
Finish
(
prim_offset
);
return
RET_OK
;
}
float
BNGradInput
::
GetEps
()
const
{
return
this
->
primitive_
->
value_as_BNGradInput
()
->
eps
();
}
float
BNGradInput
::
GetMomentum
()
const
{
return
this
->
primitive_
->
value_as_BNGradInput
()
->
momentum
();
}
#endif
}
// namespace lite
}
// namespace mindspore
mindspore/lite/src/ops/bn_grad_input.h
0 → 100644
浏览文件 @
8b007f24
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
namespace
mindspore
{
namespace
lite
{
class
BNGradInput
:
public
PrimitiveC
{
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT
(
BNGradInput
,
PrimitiveC
);
BNGradInput
()
=
default
;
explicit
BNGradInput
(
schema
::
PrimitiveT
*
primitive
)
:
PrimitiveC
(
primitive
)
{}
void
SetEps
(
float
eps
);
void
SetMomentum
(
float
momentum
);
int
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
override
;
#else
BNGradInput
()
=
default
;
int
UnPackToFlatBuilder
(
const
schema
::
Primitive
*
primitive
,
flatbuffers
::
FlatBufferBuilder
*
fbb
)
override
;
#endif
float
GetEps
()
const
;
float
GetMomentum
()
const
;
};
}
// namespace lite
}
// namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
mindspore/lite/src/ops/conv2d_grad_filter.cc
浏览文件 @
8b007f24
...
@@ -66,7 +66,133 @@ void Conv2DGradFilter::SetHasBias(bool has_bias) { this->primitive_->value.AsCon
...
@@ -66,7 +66,133 @@ void Conv2DGradFilter::SetHasBias(bool has_bias) { this->primitive_->value.AsCon
void
Conv2DGradFilter
::
SetActivationType
(
int
activation_type
)
{
void
Conv2DGradFilter
::
SetActivationType
(
int
activation_type
)
{
this
->
primitive_
->
value
.
AsConv2DGradFilter
()
->
activationType
=
(
schema
::
ActivationType
)
activation_type
;
this
->
primitive_
->
value
.
AsConv2DGradFilter
()
->
activationType
=
(
schema
::
ActivationType
)
activation_type
;
}
}
void
Conv2DGradFilter
::
PopulaterConv2DMultiGroup
(
const
Primitive
&
prim
,
schema
::
PrimitiveT
*
primitive
,
const
int
&
group
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
auto
attr
=
std
::
make_unique
<
schema
::
DepthwiseConv2DT
>
();
auto
format
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"data_format"
));
if
(
format
==
"NCHW"
)
{
attr
->
format
=
schema
::
Format_NCHW
;
}
else
if
(
format
==
"NHWC"
)
{
attr
->
format
=
schema
::
Format_NHWC
;
}
else
{
attr
->
format
=
schema
::
Format_NUM_OF_FORMAT
;
}
auto
pad_list
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"pad_list"
));
attr
->
padUp
=
pad_list
[
0
];
attr
->
padDown
=
pad_list
[
1
];
attr
->
padLeft
=
pad_list
[
2
];
attr
->
padRight
=
pad_list
[
3
];
auto
dilation
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"dilation"
));
attr
->
dilateH
=
dilation
[
0
];
attr
->
dilateW
=
dilation
[
1
];
auto
kernel_size
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"kernel_size"
));
attr
->
kernelH
=
kernel_size
[
0
];
attr
->
kernelW
=
kernel_size
[
1
];
auto
stride
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"stride"
));
attr
->
strideH
=
stride
[
2
];
attr
->
strideW
=
stride
[
3
];
auto
pad_mode
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"pad_mode"
));
if
(
pad_mode
==
"valid"
)
{
attr
->
padMode
=
schema
::
PadMode_VALID
;
}
else
if
(
pad_mode
==
"same"
)
{
attr
->
padMode
=
schema
::
PadMode_SAME
;
}
else
{
attr
->
padMode
=
schema
::
PadMode_NOTSET
;
}
if
(
prim
.
GetAttr
(
"activation_name"
)
!=
nullptr
)
{
std
::
string
activate_name
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"activation_name"
));
attr
->
activationType
=
kActivationTypeMap
[
activate_name
];
}
else
{
attr
->
activationType
=
schema
::
ActivationType_NO_ACTIVATION
;
}
int
channel_mutiplier
=
1
;
if
(
prim
.
GetAttr
(
"channel_mutiplier"
)
!=
nullptr
)
{
channel_mutiplier
=
GetValue
<
int
>
(
prim
.
GetAttr
(
"channel_multiplier"
));
}
attr
->
channelMultiplier
=
channel_mutiplier
;
primitive
->
value
.
type
=
schema
::
PrimitiveType_DepthwiseConv2D
;
primitive
->
value
.
value
=
attr
.
release
();
}
void
Conv2DGradFilter
::
PopulaterConv2DSingleGroup
(
const
Primitive
&
prim
,
schema
::
PrimitiveT
*
primitive
,
const
int
&
group
)
{
auto
attr
=
std
::
make_unique
<
schema
::
Conv2DT
>
();
attr
->
group
=
group
;
auto
format
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"data_format"
));
if
(
format
==
"NCHW"
)
{
attr
->
format
=
schema
::
Format_NCHW
;
}
else
if
(
format
==
"NHWC"
)
{
attr
->
format
=
schema
::
Format_NHWC
;
}
else
{
attr
->
format
=
schema
::
Format_NUM_OF_FORMAT
;
}
auto
pad_list
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"pad_list"
));
attr
->
padUp
=
pad_list
[
0
];
attr
->
padDown
=
pad_list
[
1
];
attr
->
padLeft
=
pad_list
[
2
];
attr
->
padRight
=
pad_list
[
3
];
auto
dilation
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"dilation"
));
attr
->
dilateH
=
dilation
[
0
];
attr
->
dilateW
=
dilation
[
1
];
auto
kernel_size
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"kernel_size"
));
attr
->
kernelH
=
kernel_size
[
0
];
attr
->
kernelW
=
kernel_size
[
1
];
auto
stride
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"stride"
));
attr
->
strideH
=
stride
[
2
];
attr
->
strideW
=
stride
[
3
];
attr
->
channelOut
=
GetValue
<
int
>
(
prim
.
GetAttr
(
"out_channel"
));
auto
pad_mode
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"pad_mode"
));
if
(
pad_mode
==
"valid"
)
{
attr
->
padMode
=
schema
::
PadMode_VALID
;
}
else
if
(
pad_mode
==
"same"
)
{
attr
->
padMode
=
schema
::
PadMode_SAME
;
}
else
{
attr
->
padMode
=
schema
::
PadMode_NOTSET
;
}
if
(
prim
.
GetAttr
(
"activation_name"
)
!=
nullptr
)
{
std
::
string
activate_name
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"activation_name"
));
attr
->
activationType
=
kActivationTypeMap
[
activate_name
];
}
else
{
attr
->
activationType
=
schema
::
ActivationType_NO_ACTIVATION
;
}
primitive
->
value
.
type
=
schema
::
PrimitiveType_Conv2D
;
primitive
->
value
.
value
=
attr
.
release
();
}
int
Conv2DGradFilter
::
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
if
(
this
->
primitive_
==
nullptr
)
{
this
->
primitive_
=
new
(
std
::
nothrow
)
schema
::
PrimitiveT
;
if
(
this
->
primitive_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new primitiveT failed"
;
return
RET_ERROR
;
}
this
->
primitive_
->
value
.
type
=
schema
::
PrimitiveType_Conv2DGradFilter
;
}
if
(
this
->
primitive_
->
value
.
type
!=
schema
::
PrimitiveType_Conv2DGradFilter
)
{
MS_LOG
(
ERROR
)
<<
"primitive_ type is error:"
<<
this
->
primitive_
->
value
.
type
;
return
RET_ERROR
;
}
int
group
=
GetValue
<
int
>
(
prim
.
GetAttr
(
"group"
));
if
(
group
>
1
)
{
PopulaterConv2DMultiGroup
(
prim
,
this
->
primitive_
,
group
,
inputs
);
}
else
{
PopulaterConv2DSingleGroup
(
prim
,
this
->
primitive_
,
group
);
}
return
RET_OK
;
}
#else
#else
int
Conv2DGradFilter
::
UnPackToFlatBuilder
(
const
schema
::
Primitive
*
primitive
,
flatbuffers
::
FlatBufferBuilder
*
fbb
)
{
int
Conv2DGradFilter
::
UnPackToFlatBuilder
(
const
schema
::
Primitive
*
primitive
,
flatbuffers
::
FlatBufferBuilder
*
fbb
)
{
MS_ASSERT
(
nullptr
!=
primitive
);
MS_ASSERT
(
nullptr
!=
primitive
);
...
...
mindspore/lite/src/ops/conv2d_grad_filter.h
浏览文件 @
8b007f24
...
@@ -20,6 +20,8 @@
...
@@ -20,6 +20,8 @@
#include <vector>
#include <vector>
#include <set>
#include <set>
#include <cmath>
#include <cmath>
#include <memory>
#include <string>
#include "ir/dtype/type_id.h"
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#include "src/ops/primitive_c.h"
...
@@ -48,6 +50,10 @@ class Conv2DGradFilter : public PrimitiveC {
...
@@ -48,6 +50,10 @@ class Conv2DGradFilter : public PrimitiveC {
void
SetDilateH
(
int
dilate_h
);
void
SetDilateH
(
int
dilate_h
);
void
SetHasBias
(
bool
has_bias
);
void
SetHasBias
(
bool
has_bias
);
void
SetActivationType
(
int
activation_type
);
void
SetActivationType
(
int
activation_type
);
int
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
override
;
void
PopulaterConv2DMultiGroup
(
const
Primitive
&
prim
,
schema
::
PrimitiveT
*
primitive
,
const
int
&
group
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
);
void
PopulaterConv2DSingleGroup
(
const
Primitive
&
prim
,
schema
::
PrimitiveT
*
primitive
,
const
int
&
group
);
#else
#else
Conv2DGradFilter
()
=
default
;
Conv2DGradFilter
()
=
default
;
...
...
mindspore/lite/src/ops/conv2d_grad_input.cc
浏览文件 @
8b007f24
...
@@ -64,7 +64,133 @@ void Conv2DGradInput::SetHasBias(bool has_bias) { this->primitive_->value.AsConv
...
@@ -64,7 +64,133 @@ void Conv2DGradInput::SetHasBias(bool has_bias) { this->primitive_->value.AsConv
void
Conv2DGradInput
::
SetActivationType
(
int
activation_type
)
{
void
Conv2DGradInput
::
SetActivationType
(
int
activation_type
)
{
this
->
primitive_
->
value
.
AsConv2DGradInput
()
->
activationType
=
(
schema
::
ActivationType
)
activation_type
;
this
->
primitive_
->
value
.
AsConv2DGradInput
()
->
activationType
=
(
schema
::
ActivationType
)
activation_type
;
}
}
void
Conv2DGradInput
::
PopulaterConv2DMultiGroup
(
const
Primitive
&
prim
,
schema
::
PrimitiveT
*
primitive
,
const
int
&
group
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
auto
attr
=
std
::
make_unique
<
schema
::
DepthwiseConv2DT
>
();
auto
format
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"data_format"
));
if
(
format
==
"NCHW"
)
{
attr
->
format
=
schema
::
Format_NCHW
;
}
else
if
(
format
==
"NHWC"
)
{
attr
->
format
=
schema
::
Format_NHWC
;
}
else
{
attr
->
format
=
schema
::
Format_NUM_OF_FORMAT
;
}
auto
pad_list
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"pad_list"
));
attr
->
padUp
=
pad_list
[
0
];
attr
->
padDown
=
pad_list
[
1
];
attr
->
padLeft
=
pad_list
[
2
];
attr
->
padRight
=
pad_list
[
3
];
auto
dilation
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"dilation"
));
attr
->
dilateH
=
dilation
[
0
];
attr
->
dilateW
=
dilation
[
1
];
auto
kernel_size
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"kernel_size"
));
attr
->
kernelH
=
kernel_size
[
0
];
attr
->
kernelW
=
kernel_size
[
1
];
auto
stride
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"stride"
));
attr
->
strideH
=
stride
[
2
];
attr
->
strideW
=
stride
[
3
];
auto
pad_mode
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"pad_mode"
));
if
(
pad_mode
==
"valid"
)
{
attr
->
padMode
=
schema
::
PadMode_VALID
;
}
else
if
(
pad_mode
==
"same"
)
{
attr
->
padMode
=
schema
::
PadMode_SAME
;
}
else
{
attr
->
padMode
=
schema
::
PadMode_NOTSET
;
}
if
(
prim
.
GetAttr
(
"activation_name"
)
!=
nullptr
)
{
std
::
string
activate_name
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"activation_name"
));
attr
->
activationType
=
kActivationTypeMap
[
activate_name
];
}
else
{
attr
->
activationType
=
schema
::
ActivationType_NO_ACTIVATION
;
}
int
channel_mutiplier
=
1
;
if
(
prim
.
GetAttr
(
"channel_mutiplier"
)
!=
nullptr
)
{
channel_mutiplier
=
GetValue
<
int
>
(
prim
.
GetAttr
(
"channel_multiplier"
));
}
attr
->
channelMultiplier
=
channel_mutiplier
;
primitive
->
value
.
type
=
schema
::
PrimitiveType_DepthwiseConv2D
;
primitive
->
value
.
value
=
attr
.
release
();
}
void
Conv2DGradInput
::
PopulaterConv2DSingleGroup
(
const
Primitive
&
prim
,
schema
::
PrimitiveT
*
primitive
,
const
int
&
group
)
{
auto
attr
=
std
::
make_unique
<
schema
::
Conv2DT
>
();
attr
->
group
=
group
;
auto
format
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"data_format"
));
if
(
format
==
"NCHW"
)
{
attr
->
format
=
schema
::
Format_NCHW
;
}
else
if
(
format
==
"NHWC"
)
{
attr
->
format
=
schema
::
Format_NHWC
;
}
else
{
attr
->
format
=
schema
::
Format_NUM_OF_FORMAT
;
}
auto
pad_list
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"pad_list"
));
attr
->
padUp
=
pad_list
[
0
];
attr
->
padDown
=
pad_list
[
1
];
attr
->
padLeft
=
pad_list
[
2
];
attr
->
padRight
=
pad_list
[
3
];
auto
dilation
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"dilation"
));
attr
->
dilateH
=
dilation
[
0
];
attr
->
dilateW
=
dilation
[
1
];
auto
kernel_size
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"kernel_size"
));
attr
->
kernelH
=
kernel_size
[
0
];
attr
->
kernelW
=
kernel_size
[
1
];
auto
stride
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"stride"
));
attr
->
strideH
=
stride
[
2
];
attr
->
strideW
=
stride
[
3
];
attr
->
channelOut
=
GetValue
<
int
>
(
prim
.
GetAttr
(
"out_channel"
));
auto
pad_mode
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"pad_mode"
));
if
(
pad_mode
==
"valid"
)
{
attr
->
padMode
=
schema
::
PadMode_VALID
;
}
else
if
(
pad_mode
==
"same"
)
{
attr
->
padMode
=
schema
::
PadMode_SAME
;
}
else
{
attr
->
padMode
=
schema
::
PadMode_NOTSET
;
}
if
(
prim
.
GetAttr
(
"activation_name"
)
!=
nullptr
)
{
std
::
string
activate_name
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"activation_name"
));
attr
->
activationType
=
kActivationTypeMap
[
activate_name
];
}
else
{
attr
->
activationType
=
schema
::
ActivationType_NO_ACTIVATION
;
}
primitive
->
value
.
type
=
schema
::
PrimitiveType_Conv2D
;
primitive
->
value
.
value
=
attr
.
release
();
}
int
Conv2DGradInput
::
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
if
(
this
->
primitive_
==
nullptr
)
{
this
->
primitive_
=
new
(
std
::
nothrow
)
schema
::
PrimitiveT
;
if
(
this
->
primitive_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new primitiveT failed"
;
return
RET_ERROR
;
}
this
->
primitive_
->
value
.
type
=
schema
::
PrimitiveType_Conv2DGradInput
;
}
if
(
this
->
primitive_
->
value
.
type
!=
schema
::
PrimitiveType_Conv2DGradInput
)
{
MS_LOG
(
ERROR
)
<<
"primitive_ type is error:"
<<
this
->
primitive_
->
value
.
type
;
return
RET_ERROR
;
}
int
group
=
GetValue
<
int
>
(
prim
.
GetAttr
(
"group"
));
if
(
group
>
1
)
{
PopulaterConv2DMultiGroup
(
prim
,
this
->
primitive_
,
group
,
inputs
);
}
else
{
PopulaterConv2DSingleGroup
(
prim
,
this
->
primitive_
,
group
);
}
return
RET_OK
;
}
#else
#else
int
Conv2DGradInput
::
UnPackToFlatBuilder
(
const
schema
::
Primitive
*
primitive
,
flatbuffers
::
FlatBufferBuilder
*
fbb
)
{
int
Conv2DGradInput
::
UnPackToFlatBuilder
(
const
schema
::
Primitive
*
primitive
,
flatbuffers
::
FlatBufferBuilder
*
fbb
)
{
MS_ASSERT
(
nullptr
!=
primitive
);
MS_ASSERT
(
nullptr
!=
primitive
);
...
...
mindspore/lite/src/ops/conv2d_grad_input.h
浏览文件 @
8b007f24
...
@@ -20,6 +20,8 @@
...
@@ -20,6 +20,8 @@
#include <vector>
#include <vector>
#include <set>
#include <set>
#include <cmath>
#include <cmath>
#include <memory>
#include <string>
#include "ir/dtype/type_id.h"
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#include "src/ops/primitive_c.h"
...
@@ -48,6 +50,10 @@ class Conv2DGradInput : public PrimitiveC {
...
@@ -48,6 +50,10 @@ class Conv2DGradInput : public PrimitiveC {
void
SetDilateH
(
int
dilate_h
);
void
SetDilateH
(
int
dilate_h
);
void
SetHasBias
(
bool
has_bias
);
void
SetHasBias
(
bool
has_bias
);
void
SetActivationType
(
int
activation_type
);
void
SetActivationType
(
int
activation_type
);
int
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
override
;
void
PopulaterConv2DMultiGroup
(
const
Primitive
&
prim
,
schema
::
PrimitiveT
*
primitive
,
const
int
&
group
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
);
void
PopulaterConv2DSingleGroup
(
const
Primitive
&
prim
,
schema
::
PrimitiveT
*
primitive
,
const
int
&
group
);
#else
#else
Conv2DGradInput
()
=
default
;
Conv2DGradInput
()
=
default
;
...
...
mindspore/lite/src/ops/pooling_grad.cc
浏览文件 @
8b007f24
...
@@ -52,7 +52,64 @@ void PoolingGrad::SetPadRight(int pad_right) { this->primitive_->value.AsPooling
...
@@ -52,7 +52,64 @@ void PoolingGrad::SetPadRight(int pad_right) { this->primitive_->value.AsPooling
void
PoolingGrad
::
SetRoundMode
(
int
round_mode
)
{
void
PoolingGrad
::
SetRoundMode
(
int
round_mode
)
{
this
->
primitive_
->
value
.
AsPoolingGrad
()
->
roundMode
=
(
schema
::
RoundMode
)
round_mode
;
this
->
primitive_
->
value
.
AsPoolingGrad
()
->
roundMode
=
(
schema
::
RoundMode
)
round_mode
;
}
}
int
PoolingGrad
::
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
if
(
this
->
primitive_
==
nullptr
)
{
this
->
primitive_
=
new
(
std
::
nothrow
)
schema
::
PrimitiveT
;
if
(
this
->
primitive_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new primitiveT failed"
;
return
RET_ERROR
;
}
this
->
primitive_
->
value
.
type
=
schema
::
PrimitiveType_PoolingGrad
;
}
if
(
this
->
primitive_
->
value
.
type
!=
schema
::
PrimitiveType_PoolingGrad
)
{
MS_LOG
(
ERROR
)
<<
"Primitive type is error :"
<<
this
->
primitive_
->
value
.
type
;
return
RET_ERROR
;
}
if
(
this
->
primitive_
->
value
.
value
==
nullptr
)
{
auto
attr
=
new
(
std
::
nothrow
)
schema
::
PoolingGradT
();
if
(
attr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new primitiveT value failed"
;
return
RET_ERROR
;
}
auto
format
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"data_format"
));
if
(
format
==
"NCHW"
)
{
attr
->
format
=
schema
::
Format_NCHW
;
}
else
if
(
format
==
"NHWC"
)
{
attr
->
format
=
schema
::
Format_NHWC
;
}
else
{
attr
->
format
=
schema
::
Format_NUM_OF_FORMAT
;
}
if
(
prim
.
instance_name
()
==
"MaxPool"
)
{
attr
->
poolingMode
=
schema
::
PoolMode_MAX_POOLING
;
}
else
if
(
prim
.
instance_name
()
==
"MeanPool"
)
{
attr
->
poolingMode
=
schema
::
PoolMode_MEAN_POOLING
;
}
auto
pad_mode
=
GetValue
<
std
::
string
>
(
prim
.
GetAttr
(
"padding"
));
if
(
pad_mode
==
"VALID"
)
{
attr
->
padMode
=
schema
::
PadMode_VALID
;
}
else
if
(
pad_mode
==
"SAME"
)
{
attr
->
padMode
=
schema
::
PadMode_SAME
;
}
else
{
attr
->
padMode
=
schema
::
PadMode_NOTSET
;
}
auto
kernel_size
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"ksize"
));
attr
->
windowH
=
kernel_size
[
2
];
attr
->
windowW
=
kernel_size
[
3
];
auto
stride
=
GetValue
<
std
::
vector
<
int
>>
(
prim
.
GetAttr
(
"strides"
));
attr
->
strideH
=
stride
[
2
];
attr
->
strideW
=
stride
[
3
];
this
->
primitive_
->
value
.
value
=
attr
;
if
(
this
->
primitive_
->
value
.
value
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"primitive value is nullptr"
;
return
RET_ERROR
;
}
}
return
RET_OK
;
}
#else
#else
int
PoolingGrad
::
GetFormat
()
const
{
return
this
->
primitive_
->
value_as_PoolingGrad
()
->
format
();
}
int
PoolingGrad
::
GetFormat
()
const
{
return
this
->
primitive_
->
value_as_PoolingGrad
()
->
format
();
}
...
...
mindspore/lite/src/ops/pooling_grad.h
浏览文件 @
8b007f24
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include <vector>
#include <vector>
#include <set>
#include <set>
#include <cmath>
#include <cmath>
#include <string>
#include "ir/dtype/type_id.h"
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#include "src/ops/primitive_c.h"
...
@@ -44,6 +45,7 @@ class PoolingGrad : public PrimitiveC {
...
@@ -44,6 +45,7 @@ class PoolingGrad : public PrimitiveC {
void
SetPadLeft
(
int
pad_left
);
void
SetPadLeft
(
int
pad_left
);
void
SetPadRight
(
int
pad_right
);
void
SetPadRight
(
int
pad_right
);
void
SetRoundMode
(
int
round_mode
);
void
SetRoundMode
(
int
round_mode
);
int
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
override
;
#else
#else
PoolingGrad
()
=
default
;
PoolingGrad
()
=
default
;
...
...
mindspore/lite/src/ops/power_grad.cc
浏览文件 @
8b007f24
...
@@ -26,7 +26,36 @@ float PowerGrad::GetShift() const { return this->primitive_->value.AsPowerGrad()
...
@@ -26,7 +26,36 @@ float PowerGrad::GetShift() const { return this->primitive_->value.AsPowerGrad()
void
PowerGrad
::
SetPower
(
float
power
)
{
this
->
primitive_
->
value
.
AsPowerGrad
()
->
power
=
power
;
}
void
PowerGrad
::
SetPower
(
float
power
)
{
this
->
primitive_
->
value
.
AsPowerGrad
()
->
power
=
power
;
}
void
PowerGrad
::
SetScale
(
float
scale
)
{
this
->
primitive_
->
value
.
AsPowerGrad
()
->
scale
=
scale
;
}
void
PowerGrad
::
SetScale
(
float
scale
)
{
this
->
primitive_
->
value
.
AsPowerGrad
()
->
scale
=
scale
;
}
void
PowerGrad
::
SetShift
(
float
shift
)
{
this
->
primitive_
->
value
.
AsPowerGrad
()
->
shift
=
shift
;
}
void
PowerGrad
::
SetShift
(
float
shift
)
{
this
->
primitive_
->
value
.
AsPowerGrad
()
->
shift
=
shift
;
}
int
PowerGrad
::
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
if
(
this
->
primitive_
==
nullptr
)
{
this
->
primitive_
=
new
(
std
::
nothrow
)
schema
::
PrimitiveT
;
if
(
this
->
primitive_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new primitiveT failed"
;
return
RET_ERROR
;
}
this
->
primitive_
->
value
.
type
=
schema
::
PrimitiveType_PowerGrad
;
}
if
(
this
->
primitive_
->
value
.
type
!=
schema
::
PrimitiveType_PowerGrad
)
{
MS_LOG
(
ERROR
)
<<
"Primitive type is error :"
<<
this
->
primitive_
->
value
.
type
;
return
RET_ERROR
;
}
if
(
this
->
primitive_
->
value
.
value
==
nullptr
)
{
auto
attr
=
new
(
std
::
nothrow
)
schema
::
PowerGradT
();
if
(
attr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new primitiveT value failed"
;
return
RET_ERROR
;
}
attr
->
power
=
GetValue
<
float
>
(
prim
.
GetAttr
(
"power"
));
attr
->
scale
=
GetValue
<
float
>
(
prim
.
GetAttr
(
"scale"
));
attr
->
shift
=
GetValue
<
float
>
(
prim
.
GetAttr
(
"shift"
));
this
->
primitive_
->
value
.
value
=
attr
;
if
(
this
->
primitive_
->
value
.
value
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"primitive value is nullptr"
;
return
RET_ERROR
;
}
}
return
RET_OK
;
}
#else
#else
float
PowerGrad
::
GetPower
()
const
{
return
this
->
primitive_
->
value_as_PowerGrad
()
->
power
();
}
float
PowerGrad
::
GetPower
()
const
{
return
this
->
primitive_
->
value_as_PowerGrad
()
->
power
();
}
...
...
mindspore/lite/src/ops/power_grad.h
浏览文件 @
8b007f24
...
@@ -34,6 +34,7 @@ class PowerGrad : public PrimitiveC {
...
@@ -34,6 +34,7 @@ class PowerGrad : public PrimitiveC {
void
SetPower
(
float
power
);
void
SetPower
(
float
power
);
void
SetScale
(
float
scale
);
void
SetScale
(
float
scale
);
void
SetShift
(
float
shift
);
void
SetShift
(
float
shift
);
int
UnPackAttr
(
const
Primitive
&
prim
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
override
;
#else
#else
PowerGrad
()
=
default
;
PowerGrad
()
=
default
;
...
...
mindspore/lite/src/ops/primitive_c.cc
浏览文件 @
8b007f24
...
@@ -383,6 +383,20 @@ std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &pri
...
@@ -383,6 +383,20 @@ std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &pri
return
NewPrimitiveC
<
ApplyMomentum
>
(
prim
,
inputs
,
quantType
);
return
NewPrimitiveC
<
ApplyMomentum
>
(
prim
,
inputs
,
quantType
);
}
else
if
(
op_type
==
"BatchNormGrad"
)
{
}
else
if
(
op_type
==
"BatchNormGrad"
)
{
return
NewPrimitiveC
<
BNGrad
>
(
prim
,
inputs
,
quantType
);
return
NewPrimitiveC
<
BNGrad
>
(
prim
,
inputs
,
quantType
);
}
else
if
(
op_type
==
"Conv2DGradInput"
)
{
return
NewPrimitiveC
<
Conv2DGradInput
>
(
prim
,
inputs
,
quantType
);
}
else
if
(
op_type
==
"Conv2DGradFilter"
)
{
return
NewPrimitiveC
<
Conv2DGradFilter
>
(
prim
,
inputs
,
quantType
);
}
else
if
(
op_type
==
"BiasGrad"
)
{
return
NewPrimitiveC
<
BiasGrad
>
(
prim
,
inputs
,
quantType
);
}
else
if
(
op_type
==
"ActivationGrad"
)
{
return
NewPrimitiveC
<
ActivationGrad
>
(
prim
,
inputs
,
quantType
);
}
else
if
(
op_type
==
"PoolingGrad"
)
{
return
NewPrimitiveC
<
PoolingGrad
>
(
prim
,
inputs
,
quantType
);
}
else
if
(
op_type
==
"BNGradInput"
)
{
return
NewPrimitiveC
<
BNGradInput
>
(
prim
,
inputs
,
quantType
);
}
else
if
(
op_type
==
"PowerGrad"
)
{
return
NewPrimitiveC
<
PowerGrad
>
(
prim
,
inputs
,
quantType
);
#endif
#endif
}
else
{
}
else
{
MS_LOG
(
ERROR
)
<<
"Unsupported primitive type in UnPackFromPrimitive : "
<<
op_type
;
MS_LOG
(
ERROR
)
<<
"Unsupported primitive type in UnPackFromPrimitive : "
<<
op_type
;
...
@@ -620,6 +634,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT
...
@@ -620,6 +634,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT
return
new
ArithmeticGrad
(
primitive
);
return
new
ArithmeticGrad
(
primitive
);
case
schema
::
PrimitiveType_DivGrad
:
case
schema
::
PrimitiveType_DivGrad
:
return
new
ArithmeticGrad
(
primitive
);
return
new
ArithmeticGrad
(
primitive
);
case
schema
::
PrimitiveType_PowerGrad
:
return
new
PowerGrad
(
primitive
);
case
schema
::
PrimitiveType_BNGradInput
:
return
new
BNGradInput
(
primitive
);
#endif
#endif
default:
default:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录