Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
39a9abaa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
39a9abaa
编写于
3月 07, 2023
作者:
Z
zhupengyang
提交者:
GitHub
3月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] support shared weight; delete isolated node (#51108)
上级
50ad760c
变更
14
展开全部
隐藏空白更改
内联
并排
Showing
14 changed file
with
863 addition
and
223 deletion
+863
-223
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+10
-1
paddle/fluid/framework/ir/delete_op_device_pass.cc
paddle/fluid/framework/ir/delete_op_device_pass.cc
+4
-4
paddle/fluid/framework/ir/delete_op_device_pass_test.cc
paddle/fluid/framework/ir/delete_op_device_pass_test.cc
+1
-2
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+1
-0
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc
+203
-0
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc
.../fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc
+181
-0
paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
+30
-45
paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
+231
-160
paddle/fluid/framework/ir/xpu/pass_utils.cc
paddle/fluid/framework/ir/xpu/pass_utils.cc
+167
-0
paddle/fluid/framework/ir/xpu/pass_utils.h
paddle/fluid/framework/ir/xpu/pass_utils.h
+21
-0
paddle/fluid/framework/ir/xpu/quant_utils.cc
paddle/fluid/framework/ir/xpu/quant_utils.cc
+6
-6
paddle/fluid/framework/ir/xpu/quant_utils.h
paddle/fluid/framework/ir/xpu/quant_utils.h
+3
-3
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py
...s/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py
+4
-2
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
39a9abaa
...
@@ -220,7 +220,7 @@ if(WITH_XPU)
...
@@ -220,7 +220,7 @@ if(WITH_XPU)
cc_library
(
cc_library
(
xpu_pass_utils
xpu_pass_utils
SRCS xpu/pass_utils.cc
SRCS xpu/pass_utils.cc
DEPS pass
)
DEPS pass
xpu_quant_utils
)
set
(
XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils
)
set
(
XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils
)
pass_library
(
embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
pass_library
(
embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
${
XPU_PASS_DEPS
}
)
...
@@ -232,6 +232,8 @@ if(WITH_XPU)
...
@@ -232,6 +232,8 @@ if(WITH_XPU)
pass_library
(
generate_sequence_xpu_fuse_pass inference DIR xpu DEPS
pass_library
(
generate_sequence_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
${
XPU_PASS_DEPS
}
)
pass_library
(
link_xpu_op_max_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
link_xpu_op_max_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
delete_isolated_node_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
endif
()
endif
()
cc_library
(
cc_library
(
...
@@ -484,3 +486,10 @@ if(WITH_MKLDNN)
...
@@ -484,3 +486,10 @@ if(WITH_MKLDNN)
SRCS mkldnn/cpu_bfloat16_pass_tester.cc
SRCS mkldnn/cpu_bfloat16_pass_tester.cc
DEPS cpu_bfloat16_pass
)
DEPS cpu_bfloat16_pass
)
endif
()
endif
()
if
(
WITH_XPU
)
cc_test
(
test_delete_isolated_node_pass
SRCS xpu/delete_isolated_node_pass_test.cc
DEPS delete_isolated_node_pass
)
endif
()
paddle/fluid/framework/ir/delete_op_device_pass.cc
浏览文件 @
39a9abaa
...
@@ -39,14 +39,14 @@ class DeleteOpDevicePass : public Pass {
...
@@ -39,14 +39,14 @@ class DeleteOpDevicePass : public Pass {
void
DeleteOpDevicePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
DeleteOpDevicePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
int
found_subgraph_count
=
0
;
int
delete_counts
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
()
||
!
node
->
Op
()
->
HasAttr
(
"op_device"
))
continue
;
if
(
!
node
->
IsOp
()
||
!
node
->
Op
()
->
HasAttr
(
"op_device"
))
continue
;
node
->
Op
()
->
RemoveAttr
(
"op_device"
);
node
->
Op
()
->
RemoveAttr
(
"op_device"
);
found_subgraph_count
++
;
delete_counts
++
;
}
}
if
(
found_subgraph_count
>
0
)
{
if
(
delete_counts
>
0
)
{
LOG
(
INFO
)
<<
"--- de
tected "
<<
found_subgraph_count
<<
" subgraphs
"
;
LOG
(
INFO
)
<<
"--- de
lete "
<<
delete_counts
<<
" op_device attr
"
;
}
}
}
}
...
...
paddle/fluid/framework/ir/delete_op_device_pass_test.cc
浏览文件 @
39a9abaa
...
@@ -13,8 +13,7 @@
...
@@ -13,8 +13,7 @@
// limitations under the License.
// limitations under the License.
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/framework/ir/pass.cc
浏览文件 @
39a9abaa
...
@@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
...
@@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass"
,
"fuse_multi_transformer_layer_pass"
,
"delete_quant_dequant_linear_op_pass"
,
"delete_quant_dequant_linear_op_pass"
,
"delete_weight_dequant_linear_op_pass"
,
"delete_weight_dequant_linear_op_pass"
,
"fc_xpu_fuse_pass"
,
"delete_op_device_pass"
};
"delete_op_device_pass"
};
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
...
...
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc
0 → 100644
浏览文件 @
39a9abaa
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/scope.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
DeleteIsolatedNodePass
:
public
Pass
{
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
private:
void
CollectReservedPersistableNodeNames
(
Graph
*
graph
,
std
::
unordered_set
<
std
::
string
>*
reserved_persistable_node_names
)
const
;
int
RemoveIsolatedNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
reserved_persistable_node_names
,
std
::
unordered_set
<
std
::
string
>*
delete_node_names
)
const
;
int
UpdateControlFlowOp
(
Graph
*
graph
,
const
std
::
map
<
int
,
Graph
*>&
block_id_graph_map
,
const
std
::
unordered_set
<
std
::
string
>&
delete_node_names
)
const
;
const
std
::
map
<
std
::
string
,
std
::
string
>
control_flow_op_input_map_
{
{
"while"
,
"X"
},
{
"conditional_block"
,
"Input"
},
};
};
void
DeleteIsolatedNodePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
PADDLE_ENFORCE
(
graph
->
IsMainGraph
(),
platform
::
errors
::
PreconditionNotMet
(
"Pass(apply in main graph) will delete isolated nodes in "
"all subgraphs. Do not apply pass in subgraph."
));
std
::
unordered_set
<
std
::
string
>
reserved_persistable_node_names
;
for
(
size_t
i
=
0
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
CollectReservedPersistableNodeNames
(
graph
->
GetSubGraph
(
i
),
&
reserved_persistable_node_names
);
}
int
delete_counts
=
0
;
std
::
unordered_set
<
std
::
string
>
delete_node_names
;
for
(
size_t
i
=
0
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
delete_counts
+=
RemoveIsolatedNodes
(
graph
->
GetSubGraph
(
i
),
reserved_persistable_node_names
,
&
delete_node_names
);
}
if
(
delete_counts
>
0
)
{
LOG
(
INFO
)
<<
"--- delete "
<<
delete_counts
<<
" isolated nodes"
;
}
std
::
map
<
int
,
Graph
*>
block_id_graph_map
;
for
(
size_t
i
=
0
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
auto
*
sub_graph
=
graph
->
GetSubGraph
(
i
);
for
(
auto
*
node
:
sub_graph
->
Nodes
())
{
if
(
node
->
IsVar
())
{
block_id_graph_map
[
node
->
GetVarNodeBlockId
()]
=
sub_graph
;
break
;
}
}
}
int
update_counts
=
0
;
for
(
size_t
i
=
0
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
update_counts
+=
UpdateControlFlowOp
(
graph
->
GetSubGraph
(
i
),
block_id_graph_map
,
delete_node_names
);
}
if
(
update_counts
>
0
)
{
LOG
(
INFO
)
<<
"--- update "
<<
update_counts
<<
" control flow ops"
;
}
}
void
DeleteIsolatedNodePass
::
CollectReservedPersistableNodeNames
(
Graph
*
graph
,
std
::
unordered_set
<
std
::
string
>*
reserved_persistable_node_names
)
const
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsVar
()
||
!
node
->
Var
()
->
Persistable
())
continue
;
for
(
auto
*
out_node
:
node
->
outputs
)
{
auto
op_type
=
out_node
->
Op
()
->
Type
();
if
(
control_flow_op_input_map_
.
count
(
op_type
)
==
0
)
{
reserved_persistable_node_names
->
insert
(
node
->
Var
()
->
Name
());
break
;
}
}
}
}
int
DeleteIsolatedNodePass
::
RemoveIsolatedNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
reserved_persistable_node_names
,
std
::
unordered_set
<
std
::
string
>*
delete_node_names
)
const
{
BlockDesc
*
block
=
nullptr
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
block
=
node
->
Op
()
->
Block
();
}
}
Scope
&
scope
=
graph
->
Get
<
framework
::
Scope
>
(
"__param_scope__"
);
// If graph has nodes to delete:
// 1. Clear var_desc in block
// 2. Clear tensor in variable
// 3. Clear variable in scope
int
delete_node_counts
=
0
;
std
::
unordered_set
<
const
Node
*>
delete_nodes
;
const
std
::
unordered_set
<
ir
::
Node
*>
nodes
=
graph
->
Nodes
();
for
(
auto
*
node
:
nodes
)
{
if
(
!
node
->
IsVar
()
||
!
node
->
Var
()
->
Persistable
())
continue
;
auto
name
=
node
->
Var
()
->
Name
();
if
(
reserved_persistable_node_names
.
count
(
name
)
>
0
)
continue
;
delete_nodes
.
insert
(
node
);
delete_node_names
->
insert
(
node
->
Name
());
block
->
RemoveVar
(
name
);
auto
*
var
=
scope
.
FindVar
(
name
);
if
(
var
!=
nullptr
)
{
var
->
Clear
();
scope
.
EraseVars
({
name
});
}
delete_node_counts
++
;
}
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
return
delete_node_counts
;
}
int
DeleteIsolatedNodePass
::
UpdateControlFlowOp
(
Graph
*
graph
,
const
std
::
map
<
int
,
Graph
*>&
block_id_graph_map
,
const
std
::
unordered_set
<
std
::
string
>&
delete_node_names
)
const
{
int
update_counts
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
())
continue
;
auto
op_type
=
node
->
Op
()
->
Type
();
if
(
control_flow_op_input_map_
.
count
(
op_type
)
==
0
)
continue
;
auto
in_arg_name
=
control_flow_op_input_map_
.
at
(
op_type
);
auto
in_name
=
node
->
Op
()
->
Input
(
in_arg_name
);
std
::
unordered_set
<
std
::
string
>
in_names_set
(
in_name
.
begin
(),
in_name
.
end
());
for
(
auto
delete_node_name
:
delete_node_names
)
{
if
(
in_names_set
.
count
(
delete_node_name
)
>
0
)
{
in_names_set
.
erase
(
delete_node_name
);
}
}
auto
*
sub_block
=
PADDLE_GET_CONST
(
framework
::
BlockDesc
*
,
node
->
Op
()
->
GetAttr
(
"sub_block"
));
auto
*
sub_graph
=
block_id_graph_map
.
at
(
sub_block
->
ID
());
std
::
unordered_set
<
std
::
string
>
sub_persistable_node_names
;
CollectReservedPersistableNodeNames
(
sub_graph
,
&
sub_persistable_node_names
);
for
(
auto
sub_name
:
sub_persistable_node_names
)
{
if
(
in_names_set
.
count
(
sub_name
)
>
0
)
continue
;
auto
*
in_node
=
FindNodeWithName
(
graph
,
sub_name
);
if
(
in_node
==
nullptr
)
continue
;
in_names_set
.
insert
(
sub_name
);
IR_NODE_LINK_TO
(
in_node
,
node
);
}
std
::
vector
<
std
::
string
>
new_in_names
(
in_names_set
.
begin
(),
in_names_set
.
end
());
node
->
Op
()
->
SetInput
(
in_arg_name
,
new_in_names
);
update_counts
++
;
}
return
update_counts
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_isolated_node_pass
,
paddle
::
framework
::
ir
::
DeleteIsolatedNodePass
);
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc
0 → 100644
浏览文件 @
39a9abaa
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
VarDesc
*
Data
(
paddle
::
framework
::
BlockDesc
*
block
,
std
::
string
name
,
std
::
vector
<
int64_t
>
shape
=
{},
bool
is_persistable
=
false
,
proto
::
VarType
::
Type
data_type
=
proto
::
VarType
::
FP32
)
{
auto
*
var
=
block
->
Var
(
name
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
var
->
SetDataType
(
data_type
);
var
->
SetShape
(
shape
);
var
->
SetPersistable
(
is_persistable
);
return
var
;
}
void
AddVarToScope
(
Scope
*
param_scope
,
const
std
::
string
&
name
,
const
DDim
&
dims
)
{
auto
*
tensor
=
param_scope
->
Var
(
name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
tensor
->
Resize
(
dims
);
auto
*
cpu_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
phi
::
CPUPlace
()));
auto
*
data
=
cpu_ctx
->
Alloc
<
float
>
(
tensor
);
int64_t
numel
=
tensor
->
numel
();
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
data
[
i
]
=
1
;
}
}
Scope
*
CreateParamScope
()
{
auto
param_scope
=
new
Scope
();
AddVarToScope
(
param_scope
,
"matmul0_w"
,
{
128
,
128
});
return
param_scope
;
}
int
WeightNodeNum
(
ir
::
Graph
*
graph
)
{
int
num
=
0
;
for
(
auto
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
()
&&
node
->
Var
()
->
Persistable
())
{
num
++
;
}
}
return
num
;
}
int
WeightTensorNum
(
Scope
*
scope
)
{
int
num
=
0
;
auto
vars
=
scope
->
LocalVars
();
for
(
auto
*
var
:
vars
)
{
if
(
var
->
Get
<
phi
::
DenseTensor
>
().
numel
()
>
0
)
{
num
++
;
}
}
return
num
;
}
TEST
(
delete_isolated_node_pass
,
basic
)
{
paddle
::
framework
::
ProgramDesc
program
;
auto
*
block0
=
program
.
MutableBlock
(
0
);
auto
*
block1
=
program
.
AppendBlock
(
*
block0
);
auto
*
matmul0_x
=
Data
(
block0
,
"matmul0_x"
,
{
1
,
128
});
auto
*
matmul0_w
=
Data
(
block0
,
"matmul0_w"
,
{
128
,
128
},
true
);
auto
*
matmul0_out
=
Data
(
block0
,
"matmul0_out"
,
{
1
,
128
});
OpDesc
*
matmul_op
=
block0
->
AppendOp
();
matmul_op
->
SetType
(
"matmul_v2"
);
matmul_op
->
SetInput
(
"X"
,
{
matmul0_x
->
Name
()});
matmul_op
->
SetInput
(
"Y"
,
{
matmul0_w
->
Name
()});
matmul_op
->
SetAttr
(
"trans_x"
,
false
);
matmul_op
->
SetAttr
(
"trans_y"
,
false
);
matmul_op
->
SetOutput
(
"Out"
,
{
matmul0_out
->
Name
()});
auto
*
while_out
=
Data
(
block0
,
"while_out"
,
{
1
,
128
});
auto
*
while_step_scopes
=
Data
(
block0
,
"while_step_scopes"
);
auto
*
while_cond
=
Data
(
block0
,
"while_cond"
);
OpDesc
*
while_op
=
block0
->
AppendOp
();
while_op
->
SetType
(
"while"
);
while_op
->
SetInput
(
"X"
,
{
matmul0_w
->
Name
(),
matmul0_out
->
Name
()});
while_op
->
SetInput
(
"Condition"
,
{
while_cond
->
Name
()});
while_op
->
SetOutput
(
"Out"
,
{
while_out
->
Name
()});
while_op
->
SetOutput
(
"StepScopes"
,
{
while_step_scopes
->
Name
()});
while_op
->
SetAttr
(
"sub_block"
,
{
block1
});
while_op
->
SetAttr
(
"is_test"
,
true
);
auto
*
matmul1_x
=
Data
(
block1
,
matmul0_out
->
Name
(),
matmul0_out
->
GetShape
());
auto
*
matmul1_w
=
Data
(
block1
,
matmul0_w
->
Name
(),
matmul0_w
->
GetShape
(),
true
);
auto
*
matmul1_out
=
Data
(
block1
,
"matmul1_out"
,
{
1
,
128
});
OpDesc
*
matmul1_op
=
block1
->
AppendOp
();
matmul1_op
->
SetType
(
"matmul_v2"
);
matmul1_op
->
SetInput
(
"X"
,
{
matmul1_x
->
Name
()});
matmul1_op
->
SetInput
(
"Y"
,
{
matmul1_w
->
Name
()});
matmul1_op
->
SetAttr
(
"trans_x"
,
false
);
matmul1_op
->
SetAttr
(
"trans_y"
,
false
);
matmul1_op
->
SetOutput
(
"Out"
,
{
matmul1_out
->
Name
()});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
program
));
auto
*
scope
=
CreateParamScope
();
graph
->
Set
(
"__param_scope__"
,
scope
);
auto
pass0
=
PassRegistry
::
Instance
().
Get
(
"fc_xpu_fuse_pass"
);
pass0
->
Apply
(
graph
.
get
());
pass0
->
Apply
(
graph
->
GetSubGraph
(
1
));
int
weight_node_num
=
WeightNodeNum
(
graph
.
get
())
+
WeightNodeNum
(
graph
->
GetSubGraph
(
1
));
PADDLE_ENFORCE_EQ
(
weight_node_num
,
6
,
platform
::
errors
::
PreconditionNotMet
(
"Graph should have 6 weight node after "
"fc_xpu_fuse_pass, but actually has %d."
,
weight_node_num
));
auto
pass1
=
PassRegistry
::
Instance
().
Get
(
"delete_isolated_node_pass"
);
pass1
->
Apply
(
graph
.
get
());
weight_node_num
=
WeightNodeNum
(
graph
.
get
())
+
WeightNodeNum
(
graph
->
GetSubGraph
(
1
));
PADDLE_ENFORCE_EQ
(
weight_node_num
,
4
,
platform
::
errors
::
PreconditionNotMet
(
"Graph should have 4 weight node after "
"delete_isolated_node_pass, but actually has %d."
,
weight_node_num
));
int
weight_tensor_num
=
WeightTensorNum
(
scope
);
PADDLE_ENFORCE_EQ
(
weight_tensor_num
,
2
,
platform
::
errors
::
PreconditionNotMet
(
"Scope should have 2 weight tensor after "
"delete_isolated_node_pass, but actually has %d."
,
weight_tensor_num
));
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"while"
)
{
auto
while_in_names
=
node
->
Op
()
->
Inputs
().
at
(
"X"
);
PADDLE_ENFORCE_EQ
(
while_in_names
.
size
(),
3
,
platform
::
errors
::
PreconditionNotMet
(
"While op should have 3 input after "
"delete_isolated_node_pass, but actually has %d."
,
while_in_names
.
size
()));
}
}
Scope
&
scope0
=
graph
->
Get
<
framework
::
Scope
>
(
"__param_scope__"
);
Scope
&
scope1
=
graph
->
GetSubGraph
(
1
)
->
Get
<
framework
::
Scope
>
(
"__param_scope__"
);
std
::
vector
<
std
::
string
>
shared_weight_names
{
matmul0_w
->
Name
()
+
"_int16"
,
matmul0_w
->
Name
()
+
"_max"
};
for
(
auto
name
:
shared_weight_names
)
{
auto
*
var0
=
scope0
.
FindVar
(
name
);
auto
*
var1
=
scope1
.
FindVar
(
name
);
PADDLE_ENFORCE
(
var0
==
var1
,
platform
::
errors
::
PreconditionNotMet
(
"Variables with the same name in two scopes is different."
));
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
delete_isolated_node_pass
);
paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
浏览文件 @
39a9abaa
...
@@ -76,7 +76,6 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern,
...
@@ -76,7 +76,6 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern,
auto
*
mul_w
=
pattern
->
NewNode
(
mul_w_repr
())
auto
*
mul_w
=
pattern
->
NewNode
(
mul_w_repr
())
->
assert_is_op_input
(
mul_type_
,
"Y"
)
->
assert_is_op_input
(
mul_type_
,
"Y"
)
->
assert_is_persistable_var
()
->
assert_is_persistable_var
()
->
assert_has_n_outputs
(
1
)
->
assert_more
([](
Node
*
node
)
{
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
2
;
return
node
->
Var
()
->
GetShape
().
size
()
==
2
;
});
});
...
@@ -169,10 +168,10 @@ class FcXPUFusePass : public FusePassBase {
...
@@ -169,10 +168,10 @@ class FcXPUFusePass : public FusePassBase {
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
private:
void
ApplyImpl
(
ir
::
Graph
*
graph
,
int
ApplyImpl
(
ir
::
Graph
*
graph
,
const
std
::
string
&
mul_type
,
const
std
::
string
&
mul_type
,
bool
with_bias
,
bool
with_bias
,
const
std
::
string
&
act_type
)
const
;
const
std
::
string
&
act_type
)
const
;
const
std
::
string
name_scope_
{
"fc_xpu_fuse_pass"
};
const
std
::
string
name_scope_
{
"fc_xpu_fuse_pass"
};
};
};
...
@@ -181,6 +180,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -181,6 +180,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
Init
(
name_scope_
,
graph
);
int
found_subgraph_count
=
0
;
for
(
auto
mul_type
:
{
"mul"
,
"matmul"
,
"matmul_v2"
})
{
for
(
auto
mul_type
:
{
"mul"
,
"matmul"
,
"matmul_v2"
})
{
for
(
auto
with_bias
:
{
true
,
false
})
{
for
(
auto
with_bias
:
{
true
,
false
})
{
for
(
auto
act_type
:
{
for
(
auto
act_type
:
{
...
@@ -189,16 +190,17 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -189,16 +190,17 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
"tanh"
,
"tanh"
,
""
,
""
,
})
{
})
{
ApplyImpl
(
graph
,
mul_type
,
with_bias
,
act_type
);
found_subgraph_count
+=
ApplyImpl
(
graph
,
mul_type
,
with_bias
,
act_type
);
}
}
}
}
}
}
AddStatis
(
found_subgraph_count
);
}
}
void
FcXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
,
int
FcXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
,
const
std
::
string
&
mul_type
,
const
std
::
string
&
mul_type
,
bool
with_bias
,
bool
with_bias
,
const
std
::
string
&
act_type
)
const
{
const
std
::
string
&
act_type
)
const
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
patterns
::
FcXPUPattern
pattern
(
patterns
::
FcXPUPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
,
mul_type
,
with_bias
,
act_type
);
gpd
.
mutable_pattern
(),
name_scope_
,
mul_type
,
with_bias
,
act_type
);
...
@@ -219,37 +221,20 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
...
@@ -219,37 +221,20 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
auto
*
block
=
mul
->
Op
()
->
Block
();
auto
*
block
=
mul
->
Op
()
->
Block
();
auto
*
scope
=
param_scope
();
auto
*
scope
=
param_scope
();
auto
mul_w_name
=
mul_w
->
Name
();
bool
transpose_w
=
false
;
auto
mul_w_tensor
=
if
(
mul_type
==
"matmul"
)
{
scope
->
FindVar
(
mul_w_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
transpose_w
=
PADDLE_GET_CONST
(
bool
,
mul
->
Op
()
->
GetAttr
(
"transpose_Y"
));
// 1. Transform weight to int16/int31
}
else
if
(
mul_type
==
"matmul_v2"
)
{
// 2. Avoid transform repeatly, because weight may be shared with other ops.
transpose_w
=
PADDLE_GET_CONST
(
bool
,
mul
->
Op
()
->
GetAttr
(
"trans_y"
));
// TODO(zhupengyang): support int31
std
::
string
mul_w_max_name
=
mul_w_name
+
"_max"
;
Node
*
mul_w_max
=
nullptr
;
if
(
mul_w_tensor
->
dtype
()
!=
phi
::
DataType
::
INT16
)
{
// Create weight_max node
VarDesc
mul_w_max_desc
(
mul_w_max_name
);
mul_w_max_desc
.
SetPersistable
(
true
);
mul_w_max
=
graph
->
CreateVarNode
(
&
mul_w_max_desc
);
// Create weight_max var/tensor
auto
mul_w_max_var
=
block
->
Var
(
mul_w_max_name
);
mul_w_max_var
->
SetPersistable
(
true
);
auto
mul_w_max_tensor
=
scope
->
Var
(
mul_w_max_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
bool
transpose_w
=
false
;
if
(
mul_type
==
"matmul"
)
{
transpose_w
=
PADDLE_GET_CONST
(
bool
,
mul
->
Op
()
->
GetAttr
(
"transpose_Y"
));
}
else
if
(
mul_type
==
"matmul_v2"
)
{
transpose_w
=
PADDLE_GET_CONST
(
bool
,
mul
->
Op
()
->
GetAttr
(
"trans_y"
));
}
QuantWeight
<
int16_t
>
(
mul_w_tensor
,
mul_w_max_tensor
,
!
transpose_w
);
}
}
Node
*
mul_w_int16
=
nullptr
;
Node
*
mul_w_max
=
nullptr
;
PrepareWeight
<
int16_t
>
(
graph
,
scope
,
block
,
mul_w
,
&
mul_w_int16
,
&
mul_w_max
,
!
transpose_w
);
Node
*
bias_fp32
=
nullptr
;
if
(
bias
!=
nullptr
)
{
if
(
bias
!=
nullptr
)
{
auto
*
bias_tensor
=
PrepareBias
(
graph
,
scope
,
block
,
bias
,
&
bias_fp32
);
scope
->
Var
(
bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
CastToFp32
(
bias_tensor
);
}
}
std
::
string
fc_out_name
;
std
::
string
fc_out_name
;
...
@@ -268,10 +253,10 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
...
@@ -268,10 +253,10 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
framework
::
OpDesc
fc_xpu_op_desc
(
block
);
framework
::
OpDesc
fc_xpu_op_desc
(
block
);
fc_xpu_op_desc
.
SetType
(
"fc_xpu"
);
fc_xpu_op_desc
.
SetType
(
"fc_xpu"
);
fc_xpu_op_desc
.
SetInput
(
"x"
,
{
mul_x
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"x"
,
{
mul_x
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"w"
,
{
mul_w
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"w"
,
{
mul_w
_int16
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"w_max"
,
{
mul_w_max
_name
});
fc_xpu_op_desc
.
SetInput
(
"w_max"
,
{
mul_w_max
->
Name
()
});
if
(
bias
)
{
if
(
bias
_fp32
)
{
fc_xpu_op_desc
.
SetInput
(
"bias"
,
{
bias
->
Name
()});
fc_xpu_op_desc
.
SetInput
(
"bias"
,
{
bias
_fp32
->
Name
()});
}
}
fc_xpu_op_desc
.
SetAttr
(
fc_xpu_op_desc
.
SetAttr
(
"in_num_col_dims"
,
"in_num_col_dims"
,
...
@@ -306,9 +291,9 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
...
@@ -306,9 +291,9 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
fc_xpu_op_desc
.
SetOutput
(
"out_max"
,
{
fc_out_max_name
});
fc_xpu_op_desc
.
SetOutput
(
"out_max"
,
{
fc_out_max_name
});
auto
*
fc_xpu
=
graph
->
CreateOpNode
(
&
fc_xpu_op_desc
);
auto
*
fc_xpu
=
graph
->
CreateOpNode
(
&
fc_xpu_op_desc
);
IR_NODE_LINK_TO
(
mul_x
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_x
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_w
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_w
_int16
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_w_max
,
fc_xpu
);
IR_NODE_LINK_TO
(
mul_w_max
,
fc_xpu
);
SAFE_IR_NODE_LINK_TO
(
bias
,
fc_xpu
);
SAFE_IR_NODE_LINK_TO
(
bias
_fp32
,
fc_xpu
);
if
(
act_out
)
{
if
(
act_out
)
{
IR_NODE_LINK_TO
(
fc_xpu
,
act_out
);
IR_NODE_LINK_TO
(
fc_xpu
,
act_out
);
}
else
if
(
add_out
)
{
}
else
if
(
add_out
)
{
...
@@ -334,7 +319,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
...
@@ -334,7 +319,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
};
};
gpd
(
graph
,
handler
);
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
)
;
return
found_subgraph_count
;
}
}
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc
浏览文件 @
39a9abaa
此差异已折叠。
点击以展开。
paddle/fluid/framework/ir/xpu/pass_utils.cc
浏览文件 @
39a9abaa
...
@@ -20,6 +20,18 @@ namespace paddle {
...
@@ -20,6 +20,18 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
void
HashCombine
(
std
::
size_t
*
seed
)
{}
// combine hash value
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
template
<
typename
T
,
typename
...
Rest
>
static
void
HashCombine
(
std
::
size_t
*
seed
,
const
T
&
v
,
Rest
...
rest
)
{
std
::
hash
<
T
>
hasher
;
*
seed
^=
hasher
(
v
)
+
0x9e3779b9
+
(
*
seed
<<
6
)
+
(
*
seed
>>
2
);
*
seed
*=
0x00000100000001B3
;
HashCombine
(
seed
,
rest
...);
}
int
ConvertActivationType
(
std
::
string
act_type
)
{
int
ConvertActivationType
(
std
::
string
act_type
)
{
if
(
act_type
==
""
)
{
if
(
act_type
==
""
)
{
return
static_cast
<
int
>
(
xpu
::
Activation_t
::
LINEAR
);
return
static_cast
<
int
>
(
xpu
::
Activation_t
::
LINEAR
);
...
@@ -50,6 +62,161 @@ int ConvertActivationType(std::string act_type) {
...
@@ -50,6 +62,161 @@ int ConvertActivationType(std::string act_type) {
return
-
1
;
return
-
1
;
}
}
Node
*
FindNodeWithName
(
Graph
*
graph
,
std
::
string
name
)
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
()
&&
node
->
Var
()
->
Name
()
==
name
)
{
return
node
;
}
}
return
nullptr
;
}
template
<
typename
T
>
std
::
string
IntTypeToString
()
{
LOG
(
FATAL
)
<<
"Not support type."
;
return
""
;
}
template
<
>
std
::
string
IntTypeToString
<
int16_t
>
()
{
return
"int16"
;
}
template
<
typename
T
>
size_t
HashTensor
(
const
phi
::
DenseTensor
&
in
)
{
size_t
ret
=
0
;
auto
in_dims
=
in
.
dims
();
HashCombine
(
&
ret
,
phi
::
DataTypeToString
(
in
.
dtype
()),
phi
::
DataLayoutToString
(
in
.
layout
()),
in_dims
.
size
());
for
(
int
i
=
0
;
i
<
in_dims
.
size
();
i
++
)
{
HashCombine
(
&
ret
,
in_dims
[
i
]);
}
auto
*
data
=
in
.
data
<
T
>
();
int64_t
size
=
in
.
numel
();
for
(
int64_t
i
=
0
;
i
<
size
;
i
++
)
{
HashCombine
(
&
ret
,
data
[
i
]);
}
return
ret
;
}
template
size_t
HashTensor
<
int16_t
>(
const
phi
::
DenseTensor
&
in
);
template
size_t
HashTensor
<
float
>(
const
phi
::
DenseTensor
&
in
);
template
<
typename
T
>
void
PrepareWeight
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
,
Node
**
dst_max
,
bool
transpose
)
{
auto
src_name
=
src
->
Name
();
auto
*
src_tensor
=
scope
->
Var
(
src_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
phi
::
DenseTensor
dst_tensor
;
Assign
(
*
src_tensor
,
&
dst_tensor
);
phi
::
DenseTensor
dst_max_tensor
;
PrepareWeight
<
T
>
(
&
dst_tensor
,
&
dst_max_tensor
,
transpose
);
size_t
dst_hash
=
HashTensor
<
T
>
(
dst_tensor
);
size_t
dst_max_hash
=
HashTensor
<
float
>
(
dst_max_tensor
);
std
::
string
dst_name
=
src_name
+
"_"
+
std
::
to_string
(
dst_hash
);
std
::
string
dst_max_name
=
src_name
+
"_max_"
+
std
::
to_string
(
dst_max_hash
);
*
dst
=
FindNodeWithName
(
graph
,
dst_name
);
if
(
*
dst
==
nullptr
)
{
// Create dst node
// Update dst var_desc in block
VarDesc
dst_desc
(
dst_name
);
dst_desc
.
SetPersistable
(
true
);
dst_desc
.
SetShape
(
vectorize
(
dst_tensor
.
dims
()));
dst_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
dst_tensor
.
dtype
()));
*
dst
=
graph
->
CreateVarNode
(
&
dst_desc
);
auto
*
block_dst_desc
=
block
->
Var
(
dst_name
);
block_dst_desc
->
SetPersistable
(
dst_desc
.
Persistable
());
block_dst_desc
->
SetShape
(
dst_desc
.
GetShape
());
block_dst_desc
->
SetDataType
(
dst_desc
.
GetDataType
());
// Create dst_max node
// Update dst_max var_desc in block
VarDesc
dst_max_desc
(
dst_max_name
);
dst_max_desc
.
SetPersistable
(
true
);
dst_max_desc
.
SetShape
(
vectorize
(
dst_max_tensor
.
dims
()));
dst_max_desc
.
SetDataType
(
proto
::
VarType
::
Type
::
VarType_Type_FP32
);
*
dst_max
=
graph
->
CreateVarNode
(
&
dst_max_desc
);
auto
*
block_dst_max_desc
=
block
->
Var
(
dst_max_name
);
block_dst_max_desc
->
SetPersistable
(
dst_max_desc
.
Persistable
());
block_dst_max_desc
->
SetShape
(
dst_max_desc
.
GetShape
());
block_dst_max_desc
->
SetDataType
(
dst_max_desc
.
GetDataType
());
// Find dst/dst_max variable in scope
auto
*
dst_var
=
scope
->
FindVar
(
dst_name
);
if
(
dst_var
==
nullptr
)
{
// Create dst/dst_max variable/tensor
Assign
(
dst_tensor
,
scope
->
Var
(
dst_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
Assign
(
dst_max_tensor
,
scope
->
Var
(
dst_max_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
}
else
{
// Share the same variable
PADDLE_ENFORCE_NOT_NULL
(
scope
->
FindVar
(
dst_max_name
),
platform
::
errors
::
Fatal
(
"dst_max(%s) variable should not be nullptr if dst(%s) "
"variable is exist. (src_name is %s)"
,
dst_max_name
,
dst_name
,
src_name
));
}
}
else
{
*
dst_max
=
FindNodeWithName
(
graph
,
dst_max_name
);
PADDLE_ENFORCE_NOT_NULL
(
*
dst_max
,
platform
::
errors
::
Fatal
(
"dst_max(%s) variable should not be nullptr if dst(%s) "
"variable is exist. (src_name is %s)"
,
dst_max_name
,
dst_name
,
src_name
));
}
}
template
void
PrepareWeight
<
int16_t
>(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
,
Node
**
dst_max
,
bool
transpose
);
void
PrepareBias
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
)
{
auto
src_name
=
src
->
Name
();
auto
*
src_tensor
=
scope
->
Var
(
src_name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
src_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
*
dst
=
src
;
}
phi
::
DenseTensor
dst_tensor
;
CastToFp32
(
src_tensor
,
&
dst_tensor
);
size_t
dst_hash
=
HashTensor
<
float
>
(
dst_tensor
);
std
::
string
dst_name
=
src_name
+
"_"
+
std
::
to_string
(
dst_hash
);
*
dst
=
FindNodeWithName
(
graph
,
dst_name
);
if
(
*
dst
==
nullptr
)
{
// Create dst node
// Update dst var_desc in block
VarDesc
dst_desc
(
dst_name
);
dst_desc
.
SetPersistable
(
true
);
dst_desc
.
SetShape
(
vectorize
(
dst_tensor
.
dims
()));
dst_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
dst_tensor
.
dtype
()));
*
dst
=
graph
->
CreateVarNode
(
&
dst_desc
);
auto
*
block_dst_desc
=
block
->
Var
(
dst_name
);
block_dst_desc
->
SetPersistable
(
dst_desc
.
Persistable
());
block_dst_desc
->
SetShape
(
dst_desc
.
GetShape
());
block_dst_desc
->
SetDataType
(
dst_desc
.
GetDataType
());
Assign
(
dst_tensor
,
scope
->
Var
(
dst_name
)
->
GetMutable
<
phi
::
DenseTensor
>
());
}
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/xpu/pass_utils.h
浏览文件 @
39a9abaa
...
@@ -14,6 +14,10 @@
...
@@ -14,6 +14,10 @@
#pragma once
#pragma once
#include <string>
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -45,6 +49,23 @@ namespace ir {
...
@@ -45,6 +49,23 @@ namespace ir {
int
ConvertActivationType
(
std
::
string
act_type
);
int
ConvertActivationType
(
std
::
string
act_type
);
Node
*
FindNodeWithName
(
Graph
*
graph
,
std
::
string
name
);
template
<
typename
T
>
size_t
HashTensor
(
const
phi
::
DenseTensor
&
in
);
template
<
typename
T
>
void
PrepareWeight
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
,
Node
**
dst_max
,
bool
transpose
);
void
PrepareBias
(
Graph
*
graph
,
Scope
*
scope
,
BlockDesc
*
block
,
Node
*
src
,
Node
**
dst
);
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/xpu/quant_utils.cc
浏览文件 @
39a9abaa
...
@@ -207,9 +207,9 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr,
...
@@ -207,9 +207,9 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr,
}
}
template
<
typename
T
>
template
<
typename
T
>
void
Quant
Weight
(
phi
::
DenseTensor
*
weight
,
void
Prepare
Weight
(
phi
::
DenseTensor
*
weight
,
phi
::
DenseTensor
*
weight_max
,
phi
::
DenseTensor
*
weight_max
,
bool
transpose
)
{
bool
transpose
)
{
// Convert fp16 to fp32
// Convert fp16 to fp32
phi
::
DenseTensor
weight_fp32
;
phi
::
DenseTensor
weight_fp32
;
CastToFp32
(
weight
,
&
weight_fp32
);
CastToFp32
(
weight
,
&
weight_fp32
);
...
@@ -249,9 +249,9 @@ void QuantWeight(phi::DenseTensor* weight,
...
@@ -249,9 +249,9 @@ void QuantWeight(phi::DenseTensor* weight,
QuantFP32ToIntX
(
weight_data
,
cpu_ctx
->
Alloc
<
T
>
(
weight
),
max_val
,
size
);
QuantFP32ToIntX
(
weight_data
,
cpu_ctx
->
Alloc
<
T
>
(
weight
),
max_val
,
size
);
}
}
template
void
Quant
Weight
<
int16_t
>(
phi
::
DenseTensor
*
weight
,
template
void
Prepare
Weight
<
int16_t
>(
phi
::
DenseTensor
*
weight
,
phi
::
DenseTensor
*
weight_max
,
phi
::
DenseTensor
*
weight_max
,
bool
transpose
);
bool
transpose
);
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/ir/xpu/quant_utils.h
浏览文件 @
39a9abaa
...
@@ -29,9 +29,9 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
...
@@ -29,9 +29,9 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
// 2. Weight data is in-place update.
// 2. Weight data is in-place update.
// 3. Generate weight max tensor
// 3. Generate weight max tensor
template
<
typename
T
>
template
<
typename
T
>
void
Quant
Weight
(
phi
::
DenseTensor
*
weight
,
void
Prepare
Weight
(
phi
::
DenseTensor
*
weight
,
phi
::
DenseTensor
*
weight_max
,
phi
::
DenseTensor
*
weight_max
,
bool
transpose
);
bool
transpose
);
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
39a9abaa
...
@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
...
@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fc_xpu_fuse_pass"
,
"fc_xpu_fuse_pass"
,
"link_xpu_op_max_pass"
,
"link_xpu_op_max_pass"
,
"delete_op_device_pass"
,
"delete_op_device_pass"
,
"delete_isolated_node_pass"
,
});
});
use_xpu_
=
true
;
use_xpu_
=
true
;
}
}
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py
浏览文件 @
39a9abaa
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
unittest
import
unittest
import
hypothesis.strategies
as
st
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
...
@@ -33,7 +34,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
...
@@ -33,7 +34,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
)
)
matmul0_y_shape
=
draw
(
matmul0_y_shape
=
draw
(
st
.
lists
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
8
),
min_size
=
2
,
max_size
=
2
st
.
integers
(
min_value
=
2
,
max_value
=
8
),
min_size
=
2
,
max_size
=
2
)
)
)
)
matmul0_y_shape
[
0
]
=
matmul0_x_shape
[
-
1
]
matmul0_y_shape
[
0
]
=
matmul0_x_shape
[
-
1
]
...
@@ -42,7 +43,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
...
@@ -42,7 +43,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
# 3. matmul1
# 3. matmul1
matmul1_y_shape
=
draw
(
matmul1_y_shape
=
draw
(
st
.
lists
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
8
),
min_size
=
2
,
max_size
=
2
st
.
integers
(
min_value
=
2
,
max_value
=
8
),
min_size
=
2
,
max_size
=
2
)
)
)
)
matmul1_y_shape
[
0
]
=
matmul0_y_shape
[
-
1
]
matmul1_y_shape
[
0
]
=
matmul0_y_shape
[
-
1
]
...
@@ -101,4 +102,5 @@ class TestFcXPUFusePass(PassAutoScanTest):
...
@@ -101,4 +102,5 @@ class TestFcXPUFusePass(PassAutoScanTest):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
200
)
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录