Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
04763b8b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
04763b8b
编写于
5月 20, 2020
作者:
L
leopz
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move signature to primitivepy and bprop_func to utils
上级
183144e1
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
328 addition
and
196 deletion
+328
-196
mindspore/ccsrc/ir/primitive.cc
mindspore/ccsrc/ir/primitive.cc
+4
-86
mindspore/ccsrc/ir/primitive.h
mindspore/ccsrc/ir/primitive.h
+10
-105
mindspore/ccsrc/ir/primitive_base.cc
mindspore/ccsrc/ir/primitive_base.cc
+71
-0
mindspore/ccsrc/ir/primitive_base.h
mindspore/ccsrc/ir/primitive_base.h
+128
-0
mindspore/ccsrc/ir/primitive_base_extends.cc
mindspore/ccsrc/ir/primitive_base_extends.cc
+25
-0
mindspore/ccsrc/operator/composite/do_signature.cc
mindspore/ccsrc/operator/composite/do_signature.cc
+2
-2
mindspore/ccsrc/optimizer/ad/kprim.cc
mindspore/ccsrc/optimizer/ad/kprim.cc
+3
-1
mindspore/ccsrc/utils/primitive_utils.cc
mindspore/ccsrc/utils/primitive_utils.cc
+49
-0
mindspore/ccsrc/utils/primitive_utils.h
mindspore/ccsrc/utils/primitive_utils.h
+33
-0
mindspore/ccsrc/vm/vmimpl.cc
mindspore/ccsrc/vm/vmimpl.cc
+2
-1
tests/ut/cpp/operator/ops_test.cc
tests/ut/cpp/operator/ops_test.cc
+1
-1
未找到文件。
mindspore/ccsrc/ir/primitive.cc
浏览文件 @
04763b8b
...
...
@@ -24,75 +24,13 @@
#include "pipeline/parse/data_converter.h"
#include "pybind11/pytypes.h"
#include "utils/convert_utils.h"
#include "utils/primitive_utils.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace
mindspore
{
using
mindspore
::
abstract
::
AbstractFunction
;
abstract
::
AbstractBasePtr
Primitive
::
ToPrimAbstract
(
const
AnfNodePtr
&
anf_node
)
{
auto
prim_func
=
std
::
make_shared
<
abstract
::
PrimitiveAbstractClosure
>
(
shared_from_base
<
Primitive
>
(),
anf_node
);
return
prim_func
;
}
static
py
::
function
GetBpropFunctionByObj
(
py
::
object
obj
)
{
static
const
std
::
string
get_bprop_fn
=
"get_bprop_fn"
;
static
const
std
::
string
ad_module
=
"mindspore.ops._grad"
;
py
::
function
fn
=
parse
::
python_adapter
::
GetPyFn
(
ad_module
,
get_bprop_fn
)(
obj
);
return
fn
;
}
py
::
function
Primitive
::
GetBpropFunction
()
{
auto
fn
=
GetBpropFunctionByObj
(
py
::
str
(
name
()));
if
(
fn
.
is_none
())
{
MS_LOG
(
WARNING
)
<<
"Can't find bprop function for "
<<
name
();
}
return
fn
;
}
py
::
function
Primitive
::
GetComputeFunction
()
{
static
const
std
::
string
module
=
"mindspore._extends.builtin_operations"
;
py
::
module
mod
=
py
::
module
::
import
(
common
::
SafeCStr
(
module
));
if
(
!
py
::
hasattr
(
mod
,
common
::
SafeCStr
(
name
())))
{
PyErr_SetString
(
PyExc_NotImplementedError
,
common
::
SafeCStr
(
name
()));
// If raise AttributeError, user can't understand. This case need raise NotImplementedError.
throw
py
::
error_already_set
();
}
py
::
object
fn
=
mod
.
attr
(
common
::
SafeCStr
(
name
()));
return
fn
;
}
bool
Primitive
::
operator
==
(
const
Value
&
other
)
const
{
if
(
other
.
isa
<
Primitive
>
())
{
auto
other_prim
=
static_cast
<
const
Primitive
&>
(
other
);
return
*
this
==
other_prim
;
}
else
{
return
false
;
}
}
bool
Primitive
::
operator
==
(
const
Primitive
&
other
)
const
{
if
(
name
()
!=
other
.
name
())
{
return
false
;
}
if
(
attrs_
.
size
()
!=
other
.
attrs_
.
size
())
{
return
false
;
}
auto
all
=
std
::
all_of
(
attrs_
.
begin
(),
attrs_
.
end
(),
[
&
other
](
const
std
::
pair
<
std
::
string
,
ValuePtr
>
&
item
)
->
bool
{
if
(
item
.
second
==
nullptr
)
{
return
false
;
}
auto
iter
=
other
.
attrs_
.
find
(
item
.
first
);
if
(
iter
==
other
.
attrs_
.
end
())
{
return
false
;
}
return
*
item
.
second
==
*
iter
->
second
;
});
return
all
;
}
void
Primitive
::
set_signatures
(
void
PrimitivePy
::
set_signatures
(
std
::
vector
<
std
::
tuple
<
std
::
string
,
SignatureEnumRW
,
SignatureEnumKind
,
py
::
object
,
SignatureEnumDType
>>
signatures
)
{
signatures_
.
clear
();
for
(
auto
&
signature
:
signatures
)
{
...
...
@@ -104,27 +42,7 @@ void Primitive::set_signatures(
std
::
tie
(
name
,
rw
,
kind
,
default_value
,
dtype
)
=
signature
;
signatures_
.
emplace_back
(
Signature
(
name
,
rw
,
kind
,
default_value
,
dtype
));
}
}
std
::
string
Primitive
::
GetAttrsText
()
const
{
if
(
attrs_
.
empty
())
{
return
""
;
}
std
::
ostringstream
oss
;
oss
<<
"["
;
bool
is_first
=
true
;
for
(
auto
&
attr
:
attrs_
)
{
if
(
is_first
)
{
is_first
=
false
;
}
else
{
oss
<<
", "
;
}
oss
<<
attr
.
first
<<
"="
<<
attr
.
second
->
DumpText
();
}
oss
<<
"]"
;
return
oss
.
str
();
set_has_signature
(
true
);
}
py
::
function
PrimitivePy
::
GetBpropFunction
()
{
...
...
@@ -158,7 +76,7 @@ py::function PrimitivePy::GetComputeFunction() {
if
(
py
::
isinstance
<
py
::
none
>
(
vm_fn
))
{
MS_LOG
(
DEBUG
)
<<
"Cannot find "
<<
python_obj_
.
attr
(
"__class__"
).
attr
(
"__name__"
).
cast
<
std
::
string
>
();
vm_fn
=
Primitive
::
GetComputeFunction
(
);
vm_fn
=
mindspore
::
GetComputeFunction
(
Primitive
::
name
()
);
}
return
vm_fn
;
}
...
...
mindspore/ccsrc/ir/primitive.h
浏览文件 @
04763b8b
...
...
@@ -22,59 +22,26 @@
#include <memory>
#include <string>
#include <tuple>
#include "pybind11/pybind11.h"
#include "pybind11/pybind11.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "utils/misc.h"
#include "utils/log_adapter.h"
#include "ir/primitive_base.h"
#include "ir/signature.h"
#include "parallel/ops_info/operator_info.h"
namespace
py
=
pybind11
;
namespace
mindspore
{
using
abstract
::
AbstractBasePtr
;
using
abstract
::
AbstractBasePtrList
;
// Supported meta type
enum
PrimType
{
kPrimTypeUnknown
=
0
,
kPrimTypeBegin
=
kTypeUnknown
,
kPrimTypeBuiltIn
,
// Built-in primitive operator
kPrimTypePyInferShape
,
// Primitive operator defined by custom
kPrimTypePyInferTensor
,
// Primitive operator defined by custom
kPrimTypeUserCustom
};
class
Primitive
:
public
Named
{
class
PrimitivePy
:
public
Primitive
{
public:
explicit
Primitive
(
const
std
::
string
&
name
,
const
PrimType
prim_type
=
kPrimTypeBuiltIn
)
:
Named
(
name
),
signatures_
(),
prim_type_
(
prim_type
)
{}
Primitive
(
const
Primitive
&
prim
)
:
Named
(
prim
),
attrs_
(
prim
.
attrs_
),
signatures_
(
prim
.
signatures_
),
instance_name_
(
prim
.
instance_name_
),
prim_type_
(
prim
.
prim_type_
)
{}
MS_DECLARE_PARENT
(
Primitive
,
Named
);
abstract
::
AbstractBasePtr
ToPrimAbstract
(
const
AnfNodePtr
&
anf_node
);
std
::
string
ToString
()
const
override
{
return
name
();
}
virtual
py
::
function
GetBpropFunction
();
virtual
py
::
function
GetComputeFunction
();
Primitive
&
AddAttr
(
const
std
::
string
&
name
,
const
ValuePtr
&
attr
)
{
attrs_
[
name
]
=
attr
;
return
*
this
;
}
Primitive
&
SetAttrs
(
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
)
{
for
(
auto
&
attr
:
attrs
)
{
attrs_
[
attr
.
first
]
=
attr
.
second
;
}
return
*
this
;
}
PrimitivePy
(
const
py
::
str
&
name
,
const
py
::
object
&
python_obj
)
:
Primitive
(
name
,
false
),
python_obj_
(
python_obj
),
signatures_
()
{}
~
PrimitivePy
()
override
=
default
;
MS_DECLARE_PARENT
(
PrimitivePy
,
Primitive
);
py
::
function
GetBpropFunction
();
py
::
function
GetComputeFunction
();
void
set_signatures
(
std
::
vector
<
std
::
tuple
<
std
::
string
,
SignatureEnumRW
,
SignatureEnumKind
,
py
::
object
,
SignatureEnumDType
>>
...
...
@@ -82,52 +49,6 @@ class Primitive : public Named {
const
std
::
vector
<
Signature
>
&
signatures
()
const
{
return
signatures_
;
}
void
set_attr
(
const
std
::
string
&
attrName
,
const
ValuePtr
&
attr
)
{
attrs_
[
attrName
]
=
attr
;
}
void
EraseAttr
(
const
std
::
string
&
attrName
)
{
(
void
)
attrs_
.
erase
(
attrName
);
}
ValuePtr
GetAttr
(
const
std
::
string
&
attrName
)
const
{
auto
iter
=
attrs_
.
find
(
attrName
);
return
iter
==
attrs_
.
cend
()
?
nullptr
:
iter
->
second
;
}
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
()
const
{
return
attrs_
;
}
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool
HasAttr
()
const
{
return
!
attrs_
.
empty
();
}
bool
HasAttr
(
const
std
::
string
&
attrName
)
const
{
auto
iter
=
attrs_
.
find
(
attrName
);
return
!
(
iter
==
attrs_
.
cend
());
}
void
set_prim_type
(
const
PrimType
t
)
{
prim_type_
=
t
;
}
void
set_instance_name
(
const
std
::
string
s
)
{
instance_name_
=
s
;
}
bool
HasPyEvaluator
()
const
{
return
prim_type_
==
kPrimTypePyInferShape
||
prim_type_
==
kPrimTypeUserCustom
;
}
bool
HasPyInferTensor
()
const
{
return
prim_type_
==
kPrimTypePyInferTensor
;
}
bool
IsCustomPrim
()
const
{
return
prim_type_
==
kPrimTypeUserCustom
;
}
PrimType
prim_type
()
const
{
return
prim_type_
;
}
std
::
string
instance_name
()
const
{
return
instance_name_
;
}
std
::
string
GetAttrsText
()
const
;
bool
operator
==
(
const
Value
&
other
)
const
override
;
bool
operator
==
(
const
Primitive
&
other
)
const
;
~
Primitive
()
override
=
default
;
protected:
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
private:
std
::
vector
<
Signature
>
signatures_
;
std
::
string
instance_name_
;
PrimType
prim_type_
;
};
class
PrimitivePy
:
public
Primitive
{
public:
PrimitivePy
(
const
py
::
str
&
name
,
const
py
::
object
&
python_obj
)
:
Primitive
(
name
),
python_obj_
(
python_obj
)
{}
~
PrimitivePy
()
override
=
default
;
MS_DECLARE_PARENT
(
PrimitivePy
,
Primitive
);
py
::
function
GetBpropFunction
()
override
;
py
::
function
GetComputeFunction
()
override
;
void
AddPyAttr
(
const
py
::
str
&
name
,
const
py
::
object
&
obj
);
py
::
dict
GetAttrDict
();
...
...
@@ -138,25 +59,9 @@ class PrimitivePy : public Primitive {
private:
py
::
object
python_obj_
;
std
::
vector
<
Signature
>
signatures_
;
};
using
PrimitivePyPtr
=
std
::
shared_ptr
<
PrimitivePy
>
;
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
PrimitivePtr
&
p
)
{
os
<<
*
p
;
return
os
;
}
struct
PrimitiveEqual
{
bool
operator
()(
PrimitivePtr
const
&
t1
,
PrimitivePtr
const
&
t2
)
const
{
MS_EXCEPTION_IF_NULL
(
t1
);
MS_EXCEPTION_IF_NULL
(
t2
);
return
t1
->
name
()
==
t2
->
name
();
}
};
struct
PrimitiveHasher
{
std
::
size_t
operator
()(
PrimitivePtr
const
&
prim
)
const
{
return
prim
->
Hash
();
}
};
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_
mindspore/ccsrc/ir/primitive_base.cc
0 → 100644
浏览文件 @
04763b8b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/primitive_base.h"
#include <utility>
namespace
mindspore
{
bool
Primitive
::
operator
==
(
const
Value
&
other
)
const
{
if
(
other
.
isa
<
Primitive
>
())
{
auto
other_prim
=
static_cast
<
const
Primitive
&>
(
other
);
return
*
this
==
other_prim
;
}
else
{
return
false
;
}
}
bool
Primitive
::
operator
==
(
const
Primitive
&
other
)
const
{
if
(
name
()
!=
other
.
name
())
{
return
false
;
}
if
(
attrs_
.
size
()
!=
other
.
attrs_
.
size
())
{
return
false
;
}
auto
all
=
std
::
all_of
(
attrs_
.
begin
(),
attrs_
.
end
(),
[
&
other
](
const
std
::
pair
<
std
::
string
,
ValuePtr
>
&
item
)
->
bool
{
if
(
item
.
second
==
nullptr
)
{
return
false
;
}
auto
iter
=
other
.
attrs_
.
find
(
item
.
first
);
if
(
iter
==
other
.
attrs_
.
end
())
{
return
false
;
}
return
*
item
.
second
==
*
iter
->
second
;
});
return
all
;
}
std
::
string
Primitive
::
GetAttrsText
()
const
{
if
(
attrs_
.
empty
())
{
return
""
;
}
std
::
ostringstream
oss
;
oss
<<
"["
;
bool
is_first
=
true
;
for
(
auto
&
attr
:
attrs_
)
{
if
(
is_first
)
{
is_first
=
false
;
}
else
{
oss
<<
", "
;
}
oss
<<
attr
.
first
<<
"="
<<
attr
.
second
->
DumpText
();
}
oss
<<
"]"
;
return
oss
.
str
();
}
}
// namespace mindspore
mindspore/ccsrc/ir/primitive_base.h
0 → 100644
浏览文件 @
04763b8b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
#define MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
#include <unordered_map>
#include <vector>
#include <memory>
#include <string>
#include <tuple>
#include "ir/dtype/type.h"
namespace
mindspore
{
// Supported meta type
enum
PrimType
{
kPrimTypeUnknown
=
0
,
kPrimTypeBegin
=
kTypeUnknown
,
kPrimTypeBuiltIn
,
// Built-in primitive operator
kPrimTypePyInferShape
,
// Primitive operator defined by custom
kPrimTypePyInferTensor
,
// Primitive operator defined by custom
kPrimTypeUserCustom
};
class
Primitive
:
public
Named
{
public:
explicit
Primitive
(
const
std
::
string
&
name
,
const
bool
is_base
=
true
,
const
PrimType
prim_type
=
kPrimTypeBuiltIn
)
:
Named
(
name
),
is_base_
(
is_base
),
has_signature_
(
false
),
prim_type_
(
prim_type
)
{}
Primitive
(
const
Primitive
&
prim
)
:
Named
(
prim
),
attrs_
(
prim
.
attrs_
),
instance_name_
(
prim
.
instance_name_
),
is_base_
(
prim
.
is_base_
),
has_signature_
(
prim
.
has_signature_
),
prim_type_
(
prim
.
prim_type_
)
{}
MS_DECLARE_PARENT
(
Primitive
,
Named
);
abstract
::
AbstractBasePtr
ToPrimAbstract
(
const
AnfNodePtr
&
anf_node
);
std
::
string
ToString
()
const
override
{
return
name
();
}
Primitive
&
AddAttr
(
const
std
::
string
&
name
,
const
ValuePtr
&
attr
)
{
attrs_
[
name
]
=
attr
;
return
*
this
;
}
Primitive
&
SetAttrs
(
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
)
{
for
(
auto
&
attr
:
attrs
)
{
attrs_
[
attr
.
first
]
=
attr
.
second
;
}
return
*
this
;
}
void
set_attr
(
const
std
::
string
&
attrName
,
const
ValuePtr
&
attr
)
{
attrs_
[
attrName
]
=
attr
;
}
void
EraseAttr
(
const
std
::
string
&
attrName
)
{
(
void
)
attrs_
.
erase
(
attrName
);
}
ValuePtr
GetAttr
(
const
std
::
string
&
attrName
)
const
{
auto
iter
=
attrs_
.
find
(
attrName
);
return
iter
==
attrs_
.
cend
()
?
nullptr
:
iter
->
second
;
}
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
()
const
{
return
attrs_
;
}
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool
HasAttr
()
const
{
return
!
attrs_
.
empty
();
}
bool
HasAttr
(
const
std
::
string
&
attrName
)
const
{
auto
iter
=
attrs_
.
find
(
attrName
);
return
!
(
iter
==
attrs_
.
cend
());
}
void
set_prim_type
(
const
PrimType
t
)
{
prim_type_
=
t
;
}
void
set_instance_name
(
const
std
::
string
s
)
{
instance_name_
=
s
;
}
bool
HasPyEvaluator
()
const
{
return
prim_type_
==
kPrimTypePyInferShape
||
prim_type_
==
kPrimTypeUserCustom
;
}
bool
HasPyInferTensor
()
const
{
return
prim_type_
==
kPrimTypePyInferTensor
;
}
bool
IsCustomPrim
()
const
{
return
prim_type_
==
kPrimTypeUserCustom
;
}
PrimType
prim_type
()
const
{
return
prim_type_
;
}
std
::
string
instance_name
()
const
{
return
instance_name_
;
}
std
::
string
GetAttrsText
()
const
;
bool
operator
==
(
const
Value
&
other
)
const
override
;
bool
operator
==
(
const
Primitive
&
other
)
const
;
~
Primitive
()
override
=
default
;
void
set_has_signature
(
bool
has_signature
)
{
has_signature_
=
has_signature
;
}
bool
has_signature
()
const
{
return
has_signature_
;
}
bool
is_base
()
const
{
return
is_base_
;
}
protected:
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
private:
std
::
string
instance_name_
;
bool
is_base_
;
bool
has_signature_
;
PrimType
prim_type_
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
PrimitivePtr
&
p
)
{
os
<<
*
p
;
return
os
;
}
struct
PrimitiveEqual
{
bool
operator
()(
PrimitivePtr
const
&
t1
,
PrimitivePtr
const
&
t2
)
const
{
MS_EXCEPTION_IF_NULL
(
t1
);
MS_EXCEPTION_IF_NULL
(
t2
);
return
t1
->
name
()
==
t2
->
name
();
}
};
struct
PrimitiveHasher
{
std
::
size_t
operator
()(
PrimitivePtr
const
&
prim
)
const
{
return
prim
->
Hash
();
}
};
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
mindspore/ccsrc/ir/primitive_base_extends.cc
0 → 100644
浏览文件 @
04763b8b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/primitive_base.h"
#include "pipeline/static_analysis/abstract_function.h"
namespace
mindspore
{
abstract
::
AbstractBasePtr
Primitive
::
ToPrimAbstract
(
const
AnfNodePtr
&
anf_node
)
{
auto
prim_func
=
std
::
make_shared
<
abstract
::
PrimitiveAbstractClosure
>
(
shared_from_base
<
Primitive
>
(),
anf_node
);
return
prim_func
;
}
}
// namespace mindspore
mindspore/ccsrc/operator/composite/do_signature.cc
浏览文件 @
04763b8b
...
...
@@ -36,8 +36,8 @@ using PatternListType = std::initializer_list<BaseRef>;
const
std
::
vector
<
Signature
>
&
GetSignature
(
const
ValuePtr
&
function
)
{
static
const
auto
empty
=
std
::
vector
<
Signature
>
();
if
(
function
->
isa
<
Primitive
>
())
{
return
function
->
cast
<
PrimitivePtr
>
()
->
signatures
();
if
(
function
->
isa
<
Primitive
>
()
&&
function
->
cast
<
PrimitivePtr
>
()
->
has_signature
()
)
{
return
function
->
cast
<
PrimitiveP
yP
tr
>
()
->
signatures
();
}
else
if
(
function
->
isa
<
MetaFuncGraph
>
())
{
return
function
->
cast
<
MetaFuncGraphPtr
>
()
->
signatures
();
}
...
...
mindspore/ccsrc/optimizer/ad/kprim.cc
浏览文件 @
04763b8b
...
...
@@ -20,6 +20,7 @@
#include <string>
#include <utility>
#include "ir/anf.h"
#include "ir/primitive.h"
#include "ir/meta_func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/manager.h"
...
...
@@ -30,6 +31,7 @@
#include "operator/ops.h"
#include "operator/composite/composite.h"
#include "utils/symbolic.h"
#include "utils/primitive_utils.h"
#include "debug/info.h"
#include "debug/trace.h"
...
...
@@ -49,7 +51,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
auto
scope
=
std
::
make_shared
<
Scope
>
(
gradients_scope
+
ScopeManager
::
GetInstance
().
GetCurrentScope
()
->
name
()
+
grad_op_child_scope_prefix
+
prim
->
name
());
ScopeGuard
scope_guard
(
scope
);
py
::
function
fn
=
prim
->
GetBpropFunction
();
py
::
function
fn
=
prim
->
is_base
()
?
GetBpropFunction
(
prim
->
name
())
:
prim
->
cast
<
PrimitivePyPtr
>
()
->
GetBpropFunction
();
if
(
fn
==
nullptr
||
py
::
isinstance
<
py
::
none
>
(
fn
))
{
MS_LOG
(
DEBUG
)
<<
"Fail to find bprop function for "
<<
prim
->
name
()
<<
"."
;
return
nullptr
;
...
...
mindspore/ccsrc/utils/primitive_utils.cc
0 → 100644
浏览文件 @
04763b8b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/primitive_utils.h"
#include "pipeline/parse/python_adapter.h"
#include "utils/log_adapter.h"
#include "common/utils.h"
namespace
mindspore
{
py
::
function
GetBpropFunctionByObj
(
py
::
object
obj
)
{
static
const
std
::
string
get_bprop_fn
=
"get_bprop_fn"
;
static
const
std
::
string
ad_module
=
"mindspore.ops._grad"
;
py
::
function
fn
=
parse
::
python_adapter
::
GetPyFn
(
ad_module
,
get_bprop_fn
)(
obj
);
return
fn
;
}
py
::
function
GetBpropFunction
(
std
::
string
name
)
{
auto
fn
=
GetBpropFunctionByObj
(
py
::
str
(
name
));
if
(
fn
.
is_none
())
{
MS_LOG
(
WARNING
)
<<
"Can't find bprop function for "
<<
name
;
}
return
fn
;
}
py
::
function
GetComputeFunction
(
std
::
string
name
)
{
static
const
std
::
string
module
=
"mindspore._extends.builtin_operations"
;
py
::
module
mod
=
py
::
module
::
import
(
common
::
SafeCStr
(
module
));
if
(
!
py
::
hasattr
(
mod
,
common
::
SafeCStr
(
name
)))
{
PyErr_SetString
(
PyExc_NotImplementedError
,
common
::
SafeCStr
(
name
));
// If raise AttributeError, user can't understand. This case need raise NotImplementedError.
throw
py
::
error_already_set
();
}
py
::
object
fn
=
mod
.
attr
(
common
::
SafeCStr
(
name
));
return
fn
;
}
}
// namespace mindspore
mindspore/ccsrc/utils/primitive_utils.h
0 → 100644
浏览文件 @
04763b8b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_
#define MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_
#include <string>
#include "pybind11/pybind11.h"
namespace
py
=
pybind11
;
namespace
mindspore
{
py
::
function
GetBpropFunctionByObj
(
py
::
object
obj
);
py
::
function
GetBpropFunction
(
std
::
string
name
);
py
::
function
GetComputeFunction
(
std
::
string
name
);
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_
mindspore/ccsrc/vm/vmimpl.cc
浏览文件 @
04763b8b
...
...
@@ -31,6 +31,7 @@
#include "ir/manager.h"
#include "ir/func_graph_cloner.h"
#include "utils/convert_utils.h"
#include "utils/primitive_utils.h"
#include "debug/draw.h"
namespace
mindspore
{
...
...
@@ -443,7 +444,7 @@ BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) {
PrimitivePyPtr
operation
=
dyn_cast
<
PrimitivePy
>
(
prim
);
MS_LOG
(
DEBUG
)
<<
"operation start "
<<
prim
->
name
();
auto
func
=
operation
!=
nullptr
?
operation
->
GetComputeFunction
()
:
prim
->
GetComputeFunction
(
);
auto
func
=
operation
!=
nullptr
?
operation
->
GetComputeFunction
()
:
GetComputeFunction
(
prim
->
name
()
);
if
(
py
::
isinstance
<
py
::
none
>
(
func
))
{
MS_LOG
(
EXCEPTION
)
<<
prim
->
name
()
<<
" 's compute function is not implemented"
;
}
...
...
tests/ut/cpp/operator/ops_test.cc
浏览文件 @
04763b8b
...
...
@@ -390,7 +390,7 @@ TEST_F(TestOps, Conv2dAttrTest) {
}
TEST_F
(
TestOps
,
CustomOpAttrTest
)
{
Primitive
prim
(
"CustomOp"
,
kPrimTypePyInferShape
);
Primitive
prim
(
"CustomOp"
,
true
,
kPrimTypePyInferShape
);
prim
.
SetAttrs
({
{
"attr1"
,
MakeValue
(
3
)},
{
"attr2"
,
MakeValue
(
1
)},
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录