提交 c5fbe5e2 编写于 作者: Q qingqing01

Remove fluid in OCR

上级 b8cb839a
...@@ -15,8 +15,7 @@ import logging ...@@ -15,8 +15,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import paddle import paddle
from paddle import fluid from paddle.distributed import ParallelEnv
from paddle.fluid.dygraph.parallel import ParallelEnv
DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5" DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5"
DATA_URL = "http://paddle-ocr-data.bj.bcebos.com/data.tar.gz" DATA_URL = "http://paddle-ocr-data.bj.bcebos.com/data.tar.gz"
...@@ -97,7 +96,7 @@ class PadTarget(object): ...@@ -97,7 +96,7 @@ class PadTarget(object):
return samples return samples
class BatchSampler(fluid.io.BatchSampler): class BatchSampler(paddle.io.BatchSampler):
def __init__(self, def __init__(self,
dataset, dataset,
batch_size, batch_size,
......
...@@ -17,7 +17,6 @@ import argparse ...@@ -17,7 +17,6 @@ import argparse
import functools import functools
import paddle import paddle
import paddle.fluid as fluid
from paddle.static import InputSpec as Input from paddle.static import InputSpec as Input
from paddle.vision.transforms import BatchCompose from paddle.vision.transforms import BatchCompose
...@@ -47,7 +46,7 @@ add_arg('dynamic', bool, False, "Whether to use dygraph. ...@@ -47,7 +46,7 @@ add_arg('dynamic', bool, False, "Whether to use dygraph.
def main(FLAGS): def main(FLAGS):
device = paddle.set_device("gpu" if FLAGS.use_gpu else "cpu") device = paddle.set_device("gpu" if FLAGS.use_gpu else "cpu")
fluid.enable_dygraph(device) if FLAGS.dynamic else None paddle.disable_static(device) if FLAGS.dynamic else None
# yapf: disable # yapf: disable
inputs = [ inputs = [
...@@ -79,7 +78,7 @@ def main(FLAGS): ...@@ -79,7 +78,7 @@ def main(FLAGS):
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
drop_last=False, drop_last=False,
shuffle=False) shuffle=False)
test_loader = fluid.io.DataLoader( test_loader = paddle.io.DataLoader(
test_dataset, test_dataset,
batch_sampler=test_sampler, batch_sampler=test_sampler,
places=device, places=device,
...@@ -94,7 +93,7 @@ def main(FLAGS): ...@@ -94,7 +93,7 @@ def main(FLAGS):
def beam_search(FLAGS): def beam_search(FLAGS):
device = set_device("gpu" if FLAGS.use_gpu else "cpu") device = set_device("gpu" if FLAGS.use_gpu else "cpu")
fluid.enable_dygraph(device) if FLAGS.dynamic else None paddle.disable_static(device) if FLAGS.dynamic else None
# yapf: disable # yapf: disable
inputs = [ inputs = [
...@@ -128,7 +127,7 @@ def beam_search(FLAGS): ...@@ -128,7 +127,7 @@ def beam_search(FLAGS):
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
drop_last=False, drop_last=False,
shuffle=False) shuffle=False)
test_loader = fluid.io.DataLoader( test_loader = paddle.io.DataLoader(
test_dataset, test_dataset,
batch_sampler=test_sampler, batch_sampler=test_sampler,
places=device, places=device,
......
...@@ -23,7 +23,6 @@ import functools ...@@ -23,7 +23,6 @@ import functools
from PIL import Image from PIL import Image
import paddle import paddle
import paddle.fluid as fluid
from paddle.static import InputSpec as Input from paddle.static import InputSpec as Input
from paddle.vision.datasets.folder import ImageFolder from paddle.vision.datasets.folder import ImageFolder
...@@ -53,7 +52,7 @@ add_arg('dynamic', bool, False, "Whether to use dygraph.") ...@@ -53,7 +52,7 @@ add_arg('dynamic', bool, False, "Whether to use dygraph.")
def main(FLAGS): def main(FLAGS):
device = paddle.set_device("gpu" if FLAGS.use_gpu else "cpu") device = paddle.set_device("gpu" if FLAGS.use_gpu else "cpu")
fluid.enable_dygraph(device) if FLAGS.dynamic else None paddle.disable_static(device) if FLAGS.dynamic else None
inputs = [Input([None, 1, 48, 384], "float32", name="pixel"), ] inputs = [Input([None, 1, 48, 384], "float32", name="pixel"), ]
model = paddle.Model( model = paddle.Model(
...@@ -71,7 +70,7 @@ def main(FLAGS): ...@@ -71,7 +70,7 @@ def main(FLAGS):
fn = lambda p: Image.open(p).convert('L') fn = lambda p: Image.open(p).convert('L')
test_dataset = ImageFolder(FLAGS.image_path, loader=fn) test_dataset = ImageFolder(FLAGS.image_path, loader=fn)
test_collate_fn = BatchCompose([data.Resize(), data.Normalize()]) test_collate_fn = BatchCompose([data.Resize(), data.Normalize()])
test_loader = fluid.io.DataLoader( test_loader = paddle.io.DataLoader(
test_dataset, test_dataset,
places=device, places=device,
num_workers=0, num_workers=0,
......
...@@ -16,11 +16,10 @@ from __future__ import print_function ...@@ -16,11 +16,10 @@ from __future__ import print_function
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.nn as nn
import paddle.fluid.layers as layers import paddle.nn.functional as F
from paddle.fluid.layers import BeamSearchDecoder #from paddle.text import RNNCell, RNN, DynamicDecode
from paddle.text import DynamicDecode, BeamSearchDecoder
from paddle.text import RNNCell, RNN, DynamicDecode
class ConvBNPool(paddle.nn.Layer): class ConvBNPool(paddle.nn.Layer):
...@@ -36,103 +35,99 @@ class ConvBNPool(paddle.nn.Layer): ...@@ -36,103 +35,99 @@ class ConvBNPool(paddle.nn.Layer):
filter_size = 3 filter_size = 3
std = (2.0 / (filter_size**2 * in_ch))**0.5 std = (2.0 / (filter_size**2 * in_ch))**0.5
param_0 = fluid.ParamAttr( param_0 = paddle.ParamAttr(
initializer=fluid.initializer.Normal(0.0, std)) initializer=paddle.nn.initializer.Normal(0.0, std))
std = (2.0 / (filter_size**2 * out_ch))**0.5 std = (2.0 / (filter_size**2 * out_ch))**0.5
param_1 = fluid.ParamAttr( param_1 = paddle.ParamAttr(
initializer=fluid.initializer.Normal(0.0, std)) initializer=paddle.nn.initializer.Normal(0.0, std))
self.conv0 = fluid.dygraph.Conv2D( net = [
in_ch, nn.Conv2d(
out_ch, in_ch,
3, out_ch,
padding=1, 3,
param_attr=param_0, padding=1,
bias_attr=False, weight_attr=param_0,
act=None, bias_attr=False),
use_cudnn=use_cudnn) nn.BatchNorm2d(out_ch),
self.bn0 = fluid.dygraph.BatchNorm(out_ch, act=act) ]
self.conv1 = fluid.dygraph.Conv2D( if act == 'relu':
out_ch, net += [nn.ReLU()]
out_ch,
filter_size=3, net += [
padding=1, nn.Conv2d(
param_attr=param_1, out_ch,
bias_attr=False, out_ch,
act=None, kernel_size=3,
use_cudnn=use_cudnn) padding=1,
self.bn1 = fluid.dygraph.BatchNorm(out_ch, act=act) weight_attr=param_1,
bias_attr=False),
nn.BatchNorm2d(out_ch),
]
if act == 'relu':
net += [nn.ReLU()]
if self.pool: if self.pool:
self.pool = fluid.dygraph.Pool2D( net += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
pool_size=2, self.net = nn.Sequential(*net)
pool_type='max',
pool_stride=2,
use_cudnn=use_cudnn,
ceil_mode=True)
def forward(self, inputs): def forward(self, inputs):
out = self.conv0(inputs) return self.net(inputs)
out = self.bn0(out)
out = self.conv1(out)
out = self.bn1(out)
if self.pool:
out = self.pool(out)
return out
class CNN(paddle.nn.Layer): class CNN(paddle.nn.Layer):
def __init__(self, in_ch=1, is_test=False): def __init__(self, in_ch=1, is_test=False):
super(CNN, self).__init__() super(CNN, self).__init__()
self.conv_bn1 = ConvBNPool(in_ch, 16) net = [
self.conv_bn2 = ConvBNPool(16, 32) ConvBNPool(in_ch, 16),
self.conv_bn3 = ConvBNPool(32, 64) ConvBNPool(16, 32),
self.conv_bn4 = ConvBNPool(64, 128, pool=False) ConvBNPool(32, 64),
ConvBNPool(
64, 128, pool=False),
]
self.net = nn.Sequential(*net)
def forward(self, inputs): def forward(self, inputs):
conv = self.conv_bn1(inputs) return self.net(inputs)
conv = self.conv_bn2(conv)
conv = self.conv_bn3(conv)
conv = self.conv_bn4(conv) #class GRUCell(RNNCell):
return conv # def __init__(self,
# input_size,
# hidden_size,
class GRUCell(RNNCell): # param_attr=None,
def __init__(self, # bias_attr=None,
input_size, # gate_activation='sigmoid',
hidden_size, # candidate_activation='tanh',
param_attr=None, # origin_mode=False):
bias_attr=None, # super(GRUCell, self).__init__()
gate_activation='sigmoid', # self.hidden_size = hidden_size
candidate_activation='tanh', # self.fc_layer = nn.Linear(
origin_mode=False): # input_size,
super(GRUCell, self).__init__() # hidden_size * 3,
self.hidden_size = hidden_size # weight_attr=param_attr,
self.fc_layer = fluid.dygraph.Linear( # bias_attr=False)
input_size, #
hidden_size * 3, # self.gru_unit = fluid.dygraph.GRUUnit(
param_attr=param_attr, # hidden_size * 3,
bias_attr=False) # param_attr=param_attr,
# bias_attr=bias_attr,
self.gru_unit = fluid.dygraph.GRUUnit( # activation=candidate_activation,
hidden_size * 3, # gate_activation=gate_activation,
param_attr=param_attr, # origin_mode=origin_mode)
bias_attr=bias_attr, #
activation=candidate_activation, # def forward(self, inputs, states):
gate_activation=gate_activation, # # step_outputs, new_states = cell(step_inputs, states)
origin_mode=origin_mode) # # for GRUCell, `step_outputs` and `new_states` both are hidden
# x = self.fc_layer(inputs)
def forward(self, inputs, states): # hidden, _, _ = self.gru_unit(x, states)
# step_outputs, new_states = cell(step_inputs, states) # return hidden, hidden
# for GRUCell, `step_outputs` and `new_states` both are hidden #
x = self.fc_layer(inputs) # @property
hidden, _, _ = self.gru_unit(x, states) # def state_shape(self):
return hidden, hidden # return [self.hidden_size]
#
@property
def state_shape(self):
return [self.hidden_size]
class Encoder(paddle.nn.Layer): class Encoder(paddle.nn.Layer):
...@@ -147,41 +142,41 @@ class Encoder(paddle.nn.Layer): ...@@ -147,41 +142,41 @@ class Encoder(paddle.nn.Layer):
self.backbone = CNN(in_ch=in_channel, is_test=is_test) self.backbone = CNN(in_ch=in_channel, is_test=is_test)
para_attr = fluid.ParamAttr( para_attr = paddle.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02)) initializer=paddle.nn.initializer.Normal(0.0, 0.02))
bias_attr = fluid.ParamAttr( bias_attr = paddle.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0) initializer=paddle.nn.initializer.Normal(0.0, 0.02),
self.gru_fwd = RNN(cell=GRUCell( learning_rate=2.0)
input_size=128 * 6, self.gru_fwd = nn.RNN(
hidden_size=rnn_hidden_size, cell=nn.GRUCell(
param_attr=para_attr, input_size=128 * 6, hidden_size=rnn_hidden_size),
bias_attr=bias_attr, # param_attr=para_attr,
candidate_activation='relu'), # bias_attr=bias_attr,
is_reverse=False, # candidate_activation='relu'),
time_major=False) is_reverse=False,
self.gru_bwd = RNN(cell=GRUCell( time_major=False)
input_size=128 * 6, self.gru_bwd = nn.RNN(
hidden_size=rnn_hidden_size, cell=nn.GRUCell(
param_attr=para_attr, input_size=128 * 6, hidden_size=rnn_hidden_size),
bias_attr=bias_attr, # param_attr=para_attr,
candidate_activation='relu'), # bias_attr=bias_attr,
is_reverse=True, # candidate_activation='relu'),
time_major=False) is_reverse=True,
self.encoded_proj_fc = fluid.dygraph.Linear( time_major=False)
self.encoded_proj_fc = nn.Linear(
rnn_hidden_size * 2, decoder_size, bias_attr=False) rnn_hidden_size * 2, decoder_size, bias_attr=False)
def forward(self, inputs): def forward(self, inputs):
conv_features = self.backbone(inputs) conv_features = self.backbone(inputs)
conv_features = fluid.layers.transpose( conv_features = paddle.transpose(conv_features, perm=[0, 3, 1, 2])
conv_features, perm=[0, 3, 1, 2])
n, w, c, h = conv_features.shape n, w, c, h = conv_features.shape
seq_feature = fluid.layers.reshape(conv_features, [0, -1, c * h]) seq_feature = paddle.reshape(conv_features, [0, -1, c * h])
gru_fwd, _ = self.gru_fwd(seq_feature) gru_fwd, _ = self.gru_fwd(seq_feature)
gru_bwd, _ = self.gru_bwd(seq_feature) gru_bwd, _ = self.gru_bwd(seq_feature)
encoded_vector = fluid.layers.concat(input=[gru_fwd, gru_bwd], axis=2) encoded_vector = paddle.concat([gru_fwd, gru_bwd], axis=2)
encoded_proj = self.encoded_proj_fc(encoded_vector) encoded_proj = self.encoded_proj_fc(encoded_vector)
return gru_bwd, encoded_vector, encoded_proj return gru_bwd, encoded_vector, encoded_proj
...@@ -194,39 +189,37 @@ class Attention(paddle.nn.Layer): ...@@ -194,39 +189,37 @@ class Attention(paddle.nn.Layer):
def __init__(self, decoder_size): def __init__(self, decoder_size):
super(Attention, self).__init__() super(Attention, self).__init__()
self.fc1 = fluid.dygraph.Linear( self.fc1 = nn.Linear(decoder_size, decoder_size, bias_attr=False)
decoder_size, decoder_size, bias_attr=False) self.fc2 = nn.Linear(decoder_size, 1, bias_attr=False)
self.fc2 = fluid.dygraph.Linear(decoder_size, 1, bias_attr=False)
def forward(self, encoder_vec, encoder_proj, decoder_state): def forward(self, encoder_vec, encoder_proj, decoder_state):
# alignment model, single-layer multilayer perceptron # alignment model, single-layer multilayer perceptron
decoder_state = self.fc1(decoder_state) decoder_state = self.fc1(decoder_state)
decoder_state = fluid.layers.unsqueeze(decoder_state, [1]) decoder_state = paddle.unsqueeze(decoder_state, [1])
e = fluid.layers.elementwise_add(encoder_proj, decoder_state) e = paddle.add(encoder_proj, decoder_state)
e = fluid.layers.tanh(e) e = paddle.tanh(e)
att_scores = self.fc2(e) att_scores = self.fc2(e)
att_scores = fluid.layers.squeeze(att_scores, [2]) att_scores = paddle.squeeze(att_scores, [2])
att_scores = fluid.layers.softmax(att_scores) att_scores = F.softmax(att_scores)
context = fluid.layers.elementwise_mul( context = paddle.multiply(encoder_vec, att_scores, axis=0)
x=encoder_vec, y=att_scores, axis=0) context = paddle.reduce_sum(context, dim=1)
context = fluid.layers.reduce_sum(context, dim=1)
return context return context
class DecoderCell(RNNCell): class DecoderCell(nn.RNNCellBase):
def __init__(self, encoder_size=200, decoder_size=128): def __init__(self, encoder_size=200, decoder_size=128):
super(DecoderCell, self).__init__() super(DecoderCell, self).__init__()
self.attention = Attention(decoder_size) self.attention = Attention(decoder_size)
self.gru_cell = GRUCell( self.gru_cell = nn.GRUCell(
input_size=encoder_size * 2 + decoder_size, input_size=encoder_size * 2 + decoder_size,
hidden_size=decoder_size) hidden_size=decoder_size)
def forward(self, current_word, states, encoder_vec, encoder_proj): def forward(self, current_word, states, encoder_vec, encoder_proj):
context = self.attention(encoder_vec, encoder_proj, states) context = self.attention(encoder_vec, encoder_proj, states)
decoder_inputs = fluid.layers.concat([current_word, context], axis=1) decoder_inputs = paddle.concat([current_word, context], axis=1)
hidden, _ = self.gru_cell(decoder_inputs, states) hidden, _ = self.gru_cell(decoder_inputs, states)
return hidden, hidden return hidden, hidden
...@@ -234,9 +227,9 @@ class DecoderCell(RNNCell): ...@@ -234,9 +227,9 @@ class DecoderCell(RNNCell):
class Decoder(paddle.nn.Layer): class Decoder(paddle.nn.Layer):
def __init__(self, num_classes, emb_dim, encoder_size, decoder_size): def __init__(self, num_classes, emb_dim, encoder_size, decoder_size):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.decoder_attention = RNN(DecoderCell(encoder_size, decoder_size)) self.decoder_attention = nn.RNN(
self.fc = fluid.dygraph.Linear( DecoderCell(encoder_size, decoder_size))
decoder_size, num_classes + 2, act='softmax') self.fc = nn.Linear(decoder_size, num_classes + 2)
def forward(self, target, initial_states, encoder_vec, encoder_proj): def forward(self, target, initial_states, encoder_vec, encoder_proj):
out, _ = self.decoder_attention( out, _ = self.decoder_attention(
...@@ -258,13 +251,10 @@ class Seq2SeqAttModel(paddle.nn.Layer): ...@@ -258,13 +251,10 @@ class Seq2SeqAttModel(paddle.nn.Layer):
num_classes=None, ): num_classes=None, ):
super(Seq2SeqAttModel, self).__init__() super(Seq2SeqAttModel, self).__init__()
self.encoder = Encoder(in_channle, encoder_size, decoder_size) self.encoder = Encoder(in_channle, encoder_size, decoder_size)
self.fc = fluid.dygraph.Linear( self.fc = nn.Sequential(
input_dim=encoder_size, nn.Linear(
output_dim=decoder_size, encoder_size, decoder_size, bias_attr=False), nn.ReLU())
bias_attr=False, self.embedding = nn.Embedding(num_classes + 2, emb_dim)
act='relu')
self.embedding = fluid.dygraph.Embedding(
[num_classes + 2, emb_dim], dtype='float32')
self.decoder = Decoder(num_classes, emb_dim, encoder_size, self.decoder = Decoder(num_classes, emb_dim, encoder_size,
decoder_size) decoder_size)
...@@ -326,7 +316,10 @@ class WeightCrossEntropy(paddle.nn.Layer): ...@@ -326,7 +316,10 @@ class WeightCrossEntropy(paddle.nn.Layer):
super(WeightCrossEntropy, self).__init__() super(WeightCrossEntropy, self).__init__()
def forward(self, predict, label, mask): def forward(self, predict, label, mask):
loss = layers.cross_entropy(predict, label=label) predict = paddle.flatten(predict, start_axis=0, stop_axis=1)
loss = layers.elementwise_mul(loss, mask, axis=0) label = paddle.reshape(label, shape=[-1, 1])
loss = layers.reduce_sum(loss) mask = paddle.reshape(mask, shape=[-1, 1])
loss = F.cross_entropy(predict, label=label)
loss = paddle.multiply(loss, mask, axis=0)
loss = paddle.sum(loss)
return loss return loss
...@@ -59,7 +59,7 @@ add_arg('dynamic', bool, False, "Whether to use dygraph.") ...@@ -59,7 +59,7 @@ add_arg('dynamic', bool, False, "Whether to use dygraph.")
def main(FLAGS): def main(FLAGS):
device = paddle.set_device("gpu" if FLAGS.use_gpu else "cpu") device = paddle.set_device("gpu" if FLAGS.use_gpu else "cpu")
fluid.enable_dygraph(device) if FLAGS.dynamic else None paddle.disable_static(device) if FLAGS.dynamic else None
# yapf: disable # yapf: disable
inputs = [ inputs = [
...@@ -100,7 +100,7 @@ def main(FLAGS): ...@@ -100,7 +100,7 @@ def main(FLAGS):
[data.Resize(), data.Normalize(), data.PadTarget()]) [data.Resize(), data.Normalize(), data.PadTarget()])
train_sampler = data.BatchSampler( train_sampler = data.BatchSampler(
train_dataset, batch_size=FLAGS.batch_size, shuffle=True) train_dataset, batch_size=FLAGS.batch_size, shuffle=True)
train_loader = fluid.io.DataLoader( train_loader = paddle.io.DataLoader(
train_dataset, train_dataset,
batch_sampler=train_sampler, batch_sampler=train_sampler,
places=device, places=device,
...@@ -115,7 +115,7 @@ def main(FLAGS): ...@@ -115,7 +115,7 @@ def main(FLAGS):
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
drop_last=False, drop_last=False,
shuffle=False) shuffle=False)
test_loader = fluid.io.DataLoader( test_loader = paddle.io.DataLoader(
test_dataset, test_dataset,
batch_sampler=test_sampler, batch_sampler=test_sampler,
places=device, places=device,
......
...@@ -21,7 +21,6 @@ import numpy as np ...@@ -21,7 +21,6 @@ import numpy as np
import six import six
import paddle import paddle
import paddle.fluid as fluid
from paddle.metric import Metric from paddle.metric import Metric
...@@ -74,8 +73,8 @@ class SeqAccuracy(Metric): ...@@ -74,8 +73,8 @@ class SeqAccuracy(Metric):
self.reset() self.reset()
def compute(self, output, label, mask, *args, **kwargs): def compute(self, output, label, mask, *args, **kwargs):
pred = fluid.layers.flatten(output, axis=2) pred = paddle.flatten(output, start_axis=2)
score, topk = fluid.layers.topk(pred, 1) score, topk = paddle.topk(pred, 1)
return topk, label, mask return topk, label, mask
def update(self, topk, label, mask, *args, **kwargs): def update(self, topk, label, mask, *args, **kwargs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册