Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9682d08d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9682d08d
编写于
7月 09, 2020
作者:
W
WilliamLian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor primitive hook function
上级
0a3bf64b
变更
28
隐藏空白更改
内联
并排
Showing
28 changed file
with
429 addition
and
421 deletion
+429
-421
mindspore/ccsrc/ir/anf.cc
mindspore/ccsrc/ir/anf.cc
+1
-1
mindspore/ccsrc/ir/primitive.cc
mindspore/ccsrc/ir/primitive.cc
+37
-88
mindspore/ccsrc/ir/primitive.h
mindspore/ccsrc/ir/primitive.h
+112
-28
mindspore/ccsrc/ir/primitive_base.cc
mindspore/ccsrc/ir/primitive_base.cc
+0
-71
mindspore/ccsrc/ir/primitive_base.h
mindspore/ccsrc/ir/primitive_base.h
+0
-150
mindspore/ccsrc/ir/primitive_extends.cc
mindspore/ccsrc/ir/primitive_extends.cc
+1
-1
mindspore/ccsrc/ir/primitive_py.cc
mindspore/ccsrc/ir/primitive_py.cc
+195
-0
mindspore/ccsrc/ir/primitive_py.h
mindspore/ccsrc/ir/primitive_py.h
+72
-0
mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc
mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc
+0
-1
mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc
mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc
+0
-1
mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc
mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc
+0
-1
mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc
...csrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc
+0
-1
mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc
mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc
+0
-1
mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc
mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc
+0
-1
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+1
-1
mindspore/ccsrc/optimizer/ad/kprim.cc
mindspore/ccsrc/optimizer/ad/kprim.cc
+2
-5
mindspore/ccsrc/optimizer/py_pass_manager.h
mindspore/ccsrc/optimizer/py_pass_manager.h
+1
-1
mindspore/ccsrc/pipeline/static_analysis/static_analysis.h
mindspore/ccsrc/pipeline/static_analysis/static_analysis.h
+1
-1
mindspore/ccsrc/pipeline/static_analysis/utils.h
mindspore/ccsrc/pipeline/static_analysis/utils.h
+0
-1
mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc
...pore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc
+0
-9
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
...rc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
+0
-1
mindspore/ccsrc/pynative/base.h
mindspore/ccsrc/pynative/base.h
+1
-1
mindspore/ccsrc/transform/op_adapter_base.h
mindspore/ccsrc/transform/op_adapter_base.h
+0
-1
mindspore/ccsrc/utils/graph_utils.h
mindspore/ccsrc/utils/graph_utils.h
+1
-1
mindspore/ccsrc/vm/vm.cc
mindspore/ccsrc/vm/vm.cc
+2
-51
mindspore/ccsrc/vm/vm.h
mindspore/ccsrc/vm/vm.h
+0
-1
mindspore/ccsrc/vm/vmimpl.cc
mindspore/ccsrc/vm/vmimpl.cc
+1
-1
tests/ut/cpp/operator/ops_test.cc
tests/ut/cpp/operator/ops_test.cc
+1
-1
未找到文件。
mindspore/ccsrc/ir/anf.cc
浏览文件 @
9682d08d
...
...
@@ -24,7 +24,7 @@
#include <unordered_map>
#include "ir/func_graph.h"
#include "ir/primitive
_base
.h"
#include "ir/primitive.h"
#include "utils/context/ms_context.h"
#include "operator/ops.h"
...
...
mindspore/ccsrc/ir/primitive.cc
浏览文件 @
9682d08d
...
...
@@ -15,108 +15,57 @@
*/
#include "ir/primitive.h"
#include <mutex>
#include <utility>
#include "ir/signature.h"
#include "operator/ops.h"
#include "./common.h"
#include "pipeline/parse/python_adapter.h"
#include "pipeline/parse/data_converter.h"
#include "pybind11/pytypes.h"
#include "utils/convert_utils_base.h"
#include "utils/primitive_utils.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
#include <utility>
namespace
mindspore
{
static
ValuePtr
PyArgToValue
(
const
py
::
object
&
arg
)
{
if
(
py
::
isinstance
<
SignatureEnumKind
>
(
arg
)
&&
py
::
cast
<
SignatureEnumKind
>
(
arg
)
==
SignatureEnumKind
::
kKindEmptyDefaultValue
)
{
return
nullptr
;
}
return
parse
::
data_converter
::
PyDataToValue
(
arg
);
}
void
PrimitivePy
::
set_signatures
(
std
::
vector
<
std
::
tuple
<
std
::
string
,
SignatureEnumRW
,
SignatureEnumKind
,
py
::
object
,
SignatureEnumDType
>>
signatures
)
{
signatures_
.
clear
();
for
(
auto
&
signature
:
signatures
)
{
auto
[
name
,
rw
,
kind
,
arg_default
,
dtype
]
=
signature
;
auto
default_value
=
PyArgToValue
(
arg_default
);
signatures_
.
emplace_back
(
name
,
rw
,
kind
,
default_value
,
dtype
);
}
set_has_signature
(
true
);
}
py
::
function
PrimitivePy
::
GetBpropFunction
()
{
static
const
char
*
const
get_bprop_func_name
=
"get_bprop"
;
if
(
py
::
hasattr
(
python_obj_
,
get_bprop_func_name
))
{
py
::
function
fn
=
python_obj_
.
attr
(
get_bprop_func_name
)().
cast
<
py
::
function
>
();
return
fn
;
bool
Primitive
::
operator
==
(
const
Value
&
other
)
const
{
if
(
other
.
isa
<
Primitive
>
())
{
auto
other_prim
=
static_cast
<
const
Primitive
&>
(
other
);
return
*
this
==
other_prim
;
}
else
{
auto
fn
=
GetBpropFunctionByObj
(
python_obj_
);
return
fn
;
return
false
;
}
}
py
::
function
PrimitivePy
::
GetComputeFunction
()
{
static
const
char
*
const
compute_func_name
=
"vm_impl"
;
if
(
py
::
hasattr
(
python_obj_
,
compute_func_name
))
{
MS_LOG
(
INFO
)
<<
name
()
<<
" compute_func_name"
;
py
::
function
fn
=
python_obj_
.
attr
(
compute_func_name
).
cast
<
py
::
function
>
();
return
fn
;
bool
Primitive
::
operator
==
(
const
Primitive
&
other
)
const
{
if
(
name
()
!=
other
.
name
())
{
return
false
;
}
static
const
std
::
string
vm_module
=
"mindspore.ops.vm_impl_registry"
;
static
const
std
::
string
get_vm_impl_fn
=
"get_vm_impl_fn"
;
MS_LOG
(
INFO
)
<<
name
()
<<
": get_vm_impl_fn"
;
py
::
function
get_fn
=
parse
::
python_adapter
::
GetPyFn
(
vm_module
,
get_vm_impl_fn
);
py
::
function
vm_fn
=
get_fn
(
python_obj_
);
if
(
py
::
isinstance
<
py
::
none
>
(
vm_fn
))
{
MS_LOG
(
WARNING
)
<<
"Cannot find "
<<
python_obj_
.
attr
(
"__class__"
).
attr
(
"__name__"
).
cast
<
std
::
string
>
();
vm_fn
=
mindspore
::
GetComputeFunction
(
Primitive
::
name
());
if
(
attrs_
.
size
()
!=
other
.
attrs_
.
size
())
{
return
false
;
}
return
vm_fn
;
auto
all
=
std
::
all_of
(
attrs_
.
begin
(),
attrs_
.
end
(),
[
&
other
](
const
std
::
pair
<
std
::
string
,
ValuePtr
>
&
item
)
->
bool
{
if
(
item
.
second
==
nullptr
)
{
return
false
;
}
auto
iter
=
other
.
attrs_
.
find
(
item
.
first
);
if
(
iter
==
other
.
attrs_
.
end
())
{
return
false
;
}
return
*
item
.
second
==
*
iter
->
second
;
});
return
all
;
}
void
PrimitivePy
::
AddPyAttr
(
const
py
::
str
&
name
,
const
py
::
object
&
obj
)
{
std
::
string
attr_name
=
name
;
ValuePtr
converted_ret
=
nullptr
;
if
(
py
::
isinstance
<
py
::
module
>
(
obj
))
{
MS_LOG
(
EXCEPTION
)
<<
"AddPyAttr failed, obj should not be py::module"
;
}
bool
converted
=
parse
::
ConvertData
(
obj
,
&
converted_ret
);
if
(
!
converted
)
{
MS_LOG
(
EXCEPTION
)
<<
"Attribute convert error with type: "
<<
std
::
string
(
py
::
str
(
obj
));
std
::
string
Primitive
::
GetAttrsText
()
const
{
if
(
attrs_
.
empty
())
{
return
""
;
}
(
void
)
this
->
AddAttr
(
attr_name
,
converted_ret
);
}
py
::
dict
PrimitivePy
::
GetAttrDict
()
{
py
::
dict
attr_dict
;
std
::
ostringstream
oss
;
oss
<<
"["
;
bool
is_first
=
true
;
for
(
auto
&
attr
:
attrs_
)
{
attr_dict
[
py
::
str
(
attr
.
first
)]
=
ValuePtrToPyData
(
attr
.
second
);
if
(
is_first
)
{
is_first
=
false
;
}
else
{
oss
<<
", "
;
}
oss
<<
attr
.
first
<<
"="
<<
attr
.
second
->
DumpText
();
}
return
attr_dict
;
}
oss
<<
"]"
;
REGISTER_PYBIND_DEFINE
(
Primitive_
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
enum_
<
PrimType
>
(
*
m
,
"prim_type"
,
py
::
arithmetic
())
.
value
(
"unknown"
,
PrimType
::
kPrimTypeUnknown
)
.
value
(
"builtin"
,
PrimType
::
kPrimTypeBuiltIn
)
.
value
(
"py_infer_shape"
,
PrimType
::
kPrimTypePyInferShape
)
.
value
(
"user_custom"
,
PrimType
::
kPrimTypeUserCustom
);
(
void
)
py
::
class_
<
PrimitivePy
,
std
::
shared_ptr
<
PrimitivePy
>>
(
*
m
,
"Primitive_"
)
.
def_readonly
(
PYTHON_PRIMITIVE_FLAG
,
&
PrimitivePy
::
parse_info_
)
.
def
(
py
::
init
<
py
::
str
&
,
py
::
object
>
())
.
def
(
"add_attr"
,
&
PrimitivePy
::
AddPyAttr
,
"add primitive attr"
)
.
def
(
"get_attr_dict"
,
&
PrimitivePy
::
GetAttrDict
,
"get primitive attr"
)
.
def
(
"set_prim_type"
,
&
PrimitivePy
::
set_prim_type
,
"Set primitive type."
)
.
def
(
"set_signatures"
,
&
PrimitivePy
::
set_signatures
,
"Set primitive inputs signature."
)
.
def
(
"register_hook"
,
&
PrimitivePy
::
set_hook
,
"Set primitive hook function."
)
.
def
(
"set_instance_name"
,
&
PrimitivePy
::
set_instance_name
,
"Set primitive instance name."
);
}));
return
oss
.
str
();
}
}
// namespace mindspore
mindspore/ccsrc/ir/primitive.h
浏览文件 @
9682d08d
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019
-2020
Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
...
...
@@ -23,45 +23,129 @@
#include <string>
#include <tuple>
#include "ir/dtype/type.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "utils/misc.h"
#include "utils/log_adapter.h"
#include "ir/primitive_base.h"
#include "ir/signature.h"
#include "parallel/ops_info/operator_info.h"
#include "utils/base_ref_extends.h"
namespace
mindspore
{
class
PrimitivePy
:
public
Primitive
{
// Supported meta type
enum
PrimType
{
kPrimTypeUnknown
=
0
,
kPrimTypeBegin
=
kTypeUnknown
,
kPrimTypeBuiltIn
,
// Built-in primitive operator
kPrimTypePyInferShape
,
// Primitive operator defined by custom
kPrimTypePyInferTensor
,
// Primitive operator defined by custom
kPrimTypeUserCustom
};
class
Primitive
:
public
Named
{
public:
PrimitivePy
(
const
py
::
str
&
name
,
const
py
::
object
&
python_obj
)
:
Primitive
(
name
,
false
),
python_obj_
(
python_obj
),
signatures_
()
{}
~
PrimitivePy
()
override
=
default
;
MS_DECLARE_PARENT
(
PrimitivePy
,
Primitive
);
py
::
function
GetBpropFunction
();
py
::
function
GetComputeFunction
();
explicit
Primitive
(
const
std
::
string
&
name
,
const
bool
is_base
=
true
,
const
PrimType
prim_type
=
kPrimTypeBuiltIn
)
:
Named
(
name
),
is_base_
(
is_base
),
has_signature_
(
false
),
prim_type_
(
prim_type
),
record_evaluate_add_attr_
(
false
)
{}
Primitive
(
const
Primitive
&
prim
)
:
Named
(
prim
),
attrs_
(
prim
.
attrs_
),
instance_name_
(
prim
.
instance_name_
),
is_base_
(
prim
.
is_base_
),
has_signature_
(
prim
.
has_signature_
),
prim_type_
(
prim
.
prim_type_
),
record_evaluate_add_attr_
(
false
)
{}
MS_DECLARE_PARENT
(
Primitive
,
Named
);
abstract
::
AbstractBasePtr
ToPrimAbstract
(
const
AnfNodePtr
&
anf_node
);
std
::
string
ToString
()
const
override
{
return
name
();
}
void
BeginRecordAddAttr
()
{
evaluate_added_attrs_
.
clear
();
record_evaluate_add_attr_
=
true
;
}
void
EndRecordAddAttr
()
{
record_evaluate_add_attr_
=
false
;
}
Primitive
&
AddAttr
(
const
std
::
string
&
name
,
const
ValuePtr
&
attr
)
{
attrs_
[
name
]
=
attr
;
if
(
record_evaluate_add_attr_
)
{
evaluate_added_attrs_
[
name
]
=
attr
;
}
return
*
this
;
}
Primitive
&
SetAttrs
(
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
)
{
for
(
auto
&
attr
:
attrs
)
{
attrs_
[
attr
.
first
]
=
attr
.
second
;
}
return
*
this
;
}
void
set_signatures
(
std
::
vector
<
std
::
tuple
<
std
::
string
,
SignatureEnumRW
,
SignatureEnumKind
,
py
::
object
,
SignatureEnumDType
>>
signatures
);
void
set_attr
(
const
std
::
string
&
attrName
,
const
ValuePtr
&
attr
)
{
attrs_
[
attrName
]
=
attr
;
}
void
EraseAttr
(
const
std
::
string
&
attrName
)
{
(
void
)
attrs_
.
erase
(
attrName
);
}
const
std
::
vector
<
Signature
>
&
signatures
()
const
{
return
signatures_
;
}
ValuePtr
GetAttr
(
const
std
::
string
&
attrName
)
const
{
auto
iter
=
attrs_
.
find
(
attrName
);
return
iter
==
attrs_
.
cend
()
?
nullptr
:
iter
->
second
;
}
void
AddPyAttr
(
const
py
::
str
&
name
,
const
py
::
object
&
obj
);
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
()
const
{
return
attrs_
;
}
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
evaluate_added_attrs
()
const
{
return
evaluate_added_attrs_
;
}
py
::
dict
GetAttrDict
();
void
set_hook
(
const
py
::
function
&
hook
)
{
hook_
=
hook
;
}
py
::
function
hook
()
const
{
return
hook_
;
}
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool
HasAttr
()
const
{
return
!
attrs_
.
empty
();
}
bool
HasAttr
(
const
std
::
string
&
attrName
)
const
{
auto
iter
=
attrs_
.
find
(
attrName
);
return
!
(
iter
==
attrs_
.
cend
());
}
void
set_prim_type
(
const
PrimType
t
)
{
prim_type_
=
t
;
}
void
set_instance_name
(
const
std
::
string
s
)
{
instance_name_
=
s
;
}
bool
HasPyEvaluator
()
const
{
return
prim_type_
==
kPrimTypePyInferShape
||
prim_type_
==
kPrimTypeUserCustom
;
}
bool
HasPyInferTensor
()
const
{
return
prim_type_
==
kPrimTypePyInferTensor
;
}
bool
IsCustomPrim
()
const
{
return
prim_type_
==
kPrimTypeUserCustom
;
}
const
bool
parse_info_
=
true
;
const
py
::
object
&
GetPyObj
()
const
{
return
python_obj_
;
}
bool
is_tuple_input_
=
false
;
PrimType
prim_type
()
const
{
return
prim_type_
;
}
std
::
string
instance_name
()
const
{
return
instance_name_
;
}
std
::
string
GetAttrsText
()
const
;
bool
operator
==
(
const
Value
&
other
)
const
override
;
bool
operator
==
(
const
Primitive
&
other
)
const
;
~
Primitive
()
override
=
default
;
void
set_has_signature
(
bool
has_signature
)
{
has_signature_
=
has_signature
;
}
bool
has_signature
()
const
{
return
has_signature_
;
}
bool
is_base
()
const
{
return
is_base_
;
}
virtual
BaseRef
RunHookFunction
(
const
VectorRef
&
args
)
const
{
MS_LOG
(
EXCEPTION
)
<<
"call a empty function!"
;
}
virtual
void
CopyHookFunction
(
const
PrimitivePtr
&
primitive
)
{
MS_LOG
(
EXCEPTION
)
<<
"call a empty function!"
;
}
protected:
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
evaluate_added_attrs_
;
private:
py
::
object
python_obj_
;
py
::
function
hook_
;
std
::
vector
<
Signature
>
signatures_
;
std
::
string
instance_name_
;
bool
is_base_
;
bool
has_signature_
;
PrimType
prim_type_
;
bool
record_evaluate_add_attr_
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
PrimitivePtr
&
p
)
{
os
<<
*
p
;
return
os
;
}
struct
PrimitiveEqual
{
bool
operator
()(
PrimitivePtr
const
&
t1
,
PrimitivePtr
const
&
t2
)
const
{
MS_EXCEPTION_IF_NULL
(
t1
);
MS_EXCEPTION_IF_NULL
(
t2
);
return
t1
->
name
()
==
t2
->
name
();
}
};
using
PrimitivePyPtr
=
std
::
shared_ptr
<
PrimitivePy
>
;
struct
PrimitiveHasher
{
std
::
size_t
operator
()(
PrimitivePtr
const
&
prim
)
const
{
MS_EXCEPTION_IF_NULL
(
prim
);
return
prim
->
Hash
();
}
};
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_
mindspore/ccsrc/ir/primitive_base.cc
已删除
100644 → 0
浏览文件 @
0a3bf64b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/primitive_base.h"
#include <utility>
namespace
mindspore
{
bool
Primitive
::
operator
==
(
const
Value
&
other
)
const
{
if
(
other
.
isa
<
Primitive
>
())
{
auto
other_prim
=
static_cast
<
const
Primitive
&>
(
other
);
return
*
this
==
other_prim
;
}
else
{
return
false
;
}
}
bool
Primitive
::
operator
==
(
const
Primitive
&
other
)
const
{
if
(
name
()
!=
other
.
name
())
{
return
false
;
}
if
(
attrs_
.
size
()
!=
other
.
attrs_
.
size
())
{
return
false
;
}
auto
all
=
std
::
all_of
(
attrs_
.
begin
(),
attrs_
.
end
(),
[
&
other
](
const
std
::
pair
<
std
::
string
,
ValuePtr
>
&
item
)
->
bool
{
if
(
item
.
second
==
nullptr
)
{
return
false
;
}
auto
iter
=
other
.
attrs_
.
find
(
item
.
first
);
if
(
iter
==
other
.
attrs_
.
end
())
{
return
false
;
}
return
*
item
.
second
==
*
iter
->
second
;
});
return
all
;
}
std
::
string
Primitive
::
GetAttrsText
()
const
{
if
(
attrs_
.
empty
())
{
return
""
;
}
std
::
ostringstream
oss
;
oss
<<
"["
;
bool
is_first
=
true
;
for
(
auto
&
attr
:
attrs_
)
{
if
(
is_first
)
{
is_first
=
false
;
}
else
{
oss
<<
", "
;
}
oss
<<
attr
.
first
<<
"="
<<
attr
.
second
->
DumpText
();
}
oss
<<
"]"
;
return
oss
.
str
();
}
}
// namespace mindspore
mindspore/ccsrc/ir/primitive_base.h
已删除
100644 → 0
浏览文件 @
0a3bf64b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
#define MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
#include <unordered_map>
#include <vector>
#include <memory>
#include <string>
#include <tuple>
#include "ir/dtype/type.h"
#include "pybind11/pybind11.h"
namespace
py
=
pybind11
;
namespace
mindspore
{
// Supported meta type
enum
PrimType
{
kPrimTypeUnknown
=
0
,
kPrimTypeBegin
=
kTypeUnknown
,
kPrimTypeBuiltIn
,
// Built-in primitive operator
kPrimTypePyInferShape
,
// Primitive operator defined by custom
kPrimTypePyInferTensor
,
// Primitive operator defined by custom
kPrimTypeUserCustom
};
class
Primitive
:
public
Named
{
public:
explicit
Primitive
(
const
std
::
string
&
name
,
const
bool
is_base
=
true
,
const
PrimType
prim_type
=
kPrimTypeBuiltIn
)
:
Named
(
name
),
is_base_
(
is_base
),
has_signature_
(
false
),
prim_type_
(
prim_type
),
record_evaluate_add_attr_
(
false
)
{}
Primitive
(
const
Primitive
&
prim
)
:
Named
(
prim
),
attrs_
(
prim
.
attrs_
),
instance_name_
(
prim
.
instance_name_
),
is_base_
(
prim
.
is_base_
),
has_signature_
(
prim
.
has_signature_
),
prim_type_
(
prim
.
prim_type_
),
record_evaluate_add_attr_
(
false
)
{}
MS_DECLARE_PARENT
(
Primitive
,
Named
);
abstract
::
AbstractBasePtr
ToPrimAbstract
(
const
AnfNodePtr
&
anf_node
);
std
::
string
ToString
()
const
override
{
return
name
();
}
void
BeginRecordAddAttr
()
{
evaluate_added_attrs_
.
clear
();
record_evaluate_add_attr_
=
true
;
}
void
EndRecordAddAttr
()
{
record_evaluate_add_attr_
=
false
;
}
Primitive
&
AddAttr
(
const
std
::
string
&
name
,
const
ValuePtr
&
attr
)
{
attrs_
[
name
]
=
attr
;
if
(
record_evaluate_add_attr_
)
{
evaluate_added_attrs_
[
name
]
=
attr
;
}
return
*
this
;
}
Primitive
&
SetAttrs
(
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
)
{
for
(
auto
&
attr
:
attrs
)
{
attrs_
[
attr
.
first
]
=
attr
.
second
;
}
return
*
this
;
}
void
set_attr
(
const
std
::
string
&
attrName
,
const
ValuePtr
&
attr
)
{
attrs_
[
attrName
]
=
attr
;
}
void
EraseAttr
(
const
std
::
string
&
attrName
)
{
(
void
)
attrs_
.
erase
(
attrName
);
}
ValuePtr
GetAttr
(
const
std
::
string
&
attrName
)
const
{
auto
iter
=
attrs_
.
find
(
attrName
);
return
iter
==
attrs_
.
cend
()
?
nullptr
:
iter
->
second
;
}
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
()
const
{
return
attrs_
;
}
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
evaluate_added_attrs
()
const
{
return
evaluate_added_attrs_
;
}
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool
HasAttr
()
const
{
return
!
attrs_
.
empty
();
}
bool
HasAttr
(
const
std
::
string
&
attrName
)
const
{
auto
iter
=
attrs_
.
find
(
attrName
);
return
!
(
iter
==
attrs_
.
cend
());
}
void
set_prim_type
(
const
PrimType
t
)
{
prim_type_
=
t
;
}
void
set_instance_name
(
const
std
::
string
s
)
{
instance_name_
=
s
;
}
bool
HasPyEvaluator
()
const
{
return
prim_type_
==
kPrimTypePyInferShape
||
prim_type_
==
kPrimTypeUserCustom
;
}
bool
HasPyInferTensor
()
const
{
return
prim_type_
==
kPrimTypePyInferTensor
;
}
bool
IsCustomPrim
()
const
{
return
prim_type_
==
kPrimTypeUserCustom
;
}
PrimType
prim_type
()
const
{
return
prim_type_
;
}
std
::
string
instance_name
()
const
{
return
instance_name_
;
}
std
::
string
GetAttrsText
()
const
;
bool
operator
==
(
const
Value
&
other
)
const
override
;
bool
operator
==
(
const
Primitive
&
other
)
const
;
~
Primitive
()
override
=
default
;
void
set_has_signature
(
bool
has_signature
)
{
has_signature_
=
has_signature
;
}
bool
has_signature
()
const
{
return
has_signature_
;
}
bool
is_base
()
const
{
return
is_base_
;
}
protected:
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
evaluate_added_attrs_
;
private:
std
::
string
instance_name_
;
bool
is_base_
;
bool
has_signature_
;
PrimType
prim_type_
;
bool
record_evaluate_add_attr_
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
PrimitivePtr
&
p
)
{
os
<<
*
p
;
return
os
;
}
struct
PrimitiveEqual
{
bool
operator
()(
PrimitivePtr
const
&
t1
,
PrimitivePtr
const
&
t2
)
const
{
MS_EXCEPTION_IF_NULL
(
t1
);
MS_EXCEPTION_IF_NULL
(
t2
);
return
t1
->
name
()
==
t2
->
name
();
}
};
struct
PrimitiveHasher
{
std
::
size_t
operator
()(
PrimitivePtr
const
&
prim
)
const
{
MS_EXCEPTION_IF_NULL
(
prim
);
return
prim
->
Hash
();
}
};
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
mindspore/ccsrc/ir/primitive_
base_
extends.cc
→
mindspore/ccsrc/ir/primitive_extends.cc
浏览文件 @
9682d08d
...
...
@@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "ir/primitive
_base
.h"
#include "ir/primitive.h"
#include "pipeline/static_analysis/abstract_function.h"
namespace
mindspore
{
...
...
mindspore/ccsrc/ir/primitive_py.cc
0 → 100644
浏览文件 @
9682d08d
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/primitive_py.h"
#include <mutex>
#include <utility>
#include "ir/signature.h"
#include "operator/ops.h"
#include "./common.h"
#include "pipeline/parse/python_adapter.h"
#include "pipeline/parse/data_converter.h"
#include "pybind11/pytypes.h"
#include "utils/convert_utils_base.h"
#include "utils/primitive_utils.h"
#include "utils/base_ref_py.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace
mindspore
{
namespace
{
constexpr
auto
kBpropAttrName
=
"bprop"
;
constexpr
auto
kCellHookAttrName
=
"cell_hook"
;
constexpr
auto
kCellIDAttrName
=
"cell_id"
;
void
SyncData
(
const
py
::
object
&
arg
)
{
if
(
py
::
isinstance
<
py
::
tuple
>
(
arg
))
{
py
::
tuple
arg_list
=
py
::
cast
<
py
::
tuple
>
(
arg
);
for
(
size_t
i
=
0
;
i
<
arg_list
.
size
();
i
++
)
{
SyncData
(
arg_list
[
i
]);
}
}
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
arg
))
{
auto
tensor
=
py
::
cast
<
tensor
::
TensorPtr
>
(
arg
);
(
void
)
tensor
->
data_sync
();
}
}
}
// namespace
std
::
map
<
std
::
string
,
py
::
object
>
PrimitivePy
::
hook_grad_
;
static
ValuePtr
PyArgToValue
(
const
py
::
object
&
arg
)
{
if
(
py
::
isinstance
<
SignatureEnumKind
>
(
arg
)
&&
py
::
cast
<
SignatureEnumKind
>
(
arg
)
==
SignatureEnumKind
::
kKindEmptyDefaultValue
)
{
return
nullptr
;
}
return
parse
::
data_converter
::
PyDataToValue
(
arg
);
}
void
PrimitivePy
::
set_signatures
(
std
::
vector
<
std
::
tuple
<
std
::
string
,
SignatureEnumRW
,
SignatureEnumKind
,
py
::
object
,
SignatureEnumDType
>>
signatures
)
{
signatures_
.
clear
();
for
(
auto
&
signature
:
signatures
)
{
auto
[
name
,
rw
,
kind
,
arg_default
,
dtype
]
=
signature
;
auto
default_value
=
PyArgToValue
(
arg_default
);
signatures_
.
emplace_back
(
name
,
rw
,
kind
,
default_value
,
dtype
);
}
set_has_signature
(
true
);
}
py
::
function
PrimitivePy
::
GetBpropFunction
()
{
static
const
char
*
const
get_bprop_func_name
=
"get_bprop"
;
if
(
py
::
hasattr
(
python_obj_
,
get_bprop_func_name
))
{
py
::
function
fn
=
python_obj_
.
attr
(
get_bprop_func_name
)().
cast
<
py
::
function
>
();
return
fn
;
}
else
{
auto
fn
=
GetBpropFunctionByObj
(
python_obj_
);
return
fn
;
}
}
BaseRef
PrimitivePy
::
RunHookFunction
(
const
VectorRef
&
args
)
const
{
auto
py_args
=
py
::
tuple
(
args
.
size
());
size_t
i
=
0
;
for
(
auto
&
arg
:
args
)
{
py_args
[
i
]
=
BaseRefToPyData
(
arg
);
MS_LOG
(
DEBUG
)
<<
"arg:"
<<
i
<<
":"
;
i
++
;
}
py
::
object
obj
;
bool
is_bprop
=
this
->
HasAttr
(
kBpropAttrName
);
if
(
is_bprop
)
{
SyncData
(
py_args
);
obj
=
hook_
(
*
py_args
);
return
std
::
make_shared
<
PyObjectRef
>
(
obj
);
}
SyncData
(
py_args
[
2
]);
bool
is_cell
=
this
->
HasAttr
(
kCellHookAttrName
);
if
(
is_cell
)
{
auto
cell_id
=
GetValue
<
std
::
string
>
(
this
->
GetAttr
(
kCellIDAttrName
));
auto
iter
=
hook_grad_
.
find
(
cell_id
);
if
(
iter
!=
hook_grad_
.
end
())
{
auto
hook_args
=
py
::
tuple
(
3
);
hook_args
[
0
]
=
cell_id
;
hook_args
[
1
]
=
py
::
make_tuple
(
iter
->
second
);
hook_args
[
2
]
=
py
::
make_tuple
(
py_args
[
2
]);
obj
=
hook_
(
*
hook_args
);
if
(
py
::
isinstance
<
py
::
none
>
(
obj
))
{
obj
=
py_args
[
2
];
}
hook_grad_
.
erase
(
cell_id
);
}
else
{
hook_grad_
[
cell_id
]
=
py_args
[
2
];
obj
=
py_args
[
2
];
}
}
else
{
// Hook operator for execute variable hook function
obj
=
hook_
(
py
::
make_tuple
(
py_args
[
2
]));
if
(
py
::
isinstance
<
py
::
none
>
(
obj
))
{
obj
=
py_args
[
2
];
}
}
obj
=
py
::
make_tuple
(
obj
);
return
std
::
make_shared
<
PyObjectRef
>
(
obj
);
}
py
::
function
PrimitivePy
::
GetComputeFunction
()
{
static
const
char
*
const
compute_func_name
=
"vm_impl"
;
if
(
py
::
hasattr
(
python_obj_
,
compute_func_name
))
{
MS_LOG
(
INFO
)
<<
name
()
<<
" compute_func_name"
;
py
::
function
fn
=
python_obj_
.
attr
(
compute_func_name
).
cast
<
py
::
function
>
();
return
fn
;
}
static
const
std
::
string
vm_module
=
"mindspore.ops.vm_impl_registry"
;
static
const
std
::
string
get_vm_impl_fn
=
"get_vm_impl_fn"
;
MS_LOG
(
INFO
)
<<
name
()
<<
": get_vm_impl_fn"
;
py
::
function
get_fn
=
parse
::
python_adapter
::
GetPyFn
(
vm_module
,
get_vm_impl_fn
);
py
::
function
vm_fn
=
get_fn
(
python_obj_
);
if
(
py
::
isinstance
<
py
::
none
>
(
vm_fn
))
{
MS_LOG
(
WARNING
)
<<
"Cannot find "
<<
python_obj_
.
attr
(
"__class__"
).
attr
(
"__name__"
).
cast
<
std
::
string
>
();
vm_fn
=
mindspore
::
GetComputeFunction
(
Primitive
::
name
());
}
return
vm_fn
;
}
void
PrimitivePy
::
AddPyAttr
(
const
py
::
str
&
name
,
const
py
::
object
&
obj
)
{
std
::
string
attr_name
=
name
;
ValuePtr
converted_ret
=
nullptr
;
if
(
py
::
isinstance
<
py
::
module
>
(
obj
))
{
MS_LOG
(
EXCEPTION
)
<<
"AddPyAttr failed, obj should not be py::module"
;
}
bool
converted
=
parse
::
ConvertData
(
obj
,
&
converted_ret
);
if
(
!
converted
)
{
MS_LOG
(
EXCEPTION
)
<<
"Attribute convert error with type: "
<<
std
::
string
(
py
::
str
(
obj
));
}
(
void
)
this
->
AddAttr
(
attr_name
,
converted_ret
);
}
py
::
dict
PrimitivePy
::
GetAttrDict
()
{
py
::
dict
attr_dict
;
for
(
auto
&
attr
:
attrs_
)
{
attr_dict
[
py
::
str
(
attr
.
first
)]
=
ValuePtrToPyData
(
attr
.
second
);
}
return
attr_dict
;
}
void
PrimitivePy
::
CopyHookFunction
(
const
PrimitivePtr
&
primitive
)
{
MS_EXCEPTION_IF_NULL
(
primitive
);
if
(
!
primitive
->
isa
<
PrimitivePy
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot copy a primtive which is not python primitive hook function to python primitive!"
;
}
auto
primitive_py
=
primitive
->
cast
<
PrimitivePyPtr
>
();
MS_EXCEPTION_IF_NULL
(
primitive_py
);
this
->
set_hook
(
primitive_py
->
hook
());
}
REGISTER_PYBIND_DEFINE
(
Primitive_
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
enum_
<
PrimType
>
(
*
m
,
"prim_type"
,
py
::
arithmetic
())
.
value
(
"unknown"
,
PrimType
::
kPrimTypeUnknown
)
.
value
(
"builtin"
,
PrimType
::
kPrimTypeBuiltIn
)
.
value
(
"py_infer_shape"
,
PrimType
::
kPrimTypePyInferShape
)
.
value
(
"user_custom"
,
PrimType
::
kPrimTypeUserCustom
);
(
void
)
py
::
class_
<
PrimitivePy
,
std
::
shared_ptr
<
PrimitivePy
>>
(
*
m
,
"Primitive_"
)
.
def_readonly
(
PYTHON_PRIMITIVE_FLAG
,
&
PrimitivePy
::
parse_info_
)
.
def
(
py
::
init
<
py
::
str
&
,
py
::
object
>
())
.
def
(
"add_attr"
,
&
PrimitivePy
::
AddPyAttr
,
"add primitive attr"
)
.
def
(
"get_attr_dict"
,
&
PrimitivePy
::
GetAttrDict
,
"get primitive attr"
)
.
def
(
"set_prim_type"
,
&
PrimitivePy
::
set_prim_type
,
"Set primitive type."
)
.
def
(
"set_signatures"
,
&
PrimitivePy
::
set_signatures
,
"Set primitive inputs signature."
)
.
def
(
"register_hook"
,
&
PrimitivePy
::
set_hook
,
"Set primitive hook function."
)
.
def
(
"set_instance_name"
,
&
PrimitivePy
::
set_instance_name
,
"Set primitive instance name."
);
}));
}
// namespace mindspore
mindspore/ccsrc/ir/primitive_py.h
0 → 100644
浏览文件 @
9682d08d
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_
#define MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_
#include <unordered_map>
#include <vector>
#include <memory>
#include <string>
#include <tuple>
#include <map>
#include "pipeline/static_analysis/abstract_value.h"
#include "utils/misc.h"
#include "pybind11/pybind11.h"
#include "utils/log_adapter.h"
#include "ir/primitive.h"
#include "ir/signature.h"
#include "parallel/ops_info/operator_info.h"
namespace
py
=
pybind11
;
namespace
mindspore
{
class
PrimitivePy
:
public
Primitive
{
public:
PrimitivePy
(
const
py
::
str
&
name
,
const
py
::
object
&
python_obj
)
:
Primitive
(
name
,
false
),
python_obj_
(
python_obj
),
signatures_
()
{}
~
PrimitivePy
()
override
=
default
;
MS_DECLARE_PARENT
(
PrimitivePy
,
Primitive
);
py
::
function
GetBpropFunction
();
py
::
function
GetComputeFunction
();
void
set_signatures
(
std
::
vector
<
std
::
tuple
<
std
::
string
,
SignatureEnumRW
,
SignatureEnumKind
,
py
::
object
,
SignatureEnumDType
>>
signatures
);
const
std
::
vector
<
Signature
>
&
signatures
()
const
{
return
signatures_
;
}
void
CopyHookFunction
(
const
PrimitivePtr
&
primitive
)
override
;
void
AddPyAttr
(
const
py
::
str
&
name
,
const
py
::
object
&
obj
);
py
::
dict
GetAttrDict
();
void
set_hook
(
const
py
::
function
&
hook
)
{
hook_
=
hook
;
}
py
::
function
hook
()
const
{
return
hook_
;
}
BaseRef
RunHookFunction
(
const
VectorRef
&
args
)
const
override
;
const
bool
parse_info_
=
true
;
const
py
::
object
&
GetPyObj
()
const
{
return
python_obj_
;
}
bool
is_tuple_input_
=
false
;
private:
py
::
object
python_obj_
;
py
::
function
hook_
;
std
::
vector
<
Signature
>
signatures_
;
static
std
::
map
<
std
::
string
,
py
::
object
>
hook_grad_
;
};
using
PrimitivePyPtr
=
std
::
shared_ptr
<
PrimitivePy
>
;
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_
mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc
浏览文件 @
9682d08d
...
...
@@ -16,7 +16,6 @@
#include "kernel/cpu/addn_cpu_kernel.h"
#include "device/cpu/cpu_device_address.h"
#include "ir/primitive.h"
namespace
mindspore
{
namespace
kernel
{
...
...
mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc
浏览文件 @
9682d08d
...
...
@@ -16,7 +16,6 @@
#include "kernel/cpu/allgather_cpu_kernel.h"
#include "device/cpu/cpu_device_address.h"
#include "device/cpu/mpi/mpi_adapter.h"
#include "ir/primitive.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
...
...
mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc
浏览文件 @
9682d08d
...
...
@@ -16,7 +16,6 @@
#include "kernel/cpu/concat_cpu_kernel.h"
#include "device/cpu/cpu_device_address.h"
#include "ir/primitive.h"
namespace
mindspore
{
namespace
kernel
{
...
...
mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc
浏览文件 @
9682d08d
...
...
@@ -17,7 +17,6 @@
#include "kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.h"
#include "device/cpu/cpu_device_address.h"
#include "device/cpu/mpi/mpi_adapter.h"
#include "ir/primitive.h"
namespace
mindspore
{
namespace
kernel
{
...
...
mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc
浏览文件 @
9682d08d
...
...
@@ -15,7 +15,6 @@
*/
#include "kernel/cpu/gather_cpu_kernel.h"
#include "device/cpu/cpu_device_address.h"
#include "ir/primitive.h"
namespace
mindspore
{
namespace
kernel
{
...
...
mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc
浏览文件 @
9682d08d
...
...
@@ -15,7 +15,6 @@
*/
#include "kernel/cpu/slice_cpu_kernel.h"
#include "device/cpu/cpu_device_address.h"
#include "ir/primitive.h"
namespace
mindspore
{
namespace
kernel
{
...
...
mindspore/ccsrc/operator/ops.h
浏览文件 @
9682d08d
...
...
@@ -21,7 +21,7 @@
#include <string>
#include <memory>
#include "ir/anf.h"
#include "ir/primitive
_base
.h"
#include "ir/primitive.h"
namespace
mindspore
{
// namespace to support primitive operators
...
...
mindspore/ccsrc/optimizer/ad/kprim.cc
浏览文件 @
9682d08d
...
...
@@ -20,7 +20,7 @@
#include <string>
#include <utility>
#include "ir/anf.h"
#include "ir/primitive.h"
#include "ir/primitive
_py
.h"
#include "ir/meta_func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/manager.h"
...
...
@@ -232,10 +232,7 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res
std
::
vector
<
AnfNodePtr
>
outputs
;
auto
bprop_cut
=
std
::
make_shared
<
PrimitivePy
>
(
"bprop_cut"
,
py
::
object
());
if
(
!
prim
->
is_base
())
{
PrimitivePyPtr
prim_py
=
dyn_cast
<
PrimitivePy
>
(
prim
);
bprop_cut
->
set_hook
(
prim_py
->
hook
());
}
bprop_cut
->
CopyHookFunction
(
prim
);
auto
cell_id
=
GetValue
<
std
::
string
>
(
prim
->
GetAttr
(
"cell_id"
));
if
(
cell_id
!=
""
)
{
...
...
mindspore/ccsrc/optimizer/py_pass_manager.h
浏览文件 @
9682d08d
...
...
@@ -23,7 +23,7 @@
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
#include "ir/primitive
_py
.h"
#include "utils/graph_utils.h"
#include "common/utils.h"
...
...
mindspore/ccsrc/pipeline/static_analysis/static_analysis.h
浏览文件 @
9682d08d
...
...
@@ -33,7 +33,7 @@
#include "utils/log_adapter.h"
#include "ir/anf.h"
#include "ir/primitive.h"
#include "ir/primitive
_py
.h"
#include "pipeline/static_analysis/analysis_context.h"
#include "pipeline/static_analysis/abstract_function.h"
#include "pipeline/parse/parse.h"
...
...
mindspore/ccsrc/pipeline/static_analysis/utils.h
浏览文件 @
9682d08d
...
...
@@ -27,7 +27,6 @@
#include "utils/any.h"
#include "utils/misc.h"
#include "utils/convert_utils.h"
#include "ir/primitive.h"
namespace
mindspore
{
namespace
abstract
{
...
...
mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc
浏览文件 @
9682d08d
...
...
@@ -181,15 +181,6 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
if
(
AnfAlgo
::
IsGraphKernel
(
node
))
{
return
ProcessGraphKernelOp
(
func_graph
,
node
);
}
else
{
// insert cast for single op.
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
node
);
// process input
CNodePtr
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
new_node
=
InsertCastForInput
(
func_graph
,
cnode
);
// process output
return
InsertCastForOutput
(
func_graph
,
new_node
,
std
::
vector
<
bool
>
(
AnfAlgo
::
GetOutputTensorNum
(
new_node
),
true
));
}
// insert cast for single op.
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
node
);
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
浏览文件 @
9682d08d
...
...
@@ -15,7 +15,6 @@
*/
#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h"
#include "pre_activate/common/helper.h"
namespace
mindspore
{
namespace
opt
{
AnfNodePtr
AdamApplyOneFusion
::
CreateAdamApplyOneNode
(
const
FuncGraphPtr
&
func_graph
,
const
EquivPtr
&
equiv
)
const
{
...
...
mindspore/ccsrc/pynative/base.h
浏览文件 @
9682d08d
...
...
@@ -26,7 +26,7 @@
#include <unordered_set>
#include "pybind11/pybind11.h"
#include "ir/primitive.h"
#include "ir/primitive
_py
.h"
#include "pipeline/static_analysis/abstract_value.h"
namespace
mindspore
{
...
...
mindspore/ccsrc/transform/op_adapter_base.h
浏览文件 @
9682d08d
...
...
@@ -29,7 +29,6 @@
#include "ir/primitive.h"
#include "ir/value.h"
#include "transform/types.h"
#ifdef ENABLE_GE
#ifdef OPEN_SOURCE
#include "graph/types.h"
...
...
mindspore/ccsrc/utils/graph_utils.h
浏览文件 @
9682d08d
...
...
@@ -29,7 +29,7 @@
#include <string>
#include "ir/anf.h"
#include "ir/primitive
_base
.h"
#include "ir/primitive.h"
#include "ir/scalar.h"
#include "ir/tensor.h"
#include "debug/label.h"
...
...
mindspore/ccsrc/vm/vm.cc
浏览文件 @
9682d08d
...
...
@@ -648,57 +648,8 @@ void FinalVM::SyncData(const py::object &arg) {
BaseRef
FinalVM
::
RunHook
(
const
PrimitivePtr
&
prim
,
const
VectorRef
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"input for operation:"
;
auto
prim_py
=
dyn_cast
<
PrimitivePy
>
(
prim
);
std
::
size_t
args_size
=
args
.
size
();
auto
py_args
=
py
::
tuple
(
args_size
);
size_t
i
=
0
;
for
(
auto
&
arg
:
args
)
{
py_args
[
i
]
=
BaseRefToPyData
(
arg
);
MS_LOG
(
DEBUG
)
<<
"arg: "
<<
i
<<
":"
;
i
++
;
}
// Hook operator for execute cell custom bprop function
py
::
object
obj
;
bool
is_bprop
=
prim
->
HasAttr
(
"bprop"
);
if
(
is_bprop
)
{
SyncData
(
py_args
);
py
::
function
fn_bprop
=
prim_py
->
hook
();
obj
=
fn_bprop
(
*
py_args
);
return
obj
;
}
// Sync gradient data from device to host
SyncData
(
py_args
[
2
]);
bool
is_cell
=
prim
->
HasAttr
(
"cell_hook"
);
if
(
is_cell
)
{
// Hook operator for execute cell hook function
std
::
string
cell_id
=
GetValue
<
std
::
string
>
(
prim
->
GetAttr
(
"cell_id"
));
if
(
_hook_grad
.
find
(
cell_id
)
!=
_hook_grad
.
end
())
{
std
::
size_t
hook_args_size
=
3
;
auto
hook_args
=
py
::
tuple
(
hook_args_size
);
hook_args
[
0
]
=
cell_id
;
hook_args
[
1
]
=
py
::
make_tuple
(
_hook_grad
[
cell_id
]);
hook_args
[
2
]
=
py
::
make_tuple
(
py_args
[
2
]);
py
::
function
fn_hook
=
prim_py
->
hook
();
obj
=
fn_hook
(
*
hook_args
);
if
(
py
::
isinstance
<
py
::
none
>
(
obj
))
{
obj
=
py_args
[
2
];
}
_hook_grad
.
erase
(
cell_id
);
}
else
{
_hook_grad
[
cell_id
]
=
py_args
[
2
];
obj
=
py_args
[
2
];
}
}
else
{
// Hook operator for execute variable hook function
py
::
function
fn_hook
=
prim_py
->
hook
();
obj
=
fn_hook
(
py
::
make_tuple
(
py_args
[
2
]));
if
(
py
::
isinstance
<
py
::
none
>
(
obj
))
{
obj
=
py_args
[
2
];
}
}
obj
=
py
::
make_tuple
(
obj
);
return
obj
;
MS_EXCEPTION_IF_NULL
(
prim
);
return
prim
->
RunHookFunction
(
args
);
}
}
// namespace compile
}
// namespace mindspore
mindspore/ccsrc/vm/vm.h
浏览文件 @
9682d08d
...
...
@@ -161,7 +161,6 @@ class FinalVM {
{
Instruction
::
kPrim
,
[
this
](
const
VectorRef
&
args
)
{
InstPushPrim
(
args
);
}},
{
Instruction
::
kSwitchReturn
,
[
this
](
const
VectorRef
&
args
)
{
InstSwitchReturn
(
args
);
}},
{
Instruction
::
kSwitchLayer
,
[
this
](
const
VectorRef
&
args
)
{
InstSwitchLayer
(
args
);
}}};
std
::
map
<
std
::
string
,
py
::
object
>
_hook_grad
;
};
using
FinalVMPtr
=
std
::
shared_ptr
<
FinalVM
>
;
...
...
mindspore/ccsrc/vm/vmimpl.cc
浏览文件 @
9682d08d
...
...
@@ -30,7 +30,7 @@
#include "operator/ops.h"
#include "ir/manager.h"
#include "ir/func_graph_cloner.h"
#include "ir/primitive.h"
#include "ir/primitive
_py
.h"
#include "utils/convert_utils.h"
#include "utils/primitive_utils.h"
#include "debug/draw.h"
...
...
tests/ut/cpp/operator/ops_test.cc
浏览文件 @
9682d08d
...
...
@@ -19,7 +19,7 @@
#include "common/common_test.h"
#include "ir/value.h"
#include "ir/primitive.h"
#include "ir/primitive
_py
.h"
#include "operator/ops.h"
#include "./common.h"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录