gather_nd函数index tensor梯度回传
Created by: hxk11111
-
版本、环境信息: 1)PaddlePaddle版本:1.6.2 2)python:3.6.3
-
训练信息 1)单机单卡
-
问题描述
def get_indexing_tensor(spatial_features, batch_size, k=10):
_, n_max_entries, _ = spatial_features.shape
# Shape of neighbor_matrix: (batch_size, n_max_entries, k)
neighbor_matrix, distance_matrix = nearest_neighbor_matrix(spatial_features, k)
batch_range = fluid.layers.unsqueeze(
fluid.layers.reshape(fluid.layers.range(0, batch_size, 1, dtype="int32"), [batch_size]),
axes=[1, 2, 3])
batch_range.stop_gradient = True
# Shape of batch_range: (batch_size, n_max_entries, k, 1)
batch_range = fluid.layers.expand(batch_range, [1, n_max_entries, k, 1])
# Shape of expanded_neighbor_matrix: (batch_size, n_max_entries, k, 1)
expanded_neighbor_matrix = fluid.layers.unsqueeze(neighbor_matrix, axes=[3])
# Shape of indexing_tensor: (batch_size, n_max_entries, k, 2)
expanded_neighbor_matrix = fluid.layers.cast(expanded_neighbor_matrix, dtype="int32")
indexing_tensor = fluid.layers.concat([batch_range, expanded_neighbor_matrix], axis=3)
return indexing_tensor, distance_matrix
def edge_conv_layer(input_feature,
batch_size,
num_neighbors=30,
dense_layers=(64, 64, 64),
aggregation_function=fluid.layers.reduce_max,
edge_activation=None):
# Shape of indexing: (batch_size, n_max_entries, k, 2)
indexing, _ = get_indexing_tensor(input_feature, batch_size, num_neighbors)
# Shape of neighbour_space: (batch_size, n_max_entries, k, input_feature.shape[-1])
neighbour_space = fluid.layers.gather_nd(input_feature, indexing, name="neighbour_space")
neighbour_space = fluid.layers.Print(neighbour_space)
# Shape of expanded_input_feature: (batch_size, n_max_entries, 1, input_feature.shape[-1])
expanded_input_feature = fluid.layers.unsqueeze(input_feature, axes=[2])
expanded_input_feature = fluid.layers.expand(expanded_input_feature, [1, 1, num_neighbors, 1])
diff = expanded_input_feature - neighbour_space
edge = fluid.layers.concat([expanded_input_feature, diff], axis=-1)
for units in dense_layers:
edge = fluid.layers.fc(edge, units, num_flatten_dims=3, act="relu")
if edge_activation is not None:
edge = edge_activation(edge)
output_feature = aggregation_function(edge, dim=2)
return output_feature
在edge_conv_layer函数中,input_feature为(batch, node_nums, feature_size)。在构造neighbour_space时,希望input_feature各node间计算距离,取距离最近的topk个node的feature共同组成当前node feature,因此用到了gather_nd函数,其中indexing即为最邻近k个node的索引,目前需要梯度能够经过indexing回传,但gather_nd函数貌似不支持index分支回传梯度,不知道是否有好的建议。