提交 55d280e7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3608 Add gpu support for RandomChoiceWithMask

Merge pull request !3608 from 34bunny/GPU-RandomChoiceWithMask
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh"
#include <algorithm>
int RcwmRoundUpPower2(int v) {
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v++;
return v;
}
template <typename T>
__inline__ __device__ void Swap(T *lhs, T *rhs) {
T tmp = lhs[0];
lhs[0] = rhs[0];
rhs[0] = tmp;
}
template <typename T, typename S>
__global__ void InitArray(const int input_size, const int ceil_power2, const T *input, S *mask_buff, S *rank_buff) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < ceil_power2; pos += blockDim.x * gridDim.x) {
mask_buff[pos] = (pos < input_size) ? static_cast<S>(input[pos]) : 0;
rank_buff[pos] = (pos < input_size && input[pos] != false) ? pos : (ceil_power2 + 1);
}
}
template <size_t blockSize, typename T>
__device__ void WarpReduce(volatile T *sdata, size_t tid) {
if (blockSize >= 64) sdata[tid] += sdata[tid + 32];
if (blockSize >= 32) sdata[tid] += sdata[tid + 16];
if (blockSize >= 16) sdata[tid] += sdata[tid + 8];
if (blockSize >= 8) sdata[tid] += sdata[tid + 4];
if (blockSize >= 4) sdata[tid] += sdata[tid + 2];
if (blockSize >= 2) sdata[tid] += sdata[tid + 1];
}
template <size_t blockSize, typename T>
__global__ void ReductionSum(T *g_idata, T *g_odata, size_t n) {
__shared__ T sdata[blockSize];
size_t tid = threadIdx.x;
size_t i = blockIdx.x * (blockSize) + tid;
size_t gridSize = blockSize * gridDim.x;
sdata[tid] = 0;
while (i < n) {
sdata[tid] += g_idata[i];
i += gridSize;
}
__syncthreads();
if (blockSize >= 1024) {
if (tid < 512) {
sdata[tid] += sdata[tid + 512];
}
__syncthreads();
}
if (blockSize >= 512) {
if (tid < 256) {
sdata[tid] += sdata[tid + 256];
}
__syncthreads();
}
if (blockSize >= 256) {
if (tid < 128) {
sdata[tid] += sdata[tid + 128];
}
__syncthreads();
}
if (blockSize >= 128) {
if (tid < 64) {
sdata[tid] += sdata[tid + 64];
}
__syncthreads();
}
if (tid < 32) WarpReduce<blockSize>(sdata, tid);
if (tid == 0) g_odata[blockIdx.x] = sdata[0];
}
template <typename T, typename S>
__global__ void Reshape2Index(const int input_size, const int input_shape_size, const int d1, const int d2,
const int d3, const int d4, const int d5, const T *input, S *output_index) {
int pos_array[MAX_DIMENSION];
int index_pos;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) {
pos_array[0] = pos / (d2 * d3 * d4 * d5) % d1;
pos_array[1] = pos / (d3 * d4 * d5) % d2;
pos_array[2] = pos / (d4 * d5) % d3;
pos_array[3] = pos / (d5) % d4;
pos_array[4] = pos % d5;
index_pos = pos * input_shape_size;
if (input[pos] == false) {
for (int i = 0; i < input_shape_size; i++) {
output_index[index_pos++] = 0;
}
} else {
for (int i = MAX_DIMENSION - input_shape_size; i < MAX_DIMENSION; i++) {
output_index[index_pos++] = pos_array[i];
}
}
}
}
template <typename T>
__global__ void Copy(const T *src, T *dst, const int n) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < n; pos += blockDim.x * gridDim.x) {
dst[pos] = src[pos];
}
}
template <typename T>
__global__ void Sort(const int ceil_power2, T *rank_buff) {
for (size_t i = 2; i <= ceil_power2; i <<= 1) {
for (size_t j = (i >> 1); j > 0; j >>= 1) {
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < ceil_power2; tid += blockDim.x * gridDim.x) {
size_t tid_comp = tid ^ j;
if (tid_comp > tid) {
if ((tid & i) == 0) {
if (rank_buff[tid] > rank_buff[tid_comp]) {
Swap(&rank_buff[tid], &rank_buff[tid_comp]);
}
} else {
if (rank_buff[tid] < rank_buff[tid_comp]) {
Swap(&rank_buff[tid], &rank_buff[tid_comp]);
}
}
}
}
__syncthreads();
}
}
}
__global__ void SrandInit(const int ceil_power2, curandState *globalState, const int seedc) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < ceil_power2; i += blockDim.x * gridDim.x) {
curand_init(seedc, i, 0, &globalState[i]);
}
}
template <typename T>
__global__ void Shuffle(const int ceil_power2, curandState *globalState, T *rank_buff) {
int limit = ceil_power2 + 1;
int value;
for (size_t i = 2; i <= ceil_power2; i <<= 1) {
for (size_t j = (i >> 1); j > 0; j >>= 1) {
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < ceil_power2; tid += blockDim.x * gridDim.x) {
size_t tid_comp = tid ^ j;
if (tid_comp > tid) {
value = static_cast<int>(curand(&globalState[tid]));
if (value & 1) {
if (rank_buff[tid] != limit && rank_buff[tid_comp] != limit) {
Swap(&rank_buff[tid], &rank_buff[tid_comp]);
}
}
}
}
__syncthreads();
}
}
}
template <typename T, typename S>
__global__ void MoveToOutput(const int input_shape_size, const int count, const T *input, S *output_index,
T *output_mask, S *index_buff, S *rank_buff, S *Tnum_buff) {
int Tnum = static_cast<int>(Tnum_buff[0]);
int idx = 0;
int pos;
if (count <= Tnum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
idx = rank_buff[i];
pos = i;
output_mask[pos] = input[idx];
pos *= input_shape_size;
idx *= input_shape_size;
for (size_t j = 0; j < input_shape_size; j++) {
output_index[pos] = index_buff[idx];
pos++;
idx++;
}
}
} else {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
if (i < Tnum) {
idx = rank_buff[i];
pos = i;
output_mask[pos] = input[idx];
pos *= input_shape_size;
idx *= input_shape_size;
for (size_t j = 0; j < input_shape_size; j++) {
output_index[pos] = index_buff[idx];
pos++;
idx++;
}
} else {
pos = i;
output_mask[pos] = static_cast<T>(0);
pos *= input_shape_size;
for (size_t j = 0; j < input_shape_size; j++) {
output_index[pos] = static_cast<S>(0);
pos++;
}
}
}
}
}
template <typename T, typename S>
void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2,
const int &d3, const int &d4, const int &d5, const int &seedc, const int &count,
const T *input, S *output_index, T *output_mask, S *index_buff, S *mask_buff, S *rank_buff,
S *Tnum_buff, S *tmp_buff, curandState *globalState, cudaStream_t stream) {
int ceil_power2 = RcwmRoundUpPower2(input_size);
InitArray<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, ceil_power2, input, mask_buff, rank_buff);
size_t BLOCKNUM;
size_t n = ceil_power2;
Copy<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(mask_buff, tmp_buff, ceil_power2);
do {
BLOCKNUM = std::ceil(static_cast<float>(n) / BLOCKSIZE);
ReductionSum<BLOCKSIZE, S><<<BLOCKNUM, BLOCKSIZE, 0, stream>>>(tmp_buff, Tnum_buff, n);
Copy<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(Tnum_buff, tmp_buff, BLOCKNUM);
n = BLOCKNUM;
} while (n > BLOCKSIZE);
if (n > 1) ReductionSum<BLOCKSIZE, S><<<1, BLOCKSIZE, 0, stream>>>(Tnum_buff, Tnum_buff, n);
Reshape2Index<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, input_shape_size, d1, d2, d3, d4, d5,
input, index_buff);
Sort<<<GET_BLOCKS(ceil_power2), GET_THREADS, 0, stream>>>(ceil_power2, rank_buff);
SrandInit<<<GET_BLOCKS(ceil_power2), GET_THREADS, 0, stream>>>(ceil_power2, globalState, seedc);
Shuffle<<<GET_BLOCKS(ceil_power2), GET_THREADS, 0, stream>>>(ceil_power2, globalState, rank_buff);
MoveToOutput<<<GET_BLOCKS(count), GET_THREADS, 0, stream>>>(input_shape_size, count, input, output_index, output_mask,
index_buff, rank_buff, Tnum_buff);
}
template void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2,
const int &d3, const int &d4, const int &d5, const int &seedc, const int &count,
const bool *input, int *output_index, bool *output_mask, int *index_buff,
int *mask_buff, int *rank_buff, int *Tnum_buff, int *tmp_buff,
curandState *globalState, cudaStream_t stream);
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
#define BLOCKSIZE 256
#define MAX_DIMENSION 5
template <typename T, typename S>
void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2,
const int &d3, const int &d4, const int &d5, const int &seedc, const int &count,
const T *input, S *output_index, T *output_mask, S *index_buff, S *mask_buff, S *rank_buff,
S *Tnum_buff, S *tmp_buff, curandState *globalState, cudaStream_t stream);
int RcwmRoundUpPower2(int v);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
RandomChoiceWithMask,
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
RandomChoiceWithMaskGpuKernel, bool, int)
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class RandomChoiceWithMaskGpuKernel : public GpuKernel {
public:
RandomChoiceWithMaskGpuKernel() : input_shape_size_(0), seedc_(0), input_size_(1), count_(0), ceil_power2_(0) {}
~RandomChoiceWithMaskGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
S *output_index = GetDeviceAddress<S>(outputs, 0);
T *output_mask = GetDeviceAddress<T>(outputs, 1);
S *index_buff = GetDeviceAddress<S>(workspaces, 0);
S *mask_buff = GetDeviceAddress<S>(workspaces, 1);
S *rank_buff = GetDeviceAddress<S>(workspaces, 2);
S *Tnum_buff = GetDeviceAddress<S>(workspaces, 3);
S *tmp_buff = GetDeviceAddress<S>(workspaces, 4);
void *States = GetDeviceAddress<void *>(workspaces, 5);
curandState *devStates = reinterpret_cast<curandState *>(States);
CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], input_shape_5D_[2],
input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, output_index, output_mask,
index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, devStates,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but RandomChoiceWithMask has 2 outputs.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_shape_size_ = input_shape.size();
if (input_shape_size_ < 1 || input_shape_size_ > MAX_DIMENSION) {
MS_LOG(ERROR) << "Input is " << input_shape_size_
<< "-D, but RandomChoiceWithMask supports only 1-D to 5-D inputs.";
return false;
}
// convert size_t to int
for (auto i = 0; i < input_shape_size_; i++) {
input_shape_5D_.push_back(input_shape[i]);
}
// convert shape to 5D
while (input_shape_5D_.size() != MAX_DIMENSION) {
input_shape_5D_.insert(input_shape_5D_.begin(), 1);
}
// init seedc_
int seed = GetAttr<int>(kernel_node, "seed");
int seed2 = GetAttr<int>(kernel_node, "seed2");
if (seed2 != 0)
seedc_ = seed2;
else if (seed != 0)
seedc_ = seed;
else
seedc_ = time(NULL);
// init memory
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
count_ = GetAttr<int>(kernel_node, "count");
// upper ceiling for input for ceil_power2
ceil_power2_ = RcwmRoundUpPower2(input_size_);
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(count_ * input_shape_size_ * sizeof(S));
output_size_list_.push_back(count_ * sizeof(T));
workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S));
workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
int blocknum = std::ceil(static_cast<float>(ceil_power2_) / BLOCKSIZE);
workspace_size_list_.push_back(blocknum * sizeof(S));
workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState));
}
private:
int input_shape_size_;
int seedc_;
int input_size_;
int count_;
int ceil_power2_;
std::vector<int> input_shape_5D_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_
......@@ -348,13 +348,13 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
seed2 (int): Random seed2. Default: 0.
Inputs:
- **input_x** (Tensor[bool]) - The input tensor.
- **input_x** (Tensor[bool]) - The input tensor. The input tensor rank should be >= 1 and <= 5.
Outputs:
Two tensors, the first one is the index tensor and the other one is the mask tensor.
- **index** (Tensor) - The output has shape between 2-D and 5-D.
- **mask** (Tensor) - The output has shape 1-D.
- **index** (Tensor) - The output shape is 2-D.
- **mask** (Tensor) - The output shape is 1-D.
Examples:
>>> rnd_choice_mask = P.RandomChoiceWithMask()
......@@ -372,6 +372,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
def infer_shape(self, x_shape):
validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name)
return ([self.count, len(x_shape)], [self.count])
def infer_dtype(self, x_dtype):
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class RCWM_count_in(nn.Cell):
def __init__(self):
super(RCWM_count_in, self).__init__()
self.RCWM_count_in = P.RandomChoiceWithMask(count=4, seed=1)
def construct(self, x):
return self.RCWM_count_in(x)
class RCWM_count_out(nn.Cell):
def __init__(self):
super(RCWM_count_out, self).__init__()
self.RCWM_count_out = P.RandomChoiceWithMask(count=10, seed=1)
def construct(self, x):
return self.RCWM_count_out(x)
class RCWM_3D(nn.Cell):
def __init__(self):
super(RCWM_3D, self).__init__()
self.RCWM_3D = P.RandomChoiceWithMask(count=10, seed=1)
def construct(self, x):
return self.RCWM_3D(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_RCWM_3D():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
input_tensor = Tensor(np.ones([3, 4, 5]).astype(np.bool))
expect1 = [[0, 1, 1], [0, 2, 1], [0, 2, 2], [1, 0, 1], [0, 1, 3], [0, 3, 0], [1, 3, 2], \
[0, 0, 0], [1, 1, 2], [1, 3, 4]]
expect2 = [True, True, True, True, True, True, True, True, True, True]
rcwm = RCWM_3D()
output1, output2 = rcwm(input_tensor)
assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1)
assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_RCWM_count_out():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool))
expect1 = [[0, 2], [2, 2], [2, 1], [2, 0], [0, 0], [3, 3], [2, 3], [1, 3], [0, 0], [0, 0]]
expect2 = [True, True, True, True, True, True, True, True, False, False]
rcwm = RCWM_count_out()
output1, output2 = rcwm(input_tensor)
assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1)
assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_RCWM_count_in():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool))
expect1 = [[0, 2], [2, 2], [2, 1], [2, 0]]
expect2 = [True, True, True, True]
rcwm = RCWM_count_in()
output1, output2 = rcwm(input_tensor)
assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1)
assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册