Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
39a9abaa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
39a9abaa
编写于
3月 07, 2023
作者:
Z
zhupengyang
提交者:
GitHub
3月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] support shared weight; delete isolated node (#51108)
上级
50ad760c
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
863 addition
and
223 deletion
+863
-223
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+10
-1
paddle/fluid/framework/ir/delete_op_device_pass.cc
paddle/fluid/framework/ir/delete_op_device_pass.cc
+4
-4
paddle/fluid/framework/ir/delete_op_device_pass_test.cc
paddle/fluid/framework/ir/delete_op_device_pass_test.cc
+1
-2
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+1
-0
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc
+203
-0
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc
.../fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc
+181
-0
paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
+30
-45
paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
+231
-160
paddle/fluid/framework/ir/xpu/pass_utils.cc
paddle/fluid/framework/ir/xpu/pass_utils.cc
+167
-0
paddle/fluid/framework/ir/xpu/pass_utils.h
paddle/fluid/framework/ir/xpu/pass_utils.h
+21
-0
paddle/fluid/framework/ir/xpu/quant_utils.cc
paddle/fluid/framework/ir/xpu/quant_utils.cc
+6
-6
paddle/fluid/framework/ir/xpu/quant_utils.h
paddle/fluid/framework/ir/xpu/quant_utils.h
+3
-3
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py
...s/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py
+4
-2
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
39a9abaa
...
...
@@ -220,7 +220,7 @@ if(WITH_XPU)
cc_library
(
xpu_pass_utils
SRCS xpu/pass_utils.cc
DEPS pass
)
DEPS pass
xpu_quant_utils
)
set
(
XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils
)
pass_library
(
embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
...
...
@@ -232,6 +232,8 @@ if(WITH_XPU)
pass_library
(
generate_sequence_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
link_xpu_op_max_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
delete_isolated_node_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
endif
()
cc_library
(
...
...
@@ -484,3 +486,10 @@ if(WITH_MKLDNN)
SRCS mkldnn/cpu_bfloat16_pass_tester.cc
DEPS cpu_bfloat16_pass
)
endif
()
if
(
WITH_XPU
)
cc_test
(
test_delete_isolated_node_pass
SRCS xpu/delete_isolated_node_pass_test.cc
DEPS delete_isolated_node_pass
)
endif
()
paddle/fluid/framework/ir/delete_op_device_pass.cc
浏览文件 @
39a9abaa
...
...
@@ -39,14 +39,14 @@ class DeleteOpDevicePass : public Pass {
void
DeleteOpDevicePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
int
found_subgraph_count
=
0
;
int
delete_counts
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
()
||
!
node
->
Op
()
->
HasAttr
(
"op_device"
))
continue
;
node
->
Op
()
->
RemoveAttr
(
"op_device"
);
found_subgraph_count
++
;
delete_counts
++
;
}
if
(
found_subgraph_count
>
0
)
{
LOG
(
INFO
)
<<
"--- de
tected "
<<
found_subgraph_count
<<
" subgraphs
"
;
if
(
delete_counts
>
0
)
{
LOG
(
INFO
)
<<
"--- de
lete "
<<
delete_counts
<<
" op_device attr
"
;
}
}
...
...
paddle/fluid/framework/ir/delete_op_device_pass_test.cc
浏览文件 @
39a9abaa
...
...
@@ -13,8 +13,7 @@
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
...
...
paddle/fluid/framework/ir/pass.cc
浏览文件 @
39a9abaa
...
...
@@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass"
,
"delete_quant_dequant_linear_op_pass"
,
"delete_weight_dequant_linear_op_pass"
,
"fc_xpu_fuse_pass"
,
"delete_op_device_pass"
};
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
...
...
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc
0 → 100644
浏览文件 @
39a9abaa
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/scope.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
DeleteIsolatedNodePass
:
public
Pass
{
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
private:
void
CollectReservedPersistableNodeNames
(
Graph
*
graph
,
std
::
unordered_set
<
std
::
string
>*
reserved_persistable_node_names
)
const
;
int
RemoveIsolatedNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
reserved_persistable_node_names
,
std
::
unordered_set
<
std
::
string
>*
delete_node_names
)
const
;
int
UpdateControlFlowOp
(
Graph
*
graph
,
const
std
::
map
<
int
,
Graph
*>&
block_id_graph_map
,
const
std
::
unordered_set
<
std
::
string
>&
delete_node_names
)
const
;
const
std
::
map
<
std
::
string
,
std
::
string
>
control_flow_op_input_map_
{
{
"while"
,
"X"
},
{
"conditional_block"
,
"Input"
},
};
};
void
DeleteIsolatedNodePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
PADDLE_ENFORCE
(
graph
->
IsMainGraph
(),
platform
::
errors
::
PreconditionNotMet
(
"Pass(apply in main graph) will delete isolated nodes in "
"all subgraphs. Do not apply pass in subgraph."
));
std
::
unordered_set
<
std
::
string
>
reserved_persistable_node_names
;
for
(
size_t
i
=
0
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
CollectReservedPersistableNodeNames
(
graph
->
GetSubGraph
(
i
),
&
reserved_persistable_node_names
);
}
int
delete_counts
=
0
;
std
::
unordered_set
<
std
::
string
>
delete_node_names
;
for
(
size_t
i
=
0
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
delete_counts
+=
RemoveIsolatedNodes
(
graph
->
GetSubGraph
(
i
),
reserved_persistable_node_names
,
&
delete_node_names
);
}
if
(
delete_counts
>
0
)
{
LOG
(
INFO
)
<<
"--- delete "
<<
delete_counts
<<
" isolated nodes"
;
}
std
::
map
<
int
,
Graph
*>
block_id_graph_map
;
for
(
size_t
i
=
0
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
auto
*
sub_graph
=
graph
->
GetSubGraph
(
i
);
for
(
auto
*
node
:
sub_graph
->
Nodes
())
{
if
(
node
->
IsVar
())
{
block_id_graph_map
[
node
->
GetVarNodeBlockId
()]
=
sub_graph
;
break
;
}
}
}
int
update_counts
=
0
;
for
(
size_t
i
=
0
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
update_counts
+=
UpdateControlFlowOp
(
graph
->
GetSubGraph
(
i
),
block_id_graph_map
,
delete_node_names
);
}
if
(
update_counts
>
0
)
{
LOG
(
INFO
)
<<
"--- update "
<<
update_counts
<<
" control flow ops"
;
}
}
void
DeleteIsolatedNodePass
::
CollectReservedPersistableNodeNames
(
Graph
*
graph
,
std
::
unordered_set
<
std
::
string
>*
reserved_persistable_node_names
)
const
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
()
||
!
node
->
Var
()
->
Persistable
())
continue
;
for
(
auto
*
out_node
:
node
->
outputs
)
{
auto
op_type
=
out_node
->
Op
()
->
Type
();
if
(
control_flow_op_input_map_
.
count
(
op_type
)
==
0
)
{
reserved_persistable_node_names
->
insert
(
node
->
Var
()
->
Name
());
break
;
}
}
}
}
int
DeleteIsolatedNodePass
::
RemoveIsolatedNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
reserved_persistable_node_names
,
std
::
unordered_set
<
std
::
string
>*
delete_node_names
)
const
{
BlockDesc
*
block
=
nullptr
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
block
=
node
->
Op
()
->
Block
();
}
}
Scope
&
scope
=
graph
->
Get
<
framework
::
Scope
>
(
"__param_scope__"
);
// If graph has nodes to delete:
// 1. Clear var_desc in block
// 2. Clear tensor in variable
// 3. Clear variable in scope
int
delete_node_counts
=
0
;
std
::
unordered_set
<
const
Node
*>
delete_nodes
;
const
std
::
unordered_set
<
ir
::
Node
*>
nodes
=
graph
->
Nodes
();
for
(
auto
*
node
:
nodes
)
{
if
(
!
node
->
IsVar
()
||
!
node
->
Var
()
->
Persistable
())
continue
;
auto
name
=
node
->
Var
()
->
Name
();
if
(
reserved_persistable_node_names
.
count
(
name
)
>
0
)
continue
;
delete_nodes
.
insert
(
node
);
delete_node_names
->
insert
(
node
->
Name
());
block
->
RemoveVar
(
name
);
auto
*
var
=
scope
.
FindVar
(
name
);
if
(
var
!=
nullptr
)
{
var
->
Clear
();
scope
.
EraseVars
({
name
});
}
delete_node_counts
++
;
}
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
return
delete_node_counts
;
}
int
DeleteIsolatedNodePass
::
UpdateControlFlowOp
(
Graph
*
graph
,
const
std
::
map
<
int
,
Graph
*>&
block_id_graph_map
,
const
std
::
unordered_set
<
std
::
string
>&
delete_node_names
)
const
{
int
update_counts
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
())
continue
;
auto
op_type
=
node
->
Op
()
->
Type
();
if
(
control_flow_op_input_map_
.
count
(
op_type
)
==
0
)
continue
;
auto
in_arg_name
=
control_flow_op_input_map_
.
at
(
op_type
);
auto
in_name
=
node
->
Op
()
->
Input
(
in_arg_name
);
std
::
unordered_set
<
std
::
string
>
in_names_set
(
in_name
.
begin
(),
in_name
.
end
());
for
(
auto
delete_node_name
:
delete_node_names
)
{
if
(
in_names_set
.
count
(
delete_node_name
)
>
0
)
{
in_names_set
.
erase
(
delete_node_name
);
}
}
auto
*
sub_block
=
PADDLE_GET_CONST
(
framework
::
BlockDesc
*
,
node
->
Op
()
->
GetAttr
(
"sub_block"
));
auto
*
sub_graph
=
block_id_graph_map
.
at
(
sub_block
->
ID
());
std
::
unordered_set
<
std
::
string
>
sub_persistable_node_names
;
CollectReservedPersistableNodeNames
(
sub_graph
,
&
sub_persistable_node_names
);
for
(
auto
sub_name
:
sub_persistable_node_names
)
{
if
(
in_names_set
.
count
(
sub_name
)
>
0
)
continue
;
auto
*
in_node
=
FindNodeWithName
(
graph
,
sub_name
);
if
(
in_node
==
nullptr
)
continue
;
in_names_set
.
insert
(
sub_name
);
IR_NODE_LINK_TO
(
in_node
,
node
);
}
std
::
vector
<
std
::
string
>
new_in_names
(
in_names_set
.
begin
(),
in_names_set
.
end
());
node
->
Op
()
->
SetInput
(
in_arg_name
,
new_in_names
);
update_counts
++
;
}
return
update_counts
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_isolated_node_pass
,
paddle
::
framework
::
ir
::
DeleteIsolatedNodePass
);
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc
0 → 100644
浏览文件 @
39a9abaa
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
VarDesc
*
Data
(
paddle
::
framework
::
BlockDesc
*
block
,
std
::
string
name
,
std
::
vector
<
int64_t
>
shape
=
{},
bool
is_persistable
=
false
,
proto
::
VarType
::
Type
data_type
=
proto
::
VarType
::
FP32
)
{
auto
*
var
=
block
->
Var
(
name
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
var
->
SetDataType
(
data_type
);
var
->
SetShape
(
shape
);
var
->
SetPersistable
(
is_persistable
);
return
var
;
}
void
AddVarToScope
(
Scope
*
param_scope
,
const
std
::
string
&
name
,
const
DDim
&
dims
)
{
auto
*
tensor
=
param_scope
->
Var
(
name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
tensor
->
Resize
(
dims
);
auto
*
cpu_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
phi
::
CPUPlace
()));
auto
*
data
=
cpu_ctx
->
Alloc
<
float
>
(
tensor
);
int64_t
numel
=
tensor
->
numel
();
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
data
[
i
]
=
1
;
}
}
Scope
*
CreateParamScope
()
{
auto
param_scope
=
new
Scope
();
AddVarToScope
(
param_scope
,
"matmul0_w"
,
{
128
,
128
});
return
param_scope
;
}
int
WeightNodeNum
(
ir
::
Graph
*
graph
)
{
int
num
=
0
;
for
(
auto
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
()
&&
node
->
Var
()
->
Persistable
())
{
num
++
;
}
}
return
num
;
}
int
WeightTensorNum
(
Scope
*
scope
)
{
int
num
=
0
;
auto
vars
=
scope
->
LocalVars
();
for
(
auto
*
var
:
vars
)
{
if
(
var
->
Get
<
phi
::
DenseTensor
>
().
numel
()
>
0
)
{
num
++
;
}
}
return
num
;
}
TEST
(
delete_isolated_node_pass
,
basic
)
{
paddle
::
framework
::
ProgramDesc
program
;
auto
*
block0
=
program
.
MutableBlock
(
0
);
auto
*
block1
=
program
.
AppendBlock
(
*
block0
);
auto
*
matmul0_x
=
Data
(
block0
,
"matmul0_x"
,
{
1
,
128
});
auto
*
matmul0_w
=
Data
(
block0
,
"matmul0_w"
,
{
128
,
128
},
true
);
auto
*
matmul0_out
=
Data
(
block0
,
"matmul0_out"
,
{
1
,
128
});
OpDesc
*
matmul_op
=
block0
->
AppendOp
();
matmul_op
->
SetType
(
"matmul_v2"
);
matmul_op
->
SetInput
(
"X"
,
{
matmul0_x
->
Name
()});
matmul_op
->
SetInput
(
"Y"
,
{
matmul0_w
->
Name
()});
matmul_op
->
SetAttr
(
"trans_x"
,
false
);
matmul_op
->
SetAttr
(
"trans_y"
,
false
);
matmul_op
->
SetOutput
(
"Out"
,
{
matmul0_out
->
Name
()});
auto
*
while_out
=
Data
(
block0
,
"while_out"
,
{
1
,
128
});
auto
*
while_step_scopes
=
Data
(
block0
,
"while_step_scopes"
);
auto
*
while_cond
=
Data
(
block0
,
"while_cond"
);
OpDesc
*
while_op
=
block0
->
AppendOp
();
while_op
->
SetType
(
"while"
);
while_op
->
SetInput
(
"X"
,
{
matmul0_w
->
Name
(),
matmul0_out
->
Name
()});
while_op
->
SetInput
(
"Condition"
,
{
while_cond
->
Name
()});
while_op
->
SetOutput
(
"Out"
,
{
while_out
->
Name
()});
while_op
->
SetOutput
(
"StepScopes"
,
{
while_step_scopes
->
Name
()});
while_op
->
SetAttr
(
"sub_block"
,
{
block1
});
while_op
->
SetAttr
(
"is_test"
,
true
);
auto
*
matmul1_x
=
Data
(
block1
,
matmul0_out
->
Name
(),
matmul0_out
->
GetShape
());
auto
*
matmul1_w
=
Data
(
block1
,
matmul0_w
->
Name
(),
matmul0_w
->
GetShape
(),
true
);
auto
*
matmul1_out
=
Data
(
block1
,
"matmul1_out"
,
{
1
,
128
});
OpDesc
*
matmul1_op
=
block1
->
AppendOp
();
matmul1_op
->
SetType
(
"matmul_v2"
);
matmul1_op
->
SetInput
(
"X"
,
{
matmul1_x
->
Name
()});
matmul1_op
->
SetInput
(
"Y"
,
{
matmul1_w
->
Name
()});
matmul1_op
->
SetAttr
(
"trans_x"
,
false
);
matmul1_op
->
SetAttr
(
"trans_y"
,
false
);
matmul1_op
->
SetOutput
(
"Out"
,
{
matmul1_out
->
Name
()});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
program
));
auto
*
scope
=
CreateParamScope
();
graph
->
Set
(
"__param_scope__"
,
scope
);
auto
pass0
=
PassRegistry
::
Instance
().
Get
(
"fc_xpu_fuse_pass"
);
pass0
->
Apply
(
graph
.
get
());
pass0
->
Apply
(
graph
->
GetSubGraph
(
1
));
int
weight_node_num
=
WeightNodeNum
(
graph
.
get
())
+
WeightNodeNum
(
graph
->
GetSubGraph
(
1
));
PADDLE_ENFORCE_EQ
(
weight_node_num
,
6
,
platform
::
errors
::
PreconditionNotMet
(
"Graph should have 6 weight node after "
"fc_xpu_fuse_pass, but actually has %d."
,
weight_node_num
));
auto
pass1
=
PassRegistry
::
Instance
().
Get
(
"delete_isolated_node_pass"
);
pass1
->
Apply
(
graph
.
get
());
weight_node_num
=
WeightNodeNum
(
graph
.
get
())
+
WeightNodeNum
(
graph
->
GetSubGraph
(
1
));
PADDLE_ENFORCE_EQ
(
weight_node_num
,
4
,
platform
::
errors
::
PreconditionNotMet
(
"Graph should have 4 weight node after "
"delete_isolated_node_pass, but actually has %d."
,
weight_node_num
));
int
weight_tensor_num
=
WeightTensorNum
(
scope
);
PADDLE_ENFORCE_EQ
(
weight_tensor_num
,
2
,
platform
::
errors
::
PreconditionNotMet
(
"Scope should have 2 weight tensor after "
"delete_isolated_node_pass, but actually has %d."
,
weight_tensor_num
));
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"while"
)
{
auto
while_in_names
=
node
->
Op
()
->
Inputs
().
at
(
"X"
);
PADDLE_ENFORCE_EQ
(
while_in_names
.
size
(),
3
,
platform
::
errors
::
PreconditionNotMet
(
"While op should have 3 input after "
"delete_isolated_node_pass, but actually has %d."
,
while_in_names
.
size
()));
}
}
Scope
&
scope0
=
graph
->
Get
<
framework
::
Scope
>
(
"__param_scope__"
);
Scope
&
scope1
=
graph
->
GetSubGraph
(
1
)
->
Get
<
framework
::
Scope
>
(
"__param_scope__"
);
std
::
vector
<
std
::
string
>
shared_weight_names
{
matmul0_w
->
Name
()
+
"_int16"
,
matmul0_w
->
Name
()
+
"_max"
};
for
(
auto
name
:
shared_weight_names
)
{
auto
*
var0
=
scope0
.
FindVar
(
name
);
auto
*
var1
=
scope1
.
FindVar
(
name
);
PADDLE_ENFORCE
(
var0
==
var1
,
platform
::
errors
::
PreconditionNotMet
(
"Variables with the same name in two scopes is different."
));
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
delete_isolated_node_pass
);
paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
浏览文件 @
39a9abaa
...
...
@@ -76,7 +76,6 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern,
auto
*
mul_w
=
pattern
->
NewNode
(
mul_w_repr
())
->
assert_is_op_input
(
mul_type_
,
"Y"
)
->
assert_is_persistable_var
()
->
assert_has_n_outputs
(
1
)
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
2
;
});
...
...
@@ -169,10 +168,10 @@ class FcXPUFusePass : public FusePassBase {
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
void
ApplyImpl
(
ir
::
Graph
*
graph
,
const
std
::
string
&
mul_type
,
bool
with_bias
,
const
std
::
string
&
act_type
)
const
;
int
ApplyImpl
(
ir
::
Graph
*
graph
,
const
std
::
string
&
mul_type
,
bool
with_bias
,
const
std
::
string
&
act_type
)
const
;
const
std
::
string
name_scope_
{
"fc_xpu_fuse_pass"
};
};
...
...
@@ -181,6 +180,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
int
found_subgraph_count
=
0
;
for
(
auto
mul_type
:
{
"mul"
,
"matmul"
,
"matmul_v2"
})
{
for
(
auto
with_bias
:
{
true
,
false
})
{
for
(
auto
act_type
:
{
...
...
@@ -189,16 +190,17 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
"tanh"
,
""
,
})
{
ApplyImpl
(
graph
,
mul_type
,
with_bias
,
act_type
);
found_subgraph_count
+=
ApplyImpl
(
graph
,
mul_type
,
with_bias
,
act_type
);
}
}
}
AddStatis
(
found_subgraph_count
);
}
void
FcXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
,
const
std
::
string
&
mul_type
,
bool
with_bias
,
const
std
::
string
&
act_type
)
const
{
int
FcXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
,
const
std
::
string
&
mul_type
,
bool
with_bias
,
const
std
::
string
&
act_type
)
const
{
GraphPatternDetector
gpd
;
patterns
::
FcXPUPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
,
mul_type
,
with_bias
,
act_type
);
...
...
@@ -219,37 +221,20 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
auto
*
block
=
mul
->
Op
()
->
Block
();
auto
*
scope
=
param_scope
();
auto
mul_w_name
=
mul_w
->
Name
();
auto
mul_w_tensor
=
scope
->
FindVar
(
mul_w_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
// 1. Transform weight to int16/int31
// 2. Avoid transform repeatly, because weight may be shared with other ops.
// TODO(zhupengyang): support int31
std
::
string
mul_w_max_name
=
mul_w_name
+
"_max"
;
Node
*
mul_w_max
=
nullptr
;
if
(
mul_w_tensor
->
dtype
()
!=
phi
::
DataType
::
INT16
)
{
// Create weight_max node
VarDesc
mul_w_max_desc
(
mul_w_max_name
);
mul_w_max_desc
.
SetPersistable
(
true
);
mul_w_max
=
graph
->
CreateVarNode
(
&
mul_w_max_desc
);
// Create weight_max var/tensor
auto
mul_w_max_var
=
block
->
Var
(
mul_w_max_name
);
mul_w_max_var
->
SetPersistable
(
true
);
auto
mul_w_max_tensor
=
scope
->
Var
(
mul_w_max_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
bool
transpose_w
=
false
;
if
(
mul_type
==
"matmul"
)
{
transpose_w
=
PADDLE_GET_CONST
(
bool
,
mul
->
Op
()
->
GetAttr
(
"transpose_Y"
));
}
else
if
(
mul_type
==
"matmul_v2"
)
{
transpose_w
=
PADDLE_GET_CONST
(
bool
,
mul
->
Op
()
->
GetAttr
(
"trans_y"
));
}
QuantWeight
<
int16_t
>
(
mul_w_tensor
,
mul_w_max_tensor
,
!
transpose_w
);
bool
transpose_w
=
false
;
if
(
mul_type
==
"matmul"
)
{
transpose_w
=
PADDLE_GET_CONST
(
bool
,
mul
->
Op
()
->
GetAttr
(
"transpose_Y"
));
}
else
if
(
mul_type
==
"matmul_v2"
)
{
transpose_w
=
PADDLE_GET_CONST
(
bool
,
mul
->
Op
()
->
GetAttr
(
"trans_y"
));
}
Node
*
mul_w_int16
=
nullptr
;
Node
*
mul_w_max
=
nullptr
;
PrepareWeight
<
int16_t
>
(
graph
,
scope
,
block
,
mul_w
,
&
mul_w_int16
,
&
mul_w_max
,
!
transpose_w
);
Node
*
bias_fp32
=
nullptr
;
if
(
bias
!=
nullptr
)
{
auto
*
bias_tensor
=
scope
->
Var
(
bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
CastToFp32
(
bias_tensor
);
PrepareBias
(
graph
,
scope
,
block
,
bias
,
&
bias_fp32
);
}
std
::
string
fc_out_name
;
...
...
@@ -268,10 +253,10 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
framework
::
OpDesc
fc_xpu_op_desc
(
block
);
fc_xpu_op_desc
.
SetType
(
"fc_xpu"
);
fc_xpu_op_desc
.
SetInput
(
"x"
,
{
mul_x
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"w"
,
{
mul_w
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"w_max"
,
{
mul_w_max
_name
});
if
(
bias
)
{
fc_xpu_op_desc
.
SetInput
(
"bias"
,
{
bias
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"w"
,
{
mul_w
_int16
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"w_max"
,
{
mul_w_max
->
Name
()
});
if
(
bias
_fp32
)
{
fc_xpu_op_desc
.
SetInput
(
"bias"
,
{
bias
_fp32
->
Name
()});
}
fc_xpu_op_desc
.
SetAttr
(
"in_num_col_dims"
,
...
...
@@ -306,9 +291,9 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
fc_xpu_op_desc
.
SetOutput
(
"out_max"
,
{
fc_out_max_name
});
auto
*
fc_xpu
=
graph
->
CreateOpNode
(
&
fc_xpu_op_desc
);
IR_NODE_LINK_TO
(
mul_x
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_w
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_w
_int16
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_w_max
,
fc_xpu
);
SAFE_IR_NODE_LINK_TO
(
bias
,
fc_xpu
);
SAFE_IR_NODE_LINK_TO
(
bias
_fp32
,
fc_xpu
);
if
(
act_out
)
{
IR_NODE_LINK_TO
(
fc_xpu
,
act_out
);
}
else
if
(
add_out
)
{
...
...
@@ -334,7 +319,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
)
;
return
found_subgraph_count
;
}
}
// namespace ir
...
...
paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
浏览文件 @
39a9abaa
...
...
@@ -625,16 +625,24 @@ class MultiEncoderXPUFusePass : public FusePassBase {
// 2. Concat q_w, k_w, v_w
// 3. Generate qkv_w_max tensor
// 4. Quant qkv_w to int16
void
PrepareQKVWeight
(
const
phi
::
DenseTensor
&
q_w
,
const
phi
::
DenseTensor
&
k_w
,
const
phi
::
DenseTensor
&
v_w
,
phi
::
DenseTensor
*
qkv_w
,
phi
::
DenseTensor
*
qkv_w_max
)
const
;
void
ConcatQKVBias
(
const
phi
::
DenseTensor
&
q_bias
,
const
phi
::
DenseTensor
&
k_bias
,
const
phi
::
DenseTensor
&
v_bias
,
phi
::
DenseTensor
*
qkv_bias
)
const
;
void
PrepareQKVWeight
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
q_w
,
Node
*
k_w
,
Node
*
v_w
,
Node
**
qkv_w
,
Node
**
qkv_w_max
)
const
;
// 1. Cast bias to fp32
// 2. Concat q/k/v bias
void
PrepareQKVBias
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
q_bias
,
Node
*
k_bias
,
Node
*
v_bias
,
Node
**
qkv_bias
)
const
;
const
std
::
string
name_scope_
{
"multi_encoder_xpu_fuse_pass"
};
};
...
...
@@ -685,55 +693,160 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis
(
cast_mask_counts
);
}
void
MultiEncoderXPUFusePass
::
PrepareQKVWeight
(
const
phi
::
DenseTensor
&
q_w
,
const
phi
::
DenseTensor
&
k_w
,
const
phi
::
DenseTensor
&
v_w
,
phi
::
DenseTensor
*
qkv_w
,
phi
::
DenseTensor
*
qkv_w_max
)
const
{
// Transpose
phi
::
DenseTensor
q_w_t
;
phi
::
DenseTensor
k_w_t
;
phi
::
DenseTensor
v_w_t
;
Assign
(
q_w
,
&
q_w_t
);
Assign
(
k_w
,
&
k_w_t
);
Assign
(
v_w
,
&
v_w_t
);
Transpose2D
(
&
q_w_t
);
Transpose2D
(
&
k_w_t
);
Transpose2D
(
&
v_w_t
);
// Concat
qkv_w
->
Resize
(
DDim
(
{
q_w_t
.
dims
()[
0
]
+
k_w_t
.
dims
()[
0
]
+
v_w_t
.
dims
()[
0
],
q_w_t
.
dims
()[
1
]}));
qkv_w
->
set_type
(
q_w
.
type
());
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
void
MultiEncoderXPUFusePass
::
PrepareQKVWeight
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
q_w
,
Node
*
k_w
,
Node
*
v_w
,
Node
**
qkv_w_int16
,
Node
**
qkv_w_max
)
const
{
phi
::
DenseTensor
q_w_fp32_t
;
phi
::
DenseTensor
k_w_fp32_t
;
phi
::
DenseTensor
v_w_fp32_t
;
Assign
(
scope
->
Var
(
q_w
->
Name
())
->
Get
<
phi
::
DenseTensor
>
(),
&
q_w_fp32_t
);
Assign
(
scope
->
Var
(
k_w
->
Name
())
->
Get
<
phi
::
DenseTensor
>
(),
&
k_w_fp32_t
);
Assign
(
scope
->
Var
(
v_w
->
Name
())
->
Get
<
phi
::
DenseTensor
>
(),
&
v_w_fp32_t
);
CastToFp32
(
&
q_w_fp32_t
);
CastToFp32
(
&
k_w_fp32_t
);
CastToFp32
(
&
v_w_fp32_t
);
Transpose2D
(
&
q_w_fp32_t
);
Transpose2D
(
&
k_w_fp32_t
);
Transpose2D
(
&
v_w_fp32_t
);
phi
::
DenseTensor
qkv_w_int16_t
;
phi
::
DenseTensor
qkv_w_max_t
;
qkv_w_int16_t
.
Resize
(
DDim
({
q_w_fp32_t
.
dims
()[
0
]
+
k_w_fp32_t
.
dims
()[
0
]
+
v_w_fp32_t
.
dims
()[
0
],
q_w_fp32_t
.
dims
()[
1
]}));
qkv_w_int16_t
.
set_type
(
q_w_fp32_t
.
type
());
auto
*
cpu_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
phi
::
CPUPlace
()));
std
::
vector
<
const
phi
::
DenseTensor
*>
in_tensors
{
&
q_w_t
,
&
k_w_t
,
&
v_w_t
};
if
(
q_w
.
type
()
==
phi
::
DataType
::
FLOAT16
)
{
phi
::
ConcatKernel
<
float16
>
(
*
dev_ctx
,
in_tensors
,
0
,
qkv_w
);
std
::
vector
<
const
phi
::
DenseTensor
*>
in_tensors
{
&
q_w_fp32_t
,
&
k_w_fp32_t
,
&
v_w_fp32_t
};
phi
::
ConcatKernel
<
float
>
(
*
cpu_ctx
,
in_tensors
,
0
,
&
qkv_w_int16_t
);
PrepareWeight
<
int16_t
>
(
&
qkv_w_int16_t
,
&
qkv_w_max_t
,
false
);
size_t
qkv_w_int16_hash
=
HashTensor
<
int16_t
>
(
qkv_w_int16_t
);
size_t
qkv_w_max_hash
=
HashTensor
<
float
>
(
qkv_w_max_t
);
std
::
string
qkv_w_int16_name
=
std
::
to_string
(
qkv_w_int16_hash
);
std
::
string
qkv_w_max_name
=
std
::
to_string
(
qkv_w_max_hash
);
*
qkv_w_int16
=
FindNodeWithName
(
graph
,
qkv_w_int16_name
);
if
(
*
qkv_w_int16
==
nullptr
)
{
// Create qkv_w_int16 node
// Update qkv_w_int16 var_desc in block
VarDesc
qkv_w_int16_desc
(
qkv_w_int16_name
);
qkv_w_int16_desc
.
SetPersistable
(
true
);
qkv_w_int16_desc
.
SetShape
(
vectorize
(
qkv_w_int16_t
.
dims
()));
qkv_w_int16_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
qkv_w_int16_t
.
dtype
()));
*
qkv_w_int16
=
graph
->
CreateVarNode
(
&
qkv_w_int16_desc
);
auto
*
block_qkv_w_int16_desc
=
block
->
Var
(
qkv_w_int16_name
);
block_qkv_w_int16_desc
->
SetPersistable
(
qkv_w_int16_desc
.
Persistable
());
block_qkv_w_int16_desc
->
SetShape
(
qkv_w_int16_desc
.
GetShape
());
block_qkv_w_int16_desc
->
SetDataType
(
qkv_w_int16_desc
.
GetDataType
());
// Create qkv_w_max node
// Update qkv_w_max var_desc in block
VarDesc
qkv_w_max_desc
(
qkv_w_max_name
);
qkv_w_max_desc
.
SetPersistable
(
true
);
qkv_w_max_desc
.
SetShape
(
vectorize
(
qkv_w_max_t
.
dims
()));
qkv_w_max_desc
.
SetDataType
(
proto
::
VarType
::
Type
::
VarType_Type_FP32
);
*
qkv_w_max
=
graph
->
CreateVarNode
(
&
qkv_w_max_desc
);
auto
*
block_qkv_w_max_desc
=
block
->
Var
(
qkv_w_max_name
);
block_qkv_w_max_desc
->
SetPersistable
(
qkv_w_max_desc
.
Persistable
());
block_qkv_w_max_desc
->
SetShape
(
qkv_w_max_desc
.
GetShape
());
block_qkv_w_max_desc
->
SetDataType
(
qkv_w_max_desc
.
GetDataType
());
// Find qkv_w_int16/qkv_w_max variable in scope
auto
*
qkv_w_int16_var
=
scope
->
FindVar
(
qkv_w_int16_name
);
if
(
qkv_w_int16_var
==
nullptr
)
{
// Create qkv_w_int16/qkv_w_max variable/tensor
Assign
(
qkv_w_int16_t
,
scope
->
Var
(
qkv_w_int16_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
Assign
(
qkv_w_max_t
,
scope
->
Var
(
qkv_w_max_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
}
else
{
// Share the same variable
PADDLE_ENFORCE_NOT_NULL
(
scope
->
FindVar
(
qkv_w_max_name
),
platform
::
errors
::
Fatal
(
"qkv_w_max(%s) variable should not be nullptr if qkv_w_int16(%s) "
"variable is exist."
,
qkv_w_max_name
,
qkv_w_int16_name
));
}
}
else
{
phi
::
ConcatKernel
<
float
>
(
*
dev_ctx
,
in_tensors
,
0
,
qkv_w
);
*
qkv_w_max
=
FindNodeWithName
(
graph
,
qkv_w_max_name
);
PADDLE_ENFORCE_NOT_NULL
(
*
qkv_w_max
,
platform
::
errors
::
Fatal
(
"qkv_w_max(%s) variable should not be nullptr if qkv_w_int16(%s) "
"variable is exist."
,
qkv_w_max_name
,
qkv_w_int16_name
));
}
// Quant to int16
QuantWeight
<
int16_t
>
(
qkv_w
,
qkv_w_max
,
false
);
}
void
MultiEncoderXPUFusePass
::
ConcatQKVBias
(
const
phi
::
DenseTensor
&
q_bias
,
const
phi
::
DenseTensor
&
k_bias
,
const
phi
::
DenseTensor
&
v_bias
,
phi
::
DenseTensor
*
qkv_bias
)
const
{
int
q_bias_size
=
q_bias
.
numel
();
qkv_bias
->
Resize
(
DDim
({
q_bias_size
*
3
}));
qkv_bias
->
set_type
(
q_bias
.
type
());
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
void
MultiEncoderXPUFusePass
::
PrepareQKVBias
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
q_bias
,
Node
*
k_bias
,
Node
*
v_bias
,
Node
**
qkv_bias
)
const
{
auto
*
q_bias_tensor
=
scope
->
Var
(
q_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
k_bias_tensor
=
scope
->
Var
(
k_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
v_bias_tensor
=
scope
->
Var
(
v_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
phi
::
DenseTensor
q_bias_fp32_tensor
;
phi
::
DenseTensor
k_bias_fp32_tensor
;
phi
::
DenseTensor
v_bias_fp32_tensor
;
CastToFp32
(
q_bias_tensor
,
&
q_bias_fp32_tensor
);
CastToFp32
(
k_bias_tensor
,
&
k_bias_fp32_tensor
);
CastToFp32
(
v_bias_tensor
,
&
v_bias_fp32_tensor
);
phi
::
DenseTensor
qkv_bias_tensor
;
int
q_bias_fp32_size
=
q_bias_fp32_tensor
.
numel
();
qkv_bias_tensor
.
Resize
(
DDim
({
q_bias_fp32_size
*
3
}));
qkv_bias_tensor
.
set_type
(
phi
::
DataType
::
FLOAT32
);
auto
*
cpu_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
phi
::
CPUPlace
()));
auto
*
qkv_bias_data
=
dev_ctx
->
Alloc
<
float
>
(
qkv_bias
);
memcpy
(
qkv_bias_data
,
q_bias
.
data
(),
q_bias_size
*
sizeof
(
float
));
qkv_bias_data
+=
q_bias_size
;
memcpy
(
qkv_bias_data
,
k_bias
.
data
(),
q_bias_size
*
sizeof
(
float
));
qkv_bias_data
+=
q_bias_size
;
memcpy
(
qkv_bias_data
,
v_bias
.
data
(),
q_bias_size
*
sizeof
(
float
));
auto
*
qkv_bias_data
=
cpu_ctx
->
Alloc
<
float
>
(
&
qkv_bias_tensor
);
memcpy
(
qkv_bias_data
,
q_bias_fp32_tensor
.
data
(),
q_bias_fp32_size
*
sizeof
(
float
));
qkv_bias_data
+=
q_bias_fp32_size
;
memcpy
(
qkv_bias_data
,
k_bias_fp32_tensor
.
data
(),
q_bias_fp32_size
*
sizeof
(
float
));
qkv_bias_data
+=
q_bias_fp32_size
;
memcpy
(
qkv_bias_data
,
v_bias_fp32_tensor
.
data
(),
q_bias_fp32_size
*
sizeof
(
float
));
size_t
qkv_bias_hash
=
HashTensor
<
float
>
(
qkv_bias_tensor
);
std
::
string
qkv_bias_name
=
std
::
to_string
(
qkv_bias_hash
);
*
qkv_bias
=
FindNodeWithName
(
graph
,
qkv_bias_name
);
if
(
*
qkv_bias
==
nullptr
)
{
// Create qkv_bias node
// Update qkv_bias var_desc in block
VarDesc
qkv_bias_desc
(
qkv_bias_name
);
qkv_bias_desc
.
SetPersistable
(
true
);
qkv_bias_desc
.
SetShape
(
vectorize
(
qkv_bias_tensor
.
dims
()));
qkv_bias_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
qkv_bias_tensor
.
dtype
()));
*
qkv_bias
=
graph
->
CreateVarNode
(
&
qkv_bias_desc
);
auto
*
block_qkv_bias_desc
=
block
->
Var
(
qkv_bias_name
);
block_qkv_bias_desc
->
SetPersistable
(
qkv_bias_desc
.
Persistable
());
block_qkv_bias_desc
->
SetShape
(
qkv_bias_desc
.
GetShape
());
block_qkv_bias_desc
->
SetDataType
(
qkv_bias_desc
.
GetDataType
());
Assign
(
qkv_bias_tensor
,
scope
->
Var
(
qkv_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
}
}
int
MultiEncoderXPUFusePass
::
ApplySingleEncoderXPUFuse
(
...
...
@@ -856,109 +969,67 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
scope
->
FindVar
(
q_matmul_w
->
Name
())
->
Get
<
phi
::
DenseTensor
>
().
dtype
()
==
phi
::
DataType
::
FLOAT16
;
// Prepare q,k,v weight
std
::
string
q_w_name
=
q_matmul_w
->
Name
();
std
::
string
k_w_name
=
k_matmul_w
->
Name
();
std
::
string
v_w_name
=
v_matmul_w
->
Name
();
std
::
string
qkv_w_name
=
q_w_name
+
"_"
+
k_w_name
+
"_"
+
v_w_name
;
VarDesc
qkv_w_desc
(
qkv_w_name
);
qkv_w_desc
.
SetPersistable
(
true
);
auto
*
qkv_w
=
graph
->
CreateVarNode
(
&
qkv_w_desc
);
auto
*
qkv_w_var
=
block
->
Var
(
qkv_w_name
);
qkv_w_var
->
SetPersistable
(
true
);
std
::
string
qkv_w_max_name
=
qkv_w_name
+
"_max"
;
VarDesc
qkv_w_max_desc
(
qkv_w_max_name
);
qkv_w_max_desc
.
SetPersistable
(
true
);
auto
*
qkv_w_max
=
graph
->
CreateVarNode
(
&
qkv_w_max_desc
);
auto
*
qkv_w_max_var
=
block
->
Var
(
qkv_w_max_name
);
qkv_w_max_var
->
SetPersistable
(
true
);
PrepareQKVWeight
(
scope
->
FindVar
(
q_w_name
)
->
Get
<
phi
::
DenseTensor
>
(),
scope
->
FindVar
(
k_w_name
)
->
Get
<
phi
::
DenseTensor
>
(),
scope
->
FindVar
(
v_w_name
)
->
Get
<
phi
::
DenseTensor
>
(),
scope
->
Var
(
qkv_w_name
)
->
GetMutable
<
phi
::
DenseTensor
>
(),
scope
->
Var
(
qkv_w_max_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
// Prepare qkv_matmul_1_w, qkv_matmul_2_w, qkv_matmul_3_w
#define PREPARE_QKV_MATMUL_W(idx_) \
std::string qkv_matmul_##idx_##_w_name = qkv_matmul_##idx_##_w->Name(); \
std::string qkv_matmul_##idx_##_w_max_name = \
qkv_matmul_##idx_##_w_name + "_max"; \
VarDesc qkv_matmul_##idx_##_w_max_desc(qkv_matmul_##idx_##_w_max_name); \
qkv_matmul_##idx_##_w_max_desc.SetPersistable(true); \
auto qkv_matmul_##idx_##_w_max = \
graph->CreateVarNode(&qkv_matmul_##idx_##_w_max_desc); \
auto qkv_matmul_##idx_##_w_max_var = \
block->Var(qkv_matmul_##idx_##_w_max_name); \
qkv_matmul_##idx_##_w_max_var->SetPersistable(true); \
auto qkv_matmul_##idx_##_w_max_tensor = \
scope->Var(qkv_matmul_##idx_##_w_max_name) \
->GetMutable<phi::DenseTensor>(); \
auto qkv_matmul_##idx_##_w_tensor = \
scope->Var(qkv_matmul_##idx_##_w_name)->GetMutable<phi::DenseTensor>(); \
QuantWeight<int16_t>( \
qkv_matmul_##idx_##_w_tensor, qkv_matmul_##idx_##_w_max_tensor, true);
Node
*
qkv_w_int16
=
nullptr
;
Node
*
qkv_w_max
=
nullptr
;
PrepareQKVWeight
(
graph
,
scope
,
block
,
q_matmul_w
,
k_matmul_w
,
v_matmul_w
,
&
qkv_w_int16
,
&
qkv_w_max
);
#define PREPARE_QKV_MATMUL_W(idx_) \
Node* qkv_matmul_##idx_##_w_int16 = nullptr; \
Node* qkv_matmul_##idx_##_w_max = nullptr; \
PrepareWeight<int16_t>(graph, \
scope, \
block, \
qkv_matmul_##idx_##_w, \
&qkv_matmul_##idx_##_w_int16, \
&qkv_matmul_##idx_##_w_max, \
true);
PREPARE_QKV_MATMUL_W
(
1
);
PREPARE_QKV_MATMUL_W
(
2
);
PREPARE_QKV_MATMUL_W
(
3
);
#undef PREPARE_QKV_MATMUL_W
// Concat q_add_bias, k_add_bias, v_add_bias
std
::
string
q_add_bias_name
=
q_add_bias
->
Name
();
std
::
string
k_add_bias_name
=
k_add_bias
->
Name
();
std
::
string
v_add_bias_name
=
v_add_bias
->
Name
();
std
::
string
qkv_add_bias_name
=
q_add_bias_name
+
"_"
+
k_add_bias_name
+
"_"
+
v_add_bias_name
;
VarDesc
qkv_add_bias_desc
(
qkv_add_bias_name
);
qkv_add_bias_desc
.
SetPersistable
(
true
);
auto
*
qkv_add_bias
=
graph
->
CreateVarNode
(
&
qkv_add_bias_desc
);
auto
*
qkv_add_bias_var
=
block
->
Var
(
qkv_add_bias_name
);
qkv_add_bias_var
->
SetPersistable
(
true
);
auto
*
q_add_bias_tensor
=
scope
->
FindVar
(
q_add_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
k_add_bias_tensor
=
scope
->
FindVar
(
k_add_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
v_add_bias_tensor
=
scope
->
FindVar
(
v_add_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
CastToFp32
(
q_add_bias_tensor
);
CastToFp32
(
k_add_bias_tensor
);
CastToFp32
(
v_add_bias_tensor
);
ConcatQKVBias
(
*
q_add_bias_tensor
,
*
k_add_bias_tensor
,
*
v_add_bias_tensor
,
scope
->
Var
(
qkv_add_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
// Prepare qkv_add_0_bias, qkv_add_2_bias, qkv_add_3_bias
auto
qkv_add_0_bias_name
=
qkv_add_0_bias
->
Name
();
CastToFp32
(
scope
->
FindVar
(
qkv_add_0_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
auto
qkv_add_2_bias_name
=
qkv_add_2_bias
->
Name
();
CastToFp32
(
scope
->
FindVar
(
qkv_add_2_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
auto
qkv_add_3_bias_name
=
qkv_add_3_bias
->
Name
();
CastToFp32
(
scope
->
FindVar
(
qkv_add_3_bias_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
Node
*
qkv_add_bias_fp32
=
nullptr
;
PrepareQKVBias
(
graph
,
scope
,
block
,
q_add_bias
,
k_add_bias
,
v_add_bias
,
&
qkv_add_bias_fp32
);
Node
*
qkv_add_0_bias_fp32
=
nullptr
;
Node
*
qkv_add_2_bias_fp32
=
nullptr
;
Node
*
qkv_add_3_bias_fp32
=
nullptr
;
PrepareBias
(
graph
,
scope
,
block
,
qkv_add_0_bias
,
&
qkv_add_0_bias_fp32
);
PrepareBias
(
graph
,
scope
,
block
,
qkv_add_2_bias
,
&
qkv_add_2_bias_fp32
);
PrepareBias
(
graph
,
scope
,
block
,
qkv_add_3_bias
,
&
qkv_add_3_bias_fp32
);
// Generate single_encoder_xpu op
framework
::
OpDesc
op_desc
(
block
);
op_desc
.
SetType
(
"single_encoder_xpu"
);
op_desc
.
SetInput
(
"x"
,
{
ln_0_x
->
Name
()});
op_desc
.
SetInput
(
"fc_weight"
,
{
qkv_w_
name
,
qkv_matmul_1_w_
name
,
qkv_matmul_2_w_
name
,
qkv_matmul_3_w_
name
});
{
qkv_w_
int16
->
Name
()
,
qkv_matmul_1_w_
int16
->
Name
()
,
qkv_matmul_2_w_
int16
->
Name
()
,
qkv_matmul_3_w_
int16
->
Name
()
});
op_desc
.
SetInput
(
"fc_weight_max"
,
{
qkv_w_max
_name
,
qkv_matmul_1_w_max
_name
,
qkv_matmul_2_w_max
_name
,
qkv_matmul_3_w_max
_name
});
{
qkv_w_max
->
Name
()
,
qkv_matmul_1_w_max
->
Name
()
,
qkv_matmul_2_w_max
->
Name
()
,
qkv_matmul_3_w_max
->
Name
()
});
op_desc
.
SetInput
(
"fc_bias"
,
{
qkv_add_bias_
name
,
qkv_add_0_bias_
name
,
qkv_add_2_bias_
name
,
qkv_add_3_bias_
name
});
{
qkv_add_bias_
fp32
->
Name
()
,
qkv_add_0_bias_
fp32
->
Name
()
,
qkv_add_2_bias_
fp32
->
Name
()
,
qkv_add_3_bias_
fp32
->
Name
()
});
if
(
norm_before
)
{
op_desc
.
SetInput
(
"ln_scale"
,
{
ln_0_scale
->
Name
(),
ln_1_scale
->
Name
()});
op_desc
.
SetInput
(
"ln_bias"
,
{
ln_0_bias
->
Name
(),
ln_1_bias
->
Name
()});
...
...
@@ -990,30 +1061,30 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
}
auto
*
single_encoder_xpu
=
graph
->
CreateOpNode
(
&
op_desc
);
// Link nodes
SAFE_
IR_NODE_LINK_TO
(
ln_0_x
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_w
,
single_encoder_xpu
);
SAFE_
IR_NODE_LINK_TO
(
qkv_w_max
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_matmul_1_w
,
single_encoder_xpu
);
SAFE_
IR_NODE_LINK_TO
(
qkv_matmul_1_w_max
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_matmul_2_w
,
single_encoder_xpu
);
SAFE_
IR_NODE_LINK_TO
(
qkv_matmul_2_w_max
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_matmul_3_w
,
single_encoder_xpu
);
SAFE_
IR_NODE_LINK_TO
(
qkv_matmul_3_w_max
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_add_bias
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_add_0_bias
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_add_2_bias
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qkv_add_3_bias
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
ln_0_x
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_w_int16
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_w_max
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_matmul_1_w_int16
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_matmul_1_w_max
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_matmul_2_w_int16
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_matmul_2_w_max
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_matmul_3_w_int16
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_matmul_3_w_max
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_add_bias_fp32
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_add_0_bias_fp32
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_add_2_bias_fp32
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
qkv_add_3_bias_fp32
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
ln_0_scale
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
ln_0_bias
,
single_encoder_xpu
);
SAFE_
IR_NODE_LINK_TO
(
ln_1_scale
,
single_encoder_xpu
);
SAFE_
IR_NODE_LINK_TO
(
ln_1_bias
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
ln_1_scale
,
single_encoder_xpu
);
IR_NODE_LINK_TO
(
ln_1_bias
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
ln_2_scale
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
ln_2_bias
,
single_encoder_xpu
);
SAFE_IR_NODE_LINK_TO
(
qk_add_mask
,
single_encoder_xpu
);
if
(
norm_before
)
{
SAFE_
IR_NODE_LINK_TO
(
single_encoder_xpu
,
qkv_add_4_out
);
IR_NODE_LINK_TO
(
single_encoder_xpu
,
qkv_add_4_out
);
}
else
{
SAFE_
IR_NODE_LINK_TO
(
single_encoder_xpu
,
ln_2_out
);
IR_NODE_LINK_TO
(
single_encoder_xpu
,
ln_2_out
);
}
// Delete nodes
...
...
paddle/fluid/framework/ir/xpu/pass_utils.cc
浏览文件 @
39a9abaa
...
...
@@ -20,6 +20,18 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
static
void
HashCombine
(
std
::
size_t
*
seed
)
{}
// combine hash value
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
template
<
typename
T
,
typename
...
Rest
>
static
void
HashCombine
(
std
::
size_t
*
seed
,
const
T
&
v
,
Rest
...
rest
)
{
std
::
hash
<
T
>
hasher
;
*
seed
^=
hasher
(
v
)
+
0x9e3779b9
+
(
*
seed
<<
6
)
+
(
*
seed
>>
2
);
*
seed
*=
0x00000100000001B3
;
HashCombine
(
seed
,
rest
...);
}
int
ConvertActivationType
(
std
::
string
act_type
)
{
if
(
act_type
==
""
)
{
return
static_cast
<
int
>
(
xpu
::
Activation_t
::
LINEAR
);
...
...
@@ -50,6 +62,161 @@ int ConvertActivationType(std::string act_type) {
return
-
1
;
}
Node
*
FindNodeWithName
(
Graph
*
graph
,
std
::
string
name
)
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
()
&&
node
->
Var
()
->
Name
()
==
name
)
{
return
node
;
}
}
return
nullptr
;
}
template
<
typename
T
>
std
::
string
IntTypeToString
()
{
LOG
(
FATAL
)
<<
"Not support type."
;
return
""
;
}
template
<
>
std
::
string
IntTypeToString
<
int16_t
>
()
{
return
"int16"
;
}
template
<
typename
T
>
size_t
HashTensor
(
const
phi
::
DenseTensor
&
in
)
{
size_t
ret
=
0
;
auto
in_dims
=
in
.
dims
();
HashCombine
(
&
ret
,
phi
::
DataTypeToString
(
in
.
dtype
()),
phi
::
DataLayoutToString
(
in
.
layout
()),
in_dims
.
size
());
for
(
int
i
=
0
;
i
<
in_dims
.
size
();
i
++
)
{
HashCombine
(
&
ret
,
in_dims
[
i
]);
}
auto
*
data
=
in
.
data
<
T
>
();
int64_t
size
=
in
.
numel
();
for
(
int64_t
i
=
0
;
i
<
size
;
i
++
)
{
HashCombine
(
&
ret
,
data
[
i
]);
}
return
ret
;
}
template
size_t
HashTensor
<
int16_t
>(
const
phi
::
DenseTensor
&
in
);
template
size_t
HashTensor
<
float
>(
const
phi
::
DenseTensor
&
in
);
template
<
typename
T
>
void
PrepareWeight
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
,
Node
**
dst_max
,
bool
transpose
)
{
auto
src_name
=
src
->
Name
();
auto
*
src_tensor
=
scope
->
Var
(
src_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
phi
::
DenseTensor
dst_tensor
;
Assign
(
*
src_tensor
,
&
dst_tensor
);
phi
::
DenseTensor
dst_max_tensor
;
PrepareWeight
<
T
>
(
&
dst_tensor
,
&
dst_max_tensor
,
transpose
);
size_t
dst_hash
=
HashTensor
<
T
>
(
dst_tensor
);
size_t
dst_max_hash
=
HashTensor
<
float
>
(
dst_max_tensor
);
std
::
string
dst_name
=
src_name
+
"_"
+
std
::
to_string
(
dst_hash
);
std
::
string
dst_max_name
=
src_name
+
"_max_"
+
std
::
to_string
(
dst_max_hash
);
*
dst
=
FindNodeWithName
(
graph
,
dst_name
);
if
(
*
dst
==
nullptr
)
{
// Create dst node
// Update dst var_desc in block
VarDesc
dst_desc
(
dst_name
);
dst_desc
.
SetPersistable
(
true
);
dst_desc
.
SetShape
(
vectorize
(
dst_tensor
.
dims
()));
dst_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
dst_tensor
.
dtype
()));
*
dst
=
graph
->
CreateVarNode
(
&
dst_desc
);
auto
*
block_dst_desc
=
block
->
Var
(
dst_name
);
block_dst_desc
->
SetPersistable
(
dst_desc
.
Persistable
());
block_dst_desc
->
SetShape
(
dst_desc
.
GetShape
());
block_dst_desc
->
SetDataType
(
dst_desc
.
GetDataType
());
// Create dst_max node
// Update dst_max var_desc in block
VarDesc
dst_max_desc
(
dst_max_name
);
dst_max_desc
.
SetPersistable
(
true
);
dst_max_desc
.
SetShape
(
vectorize
(
dst_max_tensor
.
dims
()));
dst_max_desc
.
SetDataType
(
proto
::
VarType
::
Type
::
VarType_Type_FP32
);
*
dst_max
=
graph
->
CreateVarNode
(
&
dst_max_desc
);
auto
*
block_dst_max_desc
=
block
->
Var
(
dst_max_name
);
block_dst_max_desc
->
SetPersistable
(
dst_max_desc
.
Persistable
());
block_dst_max_desc
->
SetShape
(
dst_max_desc
.
GetShape
());
block_dst_max_desc
->
SetDataType
(
dst_max_desc
.
GetDataType
());
// Find dst/dst_max variable in scope
auto
*
dst_var
=
scope
->
FindVar
(
dst_name
);
if
(
dst_var
==
nullptr
)
{
// Create dst/dst_max variable/tensor
Assign
(
dst_tensor
,
scope
->
Var
(
dst_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
Assign
(
dst_max_tensor
,
scope
->
Var
(
dst_max_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
}
else
{
// Share the same variable
PADDLE_ENFORCE_NOT_NULL
(
scope
->
FindVar
(
dst_max_name
),
platform
::
errors
::
Fatal
(
"dst_max(%s) variable should not be nullptr if dst(%s) "
"variable is exist. (src_name is %s)"
,
dst_max_name
,
dst_name
,
src_name
));
}
}
else
{
*
dst_max
=
FindNodeWithName
(
graph
,
dst_max_name
);
PADDLE_ENFORCE_NOT_NULL
(
*
dst_max
,
platform
::
errors
::
Fatal
(
"dst_max(%s) variable should not be nullptr if dst(%s) "
"variable is exist. (src_name is %s)"
,
dst_max_name
,
dst_name
,
src_name
));
}
}
template
void
PrepareWeight
<
int16_t
>(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
,
Node
**
dst_max
,
bool
transpose
);
void
PrepareBias
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
)
{
auto
src_name
=
src
->
Name
();
auto
*
src_tensor
=
scope
->
Var
(
src_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
src_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
*
dst
=
src
;
}
phi
::
DenseTensor
dst_tensor
;
CastToFp32
(
src_tensor
,
&
dst_tensor
);
size_t
dst_hash
=
HashTensor
<
float
>
(
dst_tensor
);
std
::
string
dst_name
=
src_name
+
"_"
+
std
::
to_string
(
dst_hash
);
*
dst
=
FindNodeWithName
(
graph
,
dst_name
);
if
(
*
dst
==
nullptr
)
{
// Create dst node
// Update dst var_desc in block
VarDesc
dst_desc
(
dst_name
);
dst_desc
.
SetPersistable
(
true
);
dst_desc
.
SetShape
(
vectorize
(
dst_tensor
.
dims
()));
dst_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
dst_tensor
.
dtype
()));
*
dst
=
graph
->
CreateVarNode
(
&
dst_desc
);
auto
*
block_dst_desc
=
block
->
Var
(
dst_name
);
block_dst_desc
->
SetPersistable
(
dst_desc
.
Persistable
());
block_dst_desc
->
SetShape
(
dst_desc
.
GetShape
());
block_dst_desc
->
SetDataType
(
dst_desc
.
GetDataType
());
Assign
(
dst_tensor
,
scope
->
Var
(
dst_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/xpu/pass_utils.h
浏览文件 @
39a9abaa
...
...
@@ -14,6 +14,10 @@
#pragma once
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -45,6 +49,23 @@ namespace ir {
int
ConvertActivationType
(
std
::
string
act_type
);
Node
*
FindNodeWithName
(
Graph
*
graph
,
std
::
string
name
);
template
<
typename
T
>
size_t
HashTensor
(
const
phi
::
DenseTensor
&
in
);
template
<
typename
T
>
void
PrepareWeight
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
,
Node
**
dst_max
,
bool
transpose
);
void
PrepareBias
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
);
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/xpu/quant_utils.cc
浏览文件 @
39a9abaa
...
...
@@ -207,9 +207,9 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr,
}
template
<
typename
T
>
void
Quant
Weight
(
phi
::
DenseTensor
*
weight
,
phi
::
DenseTensor
*
weight_max
,
bool
transpose
)
{
void
Prepare
Weight
(
phi
::
DenseTensor
*
weight
,
phi
::
DenseTensor
*
weight_max
,
bool
transpose
)
{
// Convert fp16 to fp32
phi
::
DenseTensor
weight_fp32
;
CastToFp32
(
weight
,
&
weight_fp32
);
...
...
@@ -249,9 +249,9 @@ void QuantWeight(phi::DenseTensor* weight,
QuantFP32ToIntX
(
weight_data
,
cpu_ctx
->
Alloc
<
T
>
(
weight
),
max_val
,
size
);
}
template
void
Quant
Weight
<
int16_t
>(
phi
::
DenseTensor
*
weight
,
phi
::
DenseTensor
*
weight_max
,
bool
transpose
);
template
void
Prepare
Weight
<
int16_t
>(
phi
::
DenseTensor
*
weight
,
phi
::
DenseTensor
*
weight_max
,
bool
transpose
);
}
// namespace ir
}
// namespace framework
...
...
paddle/fluid/framework/ir/xpu/quant_utils.h
浏览文件 @
39a9abaa
...
...
@@ -29,9 +29,9 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
// 2. Weight data is in-place update.
// 3. Generate weight max tensor
template
<
typename
T
>
void
Quant
Weight
(
phi
::
DenseTensor
*
weight
,
phi
::
DenseTensor
*
weight_max
,
bool
transpose
);
void
Prepare
Weight
(
phi
::
DenseTensor
*
weight
,
phi
::
DenseTensor
*
weight_max
,
bool
transpose
);
}
// namespace ir
}
// namespace framework
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
39a9abaa
...
...
@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fc_xpu_fuse_pass"
,
"link_xpu_op_max_pass"
,
"delete_op_device_pass"
,
"delete_isolated_node_pass"
,
});
use_xpu_
=
true
;
}
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py
浏览文件 @
39a9abaa
...
...
@@ -15,6 +15,7 @@
import
unittest
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
...
...
@@ -33,7 +34,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
)
matmul0_y_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
8
),
min_size
=
2
,
max_size
=
2
st
.
integers
(
min_value
=
2
,
max_value
=
8
),
min_size
=
2
,
max_size
=
2
)
)
matmul0_y_shape
[
0
]
=
matmul0_x_shape
[
-
1
]
...
...
@@ -42,7 +43,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
# 3. matmul1
matmul1_y_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
8
),
min_size
=
2
,
max_size
=
2
st
.
integers
(
min_value
=
2
,
max_value
=
8
),
min_size
=
2
,
max_size
=
2
)
)
matmul1_y_shape
[
0
]
=
matmul0_y_shape
[
-
1
]
...
...
@@ -101,4 +102,5 @@ class TestFcXPUFusePass(PassAutoScanTest):
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
200
)
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录