Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c90b66a0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c90b66a0
编写于
4月 27, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 27, 2020
浏览文件
操作
浏览文件
下载
差异文件
!777 fix bugs and dock ops
Merge pull request !777 from zhangbuxue/fix_bug_and_dock_ops
上级
118d434a
381acf61
变更
30
隐藏空白更改
内联
并排
Showing
30 changed file
with
291 addition
and
98 deletion
+291
-98
mindspore/ccsrc/debug/info.h
mindspore/ccsrc/debug/info.h
+1
-1
mindspore/ccsrc/debug/trace.cc
mindspore/ccsrc/debug/trace.cc
+1
-1
mindspore/ccsrc/debug/trace_info.h
mindspore/ccsrc/debug/trace_info.h
+1
-1
mindspore/ccsrc/ir/dtype/type.cc
mindspore/ccsrc/ir/dtype/type.cc
+6
-0
mindspore/ccsrc/operator/composite/composite.cc
mindspore/ccsrc/operator/composite/composite.cc
+8
-5
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+0
-1
mindspore/ccsrc/optimizer/irpass.h
mindspore/ccsrc/optimizer/irpass.h
+0
-1
mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc
mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc
+0
-1
mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h
mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h
+0
-1
mindspore/ccsrc/pipeline/base.h
mindspore/ccsrc/pipeline/base.h
+0
-2
mindspore/ccsrc/pipeline/pipeline.cc
mindspore/ccsrc/pipeline/pipeline.cc
+1
-1
mindspore/ccsrc/pipeline/pipeline_ge.cc
mindspore/ccsrc/pipeline/pipeline_ge.cc
+21
-22
mindspore/ccsrc/pipeline/pipeline_ge.h
mindspore/ccsrc/pipeline/pipeline_ge.h
+0
-2
mindspore/ccsrc/pipeline/static_analysis/abstract_function.h
mindspore/ccsrc/pipeline/static_analysis/abstract_function.h
+2
-2
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+0
-1
mindspore/ccsrc/transform/util.cc
mindspore/ccsrc/transform/util.cc
+8
-11
mindspore/nn/optim/momentum.py
mindspore/nn/optim/momentum.py
+1
-1
mindspore/ops/_op_impl/tbe/__init__.py
mindspore/ops/_op_impl/tbe/__init__.py
+5
-1
mindspore/ops/_op_impl/tbe/fill.py
mindspore/ops/_op_impl/tbe/fill.py
+2
-2
mindspore/ops/_op_impl/tbe/floor_mod.py
mindspore/ops/_op_impl/tbe/floor_mod.py
+38
-0
mindspore/ops/_op_impl/tbe/greater_equal.py
mindspore/ops/_op_impl/tbe/greater_equal.py
+45
-0
mindspore/ops/_op_impl/tbe/not_equal.py
mindspore/ops/_op_impl/tbe/not_equal.py
+45
-0
mindspore/ops/_op_impl/tbe/scatter_nd.py
mindspore/ops/_op_impl/tbe/scatter_nd.py
+1
-1
mindspore/ops/_op_impl/tbe/scatter_nd_update.py
mindspore/ops/_op_impl/tbe/scatter_nd_update.py
+42
-0
mindspore/ops/composite/multitype_ops/getitem_impl.py
mindspore/ops/composite/multitype_ops/getitem_impl.py
+15
-0
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+25
-24
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+2
-1
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+14
-9
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+3
-3
tests/ut/python/ops/test_tensor_slice.py
tests/ut/python/ops/test_tensor_slice.py
+4
-3
未找到文件。
mindspore/ccsrc/debug/info.h
浏览文件 @
c90b66a0
...
...
@@ -134,7 +134,7 @@ class DebugInfo : public Base {
explicit
DebugInfo
(
const
LocationPtr
&
loc
);
virtual
~
DebugInfo
()
=
default
;
~
DebugInfo
()
override
=
default
;
MS_DECLARE_PARENT
(
DebugInfo
,
Base
);
int64_t
debug_id
();
int64_t
unique_id
()
const
{
return
unique_id_
;
}
...
...
mindspore/ccsrc/debug/trace.cc
浏览文件 @
c90b66a0
...
...
@@ -231,10 +231,10 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
auto
engine
=
node_cfg_
->
engine
();
auto
cfg
=
engine
->
MakeConfig
(
node
,
ctx
);
auto
abs
=
engine
->
cache
().
GetValue
(
cfg
);
if
(
abs
==
nullptr
)
{
return
"Undefined"
;
}
auto
dtype
=
abs
->
BuildType
();
auto
shape
=
abs
->
BuildShape
();
std
::
ostringstream
oss
;
...
...
mindspore/ccsrc/debug/trace_info.h
浏览文件 @
c90b66a0
...
...
@@ -321,7 +321,7 @@ class TraceTransform : public TraceInfo {
std
::
string
full_name
()
override
{
return
full_name_
+
transform_name_
;
}
MS_DECLARE_PARENT
(
TraceTransform
,
TraceInfo
);
virtual
std
::
string
symbol
()
{
std
::
string
symbol
()
override
{
if
(
transform_name_
.
empty
())
{
return
""
;
}
...
...
mindspore/ccsrc/ir/dtype/type.cc
浏览文件 @
c90b66a0
...
...
@@ -87,6 +87,12 @@ const char *MetaIdLabel(const TypeId &v) {
return
"kMetaTypeExternal"
;
case
kMetaTypeNone
:
return
"kMetaTypeNone"
;
case
kMetaTypeNull
:
return
"kMetaTypeNull"
;
case
kMetaTypeEllipsis
:
return
"kMetaTypeEllipsis"
;
case
kMetaTypeEnd
:
return
"kMetaTypeEnd"
;
default:
return
"[Unknown Type Id]"
;
}
...
...
mindspore/ccsrc/operator/composite/composite.cc
浏览文件 @
c90b66a0
...
...
@@ -1084,6 +1084,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
std
::
vector
<
unsigned
int
>
shrink
;
auto
slice_tuple_eles
=
slice_tuple
->
elements
();
size_t
ellipsis_num
=
0
;
for
(
size_t
index
=
0
;
index
<
slice_tuple_size
;
index
++
)
{
if
(
slice_tuple_eles
[
index
]
->
isa
<
AbstractSlice
>
())
{
AbstractSlicePtr
slice
=
dyn_cast
<
AbstractSlice
>
(
slice_tuple_eles
[
index
]);
...
...
@@ -1118,12 +1119,13 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
<<
slice_tuple_eles
[
index
]
->
ToString
();
}
for
(
size_t
index
=
slice_tuple_size
;
index
<
shape_size
;
index
++
)
{
begin
->
push_back
(
0
);
end
->
push_back
(
shape
[
index
]);
strides
->
push_back
(
1
);
if
(
ellipsis_num
==
0
)
{
for
(
size_t
index
=
slice_tuple_size
;
index
<
shape_size
;
index
++
)
{
begin
->
push_back
(
0
);
end
->
push_back
(
shape
[
index
]);
strides
->
push_back
(
1
);
}
}
return
ConvertBinaryToDecimal
(
shrink
);
}
...
...
@@ -1199,6 +1201,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
if
(
scalar_ptr
->
BuildValue
()
->
cast
<
BoolImmPtr
>
()
->
value
())
{
return
ExpandADim
(
ret_graph
,
tensor_node
);
}
MS_LOG
(
EXCEPTION
)
<<
"TensorSlice not support the index is False."
;
}
shrink_axis_mask
=
GenerateStridedSliceParametersFromNumber
(
scalar_ptr
,
shape
,
&
begin
,
&
end
,
&
strides
);
}
else
if
(
args_spec_list
[
1
]
->
isa
<
AbstractEllipsis
>
())
{
...
...
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
c90b66a0
...
...
@@ -133,7 +133,6 @@ ResolveIRPassLib::ResolveIRPassLib() {
InferenceOptPrepareLib
::
InferenceOptPrepareLib
()
{
grad_var_prepare_
=
MakeSubstitution
(
GradVarPrepare
(),
"grad_var_prepare"
,
IsCNode
);
}
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/optimizer/irpass.h
浏览文件 @
c90b66a0
...
...
@@ -159,7 +159,6 @@ inline bool IsCNodeDup(const AnfNodePtr &node) {
}
return
false
;
}
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc
浏览文件 @
c90b66a0
...
...
@@ -31,7 +31,6 @@
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
static
AnfNodePtr
GenerateUnpackGraphNode
(
std
::
vector
<
AnfNodePtr
>
inputs_y
,
FuncGraphPtr
func_graph
,
AnfNodePtr
func_node
,
bool
is_unpack
,
bool
sens_param
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
...
...
mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h
浏览文件 @
c90b66a0
...
...
@@ -33,7 +33,6 @@
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
// {{GradOperation, g, w}, Ys}
// {UnPackCall, {GradOperation, g, w}, Ys}
class
GradVarPrepare
:
public
AnfVisitor
{
...
...
mindspore/ccsrc/pipeline/base.h
浏览文件 @
c90b66a0
...
...
@@ -28,13 +28,11 @@
namespace
mindspore
{
namespace
pipeline
{
struct
ExecutorInfo
{
FuncGraphPtr
func_graph
;
ResourcePtr
resource
;
std
::
size_t
arg_list_size
;
};
using
ExecutorInfoPtr
=
std
::
shared_ptr
<
ExecutorInfo
>
;
inline
std
::
string
GetPhasePrefix
(
const
std
::
string
&
phase
)
{
...
...
mindspore/ccsrc/pipeline/pipeline.cc
浏览文件 @
c90b66a0
...
...
@@ -101,7 +101,7 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::str
MS_LOG
(
INFO
)
<<
"Start new args and compile key:"
<<
key
;
g_args_cache
[
args_spec
]
=
key
++
;
}
py
::
tuple
argSpec
=
py
::
tuple
(
2
);
auto
argSpec
=
py
::
tuple
(
2
);
argSpec
[
0
]
=
name
;
argSpec
[
1
]
=
g_args_cache
[
args_spec
];
return
argSpec
;
...
...
mindspore/ccsrc/pipeline/pipeline_ge.cc
浏览文件 @
c90b66a0
...
...
@@ -52,11 +52,11 @@ void DoExecNonInputGraph(const std::string &phase) {
transform
::
RunOptions
run_options
;
run_options
.
name
=
phase
;
auto
graph_runner
=
DfGraphManager
::
GetInstance
().
GetGraphRunner
();
if
(
graph_runner
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Can not found GraphRunner"
;
return
;
}
{
// Release GIL before calling into (potentially long-running) C++ code
py
::
gil_scoped_release
release
;
...
...
@@ -181,7 +181,6 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::di
size_t
pos
=
phase
.
find
(
'.'
);
std
::
string
net_id
=
((
pos
==
std
::
string
::
npos
||
pos
==
phase
.
size
()
-
1
)
?
phase
:
phase
.
substr
(
pos
+
1
));
std
::
string
phase_prefix
=
phase
.
substr
(
0
,
pos
);
if
(
phase_prefix
==
"export"
)
{
MS_LOG
(
INFO
)
<<
"Set DfGraphConvertor training : false"
;
convertor
.
set_training
(
false
);
...
...
@@ -319,19 +318,24 @@ void RunGEInitGraph(const py::dict &init_params, const std::string &phase) {
py
::
object
ExtractGeneralCnodeRet
(
const
AbstractBasePtr
&
cnode_data
,
const
py
::
tuple
&
data
,
size_t
*
count
)
{
MS_EXCEPTION_IF_NULL
(
cnode_data
);
if
(
*
count
>=
data
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The number of elements in the outputs : "
<<
data
.
size
()
<<
" less than the number of elements required. "
;
}
if
(
cnode_data
->
isa
<
AbstractTensor
>
())
{
if
(
*
count
>=
data
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The number of elements in the outputs : "
<<
data
.
size
()
<<
" less than the number of elements required. "
;
}
BaseShapePtr
shape
=
cnode_data
->
BuildShape
();
auto
shape_act
=
shape
->
cast
<
abstract
::
ShapePtr
>
()
->
shape
();
Tensor
tensor_exp
=
py
::
cast
<
Tensor
>
(
data
[
*
count
]);
if
(
shape_act
!=
tensor_exp
.
shape
())
{
MS_LOG
(
EXCEPTION
)
<<
"The shape of the tensor returned from GE is not the same as "
"the shape of the tensor derived from ME."
;
if
(
!
shape
->
isa
<
abstract
::
Shape
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"The shape of the tensor derived is not Shape, is "
<<
shape
->
ToString
();
}
auto
shape_me
=
shape
->
cast
<
abstract
::
ShapePtr
>
()
->
shape
();
auto
shape_ge
=
py
::
cast
<
Tensor
>
(
data
[
*
count
]).
shape
();
if
(
shape_ge
!=
shape_me
)
{
MS_LOG
(
EXCEPTION
)
<<
"The shape of the "
<<
*
count
<<
"th tensor returned: "
<<
shape_ge
<<
" is not the same as the shape of the tensor derived: "
<<
shape_me
;
}
return
data
[(
*
count
)
++
];
}
...
...
@@ -343,7 +347,7 @@ py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::t
auto
data_tp
=
cnode_data
->
cast
<
AbstractTuplePtr
>
();
auto
elements
=
data_tp
->
elements
();
size_t
size
=
data_tp
->
size
();
py
::
tuple
tp
=
py
::
tuple
(
size
);
auto
tp
=
py
::
tuple
(
size
);
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
tp
[
i
]
=
ExtractGeneralCnodeRet
(
elements
[
i
],
data
,
count
);
}
...
...
@@ -357,11 +361,11 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
return
ValuePtrToPyData
(
GetValueNode
(
output_node
));
}
if
(
*
count
>=
data
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The number of elements in the outputs : "
<<
data
.
size
()
<<
" less than the number of elements required. "
;
}
if
(
output_node
->
isa
<
Parameter
>
())
{
if
(
*
count
>=
data
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The number of elements in the outputs : "
<<
data
.
size
()
<<
" less than the number of elements required. "
;
}
return
data
[(
*
count
)
++
];
}
...
...
@@ -374,7 +378,7 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
if
(
output_c
->
IsApply
(
prim
::
kPrimMakeTuple
))
{
auto
input_list
=
output_c
->
inputs
();
size_t
size
=
input_list
.
size
();
py
::
tuple
tp
=
py
::
tuple
(
size
-
1
);
auto
tp
=
py
::
tuple
(
size
-
1
);
for
(
size_t
i
=
1
;
i
<
size
;
i
++
)
{
tp
[
i
-
1
]
=
StructureOutput
(
input_list
[
i
],
data
,
count
);
}
...
...
@@ -396,11 +400,8 @@ std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr &graph, const std::ve
std
::
vector
<
GeTensorPtr
>
ge_outputs
;
transform
::
RunOptions
run_options
;
run_options
.
name
=
phase
;
auto
graph_runner
=
DfGraphManager
::
GetInstance
().
GetGraphRunner
();
if
(
graph_runner
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Can not found GraphRunner."
;
}
...
...
@@ -473,7 +474,6 @@ void ProcessGeArg(const std::map<std::string, ExecutorInfoPtr> &info, const py::
py
::
object
ExecDFGraph
(
const
std
::
map
<
std
::
string
,
ExecutorInfoPtr
>
&
info
,
const
py
::
tuple
&
args
,
const
std
::
string
&
phase
)
{
std
::
string
phase_prefix
=
GetPhasePrefix
(
phase
);
if
(
phase_prefix
==
"save"
)
{
DoExecNonInputGraph
(
phase
);
ConfigManager
::
GetInstance
().
ResetConfig
();
...
...
@@ -483,7 +483,6 @@ py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const
if
(
info
.
count
(
phase
)
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"There is no phase:"
<<
phase
;
}
FuncGraphPtr
anf_graph
=
info
.
at
(
phase
)
->
func_graph
;
#ifdef ENABLE_INFER
...
...
mindspore/ccsrc/pipeline/pipeline_ge.h
浏览文件 @
c90b66a0
...
...
@@ -31,7 +31,6 @@
namespace
mindspore
{
namespace
pipeline
{
namespace
py
=
pybind11
;
void
SetGeOption
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
options
);
...
...
@@ -50,7 +49,6 @@ bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batc
const
std
::
vector
<
int64_t
>
&
input_indexes
,
const
std
::
string
&
phase
);
void
ExportDFGraph
(
const
std
::
string
&
file_name
,
const
std
::
string
&
phase
);
}
// namespace pipeline
}
// namespace mindspore
...
...
mindspore/ccsrc/pipeline/static_analysis/abstract_function.h
浏览文件 @
c90b66a0
...
...
@@ -41,7 +41,7 @@ class AbstractFuncAtom : public AbstractFunction {
AbstractFunctionPtr
Join
(
const
AbstractFunctionPtr
&
other
)
final
;
void
Visit
(
std
::
function
<
void
(
const
AbstractFuncAtomPtr
&
)
>
)
const
final
;
bool
operator
==
(
const
AbstractFunction
&
other
)
const
;
bool
operator
==
(
const
AbstractFunction
&
other
)
const
override
;
std
::
size_t
hash
()
const
override
{
return
tid
();
}
};
...
...
@@ -270,7 +270,7 @@ class TypedPrimitiveAbstractClosure : public AbstractFuncAtom {
class
DummyAbstractClosure
:
public
AbstractFuncAtom
{
public:
DummyAbstractClosure
()
=
default
;
~
DummyAbstractClosure
()
=
default
;
~
DummyAbstractClosure
()
override
=
default
;
MS_DECLARE_PARENT
(
DummyAbstractClosure
,
AbstractFuncAtom
)
EvaluatorPtr
GetEvaluator
(
AnalysisEnginePtr
)
override
{
MS_LOG
(
EXCEPTION
)
<<
"A dummy function cannot eval."
;
}
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
c90b66a0
...
...
@@ -295,7 +295,6 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic
[
"shape"
]
=
shape
;
dic
[
"dtype"
]
=
arg_slice
->
BuildType
();
dic
[
"value"
]
=
BuildValue
(
arg_slice
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractTuple
>
())
{
auto
arg_tuple
=
dyn_cast
<
AbstractTuple
>
(
abs_base
);
size_t
len
=
arg_tuple
->
size
();
...
...
mindspore/ccsrc/transform/util.cc
浏览文件 @
c90b66a0
...
...
@@ -171,20 +171,17 @@ GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::s
MS_LOG
(
ERROR
)
<<
"The Me Tensor data type size is wrong, type size is: "
<<
type_size
;
return
nullptr
;
}
// get tensor buff size
size_t
data_buff_size
=
0
;
size_t
elements_num
=
IntToSize
(
tensor
->
ElementsNum
());
if
(
elements_num
>
0
&&
type_size
>
0
&&
UINT_MAX
/
type_size
>=
elements_num
)
{
data_buff_size
=
elements_num
*
type_size
;
if
(
UINT_MAX
/
type_size
<
elements_num
)
{
MS_LOG
(
ERROR
)
<<
"The required Me Tensor data buff size "
<<
elements_num
<<
" x "
<<
type_size
<<
" overflowed UINT_MAX: "
<<
UINT_MAX
<<
"."
;
return
nullptr
;
}
// get tensor buff size
size_t
data_buff_size
=
elements_num
*
type_size
;
if
(
data_buff_size
==
0
)
{
if
(
elements_num
>
0
&&
type_size
>
0
&&
UINT_MAX
/
type_size
<
elements_num
)
{
MS_LOG
(
ERROR
)
<<
"The required Me Tensor data buff size "
<<
elements_num
<<
" x "
<<
type_size
<<
" overflowed UINT_MAX: "
<<
UINT_MAX
<<
"."
;
}
else
{
MS_LOG
(
ERROR
)
<<
"The Me Tensor data buff size is 0."
;
}
return
nullptr
;
MS_LOG
(
INFO
)
<<
"The Me Tensor data buff size is 0."
;
}
// create ge tensor
auto
desc
=
GetGeTensorDesc
(
tensor
->
shape_c
(),
tensor
->
data_type
(),
format
);
...
...
mindspore/nn/optim/momentum.py
浏览文件 @
c90b66a0
...
...
@@ -56,7 +56,7 @@ class Momentum(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
Tensor[bool], the value is
True.
tuple[bool], all elements are
True.
Raises:
ValueError: If the momentum is less than 0.0.
...
...
mindspore/ops/_op_impl/tbe/__init__.py
浏览文件 @
c90b66a0
...
...
@@ -142,8 +142,12 @@ from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe
from
.fused_mul_add
import
_fused_mul_add_tbe
from
.fused_mul_add_n
import
_fused_mul_add_n_tbe
from
.fused_mul_apply_momentum
import
_fused_mul_apply_momentum_tbe
from
.fill
_d
import
_fill_d
_op_tbe
from
.fill
import
_fill
_op_tbe
from
.erf
import
_erf_op_tbe
from
.depthwise_conv2d
import
_depthwise_conv2d_tbe
from
.depthwise_conv2d_backprop_filter
import
_depthwise_conv2d_backprop_filter_tbe
from
.depthwise_conv2d_backprop_input
import
_depthwise_conv2d_backprop_input_tbe
from
.greater_equal
import
_greater_equal_tbe
from
.not_equal
import
_not_equal_tbe
from
.floor_mod
import
_floor_mod_tbe
from
.scatter_nd_update
import
_scatter_nd_update_tbe
mindspore/ops/_op_impl/tbe/fill
_d
.py
→
mindspore/ops/_op_impl/tbe/fill.py
浏览文件 @
c90b66a0
...
...
@@ -16,7 +16,7 @@
"""FillD op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
fill_d_op_info
=
TBERegOp
(
"Fill
D
"
)
\
fill_d_op_info
=
TBERegOp
(
"Fill"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"fill_d.so"
)
\
...
...
@@ -50,6 +50,6 @@ fill_d_op_info = TBERegOp("FillD") \
@
op_info_register
(
fill_d_op_info
)
def
_fill_
d_
op_tbe
():
def
_fill_op_tbe
():
"""FillD TBE register"""
return
mindspore/ops/_op_impl/tbe/floor_mod.py
0 → 100644
浏览文件 @
c90b66a0
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FloorMod op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
floor_mod_op_info
=
TBERegOp
(
"FloorMod"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"floor_mod.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"floor_mod"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"x1"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"x2"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
get_op_info
()
@
op_info_register
(
floor_mod_op_info
)
def
_floor_mod_tbe
():
"""FloorMod TBE register"""
return
mindspore/ops/_op_impl/tbe/greater_equal.py
0 → 100644
浏览文件 @
c90b66a0
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GreaterEqual op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
greater_equal_op_info
=
TBERegOp
(
"GreaterEqual"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"greater_equal.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"greater_equal"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"x1"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"x2"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
I8_5HD
,
DataType
.
I8_5HD
,
DataType
.
BOOL_5HD
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
U8_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
U8_5HD
,
DataType
.
U8_5HD
,
DataType
.
BOOL_5HD
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
I32_5HD
,
DataType
.
I32_5HD
,
DataType
.
BOOL_5HD
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
BOOL_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
BOOL_5HD
)
\
.
get_op_info
()
@
op_info_register
(
greater_equal_op_info
)
def
_greater_equal_tbe
():
"""Greater TBE register"""
return
mindspore/ops/_op_impl/tbe/not_equal.py
0 → 100644
浏览文件 @
c90b66a0
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NotEqual op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
not_equal_op_info
=
TBERegOp
(
"NotEqual"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"not_equal.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"not_equal"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"x1"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"x2"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
I8_5HD
,
DataType
.
I8_5HD
,
DataType
.
BOOL_5HD
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
U8_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
U8_5HD
,
DataType
.
U8_5HD
,
DataType
.
BOOL_5HD
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
I32_5HD
,
DataType
.
I32_5HD
,
DataType
.
BOOL_5HD
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
BOOL_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
BOOL_5HD
)
\
.
get_op_info
()
@
op_info_register
(
not_equal_op_info
)
def
_not_equal_tbe
():
"""Equal TBE register"""
return
mindspore/ops/_op_impl/tbe/scatter_nd.py
浏览文件 @
c90b66a0
...
...
@@ -37,5 +37,5 @@ scatter_nd_op_info = TBERegOp("ScatterNd") \
@
op_info_register
(
scatter_nd_op_info
)
def
_scatter_nd_tbe
():
"""
Conv2D
TBE register"""
"""
ScatterNd
TBE register"""
return
mindspore/ops/_op_impl/tbe/scatter_nd_update.py
0 → 100644
浏览文件 @
c90b66a0
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterNdUpdate op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
scatter_nd_update_op_info
=
TBERegOp
(
"ScatterNdUpdate"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"scatter_nd_update.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"scatter_nd_update"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"use_locking"
,
"optional"
,
"bool"
,
"all"
)
\
.
input
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"indices"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"updates"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
I32_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
I32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I32_Default
,
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
I32_Default
,
DataType
.
U8_Default
,
DataType
.
U8_Default
,)
\
.
dtype_format
(
DataType
.
BOOL_Default
,
DataType
.
I32_Default
,
DataType
.
BOOL_Default
,
DataType
.
BOOL_Default
)
\
.
get_op_info
()
@
op_info_register
(
scatter_nd_update_op_info
)
def
_scatter_nd_update_tbe
():
"""ScatterNdUpdate TBE register"""
return
mindspore/ops/composite/multitype_ops/getitem_impl.py
浏览文件 @
c90b66a0
...
...
@@ -147,6 +147,21 @@ def _tensor_getitem_by_number(data, number_index):
return
_tensor_slice
(
data
,
number_index
)
@
getitem
.
register
(
"Tensor"
,
"None"
)
def
_tensor_getitem_by_none
(
data
,
index
):
"""
Getting item of tensor by None.
Inputs:
data (Tensor): A tensor.
index (None): None.
Outputs:
Tensor, element type is as same as the element type of data.
"""
return
_tensor_slice
(
data
,
index
)
@
getitem
.
register
(
"Tensor"
,
"Slice"
)
def
_tensor_getitem_by_slice
(
data
,
slice_index
):
"""
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
c90b66a0
...
...
@@ -633,15 +633,15 @@ class TruncatedNormal(PrimitiveWithInfer):
dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
Inputs:
- **shape** (
Tensor) - Shape of output tensor. The shape is a 1-D tensor, and type is
int.
- **shape** (
tuple[int]) - Shape of output tensor, is a tuple of positive
int.
Outputs:
Tensor, type of output tensor is same as attribute `dtype`.
Examples:
>>>
input_shape = Tensor(np.array([1, 2, 3])
)
>>>
shape = (1, 2, 3
)
>>> truncated_normal = P.TruncatedNormal()
>>> output = truncated_normal(
input_
shape)
>>> output = truncated_normal(shape)
"""
@
prim_attr_register
...
...
@@ -651,16 +651,12 @@ class TruncatedNormal(PrimitiveWithInfer):
validator
.
check_typename
(
'dtype'
,
dtype
,
mstype
.
number_type
)
def
__infer__
(
self
,
shape
):
shape_t
=
shape
[
'value'
]
validator
.
check_subclass
(
"shape"
,
shape
[
'dtype'
],
mstype
.
tensor
)
shape_n
=
shape_t
.
asnumpy
()
if
shape_n
.
ndim
!=
1
:
raise
ValueError
(
'The rank of input shape must be 1.'
)
if
shape_n
.
dtype
not
in
(
np
.
int32
,
np
.
int64
):
raise
TypeError
(
'The type of input shape must be int32 or int64.'
)
for
i
,
item
in
enumerate
(
shape_n
):
validator
.
check_integer
(
f
"shape[
{
i
}
]"
,
item
.
item
(),
0
,
Rel
.
GT
)
out
=
{
'shape'
:
tuple
(
shape_n
),
shape_value
=
shape
[
'value'
]
validator
.
check_const_input
(
"shape"
,
shape_value
)
validator
.
check_type
(
"shape"
,
shape_value
,
[
tuple
])
for
i
,
value
in
enumerate
(
shape_value
):
validator
.
check_integer
(
f
'
{
i
}
th value of shape'
,
value
,
0
,
Rel
.
GT
)
out
=
{
'shape'
:
shape_value
,
'dtype'
:
mstype
.
tensor_type
(
self
.
dtype
),
'value'
:
None
}
return
out
...
...
@@ -1648,20 +1644,20 @@ class StridedSlice(PrimitiveWithInfer):
validator
.
check_type
(
'shrink_axis_mask'
,
shrink_axis_mask
,
[
int
])
def
__infer__
(
self
,
x
,
begin
,
end
,
strides
):
begin_shape
,
end_shape
,
strides_shape
=
begin
[
'shape'
],
end
[
'shape'
],
strides
[
'shape'
]
if
begin_shape
!=
strides_shape
or
end_shape
!=
strides_shape
:
raise
ValueError
(
"The shape of begin, end and strides in 'StridedSlice' must be equal."
)
validator
.
check_const_input
(
"begin"
,
begin
[
'value'
])
validator
.
check_const_input
(
"end"
,
end
[
'value'
])
validator
.
check_const_input
(
"strides"
,
strides
[
'value'
])
validator
.
check_type
(
"begin"
,
begin
[
'value'
],
[
tuple
])
validator
.
check_type
(
"end"
,
end
[
'value'
],
[
tuple
])
validator
.
check_type
(
"strides"
,
strides
[
'value'
],
[
tuple
])
begin_v
,
end_v
,
strides_v
=
begin
[
'value'
],
end
[
'value'
],
strides
[
'value'
]
validator
.
check_const_input
(
"begin"
,
begin_v
)
validator
.
check_const_input
(
"end"
,
end_v
)
validator
.
check_const_input
(
"strides"
,
strides_v
)
validator
.
check_type
(
"begin"
,
begin_v
,
[
tuple
])
validator
.
check_type
(
"end"
,
end_v
,
[
tuple
])
validator
.
check_type
(
"strides"
,
strides_v
,
[
tuple
])
x_shape
=
x
[
'shape'
]
x_shp_len
=
len
(
x_shape
)
begin_v
,
end_v
,
strides_v
=
begin
[
'value'
],
end
[
'value'
],
strides
[
'value'
]
if
len
(
begin_v
)
!=
x_shp_len
or
len
(
end_v
)
!=
x_shp_len
or
len
(
strides_v
)
!=
x_shp_len
:
raise
ValueError
(
f
"The length of begin index
{
begin_v
}
, end index
{
end_v
}
and strides
{
strides_v
}
"
f
"must be equal to the dims(
{
x_shp_len
}
) of input."
)
ret_shape
=
[]
append_dimensions
=
[]
shrink_pos
=
bin
(
self
.
shrink_axis_mask
)[::
-
1
]
...
...
@@ -1914,6 +1910,11 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
size
,
align_corners
=
False
):
"""Init ResizeNearestNeighbor"""
validator
.
check_type
(
"size"
,
size
,
[
tuple
,
list
])
validator
.
check_type
(
"align_corners"
,
align_corners
,
[
bool
])
validator
.
check_integer
(
"length of size"
,
len
(
size
),
2
,
Rel
.
EQ
)
for
i
,
value
in
enumerate
(
size
):
validator
.
check_integer
(
f
'
{
i
}
th value of size'
,
value
,
0
,
Rel
.
GE
)
self
.
init_prim_io_names
(
inputs
=
[
'image_in'
],
outputs
=
[
'image_out'
])
def
infer_shape
(
self
,
x
):
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
c90b66a0
...
...
@@ -1251,7 +1251,8 @@ class Acosh(PrimitiveWithInfer):
Compute inverse hyperbolic cosine of x element-wise.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`,
and the data type of 'input_x' is number, the element in 'input_x' should be greater than or equal to 1.
Outputs:
Tensor, has the same shape as `input_x`.
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
c90b66a0
...
...
@@ -753,8 +753,15 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'w'
],
outputs
=
[
'output'
])
self
.
kernel_size
=
_check_positive_int_or_tuple
(
'kernel_size'
,
kernel_size
,
self
.
name
)
self
.
stride
=
_check_positive_int_or_tuple
(
'stride'
,
stride
,
self
.
name
)
if
self
.
stride
[
0
]
!=
self
.
stride
[
1
]:
raise
ValueError
(
"The height and width of stride should be equal,"
f
"but got height:
{
self
.
stride
[
0
]
}
, width:
{
self
.
stride
[
1
]
}
"
)
self
.
add_prim_attr
(
'stride'
,
(
1
,
1
,
self
.
stride
[
0
],
self
.
stride
[
1
]))
self
.
dilation
=
_check_positive_int_or_tuple
(
'dilation'
,
dilation
,
self
.
name
)
if
self
.
dilation
[
0
]
!=
self
.
dilation
[
1
]:
raise
ValueError
(
"The height and width of dilation should be equal,"
f
"but got height:
{
self
.
dilation
[
0
]
}
, width:
{
self
.
dilation
[
1
]
}
"
)
self
.
add_prim_attr
(
'dilation'
,
(
1
,
1
,
self
.
dilation
[
0
],
self
.
dilation
[
1
]))
validator
.
check_value_type
(
'pad'
,
pad
,
(
int
,),
self
.
name
)
self
.
pad_mode
=
validator
.
check_string
(
'pad_mode'
,
pad_mode
,
[
'valid'
,
'same'
,
'pad'
],
self
.
name
)
...
...
@@ -771,13 +778,11 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
validator
.
check
(
"x_shape[1]"
,
x_shape
[
1
],
"w_shape[1]"
,
w_shape
[
1
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'kernel_size'
,
self
.
kernel_size
,
'w_shape[2:4]'
,
tuple
(
w_shape
[
2
:
4
]),
Rel
.
EQ
,
self
.
name
)
kernel_size_h
=
w_shape
[
2
]
kernel_size_w
=
w_shape
[
3
]
stride_h
=
self
.
stride
[
2
]
stride_w
=
self
.
stride
[
3
]
dilation_h
=
self
.
dilation
[
2
]
dilation_w
=
self
.
dilation
[
3
]
kernel_size_n
,
_
,
kernel_size_h
,
kernel_size_w
=
w_shape
_
,
_
,
stride_h
,
stride_w
=
self
.
stride
_
,
_
,
dilation_h
,
dilation_w
=
self
.
dilation
if
kernel_size_n
!=
1
:
raise
ValueError
(
f
"The batch of input weight should be 1, but got
{
kernel_size_n
}
"
)
if
self
.
pad_mode
==
"valid"
:
h_out
=
math
.
ceil
((
x_shape
[
2
]
-
dilation_h
*
(
kernel_size_h
-
1
))
/
stride_h
)
w_out
=
math
.
ceil
((
x_shape
[
3
]
-
dilation_w
*
(
kernel_size_w
-
1
))
/
stride_w
)
...
...
@@ -1198,8 +1203,8 @@ class TopK(PrimitiveWithInfer):
>>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16)
>>> k = 3
>>> values, indices = topk(input_x, k)
>>> assert values == Tensor(np.array([5, 4, 3]))
>>> assert indices == Tensor(np.array([4, 3, 2]))
>>> assert values == Tensor(np.array([5, 4, 3])
, mstype.float16
)
>>> assert indices == Tensor(np.array([4, 3, 2])
, mstype.int32
)
"""
@
prim_attr_register
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
c90b66a0
...
...
@@ -372,7 +372,7 @@ test_case_math_ops = [
'desc_bprop'
:
[[
3
]]}),
(
'TruncatedNormal'
,
{
'block'
:
P
.
TruncatedNormal
(),
'desc_const'
:
[
Tensor
(
np
.
array
([
1
,
2
,
3
])
)],
'desc_const'
:
[
(
1
,
2
,
3
)],
'desc_inputs'
:
[],
'skip'
:
[
'backward'
],
'add_fake_input'
:
True
}),
...
...
@@ -793,8 +793,8 @@ test_case_nn_ops = [
'desc_bprop'
:
[[
5
,
5
]]}),
(
'DepthwiseConv2dNative_1'
,
{
'block'
:
P
.
DepthwiseConv2dNative
(
3
,
(
3
,
3
),
pad_mode
=
"pad"
,
pad
=
1
,
stride
=
2
),
'desc_inputs'
:
[[
10
,
32
,
32
,
32
],
[
3
,
32
,
3
,
3
]],
'desc_bprop'
:
[[
10
,
3
0
,
16
,
16
]]}),
'desc_inputs'
:
[[
10
,
32
,
32
,
32
],
[
1
,
32
,
3
,
3
]],
'desc_bprop'
:
[[
10
,
3
2
,
16
,
16
]]}),
(
'DepthwiseConv2dNative_2'
,
{
'block'
:
P
.
DepthwiseConv2dNative
(
1
,
(
3
,
3
),
pad_mode
=
"same"
,
pad
=
0
,
stride
=
1
),
'desc_inputs'
:
[[
2592
,
2048
,
4
,
4
],
[
1
,
2048
,
3
,
3
]],
...
...
tests/ut/python/ops/test_tensor_slice.py
浏览文件 @
c90b66a0
...
...
@@ -52,8 +52,9 @@ class NetWorkSliceEllipsis(Cell):
def
construct
(
self
,
tensor
):
ret0
=
tensor
[
0
:
4
:
2
,
...,
1
]
+
self
.
tensor_ret0
ret1
=
tensor
[...]
+
self
.
tensor_ret1
ret2
=
tensor
[
True
]
+
self
.
tensor_ret2
return
ret0
,
ret1
,
ret2
ret2
=
tensor
[
None
]
+
self
.
tensor_ret2
ret3
=
tensor
[
True
]
+
self
.
tensor_ret2
return
ret0
,
ret1
,
ret2
,
ret3
class
NetWorkReduceDimension
(
Cell
):
...
...
@@ -305,7 +306,7 @@ test_cases = [
'block'
:
NetWorkReduceToScalar
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
6
,
8
,
10
],
np
.
int32
))],
}),
(
'
NetWork
SliceEllipsis'
,
{
(
'
Tensor
SliceEllipsis'
,
{
'block'
:
NetWorkSliceEllipsis
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
6
,
7
,
8
,
9
],
np
.
int32
))],
}),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录