Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • 合并请求
  • !23543

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板

Add cholesky_op !23543

  • Report abuse
!23543 已合并 4月 07, 2020 由 saxon_zh@saxon_zh 创建
#<User:0x00007f0ef9097388>
  • 概览 5
  • 提交 19
  • 变更 20

Created by: guoshengCS

Add cholesky_op

预览暂时放在了paddle.fluid.layers下 image

由于无法使用OpTest来进行梯度检查,换用gradient_checker.grad_check进行了梯度检查。另外进行了对比测试,测试脚本如下

import numpy as np

use_gpu = True

np.random.seed(1000)
a = np.random.rand(2, 1, 3, 3).astype("float64")
rank = len(a.shape)
a_t = a.transpose(list(range(rank-2)) + [rank-1, rank-2])
a = np.matmul(a, a_t) + 1e-03

inp = a

import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
fluid.enable_dygraph(place)
x = to_variable(inp)
x.stop_gradient = False
l = paddle.cholesky(x, upper=False)
loss = fluid.layers.reduce_sum(l, dim=list(range(rank)))
loss.backward()
pd_values = (l.numpy(), x._grad_ivar().numpy())

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)
with tf.device('/GPU:0' if use_gpu else '/CPU:0'):
    with tf.GradientTape() as tape:
        x = tf.Variable(inp)
        l = tf.linalg.cholesky(x)
        loss = tf.math.reduce_sum(l)
        grad = tape.gradient(loss, x)
tf_values = (l.numpy(), grad.numpy())
assert np.allclose(pd_values[0], tf_values[0], rtol=1e-07, atol=1e-08)
assert np.allclose(pd_values[1], tf_values[1], rtol=1e-07, atol=1e-08)

import torch
x = torch.Tensor(inp).cuda() if use_gpu else torch.Tensor(inp)
x.requires_grad=True
l = torch.cholesky(x, upper=False)
loss = l.sum()
loss.backward()
torch_values = (l.cpu().detach().numpy(), x.grad.cpu().numpy())
assert np.allclose(pd_values[0], torch_values[0], rtol=1e-07, atol=1e-08)
assert np.allclose(pd_values[1], torch_values[1], rtol=1e-05, atol=1e-08)

print(torch_values[1])
print(pd_values[1])
指派人
分配到
审核者
Request review from
无
里程碑
无
分配里程碑
工时统计
标识: paddlepaddle/Paddle!23543
Source branch: github/fork/guoshengCS/add-cholesky
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7