提交 649ffd9e 编写于 作者: W wanghaoshuang

Add conv bert search space based DARTS.

上级 b4d29614
......@@ -8,7 +8,7 @@ with fluid.dygraph.guard(place):
bert = BERTClassifier(3)
bert.fit("./data/glue_data/MNLI/",
5,
batch_size=16,
batch_size=32,
use_data_parallel=True,
learning_rate=0.00005,
save_steps=1000)
......@@ -21,7 +21,72 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, Conv2D, BatchNorm
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, Conv2D, BatchNorm, Pool2D, to_variable
from paddle.fluid.initializer import NormalInitializer
PRIMITIVES = [
'std_conv_3', 'std_conv_5', 'std_conv_7', 'dil_conv_3', 'dil_conv_5',
'dil_conv_7', 'avg_pool_3', 'max_pool_3', 'none', 'skip_connect'
]
OPS = {
'std_conv_3': lambda : ConvBN(1, 1, filter_size=3, dilation=1),
'std_conv_5': lambda : ConvBN(1, 1, filter_size=5, dilation=1),
'std_conv_7': lambda : ConvBN(1, 1, filter_size=7, dilation=1),
'dil_conv_3': lambda : ConvBN(1, 1, filter_size=3, dilation=2),
'dil_conv_5': lambda : ConvBN(1, 1, filter_size=5, dilation=2),
'dil_conv_7': lambda : ConvBN(1, 1, filter_size=7, dilation=2),
'avg_pool_3': lambda : Pool2D(pool_size=(3, 1), pool_type='avg'),
'max_pool_3': lambda : Pool2D(pool_size=(3, 1), pool_type='max'),
'none': lambda : Zero(),
'skip_connect': lambda : Identity(),
}
class MixedOp(fluid.dygraph.Layer):
def __init__(self):
super(MixedOp, self).__init__()
ops = [OPS[primitive]() for primitive in PRIMITIVES]
self._ops = fluid.dygraph.LayerList(ops)
def forward(self, x, weights):
for i in range(len(self._ops)):
if weights[i] != 0:
return self._ops[i](x) * weights[i]
class Zero(fluid.dygraph.Layer):
def __init__(self):
super(Zero, self).__init__()
def forward(self, x):
x = fluid.layers.zeros_like(x)
return x
class Identity(fluid.dygraph.Layer):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
def gumbel_softmax(logits, temperature=0.1, hard=True, eps=1e-20):
U = np.random.uniform(0, 1, logits.shape)
logits = logits - to_variable(
np.log(-np.log(U + eps) + eps).astype("float32"))
logits = logits / temperature
logits = fluid.layers.softmax(logits)
if hard:
maxes = fluid.layers.reduce_max(logits, dim=1, keep_dim=True)
hard = fluid.layers.cast((logits == maxes), logits.dtype)
tmp = hard - logits
tmp.stop_gradient = True
out = tmp + logits
return out
class ConvBN(fluid.dygraph.Layer):
......@@ -55,30 +120,31 @@ class ConvBN(fluid.dygraph.Layer):
return bn
class EncoderSubLayer(Layer):
"""
EncoderSubLayer
"""
class Cell(fluid.dygraph.Layer):
def __init__(self, steps):
super(Cell, self).__init__()
self._steps = steps
def __init__(self, name=""):
ops = []
for i in range(self._steps):
for j in range(2 + i):
op = MixedOp()
ops.append(op)
self._ops = fluid.dygraph.LayerList(ops)
super(EncoderSubLayer, self).__init__()
self.name = name
self.conv0 = ConvBN(1, 1, filter_size=5)
self.conv1 = ConvBN(1, 1, filter_size=5)
self.conv2 = ConvBN(1, 1, filter_size=5)
def forward(self, s0, s1, weights, weights2=None):
def forward(self, enc_input):
"""
forward
:param enc_input:
:param attn_bias:
:return:
"""
tmp = self.conv0(enc_input)
tmp = self.conv1(tmp)
tmp = self.conv2(tmp)
return tmp
states = [s0, s1]
offset = 0
for i in range(self._steps):
s = fluid.layers.sums([
self._ops[offset + j](h, weights[offset + j])
for j, h in enumerate(states)
])
offset += len(states)
states.append(s)
out = fluid.layers.sum(states[-self._steps:])
return out
class EncoderLayer(Layer):
......@@ -89,15 +155,23 @@ class EncoderLayer(Layer):
def __init__(self, n_layer, d_model=128, name=""):
super(EncoderLayer, self).__init__()
self._encoder_sublayers = list()
cells = []
self._n_layer = n_layer
self._d_model = d_model
self._steps = 3
cells = []
for i in range(n_layer):
self._encoder_sublayers.append(
self.add_sublayer(
'esl_%d' % i,
EncoderSubLayer(name=name + '_layer_' + str(i))))
cells.append(Cell(steps=self._steps))
self._cells = fluid.dygraph.LayerList(cells)
k = sum(1 for i in range(self._steps) for n in range(2 + i))
num_ops = len(PRIMITIVES)
self.alphas = fluid.layers.create_parameter(
shape=[k, num_ops],
dtype="float32",
default_initializer=NormalInitializer(
loc=0.0, scale=1e-3))
def forward(self, enc_input):
"""
......@@ -108,10 +182,14 @@ class EncoderLayer(Layer):
"""
tmp = fluid.layers.reshape(enc_input,
[-1, 1, enc_input.shape[1], self._d_model])
alphas = gumbel_softmax(self.alphas)
outputs = []
for i in range(self._n_layer):
tmp = self._encoder_sublayers[i](tmp)
s0 = s1 = tmp
for i, cell in enumerate(self._cells):
s0, s1 = s1, cell(s0, s1, alphas)
enc_output = fluid.layers.reshape(
tmp, [-1, enc_input.shape[1], self._d_model])
s1, [-1, enc_input.shape[1], self._d_model])
outputs.append(enc_output)
return outputs
......@@ -204,7 +204,7 @@ class BERTClassifier(Layer):
data_ids)
optimizer.optimization(
losses[-1],
total_loss,
use_data_parallel=use_data_parallel,
model=self.cls_model)
self.cls_model.clear_gradients()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册