提交 825d9740 编写于 作者: Z zhangdengcheng

Fixed the bug that mean aggregator argument can not pass to base class and add...

Fixed the bug that mean aggregator argument can not pass to base class and add attention head for GAT
上级 fb4b16a5
......@@ -64,7 +64,7 @@ class GNNFeatureTransform(nn.Cell):
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
"""
@cell_attr_register(attrs=['has_bias', 'activation'])
@cell_attr_register
def __init__(self,
in_channels,
out_channels,
......@@ -125,7 +125,7 @@ class _BaseAggregator(nn.Cell):
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
Examples:
>>> class MyAggregator(_BaseAggregator):
......@@ -203,12 +203,12 @@ class MeanAggregator(_BaseAggregator):
super(MeanAggregator, self).__init__(
feature_in_dim,
feature_out_dim,
use_fc=True,
weight_init="normal",
bias_init="zeros",
has_bias=True,
dropout_ratio=None,
activation=None)
use_fc,
weight_init,
bias_init,
has_bias,
dropout_ratio,
activation)
self.reduce_mean = P.ReduceMean(keep_dims=False)
def construct(self, input_feature):
......@@ -220,3 +220,157 @@ class MeanAggregator(_BaseAggregator):
input_feature = self.activation(input_feature)
output_feature = self.reduce_mean(input_feature, 1)
return output_feature
class AttentionHead(nn.Cell):
"""
Attention Head for Graph Attention Networks.
Args:
in_channel (int): The number of input channel, input feature dim.
out_channel (int): The number of output channel, output feature dim.
in_drop_ratio (float): Input feature dropout ratio, default 0.0.
coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
residual (bool): Whether to use residual connection, default False.
coef_activation (Cell): The attention coefficient activation function,
default nn.LeakyReLU().
activation (Cell): The output activation function, default nn.ELU().
Inputs:
- **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
- **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
Examples:
>>> head = AttentionHead(1433,
8,
in_drop_ratio=0.6,
coef_drop_ratio=0.6,
residual=False)
>>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtypy=np.float32))
>>> output = net(input_data)
"""
def __init__(self,
in_channel,
out_channel,
in_drop_ratio=0.0,
coef_drop_ratio=0.0,
residual=False,
coef_activation=nn.LeakyReLU(),
activation=nn.ELU()):
super(AttentionHead, self).__init__()
self.in_channel = check_int_positive(in_channel)
self.out_channel = check_int_positive(out_channel)
self.in_drop_ratio = in_drop_ratio
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
self.feature_transform = GNNFeatureTransform(
in_channels=self.in_channel,
out_channels=self.out_channel,
has_bias=False)
self.f_1_transform = GNNFeatureTransform(
in_channels=self.out_channel,
out_channels=1)
self.f_2_transform = GNNFeatureTransform(
in_channels=self.out_channel,
out_channels=1)
self.softmax = nn.Softmax()
self.coef_drop = nn.Dropout(keep_prob=1 - coef_drop_ratio)
self.batch_matmul = P.BatchMatMul()
self.bias_add = P.BiasAdd()
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
self.residual = check_bool(residual)
if self.residual:
if in_channel != out_channel:
self.residual_transform_flag = True
self.residual_transform = GNNFeatureTransform(
in_channels=self.in_channel,
out_channels=self.out_channel)
else:
self.residual_transform = None
self.coef_activation = coef_activation
self.activation = activation
def construct(self, input_feature, bias_mat):
input_feature = self.in_drop(input_feature)
feature = self.feature_transform(input_feature)
# self attention following the author
f_1 = self.f_1_transform(feature)
f_2 = self.f_2_transform(feature)
logits = f_1 + P.Transpose()(f_2, (0, 2, 1))
logits = self.coef_activation(logits) + bias_mat
coefs = self.softmax(logits)
coefs = self.coef_drop(coefs)
feature = self.in_drop_2(feature)
ret = self.batch_matmul(coefs, feature)
ret = P.Squeeze(0)(ret)
ret = self.bias_add(ret, self.bias)
ret = P.ExpandDims()(ret, 0)
# residual connection
if self.residual:
if self.residual_transform_flag:
res = self.residual_transform(input_feature)
ret = ret + res
else:
ret = ret + input_feature
# activation
ret = self.activation(ret)
return ret
class AttentionAggregator(nn.Cell):
"""
Attention Head for Graph Attention Networks,can be regarded as one
GAT layer.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
num_heads (int): Number of attention heads for this layer, default 1.
in_drop_ratio (float): Input feature dropout ratio, default 0.0.
coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
activation (Cell): The output activation function, default nn.ELU().
residual (bool): Whether to use residual connection, default False.
Inputs:
- **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
- **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
Examples:
>>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
>>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
>>> net = AttentionAggregator(1433,
8,
8)
>>> net(input_data, biases)
"""
def __init__(self,
in_channels,
out_channels,
num_heads=1,
in_drop=0.0,
coef_drop=0.0,
activation=nn.ELU(),
residual=False):
super(AttentionAggregator, self).__init__()
self.num_heads = num_heads
self.attns = []
for _ in range(num_heads):
self.attns.append(AttentionHead(in_channels,
out_channels,
in_drop_ratio=in_drop,
coef_drop_ratio=coef_drop,
activation=activation,
residual=residual))
self.attns = nn.layer.CellList(self.attns)
def construct(self, input_data, bias_mat):
res = ()
for i in range(self.num_heads):
res += (self.attns[i](input_data, bias_mat),)
return P.Concat(-1)(res)
......@@ -20,7 +20,7 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore.common.api import _executor
import mindspore.ops.composite as C
from aggregator import MeanAggregator
from aggregator import MeanAggregator, AttentionHead, AttentionAggregator
context.set_context(mode=context.GRAPH_MODE)
......@@ -51,3 +51,22 @@ def test_MeanAggregator_grad():
sens = Tensor(np.ones([32, 64]).astype(np.float32))
grad_op = MeanAggregatorGrad(aggregator)
_executor.compile(grad_op, input_data, sens)
def test_AttentionHead():
"""Compile AttentionHead forward graph"""
head = AttentionHead(1433,
8,
in_drop_ratio=0.6,
coef_drop_ratio=0.6,
residual=False)
input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
_executor.compile(head, input_data, biases)
def test_AttentionAggregator():
input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
net = AttentionAggregator(1433, 8, 8)
_executor.compile(net, input_data, biases)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册