未验证 提交 161f5814 编写于 作者: C chenjian 提交者: GitHub

remove fluid api in paddlehub compat code (#2118)

* remove fluid api in paddlehub compat code

* fix

* fix

* fix

* fix
上级 3d33cadc
...@@ -17,7 +17,7 @@ import os ...@@ -17,7 +17,7 @@ import os
from typing import List from typing import List
from typing import Tuple from typing import Tuple
import paddle import paddle.io
import paddle2onnx import paddle2onnx
from easydict import EasyDict from easydict import EasyDict
...@@ -179,15 +179,18 @@ class ModuleV1(object): ...@@ -179,15 +179,18 @@ class ModuleV1(object):
for item in zip(*process_data): for item in zip(*process_data):
yield item yield item
nonlocal feed_dict
process_data = [] process_data = []
feed_name_list = [] feed_name_list = []
feed_list = []
for key in data_format: for key in data_format:
process_data.append([value['processed'] for value in data[key]]) process_data.append([value['processed'] for value in data[key]])
feed_name_list.append(data_format[key]['feed_key']) feed_name_list.append(data_format[key]['feed_key'])
feeder = paddle.fluid.DataFeeder(feed_list=feed_name_list, place=place) feed_list.append(feed_dict[key])
return functools.partial(_reader, process_data=process_data), feeder loader = paddle.io.DataLoader.from_generator(feed_list=feed_list, capacity=1)
return functools.partial(_reader, process_data=process_data), loader
_, fetch_dict, program = self.context(signature=sign_name, for_test=True) feed_dict, fetch_dict, program = self.context(signature=sign_name, for_test=True)
fetch_list = list([value for key, value in fetch_dict.items()]) fetch_list = list([value for key, value in fetch_dict.items()])
with paddle.static.program_guard(program): with paddle.static.program_guard(program):
result = [] result = []
...@@ -197,10 +200,11 @@ class ModuleV1(object): ...@@ -197,10 +200,11 @@ class ModuleV1(object):
exe = paddle.static.Executor(place=place) exe = paddle.static.Executor(place=place)
data = self.processor.preprocess(sign_name=sign_name, data_dict=data) data = self.processor.preprocess(sign_name=sign_name, data_dict=data)
data_format = self.processor.data_format(sign_name=sign_name) data_format = self.processor.data_format(sign_name=sign_name)
reader, feeder = _get_reader_and_feeder(data_format, data, place) reader, loader = _get_reader_and_feeder(data_format, data, place)
reader = paddle.batch(reader, batch_size=batch_size) reader = paddle.batch(reader, batch_size=batch_size)
for batch in reader(): loader.set_sample_list_generator(reader, places=place)
data_out = exe.run(feed=feeder.feed(batch), fetch_list=fetch_list, return_numpy=False) for batch in loader():
data_out = exe.run(feed=batch, fetch_list=fetch_list, return_numpy=False)
sub_data = {key: value[index:index + len(batch)] for key, value in data.items()} sub_data = {key: value[index:index + len(batch)] for key, value in data.items()}
result += self.processor.postprocess(sign_name, data_out, sub_data, **kwargs) result += self.processor.postprocess(sign_name, data_out, sub_data, **kwargs)
index += len(batch) index += len(batch)
......
...@@ -12,12 +12,11 @@ ...@@ -12,12 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import copy
import time import time
from typing import Any from typing import Any
import paddle import paddle.utils.unique_name
class RunState(object): class RunState(object):
...@@ -63,7 +62,7 @@ class RunEnv(object): ...@@ -63,7 +62,7 @@ class RunEnv(object):
self.labels = None self.labels = None
self.metrics = None self.metrics = None
self.is_inititalized = False self.is_inititalized = False
self.UNG = copy.deepcopy(paddle.fluid.unique_name.generator) self.UNG = paddle.utils.unique_name.generate
def __setattr__(self, key: str, value: Any): def __setattr__(self, key: str, value: Any):
self.__dict__[key] = value self.__dict__[key] = value
......
...@@ -12,69 +12,72 @@ ...@@ -12,69 +12,72 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import time import time
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle
from paddle.fluid import ParamAttr import paddle.nn as nn
from paddle.fluid.layers import RNNCell, LSTMCell, rnn, BeamSearchDecoder, dynamic_decode from paddle import ParamAttr
from paddle.nn import BeamSearchDecoder
from paddle.nn import dynamic_decode
from paddle.nn import LSTMCell
from paddle.nn import RNN
from paddle.nn import RNNCellBase
from paddlehub.compat.task.metrics import compute_bleu
from paddlehub.compat.task.base_task import BaseTask from paddlehub.compat.task.base_task import BaseTask
from paddlehub.compat.task.metrics import compute_bleu
class AttentionDecoderCell(RNNCellBase):
class AttentionDecoderCell(RNNCell): def __init__(self, num_layers, input_size, hidden_size, dropout_prob=0., init_scale=0.1):
def __init__(self, num_layers, hidden_size, dropout_prob=0., init_scale=0.1): super(AttentionDecoderCell, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dropout_prob = dropout_prob self.dropout_prob = dropout_prob
self.lstm_cells = [] self.lstm_cells = []
self.init_scale = init_scale self.init_scale = init_scale
param_attr = ParamAttr(initializer=fluid.initializer.UniformInitializer(low=-init_scale, high=init_scale))
bias_attr = ParamAttr(initializer=fluid.initializer.Constant(0.0))
for i in range(num_layers): for i in range(num_layers):
self.lstm_cells.append(LSTMCell(hidden_size, param_attr, bias_attr)) self.lstm_cells.append(
LSTMCell(input_size=input_size + hidden_size if i == 0 else hidden_size, hidden_size=hidden_size))
def attention(self, query, enc_output, mask=None): def attention(self, query, enc_output, mask=None):
query = fluid.layers.unsqueeze(query, [1]) query = paddle.unsqueeze(query, [1])
memory = fluid.layers.fc( memory = paddle.static.nn.fc(enc_output,
enc_output, self.hidden_size,
self.hidden_size, num_flatten_dims=2,
num_flatten_dims=2, weight_attr=ParamAttr(name='dec_memory_w',
param_attr=ParamAttr( initializer=nn.initializer.Uniform(low=-self.init_scale,
name='dec_memory_w', high=self.init_scale)))
initializer=fluid.initializer.UniformInitializer(low=-self.init_scale, high=self.init_scale))) attn = paddle.matmul(query, memory, transpose_y=True)
attn = fluid.layers.matmul(query, memory, transpose_y=True)
if mask: if mask:
attn = fluid.layers.transpose(attn, [1, 0, 2]) attn = paddle.transpose(attn, [1, 0, 2])
attn = fluid.layers.elementwise_add(attn, mask * 1000000000, -1) attn = attn + (mask * 1000000000)
attn = fluid.layers.transpose(attn, [1, 0, 2]) attn = paddle.transpose(attn, [1, 0, 2])
weight = fluid.layers.softmax(attn) weight = nn.functional.softmax(attn)
weight_memory = fluid.layers.matmul(weight, memory) weight_memory = paddle.matmul(weight, memory)
return weight_memory return weight_memory
def call(self, step_input, states, enc_output, enc_padding_mask=None): def forward(self, step_input, states, enc_output, enc_padding_mask=None):
lstm_states, input_feed = states lstm_states, input_feed = states
new_lstm_states = [] new_lstm_states = []
step_input = fluid.layers.concat([step_input, input_feed], 1) step_input = paddle.concat([step_input, input_feed], 1)
for i in range(self.num_layers): for i in range(self.num_layers):
out, new_lstm_state = self.lstm_cells[i](step_input, lstm_states[i]) out, new_lstm_state = self.lstm_cells[i](step_input, lstm_states[i])
step_input = fluid.layers.dropout( step_input = nn.functional.dropout(out, self.dropout_prob,
out, self.dropout_prob, dropout_implementation='upscale_in_train') if self.dropout_prob > 0 else out mode='upscale_in_train') if self.dropout_prob > 0 else out
new_lstm_states.append(new_lstm_state) new_lstm_states.append(new_lstm_state)
dec_att = self.attention(step_input, enc_output, enc_padding_mask) dec_att = self.attention(step_input, enc_output, enc_padding_mask)
dec_att = fluid.layers.squeeze(dec_att, [1]) dec_att = paddle.squeeze(dec_att, [1])
concat_att_out = fluid.layers.concat([dec_att, step_input], 1) concat_att_out = paddle.concat([dec_att, step_input], 1)
out = fluid.layers.fc( out = paddle.static.nn.fc(concat_att_out,
concat_att_out, self.hidden_size,
self.hidden_size, weight_attr=ParamAttr(name='dec_out_w',
param_attr=ParamAttr( initializer=nn.initializer.Uniform(low=-self.init_scale,
name='dec_out_w', high=self.init_scale)))
initializer=fluid.initializer.UniformInitializer(low=-self.init_scale, high=self.init_scale)))
return out, [new_lstm_states, out] return out, [new_lstm_states, out]
...@@ -101,32 +104,31 @@ class TextGenerationTask(BaseTask): ...@@ -101,32 +104,31 @@ class TextGenerationTask(BaseTask):
''' '''
def __init__( def __init__(
self, self,
feature, feature,
token_feature, token_feature,
max_seq_len, max_seq_len,
num_classes, num_classes,
dataset=None, dataset=None,
num_layers=1, num_layers=1,
hidden_size=512, hidden_size=512,
dropout=0., dropout=0.,
beam_size=10, beam_size=10,
beam_max_step_num=30, beam_max_step_num=30,
start_token='<s>', start_token='<s>',
end_token='</s>', end_token='</s>',
startup_program=None, startup_program=None,
config=None, config=None,
metrics_choices='default', metrics_choices='default',
): ):
if metrics_choices == 'default': if metrics_choices == 'default':
metrics_choices = ['bleu'] metrics_choices = ['bleu']
main_program = feature.block.program main_program = feature.block.program
super(TextGenerationTask, self).__init__( super(TextGenerationTask, self).__init__(dataset=dataset,
dataset=dataset, main_program=main_program,
main_program=main_program, startup_program=startup_program,
startup_program=startup_program, config=config,
config=config, metrics_choices=metrics_choices)
metrics_choices=metrics_choices)
self.num_layers = num_layers self.num_layers = num_layers
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -141,77 +143,73 @@ class TextGenerationTask(BaseTask): ...@@ -141,77 +143,73 @@ class TextGenerationTask(BaseTask):
self.end_token = end_token self.end_token = end_token
def _add_label(self): def _add_label(self):
label = fluid.layers.data(name='label', shape=[self.max_seq_len, 1], dtype='int64') label = paddle.static.data(name='label', shape=[self.max_seq_len, 1], dtype='int64')
return [label] return [label]
def _build_net(self): def _build_net(self):
self.seq_len = fluid.layers.data(name='seq_len', shape=[1], dtype='int64', lod_level=0) self.seq_len = paddle.static.data(name='seq_len', shape=[1], dtype='int64', lod_level=0)
self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1]) self.seq_len_used = paddle.squeeze(self.seq_len)
src_mask = fluid.layers.sequence_mask(self.seq_len_used, maxlen=self.max_seq_len, dtype='float32') src_mask = nn.functional.sequence_mask(self.seq_len_used, maxlen=self.max_seq_len, dtype='float32')
enc_padding_mask = (src_mask - 1.0) enc_padding_mask = (src_mask - 1.0)
# Define decoder and initialize it. # Define decoder and initialize it.
dec_cell = AttentionDecoderCell(self.num_layers, self.hidden_size, self.dropout) dec_cell = AttentionDecoderCell(self.num_layers, self.feature.shape[-1], self.hidden_size, self.dropout)
dec_init_hidden = fluid.layers.fc( dec_init_hidden = paddle.static.nn.fc(self.feature,
input=self.feature, size=self.hidden_size,
size=self.hidden_size, num_flatten_dims=1,
num_flatten_dims=1, weight_attr=ParamAttr(
param_attr=fluid.ParamAttr( name='dec_init_hidden_w',
name='dec_init_hidden_w', initializer=fluid.initializer.TruncatedNormal(scale=0.02)), initializer=nn.initializer.TruncatedNormal(std=0.02)),
bias_attr=fluid.ParamAttr(name='dec_init_hidden_b', initializer=fluid.initializer.Constant(0.))) bias_attr=ParamAttr(name='dec_init_hidden_b',
initializer=nn.initializer.Constant(0.)))
dec_initial_states = [ dec_initial_states = [
[[dec_init_hidden, [[dec_init_hidden,
dec_cell.get_initial_states(batch_ref=self.feature, shape=[self.hidden_size])]] * self.num_layers, dec_cell.get_initial_states(batch_ref=self.feature, shape=[self.hidden_size])]] * self.num_layers,
dec_cell.get_initial_states(batch_ref=self.feature, shape=[self.hidden_size]) dec_cell.get_initial_states(batch_ref=self.feature, shape=[self.hidden_size])
] ]
tar_vocab_size = len(self._label_list) tar_vocab_size = len(self._label_list)
tar_embeder = lambda x: fluid.embedding( tar_embeder = lambda x: paddle.static.nn.embedding(
input=x, input=x,
size=[tar_vocab_size, self.hidden_size], size=[tar_vocab_size, self.hidden_size],
dtype='float32', dtype='float32',
is_sparse=False, is_sparse=False,
param_attr=fluid.ParamAttr( param_attr=ParamAttr(name='target_embedding', initializer=nn.initializer.Uniform(low=-0.1, high=0.1)))
name='target_embedding', initializer=fluid.initializer.UniformInitializer(low=-0.1, high=0.1)))
start_token_id = self._label_list.index(self.start_token) start_token_id = self._label_list.index(self.start_token)
end_token_id = self._label_list.index(self.end_token) end_token_id = self._label_list.index(self.end_token)
if not self.is_predict_phase: if not self.is_predict_phase:
self.dec_input = fluid.layers.data(name='dec_input', shape=[self.max_seq_len], dtype='int64') self.dec_input = paddle.static.data(name='dec_input', shape=[self.max_seq_len], dtype='int64')
tar_emb = tar_embeder(self.dec_input) tar_emb = tar_embeder(self.dec_input)
dec_output, _ = rnn( rnn = nn.RNN(dec_cell, is_reverse=False, time_major=False)
cell=dec_cell, dec_output, _ = rnn(inputs=tar_emb,
inputs=tar_emb, initial_states=dec_initial_states,
initial_states=dec_initial_states, enc_output=self.token_feature,
sequence_length=None, enc_padding_mask=enc_padding_mask)
enc_output=self.token_feature, self.logits = paddle.static.nn.fc(dec_output,
enc_padding_mask=enc_padding_mask) size=tar_vocab_size,
self.logits = fluid.layers.fc( num_flatten_dims=len(dec_output.shape) - 1,
dec_output, weight_attr=ParamAttr(name='output_w',
size=tar_vocab_size, initializer=nn.initializer.Uniform(low=-0.1,
num_flatten_dims=len(dec_output.shape) - 1, high=0.1)))
param_attr=fluid.ParamAttr( self.ret_infers = paddle.reshape(x=paddle.argmax(self.logits, axis=2), shape=[-1, 1])
name='output_w', initializer=fluid.initializer.UniformInitializer(low=-0.1, high=0.1)))
self.ret_infers = fluid.layers.reshape(x=fluid.layers.argmax(self.logits, axis=2), shape=[-1, 1])
logits = self.logits logits = self.logits
logits = fluid.layers.softmax(logits) logits = nn.functional.softmax(logits)
return [logits] return [logits]
else: else:
output_layer = lambda x: fluid.layers.fc( output_layer = lambda x: paddle.static.nn.fc(
x, size=tar_vocab_size, num_flatten_dims=len(x.shape) - 1, param_attr=fluid.ParamAttr(name='output_w')) x, size=tar_vocab_size, num_flatten_dims=len(x.shape) - 1, weight_attr=ParamAttr(name='output_w'))
beam_search_decoder = BeamSearchDecoder( beam_search_decoder = BeamSearchDecoder(dec_cell,
dec_cell, start_token_id,
start_token_id, end_token_id,
end_token_id, self.beam_size,
self.beam_size, embedding_fn=tar_embeder,
embedding_fn=tar_embeder, output_fn=output_layer)
output_fn=output_layer)
enc_output = beam_search_decoder.tile_beam_merge_with_batch(self.token_feature, self.beam_size) enc_output = beam_search_decoder.tile_beam_merge_with_batch(self.token_feature, self.beam_size)
enc_padding_mask = beam_search_decoder.tile_beam_merge_with_batch(enc_padding_mask, self.beam_size) enc_padding_mask = beam_search_decoder.tile_beam_merge_with_batch(enc_padding_mask, self.beam_size)
self.ret_infers, _ = dynamic_decode( self.ret_infers, _ = dynamic_decode(beam_search_decoder,
beam_search_decoder, inits=dec_initial_states,
inits=dec_initial_states, max_step_num=self.beam_max_step_num,
max_step_num=self.beam_max_step_num, enc_output=enc_output,
enc_output=enc_output, enc_padding_mask=enc_padding_mask)
enc_padding_mask=enc_padding_mask)
return self.ret_infers return self.ret_infers
def _postprocessing(self, run_states): def _postprocessing(self, run_states):
...@@ -229,18 +227,18 @@ class TextGenerationTask(BaseTask): ...@@ -229,18 +227,18 @@ class TextGenerationTask(BaseTask):
return results return results
def _add_metrics(self): def _add_metrics(self):
self.ret_labels = fluid.layers.reshape(x=self.labels[0], shape=[-1, 1]) self.ret_labels = paddle.reshape(x=self.labels[0], shape=[-1, 1])
return [self.ret_labels, self.ret_infers, self.seq_len_used] return [self.ret_labels, self.ret_infers, self.seq_len_used]
def _add_loss(self): def _add_loss(self):
loss = fluid.layers.cross_entropy(input=self.outputs[0], label=self.labels[0], soft_label=False) loss = nn.functional.cross_entropy(input=self.outputs[0], label=self.labels[0], soft_label=False)
loss = fluid.layers.unsqueeze(loss, axes=[2]) loss = paddle.unsqueeze(loss, axis=[2])
max_tar_seq_len = fluid.layers.shape(self.dec_input)[1] max_tar_seq_len = paddle.shape(self.dec_input)[1]
tar_sequence_length = fluid.layers.elementwise_sub(self.seq_len_used, fluid.layers.ones_like(self.seq_len_used)) tar_sequence_length = self.seq_len_used - paddle.ones_like(self.seq_len_used)
tar_mask = fluid.layers.sequence_mask(tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32') tar_mask = nn.functional.sequence_mask(tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32')
loss = loss * tar_mask loss = loss * tar_mask
loss = fluid.layers.reduce_mean(loss, dim=[0]) loss = paddle.mean(loss, axis=[0])
loss = fluid.layers.reduce_sum(loss) loss = paddle.sum(loss)
return loss return loss
@property @property
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册