提交 a7d6b1af 编写于 作者: W wanghaoshuang

Fix some issues

上级 fbbf6c04
...@@ -11,8 +11,6 @@ ...@@ -11,8 +11,6 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
from __future__ import print_function
import sys import sys
import paddle.v2 as paddle import paddle.v2 as paddle
...@@ -42,7 +40,7 @@ def ocr_conv(input, num, with_bn): ...@@ -42,7 +40,7 @@ def ocr_conv(input, num, with_bn):
num_classes = 9054 num_classes = 9054
data_shape = [3, 32, 32] data_shape = [1, 512, 512]
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
...@@ -57,11 +55,10 @@ sliced_feature = fluid.layers.im2sequence( ...@@ -57,11 +55,10 @@ sliced_feature = fluid.layers.im2sequence(
block_x=1, block_x=1,
block_y=3, ) block_y=3, )
gru_forward = fluid.layers.gru(input=sliced_feature, size=200, act="relu") # TODO(wanghaoshuang): repaced by GRU
gru_backward = fluid.layers.gru(input=sliced_feature, gru_forward = fluid.layers.lstm(input=sliced_feature, size=200, act="relu")
size=200, gru_backward = fluid.layers.lstm(
reverse=True, input=sliced_feature, size=200, reverse=True, act="relu")
act="relu")
out = fluid.layers.fc(input=[gru_forward, gru_backward], size=num_classes + 1) out = fluid.layers.fc(input=[gru_forward, gru_backward], size=num_classes + 1)
cost = fluid.layers.warpctc( cost = fluid.layers.warpctc(
...@@ -70,17 +67,20 @@ cost = fluid.layers.warpctc( ...@@ -70,17 +67,20 @@ cost = fluid.layers.warpctc(
size=num_classes + 1, size=num_classes + 1,
blank=num_classes, blank=num_classes,
norm_by_times=True) norm_by_times=True)
avg_cost = fluid.layers.mean(x=cost)
# TODO(wanghaoshuang): set clipping
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=((1.0e-3) / 16), momentum=0.9) learning_rate=((1.0e-3) / 16), momentum=0.9)
opts = optimizer.minimize(cost) opts = optimizer.minimize(cost)
decoded_out = fluid.layers.ctc_greedy_decoder(input=output, blank=class_num) decoded_out = fluid.layers.ctc_greedy_decoder(input=output, blank=class_num)
error = fluid.evaluator.EditDistance(input=decoded_out, label=label) error_evaluator = fluid.evaluator.EditDistance(input=decoded_out, label=label)
BATCH_SIZE = 16 BATCH_SIZE = 16
PASS_NUM = 1 PASS_NUM = 1
# TODO(wanghaoshuang): replaced by correct data reader
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=128 * 10), paddle.dataset.cifar.train10(), buf_size=128 * 10),
...@@ -92,14 +92,11 @@ feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) ...@@ -92,14 +92,11 @@ feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
accuracy.reset(exe) error_evaluator.reset(exe)
for data in train_reader(): for data in train_reader():
loss, acc = exe.run(fluid.default_main_program(), loss, error = exe.run(fluid.default_main_program(),
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_cost] + accuracy.metrics) fetch_list=[avg_cost] + error.metrics)
pass_acc = accuracy.eval(exe) pass_error = error_evaluator.eval(exe)
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( print "loss: %s; distance error: %s; pass_dis_error: %s;" % (
pass_acc)) str(loss), str(error), str(pass_error))
# this model is slow, so if we can train two mini batch, we think it works properly.
exit(0)
exit(1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册