From b857db717e90b2220501155f784b2788949224f7 Mon Sep 17 00:00:00 2001 From: zhangxuefei Date: Fri, 2 Aug 2019 16:28:05 +0800 Subject: [PATCH] Set FLAGS_eager_delete_tensor_gb=0.0 automatically --- paddlehub/__init__.py | 4 +++- paddlehub/module/module.py | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/paddlehub/__init__.py b/paddlehub/__init__.py index c5e9d2e9..9824616f 100644 --- a/paddlehub/__init__.py +++ b/paddlehub/__init__.py @@ -12,9 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os import six +os.environ["FLAGS_eager_delete_tensor_gb"] = "0.0" + if six.PY2: import sys reload(sys) # noqa diff --git a/paddlehub/module/module.py b/paddlehub/module/module.py index fce447b1..0190c492 100644 --- a/paddlehub/module/module.py +++ b/paddlehub/module/module.py @@ -580,10 +580,22 @@ class Module(object): logger.info( "Set maximum sequence length of input tensor to {}".format( max_seq_len)) - for tensor_name in [ + if self.name.startswith("ernie_v2"): + feed_list = [ "input_ids", "position_ids", "segment_ids", "input_mask", "task_ids" - ]: + ] + logger.warning( + "%s will exploite task_id, the arguement use_taskid of Reader class must be True." + % self.name) + else: + feed_list = [ + "input_ids", "position_ids", "segment_ids", "input_mask" + ] + logger.warning( + "%s has no task_id, the arguement use_taskid of Reader class must be False." + % self.name) + for tensor_name in feed_list: seq_tensor_shape = [-1, max_seq_len, 1] logger.info("The shape of input tensor[{}] set to {}".format( tensor_name, seq_tensor_shape)) -- GitLab