Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e32929df
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e32929df
编写于
1月 14, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dispatch): implement scalar
GitOrigin-RevId: b244c2ca1ad5cb28ffcf0e320cd0440f298bea51
上级
59084fa8
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
464 addition
and
0 deletion
+464
-0
imperative/src/impl/transformations/scalar.cpp
imperative/src/impl/transformations/scalar.cpp
+404
-0
imperative/src/include/megbrain/imperative/transformations/scalar.h
.../src/include/megbrain/imperative/transformations/scalar.h
+60
-0
未找到文件。
imperative/src/impl/transformations/scalar.cpp
0 → 100644
浏览文件 @
e32929df
/**
* \file imperative/src/impl/transformations/trace.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/ops/autogen.h"
namespace
mgb
{
namespace
imperative
{
namespace
{
using
ScalarRule
=
std
::
function
<
std
::
vector
<
ValueRef
>
(
const
OpDef
&
,
Span
<
ValueRef
>
)
>
;
static
std
::
unordered_map
<
Typeinfo
*
,
std
::
function
<
std
::
vector
<
ValueRef
>
(
const
OpDef
&
,
Span
<
ValueRef
>
)
>>
scalar_rules
;
ValueRef
unwrap_input
(
ValueRef
input
)
{
if
(
auto
scalar_input
=
input
.
as_ref
<
ScalarValue
>
())
{
return
scalar_input
->
value
();
}
else
{
return
input
;
}
}
std
::
vector
<
ValueRef
>
unwrap_inputs
(
Span
<
ValueRef
>
inputs
)
{
std
::
vector
<
ValueRef
>
unwrapped_inputs
;
for
(
auto
&&
input
:
inputs
)
{
unwrapped_inputs
.
push_back
(
unwrap_input
(
input
));
}
return
unwrapped_inputs
;
}
ValueRef
make_scalar_shape
(
CompNode
device
)
{
HostTensorND
scalar_shape
(
device
,
{
1
},
dtype
::
Int32
());
scalar_shape
.
ptr
<
dt_int32
>
()[
0
]
=
1
;
return
imperative
::
apply
(
CreateTensor
(
CreateTensor
::
Const
,
device
,
scalar_shape
.
layout
()),
HostStorage
::
make
(
scalar_shape
.
storage
()))[
0
];
}
bool
is_scalar_shape
(
ValueRef
shape
)
{
if
(
shape
.
is
<
ScalarValue
>
())
{
return
false
;
}
auto
shape_of_shape
=
shape
.
shape
();
if
(
!
shape_of_shape
)
{
// assume not scalar
return
false
;
}
return
*
shape_of_shape
==
ValueShape
{
0
};
}
template
<
typename
T
>
void
register_scalar_rule
(
std
::
vector
<
ValueRef
>
(
*
rule
)(
const
T
&
,
Span
<
ValueRef
>
))
{
scalar_rules
[
T
::
typeinfo
()]
=
[
rule
](
const
OpDef
&
def
,
Span
<
ValueRef
>
inputs
)
{
return
(
*
rule
)(
def
.
cast_final_safe
<
T
>
(),
inputs
);
};
}
std
::
vector
<
ValueRef
>
elemwise_rule
(
const
Elemwise
&
elem
,
Span
<
ValueRef
>
inputs
)
{
bool
all_scalar
=
true
;
for
(
auto
&&
input
:
inputs
)
{
if
(
!
input
.
is
<
ScalarValue
>
())
{
all_scalar
=
false
;
break
;
}
}
auto
output
=
imperative
::
apply
(
elem
,
unwrap_inputs
(
inputs
))[
0
];
if
(
all_scalar
)
{
return
{
ScalarValue
::
make
(
output
)};
}
else
{
return
{
output
};
}
}
std
::
vector
<
ValueRef
>
remove_axis_rule
(
const
RemoveAxis
&
remove_axis
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
mgb_assert
(
!
inputs
[
0
].
is
<
ScalarValue
>
());
auto
output
=
imperative
::
apply
(
remove_axis
,
inputs
)[
0
];
bool
is_scalar
=
inputs
[
0
].
shape
()
->
ndim
==
remove_axis
.
axis
.
size
();
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
output
)};
}
else
{
return
{
output
};
}
}
std
::
vector
<
ValueRef
>
reduce_rule
(
const
Reduce
&
reduce
,
Span
<
ValueRef
>
inputs
)
{
if
(
inputs
.
size
()
==
1
)
{
return
imperative
::
apply
(
reduce
,
unwrap_inputs
(
inputs
));
}
mgb_assert
(
inputs
.
size
()
==
2
);
bool
is_scalar
=
is_scalar_shape
(
inputs
[
1
]);
if
(
is_scalar
)
{
auto
unwrapped_input
=
unwrap_input
(
inputs
[
0
]);
CompNode
device
=
*
unwrapped_input
.
device
();
return
{
ScalarValue
::
make
(
imperative
::
apply
(
reduce
,
unwrapped_input
,
make_scalar_shape
(
device
))[
0
])};
}
auto
output
=
imperative
::
apply
(
reduce
,
unwrap_inputs
(
inputs
))[
0
];
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
output
)};
}
else
{
return
{
output
};
}
}
std
::
vector
<
ValueRef
>
typecvt_rule
(
const
TypeCvt
&
typecvt
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
if
(
auto
scalar_input
=
inputs
[
0
].
as_ref
<
ScalarValue
>
())
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
typecvt
,
scalar_input
->
value
())[
0
])};
}
else
{
return
imperative
::
apply
(
typecvt
,
inputs
);
}
}
std
::
vector
<
ValueRef
>
collective_comm_rule
(
const
CollectiveComm
&
collective_comm
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
static
std
::
unordered_set
<
CollectiveComm
::
Mode
>
modes
=
{
CollectiveComm
::
Mode
::
ALL_REDUCE_MAX
,
CollectiveComm
::
Mode
::
ALL_REDUCE_MIN
,
CollectiveComm
::
Mode
::
ALL_REDUCE_SUM
,
CollectiveComm
::
Mode
::
BROADCAST
,
CollectiveComm
::
Mode
::
REDUCE_SUM
,
};
if
(
modes
.
count
(
collective_comm
.
mode
)
==
0
)
{
return
imperative
::
apply
(
collective_comm
,
inputs
);
}
if
(
auto
scalar_input
=
inputs
[
0
].
as_ref
<
ScalarValue
>
())
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
collective_comm
,
scalar_input
->
value
())[
0
])};
}
else
{
return
imperative
::
apply
(
collective_comm
,
inputs
);
}
}
std
::
vector
<
ValueRef
>
param_pack_split_rule
(
const
ParamPackSplit
&
param_pack_split
,
Span
<
ValueRef
>
inputs
)
{
auto
outputs
=
imperative
::
apply
(
param_pack_split
,
unwrap_inputs
(
inputs
));
size_t
nr_outputs
=
outputs
.
size
();
mgb_assert
(
nr_outputs
==
param_pack_split
.
shapes
.
size
());
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
if
(
param_pack_split
.
shapes
[
i
].
empty
())
{
outputs
[
i
]
=
ScalarValue
::
make
(
outputs
[
i
]);
}
}
return
outputs
;
}
std
::
vector
<
ValueRef
>
dot_rule
(
const
Dot
&
dot
,
Span
<
ValueRef
>
inputs
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
dot
,
unwrap_inputs
(
inputs
))[
0
])};
}
std
::
vector
<
ValueRef
>
add_axis_rule
(
const
AddAxis
&
add_axis
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
if
(
auto
scalar_input
=
inputs
[
0
].
as_ref
<
ScalarValue
>
())
{
mgb_assert
(
add_axis
.
axis
[
0
]
==
0
);
if
(
add_axis
.
axis
.
size
()
==
1
)
{
return
{
scalar_input
->
value
()};
}
else
{
std
::
vector
<
int32_t
>
axis
(
add_axis
.
axis
.
begin
()
+
1
,
add_axis
.
axis
.
end
());
return
imperative
::
apply
(
ApplyOp
(
*
AddAxis
::
make
(
axis
,
add_axis
.
scope
())),
scalar_input
->
value
());
}
}
else
{
return
imperative
::
apply
(
add_axis
,
inputs
);
}
}
std
::
vector
<
ValueRef
>
remote_recv_rule
(
const
RemoteRecv
&
remote_recv
,
Span
<
ValueRef
>
inputs
)
{
if
(
remote_recv
.
shape
.
empty
())
{
std
::
vector
<
int32_t
>
shape
=
{
1
};
auto
remote_recv_no_scalar
=
RemoteRecv
::
make
(
remote_recv
.
key
,
remote_recv
.
addr
,
remote_recv
.
port
,
remote_recv
.
rank_from
,
remote_recv
.
cn
,
shape
,
remote_recv
.
dtype
,
remote_recv
.
backend
);
remote_recv_no_scalar
->
set_scope
(
remote_recv
.
scope
());
return
imperative
::
apply
(
ApplyOp
(
*
remote_recv_no_scalar
),
unwrap_inputs
(
inputs
));
}
else
{
return
imperative
::
apply
(
remote_recv
,
unwrap_inputs
(
inputs
));
}
}
std
::
vector
<
ValueRef
>
check_no_finite_rule
(
const
CheckNonFinite
&
check_no_finite
,
Span
<
ValueRef
>
inputs
)
{
auto
outputs
=
imperative
::
apply
(
check_no_finite
,
unwrap_inputs
(
inputs
));
mgb_assert
(
outputs
.
size
()
==
inputs
.
size
()
+
1
,
"output size mismatch"
);
outputs
.
back
()
=
ScalarValue
::
make
(
outputs
.
back
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
inputs
[
i
].
is
<
ScalarValue
>
())
{
outputs
[
i
]
=
ScalarValue
::
make
(
outputs
[
i
]);
}
}
return
outputs
;
}
std
::
vector
<
ValueRef
>
subtensor_rule
(
const
Subtensor
&
subtensor
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
>=
1
);
auto
input
=
inputs
[
0
];
size_t
ndim
=
input
.
is
<
ScalarValue
>
()
?
0
:
input
.
shape
()
->
ndim
;
for
(
auto
&&
[
axis
,
begin
,
end
,
step
,
idx
]
:
subtensor
.
items
)
{
if
(
idx
)
{
ndim
--
;
}
}
auto
output
=
imperative
::
apply
(
subtensor
,
unwrap_inputs
(
inputs
))[
0
];
if
(
!
ndim
)
{
return
{
ScalarValue
::
make
(
output
)};
}
else
{
return
{
output
};
}
}
std
::
vector
<
ValueRef
>
get_var_shape_rule
(
const
GetVarShape
&
get_var_shape
,
Span
<
ValueRef
>
inputs
)
{
bool
all_scalar
=
true
;
mgb_assert
(
inputs
.
size
()
>=
1
);
for
(
auto
&&
input
:
inputs
)
{
if
(
!
input
.
is
<
ScalarValue
>
())
{
all_scalar
=
false
;
}
}
if
(
all_scalar
)
{
auto
device
=
inputs
[
0
].
cast
<
ScalarValue
>
().
value
().
device
();
auto
storage
=
HostStorage
::
make
(
*
device
);
// storage->ensure_size(1);
return
imperative
::
apply
(
CreateTensor
(
CreateTensor
::
Const
,
*
device
,
dtype
::
Int32
(),
ValueShape
{
0
}),
storage
);
}
else
{
return
imperative
::
apply
(
get_var_shape
,
unwrap_inputs
(
inputs
));
}
}
std
::
vector
<
ValueRef
>
fastpath_copy_rule
(
const
FastpathCopy
&
fastpath_copy
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
bool
is_scalar
=
inputs
[
0
].
is
<
ScalarValue
>
();
auto
output
=
imperative
::
apply
(
fastpath_copy
,
unwrap_inputs
(
inputs
))[
0
];
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
output
)};
}
else
{
return
{
output
};
}
}
std
::
vector
<
ValueRef
>
reshape_rule
(
const
Reshape
&
reshape
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
2
);
bool
is_scalar
=
(
!
inputs
[
1
].
is
<
ScalarValue
>
())
&&
*
inputs
[
1
].
shape
()
==
ValueShape
{
0
};
auto
unwrapped_input
=
inputs
[
0
].
is
<
ScalarValue
>
()
?
inputs
[
0
].
cast
<
ScalarValue
>
().
value
()
:
inputs
[
0
];
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
reshape
,
unwrapped_input
,
make_scalar_shape
(
*
unwrapped_input
.
device
()))[
0
])};
}
else
{
return
imperative
::
apply
(
reshape
,
unwrap_inputs
(
inputs
));
}
}
std
::
vector
<
ValueRef
>
broadcast_rule
(
const
Broadcast
&
broadcast
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
2
);
bool
is_scalar
=
is_scalar_shape
(
inputs
[
1
]);
auto
unwrapped_input
=
inputs
[
0
].
is
<
ScalarValue
>
()
?
inputs
[
0
].
cast
<
ScalarValue
>
().
value
()
:
inputs
[
0
];
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
broadcast
,
unwrapped_input
,
make_scalar_shape
(
*
unwrapped_input
.
device
()))[
0
])};
}
else
{
return
imperative
::
apply
(
broadcast
,
unwrap_inputs
(
inputs
));
}
}
std
::
vector
<
ValueRef
>
copy_rule
(
const
Copy
&
copy
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
1
);
bool
is_scalar
=
inputs
[
0
].
is
<
ScalarValue
>
();
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
copy
,
unwrap_inputs
(
inputs
))[
0
])};
}
else
{
return
imperative
::
apply
(
copy
,
unwrap_inputs
(
inputs
));
}
}
std
::
vector
<
ValueRef
>
inplace_add_rule
(
const
InplaceAdd
&
inplace_add
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
4
);
bool
is_scalar
=
inputs
[
0
].
is
<
ScalarValue
>
();
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
inplace_add
,
unwrap_inputs
(
inputs
))[
0
])};
}
else
{
return
imperative
::
apply
(
inplace_add
,
unwrap_inputs
(
inputs
));
}
}
struct
ScalarRuleRegistry
{
ScalarRuleRegistry
()
{
register_scalar_rule
(
elemwise_rule
);
register_scalar_rule
(
remove_axis_rule
);
register_scalar_rule
(
reduce_rule
);
register_scalar_rule
(
typecvt_rule
);
register_scalar_rule
(
collective_comm_rule
);
register_scalar_rule
(
param_pack_split_rule
);
register_scalar_rule
(
dot_rule
);
register_scalar_rule
(
add_axis_rule
);
register_scalar_rule
(
remote_recv_rule
);
register_scalar_rule
(
check_no_finite_rule
);
register_scalar_rule
(
subtensor_rule
);
register_scalar_rule
(
get_var_shape_rule
);
register_scalar_rule
(
fastpath_copy_rule
);
register_scalar_rule
(
reshape_rule
);
register_scalar_rule
(
broadcast_rule
);
register_scalar_rule
(
copy_rule
);
register_scalar_rule
(
inplace_add_rule
);
}
}
_
;
}
// namespace
std
::
vector
<
ValueRef
>
ScalarTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
apply_op
=
op
.
as
<
ApplyOp
>
())
{
auto
iter
=
scalar_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
if
(
iter
!=
scalar_rules
.
end
())
{
return
iter
->
second
(
apply_op
->
op
(),
inputs
);
}
else
{
// TODO: repeat op
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
if
(
create_tensor
->
shape
().
is_scalar
())
{
ValueShape
scalar_shape
=
{
1
};
CreateTensor
scalar_op
(
create_tensor
->
kind
(),
create_tensor
->
device
(),
create_tensor
->
dtype
(),
scalar_shape
);
return
{
ScalarValue
::
make
(
imperative
::
apply
(
scalar_op
,
inputs
)[
0
])};
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
}
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
bool
is_scalar
=
inputs
.
as_array
<
1
>
()[
0
].
is
<
ScalarValue
>
();
auto
output
=
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
))[
0
];
if
(
!
is_scalar
)
{
return
{
output
};
}
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
Shape
:
{
// Scalar Shape
return
{
ShapeValue
::
make
()};
}
case
GetAttr
::
Value
:
{
auto
&
hv
=
output
.
cast
<
HostValue
>
();
mgb_assert
(
hv
.
shape
()
==
ValueShape
({
1
}),
"underlying value should has shape {1}, got %s"
,
hv
.
shape
().
to_string
().
c_str
());
return
{
HostValue
::
make
(
hv
.
dtype
(),
ValueShape
(),
hv
.
storage
())};
}
case
GetAttr
::
Data
:
{
auto
&
dv
=
output
.
cast
<
DeviceValue
>
();
mgb_assert
(
dv
.
shape
()
==
ValueShape
({
1
}),
"underlying value should has shape {1}, got %s"
,
dv
.
shape
().
to_string
().
c_str
());
return
{
DeviceValue
::
make
(
dv
.
dtype
(),
ValueShape
(),
dv
.
storage
())};
}
default:
return
{
output
};
}
}
else
if
(
op
.
as
<
IsScalar
>
())
{
return
{
BoolValue
::
make
(
inputs
.
as_array
<
1
>
()[
0
].
is
<
ScalarValue
>
())};
}
else
if
(
op
.
is
<
Operator
::
IdentityLike
>
())
{
bool
is_scalar
=
inputs
.
as_array
<
1
>
()[
0
].
is
<
ScalarValue
>
();
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
))[
0
])};
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
}
}
else
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
};
}
// namespace imperative
}
// namespace mgb
imperative/src/include/megbrain/imperative/transformations/scalar.h
0 → 100644
浏览文件 @
e32929df
/**
* \file imperative/src/include/megbrain/imperative/scalar.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/ops/autogen.h"
namespace
mgb
::
imperative
{
class
ScalarValue
final
:
public
ValueImpl
<
ScalarValue
>
{
private:
ValueRef
m_value
;
public:
ScalarValue
(
ValueRef
value
)
:
m_value
(
value
)
{}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"ScalarValue{value=%s}"
,
m_value
.
to_string
().
c_str
());
}
ValueRef
value
()
const
{
return
m_value
;
}
void
clear
()
override
{
m_value
=
{};
}
void
on_watch
()
override
{
m_value
.
watch
();
}
void
on_unwatch
()
override
{
m_value
.
unwatch
();
}
};
/**
* \brief simulates scalar because megbrain graph system don't support scalar
*
* Assume that we has 'a = ScalarValue(b)', thus 'a.shape == []', 'b.shape == [1]'.
* This transformation simulates scalars with a flag. If a value is ScalarValue, it is
* scalar, vice versa. So there is not scalar down this layer.
*/
class
ScalarTransformation
final
:
public
Transformation
{
private:
public:
std
::
vector
<
ValueRef
>
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
mgb_assert
(
!
value
.
is
<
ScalarValue
>
());
return
value
;
}
std
::
string
name
()
const
override
{
return
"ScalarTransformation"
;
}
};
}
// namespace mgb::imperative
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录