diff --git a/demo/conv_mnli/data b/demo/conv_mnli/data new file mode 120000 index 0000000000000000000000000000000000000000..030c2afe99651a7423754d2cd18d29ad2938aed8 --- /dev/null +++ b/demo/conv_mnli/data @@ -0,0 +1 @@ +../bert/data \ No newline at end of file diff --git a/demo/conv_mnli/model/__init__.py b/demo/conv_mnli/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/demo/conv_mnli/model/bert.py b/demo/conv_mnli/model/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce908c2fe3c2dce0916ab2da65a0260cf003be8 --- /dev/null +++ b/demo/conv_mnli/model/bert.py @@ -0,0 +1,116 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"dygraph transformer layers" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import json +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, to_variable, Layer, guard +from .transformer_encoder import ConvBNEncoderLayer +from .mobilenet_v1 import MobileNetV1 + + +class BertModelLayer(Layer): + def __init__(self, + emb_size=128, + hidden_size=768, + voc_size=30522, + max_position_seq_len=512, + sent_types=2, + return_pooled_out=True, + initializer_range=1.0, + conv_type="mobilenet", + use_fp16=False): + super(BertModelLayer, self).__init__() + + self._emb_size = emb_size + self._hidden_size = hidden_size + self._voc_size = voc_size + self._max_position_seq_len = max_position_seq_len + self._sent_types = sent_types + self.return_pooled_out = return_pooled_out + + self._word_emb_name = "s_word_embedding" + self._pos_emb_name = "s_pos_embedding" + self._sent_emb_name = "s_sent_embedding" + self._dtype = "float16" if use_fp16 else "float32" + + self._conv_type = conv_type + self._param_initializer = fluid.initializer.TruncatedNormal( + scale=initializer_range) + + self._src_emb = Embedding( + size=[self._voc_size, self._emb_size], + param_attr=fluid.ParamAttr( + name=self._word_emb_name, initializer=self._param_initializer), + dtype=self._dtype) + + self._pos_emb = Embedding( + size=[self._max_position_seq_len, self._emb_size], + param_attr=fluid.ParamAttr( + name=self._pos_emb_name, initializer=self._param_initializer), + dtype=self._dtype) + + self._sent_emb = Embedding( + size=[self._sent_types, self._emb_size], + param_attr=fluid.ParamAttr( + name=self._sent_emb_name, initializer=self._param_initializer), + dtype=self._dtype) + + self._emb_fac = Linear( + input_dim=self._emb_size, + output_dim=self._hidden_size, + param_attr=fluid.ParamAttr(name="s_emb_factorization")) + + # self.pooled_fc = Linear( + # input_dim=self._hidden_size, + # output_dim=self._hidden_size, + # param_attr=fluid.ParamAttr( + # name="s_pooled_fc.w_0", initializer=self._param_initializer), + # bias_attr="s_pooled_fc.b_0", + # act="tanh") + if self._conv_type == "conv_bn": + self._encoder = ConvBNEncoderLayer( + hidden_size=self._hidden_size, name="encoder") + elif self._conv_type == "mobilenet": + self._encoder = MobileNetV1(1, 3) + + def forward(self, src_ids, position_ids, sentence_ids): + src_emb = self._src_emb(src_ids) + pos_emb = self._pos_emb(position_ids) + sent_emb = self._sent_emb(sentence_ids) + + emb_out = src_emb + pos_emb + emb_out = emb_out + sent_emb + + emb_out = self._emb_fac(emb_out) + emb_out = fluid.layers.reshape( + emb_out, shape=[-1, 1, emb_out.shape[1], emb_out.shape[2]]) + enc_output = self._encoder(emb_out) + + # if not self.return_pooled_out: + # return enc_output + # next_sent_feat = fluid.layers.slice( + # input=enc_output, axes=[2], starts=[0], ends=[1]) + # next_sent_feat = fluid.layers.reshape( + # next_sent_feat, shape=[-1, self._hidden_size]) + # next_sent_feat = self.pooled_fc(next_sent_feat) + + return enc_output diff --git a/demo/conv_mnli/model/cls.py b/demo/conv_mnli/model/cls.py new file mode 100644 index 0000000000000000000000000000000000000000..730335d19cf58e548c35898f0214a5afe32eb64a --- /dev/null +++ b/demo/conv_mnli/model/cls.py @@ -0,0 +1,87 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"dygraph transformer layers" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import json +import numpy as np + +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph import Linear, Layer +from paddle.fluid.dygraph import to_variable + +from .bert import BertModelLayer + + +class ClsModelLayer(Layer): + """ + classify model + """ + + def __init__(self, num_labels, return_pooled_out=True): + super(ClsModelLayer, self).__init__() + self._hiden_size = 768 + + self.bert_layer = BertModelLayer( + emb_size=128, + hidden_size=self._hiden_size, + voc_size=30522, + max_position_seq_len=512, + sent_types=2, + return_pooled_out=True, + initializer_range=1.0, + conv_type="mobilenet", + use_fp16=False) + +# self.cls_fc = Linear( +# input_dim=self._hiden_size, +# output_dim=num_labels, +# param_attr=fluid.ParamAttr( +# name="cls_out_w", +# initializer=fluid.initializer.TruncatedNormal(scale=0.02)), +# bias_attr=fluid.ParamAttr( +# name="cls_out_b", +# initializer=fluid.initializer.Constant(0.))) + + def forward(self, data_ids): + """ + forward + """ + src_ids = to_variable(data_ids[0]) + position_ids = to_variable(data_ids[1]) + sentence_ids = to_variable(data_ids[2]) + input_mask = data_ids[3] + labels = to_variable(data_ids[4]) + + logits = self.bert_layer(src_ids, position_ids, sentence_ids) + # cls_feat = fluid.layers.dropout( + # x=next_sent_feat, + # dropout_prob=0.1, + # dropout_implementation="upscale_in_train") + # logit = self.cls_fc(cls_feat) + + ce_loss, probs = fluid.layers.softmax_with_cross_entropy( + logits=logits, label=labels, return_softmax=True) + loss = fluid.layers.mean(x=ce_loss) + + num_seqs = fluid.layers.create_tensor(dtype='int64') + accuracy = fluid.layers.accuracy( + input=probs, label=labels, total=num_seqs) + + return loss, accuracy, num_seqs diff --git a/demo/conv_mnli/model/mobilenet_v1.py b/demo/conv_mnli/model/mobilenet_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..38df7e39748e4ff185fee8ddf8d02c590dabcd2c --- /dev/null +++ b/demo/conv_mnli/model/mobilenet_v1.py @@ -0,0 +1,236 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#order: standard library, third party, local library +import os +import time +import sys +import math +import numpy as np +import argparse +import paddle +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear +from paddle.fluid.dygraph.base import to_variable +from paddle.fluid import framework + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + act='relu', + use_cudnn=True, + name=None): + super(ConvBNLayer, self).__init__() + + self._conv = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr( + initializer=MSRA(), name=self.full_name() + "_weights"), + bias_attr=False) + + self._batch_norm = BatchNorm( + num_filters, + act=act, + param_attr=ParamAttr(name=self.full_name() + "_bn" + "_scale"), + bias_attr=ParamAttr(name=self.full_name() + "_bn" + "_offset"), + moving_mean_name=self.full_name() + "_bn" + '_mean', + moving_variance_name=self.full_name() + "_bn" + '_variance') + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class DepthwiseSeparable(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + name=None): + super(DepthwiseSeparable, self).__init__() + + self._depthwise_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=int(num_filters1 * scale), + filter_size=3, + stride=stride, + padding=1, + num_groups=int(num_groups * scale), + use_cudnn=False) + + self._pointwise_conv = ConvBNLayer( + num_channels=int(num_filters1 * scale), + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0) + + def forward(self, inputs): + y = self._depthwise_conv(inputs) + y = self._pointwise_conv(y) + return y + + +class MobileNetV1(fluid.dygraph.Layer): + def __init__(self, scale=1.0, class_dim=1000): + super(MobileNetV1, self).__init__() + self.scale = scale + self.dwsl = [] + + self.conv1 = ConvBNLayer( + num_channels=1, + filter_size=3, + channels=3, + num_filters=int(32 * scale), + stride=2, + padding=1) + + dws21 = self.add_sublayer( + sublayer=DepthwiseSeparable( + num_channels=int(32 * scale), + num_filters1=32, + num_filters2=64, + num_groups=32, + stride=1, + scale=scale), + name="conv2_1") + self.dwsl.append(dws21) + + dws22 = self.add_sublayer( + sublayer=DepthwiseSeparable( + num_channels=int(64 * scale), + num_filters1=64, + num_filters2=128, + num_groups=64, + stride=2, + scale=scale), + name="conv2_2") + self.dwsl.append(dws22) + + dws31 = self.add_sublayer( + sublayer=DepthwiseSeparable( + num_channels=int(128 * scale), + num_filters1=128, + num_filters2=128, + num_groups=128, + stride=1, + scale=scale), + name="conv3_1") + self.dwsl.append(dws31) + + dws32 = self.add_sublayer( + sublayer=DepthwiseSeparable( + num_channels=int(128 * scale), + num_filters1=128, + num_filters2=256, + num_groups=128, + stride=2, + scale=scale), + name="conv3_2") + self.dwsl.append(dws32) + + dws41 = self.add_sublayer( + sublayer=DepthwiseSeparable( + num_channels=int(256 * scale), + num_filters1=256, + num_filters2=256, + num_groups=256, + stride=1, + scale=scale), + name="conv4_1") + self.dwsl.append(dws41) + + dws42 = self.add_sublayer( + sublayer=DepthwiseSeparable( + num_channels=int(256 * scale), + num_filters1=256, + num_filters2=512, + num_groups=256, + stride=2, + scale=scale), + name="conv4_2") + self.dwsl.append(dws42) + + for i in range(5): + tmp = self.add_sublayer( + sublayer=DepthwiseSeparable( + num_channels=int(512 * scale), + num_filters1=512, + num_filters2=512, + num_groups=512, + stride=1, + scale=scale), + name="conv5_" + str(i + 1)) + self.dwsl.append(tmp) + + dws56 = self.add_sublayer( + sublayer=DepthwiseSeparable( + num_channels=int(512 * scale), + num_filters1=512, + num_filters2=1024, + num_groups=512, + stride=2, + scale=scale), + name="conv5_6") + self.dwsl.append(dws56) + + dws6 = self.add_sublayer( + sublayer=DepthwiseSeparable( + num_channels=int(1024 * scale), + num_filters1=1024, + num_filters2=1024, + num_groups=1024, + stride=1, + scale=scale), + name="conv6") + self.dwsl.append(dws6) + + self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True) + + self.out = Linear( + int(1024 * scale), + class_dim, + param_attr=ParamAttr( + initializer=MSRA(), name=self.full_name() + "fc7_weights"), + bias_attr=ParamAttr(name="fc7_offset")) + + def forward(self, inputs): + y = self.conv1(inputs) + for dws in self.dwsl: + y = dws(y) + y = self.pool2d_avg(y) + y = fluid.layers.reshape(y, shape=[-1, 1024]) + y = self.out(y) + return y diff --git a/demo/conv_mnli/model/transformer_encoder.py b/demo/conv_mnli/model/transformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e6de5fbae2fc18b898c5044d5008fe6c528775 --- /dev/null +++ b/demo/conv_mnli/model/transformer_encoder.py @@ -0,0 +1,186 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"dygraph transformer layers" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from collections import Iterable + +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, Conv2D, BatchNorm, Pool2D, to_variable +from paddle.fluid.initializer import NormalInitializer + + +class ConvBNRelu(fluid.dygraph.Layer): + def __init__(self, + in_c=768, + out_c=768, + filter_size=[3, 1], + dilation=1, + is_test=False, + use_cudnn=True, + name=None): + super(ConvBNRelu, self).__init__() + self._name = name + conv_std = (2.0 / + (filter_size[0] * filter_size[1] * out_c * in_c))**0.5 + conv_param = fluid.ParamAttr( + name=name if name is None else (name + "_conv.weights"), + initializer=fluid.initializer.Normal(0.0, conv_std)) + + self.conv = Conv2D( + in_c, + out_c, + filter_size, + dilation=[dilation, 1], + padding=[(filter_size[0] - 1) * dilation // 2, 0], + param_attr=conv_param, + act=None, + bias_attr=False, + use_cudnn=use_cudnn) + self.bn = BatchNorm(out_c, act="relu", is_test=False) + + def forward(self, inputs): + conv = self.conv(inputs) + bn = self.bn(conv) + return conv + + +class GateConv(fluid.dygraph.Layer): + def __init__(self, + in_c=768, + out_c=768, + filter_size=[3, 1], + dilation=1, + is_test=False, + use_cudnn=True, + name=None): + super(GateConv, self).__init__() + conv_std = (2.0 / + (filter_size[0] * filter_size[1] * out_c * in_c))**0.5 + conv_param = fluid.ParamAttr( + name=name if name is None else (name + "_conv.weights"), + initializer=fluid.initializer.Normal(0.0, conv_std)) + + gate_param = fluid.ParamAttr( + name=name if name is None else (name + "_conv_gate.weights"), + initializer=fluid.initializer.Normal(0.0, conv_std)) + + self.conv = Conv2D( + in_c, + out_c, + filter_size, + dilation=[dilation, 1], + padding=[(filter_size[0] - 1) * dilation // 2, 0], + param_attr=conv_param, + act=None, + use_cudnn=use_cudnn) + + self.gate = Conv2D( + in_c, + out_c, + filter_size, + dilation=[dilation, 1], + padding=[(filter_size[0] - 1) * dilation // 2, 0], + param_attr=gate_param, + act="sigmoid", + use_cudnn=use_cudnn) + + def forward(self, inputs): + conv = self.conv(inputs) + gate = self.gate(inputs) + return conv * gate + + +class ConvBNBlock(Layer): + def __init__(self, + in_c, + out_c, + filter_size, + dillation=1, + block_size=4, + name=None): + super(ConvBNBlock, self).__init__() + self._in_c = in_c + self._out_c = out_c + self._filter_size = filter_size + self._dillation = dillation + self._block_sie = block_size + self.convs = [] + for n in range(block_size): + name = None if name is None else name + "_" + str(n) + conv = ConvBNRelu( + in_c=self._in_c, + out_c=self._out_c, + filter_size=self._filter_size, + dilation=self._dillation, + is_test=False, + use_cudnn=True, + name=name) + self.convs.append(conv) + + def forward(self, input): + tmp = input + for conv in self.convs: + tmp = conv(input) + return tmp + + +class ConvBNEncoderLayer(Layer): + def __init__(self, + filters=[3, 5, 7], + dillations=[1, 1, 1], + hidden_size=768, + name="encoder"): + super(ConvBNEncoderLayer, self).__init__() + cells = [] + self._hidden_size = hidden_size + + self.conv0 = ConvBNRelu( + in_c=1, + out_c=self._hidden_size, + filter_size=[3, self._hidden_size], + dilation=1, + is_test=False, + use_cudnn=True, + name="sten") + + self.blocks = [] + n = 0 + for filter_n, dillation_n in zip(filters, dillations): + name = None if name is None else name + "_block" + str(n) + block = ConvBNBlock( + self._hidden_size, + self._hidden_size, + filter_size=[filter_n, 1], + dillation=dillation_n, + block_size=4, + name=name) + self.blocks.append(block) + n += 1 + + def forward(self, enc_input): + tmp = fluid.layers.reshape( + enc_input, [-1, 1, enc_input.shape[1], + self._hidden_size]) #(bs, 1, seq_len, hidden_size) + + tmp = self.conv0(tmp) # (bs, hidden_size, seq_len, 1) + for block in self.blocks: + tmp = block(tmp) + + return tmp diff --git a/demo/conv_mnli/run.sh b/demo/conv_mnli/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..5d515ee2a473af3c237bdda99775cc2af4a6265c --- /dev/null +++ b/demo/conv_mnli/run.sh @@ -0,0 +1,2 @@ +export CUDA_VISIBLE_DEVICES=0 +python train.py diff --git a/demo/conv_mnli/train.py b/demo/conv_mnli/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d742cb90992a9b0219f9468776ab67c099100539 --- /dev/null +++ b/demo/conv_mnli/train.py @@ -0,0 +1,85 @@ +import time +import paddle.fluid as fluid +from paddleslim.teachers.bert.reader.cls import * +from model.cls import ClsModelLayer + + +def main(): + place = fluid.CUDAPlace(0) + + BERT_BASE_PATH = "./data/pretrained_models/uncased_L-12_H-768_A-12/" + bert_config_path = BERT_BASE_PATH + "/bert_config.json" + vocab_path = BERT_BASE_PATH + "/vocab.txt" + data_dir = "./data/glue_data/MNLI/" + max_seq_len = 128 + do_lower_case = True + batch_size = 128 + epoch = 600 + + processor = MnliProcessor( + data_dir=data_dir, + vocab_path=vocab_path, + max_seq_len=max_seq_len, + do_lower_case=do_lower_case, + in_tokens=False) + + train_reader = processor.data_generator( + batch_size=batch_size, + phase='train', + epoch=1, + dev_count=1, + shuffle=True) + + val_reader = processor.data_generator( + batch_size=batch_size, + phase='dev', + epoch=1, + dev_count=1, + shuffle=False) + + with fluid.dygraph.guard(place): + model = ClsModelLayer(3) + optimizer = fluid.optimizer.MomentumOptimizer( + 0.001, + 0.9, + regularization=fluid.regularizer.L2DecayRegularizer(3e-4), + parameter_list=model.parameters()) + + for i in range(epoch): + model.train() + losses = [] + accs = [] + start = time.time() + for step_id, data in enumerate(train_reader()): + loss, acc, num = model(data) + loss.backward() + optimizer.minimize(loss) + model.clear_gradients() + losses.append(loss.numpy()) + accs.append(acc.numpy()) + if step_id % 50 == 0: + time_cost = time.time() - start + start = time.time() + speed = time_cost / 50.0 + print( + "Train iter-[{}]-[{}] - loss: {:.4f}; acc: {:.4f}; speed: {:.3f}s/step; time: {}". + format(i, step_id, + np.mean(losses), + np.mean(accs), speed, + time.asctime(time.localtime()))) + losses = [] + accs = [] + + model.eval() + losses = [] + accs = [] + for step_id, data in enumerate(val_reader()): + loss, acc, num = model(data) + losses.append(loss.numpy()) + accs.append(acc.numpy()) + print("Eval epoch [{}]- loss: {:.4f}; acc: {:.4f}".format( + i, np.mean(losses), np.mean(accs))) + + +if __name__ == '__main__': + main() diff --git a/paddleslim/teachers/bert/reader/cls.py b/paddleslim/teachers/bert/reader/cls.py index e05f02a3a99dc9aeaae88c8f7f6d6986c9b0121c..835257bd1dc5fc9e5a331d876f34400bc5f9b1e8 100644 --- a/paddleslim/teachers/bert/reader/cls.py +++ b/paddleslim/teachers/bert/reader/cls.py @@ -151,7 +151,7 @@ class DataProcessor(object): elif phase == 'search_valid': examples = self.get_train_examples(self.data_dir) self.num_examples['search_valid'] = len(examples) / 2 - examples = examples[self.num_examples['search_train']:] + examples = examples[self.num_examples['search_valid']:] else: raise ValueError( "Unknown phase, which should be in ['train', 'dev', 'test'].")