提交 33138a42 编写于 作者: S sneaxiy

remove match check

test=develop
上级 814a7590
...@@ -241,9 +241,6 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl( ...@@ -241,9 +241,6 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op); ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op);
while_op_set.erase(*matched_fwd_op); while_op_set.erase(*matched_fwd_op);
} }
PADDLE_ENFORCE(while_op_set.empty(),
"There are not matched while_grad op in graph.");
} }
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
......
...@@ -47,7 +47,7 @@ class TestEagerDeletionWhileOpBase(unittest.TestCase): ...@@ -47,7 +47,7 @@ class TestEagerDeletionWhileOpBase(unittest.TestCase):
self.with_data_parallel = with_data_parallel self.with_data_parallel = with_data_parallel
if not core.is_compiled_with_cuda() and isinstance(self.place, if not core.is_compiled_with_cuda() and isinstance(self.place,
core.CUDPlace): core.CUDAPlace):
return return
if isinstance(self.place, core.CUDAPlace): if isinstance(self.place, core.CUDAPlace):
...@@ -55,8 +55,8 @@ class TestEagerDeletionWhileOpBase(unittest.TestCase): ...@@ -55,8 +55,8 @@ class TestEagerDeletionWhileOpBase(unittest.TestCase):
) if self.with_data_parallel else 1 ) if self.with_data_parallel else 1
else: else:
device_cnt = int( device_cnt = int(
os.environ['CPU_NUM'], os.environ.get('CPU_NUM', multiprocessing.cpu_count(
multiprocessing.cpu_count()) if self.with_data_parallel else 1 ))) if self.with_data_parallel else 1
d0 = layers.data( d0 = layers.data(
"d0", shape=[10], append_batch_size=False, dtype='float32') "d0", shape=[10], append_batch_size=False, dtype='float32')
......
...@@ -18,7 +18,7 @@ os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0" ...@@ -18,7 +18,7 @@ os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
os.environ['FLAGS_memory_fraction_of_eager_deletion'] = "0.55" os.environ['FLAGS_memory_fraction_of_eager_deletion'] = "0.55"
os.environ[ os.environ[
'RECORDIO_FILENAME'] = '/tmp/eager_deletion_transformer.wmt16.recordio' 'RECORDIO_FILENAME'] = '/tmp/partial_eager_deletion_transformer.wmt16.recordio'
from test_parallel_executor_transformer import TestTransformer from test_parallel_executor_transformer import TestTransformer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册