提交 b857db71 编写于 作者: Z zhangxuefei

Set FLAGS_eager_delete_tensor_gb=0.0 automatically

上级 510f5407
...@@ -12,9 +12,11 @@ ...@@ -12,9 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import six import six
os.environ["FLAGS_eager_delete_tensor_gb"] = "0.0"
if six.PY2: if six.PY2:
import sys import sys
reload(sys) # noqa reload(sys) # noqa
......
...@@ -580,10 +580,22 @@ class Module(object): ...@@ -580,10 +580,22 @@ class Module(object):
logger.info( logger.info(
"Set maximum sequence length of input tensor to {}".format( "Set maximum sequence length of input tensor to {}".format(
max_seq_len)) max_seq_len))
for tensor_name in [ if self.name.startswith("ernie_v2"):
feed_list = [
"input_ids", "position_ids", "segment_ids", "input_mask", "input_ids", "position_ids", "segment_ids", "input_mask",
"task_ids" "task_ids"
]: ]
logger.warning(
"%s will exploite task_id, the arguement use_taskid of Reader class must be True."
% self.name)
else:
feed_list = [
"input_ids", "position_ids", "segment_ids", "input_mask"
]
logger.warning(
"%s has no task_id, the arguement use_taskid of Reader class must be False."
% self.name)
for tensor_name in feed_list:
seq_tensor_shape = [-1, max_seq_len, 1] seq_tensor_shape = [-1, max_seq_len, 1]
logger.info("The shape of input tensor[{}] set to {}".format( logger.info("The shape of input tensor[{}] set to {}".format(
tensor_name, seq_tensor_shape)) tensor_name, seq_tensor_shape))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册