提交 25782419 编写于 作者: D dzhwinter

fix default value. test=develop

上级 c6bd434f
......@@ -13,21 +13,47 @@
# limitations under the License.
import os
import sys
import unittest
from timeit import default_timer as timer
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.dataset.wmt16 as wmt16
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
os.environ[
'RECORDIO_FILENAME'] = '/tmp/ir_memory_optimize_transformer.wmt16.recordio'
from test_parallel_executor_transformer import TestTransformer
from test_parallel_executor_transformer import transformer
from test_parallel_executor_transformer import transformer, ModelHyperParams, transformer_model, transformer, prepare_batch_input
from parallel_executor_test_base import TestParallelExecutorBase
# disable temporarily because of timeout.
sys.exit(0)
# NOTE(dzhwinter): test diferent strategy colisions.
# open the eager delete tensor strategy by default.
class TestTransformerWithIR(TestTransformer):
class TestTransformerWithIR(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
reader = paddle.batch(
wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=transformer_model.batch_size)
with fluid.recordio_writer.create_recordio_writer(
os.environ.get("RECORDIO_FILENAME")) as writer:
for batch in reader():
for tensor in prepare_batch_input(
batch, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head):
t = fluid.LoDTensor()
t.set(tensor, fluid.CPUPlace())
writer.append_tensor(t)
writer.complete_append_tensor()
def test_main(self):
if core.is_compiled_with_cuda():
# check python transpiler
......@@ -35,13 +61,15 @@ class TestTransformerWithIR(TestTransformer):
transformer,
use_cuda=True,
memory_opt=True,
use_ir_memory_optimize=False)
use_ir_memory_optimize=False,
iter=2)
# check IR memory optimize
self.check_network_convergence(
transformer,
use_cuda=True,
memory_opt=False,
use_ir_memory_optimize=True)
use_ir_memory_optimize=True,
iter=2)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册