From e508373ac84eda7e0d1e5970f382dcdfacd8fc4f Mon Sep 17 00:00:00 2001 From: Jiaqi Liu Date: Fri, 7 Jan 2022 11:29:43 +0800 Subject: [PATCH] Support string input for ofa export (#964) (#967) * support string input for ofa export --- paddleslim/nas/ofa/ofa.py | 17 +++++++++++++++++ paddleslim/quant/quanter.py | 5 +++++ 2 files changed, 22 insertions(+) diff --git a/paddleslim/nas/ofa/ofa.py b/paddleslim/nas/ofa/ofa.py index 3f078390..ac788209 100644 --- a/paddleslim/nas/ofa/ofa.py +++ b/paddleslim/nas/ofa/ofa.py @@ -79,6 +79,21 @@ DistillConfig = namedtuple( DistillConfig.__new__.__defaults__ = (None, ) * len(DistillConfig._fields) +def to_tensor(string_values, name="text"): + """ + Create the tensor that the value holds the list of string. + NOTICE: The value will be holded in the cpu place. + + Parameters: + string_values(list[string]): The value will be setted to the tensor. + name(string): The name of the tensor. + """ + tensor = paddle.Tensor(core.VarDesc.VarType.STRING, [], name, + core.VarDesc.VarType.STRINGS, False) + tensor.value().set_string_list(string_values) + return tensor + + class OFABase(Layer): def __init__(self, model): super(OFABase, self).__init__() @@ -531,6 +546,8 @@ class OFA(OFABase): dtype = dtypes[0] else: dtype = dtypes + if dtype == core.VarDesc.VarType.STRINGS: + return to_tensor([""]) return paddle.cast(paddle.rand(list(input_size)), dtype) if isinstance(input_size, dict): inputs = {} diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index 2522fed7..c6c43f57 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -306,6 +306,7 @@ def quant_post_static( quantize_model_path, batch_generator=None, sample_generator=None, + data_loader=None, model_filename=None, params_filename=None, save_model_filename='__model__', @@ -345,6 +346,9 @@ def quant_post_static( can be set. Beisdes, batch_generator supports lod tensor. sample_generator(Python Generator): The sample generator provides calibrate data for DataLoader, and it only returns a sample every time. + data_loader(Python Generator, Paddle.io.DataLoader, optional): The + Generator or Dataloader provides calibrate data, and it could + return a batch every time. model_filename(str, optional): The name of model file. If parameters are saved in separate files, set it as 'None'. Default: 'None'. params_filename(str, optional): The name of params file. @@ -398,6 +402,7 @@ def quant_post_static( executor=executor, sample_generator=sample_generator, batch_generator=batch_generator, + data_loader=data_loader, model_dir=model_dir, model_filename=model_filename, params_filename=params_filename, -- GitLab