未验证 提交 95c4bb41 编写于 作者: R RichardWooSJTU 提交者: GitHub

fix dynamic to static when export LLM inference model (#56390)

上级 36153898
...@@ -1446,7 +1446,8 @@ void FusedBiasActInferMeta(const MetaTensor& x, ...@@ -1446,7 +1446,8 @@ void FusedBiasActInferMeta(const MetaTensor& x,
int quant_round_type, int quant_round_type,
float quant_max_bound, float quant_max_bound,
float quant_min_bound, float quant_min_bound,
MetaTensor* out) { MetaTensor* out,
MetaConfig config) {
auto x_dims = x.dims(); auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(x_dims.size(), PADDLE_ENFORCE_EQ(x_dims.size(),
2, 2,
...@@ -1455,15 +1456,17 @@ void FusedBiasActInferMeta(const MetaTensor& x, ...@@ -1455,15 +1456,17 @@ void FusedBiasActInferMeta(const MetaTensor& x,
auto token_num = x_dims[0]; auto token_num = x_dims[0];
auto dim = x_dims[1]; auto dim = x_dims[1];
PADDLE_ENFORCE_GT( if (!config.is_runtime) {
x_dims[0], PADDLE_ENFORCE_GT(
0, x_dims[0],
phi::errors::InvalidArgument("The size of Attr(rows) must > 0")); 0,
phi::errors::InvalidArgument("The size of Attr(rows) must > 0"));
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
x_dims[1], x_dims[1],
0, 0,
phi::errors::InvalidArgument("The size of Attr(cols) must > 0")); phi::errors::InvalidArgument("The size of Attr(cols) must > 0"));
}
if (act_method == "geglu" || act_method == "swiglu") { if (act_method == "geglu" || act_method == "swiglu") {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -318,7 +318,8 @@ void FusedBiasActInferMeta(const MetaTensor& x, ...@@ -318,7 +318,8 @@ void FusedBiasActInferMeta(const MetaTensor& x,
int quant_round_type, int quant_round_type,
float quant_max_bound, float quant_max_bound,
float quant_min_bound, float quant_min_bound,
MetaTensor* out); MetaTensor* out,
MetaConfig config = MetaConfig());
void FusedLayerNormInferMeta(const MetaTensor& x, void FusedLayerNormInferMeta(const MetaTensor& x,
const MetaTensor& bias, const MetaTensor& bias,
......
...@@ -47,6 +47,7 @@ def _all_reduce_in_static_mode(tensor, op, group, sync_op, use_calc_stream): ...@@ -47,6 +47,7 @@ def _all_reduce_in_static_mode(tensor, op, group, sync_op, use_calc_stream):
'int8', 'int8',
'uint8', 'uint8',
'bool', 'bool',
'uint16',
], ],
'all_reduce', 'all_reduce',
) )
......
...@@ -109,7 +109,7 @@ def _c_identity(tensor, group=None): ...@@ -109,7 +109,7 @@ def _c_identity(tensor, group=None):
check_variable_and_dtype( check_variable_and_dtype(
tensor, tensor,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity', '_c_identity',
) )
...@@ -169,7 +169,7 @@ def _c_concat(tensor, group=None): ...@@ -169,7 +169,7 @@ def _c_concat(tensor, group=None):
check_variable_and_dtype( check_variable_and_dtype(
tensor, tensor,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_concat', '_c_concat',
) )
...@@ -235,7 +235,7 @@ def _c_split(tensor, group=None): ...@@ -235,7 +235,7 @@ def _c_split(tensor, group=None):
check_variable_and_dtype( check_variable_and_dtype(
tensor, tensor,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_split', '_c_split',
) )
...@@ -322,7 +322,7 @@ def _mp_allreduce( ...@@ -322,7 +322,7 @@ def _mp_allreduce(
check_variable_and_dtype( check_variable_and_dtype(
tensor, tensor,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64' 'uint16'],
op_type, op_type,
) )
......
...@@ -102,7 +102,11 @@ def fused_layer_norm( ...@@ -102,7 +102,11 @@ def fused_layer_norm(
residual_out = helper.create_variable_for_type_inference(dtype=x.dtype) residual_out = helper.create_variable_for_type_inference(dtype=x.dtype)
outputs_dict['residual_out'] = residual_out outputs_dict['residual_out'] = residual_out
inputs = {'x': x, 'norm_weight': norm_weight, 'norm_bias': norm_bias} inputs = {'x': x}
if norm_weight is not None:
inputs['norm_weight'] = norm_weight
if norm_bias is not None:
inputs['norm_bias'] = norm_bias
if residual is not None: if residual is not None:
inputs['residual'] = residual inputs['residual'] = residual
if bias is not None: if bias is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册