Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cf198dc9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cf198dc9
编写于
5月 10, 2022
作者:
X
xiongkun
提交者:
GitHub
5月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[EinsumOp] Polish forward logic and backward logic for optimize (#42603)
* change logic for optimize * modifty
上级
02e5c4be
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
148 addition
and
64 deletion
+148
-64
paddle/phi/kernels/impl/einsum_grad_impl.h
paddle/phi/kernels/impl/einsum_grad_impl.h
+4
-2
paddle/phi/kernels/impl/einsum_impl.h
paddle/phi/kernels/impl/einsum_impl.h
+129
-61
python/paddle/fluid/tests/unittests/test_einsum_v2.py
python/paddle/fluid/tests/unittests/test_einsum_v2.py
+15
-1
未找到文件。
paddle/phi/kernels/impl/einsum_grad_impl.h
浏览文件 @
cf198dc9
...
...
@@ -148,14 +148,16 @@ void EinsumGradKernel(const Context& dev_ctx,
right
=
splits
[
1
].
substr
(
1
);
auto
equation_for_A
=
right
+
","
+
ops
[
1
]
+
"->"
+
gather_labels_except_reduction
(
ops
[
0
]);
ops
[
1
]
+
","
+
right
+
"->"
+
gather_labels_except_reduction
(
ops
[
0
]);
auto
equation_for_B
=
right
+
","
+
ops
[
0
]
+
"->"
+
gather_labels_except_reduction
(
ops
[
1
]);
auto
operands_for_A
=
std
::
vector
<
const
DenseTensor
*>
();
auto
operands_for_B
=
std
::
vector
<
const
DenseTensor
*>
();
DenseTensor
dA
,
dB
;
operands_for_A
.
push_back
(
&
out_grad
);
// dA = einsum(B, dC)
operands_for_A
.
push_back
(
x
[
1
]);
operands_for_A
.
push_back
(
&
out_grad
);
// dB = einsum(dC, A)
operands_for_B
.
push_back
(
&
out_grad
);
operands_for_B
.
push_back
(
x
[
0
]);
...
...
paddle/phi/kernels/impl/einsum_impl.h
浏览文件 @
cf198dc9
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <set>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
...
...
@@ -55,7 +56,8 @@ inline static void ValidationCheck(const std::string& equation) {
enum
LabelType
{
ALL_TYPE
=
0
,
Batch
=
1
,
// ABO
Free
,
// AO, BO
AO
,
// AO -- free label
BO
,
// BO -- free label
Contraction
,
// AB
Reduction
,
// A, B
};
...
...
@@ -125,18 +127,32 @@ inline std::vector<char> union_labels(const std::vector<char>& a,
return
res
;
}
// Apply transforms to all_labels and get another all_labels
inline
std
::
vector
<
char
>
TransformLabelsOrder
(
const
std
::
vector
<
char
>&
all_labels
,
const
LabelMap
&
type
,
std
::
vector
<
LabelType
>
new_order
)
{
std
::
vector
<
char
>
ret
;
for
(
auto
cnt_type
:
new_order
)
{
std
::
vector
<
char
>
tmp
;
for
(
int
c
:
all_labels
)
{
if
(
type
[
c
]
==
cnt_type
)
tmp
.
push_back
(
c
);
std
::
sort
(
tmp
.
begin
(),
tmp
.
end
());
}
ret
.
insert
(
ret
.
end
(),
tmp
.
begin
(),
tmp
.
end
());
}
return
ret
;
}
inline
static
void
GlobalInfo
(
const
std
::
vector
<
std
::
string
>&
op_labels
,
const
std
::
string
&
right
,
LabelMap
*
label2type
,
std
::
vector
<
char
>*
sorted_labels
)
{
// sorted_labels: ['.', <right>, <left only label>]
VLOG
(
5
)
<<
"GlobalInfo: "
<<
paddle
::
string
::
join_strings
(
*
sorted_labels
,
","
);
std
::
vector
<
char
>
all
;
LabelMap
counter
(
0
);
for
(
auto
&
ch
:
right
)
{
// char
int
c
=
ch
;
(
*
label2type
)[
c
]
=
LabelType
::
Free
;
(
*
label2type
)[
c
]
=
LabelType
::
BO
;
}
for
(
auto
&
op
:
op_labels
)
{
...
...
@@ -146,39 +162,36 @@ inline static void GlobalInfo(const std::vector<std::string>& op_labels,
all
.
push_back
(
ch
);
}
counter
[
c
]
+=
1
;
if
((
*
label2type
)[
c
]
!=
LabelType
::
Free
&&
counter
[
c
]
==
2
)
if
((
*
label2type
)[
c
]
!=
LabelType
::
BO
&&
counter
[
c
]
==
2
)
(
*
label2type
)[
c
]
=
LabelType
::
Contraction
;
else
if
(
counter
[
c
]
==
2
)
(
*
label2type
)[
c
]
=
LabelType
::
Batch
;
}
}
// BO is represent Free, so we need find the AO.
for
(
int
c
:
op_labels
[
0
])
{
if
((
*
label2type
)[
c
]
==
LabelType
::
BO
)
(
*
label2type
)[
c
]
=
LabelType
::
AO
;
}
(
*
label2type
)[
'.'
]
=
LabelType
::
Batch
;
std
::
for_each
(
all
.
begin
(),
all
.
end
(),
[
sorted_labels
,
label2type
](
int
c
)
{
if
((
*
label2type
)[
c
]
==
LabelType
::
Batch
)
sorted_labels
->
push_back
(
static_cast
<
char
>
(
c
));
});
std
::
for_each
(
all
.
begin
(),
all
.
end
(),
[
sorted_labels
,
label2type
](
int
c
)
{
if
((
*
label2type
)[
c
]
==
LabelType
::
Free
)
sorted_labels
->
push_back
(
static_cast
<
char
>
(
c
));
});
std
::
for_each
(
all
.
begin
(),
all
.
end
(),
[
sorted_labels
,
label2type
](
int
c
)
{
if
((
*
label2type
)[
c
]
==
LabelType
::
Contraction
)
sorted_labels
->
push_back
(
static_cast
<
char
>
(
c
));
});
std
::
for_each
(
all
.
begin
(),
all
.
end
(),
[
&
sorted_labels
,
label2type
](
int
c
)
{
if
((
*
label2type
)[
c
]
==
LabelType
::
Reduction
)
sorted_labels
->
push_back
(
static_cast
<
char
>
(
c
));
});
VLOG
(
5
)
<<
"GlobalInfo: sorted_labels before: "
<<
paddle
::
string
::
join_strings
(
*
sorted_labels
,
","
);
*
sorted_labels
=
TransformLabelsOrder
(
all
,
*
label2type
,
{
LabelType
::
Batch
,
LabelType
::
AO
,
LabelType
::
BO
,
LabelType
::
Contraction
,
LabelType
::
Reduction
});
if
(
counter
[
static_cast
<
int
>
(
'.'
)]
>
0
)
{
std
::
vector
<
char
>
tmp
;
tmp
.
push_back
(
'.'
);
// push '.' in the front
*
sorted_labels
=
union_labels
(
tmp
,
*
sorted_labels
);
VLOG
(
5
)
<<
"GlobalInfo: sorted_labels after: "
<<
paddle
::
string
::
join_strings
(
*
sorted_labels
,
","
);
}
VLOG
(
5
)
<<
"GlobalInfo: sorted_labels after: "
<<
paddle
::
string
::
join_strings
(
*
sorted_labels
,
","
);
}
inline
static
void
InferLabelShape
(
const
std
::
vector
<
std
::
string
>&
op_labels
,
...
...
@@ -289,17 +302,20 @@ inline static void ParseEinsumEquation(
*
right
=
results
[
1
].
substr
(
1
);
ReplaceEllipsis
(
*
right
);
auto
op_labels
=
paddle
::
string
::
split_string
(
left
,
","
);
// split_string("i,") -> ["i"], we expect 2 op_labels.
if
(
left
[
left
.
size
()
-
1
]
==
','
)
op_labels
.
push_back
(
""
);
std
::
for_each
(
op_labels
.
begin
(),
op_labels
.
end
(),
ReplaceEllipsis
);
GlobalInfo
(
op_labels
,
*
right
,
labeltype
,
all_labels
);
InferLabelShape
(
op_labels
,
inputs
,
labelshape
,
ellipsis_dims
,
broadcast_dims
);
VLOG
(
5
)
<<
"Einsum Infershape: right:"
<<
right
;
VLOG
(
5
)
<<
"Einsum Infershape:
op_labels
:"
<<
paddle
::
string
::
join_strings
(
op_labels
,
"
\n
"
);
VLOG
(
5
)
<<
"Einsum Infershape: right:"
<<
*
right
;
VLOG
(
5
)
<<
"Einsum Infershape:
left
:"
<<
paddle
::
string
::
join_strings
(
op_labels
,
'\n'
);
InferOutputDims
(
*
right
,
*
broadcast_dims
,
*
labelshape
,
output_dims
);
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
InferLabelPerm
(
op_labels
[
i
],
ellipsis_dims
->
at
(
i
).
size
(),
&
((
*
label2perms
)[
i
]));
}
VLOG
(
5
)
<<
"Einsum Infershape: end"
;
}
template
<
typename
T
>
...
...
@@ -327,10 +343,12 @@ std::vector<T> GetShapeByType(const std::vector<char>& all_labels,
const
LabelMap
&
perm
,
const
LabelMap
&
label2shape
,
const
std
::
vector
<
int
>&
ellipsis
,
LabelType
filter
)
{
std
::
set
<
LabelType
>
filter
)
{
std
::
vector
<
T
>
res
;
for
(
T
c
:
all_labels
)
{
if
((
filter
==
LabelType
::
ALL_TYPE
||
type
[
c
]
==
filter
)
&&
perm
[
c
]
!=
-
1
)
{
if
((
filter
.
count
(
LabelType
::
ALL_TYPE
)
||
filter
.
count
(
LabelType
(
type
[
c
])))
&&
perm
[
c
]
!=
-
1
)
{
if
(
c
==
'.'
)
res
.
insert
(
res
.
end
(),
ellipsis
.
begin
(),
ellipsis
.
end
());
else
...
...
@@ -390,7 +408,8 @@ DenseTensor PerformContraction(
const
LabelMap
&
label2type
,
const
LabelMap
&
label2shape
,
const
std
::
vector
<
std
::
vector
<
int
>>&
ellipsis_dims
,
const
std
::
vector
<
int
>&
broadcast_dims
)
{
const
std
::
vector
<
int
>&
broadcast_dims
,
std
::
vector
<
DenseTensor
*>
cache
)
{
// Get All the Batches, so perm is
auto
all_valid
=
LabelMap
(
1
);
auto
recover_dim
=
GetShapeByType
<
int
>
(
all_labels
,
...
...
@@ -398,36 +417,74 @@ DenseTensor PerformContraction(
all_valid
,
label2shape
,
broadcast_dims
,
LabelType
::
Batch
);
{
LabelType
::
Batch
}
);
auto
preprocess
=
[
&
](
const
DenseTensor
&
t
,
const
LabelMap
&
perm
,
const
std
::
vector
<
int
>&
ellipsis
)
->
DenseTensor
{
auto
frees
=
GetShapeByType
<
int
>
(
all_labels
,
label2type
,
perm
,
label2shape
,
ellipsis
,
LabelType
::
Free
);
const
std
::
vector
<
int
>&
ellipsis
,
int
operand_idx
)
->
DenseTensor
{
// reshape
auto
frees
=
GetShapeByType
<
int
>
(
all_labels
,
label2type
,
perm
,
label2shape
,
ellipsis
,
{
LabelType
::
AO
,
LabelType
::
BO
});
auto
conts
=
GetShapeByType
<
int
>
(
all_labels
,
label2type
,
perm
,
label2shape
,
ellipsis
,
LabelType
::
Contraction
);
auto
trans_t
=
PerformTranspose
<
T
,
Context
>
(
dev_ctx
,
t
,
perm
,
all_labels
,
ellipsis
,
label2type
);
auto
mul_dims
=
GetShapeByType
<
int
>
(
all_labels
,
label2type
,
perm
,
label2shape
,
ellipsis
,
LabelType
::
Batch
);
{
LabelType
::
Contraction
});
std
::
vector
<
char
>
reordered_all_labels
=
all_labels
;
if
(
operand_idx
==
1
)
{
reordered_all_labels
=
TransformLabelsOrder
(
all_labels
,
label2type
,
{
LabelType
::
Batch
,
LabelType
::
Contraction
,
LabelType
::
AO
,
LabelType
::
BO
,
LabelType
::
Reduction
});
}
// reduction
DenseTensor
trans_t
;
if
(
cache
[
operand_idx
]
->
IsInitialized
())
{
trans_t
.
ShareBufferWith
(
*
(
cache
[
operand_idx
]));
}
else
{
auto
reduct_t
=
PerformReduction
<
T
,
Context
>
(
dev_ctx
,
t
,
perm
,
all_labels
,
ellipsis
,
label2type
);
trans_t
=
PerformTranspose
<
T
,
Context
>
(
dev_ctx
,
reduct_t
,
perm
,
reordered_all_labels
,
ellipsis
,
label2type
);
cache
[
operand_idx
]
->
ShareBufferWith
(
trans_t
);
}
auto
mul_dims
=
GetShapeByType
<
int
>
(
all_labels
,
label2type
,
perm
,
label2shape
,
ellipsis
,
{
LabelType
::
Batch
});
recover_dim
.
insert
(
recover_dim
.
end
(),
frees
.
begin
(),
frees
.
end
());
mul_dims
.
push_back
(
std
::
accumulate
(
frees
.
begin
(),
frees
.
end
(),
1
,
std
::
multiplies
<
int
>
()));
mul_dims
.
push_back
(
std
::
accumulate
(
conts
.
begin
(),
conts
.
end
(),
1
,
std
::
multiplies
<
int
>
()));
if
(
operand_idx
==
0
)
{
mul_dims
.
push_back
(
std
::
accumulate
(
frees
.
begin
(),
frees
.
end
(),
1
,
std
::
multiplies
<
int
>
()));
mul_dims
.
push_back
(
std
::
accumulate
(
conts
.
begin
(),
conts
.
end
(),
1
,
std
::
multiplies
<
int
>
()));
}
else
{
mul_dims
.
push_back
(
std
::
accumulate
(
conts
.
begin
(),
conts
.
end
(),
1
,
std
::
multiplies
<
int
>
()));
mul_dims
.
push_back
(
std
::
accumulate
(
frees
.
begin
(),
frees
.
end
(),
1
,
std
::
multiplies
<
int
>
()));
}
VLOG
(
5
)
<<
"PerformContraction: mul_dims: "
<<
paddle
::
string
::
join_strings
(
mul_dims
,
","
);
trans_t
.
Resize
(
make_ddim
(
mul_dims
));
return
trans_t
;
};
auto
trans_a
=
preprocess
(
A
,
label2perm
[
0
],
ellipsis_dims
[
0
]);
auto
trans_b
=
preprocess
(
B
,
label2perm
[
1
],
ellipsis_dims
[
1
]);
// Reduction, Reshape and Matmul
auto
trans_a
=
preprocess
(
A
,
label2perm
[
0
],
ellipsis_dims
[
0
],
0
);
auto
trans_b
=
preprocess
(
B
,
label2perm
[
1
],
ellipsis_dims
[
1
],
1
);
auto
after_contraction
=
Matmul
<
T
,
Context
>
(
dev_ctx
,
trans_a
,
trans_b
,
false
,
tru
e
);
Matmul
<
T
,
Context
>
(
dev_ctx
,
trans_a
,
trans_b
,
false
,
fals
e
);
VLOG
(
5
)
<<
"PerformContraction: recover_dim: "
<<
paddle
::
string
::
join_strings
(
recover_dim
,
","
);
after_contraction
.
Resize
(
make_ddim
(
recover_dim
));
...
...
@@ -465,10 +522,11 @@ void TransposeToOutput(const Context& dev_ctx,
}
template
<
typename
T
,
typename
Context
>
void
EinsumKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
inputs
,
const
std
::
string
&
equation
,
DenseTensor
*
out
)
{
void
EinsumKernelImpl
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
inputs
,
const
std
::
string
&
equation
,
DenseTensor
*
out
,
std
::
vector
<
DenseTensor
*>
cache
)
{
ValidationCheck
(
equation
);
// collect the following informations to prepare einsum.
LabelMap
labelshape
(
0
);
...
...
@@ -498,22 +556,18 @@ void EinsumKernel(const Context& dev_ctx,
if
(
inputs
.
size
()
==
2
)
{
auto
&
A
=
inputs
[
0
];
auto
&
B
=
inputs
[
1
];
// Reduce Procedure
auto
reduce_A
=
PerformReduction
<
T
,
Context
>
(
dev_ctx
,
*
A
,
label2perms
[
0
],
all_labels
,
ellipsis_dims
[
0
],
labeltype
);
auto
reduce_B
=
PerformReduction
<
T
,
Context
>
(
dev_ctx
,
*
B
,
label2perms
[
1
],
all_labels
,
ellipsis_dims
[
1
],
labeltype
);
// Contract Procedure
// Reduction and Contract Procedure
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
after_contraction
=
PerformContraction
<
T
,
Context
>
(
dev_ctx
,
reduce_
A
,
reduce_
B
,
*
A
,
*
B
,
label2perms
,
all_labels
,
labeltype
,
labelshape
,
ellipsis_dims
,
broadcast_dims
);
broadcast_dims
,
cache
);
TransposeToOutput
<
T
,
Context
>
(
dev_ctx
,
after_contraction
,
right
,
...
...
@@ -545,4 +599,18 @@ void EinsumKernel(const Context& dev_ctx,
}
}
template
<
typename
T
,
typename
Context
>
void
EinsumKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
inputs
,
const
std
::
string
&
equation
,
DenseTensor
*
out
)
{
std
::
vector
<
DenseTensor
>
cache
(
inputs
.
size
());
// set empty; TA, TB, TdC
std
::
vector
<
DenseTensor
*>
cache_tensor
(
inputs
.
size
());
// set empty; TA, TB, TdC
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
cache_tensor
[
i
]
=
&
cache
[
i
];
}
EinsumKernelImpl
<
T
,
Context
>
(
dev_ctx
,
inputs
,
equation
,
out
,
cache_tensor
);
}
}
// namespace phi
python/paddle/fluid/tests/unittests/test_einsum_v2.py
浏览文件 @
cf198dc9
...
...
@@ -464,5 +464,19 @@ class TestNumpyTests(unittest.TestCase):
self
.
check_output_equal
(
a
,
e
)
class
TestStaticGraphShape
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
enable_static
()
def
tearDown
(
self
):
paddle
.
disable_static
()
def
test_shape
(
self
):
A
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
-
1
])
B
=
paddle
.
static
.
data
(
name
=
'y'
,
shape
=
[
384
])
C
=
paddle
.
einsum
(
'i,d->id'
,
A
,
B
)
self
.
assertEqual
(
C
.
shape
,
(
-
1
,
384
))
if
__name__
==
"__main__"
:
u
u
nittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录