提交 413e3eb5 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5622 fix nccl broadcast

Merge pull request !5622 from baihuawei/0901
......@@ -86,6 +86,9 @@ void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
num_directions_ = 2;
}
const int gate_size = 4 * hidden_size_;
if (num_layers_ <= 0) {
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
}
for (int i = 0; i < num_layers_; ++i) {
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
weight_h_size_ += gate_size * hidden_size_;
......@@ -95,9 +98,6 @@ void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
MS_LOG(EXCEPTION) << "error iteration shape!";
}
if (num_layers_ <= 0) {
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
}
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
MS_LOG(EXCEPTION) << "lstm only support 3-D input!";
}
......
......@@ -104,6 +104,9 @@ void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
num_directions_ = 2;
}
const int gate_size = 4 * hidden_size_;
if (num_layers_ <= 0) {
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
}
for (int i = 0; i < num_layers_; ++i) {
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
weight_h_size_ += gate_size * hidden_size_;
......@@ -113,9 +116,6 @@ void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
MS_LOG(EXCEPTION) << "error iteration shape!";
}
if (num_layers_ <= 0) {
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
}
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
MS_LOG(EXCEPTION) << "lstm only support 3-D input!";
}
......
......@@ -109,9 +109,13 @@ class NcclGpuKernel : public GpuKernel {
auto broadcast_funcptr =
reinterpret_cast<Broadcast>(dlsym(const_cast<void *>(collective_handle_), "Broadcast"));
MS_EXCEPTION_IF_NULL(broadcast_funcptr);
CHECK_NCCL_RET_WITH_EXCEPT((*broadcast_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
nccl_data_type_, root_, stream, group_name_),
"ncclBroadcast failed");
for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) {
input_addr = GetDeviceAddress<T>(inputs, i);
output_addr = GetDeviceAddress<T>(outputs, i);
CHECK_NCCL_RET_WITH_EXCEPT((*broadcast_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
nccl_data_type_, root_, stream, group_name_),
"ncclBroadcast failed");
}
break;
}
default: {
......
......@@ -28,7 +28,7 @@ class Categorical(Distribution):
probs (Tensor, list, numpy.ndarray, Parameter, float): event probabilities.
logits (Tensor, list, numpy.ndarray, Parameter, float): event log-odds.
seed (int): seed to use in sampling. Default: 0.
dtype (mindspore.dtype): type of the distribution. Default: mstype.int32.
dtype (mstype.int32): type of the distribution. Default: mstype.int32.
name (str): name of the distribution. Default: Categorical.
Note:
......@@ -49,7 +49,7 @@ class Categorical(Distribution):
>>>
>>> # Similar calls can be made to logits
>>> ans = self.ca.probs
>>> # value should be Tensor
>>> # value should be Tensor(mstype.float32, bool, mstype.int32)
>>> ans = self.ca.log_prob(value)
>>>
>>> # Usage of enumerate_support
......@@ -210,6 +210,8 @@ class Categorical(Distribution):
def enumerate_support(self, expand=True):
r"""
Enumerate categories.
Args:
expand (Bool): whether to expand.
"""
num_events = self._num_events
values = nn.Range(0., num_events, 1)()
......
......@@ -439,7 +439,7 @@ class Multinomial(PrimitiveWithInfer):
Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
Must be non-negative. Default: 0.
replacement(bool) - whether to draw with replacement or not.
replacement(bool): Whether to draw with replacement or not.
Inputs:
- **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dims.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册