Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
92f1855a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
92f1855a
编写于
9月 05, 2020
作者:
B
baihuawei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix categorical in GraphMode
上级
1a4d3e35
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
64 addition
and
101 deletion
+64
-101
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc
...ackend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc
+5
-5
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h
...backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h
+1
-1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu
...backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu
+20
-14
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh
...ackend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh
+3
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h
...backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h
+8
-6
mindspore/nn/probability/distribution/_utils/utils.py
mindspore/nn/probability/distribution/_utils/utils.py
+4
-18
mindspore/nn/probability/distribution/categorical.py
mindspore/nn/probability/distribution/categorical.py
+23
-57
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc
浏览文件 @
92f1855a
...
...
@@ -148,7 +148,7 @@ void LSTMGradCPUKernel::SetArgumentHandleOp(const std::vector<kernel::AddressPtr
SetArgumentHandle
(
DNNL_ARG_DIFF_DST_ITER_C
,
inputs
[
9
]
->
addr
);
}
void
LSTMGradCPUKernel
::
Memset_op
(
const
dnnl
::
memory
&
mem
,
string
name
)
{
void
LSTMGradCPUKernel
::
ResetMemory
(
const
dnnl
::
memory
&
mem
,
string
name
)
{
if
(
memset_s
(
mem
.
get_data_handle
(),
mem
.
get_desc
().
get_size
(),
0
,
mem
.
get_desc
().
get_size
()))
{
MS_LOG
(
EXCEPTION
)
<<
name
<<
" memset error"
;
}
...
...
@@ -186,10 +186,10 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
auto
user_diff_weights_h_memory
=
dnnl
::
memory
(
dnnl
::
memory
::
desc
{{
weights_h_dims_
},
dt
::
f32
,
tag
::
ldgoi
},
eng
);
user_diff_weights_memory
.
set_data_handle
(
outputs
[
3
]
->
addr
);
user_diff_weights_h_memory
.
set_data_handle
(
reinterpret_cast
<
float
*>
(
outputs
[
3
]
->
addr
)
+
weight_size_
);
Memset_op
(
user_diff_weights_memory
,
"user weights grad"
);
Memset_op
(
user_diff_weights_h_memory
,
"user weights iter grad"
);
Memset_op
(
diff_weights_memory
,
"weights grad"
);
Memset_op
(
diff_weights_h_memory
,
"weights iter grad"
);
ResetMemory
(
user_diff_weights_memory
,
"user weights grad"
);
ResetMemory
(
user_diff_weights_h_memory
,
"user weights iter grad"
);
ResetMemory
(
diff_weights_memory
,
"weights grad"
);
ResetMemory
(
diff_weights_h_memory
,
"weights iter grad"
);
if
(
has_bias_
)
{
diff_bias_memory
.
set_data_handle
(
reinterpret_cast
<
float
*>
(
outputs
[
3
]
->
addr
)
+
weight_size_
+
weight_h_size_
);
}
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h
浏览文件 @
92f1855a
...
...
@@ -42,7 +42,7 @@ class LSTMGradCPUKernel : public MKLCPUKernel {
const
dnnl
::
memory
&
weights_h_memory
,
const
dnnl
::
memory
&
bias_memory
,
const
dnnl
::
memory
&
diff_weights_memory
,
const
dnnl
::
memory
&
diff_weights_h_memory
,
const
dnnl
::
memory
&
diff_bias_memory
);
void
Memset_op
(
const
dnnl
::
memory
&
mem
,
string
name
);
void
ResetMemory
(
const
dnnl
::
memory
&
mem
,
string
name
);
void
CheckParam
(
const
CNodePtr
&
kernel_node
);
int
weight_size_
=
0
;
int
weight_h_size_
=
0
;
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu
浏览文件 @
92f1855a
...
...
@@ -16,18 +16,6 @@
#include "multinomial_impl.cuh"
template
<
typename
T
>
__global__
void
NormInput
(
T
*
input
,
const
size_t
distributions
,
const
size_t
categories
)
{
size_t
size
=
distributions
*
categories
;
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
((
pos
+
1
)
%
categories
!=
0
)
{
int
de_pos
=
(
1
+
pos
/
categories
)
*
categories
-
1
;
input
[
pos
]
/=
input
[
de_pos
];
}
}
return
;
}
template
<
typename
T
>
__global__
void
CheckZeroKernel
(
const
size_t
distributions
,
const
size_t
categories
,
const
T
*
input
,
T
*
out
)
{
out
[
0
]
=
0
;
...
...
@@ -61,6 +49,24 @@ void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t cuda
CheckNonNegKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
input
,
output
);
}
template
<
typename
T
>
__global__
void
NormInputKernel
(
T
*
input
,
const
size_t
distributions
,
const
size_t
categories
)
{
size_t
size
=
distributions
*
categories
;
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
((
pos
+
1
)
%
categories
!=
0
)
{
int
de_pos
=
(
1
+
pos
/
categories
)
*
categories
-
1
;
input
[
pos
]
/=
input
[
de_pos
];
}
}
return
;
}
template
<
typename
T
>
void
NormInput
(
T
*
input
,
const
size_t
distributions
,
const
size_t
categories
,
cudaStream_t
cuda_stream
)
{
int
count1
=
distributions
*
categories
;
NormInputKernel
<<<
GET_BLOCKS
(
count1
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
distributions
,
categories
);
}
template
<
typename
T
>
__device__
int
BinarySearchForMultinomial
(
T
*
start_addr
,
int
size
,
T
rand
)
{
int
start
=
0
;
...
...
@@ -104,8 +110,6 @@ void Multinomial(int seed, T *input, int num_sample, curandState *globalState, i
RNG_seed
=
time
(
NULL
);
}
int
count
=
distributions
*
num_sample
;
int
count1
=
distributions
*
categories
;
NormInput
<<<
GET_BLOCKS
(
count1
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
distributions
,
categories
);
MultinomialKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
RNG_seed
,
input
,
num_sample
,
globalState
,
output
,
distributions
,
categories
);
return
;
...
...
@@ -116,3 +120,5 @@ template void Multinomial<float>(int seed, float *input, int num_sample, curandS
template
void
CheckNonNeg
<
float
>(
const
size_t
size
,
const
float
*
input
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
CheckZero
<
float
>(
const
size_t
distributions
,
const
size_t
categories
,
const
float
*
input
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
NormInput
<
float
>(
float
*
input
,
const
size_t
distributions
,
const
size_t
categories
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh
浏览文件 @
92f1855a
...
...
@@ -26,4 +26,7 @@ template <typename T>
void
CheckNonNeg
(
const
size_t
size
,
const
T
*
input
,
T
*
output
,
cudaStream_t
stream
);
template
<
typename
T
>
void
CheckZero
(
const
size_t
distributions
,
const
size_t
categories
,
const
T
*
input
,
T
*
output
,
cudaStream_t
stream
);
template
<
typename
T
>
void
NormInput
(
T
*
input
,
const
size_t
distributions
,
const
size_t
categories
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_IMPL_CUH_
mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h
浏览文件 @
92f1855a
...
...
@@ -47,22 +47,23 @@ class MultinomialGpuKernel : public GpuKernel {
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
void
*
workspace_addr
=
GetDeviceAddress
<
void
*>
(
workspace
,
0
);
void
*
workspace_addr
=
GetDeviceAddress
<
void
*>
(
workspace
,
1
);
T
*
cum_sum_input
=
GetDeviceAddress
<
T
>
(
workspace
,
0
);
curandState
*
devStates
=
reinterpret_cast
<
curandState
*>
(
workspace_addr
);
int
*
output_addr
=
GetDeviceAddress
<
int
>
(
outputs
,
0
);
T
*
input_addr
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
int
categories
=
SizeToInt
(
inputs
[
0
]
->
size
/
sizeof
(
T
))
/
distributions_
;
int
num_sample
=
SizeToInt
(
outputs
[
0
]
->
size
/
sizeof
(
T
))
/
distributions_
;
int
num_sample
=
SizeToInt
(
outputs
[
0
]
->
size
/
sizeof
(
int
))
/
distributions_
;
// check input
T
*
cum_sum_input
=
nullptr
;
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
cum_sum_input
),
input_size_0_
),
"cudaMalloc failed."
);
CheckPeram
(
input_addr
,
cum_sum_input
,
categories
,
stream_ptr
);
if
(
replacement_
)
{
NormInput
(
cum_sum_input
,
IntToSize
(
distributions_
),
IntToSize
(
categories
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaStreamSynchronize
(
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"cudaStreamSynchronize failed."
);
Multinomial
(
seed_
,
cum_sum_input
,
num_sample
,
devStates
,
output_addr
,
IntToSize
(
distributions_
),
IntToSize
(
categories
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaFree
(
cum_sum_input
),
"cudaFree failed."
);
return
true
;
}
...
...
@@ -145,6 +146,7 @@ class MultinomialGpuKernel : public GpuKernel {
input_size_list_
.
push_back
(
input_size_0_
);
input_size_list_
.
push_back
(
sizeof
(
int
));
output_size_list_
.
push_back
(
output_size_
);
workspace_size_list_
.
push_back
(
input_size_0_
);
workspace_size_list_
.
push_back
(
workspace_size_
);
}
...
...
mindspore/nn/probability/distribution/_utils/utils.py
浏览文件 @
92f1855a
...
...
@@ -271,24 +271,6 @@ def probs_to_logits(probs, is_binary=False):
return
P
.
Log
()(
ps_clamped
)
def
check_tensor_type
(
name
,
inputs
,
valid_type
):
"""
Check if inputs is proper.
Args:
name: inputs name
inputs: Tensor to be checked.
Raises:
ValueError: if inputs is not a proper Tensor.
"""
if
not
isinstance
(
inputs
,
Tensor
):
raise
TypeError
(
f
"
{
name
}
should be a Tensor"
)
input_type
=
P
.
DType
()(
inputs
)
if
input_type
not
in
valid_type
:
raise
TypeError
(
f
"
{
name
}
dtype is invalid"
)
def
check_type
(
data_type
,
value_type
,
name
):
if
not
data_type
in
value_type
:
raise
TypeError
(
...
...
@@ -304,6 +286,10 @@ def raise_none_error(name):
def
raise_probs_logits_error
():
raise
TypeError
(
"Either 'probs' or 'logits' must be specified, but not both."
)
@
constexpr
def
raise_broadcast_error
(
shape_a
,
shape_b
):
raise
ValueError
(
f
"Shape
{
shape_a
}
and
{
shape_b
}
is not broadcastable."
)
@
constexpr
def
raise_not_impl_error
(
name
):
raise
ValueError
(
...
...
mindspore/nn/probability/distribution/categorical.py
浏览文件 @
92f1855a
...
...
@@ -17,7 +17,8 @@ from mindspore.ops import operations as P
import
mindspore.nn
as
nn
from
mindspore.common
import
dtype
as
mstype
from
.distribution
import
Distribution
from
._utils.utils
import
logits_to_probs
,
probs_to_logits
,
check_type
,
check_tensor_type
,
cast_to_tensor
,
raise_probs_logits_error
from
._utils.utils
import
logits_to_probs
,
probs_to_logits
,
check_type
,
cast_to_tensor
,
\
raise_probs_logits_error
class
Categorical
(
Distribution
):
...
...
@@ -25,7 +26,7 @@ class Categorical(Distribution):
Creates a categorical distribution parameterized by either probs or logits (but not both).
Args:
probs (Tensor, list, numpy.ndarray, Parameter
, float
): event probabilities.
probs (Tensor, list, numpy.ndarray, Parameter): event probabilities.
logits (Tensor, list, numpy.ndarray, Parameter, float): event log-odds.
seed (int): seed to use in sampling. Default: 0.
dtype (mindspore.dtype): type of the distribution. Default: mstype.int32.
...
...
@@ -77,6 +78,7 @@ class Categorical(Distribution):
if
(
probs
is
None
)
==
(
logits
is
None
):
raise_probs_logits_error
()
self
.
reduce_sum
=
P
.
ReduceSum
(
keep_dims
=
True
)
self
.
reduce_sum1
=
P
.
ReduceSum
(
keep_dims
=
False
)
self
.
log
=
P
.
Log
()
self
.
exp
=
P
.
Exp
()
self
.
shape
=
P
.
Shape
()
...
...
@@ -88,6 +90,7 @@ class Categorical(Distribution):
self
.
expandim
=
P
.
ExpandDims
()
self
.
gather
=
P
.
GatherNd
()
self
.
concat
=
P
.
Concat
(
-
1
)
self
.
transpose
=
P
.
Transpose
()
if
probs
is
not
None
:
self
.
_probs
=
cast_to_tensor
(
probs
,
mstype
.
float32
)
input_sum
=
self
.
reduce_sum
(
self
.
_probs
,
-
1
)
...
...
@@ -102,8 +105,8 @@ class Categorical(Distribution):
self
.
_param
=
self
.
_logits
self
.
_num_events
=
self
.
shape
(
self
.
_param
)[
-
1
]
self
.
_param2d
=
self
.
reshape
(
self
.
_param
,
(
-
1
,
self
.
_num_events
))
self
.
_batch_shape
=
self
.
shape
(
self
.
_param
2d
)[
0
]
self
.
_batch_shape
=
self
.
shape
(
self
.
_param
)[:
-
1
]
self
.
_batch_shape_n
=
(
1
,)
*
len
(
self
.
_batch_shape
)
@
property
def
logits
(
self
):
...
...
@@ -130,72 +133,35 @@ class Categorical(Distribution):
Tensor, shape is shape(probs)[:-1] + sample_shape
"""
self
.
checktuple
(
sample_shape
,
'shape'
)
if
sample_shape
==
():
sample_shape
=
(
1
,)
num_sample
=
1
for
i
in
sample_shape
:
num_sample
*=
i
probs_2d
=
self
.
reshape
(
self
.
_probs
,
(
-
1
,
self
.
_num_events
))
samples
=
self
.
mutinomial
(
probs_2d
,
num_sample
)
samples
=
self
.
transpose
(
samples
,
(
1
,
0
))
extend_shape
=
sample_shape
if
len
(
self
.
shape
(
self
.
_probs
))
>
1
:
extend_shape
=
sample_shape
+
self
.
shape
(
self
.
_probs
)[:
-
1
]
return
self
.
cast
(
self
.
reshape
(
samples
,
extend_shape
),
self
.
dtype
)
def
_broad_cast_shape
(
self
,
a
,
b
):
"""
Broadcast Tensor shape.
Args:
a (Tensor): A Tensor need to Broadcast.
b (Tensor): Another Tensor need to Broadcast.
Returns:
Tuple, Broadcast shape.
"""
shape_a
=
self
.
shape
(
a
)
shape_b
=
self
.
shape
(
b
)
size_a
=
len
(
shape_a
)
size_b
=
len
(
shape_b
)
if
size_a
>
size_b
:
size
=
size_a
shape_out
=
list
(
shape_a
)
shape_short
=
list
(
shape_b
)
diff_size
=
size_a
-
size_b
else
:
size
=
size_b
shape_out
=
list
(
shape_b
)
shape_short
=
list
(
shape_a
)
diff_size
=
size_b
-
size_a
for
i
in
range
(
diff_size
,
size
):
if
shape_out
[
i
]
==
shape_short
[
i
-
diff_size
]:
continue
if
shape_out
[
i
]
==
1
or
shape_short
[
i
-
diff_size
]
==
1
:
shape_out
[
i
]
=
shape_out
[
i
]
*
shape_short
[
i
-
diff_size
]
else
:
raise
ValueError
(
f
"Shape
{
shape_a
}
and
{
shape_b
}
is not broadcastable."
)
return
tuple
(
shape_out
)
def
_log_prob
(
self
,
value
):
r
"""
Evaluate log probability.
Args:
value (Tensor): value to be evaluated.
The dtype could be mstype.float32, bool, mstype.int32.
value (Tensor): value to be evaluated.
"""
if
value
is
not
None
:
check_tensor_type
(
"value"
,
value
,
[
mstype
.
float32
,
bool
,
mstype
.
int32
])
value
=
self
.
expandim
(
self
.
cast
(
value
,
mstype
.
float32
),
-
1
)
broad_shape
=
self
.
_broad_cast_shape
(
value
,
self
.
_logits
)
broad
=
P
.
BroadcastTo
(
broad_shape
)
logits_pmf
=
self
.
reshape
(
broad
(
self
.
_logits
),
(
-
1
,
broad_shape
[
-
1
]))
value
=
self
.
reshape
(
broad
(
value
)[...,
:
1
],
(
-
1
,
1
))
index
=
nn
.
Range
(
0.
,
self
.
shape
(
value
)[
0
],
1
)()
index
=
self
.
reshape
(
index
,
(
-
1
,
1
))
value
=
self
.
concat
((
index
,
value
))
value
=
self
.
cast
(
value
,
mstype
.
int32
)
return
self
.
reshape
(
self
.
gather
(
logits_pmf
,
value
),
broad_shape
[:
-
1
])
return
None
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
expandim
(
self
.
cast
(
value
,
mstype
.
float32
),
-
1
)
broad_shape
=
self
.
shape
(
value
+
self
.
_logits
)
broad
=
P
.
BroadcastTo
(
broad_shape
)
logits_pmf
=
self
.
reshape
(
broad
(
self
.
_logits
),
(
-
1
,
broad_shape
[
-
1
]))
value
=
self
.
reshape
(
broad
(
value
)[...,
:
1
],
(
-
1
,
1
))
index
=
nn
.
Range
(
0.
,
self
.
shape
(
value
)[
0
],
1
)()
index
=
self
.
reshape
(
index
,
(
-
1
,
1
))
value
=
self
.
concat
((
index
,
value
))
value
=
self
.
cast
(
value
,
mstype
.
int32
)
return
self
.
reshape
(
self
.
gather
(
logits_pmf
,
value
),
broad_shape
[:
-
1
])
def
_entropy
(
self
):
r
"""
...
...
@@ -205,7 +171,7 @@ class Categorical(Distribution):
H(X) = -\sum(logits * probs)
"""
p_log_p
=
self
.
_logits
*
self
.
_probs
return
self
.
reduce_sum
(
-
p_log_p
,
-
1
)
return
self
.
reduce_sum
1
(
-
p_log_p
,
-
1
)
def
enumerate_support
(
self
,
expand
=
True
):
r
"""
...
...
@@ -213,8 +179,8 @@ class Categorical(Distribution):
"""
num_events
=
self
.
_num_events
values
=
nn
.
Range
(
0.
,
num_events
,
1
)()
values
=
self
.
reshape
(
values
,
(
num_events
,
1
)
)
values
=
self
.
reshape
(
values
,
(
num_events
,
)
+
self
.
_batch_shape_n
)
if
expand
:
values
=
P
.
BroadcastTo
((
num_events
,
self
.
_batch_shape
)
)(
values
)
values
=
P
.
BroadcastTo
((
num_events
,
)
+
self
.
_batch_shape
)(
values
)
values
=
self
.
cast
(
values
,
mstype
.
int32
)
return
values
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录