Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
39a9abaa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
39a9abaa
编写于
3月 07, 2023
作者:
Z
zhupengyang
提交者:
GitHub
3月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] support shared weight; delete isolated node (#51108)
上级
50ad760c
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
863 addition
and
223 deletion
+863
-223
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+10
-1
paddle/fluid/framework/ir/delete_op_device_pass.cc
paddle/fluid/framework/ir/delete_op_device_pass.cc
+4
-4
paddle/fluid/framework/ir/delete_op_device_pass_test.cc
paddle/fluid/framework/ir/delete_op_device_pass_test.cc
+1
-2
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+1
-0
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc
+203
-0
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc
.../fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc
+181
-0
paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
+30
-45
paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
+231
-160
paddle/fluid/framework/ir/xpu/pass_utils.cc
paddle/fluid/framework/ir/xpu/pass_utils.cc
+167
-0
paddle/fluid/framework/ir/xpu/pass_utils.h
paddle/fluid/framework/ir/xpu/pass_utils.h
+21
-0
paddle/fluid/framework/ir/xpu/quant_utils.cc
paddle/fluid/framework/ir/xpu/quant_utils.cc
+6
-6
paddle/fluid/framework/ir/xpu/quant_utils.h
paddle/fluid/framework/ir/xpu/quant_utils.h
+3
-3
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py
...s/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py
+4
-2
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
39a9abaa
...
@@ -220,7 +220,7 @@ if(WITH_XPU)
...
@@ -220,7 +220,7 @@ if(WITH_XPU)
cc_library
(
cc_library
(
xpu_pass_utils
xpu_pass_utils
SRCS xpu/pass_utils.cc
SRCS xpu/pass_utils.cc
DEPS pass
)
DEPS pass
xpu_quant_utils
)
set
(
XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils
)
set
(
XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils
)
pass_library
(
embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
pass_library
(
embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
${
XPU_PASS_DEPS
}
)
...
@@ -232,6 +232,8 @@ if(WITH_XPU)
...
@@ -232,6 +232,8 @@ if(WITH_XPU)
pass_library
(
generate_sequence_xpu_fuse_pass inference DIR xpu DEPS
pass_library
(
generate_sequence_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
${
XPU_PASS_DEPS
}
)
pass_library
(
link_xpu_op_max_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
link_xpu_op_max_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
delete_isolated_node_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
endif
()
endif
()
cc_library
(
cc_library
(
...
@@ -484,3 +486,10 @@ if(WITH_MKLDNN)
...
@@ -484,3 +486,10 @@ if(WITH_MKLDNN)
SRCS mkldnn/cpu_bfloat16_pass_tester.cc
SRCS mkldnn/cpu_bfloat16_pass_tester.cc
DEPS cpu_bfloat16_pass
)
DEPS cpu_bfloat16_pass
)
endif
()
endif
()
if
(
WITH_XPU
)
cc_test
(
test_delete_isolated_node_pass
SRCS xpu/delete_isolated_node_pass_test.cc
DEPS delete_isolated_node_pass
)
endif
()
paddle/fluid/framework/ir/delete_op_device_pass.cc
浏览文件 @
39a9abaa
...
@@ -39,14 +39,14 @@ class DeleteOpDevicePass : public Pass {
...
@@ -39,14 +39,14 @@ class DeleteOpDevicePass : public Pass {
void
DeleteOpDevicePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
DeleteOpDevicePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
int
found_subgraph_count
=
0
;
int
delete_counts
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
()
||
!
node
->
Op
()
->
HasAttr
(
"op_device"
))
continue
;
if
(
!
node
->
IsOp
()
||
!
node
->
Op
()
->
HasAttr
(
"op_device"
))
continue
;
node
->
Op
()
->
RemoveAttr
(
"op_device"
);
node
->
Op
()
->
RemoveAttr
(
"op_device"
);
found_subgraph_count
++
;
delete_counts
++
;
}
}
if
(
found_subgraph_count
>
0
)
{
if
(
delete_counts
>
0
)
{
LOG
(
INFO
)
<<
"--- de
tected "
<<
found_subgraph_count
<<
" subgraphs
"
;
LOG
(
INFO
)
<<
"--- de
lete "
<<
delete_counts
<<
" op_device attr
"
;
}
}
}
}
...
...
paddle/fluid/framework/ir/delete_op_device_pass_test.cc
浏览文件 @
39a9abaa
...
@@ -13,8 +13,7 @@
...
@@ -13,8 +13,7 @@
// limitations under the License.
// limitations under the License.
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/framework/ir/pass.cc
浏览文件 @
39a9abaa
...
@@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
...
@@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass"
,
"fuse_multi_transformer_layer_pass"
,
"delete_quant_dequant_linear_op_pass"
,
"delete_quant_dequant_linear_op_pass"
,
"delete_weight_dequant_linear_op_pass"
,
"delete_weight_dequant_linear_op_pass"
,
"fc_xpu_fuse_pass"
,
"delete_op_device_pass"
};
"delete_op_device_pass"
};
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
...
...
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc
0 → 100644
浏览文件 @
39a9abaa
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/scope.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
DeleteIsolatedNodePass
:
public
Pass
{
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
private:
void
CollectReservedPersistableNodeNames
(
Graph
*
graph
,
std
::
unordered_set
<
std
::
string
>*
reserved_persistable_node_names
)
const
;
int
RemoveIsolatedNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
reserved_persistable_node_names
,
std
::
unordered_set
<
std
::
string
>*
delete_node_names
)
const
;
int
UpdateControlFlowOp
(
Graph
*
graph
,
const
std
::
map
<
int
,
Graph
*>&
block_id_graph_map
,
const
std
::
unordered_set
<
std
::
string
>&
delete_node_names
)
const
;
const
std
::
map
<
std
::
string
,
std
::
string
>
control_flow_op_input_map_
{
{
"while"
,
"X"
},
{
"conditional_block"
,
"Input"
},
};
};
void
DeleteIsolatedNodePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
PADDLE_ENFORCE
(
graph
->
IsMainGraph
(),
platform
::
errors
::
PreconditionNotMet
(
"Pass(apply in main graph) will delete isolated nodes in "
"all subgraphs. Do not apply pass in subgraph."
));
std
::
unordered_set
<
std
::
string
>
reserved_persistable_node_names
;
for
(
size_t
i
=
0
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
CollectReservedPersistableNodeNames
(
graph
->
GetSubGraph
(
i
),
&
reserved_persistable_node_names
);
}
int
delete_counts
=
0
;
std
::
unordered_set
<
std
::
string
>
delete_node_names
;
for
(
size_t
i
=
0
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
delete_counts
+=
RemoveIsolatedNodes
(
graph
->
GetSubGraph
(
i
),
reserved_persistable_node_names
,
&
delete_node_names
);
}
if
(
delete_counts
>
0
)
{
LOG
(
INFO
)
<<
"--- delete "
<<
delete_counts
<<
" isolated nodes"
;
}
std
::
map
<
int
,
Graph
*>
block_id_graph_map
;
for
(
size_t
i
=
0
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
auto
*
sub_graph
=
graph
->
GetSubGraph
(
i
);
for
(
auto
*
node
:
sub_graph
->
Nodes
())
{
if
(
node
->
IsVar
())
{
block_id_graph_map
[
node
->
GetVarNodeBlockId
()]
=
sub_graph
;
break
;
}
}
}
int
update_counts
=
0
;
for
(
size_t
i
=
0
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
update_counts
+=
UpdateControlFlowOp
(
graph
->
GetSubGraph
(
i
),
block_id_graph_map
,
delete_node_names
);
}
if
(
update_counts
>
0
)
{
LOG
(
INFO
)
<<
"--- update "
<<
update_counts
<<
" control flow ops"
;
}
}
void
DeleteIsolatedNodePass
::
CollectReservedPersistableNodeNames
(
Graph
*
graph
,
std
::
unordered_set
<
std
::
string
>*
reserved_persistable_node_names
)
const
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
()
||
!
node
->
Var
()
->
Persistable
())
continue
;
for
(
auto
*
out_node
:
node
->
outputs
)
{
auto
op_type
=
out_node
->
Op
()
->
Type
();
if
(
control_flow_op_input_map_
.
count
(
op_type
)
==
0
)
{
reserved_persistable_node_names
->
insert
(
node
->
Var
()
->
Name
());
break
;
}
}
}
}
int
DeleteIsolatedNodePass
::
RemoveIsolatedNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
reserved_persistable_node_names
,
std
::
unordered_set
<
std
::
string
>*
delete_node_names
)
const
{
BlockDesc
*
block
=
nullptr
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
block
=
node
->
Op
()
->
Block
();
}
}
Scope
&
scope
=
graph
->
Get
<
framework
::
Scope
>
(
"__param_scope__"
);
// If graph has nodes to delete:
// 1. Clear var_desc in block
// 2. Clear tensor in variable
// 3. Clear variable in scope
int
delete_node_counts
=
0
;
std
::
unordered_set
<
const
Node
*>
delete_nodes
;
const
std
::
unordered_set
<
ir
::
Node
*>
nodes
=
graph
->
Nodes
();
for
(
auto
*
node
:
nodes
)
{
if
(
!
node
->
IsVar
()
||
!
node
->
Var
()
->
Persistable
())
continue
;
auto
name
=
node
->
Var
()
->
Name
();
if
(
reserved_persistable_node_names
.
count
(
name
)
>
0
)
continue
;
delete_nodes
.
insert
(
node
);
delete_node_names
->
insert
(
node
->
Name
());
block
->
RemoveVar
(
name
);
auto
*
var
=
scope
.
FindVar
(
name
);
if
(
var
!=
nullptr
)
{
var
->
Clear
();
scope
.
EraseVars
({
name
});
}
delete_node_counts
++
;
}
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
return
delete_node_counts
;
}
int
DeleteIsolatedNodePass
::
UpdateControlFlowOp
(
Graph
*
graph
,
const
std
::
map
<
int
,
Graph
*>&
block_id_graph_map
,
const
std
::
unordered_set
<
std
::
string
>&
delete_node_names
)
const
{
int
update_counts
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
())
continue
;
auto
op_type
=
node
->
Op
()
->
Type
();
if
(
control_flow_op_input_map_
.
count
(
op_type
)
==
0
)
continue
;
auto
in_arg_name
=
control_flow_op_input_map_
.
at
(
op_type
);
auto
in_name
=
node
->
Op
()
->
Input
(
in_arg_name
);
std
::
unordered_set
<
std
::
string
>
in_names_set
(
in_name
.
begin
(),
in_name
.
end
());
for
(
auto
delete_node_name
:
delete_node_names
)
{
if
(
in_names_set
.
count
(
delete_node_name
)
>
0
)
{
in_names_set
.
erase
(
delete_node_name
);
}
}
auto
*
sub_block
=
PADDLE_GET_CONST
(
framework
::
BlockDesc
*
,
node
->
Op
()
->
GetAttr
(
"sub_block"
));
auto
*
sub_graph
=
block_id_graph_map
.
at
(
sub_block
->
ID
());
std
::
unordered_set
<
std
::
string
>
sub_persistable_node_names
;
CollectReservedPersistableNodeNames
(
sub_graph
,
&
sub_persistable_node_names
);
for
(
auto
sub_name
:
sub_persistable_node_names
)
{
if
(
in_names_set
.
count
(
sub_name
)
>
0
)
continue
;
auto
*
in_node
=
FindNodeWithName
(
graph
,
sub_name
);
if
(
in_node
==
nullptr
)
continue
;
in_names_set
.
insert
(
sub_name
);
IR_NODE_LINK_TO
(
in_node
,
node
);
}
std
::
vector
<
std
::
string
>
new_in_names
(
in_names_set
.
begin
(),
in_names_set
.
end
());
node
->
Op
()
->
SetInput
(
in_arg_name
,
new_in_names
);
update_counts
++
;
}
return
update_counts
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_isolated_node_pass
,
paddle
::
framework
::
ir
::
DeleteIsolatedNodePass
);
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc
0 → 100644
浏览文件 @
39a9abaa
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
VarDesc
*
Data
(
paddle
::
framework
::
BlockDesc
*
block
,
std
::
string
name
,
std
::
vector
<
int64_t
>
shape
=
{},
bool
is_persistable
=
false
,
proto
::
VarType
::
Type
data_type
=
proto
::
VarType
::
FP32
)
{
auto
*
var
=
block
->
Var
(
name
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
var
->
SetDataType
(
data_type
);
var
->
SetShape
(
shape
);
var
->
SetPersistable
(
is_persistable
);
return
var
;
}
void
AddVarToScope
(
Scope
*
param_scope
,
const
std
::
string
&
name
,
const
DDim
&
dims
)
{
auto
*
tensor
=
param_scope
->
Var
(
name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
tensor
->
Resize
(
dims
);
auto
*
cpu_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
phi
::
CPUPlace
()));
auto
*
data
=
cpu_ctx
->
Alloc
<
float
>
(
tensor
);
int64_t
numel
=
tensor
->
numel
();
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
data
[
i
]
=
1
;
}
}
Scope
*
CreateParamScope
()
{
auto
param_scope
=
new
Scope
();
AddVarToScope
(
param_scope
,
"matmul0_w"
,
{
128
,
128
});
return
param_scope
;
}
int
WeightNodeNum
(
ir
::
Graph
*
graph
)
{
int
num
=
0
;
for
(
auto
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
()
&&
node
->
Var
()
->
Persistable
())
{
num
++
;
}
}
return
num
;
}
int
WeightTensorNum
(
Scope
*
scope
)
{
int
num
=
0
;
auto
vars
=
scope
->
LocalVars
();
for
(
auto
*
var
:
vars
)
{
if
(
var
->
Get
<
phi
::
DenseTensor
>
().
numel
()
>
0
)
{
num
++
;
}
}
return
num
;
}
TEST
(
delete_isolated_node_pass
,
basic
)
{
paddle
::
framework
::
ProgramDesc
program
;
auto
*
block0
=
program
.
MutableBlock
(
0
);
auto
*
block1
=
program
.
AppendBlock
(
*
block0
);
auto
*
matmul0_x
=
Data
(
block0
,
"matmul0_x"
,
{
1
,
128
});
auto
*
matmul0_w
=
Data
(
block0
,
"matmul0_w"
,
{
128
,
128
},
true
);
auto
*
matmul0_out
=
Data
(
block0
,
"matmul0_out"
,
{
1
,
128
});
OpDesc
*
matmul_op
=
block0
->
AppendOp
();
matmul_op
->
SetType
(
"matmul_v2"
);
matmul_op
->
SetInput
(
"X"
,
{
matmul0_x
->
Name
()});
matmul_op
->
SetInput
(
"Y"
,
{
matmul0_w
->
Name
()});
matmul_op
->
SetAttr
(
"trans_x"
,
false
);
matmul_op
->
SetAttr
(
"trans_y"
,
false
);
matmul_op
->
SetOutput
(
"Out"
,
{
matmul0_out
->
Name
()});
auto
*
while_out
=
Data
(
block0
,
"while_out"
,
{
1
,
128
});
auto
*
while_step_scopes
=
Data
(
block0
,
"while_step_scopes"
);
auto
*
while_cond
=
Data
(
block0
,
"while_cond"
);
OpDesc
*
while_op
=
block0
->
AppendOp
();
while_op
->
SetType
(
"while"
);
while_op
->
SetInput
(
"X"
,
{
matmul0_w
->
Name
(),
matmul0_out
->
Name
()});
while_op
->
SetInput
(
"Condition"
,
{
while_cond
->
Name
()});
while_op
->
SetOutput
(
"Out"
,
{
while_out
->
Name
()});
while_op
->
SetOutput
(
"StepScopes"
,
{
while_step_scopes
->
Name
()});
while_op
->
SetAttr
(
"sub_block"
,
{
block1
});
while_op
->
SetAttr
(
"is_test"
,
true
);
auto
*
matmul1_x
=
Data
(
block1
,
matmul0_out
->
Name
(),
matmul0_out
->
GetShape
());
auto
*
matmul1_w
=
Data
(
block1
,
matmul0_w
->
Name
(),
matmul0_w
->
GetShape
(),
true
);
auto
*
matmul1_out
=
Data
(
block1
,
"matmul1_out"
,
{
1
,
128
});
OpDesc
*
matmul1_op
=
block1
->
AppendOp
();
matmul1_op
->
SetType
(
"matmul_v2"
);
matmul1_op
->
SetInput
(
"X"
,
{
matmul1_x
->
Name
()});
matmul1_op
->
SetInput
(
"Y"
,
{
matmul1_w
->
Name
()});
matmul1_op
->
SetAttr
(
"trans_x"
,
false
);
matmul1_op
->
SetAttr
(
"trans_y"
,
false
);
matmul1_op
->
SetOutput
(
"Out"
,
{
matmul1_out
->
Name
()});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
program
));
auto
*
scope
=
CreateParamScope
();
graph
->
Set
(
"__param_scope__"
,
scope
);
auto
pass0
=
PassRegistry
::
Instance
().
Get
(
"fc_xpu_fuse_pass"
);
pass0
->
Apply
(
graph
.
get
());
pass0
->
Apply
(
graph
->
GetSubGraph
(
1
));
int
weight_node_num
=
WeightNodeNum
(
graph
.
get
())
+
WeightNodeNum
(
graph
->
GetSubGraph
(
1
));
PADDLE_ENFORCE_EQ
(
weight_node_num
,
6
,
platform
::
errors
::
PreconditionNotMet
(
"Graph should have 6 weight node after "
"fc_xpu_fuse_pass, but actually has %d."
,
weight_node_num
));
auto
pass1
=
PassRegistry
::
Instance
().
Get
(
"delete_isolated_node_pass"
);
pass1
->
Apply
(
graph
.
get
());
weight_node_num
=
WeightNodeNum
(
graph
.
get
())
+
WeightNodeNum
(
graph
->
GetSubGraph
(
1
));
PADDLE_ENFORCE_EQ
(
weight_node_num
,
4
,
platform
::
errors
::
PreconditionNotMet
(
"Graph should have 4 weight node after "
"delete_isolated_node_pass, but actually has %d."
,
weight_node_num
));
int
weight_tensor_num
=
WeightTensorNum
(
scope
);
PADDLE_ENFORCE_EQ
(
weight_tensor_num
,
2
,
platform
::
errors
::
PreconditionNotMet
(
"Scope should have 2 weight tensor after "
"delete_isolated_node_pass, but actually has %d."
,
weight_tensor_num
));
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"while"
)
{
auto
while_in_names
=
node
->
Op
()
->
Inputs
().
at
(
"X"
);
PADDLE_ENFORCE_EQ
(
while_in_names
.
size
(),
3
,
platform
::
errors
::
PreconditionNotMet
(
"While op should have 3 input after "
"delete_isolated_node_pass, but actually has %d."
,
while_in_names
.
size
()));
}
}
Scope
&
scope0
=
graph
->
Get
<
framework
::
Scope
>
(
"__param_scope__"
);
Scope
&
scope1
=
graph
->
GetSubGraph
(
1
)
->
Get
<
framework
::
Scope
>
(
"__param_scope__"
);
std
::
vector
<
std
::
string
>
shared_weight_names
{
matmul0_w
->
Name
()
+
"_int16"
,
matmul0_w
->
Name
()
+
"_max"
};
for
(
auto
name
:
shared_weight_names
)
{
auto
*
var0
=
scope0
.
FindVar
(
name
);
auto
*
var1
=
scope1
.
FindVar
(
name
);
PADDLE_ENFORCE
(
var0
==
var1
,
platform
::
errors
::
PreconditionNotMet
(
"Variables with the same name in two scopes is different."
));
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
delete_isolated_node_pass
);
paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
浏览文件 @
39a9abaa
...
@@ -76,7 +76,6 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern,
...
@@ -76,7 +76,6 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern,
auto
*
mul_w
=
pattern
->
NewNode
(
mul_w_repr
())
auto
*
mul_w
=
pattern
->
NewNode
(
mul_w_repr
())
->
assert_is_op_input
(
mul_type_
,
"Y"
)
->
assert_is_op_input
(
mul_type_
,
"Y"
)
->
assert_is_persistable_var
()
->
assert_is_persistable_var
()
->
assert_has_n_outputs
(
1
)
->
assert_more
([](
Node
*
node
)
{
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
2
;
return
node
->
Var
()
->
GetShape
().
size
()
==
2
;
});
});
...
@@ -169,7 +168,7 @@ class FcXPUFusePass : public FusePassBase {
...
@@ -169,7 +168,7 @@ class FcXPUFusePass : public FusePassBase {
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
private:
void
ApplyImpl
(
ir
::
Graph
*
graph
,
int
ApplyImpl
(
ir
::
Graph
*
graph
,
const
std
::
string
&
mul_type
,
const
std
::
string
&
mul_type
,
bool
with_bias
,
bool
with_bias
,
const
std
::
string
&
act_type
)
const
;
const
std
::
string
&
act_type
)
const
;
...
@@ -181,6 +180,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -181,6 +180,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
Init
(
name_scope_
,
graph
);
int
found_subgraph_count
=
0
;
for
(
auto
mul_type
:
{
"mul"
,
"matmul"
,
"matmul_v2"
})
{
for
(
auto
mul_type
:
{
"mul"
,
"matmul"
,
"matmul_v2"
})
{
for
(
auto
with_bias
:
{
true
,
false
})
{
for
(
auto
with_bias
:
{
true
,
false
})
{
for
(
auto
act_type
:
{
for
(
auto
act_type
:
{
...
@@ -189,13 +190,14 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -189,13 +190,14 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
"tanh"
,
"tanh"
,
""
,
""
,
})
{
})
{
ApplyImpl
(
graph
,
mul_type
,
with_bias
,
act_type
);
found_subgraph_count
+=
ApplyImpl
(
graph
,
mul_type
,
with_bias
,
act_type
);
}
}
}
}
}
}
AddStatis
(
found_subgraph_count
);
}
}
void
FcXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
,
int
FcXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
,
const
std
::
string
&
mul_type
,
const
std
::
string
&
mul_type
,
bool
with_bias
,
bool
with_bias
,
const
std
::
string
&
act_type
)
const
{
const
std
::
string
&
act_type
)
const
{
...
@@ -219,37 +221,20 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
...
@@ -219,37 +221,20 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
auto
*
block
=
mul
->
Op
()
->
Block
();
auto
*
block
=
mul
->
Op
()
->
Block
();
auto
*
scope
=
param_scope
();
auto
*
scope
=
param_scope
();
auto
mul_w_name
=
mul_w
->
Name
();
auto
mul_w_tensor
=
scope
->
FindVar
(
mul_w_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
// 1. Transform weight to int16/int31
// 2. Avoid transform repeatly, because weight may be shared with other ops.
// TODO(zhupengyang): support int31
std
::
string
mul_w_max_name
=
mul_w_name
+
"_max"
;
Node
*
mul_w_max
=
nullptr
;
if
(
mul_w_tensor
->
dtype
()
!=
phi
::
DataType
::
INT16
)
{
// Create weight_max node
VarDesc
mul_w_max_desc
(
mul_w_max_name
);
mul_w_max_desc
.
SetPersistable
(
true
);
mul_w_max
=
graph
->
CreateVarNode
(
&
mul_w_max_desc
);
// Create weight_max var/tensor
auto
mul_w_max_var
=
block
->
Var
(
mul_w_max_name
);
mul_w_max_var
->
SetPersistable
(
true
);
auto
mul_w_max_tensor
=
scope
->
Var
(
mul_w_max_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
bool
transpose_w
=
false
;
bool
transpose_w
=
false
;
if
(
mul_type
==
"matmul"
)
{
if
(
mul_type
==
"matmul"
)
{
transpose_w
=
PADDLE_GET_CONST
(
bool
,
mul
->
Op
()
->
GetAttr
(
"transpose_Y"
));
transpose_w
=
PADDLE_GET_CONST
(
bool
,
mul
->
Op
()
->
GetAttr
(
"transpose_Y"
));
}
else
if
(
mul_type
==
"matmul_v2"
)
{
}
else
if
(
mul_type
==
"matmul_v2"
)
{
transpose_w
=
PADDLE_GET_CONST
(
bool
,
mul
->
Op
()
->
GetAttr
(
"trans_y"
));
transpose_w
=
PADDLE_GET_CONST
(
bool
,
mul
->
Op
()
->
GetAttr
(
"trans_y"
));
}
}
QuantWeight
<
int16_t
>
(
mul_w_tensor
,
mul_w_max_tensor
,
!
transpose_w
);
Node
*
mul_w_int16
=
nullptr
;
}
Node
*
mul_w_max
=
nullptr
;
PrepareWeight
<
int16_t
>
(
graph
,
scope
,
block
,
mul_w
,
&
mul_w_int16
,
&
mul_w_max
,
!
transpose_w
);
Node
*
bias_fp32
=
nullptr
;
if
(
bias
!=
nullptr
)
{
if
(
bias
!=
nullptr
)
{
auto
*
bias_tensor
=
PrepareBias
(
graph
,
scope
,
block
,
bias
,
&
bias_fp32
);
scope
->
Var
(
bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
CastToFp32
(
bias_tensor
);
}
}
std
::
string
fc_out_name
;
std
::
string
fc_out_name
;
...
@@ -268,10 +253,10 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
...
@@ -268,10 +253,10 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
framework
::
OpDesc
fc_xpu_op_desc
(
block
);
framework
::
OpDesc
fc_xpu_op_desc
(
block
);
fc_xpu_op_desc
.
SetType
(
"fc_xpu"
);
fc_xpu_op_desc
.
SetType
(
"fc_xpu"
);
fc_xpu_op_desc
.
SetInput
(
"x"
,
{
mul_x
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"x"
,
{
mul_x
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"w"
,
{
mul_w
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"w"
,
{
mul_w
_int16
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"w_max"
,
{
mul_w_max
_name
});
fc_xpu_op_desc
.
SetInput
(
"w_max"
,
{
mul_w_max
->
Name
()
});
if
(
bias
)
{
if
(
bias
_fp32
)
{
fc_xpu_op_desc
.
SetInput
(
"bias"
,
{
bias
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"bias"
,
{
bias
_fp32
->
Name
()});
}
}
fc_xpu_op_desc
.
SetAttr
(
fc_xpu_op_desc
.
SetAttr
(
"in_num_col_dims"
,
"in_num_col_dims"
,
...
@@ -306,9 +291,9 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
...
@@ -306,9 +291,9 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
fc_xpu_op_desc
.
SetOutput
(
"out_max"
,
{
fc_out_max_name
});
fc_xpu_op_desc
.
SetOutput
(
"out_max"
,
{
fc_out_max_name
});
auto
*
fc_xpu
=
graph
->
CreateOpNode
(
&
fc_xpu_op_desc
);
auto
*
fc_xpu
=
graph
->
CreateOpNode
(
&
fc_xpu_op_desc
);
IR_NODE_LINK_TO
(
mul_x
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_x
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_w
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_w
_int16
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_w_max
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_w_max
,
fc_xpu
);
SAFE_IR_NODE_LINK_TO
(
bias
,
fc_xpu
);
SAFE_IR_NODE_LINK_TO
(
bias
_fp32
,
fc_xpu
);
if
(
act_out
)
{
if
(
act_out
)
{
IR_NODE_LINK_TO
(
fc_xpu
,
act_out
);
IR_NODE_LINK_TO
(
fc_xpu
,
act_out
);
}
else
if
(
add_out
)
{
}
else
if
(
add_out
)
{
...
@@ -334,7 +319,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
...
@@ -334,7 +319,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
};
};
gpd
(
graph
,
handler
);
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
)
;
return
found_subgraph_count
;
}
}
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
浏览文件 @
39a9abaa
...
@@ -625,16 +625,24 @@ class MultiEncoderXPUFusePass : public FusePassBase {
...
@@ -625,16 +625,24 @@ class MultiEncoderXPUFusePass : public FusePassBase {
// 2. Concat q_w, k_w, v_w
// 2. Concat q_w, k_w, v_w
// 3. Generate qkv_w_max tensor
// 3. Generate qkv_w_max tensor
// 4. Quant qkv_w to int16
// 4. Quant qkv_w to int16
void
PrepareQKVWeight
(
const
phi
::
DenseTensor
&
q_w
,
void
PrepareQKVWeight
(
Graph
*
graph
,
const
phi
::
DenseTensor
&
k_w
,
Scope
*
scope
,
const
phi
::
DenseTensor
&
v_w
,
BlockDesc
*
block
,
phi
::
DenseTensor
*
qkv_w
,
Node
*
q_w
,
phi
::
DenseTensor
*
qkv_w_max
)
const
;
Node
*
k_w
,
Node
*
v_w
,
void
ConcatQKVBias
(
const
phi
::
DenseTensor
&
q_bias
,
Node
**
qkv_w
,
const
phi
::
DenseTensor
&
k_bias
,
Node
**
qkv_w_max
)
const
;
const
phi
::
DenseTensor
&
v_bias
,
phi
::
DenseTensor
*
qkv_bias
)
const
;
// 1. Cast bias to fp32
// 2. Concat q/k/v bias
void
PrepareQKVBias
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
q_bias
,
Node
*
k_bias
,
Node
*
v_bias
,
Node
**
qkv_bias
)
const
;
const
std
::
string
name_scope_
{
"multi_encoder_xpu_fuse_pass"
};
const
std
::
string
name_scope_
{
"multi_encoder_xpu_fuse_pass"
};
};
};
...
@@ -685,55 +693,160 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -685,55 +693,160 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis
(
cast_mask_counts
);
AddStatis
(
cast_mask_counts
);
}
}
void
MultiEncoderXPUFusePass
::
PrepareQKVWeight
(
void
MultiEncoderXPUFusePass
::
PrepareQKVWeight
(
Graph
*
graph
,
const
phi
::
DenseTensor
&
q_w
,
Scope
*
scope
,
const
phi
::
DenseTensor
&
k_w
,
BlockDesc
*
block
,
const
phi
::
DenseTensor
&
v_w
,
Node
*
q_w
,
phi
::
DenseTensor
*
qkv_w
,
Node
*
k_w
,
phi
::
DenseTensor
*
qkv_w_max
)
const
{
Node
*
v_w
,
// Transpose
Node
**
qkv_w_int16
,
phi
::
DenseTensor
q_w_t
;
Node
**
qkv_w_max
)
const
{
phi
::
DenseTensor
k_w_t
;
phi
::
DenseTensor
q_w_fp32_t
;
phi
::
DenseTensor
v_w_t
;
phi
::
DenseTensor
k_w_fp32_t
;
Assign
(
q_w
,
&
q_w_t
);
phi
::
DenseTensor
v_w_fp32_t
;
Assign
(
k_w
,
&
k_w_t
);
Assign
(
scope
->
Var
(
q_w
->
Name
())
->
Get
<
phi
::
DenseTensor
>
(),
&
q_w_fp32_t
);
Assign
(
v_w
,
&
v_w_t
);
Assign
(
scope
->
Var
(
k_w
->
Name
())
->
Get
<
phi
::
DenseTensor
>
(),
&
k_w_fp32_t
);
Transpose2D
(
&
q_w_t
);
Assign
(
scope
->
Var
(
v_w
->
Name
())
->
Get
<
phi
::
DenseTensor
>
(),
&
v_w_fp32_t
);
Transpose2D
(
&
k_w_t
);
Transpose2D
(
&
v_w_t
);
CastToFp32
(
&
q_w_fp32_t
);
CastToFp32
(
&
k_w_fp32_t
);
// Concat
CastToFp32
(
&
v_w_fp32_t
);
qkv_w
->
Resize
(
DDim
(
{
q_w_t
.
dims
()[
0
]
+
k_w_t
.
dims
()[
0
]
+
v_w_t
.
dims
()[
0
],
q_w_t
.
dims
()[
1
]}));
Transpose2D
(
&
q_w_fp32_t
);
qkv_w
->
set_type
(
q_w
.
type
());
Transpose2D
(
&
k_w_fp32_t
);
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
Transpose2D
(
&
v_w_fp32_t
);
phi
::
DenseTensor
qkv_w_int16_t
;
phi
::
DenseTensor
qkv_w_max_t
;
qkv_w_int16_t
.
Resize
(
DDim
({
q_w_fp32_t
.
dims
()[
0
]
+
k_w_fp32_t
.
dims
()[
0
]
+
v_w_fp32_t
.
dims
()[
0
],
q_w_fp32_t
.
dims
()[
1
]}));
qkv_w_int16_t
.
set_type
(
q_w_fp32_t
.
type
());
auto
*
cpu_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
phi
::
CPUPlace
()));
platform
::
DeviceContextPool
::
Instance
().
Get
(
phi
::
CPUPlace
()));
std
::
vector
<
const
phi
::
DenseTensor
*>
in_tensors
{
&
q_w_t
,
&
k_w_t
,
&
v_w_t
};
std
::
vector
<
const
phi
::
DenseTensor
*>
in_tensors
{
if
(
q_w
.
type
()
==
phi
::
DataType
::
FLOAT16
)
{
&
q_w_fp32_t
,
&
k_w_fp32_t
,
&
v_w_fp32_t
};
phi
::
ConcatKernel
<
float16
>
(
*
dev_ctx
,
in_tensors
,
0
,
qkv_w
);
phi
::
ConcatKernel
<
float
>
(
*
cpu_ctx
,
in_tensors
,
0
,
&
qkv_w_int16_t
);
PrepareWeight
<
int16_t
>
(
&
qkv_w_int16_t
,
&
qkv_w_max_t
,
false
);
size_t
qkv_w_int16_hash
=
HashTensor
<
int16_t
>
(
qkv_w_int16_t
);
size_t
qkv_w_max_hash
=
HashTensor
<
float
>
(
qkv_w_max_t
);
std
::
string
qkv_w_int16_name
=
std
::
to_string
(
qkv_w_int16_hash
);
std
::
string
qkv_w_max_name
=
std
::
to_string
(
qkv_w_max_hash
);
*
qkv_w_int16
=
FindNodeWithName
(
graph
,
qkv_w_int16_name
);
if
(
*
qkv_w_int16
==
nullptr
)
{
// Create qkv_w_int16 node
// Update qkv_w_int16 var_desc in block
VarDesc
qkv_w_int16_desc
(
qkv_w_int16_name
);
qkv_w_int16_desc
.
SetPersistable
(
true
);
qkv_w_int16_desc
.
SetShape
(
vectorize
(
qkv_w_int16_t
.
dims
()));
qkv_w_int16_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
qkv_w_int16_t
.
dtype
()));
*
qkv_w_int16
=
graph
->
CreateVarNode
(
&
qkv_w_int16_desc
);
auto
*
block_qkv_w_int16_desc
=
block
->
Var
(
qkv_w_int16_name
);
block_qkv_w_int16_desc
->
SetPersistable
(
qkv_w_int16_desc
.
Persistable
());
block_qkv_w_int16_desc
->
SetShape
(
qkv_w_int16_desc
.
GetShape
());
block_qkv_w_int16_desc
->
SetDataType
(
qkv_w_int16_desc
.
GetDataType
());
// Create qkv_w_max node
// Update qkv_w_max var_desc in block
VarDesc
qkv_w_max_desc
(
qkv_w_max_name
);
qkv_w_max_desc
.
SetPersistable
(
true
);
qkv_w_max_desc
.
SetShape
(
vectorize
(
qkv_w_max_t
.
dims
()));
qkv_w_max_desc
.
SetDataType
(
proto
::
VarType
::
Type
::
VarType_Type_FP32
);
*
qkv_w_max
=
graph
->
CreateVarNode
(
&
qkv_w_max_desc
);
auto
*
block_qkv_w_max_desc
=
block
->
Var
(
qkv_w_max_name
);
block_qkv_w_max_desc
->
SetPersistable
(
qkv_w_max_desc
.
Persistable
());
block_qkv_w_max_desc
->
SetShape
(
qkv_w_max_desc
.
GetShape
());
block_qkv_w_max_desc
->
SetDataType
(
qkv_w_max_desc
.
GetDataType
());
// Find qkv_w_int16/qkv_w_max variable in scope
auto
*
qkv_w_int16_var
=
scope
->
FindVar
(
qkv_w_int16_name
);
if
(
qkv_w_int16_var
==
nullptr
)
{
// Create qkv_w_int16/qkv_w_max variable/tensor
Assign
(
qkv_w_int16_t
,
scope
->
Var
(
qkv_w_int16_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
Assign
(
qkv_w_max_t
,
scope
->
Var
(
qkv_w_max_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
}
else
{
}
else
{
phi
::
ConcatKernel
<
float
>
(
*
dev_ctx
,
in_tensors
,
0
,
qkv_w
);
// Share the same variable
PADDLE_ENFORCE_NOT_NULL
(
scope
->
FindVar
(
qkv_w_max_name
),
platform
::
errors
::
Fatal
(
"qkv_w_max(%s) variable should not be nullptr if qkv_w_int16(%s) "
"variable is exist."
,
qkv_w_max_name
,
qkv_w_int16_name
));
}
}
else
{
*
qkv_w_max
=
FindNodeWithName
(
graph
,
qkv_w_max_name
);
PADDLE_ENFORCE_NOT_NULL
(
*
qkv_w_max
,
platform
::
errors
::
Fatal
(
"qkv_w_max(%s) variable should not be nullptr if qkv_w_int16(%s) "
"variable is exist."
,
qkv_w_max_name
,
qkv_w_int16_name
));
}
}
// Quant to int16
QuantWeight
<
int16_t
>
(
qkv_w
,
qkv_w_max
,
false
);
}
}
void
MultiEncoderXPUFusePass
::
ConcatQKVBias
(
const
phi
::
DenseTensor
&
q_bias
,
void
MultiEncoderXPUFusePass
::
PrepareQKVBias
(
Graph
*
graph
,
const
phi
::
DenseTensor
&
k_bias
,
Scope
*
scope
,
const
phi
::
DenseTensor
&
v_bias
,
BlockDesc
*
block
,
phi
::
DenseTensor
*
qkv_bias
)
const
{
Node
*
q_bias
,
int
q_bias_size
=
q_bias
.
numel
();
Node
*
k_bias
,
qkv_bias
->
Resize
(
DDim
({
q_bias_size
*
3
}));
Node
*
v_bias
,
qkv_bias
->
set_type
(
q_bias
.
type
());
Node
**
qkv_bias
)
const
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
auto
*
q_bias_tensor
=
scope
->
Var
(
q_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
k_bias_tensor
=
scope
->
Var
(
k_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
v_bias_tensor
=
scope
->
Var
(
v_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
phi
::
DenseTensor
q_bias_fp32_tensor
;
phi
::
DenseTensor
k_bias_fp32_tensor
;
phi
::
DenseTensor
v_bias_fp32_tensor
;
CastToFp32
(
q_bias_tensor
,
&
q_bias_fp32_tensor
);
CastToFp32
(
k_bias_tensor
,
&
k_bias_fp32_tensor
);
CastToFp32
(
v_bias_tensor
,
&
v_bias_fp32_tensor
);
phi
::
DenseTensor
qkv_bias_tensor
;
int
q_bias_fp32_size
=
q_bias_fp32_tensor
.
numel
();
qkv_bias_tensor
.
Resize
(
DDim
({
q_bias_fp32_size
*
3
}));
qkv_bias_tensor
.
set_type
(
phi
::
DataType
::
FLOAT32
);
auto
*
cpu_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
phi
::
CPUPlace
()));
platform
::
DeviceContextPool
::
Instance
().
Get
(
phi
::
CPUPlace
()));
auto
*
qkv_bias_data
=
dev_ctx
->
Alloc
<
float
>
(
qkv_bias
);
auto
*
qkv_bias_data
=
cpu_ctx
->
Alloc
<
float
>
(
&
qkv_bias_tensor
);
memcpy
(
qkv_bias_data
,
q_bias
.
data
(),
q_bias_size
*
sizeof
(
float
));
memcpy
(
qkv_bias_data
,
qkv_bias_data
+=
q_bias_size
;
q_bias_fp32_tensor
.
data
(),
memcpy
(
qkv_bias_data
,
k_bias
.
data
(),
q_bias_size
*
sizeof
(
float
));
q_bias_fp32_size
*
sizeof
(
float
));
qkv_bias_data
+=
q_bias_size
;
qkv_bias_data
+=
q_bias_fp32_size
;
memcpy
(
qkv_bias_data
,
v_bias
.
data
(),
q_bias_size
*
sizeof
(
float
));
memcpy
(
qkv_bias_data
,
k_bias_fp32_tensor
.
data
(),
q_bias_fp32_size
*
sizeof
(
float
));
qkv_bias_data
+=
q_bias_fp32_size
;
memcpy
(
qkv_bias_data
,
v_bias_fp32_tensor
.
data
(),
q_bias_fp32_size
*
sizeof
(
float
));
size_t
qkv_bias_hash
=
HashTensor
<
float
>
(
qkv_bias_tensor
);
std
::
string
qkv_bias_name
=
std
::
to_string
(
qkv_bias_hash
);
*
qkv_bias
=
FindNodeWithName
(
graph
,
qkv_bias_name
);
if
(
*
qkv_bias
==
nullptr
)
{
// Create qkv_bias node
// Update qkv_bias var_desc in block
VarDesc
qkv_bias_desc
(
qkv_bias_name
);
qkv_bias_desc
.
SetPersistable
(
true
);
qkv_bias_desc
.
SetShape
(
vectorize
(
qkv_bias_tensor
.
dims
()));
qkv_bias_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
qkv_bias_tensor
.
dtype
()));
*
qkv_bias
=
graph
->
CreateVarNode
(
&
qkv_bias_desc
);
auto
*
block_qkv_bias_desc
=
block
->
Var
(
qkv_bias_name
);
block_qkv_bias_desc
->
SetPersistable
(
qkv_bias_desc
.
Persistable
());
block_qkv_bias_desc
->
SetShape
(
qkv_bias_desc
.
GetShape
());
block_qkv_bias_desc
->
SetDataType
(
qkv_bias_desc
.
GetDataType
());
Assign
(
qkv_bias_tensor
,
scope
->
Var
(
qkv_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
}
}
}
int
MultiEncoderXPUFusePass
::
ApplySingleEncoderXPUFuse
(
int
MultiEncoderXPUFusePass
::
ApplySingleEncoderXPUFuse
(
...
@@ -856,109 +969,67 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
...
@@ -856,109 +969,67 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
scope
->
FindVar
(
q_matmul_w
->
Name
())
->
Get
<
phi
::
DenseTensor
>
().
dtype
()
==
scope
->
FindVar
(
q_matmul_w
->
Name
())
->
Get
<
phi
::
DenseTensor
>
().
dtype
()
==
phi
::
DataType
::
FLOAT16
;
phi
::
DataType
::
FLOAT16
;
// Prepare q,k,v weight
Node
*
qkv_w_int16
=
nullptr
;
std
::
string
q_w_name
=
q_matmul_w
->
Name
();
Node
*
qkv_w_max
=
nullptr
;
std
::
string
k_w_name
=
k_matmul_w
->
Name
();
PrepareQKVWeight
(
graph
,
std
::
string
v_w_name
=
v_matmul_w
->
Name
();
scope
,
std
::
string
qkv_w_name
=
q_w_name
+
"_"
+
k_w_name
+
"_"
+
v_w_name
;
block
,
VarDesc
qkv_w_desc
(
qkv_w_name
);
q_matmul_w
,
qkv_w_desc
.
SetPersistable
(
true
);
k_matmul_w
,
auto
*
qkv_w
=
graph
->
CreateVarNode
(
&
qkv_w_desc
);
v_matmul_w
,
auto
*
qkv_w_var
=
block
->
Var
(
qkv_w_name
);
&
qkv_w_int16
,
qkv_w_var
->
SetPersistable
(
true
);
&
qkv_w_max
);
std
::
string
qkv_w_max_name
=
qkv_w_name
+
"_max"
;
VarDesc
qkv_w_max_desc
(
qkv_w_max_name
);
qkv_w_max_desc
.
SetPersistable
(
true
);
auto
*
qkv_w_max
=
graph
->
CreateVarNode
(
&
qkv_w_max_desc
);
auto
*
qkv_w_max_var
=
block
->
Var
(
qkv_w_max_name
);
qkv_w_max_var
->
SetPersistable
(
true
);
PrepareQKVWeight
(
scope
->
FindVar
(
q_w_name
)
->
Get
<
phi
::
DenseTensor
>
(),
scope
->
FindVar
(
k_w_name
)
->
Get
<
phi
::
DenseTensor
>
(),
scope
->
FindVar
(
v_w_name
)
->
Get
<
phi
::
DenseTensor
>
(),
scope
->
Var
(
qkv_w_name
)
->
GetMutable
<
phi
::
DenseTensor
>
(),
scope
->
Var
(
qkv_w_max_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
// Prepare qkv_matmul_1_w, qkv_matmul_2_w, qkv_matmul_3_w
#define PREPARE_QKV_MATMUL_W(idx_) \
#define PREPARE_QKV_MATMUL_W(idx_) \
std::string qkv_matmul_##idx_##_w_name = qkv_matmul_##idx_##_w->Name(); \
Node* qkv_matmul_##idx_##_w_int16 = nullptr; \
std::string qkv_matmul_##idx_##_w_max_name = \
Node* qkv_matmul_##idx_##_w_max = nullptr; \
qkv_matmul_##idx_##_w_name + "_max"; \
PrepareWeight<int16_t>(graph, \
VarDesc qkv_matmul_##idx_##_w_max_desc(qkv_matmul_##idx_##_w_max_name); \
scope, \
qkv_matmul_##idx_##_w_max_desc.SetPersistable(true); \
block, \
auto qkv_matmul_##idx_##_w_max = \
qkv_matmul_##idx_##_w, \
graph->CreateVarNode(&qkv_matmul_##idx_##_w_max_desc); \
&qkv_matmul_##idx_##_w_int16, \
auto qkv_matmul_##idx_##_w_max_var = \
&qkv_matmul_##idx_##_w_max, \
block->Var(qkv_matmul_##idx_##_w_max_name); \
true);
qkv_matmul_##idx_##_w_max_var->SetPersistable(true); \
auto qkv_matmul_##idx_##_w_max_tensor = \
scope->Var(qkv_matmul_##idx_##_w_max_name) \
->GetMutable<phi::DenseTensor>(); \
auto qkv_matmul_##idx_##_w_tensor = \
scope->Var(qkv_matmul_##idx_##_w_name)->GetMutable<phi::DenseTensor>(); \
QuantWeight<int16_t>( \
qkv_matmul_##idx_##_w_tensor, qkv_matmul_##idx_##_w_max_tensor, true);
PREPARE_QKV_MATMUL_W
(
1
);
PREPARE_QKV_MATMUL_W
(
1
);
PREPARE_QKV_MATMUL_W
(
2
);
PREPARE_QKV_MATMUL_W
(
2
);
PREPARE_QKV_MATMUL_W
(
3
);
PREPARE_QKV_MATMUL_W
(
3
);
#undef PREPARE_QKV_MATMUL_W
#undef PREPARE_QKV_MATMUL_W
// Concat q_add_bias, k_add_bias, v_add_bias
Node
*
qkv_add_bias_fp32
=
nullptr
;
std
::
string
q_add_bias_name
=
q_add_bias
->
Name
();
PrepareQKVBias
(
graph
,
std
::
string
k_add_bias_name
=
k_add_bias
->
Name
();
scope
,
std
::
string
v_add_bias_name
=
v_add_bias
->
Name
();
block
,
std
::
string
qkv_add_bias_name
=
q_add_bias
,
q_add_bias_name
+
"_"
+
k_add_bias_name
+
"_"
+
v_add_bias_name
;
k_add_bias
,
VarDesc
qkv_add_bias_desc
(
qkv_add_bias_name
);
v_add_bias
,
qkv_add_bias_desc
.
SetPersistable
(
true
);
&
qkv_add_bias_fp32
);
auto
*
qkv_add_bias
=
graph
->
CreateVarNode
(
&
qkv_add_bias_desc
);
auto
*
qkv_add_bias_var
=
block
->
Var
(
qkv_add_bias_name
);
Node
*
qkv_add_0_bias_fp32
=
nullptr
;
qkv_add_bias_var
->
SetPersistable
(
true
);
Node
*
qkv_add_2_bias_fp32
=
nullptr
;
auto
*
q_add_bias_tensor
=
Node
*
qkv_add_3_bias_fp32
=
nullptr
;
scope
->
FindVar
(
q_add_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
PrepareBias
(
graph
,
scope
,
block
,
qkv_add_0_bias
,
&
qkv_add_0_bias_fp32
);
auto
*
k_add_bias_tensor
=
PrepareBias
(
graph
,
scope
,
block
,
qkv_add_2_bias
,
&
qkv_add_2_bias_fp32
);
scope
->
FindVar
(
k_add_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
PrepareBias
(
graph
,
scope
,
block
,
qkv_add_3_bias
,
&
qkv_add_3_bias_fp32
);
auto
*
v_add_bias_tensor
=
scope
->
FindVar
(
v_add_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
CastToFp32
(
q_add_bias_tensor
);
CastToFp32
(
k_add_bias_tensor
);
CastToFp32
(
v_add_bias_tensor
);
ConcatQKVBias
(
*
q_add_bias_tensor
,
*
k_add_bias_tensor
,
*
v_add_bias_tensor
,
scope
->
Var
(
qkv_add_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
// Prepare qkv_add_0_bias, qkv_add_2_bias, qkv_add_3_bias
auto
qkv_add_0_bias_name
=
qkv_add_0_bias
->
Name
();
CastToFp32
(
scope
->
FindVar
(
qkv_add_0_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
auto
qkv_add_2_bias_name
=
qkv_add_2_bias
->
Name
();
CastToFp32
(
scope
->
FindVar
(
qkv_add_2_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
auto
qkv_add_3_bias_name
=
qkv_add_3_bias
->
Name
();
CastToFp32
(
scope
->
FindVar
(
qkv_add_3_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
// Generate single_encoder_xpu op
// Generate single_encoder_xpu op
framework
::
OpDesc
op_desc
(
block
);
framework
::
OpDesc
op_desc
(
block
);
op_desc
.
SetType
(
"single_encoder_xpu"
);
op_desc
.
SetType
(
"single_encoder_xpu"
);
op_desc
.
SetInput
(
"x"
,
{
ln_0_x
->
Name
()});
op_desc
.
SetInput
(
"x"
,
{
ln_0_x
->
Name
()});
op_desc
.
SetInput
(
"fc_weight"
,
op_desc
.
SetInput
(
"fc_weight"
,
{
qkv_w_
name
,
{
qkv_w_
int16
->
Name
()
,
qkv_matmul_1_w_
name
,
qkv_matmul_1_w_
int16
->
Name
()
,
qkv_matmul_2_w_
name
,
qkv_matmul_2_w_
int16
->
Name
()
,
qkv_matmul_3_w_
name
});
qkv_matmul_3_w_
int16
->
Name
()
});
op_desc
.
SetInput
(
"fc_weight_max"
,
op_desc
.
SetInput
(
"fc_weight_max"
,
{
qkv_w_max
_name
,
{
qkv_w_max
->
Name
()
,
qkv_matmul_1_w_max
_name
,
qkv_matmul_1_w_max
->
Name
()
,
qkv_matmul_2_w_max
_name
,
qkv_matmul_2_w_max
->
Name
()
,
qkv_matmul_3_w_max
_name
});
qkv_matmul_3_w_max
->
Name
()
});
op_desc
.
SetInput
(
"fc_bias"
,
op_desc
.
SetInput
(
"fc_bias"
,
{
qkv_add_bias_
name
,
{
qkv_add_bias_
fp32
->
Name
()
,
qkv_add_0_bias_
name
,
qkv_add_0_bias_
fp32
->
Name
()
,
qkv_add_2_bias_
name
,
qkv_add_2_bias_
fp32
->
Name
()
,
qkv_add_3_bias_
name
});
qkv_add_3_bias_
fp32
->
Name
()
});
if
(
norm_before
)
{
if
(
norm_before
)
{
op_desc
.
SetInput
(
"ln_scale"
,
{
ln_0_scale
->
Name
(),
ln_1_scale
->
Name
()});
op_desc
.
SetInput
(
"ln_scale"
,
{
ln_0_scale
->
Name
(),
ln_1_scale
->
Name
()});
op_desc
.
SetInput
(
"ln_bias"
,
{
ln_0_bias
->
Name
(),
ln_1_bias
->
Name
()});
op_desc
.
SetInput
(
"ln_bias"
,
{
ln_0_bias
->
Name
(),
ln_1_bias
->
Name
()});
...
@@ -990,30 +1061,30 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
...
@@ -990,30 +1061,30 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
}
}
auto
*
single_encoder_xpu
=
graph
->
CreateOpNode
(
&
op_desc
);
auto
*
single_encoder_xpu
=
graph
->
CreateOpNode
(
&
op_desc
);
// Link nodes
// Link nodes
SAFE_
IR_NODE_LINK_TO
(
ln_0_x
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
ln_0_x
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_w
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_w_int16
,
single_encoder_xpu
);
SAFE_
IR_NODE_LINK_TO
(
qkv_w_max
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_w_max
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_matmul_1_w
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_matmul_1_w_int16
,
single_encoder_xpu
);
SAFE_
IR_NODE_LINK_TO
(
qkv_matmul_1_w_max
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_matmul_1_w_max
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_matmul_2_w
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_matmul_2_w_int16
,
single_encoder_xpu
);
SAFE_
IR_NODE_LINK_TO
(
qkv_matmul_2_w_max
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_matmul_2_w_max
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_matmul_3_w
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_matmul_3_w_int16
,
single_encoder_xpu
);
SAFE_
IR_NODE_LINK_TO
(
qkv_matmul_3_w_max
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_matmul_3_w_max
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_add_bias
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_add_bias_fp32
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_add_0_bias
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_add_0_bias_fp32
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_add_2_bias
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_add_2_bias_fp32
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_add_3_bias
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_add_3_bias_fp32
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
ln_0_scale
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
ln_0_scale
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
ln_0_bias
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
ln_0_bias
,
single_encoder_xpu
);
SAFE_
IR_NODE_LINK_TO
(
ln_1_scale
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
ln_1_scale
,
single_encoder_xpu
);
SAFE_
IR_NODE_LINK_TO
(
ln_1_bias
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
ln_1_bias
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
ln_2_scale
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
ln_2_scale
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
ln_2_bias
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
ln_2_bias
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qk_add_mask
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qk_add_mask
,
single_encoder_xpu
);
if
(
norm_before
)
{
if
(
norm_before
)
{
SAFE_
IR_NODE_LINK_TO
(
single_encoder_xpu
,
qkv_add_4_out
);
IR_NODE_LINK_TO
(
single_encoder_xpu
,
qkv_add_4_out
);
}
else
{
}
else
{
SAFE_
IR_NODE_LINK_TO
(
single_encoder_xpu
,
ln_2_out
);
IR_NODE_LINK_TO
(
single_encoder_xpu
,
ln_2_out
);
}
}
// Delete nodes
// Delete nodes
...
...
paddle/fluid/framework/ir/xpu/pass_utils.cc
浏览文件 @
39a9abaa
...
@@ -20,6 +20,18 @@ namespace paddle {
...
@@ -20,6 +20,18 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
void
HashCombine
(
std
::
size_t
*
seed
)
{}
// combine hash value
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
template
<
typename
T
,
typename
...
Rest
>
static
void
HashCombine
(
std
::
size_t
*
seed
,
const
T
&
v
,
Rest
...
rest
)
{
std
::
hash
<
T
>
hasher
;
*
seed
^=
hasher
(
v
)
+
0x9e3779b9
+
(
*
seed
<<
6
)
+
(
*
seed
>>
2
);
*
seed
*=
0x00000100000001B3
;
HashCombine
(
seed
,
rest
...);
}
int
ConvertActivationType
(
std
::
string
act_type
)
{
int
ConvertActivationType
(
std
::
string
act_type
)
{
if
(
act_type
==
""
)
{
if
(
act_type
==
""
)
{
return
static_cast
<
int
>
(
xpu
::
Activation_t
::
LINEAR
);
return
static_cast
<
int
>
(
xpu
::
Activation_t
::
LINEAR
);
...
@@ -50,6 +62,161 @@ int ConvertActivationType(std::string act_type) {
...
@@ -50,6 +62,161 @@ int ConvertActivationType(std::string act_type) {
return
-
1
;
return
-
1
;
}
}
Node
*
FindNodeWithName
(
Graph
*
graph
,
std
::
string
name
)
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
()
&&
node
->
Var
()
->
Name
()
==
name
)
{
return
node
;
}
}
return
nullptr
;
}
template
<
typename
T
>
std
::
string
IntTypeToString
()
{
LOG
(
FATAL
)
<<
"Not support type."
;
return
""
;
}
template
<
>
std
::
string
IntTypeToString
<
int16_t
>
()
{
return
"int16"
;
}
template
<
typename
T
>
size_t
HashTensor
(
const
phi
::
DenseTensor
&
in
)
{
size_t
ret
=
0
;
auto
in_dims
=
in
.
dims
();
HashCombine
(
&
ret
,
phi
::
DataTypeToString
(
in
.
dtype
()),
phi
::
DataLayoutToString
(
in
.
layout
()),
in_dims
.
size
());
for
(
int
i
=
0
;
i
<
in_dims
.
size
();
i
++
)
{
HashCombine
(
&
ret
,
in_dims
[
i
]);
}
auto
*
data
=
in
.
data
<
T
>
();
int64_t
size
=
in
.
numel
();
for
(
int64_t
i
=
0
;
i
<
size
;
i
++
)
{
HashCombine
(
&
ret
,
data
[
i
]);
}
return
ret
;
}
template
size_t
HashTensor
<
int16_t
>(
const
phi
::
DenseTensor
&
in
);
template
size_t
HashTensor
<
float
>(
const
phi
::
DenseTensor
&
in
);
template
<
typename
T
>
void
PrepareWeight
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
,
Node
**
dst_max
,
bool
transpose
)
{
auto
src_name
=
src
->
Name
();
auto
*
src_tensor
=
scope
->
Var
(
src_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
phi
::
DenseTensor
dst_tensor
;
Assign
(
*
src_tensor
,
&
dst_tensor
);
phi
::
DenseTensor
dst_max_tensor
;
PrepareWeight
<
T
>
(
&
dst_tensor
,
&
dst_max_tensor
,
transpose
);
size_t
dst_hash
=
HashTensor
<
T
>
(
dst_tensor
);
size_t
dst_max_hash
=
HashTensor
<
float
>
(
dst_max_tensor
);
std
::
string
dst_name
=
src_name
+
"_"
+
std
::
to_string
(
dst_hash
);
std
::
string
dst_max_name
=
src_name
+
"_max_"
+
std
::
to_string
(
dst_max_hash
);
*
dst
=
FindNodeWithName
(
graph
,
dst_name
);
if
(
*
dst
==
nullptr
)
{
// Create dst node
// Update dst var_desc in block
VarDesc
dst_desc
(
dst_name
);
dst_desc
.
SetPersistable
(
true
);
dst_desc
.
SetShape
(
vectorize
(
dst_tensor
.
dims
()));
dst_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
dst_tensor
.
dtype
()));
*
dst
=
graph
->
CreateVarNode
(
&
dst_desc
);
auto
*
block_dst_desc
=
block
->
Var
(
dst_name
);
block_dst_desc
->
SetPersistable
(
dst_desc
.
Persistable
());
block_dst_desc
->
SetShape
(
dst_desc
.
GetShape
());
block_dst_desc
->
SetDataType
(
dst_desc
.
GetDataType
());
// Create dst_max node
// Update dst_max var_desc in block
VarDesc
dst_max_desc
(
dst_max_name
);
dst_max_desc
.
SetPersistable
(
true
);
dst_max_desc
.
SetShape
(
vectorize
(
dst_max_tensor
.
dims
()));
dst_max_desc
.
SetDataType
(
proto
::
VarType
::
Type
::
VarType_Type_FP32
);
*
dst_max
=
graph
->
CreateVarNode
(
&
dst_max_desc
);
auto
*
block_dst_max_desc
=
block
->
Var
(
dst_max_name
);
block_dst_max_desc
->
SetPersistable
(
dst_max_desc
.
Persistable
());
block_dst_max_desc
->
SetShape
(
dst_max_desc
.
GetShape
());
block_dst_max_desc
->
SetDataType
(
dst_max_desc
.
GetDataType
());
// Find dst/dst_max variable in scope
auto
*
dst_var
=
scope
->
FindVar
(
dst_name
);
if
(
dst_var
==
nullptr
)
{
// Create dst/dst_max variable/tensor
Assign
(
dst_tensor
,
scope
->
Var
(
dst_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
Assign
(
dst_max_tensor
,
scope
->
Var
(
dst_max_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
}
else
{
// Share the same variable
PADDLE_ENFORCE_NOT_NULL
(
scope
->
FindVar
(
dst_max_name
),
platform
::
errors
::
Fatal
(
"dst_max(%s) variable should not be nullptr if dst(%s) "
"variable is exist. (src_name is %s)"
,
dst_max_name
,
dst_name
,
src_name
));
}
}
else
{
*
dst_max
=
FindNodeWithName
(
graph
,
dst_max_name
);
PADDLE_ENFORCE_NOT_NULL
(
*
dst_max
,
platform
::
errors
::
Fatal
(
"dst_max(%s) variable should not be nullptr if dst(%s) "
"variable is exist. (src_name is %s)"
,
dst_max_name
,
dst_name
,
src_name
));
}
}
template
void
PrepareWeight
<
int16_t
>(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
,
Node
**
dst_max
,
bool
transpose
);
void
PrepareBias
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
)
{
auto
src_name
=
src
->
Name
();
auto
*
src_tensor
=
scope
->
Var
(
src_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
src_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
*
dst
=
src
;
}
phi
::
DenseTensor
dst_tensor
;
CastToFp32
(
src_tensor
,
&
dst_tensor
);
size_t
dst_hash
=
HashTensor
<
float
>
(
dst_tensor
);
std
::
string
dst_name
=
src_name
+
"_"
+
std
::
to_string
(
dst_hash
);
*
dst
=
FindNodeWithName
(
graph
,
dst_name
);
if
(
*
dst
==
nullptr
)
{
// Create dst node
// Update dst var_desc in block
VarDesc
dst_desc
(
dst_name
);
dst_desc
.
SetPersistable
(
true
);
dst_desc
.
SetShape
(
vectorize
(
dst_tensor
.
dims
()));
dst_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
dst_tensor
.
dtype
()));
*
dst
=
graph
->
CreateVarNode
(
&
dst_desc
);
auto
*
block_dst_desc
=
block
->
Var
(
dst_name
);
block_dst_desc
->
SetPersistable
(
dst_desc
.
Persistable
());
block_dst_desc
->
SetShape
(
dst_desc
.
GetShape
());
block_dst_desc
->
SetDataType
(
dst_desc
.
GetDataType
());
Assign
(
dst_tensor
,
scope
->
Var
(
dst_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
}
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/xpu/pass_utils.h
浏览文件 @
39a9abaa
...
@@ -14,6 +14,10 @@
...
@@ -14,6 +14,10 @@
#pragma once
#pragma once
#include <string>
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -45,6 +49,23 @@ namespace ir {
...
@@ -45,6 +49,23 @@ namespace ir {
int
ConvertActivationType
(
std
::
string
act_type
);
int
ConvertActivationType
(
std
::
string
act_type
);
Node
*
FindNodeWithName
(
Graph
*
graph
,
std
::
string
name
);
template
<
typename
T
>
size_t
HashTensor
(
const
phi
::
DenseTensor
&
in
);
template
<
typename
T
>
void
PrepareWeight
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
,
Node
**
dst_max
,
bool
transpose
);
void
PrepareBias
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
);
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/xpu/quant_utils.cc
浏览文件 @
39a9abaa
...
@@ -207,7 +207,7 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr,
...
@@ -207,7 +207,7 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr,
}
}
template
<
typename
T
>
template
<
typename
T
>
void
Quant
Weight
(
phi
::
DenseTensor
*
weight
,
void
Prepare
Weight
(
phi
::
DenseTensor
*
weight
,
phi
::
DenseTensor
*
weight_max
,
phi
::
DenseTensor
*
weight_max
,
bool
transpose
)
{
bool
transpose
)
{
// Convert fp16 to fp32
// Convert fp16 to fp32
...
@@ -249,7 +249,7 @@ void QuantWeight(phi::DenseTensor* weight,
...
@@ -249,7 +249,7 @@ void QuantWeight(phi::DenseTensor* weight,
QuantFP32ToIntX
(
weight_data
,
cpu_ctx
->
Alloc
<
T
>
(
weight
),
max_val
,
size
);
QuantFP32ToIntX
(
weight_data
,
cpu_ctx
->
Alloc
<
T
>
(
weight
),
max_val
,
size
);
}
}
template
void
Quant
Weight
<
int16_t
>(
phi
::
DenseTensor
*
weight
,
template
void
Prepare
Weight
<
int16_t
>(
phi
::
DenseTensor
*
weight
,
phi
::
DenseTensor
*
weight_max
,
phi
::
DenseTensor
*
weight_max
,
bool
transpose
);
bool
transpose
);
...
...
paddle/fluid/framework/ir/xpu/quant_utils.h
浏览文件 @
39a9abaa
...
@@ -29,7 +29,7 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
...
@@ -29,7 +29,7 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
// 2. Weight data is in-place update.
// 2. Weight data is in-place update.
// 3. Generate weight max tensor
// 3. Generate weight max tensor
template
<
typename
T
>
template
<
typename
T
>
void
Quant
Weight
(
phi
::
DenseTensor
*
weight
,
void
Prepare
Weight
(
phi
::
DenseTensor
*
weight
,
phi
::
DenseTensor
*
weight_max
,
phi
::
DenseTensor
*
weight_max
,
bool
transpose
);
bool
transpose
);
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
39a9abaa
...
@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
...
@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fc_xpu_fuse_pass"
,
"fc_xpu_fuse_pass"
,
"link_xpu_op_max_pass"
,
"link_xpu_op_max_pass"
,
"delete_op_device_pass"
,
"delete_op_device_pass"
,
"delete_isolated_node_pass"
,
});
});
use_xpu_
=
true
;
use_xpu_
=
true
;
}
}
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py
浏览文件 @
39a9abaa
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
unittest
import
unittest
import
hypothesis.strategies
as
st
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
...
@@ -33,7 +34,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
...
@@ -33,7 +34,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
)
)
matmul0_y_shape
=
draw
(
matmul0_y_shape
=
draw
(
st
.
lists
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
8
),
min_size
=
2
,
max_size
=
2
st
.
integers
(
min_value
=
2
,
max_value
=
8
),
min_size
=
2
,
max_size
=
2
)
)
)
)
matmul0_y_shape
[
0
]
=
matmul0_x_shape
[
-
1
]
matmul0_y_shape
[
0
]
=
matmul0_x_shape
[
-
1
]
...
@@ -42,7 +43,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
...
@@ -42,7 +43,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
# 3. matmul1
# 3. matmul1
matmul1_y_shape
=
draw
(
matmul1_y_shape
=
draw
(
st
.
lists
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
8
),
min_size
=
2
,
max_size
=
2
st
.
integers
(
min_value
=
2
,
max_value
=
8
),
min_size
=
2
,
max_size
=
2
)
)
)
)
matmul1_y_shape
[
0
]
=
matmul0_y_shape
[
-
1
]
matmul1_y_shape
[
0
]
=
matmul0_y_shape
[
-
1
]
...
@@ -101,4 +102,5 @@ class TestFcXPUFusePass(PassAutoScanTest):
...
@@ -101,4 +102,5 @@ class TestFcXPUFusePass(PassAutoScanTest):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
200
)
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录