提交 b857db71 编写于 作者: Z zhangxuefei

Set FLAGS_eager_delete_tensor_gb=0.0 automatically

上级 510f5407
......@@ -12,9 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
os.environ["FLAGS_eager_delete_tensor_gb"] = "0.0"
if six.PY2:
import sys
reload(sys) # noqa
......
......@@ -580,10 +580,22 @@ class Module(object):
logger.info(
"Set maximum sequence length of input tensor to {}".format(
max_seq_len))
for tensor_name in [
if self.name.startswith("ernie_v2"):
feed_list = [
"input_ids", "position_ids", "segment_ids", "input_mask",
"task_ids"
]:
]
logger.warning(
"%s will exploite task_id, the arguement use_taskid of Reader class must be True."
% self.name)
else:
feed_list = [
"input_ids", "position_ids", "segment_ids", "input_mask"
]
logger.warning(
"%s has no task_id, the arguement use_taskid of Reader class must be False."
% self.name)
for tensor_name in feed_list:
seq_tensor_shape = [-1, max_seq_len, 1]
logger.info("The shape of input tensor[{}] set to {}".format(
tensor_name, seq_tensor_shape))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册