Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
875ebc29
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
875ebc29
编写于
4月 26, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
!662 fix reviewboot and example of TruncatedNormal
Merge pull request !662 from zhangbuxue/fix-reviewboot
上级
5633f6c6
aff6777e
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
26 addition
and
32 deletion
+26
-32
mindspore/ccsrc/debug/info.h
mindspore/ccsrc/debug/info.h
+1
-1
mindspore/ccsrc/debug/trace.cc
mindspore/ccsrc/debug/trace.cc
+1
-1
mindspore/ccsrc/debug/trace_info.h
mindspore/ccsrc/debug/trace_info.h
+1
-1
mindspore/ccsrc/ir/dtype/type.cc
mindspore/ccsrc/ir/dtype/type.cc
+6
-0
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+0
-1
mindspore/ccsrc/optimizer/irpass.h
mindspore/ccsrc/optimizer/irpass.h
+0
-1
mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc
mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc
+0
-1
mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h
mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h
+0
-1
mindspore/ccsrc/pipeline/base.h
mindspore/ccsrc/pipeline/base.h
+0
-2
mindspore/ccsrc/pipeline/pipeline.cc
mindspore/ccsrc/pipeline/pipeline.cc
+1
-1
mindspore/ccsrc/pipeline/pipeline_ge.cc
mindspore/ccsrc/pipeline/pipeline_ge.cc
+3
-9
mindspore/ccsrc/pipeline/pipeline_ge.h
mindspore/ccsrc/pipeline/pipeline_ge.h
+0
-2
mindspore/ccsrc/pipeline/static_analysis/abstract_function.h
mindspore/ccsrc/pipeline/static_analysis/abstract_function.h
+2
-2
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+0
-1
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+10
-7
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+1
-1
未找到文件。
mindspore/ccsrc/debug/info.h
浏览文件 @
875ebc29
...
...
@@ -134,7 +134,7 @@ class DebugInfo : public Base {
explicit
DebugInfo
(
const
LocationPtr
&
loc
);
virtual
~
DebugInfo
()
=
default
;
~
DebugInfo
()
override
=
default
;
MS_DECLARE_PARENT
(
DebugInfo
,
Base
);
int64_t
debug_id
();
int64_t
unique_id
()
const
{
return
unique_id_
;
}
...
...
mindspore/ccsrc/debug/trace.cc
浏览文件 @
875ebc29
...
...
@@ -231,10 +231,10 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
auto
engine
=
node_cfg_
->
engine
();
auto
cfg
=
engine
->
MakeConfig
(
node
,
ctx
);
auto
abs
=
engine
->
cache
().
GetValue
(
cfg
);
if
(
abs
==
nullptr
)
{
return
"Undefined"
;
}
auto
dtype
=
abs
->
BuildType
();
auto
shape
=
abs
->
BuildShape
();
std
::
ostringstream
oss
;
...
...
mindspore/ccsrc/debug/trace_info.h
浏览文件 @
875ebc29
...
...
@@ -321,7 +321,7 @@ class TraceTransform : public TraceInfo {
std
::
string
full_name
()
override
{
return
full_name_
+
transform_name_
;
}
MS_DECLARE_PARENT
(
TraceTransform
,
TraceInfo
);
virtual
std
::
string
symbol
()
{
std
::
string
symbol
()
override
{
if
(
transform_name_
.
empty
())
{
return
""
;
}
...
...
mindspore/ccsrc/ir/dtype/type.cc
浏览文件 @
875ebc29
...
...
@@ -87,6 +87,12 @@ const char *MetaIdLabel(const TypeId &v) {
return
"kMetaTypeExternal"
;
case
kMetaTypeNone
:
return
"kMetaTypeNone"
;
case
kMetaTypeNull
:
return
"kMetaTypeNull"
;
case
kMetaTypeEllipsis
:
return
"kMetaTypeEllipsis"
;
case
kMetaTypeEnd
:
return
"kMetaTypeEnd"
;
default:
return
"[Unknown Type Id]"
;
}
...
...
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
875ebc29
...
...
@@ -133,7 +133,6 @@ ResolveIRPassLib::ResolveIRPassLib() {
InferenceOptPrepareLib
::
InferenceOptPrepareLib
()
{
grad_var_prepare_
=
MakeSubstitution
(
GradVarPrepare
(),
"grad_var_prepare"
,
IsCNode
);
}
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/optimizer/irpass.h
浏览文件 @
875ebc29
...
...
@@ -159,7 +159,6 @@ inline bool IsCNodeDup(const AnfNodePtr &node) {
}
return
false
;
}
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc
浏览文件 @
875ebc29
...
...
@@ -31,7 +31,6 @@
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
static
AnfNodePtr
GenerateUnpackGraphNode
(
std
::
vector
<
AnfNodePtr
>
inputs_y
,
FuncGraphPtr
func_graph
,
AnfNodePtr
func_node
,
bool
is_unpack
,
bool
sens_param
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
...
...
mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h
浏览文件 @
875ebc29
...
...
@@ -33,7 +33,6 @@
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
// {{GradOperation, g, w}, Ys}
// {UnPackCall, {GradOperation, g, w}, Ys}
class
GradVarPrepare
:
public
AnfVisitor
{
...
...
mindspore/ccsrc/pipeline/base.h
浏览文件 @
875ebc29
...
...
@@ -28,13 +28,11 @@
namespace
mindspore
{
namespace
pipeline
{
struct
ExecutorInfo
{
FuncGraphPtr
func_graph
;
ResourcePtr
resource
;
std
::
size_t
arg_list_size
;
};
using
ExecutorInfoPtr
=
std
::
shared_ptr
<
ExecutorInfo
>
;
inline
std
::
string
GetPhasePrefix
(
const
std
::
string
&
phase
)
{
...
...
mindspore/ccsrc/pipeline/pipeline.cc
浏览文件 @
875ebc29
...
...
@@ -101,7 +101,7 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::str
MS_LOG
(
INFO
)
<<
"Start new args and compile key:"
<<
key
;
g_args_cache
[
args_spec
]
=
key
++
;
}
py
::
tuple
argSpec
=
py
::
tuple
(
2
);
auto
argSpec
=
py
::
tuple
(
2
);
argSpec
[
0
]
=
name
;
argSpec
[
1
]
=
g_args_cache
[
args_spec
];
return
argSpec
;
...
...
mindspore/ccsrc/pipeline/pipeline_ge.cc
浏览文件 @
875ebc29
...
...
@@ -52,11 +52,11 @@ void DoExecNonInputGraph(const std::string &phase) {
transform
::
RunOptions
run_options
;
run_options
.
name
=
phase
;
auto
graph_runner
=
DfGraphManager
::
GetInstance
().
GetGraphRunner
();
if
(
graph_runner
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Can not found GraphRunner"
;
return
;
}
{
// Release GIL before calling into (potentially long-running) C++ code
py
::
gil_scoped_release
release
;
...
...
@@ -181,7 +181,6 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::di
size_t
pos
=
phase
.
find
(
'.'
);
std
::
string
net_id
=
((
pos
==
std
::
string
::
npos
||
pos
==
phase
.
size
()
-
1
)
?
phase
:
phase
.
substr
(
pos
+
1
));
std
::
string
phase_prefix
=
phase
.
substr
(
0
,
pos
);
if
(
phase_prefix
==
"export"
)
{
MS_LOG
(
INFO
)
<<
"Set DfGraphConvertor training : false"
;
convertor
.
set_training
(
false
);
...
...
@@ -348,7 +347,7 @@ py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::t
auto
data_tp
=
cnode_data
->
cast
<
AbstractTuplePtr
>
();
auto
elements
=
data_tp
->
elements
();
size_t
size
=
data_tp
->
size
();
py
::
tuple
tp
=
py
::
tuple
(
size
);
auto
tp
=
py
::
tuple
(
size
);
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
tp
[
i
]
=
ExtractGeneralCnodeRet
(
elements
[
i
],
data
,
count
);
}
...
...
@@ -379,7 +378,7 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
if
(
output_c
->
IsApply
(
prim
::
kPrimMakeTuple
))
{
auto
input_list
=
output_c
->
inputs
();
size_t
size
=
input_list
.
size
();
py
::
tuple
tp
=
py
::
tuple
(
size
-
1
);
auto
tp
=
py
::
tuple
(
size
-
1
);
for
(
size_t
i
=
1
;
i
<
size
;
i
++
)
{
tp
[
i
-
1
]
=
StructureOutput
(
input_list
[
i
],
data
,
count
);
}
...
...
@@ -401,11 +400,8 @@ std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr &graph, const std::ve
std
::
vector
<
GeTensorPtr
>
ge_outputs
;
transform
::
RunOptions
run_options
;
run_options
.
name
=
phase
;
auto
graph_runner
=
DfGraphManager
::
GetInstance
().
GetGraphRunner
();
if
(
graph_runner
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Can not found GraphRunner."
;
}
...
...
@@ -478,7 +474,6 @@ void ProcessGeArg(const std::map<std::string, ExecutorInfoPtr> &info, const py::
py
::
object
ExecDFGraph
(
const
std
::
map
<
std
::
string
,
ExecutorInfoPtr
>
&
info
,
const
py
::
tuple
&
args
,
const
std
::
string
&
phase
)
{
std
::
string
phase_prefix
=
GetPhasePrefix
(
phase
);
if
(
phase_prefix
==
"save"
)
{
DoExecNonInputGraph
(
phase
);
ConfigManager
::
GetInstance
().
ResetConfig
();
...
...
@@ -488,7 +483,6 @@ py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const
if
(
info
.
count
(
phase
)
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"There is no phase:"
<<
phase
;
}
FuncGraphPtr
anf_graph
=
info
.
at
(
phase
)
->
func_graph
;
#ifdef ENABLE_INFER
...
...
mindspore/ccsrc/pipeline/pipeline_ge.h
浏览文件 @
875ebc29
...
...
@@ -31,7 +31,6 @@
namespace
mindspore
{
namespace
pipeline
{
namespace
py
=
pybind11
;
void
SetGeOption
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
options
);
...
...
@@ -50,7 +49,6 @@ bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batc
const
std
::
vector
<
int64_t
>
&
input_indexes
,
const
std
::
string
&
phase
);
void
ExportDFGraph
(
const
std
::
string
&
file_name
,
const
std
::
string
&
phase
);
}
// namespace pipeline
}
// namespace mindspore
...
...
mindspore/ccsrc/pipeline/static_analysis/abstract_function.h
浏览文件 @
875ebc29
...
...
@@ -41,7 +41,7 @@ class AbstractFuncAtom : public AbstractFunction {
AbstractFunctionPtr
Join
(
const
AbstractFunctionPtr
&
other
)
final
;
void
Visit
(
std
::
function
<
void
(
const
AbstractFuncAtomPtr
&
)
>
)
const
final
;
bool
operator
==
(
const
AbstractFunction
&
other
)
const
;
bool
operator
==
(
const
AbstractFunction
&
other
)
const
override
;
std
::
size_t
hash
()
const
override
{
return
tid
();
}
};
...
...
@@ -270,7 +270,7 @@ class TypedPrimitiveAbstractClosure : public AbstractFuncAtom {
class
DummyAbstractClosure
:
public
AbstractFuncAtom
{
public:
DummyAbstractClosure
()
=
default
;
~
DummyAbstractClosure
()
=
default
;
~
DummyAbstractClosure
()
override
=
default
;
MS_DECLARE_PARENT
(
DummyAbstractClosure
,
AbstractFuncAtom
)
EvaluatorPtr
GetEvaluator
(
AnalysisEnginePtr
)
override
{
MS_LOG
(
EXCEPTION
)
<<
"A dummy function cannot eval."
;
}
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
875ebc29
...
...
@@ -295,7 +295,6 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic
[
"shape"
]
=
shape
;
dic
[
"dtype"
]
=
arg_slice
->
BuildType
();
dic
[
"value"
]
=
BuildValue
(
arg_slice
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractTuple
>
())
{
auto
arg_tuple
=
dyn_cast
<
AbstractTuple
>
(
abs_base
);
size_t
len
=
arg_tuple
->
size
();
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
875ebc29
...
...
@@ -639,9 +639,9 @@ class TruncatedNormal(PrimitiveWithInfer):
Tensor, type of output tensor is same as attribute `dtype`.
Examples:
>>>
input_shape = Tensor(np.array([1, 2, 3])
)
>>>
shape = (1, 2, 3
)
>>> truncated_normal = P.TruncatedNormal()
>>> output = truncated_normal(
input_
shape)
>>> output = truncated_normal(shape)
"""
@
prim_attr_register
...
...
@@ -652,6 +652,8 @@ class TruncatedNormal(PrimitiveWithInfer):
def
__infer__
(
self
,
shape
):
shape_value
=
shape
[
'value'
]
validator
.
check_const_input
(
"shape"
,
shape_value
)
validator
.
check_type
(
"shape"
,
shape_value
,
[
tuple
])
for
i
,
value
in
enumerate
(
shape_value
):
validator
.
check_integer
(
f
'
{
i
}
th value of shape'
,
value
,
0
,
Rel
.
GT
)
out
=
{
'shape'
:
shape_value
,
...
...
@@ -1642,15 +1644,16 @@ class StridedSlice(PrimitiveWithInfer):
validator
.
check_type
(
'shrink_axis_mask'
,
shrink_axis_mask
,
[
int
])
def
__infer__
(
self
,
x
,
begin
,
end
,
strides
):
x_shape
=
x
[
'shape'
]
x_shp_len
=
len
(
x_shape
)
begin_v
,
end_v
,
strides_v
=
begin
[
'value'
],
end
[
'value'
],
strides
[
'value'
]
validator
.
check_const_input
(
"begin"
,
begin_v
)
validator
.
check_const_input
(
"end"
,
end_v
)
validator
.
check_const_input
(
"strides"
,
strides_v
)
validator
.
check_type
(
"begin"
,
begin
[
'value'
],
[
tuple
])
validator
.
check_type
(
"end"
,
end
[
'value'
],
[
tuple
])
validator
.
check_type
(
"strides"
,
strides
[
'value'
],
[
tuple
])
validator
.
check_type
(
"begin"
,
begin_v
,
[
tuple
])
validator
.
check_type
(
"end"
,
end_v
,
[
tuple
])
validator
.
check_type
(
"strides"
,
strides_v
,
[
tuple
])
x_shape
=
x
[
'shape'
]
x_shp_len
=
len
(
x_shape
)
if
len
(
begin_v
)
!=
x_shp_len
or
len
(
end_v
)
!=
x_shp_len
or
len
(
strides_v
)
!=
x_shp_len
:
raise
ValueError
(
f
"The length of begin index
{
begin_v
}
, end index
{
end_v
}
and strides
{
strides_v
}
"
f
"must be equal to the dims(
{
x_shp_len
}
) of input."
)
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
875ebc29
...
...
@@ -372,7 +372,7 @@ test_case_math_ops = [
'desc_bprop'
:
[[
3
]]}),
(
'TruncatedNormal'
,
{
'block'
:
P
.
TruncatedNormal
(),
'desc_const'
:
[
[
1
,
2
,
3
]
],
'desc_const'
:
[
(
1
,
2
,
3
)
],
'desc_inputs'
:
[],
'skip'
:
[
'backward'
],
'add_fake_input'
:
True
}),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录