提交 86835b86 编写于 作者: L liujiaxiang

fix the comment and delete useless code

上级 a9d30244
#!/usr/bin/env python
# -*- coding: utf-8 -*-
########################################################################
# #
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved #
# #
########################################################################
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.layers as L
import paddle.fluid.layers as layers
from pgl.utils import paddle_helper
import pgl
def masked_select(input, mask):
"""masked_select
......@@ -28,7 +37,6 @@ def masked_select(input, mask):
return L.gather(input, index, overwrite=False)
class BigBirdWrapper(pgl.graph_wrapper.BaseGraphWrapper):
"""Implement of Edge Drop """
def __init__(self, input_mask):
......@@ -51,6 +59,7 @@ class BigBirdWrapper(pgl.graph_wrapper.BaseGraphWrapper):
self._edge_uniq_dst_count = L.concat([uniq_count, last])
self._edge_uniq_dst_count.stop_gradient=True
def select_edges(src, dst, input_mask, num_nodes, max_seqlen):
src = fluid.layers.elementwise_max(src, num_nodes * 0)
dst = fluid.layers.elementwise_max(dst, num_nodes * 0)
......@@ -74,6 +83,7 @@ def select_edges(src, dst, input_mask, num_nodes, max_seqlen):
src = masked_select(src, mask)
return src, dst
def uniq_edges(src, dst, num_nodes):
sorted_dst = L.cast(dst, dtype="int64")
sorted_src = L.cast(src, dtype="int64")
......@@ -88,38 +98,6 @@ def uniq_edges(src, dst, num_nodes):
return sorted_src, sorted_dst
#def build_edges(num_nodes, input_mask, max_seqlen):
# edges = L.range(start=0, end=num_nodes, step=1, dtype="int32")
# all_edges = []
# # Window
# filter_func = lambda x, y: select_edges(x, y, input_mask, num_nodes, max_seqlen)
#
# all_edges.append(filter_func(edges - 1, edges)) # win-1
# all_edges.append(filter_func(edges + 1, edges)) # win-2
# all_edges.append(filter_func(edges, edges)) #self-loop
#
# # Global Assume [CLS] is the first token.
# cls_position = edges / max_seqlen * max_seqlen
# all_edges.append(filter_func(cls_position, edges))
# all_edges.append(filter_func(edges, cls_position))
#
# # Random
# for i in range(2):
# rand_edge = L.floor(L.uniform_random(min=0, max=1, shape=[num_nodes]) * L.cast(max_seqlen, dtype="float32"))
# rand_edge = L.cast(rand_edge, dtype="int32") + cls_position
# all_edges.append(filter_func(rand_edge, edges))
#
# if len(all_edges) > 1:
# src = L.concat([ s for s, d in all_edges], 0)
# dst = L.concat([ d for s, d in all_edges], 0)
# else:
# src = all_edges[0][0]
# dst = all_edges[0][1]
#
# # sort edges
# sorted_src, sorted_dst = uniq_edges(src, dst, num_nodes)
# return sorted_src, sorted_dst
def build_edges(num_nodes, input_mask, max_seqlen):
edges = L.range(start=0, end=num_nodes, step=1, dtype="int32")
all_edges = []
......@@ -127,35 +105,19 @@ def build_edges(num_nodes, input_mask, max_seqlen):
filter_func = lambda x, y: select_edges(x, y, input_mask, num_nodes, max_seqlen)
all_edges.append(filter_func(edges - 1, edges)) # win-1
#all_edges.append(filter_func(edges - 2, edges)) # win-1
#all_edges.append(filter_func(edges - 3, edges)) # win-1
all_edges.append(filter_func(edges + 1, edges)) # win-2
#all_edges.append(filter_func(edges + 2, edges)) # win-2
#all_edges.append(filter_func(edges + 3, edges)) # win-2
all_edges.append(filter_func(edges, edges)) #self-loop
# Global Assume [CLS] is the first token.
# vertical cls-window attention
cls_position = edges / max_seqlen * max_seqlen
#all_edges.append(filter_func(cls_position + 1, edges))
all_edges.append(filter_func(cls_position, edges))
# vertical sliding attention
#all_edges.append(filter_func(cls_position + 6, edges))
#all_edges.append(filter_func(cls_position + max_seqlen - 6, edges))
# horizontal cls attention
all_edges.append(filter_func(edges, cls_position))
#all_edges.append(filter_func(edges, cls_position))
# horizontal sliding attention
#all_edges.append(filter_func(edges, cls_position + 6)
#all_edges.append(filter_func(edges, cls_position + max_seq_len - 6)
# Random
#for i in range(2):
for i in range(2):
rand_edge = L.floor(L.uniform_random(min=0, max=1, shape=[num_nodes]) * L.cast(max_seqlen, dtype="float32"))
rand_edge = L.cast(rand_edge, dtype="int32") + cls_position
......@@ -173,8 +135,6 @@ def build_edges(num_nodes, input_mask, max_seqlen):
return sorted_src, sorted_dst
def sparse_scaled_dot_product_attention(q, k, v, input_mask, dropout_rate, n_head, d_key, d_value):
def send_q_k_spmm(src_feat, dst_feat, edge_feat):
# q [ num_edges, n_head * dim]
......@@ -221,4 +181,3 @@ def sparse_scaled_dot_product_attention(q, k, v, input_mask, dropout_rate, n_hea
return out, out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册