Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
7e8f9f53
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7e8f9f53
编写于
4月 20, 2022
作者:
Q
qingen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[vec][layer] add GRL to domain adaptation, test=doc fix #1724
上级
9382ad8a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
76 addition
and
0 deletion
+76
-0
paddlespeech/vector/modules/layer.py
paddlespeech/vector/modules/layer.py
+76
-0
未找到文件。
paddlespeech/vector/modules/layer.py
0 → 100644
浏览文件 @
7e8f9f53
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
from
paddle.autograd
import
PyLayer
class
GradientReversalFunction
(
PyLayer
):
"""Gradient Reversal Layer from:
Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
Forward pass is the identity function. In the backward pass,
the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
"""
@
staticmethod
def
forward
(
ctx
,
x
,
lambda_
=
1
):
"""Forward in networks
"""
ctx
.
save_for_backward
(
lambda_
)
return
x
.
clone
()
@
staticmethod
def
backward
(
ctx
,
grads
):
"""Backward in networks
"""
lambda_
,
=
ctx
.
saved_tensor
()
dx
=
-
lambda_
*
grads
return
dx
class
GradientReversalLayer
(
nn
.
Layer
):
"""Gradient Reversal Layer from:
Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
Forward pass is the identity function. In the backward pass,
the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
"""
def
__init__
(
self
,
lambda_
=
1
):
super
(
GradientReversalLayer
,
self
).
__init__
()
self
.
lambda_
=
lambda_
def
forward
(
self
,
x
):
"""Forward in networks
"""
return
GradientReversalFunction
.
apply
(
x
,
self
.
lambda_
)
if
__name__
==
"__main__"
:
paddle
.
set_device
(
"cpu"
)
data
=
paddle
.
randn
([
2
,
3
],
dtype
=
"float64"
)
data
.
stop_gradient
=
False
grl
=
GradientReversalLayer
(
1
)
out
=
grl
(
data
)
out
.
mean
().
backward
()
print
(
data
.
grad
)
data
=
paddle
.
randn
([
2
,
3
],
dtype
=
"float64"
)
data
.
stop_gradient
=
False
grl
=
GradientReversalLayer
(
-
1
)
out
=
grl
(
data
)
out
.
mean
().
backward
()
print
(
data
.
grad
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录