Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
96642a76
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
96642a76
编写于
8月 06, 2020
作者:
M
mamba_ni
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support cusolver AND OPS cholesky_solve
fix bug clang-format format fix
上级
a375c50c
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
339 addition
and
3 deletion
+339
-3
mindspore/ccsrc/CMakeLists.txt
mindspore/ccsrc/CMakeLists.txt
+2
-1
mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc
...end/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc
+23
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h
...kend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h
+254
-0
mindspore/ccsrc/runtime/device/gpu/gpu_common.h
mindspore/ccsrc/runtime/device/gpu/gpu_common.h
+16
-0
mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc
mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc
+8
-1
mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h
mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h
+4
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+1
-1
mindspore/ops/operations/_thor_ops.py
mindspore/ops/operations/_thor_ops.py
+31
-0
未找到文件。
mindspore/ccsrc/CMakeLists.txt
浏览文件 @
96642a76
...
...
@@ -271,7 +271,8 @@ if (ENABLE_GPU)
${
CUDA_PATH
}
/lib64/libcurand.so
${
CUDNN_PATH
}
/lib64/libcudnn.so
${
CUDA_PATH
}
/lib64/libcudart.so
${
CUDA_PATH
}
/lib64/stubs/libcuda.so
)
${
CUDA_PATH
}
/lib64/stubs/libcuda.so
${
CUDA_PATH
}
/lib64/libcusolver.so
)
if
(
ENABLE_MPI
)
set_target_properties
(
_ms_mpi PROPERTIES INSTALL_RPATH
${
ORIGIN_PATH
}
)
endif
()
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc
0 → 100644
浏览文件 @
96642a76
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
Cholesky
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
CholeskyGpuKernel
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h
0 → 100644
浏览文件 @
96642a76
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H
#define MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <vector>
#include "backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "utils/convert_utils.h"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
CholeskyGpuKernel
:
public
GpuKernel
{
public:
CholeskyGpuKernel
()
:
batch_
(
0
),
m_
(
0
),
lda_
(
0
),
is_null_input_
(
false
),
handle_
(
nullptr
)
{}
~
CholeskyGpuKernel
()
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
if
(
is_null_input_
)
{
return
true
;
}
if
(
!
use_split_matrix
)
{
auto
input1_addr
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
auto
output_addr
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
auto
d_array_addr
=
GetDeviceAddress
<
T
*>
(
workspace
,
0
);
auto
d_identity_addr
=
GetDeviceAddress
<
T
*>
(
workspace
,
1
);
auto
d_info_array_addr
=
GetDeviceAddress
<
int
>
(
workspace
,
2
);
for
(
size_t
i
=
0
;
i
<
batch_
;
i
++
)
{
h_array
[
i
]
=
input1_addr
+
i
*
lda_
*
m_
;
h_identity
[
i
]
=
output_addr
+
i
*
ldb_
*
m_
;
CHECK_CUDA_RET_WITH_ERROR
(
cudaMemcpyAsync
(
output_addr
+
i
*
ldb_
*
m_
,
h_identity_data
.
data
(),
sizeof
(
T
)
*
ldb_
*
m_
,
cudaMemcpyHostToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"cuda memcopy Fail"
);
}
CHECK_CUDA_RET_WITH_ERROR
(
cudaMemcpyAsync
(
d_array_addr
,
h_array
.
data
(),
sizeof
(
T
*
)
*
batch_
,
cudaMemcpyHostToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"cuda memcopy Fail"
);
CHECK_CUDA_RET_WITH_ERROR
(
cudaMemcpyAsync
(
d_identity_addr
,
h_identity
.
data
(),
sizeof
(
T
*
)
*
batch_
,
cudaMemcpyHostToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"cuda memcopy Fail"
);
CHECK_CUSOLVER_RET_WITH_EXCEPT
(
cusolverDnSpotrfBatched
(
handle_
,
uplo
,
m_
,
d_array_addr
,
lda_
,
d_info_array_addr
,
batch_
),
"cusolver cholesky batched Fail"
);
float
alpha
=
1
;
CHECK_CUBLAS_RET_WITH_EXCEPT
(
cublasStrsmBatched
(
blas_handle_
,
CUBLAS_SIDE_LEFT
,
uplo
,
CUBLAS_OP_N
,
CUBLAS_DIAG_NON_UNIT
,
m_
,
m_
,
&
alpha
,
d_array_addr
,
lda_
,
d_identity_addr
,
ldb_
,
batch_
),
"cublas trsm batched Fail"
);
}
else
{
auto
input1_addr
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
auto
output_addr
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
auto
d_array_addr
=
GetDeviceAddress
<
T
*>
(
workspace
,
0
);
auto
d_identity_addr
=
GetDeviceAddress
<
T
*>
(
workspace
,
1
);
auto
d_info_array_addr
=
GetDeviceAddress
<
int
>
(
workspace
,
2
);
auto
d_batch_input_addr
=
GetDeviceAddress
<
T
>
(
workspace
,
3
);
for
(
size_t
i
=
0
;
i
<
batch_
;
i
++
)
{
h_array
[
i
]
=
d_batch_input_addr
+
i
*
lda_
*
m_
;
h_identity
[
i
]
=
output_addr
+
i
*
ldb_
*
m_
;
}
Identity
(
batch_
*
split_dim
*
split_dim
,
split_dim
,
output_addr
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
MatrixSplit
(
batch_
*
split_dim
*
split_dim
,
split_dim
,
width
,
input1_addr
,
d_batch_input_addr
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CHECK_CUDA_RET_WITH_ERROR
(
cudaMemcpyAsync
(
d_array_addr
,
h_array
.
data
(),
sizeof
(
T
*
)
*
batch_
,
cudaMemcpyHostToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"cuda memcopy Fail"
);
CHECK_CUDA_RET_WITH_ERROR
(
cudaMemcpyAsync
(
d_identity_addr
,
h_identity
.
data
(),
sizeof
(
T
*
)
*
batch_
,
cudaMemcpyHostToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"cuda memcopy Fail"
);
CHECK_CUSOLVER_RET_WITH_EXCEPT
(
cusolverDnSpotrfBatched
(
handle_
,
uplo
,
m_
,
d_array_addr
,
lda_
,
d_info_array_addr
,
batch_
),
"cusolver cholesky batched Fail"
);
float
alpha
=
1
;
CHECK_CUBLAS_RET_WITH_EXCEPT
(
cublasStrsmBatched
(
blas_handle_
,
CUBLAS_SIDE_LEFT
,
uplo
,
CUBLAS_OP_N
,
CUBLAS_DIAG_NON_UNIT
,
m_
,
m_
,
&
alpha
,
d_array_addr
,
lda_
,
d_identity_addr
,
ldb_
,
batch_
),
"cublas trsm batched Fail"
);
}
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
handle_
=
device
::
gpu
::
GPUDeviceManager
::
GetInstance
().
GetCusolverDnHandle
();
blas_handle_
=
device
::
gpu
::
GPUDeviceManager
::
GetInstance
().
GetCublasHandle
();
auto
in_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
split_dim
=
GetAttr
<
int
>
(
kernel_node
,
"split_dim"
);
if
(
split_dim
==
0
)
{
use_split_matrix
=
false
;
if
(
in_shape
.
size
()
==
2
)
{
batch_
=
1
;
if
(
in_shape
[
0
]
!=
in_shape
[
1
])
{
MS_LOG
(
ERROR
)
<<
"Cholesky need square matrix as input."
;
}
}
else
if
(
in_shape
.
size
()
==
3
)
{
batch_
=
SizeToInt
(
in_shape
[
0
]);
if
(
in_shape
[
1
]
!=
in_shape
[
2
])
{
MS_LOG
(
ERROR
)
<<
"Cholesky need square matrix as input."
;
}
}
else
{
MS_LOG
(
ERROR
)
<<
"Input Only support Rank 2 OR 3"
;
}
m_
=
SizeToInt
(
in_shape
[
1
]);
lda_
=
m_
;
ldb_
=
m_
;
h_array
.
resize
(
batch_
);
h_identity
.
resize
(
batch_
);
h_identity_data
.
resize
(
m_
*
m_
);
for
(
size_t
i
=
0
;
i
<
m_
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
m_
;
j
++
)
{
if
(
i
==
j
)
{
h_identity_data
[
i
*
m_
+
j
]
=
1
;
}
else
{
h_identity_data
[
i
*
m_
+
j
]
=
0
;
}
}
}
InitSizeLists
();
}
else
{
if
(
in_shape
.
size
()
!=
2
)
{
MS_LOG
(
ERROR
)
<<
"Cholesky Split Matrix Need Input Rank as 2."
;
}
height
=
in_shape
[
0
];
width
=
in_shape
[
1
];
if
(
height
!=
width
)
{
MS_LOG
(
ERROR
)
<<
"Cholesky Split Matrix Need Square Matrix as Input."
;
}
if
(
SizeToInt
(
height
)
<=
split_dim
)
{
use_split_matrix
=
false
;
batch_
=
1
;
m_
=
SizeToInt
(
in_shape
[
1
]);
lda_
=
m_
;
ldb_
=
m_
;
h_array
.
resize
(
batch_
);
h_identity
.
resize
(
batch_
);
h_identity_data
.
resize
(
m_
*
m_
);
for
(
size_t
i
=
0
;
i
<
m_
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
m_
;
j
++
)
{
if
(
i
==
j
)
{
h_identity_data
[
i
*
m_
+
j
]
=
1
;
}
else
{
h_identity_data
[
i
*
m_
+
j
]
=
0
;
}
}
}
InitSizeLists
();
}
else
{
use_split_matrix
=
true
;
int
batch
=
SizeToInt
(
in_shape
[
1
])
/
split_dim
;
res_dim
=
in_shape
[
1
]
-
batch
*
split_dim
;
if
(
res_dim
==
0
)
{
batch_
=
batch
;
}
else
{
batch_
=
batch
+
1
;
}
m_
=
split_dim
;
lda_
=
m_
;
ldb_
=
m_
;
h_array
.
resize
(
batch_
);
h_identity
.
resize
(
batch_
);
h_identity_data
.
resize
(
m_
*
m_
);
for
(
size_t
i
=
0
;
i
<
m_
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
m_
;
j
++
)
{
if
(
i
==
j
)
{
h_identity_data
[
i
*
m_
+
j
]
=
1
;
}
else
{
h_identity_data
[
i
*
m_
+
j
]
=
0
;
}
}
}
InitSizeLists
();
}
}
return
true
;
}
protected:
void
InitSizeLists
()
override
{
if
(
!
use_split_matrix
)
{
size_t
unit_size
=
sizeof
(
T
);
size_t
input_size
=
batch_
*
m_
*
lda_
*
unit_size
;
input_size_list_
.
push_back
(
input_size
);
size_t
output_size
=
batch_
*
m_
*
lda_
*
unit_size
;
output_size_list_
.
push_back
(
output_size
);
size_t
workspace_size
=
batch_
*
sizeof
(
T
*
);
workspace_size_list_
.
push_back
(
workspace_size
);
workspace_size
=
batch_
*
sizeof
(
T
*
);
workspace_size_list_
.
push_back
(
workspace_size
);
workspace_size
=
batch_
*
sizeof
(
int
);
workspace_size_list_
.
push_back
(
workspace_size
);
}
else
{
size_t
unit_size
=
sizeof
(
T
);
size_t
input_size
=
height
*
width
*
unit_size
;
input_size_list_
.
push_back
(
input_size
);
size_t
output_size
=
batch_
*
m_
*
lda_
*
unit_size
;
output_size_list_
.
push_back
(
output_size
);
size_t
workspace_size
=
batch_
*
sizeof
(
T
*
);
workspace_size_list_
.
push_back
(
workspace_size
);
workspace_size
=
batch_
*
sizeof
(
T
*
);
workspace_size_list_
.
push_back
(
workspace_size
);
workspace_size
=
batch_
*
sizeof
(
int
);
workspace_size_list_
.
push_back
(
workspace_size
);
workspace_size
=
batch_
*
m_
*
lda_
*
unit_size
;
workspace_size_list_
.
push_back
(
workspace_size
);
}
}
private:
size_t
batch_
;
size_t
m_
;
size_t
lda_
;
size_t
ldb_
;
int
res_dim
;
int
split_dim
;
bool
is_null_input_
;
bool
use_split_matrix
;
size_t
height
;
size_t
width
;
cusolverDnHandle_t
handle_
;
cublasHandle_t
blas_handle_
;
cublasFillMode_t
uplo
=
CUBLAS_FILL_MODE_UPPER
;
std
::
vector
<
T
*>
h_array
;
std
::
vector
<
T
*>
h_identity
;
std
::
vector
<
T
>
h_identity_data
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif
mindspore/ccsrc/runtime/device/gpu/gpu_common.h
浏览文件 @
96642a76
...
...
@@ -93,6 +93,22 @@ namespace gpu {
} \
}
#define CHECK_CUSOLVER_RET_WITH_EXCEPT(expression, message) \
{ \
cusolverStatus_t status = (expression); \
if (status != CUSOLVER_STATUS_SUCCESS) { \
MS_LOG(EXCEPTION) << "cusolver Error: " << message << " | Error Number: " << status; \
} \
}
#define CHECK_CUSOLVER_RET_WITH_ERROR(expression, message) \
{ \
cusolverStatus_t status = (expression); \
if (status != CUSOLVER_STATUS_SUCCESS) { \
MS_LOG(ERROR) << "cusolver Error: " << message << " | Error Number: " << status; \
} \
}
#define CHECK_NCCL_RET_WITH_EXCEPT(expression, message) \
{ \
int result = (expression); \
...
...
mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc
浏览文件 @
96642a76
...
...
@@ -32,6 +32,10 @@ void GPUDeviceManager::InitDevice() {
CHECK_CUBLAS_RET_WITH_EXCEPT
(
cublasCreate
(
&
cublas_handle_
),
"Failed to create cuBLAS handle."
);
CHECK_CUBLAS_RET_WITH_EXCEPT
(
cublasSetStream
(
cublas_handle_
,
reinterpret_cast
<
cudaStream_t
>
(
default_stream
())),
"Failed to set stream for cuBLAS handle."
);
CHECK_CUSOLVER_RET_WITH_EXCEPT
(
cusolverDnCreate
(
&
cusolver_dn_handle_
),
"Failed to create cusolver dn handle."
);
CHECK_CUSOLVER_RET_WITH_EXCEPT
(
cusolverDnSetStream
(
cusolver_dn_handle_
,
reinterpret_cast
<
cudaStream_t
>
(
default_stream
())),
"Failed to set stream for cusolver dn handle"
);
CHECK_OP_RET_WITH_EXCEPT
(
GPUMemoryAllocator
::
GetInstance
().
Init
(),
"Failed to Init gpu memory allocator"
)
}
...
...
@@ -47,6 +51,9 @@ void GPUDeviceManager::ReleaseDevice() {
if
(
cublas_handle_
!=
nullptr
)
{
CHECK_CUBLAS_RET_WITH_ERROR
(
cublasDestroy
(
cublas_handle_
),
"Failed to destroy cuBLAS handle."
);
}
if
(
cusolver_dn_handle_
!=
nullptr
)
{
CHECK_CUSOLVER_RET_WITH_ERROR
(
cusolverDnDestroy
(
cusolver_dn_handle_
),
"Failed to destroy cusolver dn handle."
);
}
CHECK_OP_RET_WITH_ERROR
(
GPUMemoryAllocator
::
GetInstance
().
Finalize
(),
"Failed to destroy gpu memory allocator"
);
}
...
...
@@ -79,7 +86,7 @@ bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; }
const
cudnnHandle_t
&
GPUDeviceManager
::
GetCudnnHandle
()
const
{
return
cudnn_handle_
;
}
const
cublasHandle_t
&
GPUDeviceManager
::
GetCublasHandle
()
const
{
return
cublas_handle_
;
}
const
cusolverDnHandle_t
&
GPUDeviceManager
::
GetCusolverDnHandle
()
const
{
return
cusolver_dn_handle_
;
}
bool
GPUDeviceManager
::
SyncStream
(
const
DeviceStream
&
stream
)
const
{
return
CudaDriver
::
SyncStream
(
stream
);
}
bool
GPUDeviceManager
::
CopyDeviceMemToHost
(
const
HostMemPtr
&
dst
,
const
DeviceMemPtr
&
src
,
size_t
size
)
const
{
...
...
mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h
浏览文件 @
96642a76
...
...
@@ -19,6 +19,7 @@
#include <cudnn.h>
#include <cublas_v2.h>
#include <cusolverDn.h>
#include <vector>
#include <memory>
#include "runtime/device/gpu/cuda_driver.h"
...
...
@@ -43,6 +44,7 @@ class GPUDeviceManager {
const
cudnnHandle_t
&
GetCudnnHandle
()
const
;
const
cublasHandle_t
&
GetCublasHandle
()
const
;
const
cusolverDnHandle_t
&
GetCusolverDnHandle
()
const
;
bool
CopyDeviceMemToHost
(
const
HostMemPtr
&
dst
,
const
DeviceMemPtr
&
src
,
size_t
size
)
const
;
bool
CopyHostMemToDevice
(
const
DeviceMemPtr
&
dst
,
const
void
*
src
,
size_t
size
)
const
;
...
...
@@ -73,6 +75,8 @@ class GPUDeviceManager {
// handle used for cuBLAS kernels.
cublasHandle_t
cublas_handle_
{
nullptr
};
// handle used for cusolver dn kernels;
cusolverDnHandle_t
cusolver_dn_handle_
{
nullptr
};
bool
dev_id_init_
;
uint32_t
cur_dev_id_
;
};
...
...
mindspore/ops/operations/__init__.py
浏览文件 @
96642a76
...
...
@@ -86,7 +86,7 @@ from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, Popul
from
._thor_ops
import
(
CusBatchMatMul
,
CusCholeskyTrsm
,
CusFusedAbsMax1
,
CusImg2Col
,
CusMatMulCubeDenseLeft
,
CusMatMulCubeFraczRightMul
,
CusMatMulCube
,
CusMatrixCombine
,
CusTranspose02314
,
CusMatMulCubeDenseRight
,
CusMatMulCubeFraczLeftCast
,
Im2Col
,
UpdateThorGradient
)
CusMatMulCubeFraczLeftCast
,
Im2Col
,
UpdateThorGradient
,
Cholesky
)
from
.sparse_ops
import
SparseToDense
__all__
=
[
...
...
mindspore/ops/operations/_thor_ops.py
浏览文件 @
96642a76
...
...
@@ -607,3 +607,34 @@ class UpdateThorGradient(PrimitiveWithInfer):
validator
.
check_tensor_type_same
({
'x1_dtype'
:
x1_dtype
,
'x2_dtype'
:
x2_dtype
,
'x3_dtype'
:
x3_dtype
},
[
mstype
.
float32
],
self
.
name
)
return
x2_dtype
class
Cholesky
(
PrimitiveWithInfer
):
"""
Inner API for resnet50 THOR GPU backend
"""
@
prim_attr_register
def
__init__
(
self
,
split_dim
=
0
):
self
.
init_prim_io_names
(
inputs
=
[
'x1'
],
outputs
=
[
'y'
])
self
.
split_dim
=
split_dim
self
.
add_prim_attr
(
'split_dim'
,
self
.
split_dim
)
def
infer_shape
(
self
,
x1_shape
):
if
self
.
split_dim
!=
0
:
assert
len
(
x1_shape
)
==
2
height
=
x1_shape
[
0
]
width
=
x1_shape
[
1
]
assert
height
==
width
if
height
<=
self
.
split_dim
:
out_shape
=
[
1
,
height
,
width
]
else
:
batch
=
height
//
self
.
split_dim
if
height
!=
batch
*
self
.
split_dim
:
batch
+=
1
out_shape
=
[
batch
,
self
.
split_dim
,
self
.
split_dim
]
else
:
out_shape
=
x1_shape
return
out_shape
def
infer_dtype
(
self
,
x1_dtype
):
validator
.
check_tensor_type_same
({
'x1_dtype'
:
x1_dtype
},
[
mstype
.
float32
],
self
.
name
)
return
x1_dtype
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录