Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
17d6d932
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
17d6d932
编写于
6月 02, 2023
作者:
W
wz1qqx
提交者:
GitHub
6月 02, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU]fuse small ops of idg models (#54245)
上级
a087b9cb
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
420 addition
and
6 deletion
+420
-6
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+6
-0
paddle/fluid/framework/ir/pass_tester_helper.h
paddle/fluid/framework/ir/pass_tester_helper.h
+17
-6
paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.cc
...e/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.cc
+229
-0
paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h
...le/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h
+74
-0
paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass_test.cc
...id/framework/ir/xpu/fold_interp_outsize_fuse_pass_test.cc
+58
-0
paddle/fluid/framework/ir/xpu/quant_utils.cc
paddle/fluid/framework/ir/xpu/quant_utils.cc
+33
-0
paddle/fluid/framework/ir/xpu/quant_utils.h
paddle/fluid/framework/ir/xpu/quant_utils.h
+2
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
17d6d932
...
...
@@ -252,6 +252,8 @@ if(WITH_XPU)
xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
add_activation_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
fold_interp_outsize_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
endif
()
cc_library
(
...
...
@@ -536,4 +538,8 @@ if(WITH_XPU)
test_multi_encoder_xpu_adaptive_seqlen_fuse_pass
SRCS xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass_test.cc
DEPS multi_encoder_xpu_adaptive_seqlen_fuse_pass
)
cc_test
(
test_fold_interp_outsize_fuse_pass
SRCS xpu/fold_interp_outsize_fuse_pass_test.cc
DEPS fold_interp_outsize_fuse_pass
)
endif
()
paddle/fluid/framework/ir/pass_tester_helper.h
浏览文件 @
17d6d932
...
...
@@ -361,20 +361,31 @@ struct Layers {
return
outs
;
}
std
::
vector
<
VarDesc
*>
split
(
VarDesc
*
x
,
int
num_or_section
,
int
axis
=
0
)
{
std
::
vector
<
VarDesc
*>
outs
(
num_or_section
);
for
(
int
i
=
0
;
i
<
num_or_section
;
i
++
)
{
std
::
vector
<
VarDesc
*>
split
(
VarDesc
*
x
,
int
num_or_section
=
0
,
int
axis
=
0
,
std
::
vector
<
int
>
sections
=
{
-
1
})
{
int
out_num
=
num_or_section
;
if
(
num_or_section
==
0
)
{
out_num
=
sections
.
size
();
}
std
::
vector
<
VarDesc
*>
outs
(
out_num
);
for
(
int
i
=
0
;
i
<
out_num
;
i
++
)
{
outs
[
i
]
=
lod_tensor
(
unique_name
());
}
std
::
vector
<
std
::
string
>
out_names
(
num_or_section
);
for
(
int
i
=
0
;
i
<
num_or_section
;
i
++
)
{
std
::
vector
<
std
::
string
>
out_names
(
out_num
);
for
(
int
i
=
0
;
i
<
out_num
;
i
++
)
{
out_names
[
i
]
=
outs
[
i
]
->
Name
();
}
OpDesc
*
op
=
program_
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"split"
);
op
->
SetInput
(
"X"
,
{
x
->
Name
()});
op
->
SetOutput
(
"Out"
,
out_names
);
op
->
SetAttr
(
"num_or_section"
,
num_or_section
);
if
(
num_or_section
==
0
)
{
op
->
SetAttr
(
"sections"
,
sections
);
}
else
{
op
->
SetAttr
(
"num_or_section"
,
num_or_section
);
}
op
->
SetAttr
(
"axis"
,
axis
);
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
...
...
paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.cc
0 → 100644
浏览文件 @
17d6d932
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
DetectorFusePattern
:
public
PatternBase
{
DetectorFusePattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
);
// declare operator node's name
PATTERN_DECL_NODE
(
shape
);
PATTERN_DECL_NODE
(
cast1
);
PATTERN_DECL_NODE
(
slice
);
PATTERN_DECL_NODE
(
concat
);
PATTERN_DECL_NODE
(
split
);
PATTERN_DECL_NODE
(
cast2
);
PATTERN_DECL_NODE
(
bilinear_interp
);
// declare variable node's name
PATTERN_DECL_NODE
(
x
);
PATTERN_DECL_NODE
(
shape_out
);
PATTERN_DECL_NODE
(
cast1_out
);
PATTERN_DECL_NODE
(
slice_out
);
PATTERN_DECL_NODE
(
concat_y
);
PATTERN_DECL_NODE
(
concat_out
);
PATTERN_DECL_NODE
(
split_out_0
);
PATTERN_DECL_NODE
(
split_out_1
);
PATTERN_DECL_NODE
(
cast2_out
);
};
DetectorFusePattern
::
DetectorFusePattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
)
{
auto
*
x
=
pattern
->
NewNode
(
x_repr
())
->
assert_is_op_input
(
"shape"
,
"Input"
)
->
assert_is_op_input
(
"bilinear_interp_v2"
,
"X"
);
auto
*
shape
=
pattern
->
NewNode
(
shape_repr
())
->
assert_is_op
(
"shape"
);
auto
*
shape_out
=
pattern
->
NewNode
(
shape_out_repr
())
->
assert_is_op_output
(
"shape"
,
"Out"
)
->
assert_is_op_input
(
"cast"
,
"X"
);
shape
->
LinksFrom
({
x
}).
LinksTo
({
shape_out
});
auto
*
cast1
=
pattern
->
NewNode
(
cast1_repr
())
->
assert_is_op
(
"cast"
)
->
assert_more
([
&
](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
return
op_desc
->
GetAttrIfExists
<
int
>
(
"in_dtype"
)
==
2
&&
op_desc
->
GetAttrIfExists
<
int
>
(
"out_dtype"
)
==
3
;
});
auto
*
cast1_out
=
pattern
->
NewNode
(
cast1_out_repr
())
->
assert_is_op_output
(
"cast"
,
"Out"
)
->
assert_is_op_input
(
"slice"
,
"Input"
);
cast1
->
LinksFrom
({
shape_out
}).
LinksTo
({
cast1_out
});
auto
*
slice
=
pattern
->
NewNode
(
slice_repr
())
->
assert_is_op
(
"slice"
)
->
assert_more
([
&
](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
return
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"axes"
)
==
std
::
vector
<
int
>
{
0
}
&&
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"starts"
)
==
std
::
vector
<
int
>
{
0
}
&&
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"ends"
)
==
std
::
vector
<
int
>
{
2
};
});
auto
*
slice_out
=
pattern
->
NewNode
(
slice_out_repr
())
->
assert_is_op_output
(
"slice"
,
"Out"
)
->
assert_is_op_nth_input
(
"concat"
,
"X"
,
0
);
slice
->
LinksFrom
({
cast1_out
}).
LinksTo
({
slice_out
});
auto
*
concat
=
pattern
->
NewNode
(
concat_repr
())
->
assert_is_op
(
"concat"
)
->
assert_more
([
&
](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
return
op_desc
->
GetAttrIfExists
<
int
>
(
"axis"
)
==
0
;
});
auto
*
concat_y
=
pattern
->
NewNode
(
concat_y_repr
())
->
assert_is_op_nth_input
(
"concat"
,
"X"
,
1
)
->
assert_is_persistable_var
();
auto
*
concat_out
=
pattern
->
NewNode
(
concat_out_repr
())
->
assert_is_op_output
(
"concat"
,
"Out"
)
->
assert_is_op_input
(
"split"
,
"X"
);
concat
->
LinksFrom
({
slice_out
,
concat_y
}).
LinksTo
({
concat_out
});
auto
*
split
=
pattern
->
NewNode
(
split_repr
())
->
assert_is_op
(
"split"
)
->
assert_more
([
&
](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
return
op_desc
->
GetAttrIfExists
<
int
>
(
"axis"
)
==
0
&&
(
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"sections"
)
==
std
::
vector
<
int
>
{
2
,
2
}
||
op_desc
->
GetAttrIfExists
<
int
>
(
"num"
)
==
2
);
});
auto
*
split_out_0
=
pattern
->
NewNode
(
split_out_0_repr
())
->
assert_is_op_nth_output
(
"split"
,
"Out"
,
0
);
auto
*
split_out_1
=
pattern
->
NewNode
(
split_out_1_repr
())
->
assert_is_op_nth_output
(
"split"
,
"Out"
,
1
)
->
assert_is_op_input
(
"cast"
,
"X"
);
split
->
LinksFrom
({
concat_out
}).
LinksTo
({
split_out_0
,
split_out_1
});
auto
*
cast2
=
pattern
->
NewNode
(
cast2_repr
())
->
assert_is_op
(
"cast"
)
->
assert_more
([
&
](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
return
op_desc
->
GetAttrIfExists
<
int
>
(
"in_dtype"
)
==
3
&&
op_desc
->
GetAttrIfExists
<
int
>
(
"out_dtype"
)
==
2
;
});
auto
*
cast2_out
=
pattern
->
NewNode
(
cast2_out_repr
())
->
assert_is_op_output
(
"cast"
,
"Out"
)
->
assert_is_op_input
(
"bilinear_interp_v2"
,
"OutSize"
);
cast2
->
LinksFrom
({
split_out_1
}).
LinksTo
({
cast2_out
});
auto
*
bilinear_interp
=
pattern
->
NewNode
(
bilinear_interp_repr
())
->
assert_is_op
(
"bilinear_interp_v2"
);
bilinear_interp
->
LinksFrom
({
x
,
cast2_out
});
}
}
// namespace patterns
void
FoldInterpOutsizeFusePass
::
DetectorFuse
(
ir
::
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
patterns
::
DetectorFusePattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle DetectorFuse"
;
/* declare operator node's name */
GET_IR_NODE
(
shape
);
GET_IR_NODE
(
cast1
);
GET_IR_NODE
(
slice
);
GET_IR_NODE
(
concat
);
GET_IR_NODE
(
split
);
GET_IR_NODE
(
cast2
);
GET_IR_NODE
(
bilinear_interp
);
/* declare variable node's name*/
GET_IR_NODE
(
x
);
GET_IR_NODE
(
shape_out
);
GET_IR_NODE
(
cast1_out
);
GET_IR_NODE
(
slice_out
);
GET_IR_NODE
(
concat_y
);
GET_IR_NODE
(
concat_out
);
GET_IR_NODE
(
split_out_0
);
GET_IR_NODE
(
split_out_1
);
GET_IR_NODE
(
cast2_out
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope cannot be nullptr."
));
auto
*
concat_y_t
=
scope
->
GetVar
(
concat_y
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
// concat_y int64 --> int32
auto
tensor_type
=
concat_y_t
->
dtype
();
if
(
tensor_type
==
phi
::
DataType
::
INT64
)
{
CastToInt32
(
concat_y_t
,
nullptr
);
}
bilinear_interp
->
Op
()
->
RenameInput
(
cast2_out
->
Name
(),
concat_y
->
Name
());
IR_NODE_UNLINK
(
x
,
shape
);
IR_NODE_UNLINK
(
cast2_out
,
bilinear_interp
);
IR_NODE_LINK_TO
(
concat_y
,
bilinear_interp
);
// delete useless node
std
::
unordered_set
<
const
Node
*>
delete_nodes
=
{
shape
,
cast1
,
slice
,
concat
,
split
,
cast2
,
shape_out
,
cast1_out
,
slice_out
,
concat_out
,
split_out_0
,
split_out_1
,
cast2_out
};
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
void
FoldInterpOutsizeFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
DetectorFuse
(
graph
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fold_interp_outsize_fuse_pass
,
paddle
::
framework
::
ir
::
FoldInterpOutsizeFusePass
);
REGISTER_PASS_CAPABILITY
(
fold_interp_outsize_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
().
EQ
(
"shape"
,
0
));
paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h
0 → 100644
浏览文件 @
17d6d932
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
FoldInterpOutsizeFusePass
:
public
FusePassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
/*
Origin subgraph:
x
/ \
| shape
| |
| cast
| |
| slice
| |
| concat
| |
| split
| | \
| | \
| outvar_1 outvar_0
| |
| cast
| /
\ /
bilinear_interp_v2
Fused subgraph:
x
| concat_y
| /
bilinear_interp_v2
*/
void
DetectorFuse
(
ir
::
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"fold_interp_outsize_fuse_pass"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass_test.cc
0 → 100644
浏览文件 @
17d6d932
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
TEST
(
DetectorFuse
,
basic
)
{
Layers
layers
;
auto
*
block
=
layers
.
Block
();
auto
*
shape_x
=
layers
.
data
(
"shape_x"
,
{
1
,
18
,
288
,
288
});
auto
*
concat_y
=
layers
.
data
(
"concat_y"
,
{
576
,
576
},
true
,
proto
::
VarType
::
INT64
);
auto
*
shape_out
=
layers
.
shape
(
shape_x
);
auto
*
cast1_out
=
layers
.
cast
(
shape_out
,
2
,
3
);
auto
*
slice_out
=
layers
.
slice
(
cast1_out
,
{
0
},
{
0
},
{
2
});
auto
*
concat_out
=
layers
.
concat
({
slice_out
,
concat_y
},
0
);
auto
split_outs
=
layers
.
split
(
concat_out
,
0
,
0
,
{
2
,
2
});
auto
*
split_out_1
=
split_outs
[
1
];
auto
*
cast2_out
=
layers
.
cast
(
split_out_1
,
3
,
2
);
OpDesc
*
bilinear_interp_v2_op
=
block
->
AppendOp
();
bilinear_interp_v2_op
->
SetType
(
"bilinear_interp_v2"
);
bilinear_interp_v2_op
->
SetInput
(
"X"
,
{
shape_x
->
Name
()});
bilinear_interp_v2_op
->
SetInput
(
"OutSize"
,
{
cast2_out
->
Name
()});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fold_interp_outsize_fuse_pass"
);
pass
->
Apply
(
graph
.
get
());
auto
ops_num
=
GetNumOpNodes
(
graph
);
PADDLE_ENFORCE_EQ
(
ops_num
,
1
,
platform
::
errors
::
PreconditionNotMet
(
"graph should only have 2 op nodes, but received %d."
,
ops_num
));
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
fold_interp_outsize_fuse_pass
);
paddle/fluid/framework/ir/xpu/quant_utils.cc
浏览文件 @
17d6d932
...
...
@@ -70,6 +70,39 @@ void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out) {
}
}
void
CastToInt32
(
phi
::
DenseTensor
*
in
,
phi
::
DenseTensor
*
out
)
{
auto
*
cpu_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
phi
::
CPUPlace
()));
phi
::
DenseTensor
int32_tensor
;
phi
::
DenseTensor
*
out_ptr
=
out
==
nullptr
?
&
int32_tensor
:
out
;
out_ptr
->
Resize
(
in
->
dims
());
out_ptr
->
set_type
(
phi
::
DataType
::
INT32
);
out_ptr
->
set_layout
(
in
->
layout
());
switch
(
in
->
dtype
())
{
case
phi
::
DataType
::
INT64
:
phi
::
CastKernel
<
int64_t
>
(
*
cpu_ctx
,
*
in
,
phi
::
DataType
::
INT32
,
out_ptr
);
break
;
case
phi
::
DataType
::
INT32
:
if
(
out
==
nullptr
)
{
return
;
}
else
{
phi
::
AssignKernel
(
*
cpu_ctx
,
*
in
,
out_ptr
);
}
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support int64 and int32, but received dtype is %s."
,
phi
::
DataTypeToString
(
in
->
dtype
())));
break
;
}
if
(
out
==
nullptr
)
{
Assign
(
*
out_ptr
,
in
);
}
}
void
CastToFp32
(
phi
::
DenseTensor
*
in
,
phi
::
DenseTensor
*
out
)
{
auto
*
cpu_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
phi
::
CPUPlace
()));
...
...
paddle/fluid/framework/ir/xpu/quant_utils.h
浏览文件 @
17d6d932
...
...
@@ -25,6 +25,8 @@ void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
void
CastToFp32
(
phi
::
DenseTensor
*
in
,
phi
::
DenseTensor
*
out
=
nullptr
);
void
CastToInt32
(
phi
::
DenseTensor
*
in
,
phi
::
DenseTensor
*
out
=
nullptr
);
// 1. Quant weight from fp32 to int16/int31
// 2. Weight data is in-place update.
// 3. Generate weight max tensor
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
17d6d932
...
...
@@ -522,6 +522,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"multi_encoder_xpu_slice_fuse_pass"
,
"fused_multi_transformer_cachekv_layout_trans_pass"
,
"one_beam_size_fuse_pass"
,
"fold_interp_outsize_fuse_pass"
,
"delete_cast_op_pass"
,
"stack_fuse_pass"
,
"fused_multi_transformer_xpu_pass"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录