未验证 提交 5efb3d3d 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #1122 from kuke/text_cls_fix

Fix paddle import in text_cls
...@@ -4,8 +4,8 @@ import unittest ...@@ -4,8 +4,8 @@ import unittest
import contextlib import contextlib
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.v2 as paddle
import utils import utils
......
...@@ -4,8 +4,8 @@ import time ...@@ -4,8 +4,8 @@ import time
import unittest import unittest
import contextlib import contextlib
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.v2 as paddle
import utils import utils
from nets import bow_net from nets import bow_net
...@@ -55,7 +55,7 @@ def train(train_reader, ...@@ -55,7 +55,7 @@ def train(train_reader,
feeder = fluid.DataFeeder(feed_list=[data, label], place=place) feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
# For internal continuous evaluation # For internal continuous evaluation
if 'CE_MODE_X' in os.environ: if "CE_MODE_X" in os.environ:
fluid.default_startup_program().random_seed = 110 fluid.default_startup_program().random_seed = 110
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for pass_id in xrange(pass_num): for pass_id in xrange(pass_num):
...@@ -80,7 +80,7 @@ def train(train_reader, ...@@ -80,7 +80,7 @@ def train(train_reader,
pass_end = time.time() pass_end = time.time()
# For internal continuous evaluation # For internal continuous evaluation
if 'CE_MODE_X' in os.environ: if "CE_MODE_X" in os.environ:
print("kpis train_acc %f" % avg_acc) print("kpis train_acc %f" % avg_acc)
print("kpis train_cost %f" % avg_cost) print("kpis train_cost %f" % avg_cost)
print("kpis train_duration %f" % (pass_end - pass_start)) print("kpis train_duration %f" % (pass_end - pass_start))
......
...@@ -65,7 +65,7 @@ def prepare_data(data_type="imdb", ...@@ -65,7 +65,7 @@ def prepare_data(data_type="imdb",
raise RuntimeError("No such dataset") raise RuntimeError("No such dataset")
if data_type == "imdb": if data_type == "imdb":
if 'CE_MODE_X' in os.environ: if "CE_MODE_X" in os.environ:
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.imdb.train(word_dict), batch_size=batch_size) paddle.dataset.imdb.train(word_dict), batch_size=batch_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册