提交 78787931 编写于 作者: S sunsuodong

optimize performance

上级 1971e3f9
...@@ -66,5 +66,15 @@ size_t CPUKernelUtils::GetElementNumOnAxis(const std::vector<size_t> &shape, int ...@@ -66,5 +66,15 @@ size_t CPUKernelUtils::GetElementNumOnAxis(const std::vector<size_t> &shape, int
} }
return result; return result;
} }
void CPUKernelUtils::GetElementNumEveryDim(const std::vector<size_t> &shape, std::vector<size_t> *element_num) {
size_t accumulation = 1;
element_num->emplace_back(1);
for (size_t i = shape.size() - 1; i > 0; --i) {
accumulation *= shape[i];
element_num->emplace_back(accumulation);
}
std::reverse(element_num->begin(), element_num->end());
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
...@@ -78,6 +78,7 @@ class CPUKernelUtils { ...@@ -78,6 +78,7 @@ class CPUKernelUtils {
static void ExpandDimsTo4(std::vector<size_t> *shape); static void ExpandDimsTo4(std::vector<size_t> *shape);
static size_t CalcOffset(const std::vector<size_t> &shape, size_t dim0, size_t dim1, size_t dim2, size_t dim3); static size_t CalcOffset(const std::vector<size_t> &shape, size_t dim0, size_t dim1, size_t dim2, size_t dim3);
static size_t GetElementNumOnAxis(const std::vector<size_t> &shape, int axis); static size_t GetElementNumOnAxis(const std::vector<size_t> &shape, int axis);
static void GetElementNumEveryDim(const std::vector<size_t> &shape, std::vector<size_t> *element_num);
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
......
...@@ -70,6 +70,8 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { ...@@ -70,6 +70,8 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
end_.insert(end_.begin(), 1); end_.insert(end_.begin(), 1);
} }
} }
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
} }
bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
...@@ -78,12 +80,40 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, ...@@ -78,12 +80,40 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
for (int i = begin_[0]; i < end_[0]; i += strides_[0]) { bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)};
for (int j = begin_[1]; j < end_[1]; j += strides_[1]) { size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1],
for (int k = begin_[2]; k < end_[2]; k += strides_[2]) { begin_[2] * input_element_num_[2]};
size_t in_step_size[3] = {strides_[0] * input_element_num_[0], strides_[1] * input_element_num_[1],
strides_[2] * input_element_num_[2]};
auto in_n_offset = in_start_offset[0];
auto out_n_offset = 0;
for (int i = begin_[0]; i < end_[0];
i += strides_[0], in_n_offset += in_step_size[0], out_n_offset += output_element_num_[0]) {
if (can_copy_memory[0]) {
CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0]);
continue;
}
auto in_c_offset = in_start_offset[1];
auto out_c_offset = 0;
for (int j = begin_[1]; j < end_[1];
j += strides_[1], in_c_offset += in_step_size[1], out_c_offset += output_element_num_[1]) {
if (can_copy_memory[1]) {
CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset,
input_element_num_[1]);
continue;
}
auto in_h_offset = in_start_offset[2];
auto out_h_offset = 0;
for (int k = begin_[2]; k < end_[2];
k += strides_[2], in_h_offset += in_step_size[2], out_h_offset += output_element_num_[2]) {
if (can_copy_memory[2]) {
CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs,
out_n_offset + out_c_offset + out_h_offset, input_element_num_[2]);
continue;
}
for (int m = begin_[3]; m < end_[3]; m += strides_[3]) { for (int m = begin_[3]; m < end_[3]; m += strides_[3]) {
auto offset = CPUKernelUtils::CalcOffset(input_shape_, i, j, k, m); *output_addr++ = input_addr[in_n_offset + in_c_offset + in_h_offset + m];
*output_addr++ = input_addr[offset];
} }
} }
} }
...@@ -92,7 +122,38 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, ...@@ -92,7 +122,38 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
return true; return true;
} }
void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) { bool SliceCPUKernel::CanCopyMemoryOnAxis(size_t dim) const {
for (size_t i = dim + 1; i < 4; ++i) {
if (begin_[i] != 0 || end_[i] != SizeToInt(input_shape_[i]) || strides_[i] != 1) {
return false;
}
}
return true;
}
void SliceCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset,
size_t copy_num) const {
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto in_buff_size = inputs[0]->size;
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto out_buff_size = outputs[0]->size;
if ((in_offset + copy_num) * sizeof(float) > in_buff_size) {
MS_LOG(EXCEPTION) << "input memory out of bounds.";
}
if ((out_offset + copy_num) * sizeof(float) > out_buff_size) {
MS_LOG(EXCEPTION) << "output memory out of bounds.";
}
auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset,
copy_num * sizeof(float));
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret;
}
}
void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) { if (input_num != 1) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SliceCPUKernel needs 1 inputs."; MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SliceCPUKernel needs 1 inputs.";
......
...@@ -33,12 +33,17 @@ class SliceCPUKernel : public CPUKernel { ...@@ -33,12 +33,17 @@ class SliceCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
void CheckParam(const CNodePtr &kernel_node); bool CanCopyMemoryOnAxis(size_t dim) const;
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num) const;
void CheckParam(const CNodePtr &kernel_node) const;
std::vector<int> begin_; std::vector<int> begin_;
std::vector<int> end_; std::vector<int> end_;
std::vector<int> strides_; std::vector<int> strides_;
std::vector<size_t> input_shape_; std::vector<size_t> input_shape_;
std::vector<size_t> input_element_num_;
std::vector<size_t> output_shape_; std::vector<size_t> output_shape_;
std::vector<size_t> output_element_num_;
}; };
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
......
...@@ -21,13 +21,14 @@ namespace mindspore { ...@@ -21,13 +21,14 @@ namespace mindspore {
namespace kernel { namespace kernel {
void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node); CheckParam(kernel_node);
output_dx_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
input_dy_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN); begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN);
for (size_t i = 0; i < begin_.size(); i++) { for (size_t i = 0; i < begin_.size(); i++) {
if (begin_[i] < 0) { if (begin_[i] < 0) {
begin_[i] = begin_[i] + output_dx_shape_[i]; begin_[i] = begin_[i] + output_shape_[i];
} }
} }
...@@ -37,61 +38,90 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { ...@@ -37,61 +38,90 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (strides != nullptr) { if (strides != nullptr) {
strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES); strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES);
end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END); end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END);
if (strides_.size() != end_.size() || strides_.size() != output_dx_shape_.size()) { if (strides_.size() != end_.size() || strides_.size() != output_shape_.size()) {
MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; MS_LOG(EXCEPTION) << "stride|end|input size must be equal";
} }
for (size_t i = 0; i < strides_.size(); ++i) { for (size_t i = 0; i < strides_.size(); ++i) {
if (strides_[i] < 0) { if (strides_[i] < 0) {
strides_[i] = (strides_[i] + output_dx_shape_[i]) > 0 ? (strides_[i] + output_dx_shape_[i]) : 0; strides_[i] = (strides_[i] + output_shape_[i]) > 0 ? (strides_[i] + output_shape_[i]) : 0;
} }
if (end_[i] < 0) { if (end_[i] < 0) {
end_[i] = (end_[i] + output_dx_shape_[i]) > 0 ? (end_[i] + output_dx_shape_[i]) : 0; end_[i] = (end_[i] + output_shape_[i]) > 0 ? (end_[i] + output_shape_[i]) : 0;
} }
} }
} else { } else {
auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE); auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE);
if (sizes.size() != output_dx_shape_.size() || begin_.size() != output_dx_shape_.size()) { if (sizes.size() != output_shape_.size() || begin_.size() != output_shape_.size()) {
MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; MS_LOG(EXCEPTION) << "begin|size|input size must be equal";
} }
for (size_t i = 0; i < sizes.size(); ++i) { for (size_t i = 0; i < sizes.size(); ++i) {
if (sizes[i] < 0) { if (sizes[i] < 0) {
sizes[i] = (sizes[i] + output_dx_shape_[i]) > 0 ? (sizes[i] + output_dx_shape_[i]) : 0; sizes[i] = (sizes[i] + output_shape_[i]) > 0 ? (sizes[i] + output_shape_[i]) : 0;
} }
strides_.emplace_back(1); strides_.emplace_back(1);
end_.emplace_back(begin_[i] + sizes[i]); end_.emplace_back(begin_[i] + sizes[i]);
} }
} }
auto output_len = output_dx_shape_.size(); auto output_len = output_shape_.size();
if (output_len < 4) { if (output_len < 4) {
for (size_t i = 0; i < 4 - output_len; ++i) { for (size_t i = 0; i < 4 - output_len; ++i) {
output_dx_shape_.insert(output_dx_shape_.begin(), 1); output_shape_.insert(output_shape_.begin(), 1);
begin_.insert(begin_.begin(), 0); begin_.insert(begin_.begin(), 0);
strides_.insert(strides_.begin(), 1); strides_.insert(strides_.begin(), 1);
end_.insert(end_.begin(), 1); end_.insert(end_.begin(), 1);
} }
} }
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
} }
bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/, const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
auto input_dy_addr = reinterpret_cast<float *>(inputs[0]->addr); auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_dx_addr = reinterpret_cast<float *>(outputs[0]->addr); auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto out_size = sizeof(float) * output_dx_shape_[0] * output_dx_shape_[1] * output_dx_shape_[2] * output_dx_shape_[3]; auto ret = memset_s(output_addr, outputs[0]->size, 0, outputs[0]->size);
auto ret = memset_s(output_dx_addr, out_size, 0, out_size);
if (ret != EOK) { if (ret != EOK) {
MS_LOG(ERROR) << "output buff memset fail."; MS_LOG(ERROR) << "output buff memset fail. ret:" << ret;
return false; return false;
} }
for (int i = begin_[0]; i < end_[0]; i += strides_[0]) { bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)};
for (int j = begin_[1]; j < end_[1]; j += strides_[1]) { size_t out_start_offset[3] = {begin_[0] * output_element_num_[0], begin_[1] * output_element_num_[1],
for (int k = begin_[2]; k < end_[2]; k += strides_[2]) { begin_[2] * output_element_num_[2]};
size_t out_step_size[3] = {strides_[0] * output_element_num_[0], strides_[1] * output_element_num_[1],
strides_[2] * output_element_num_[2]};
auto in_n_offset = 0;
auto out_n_offset = out_start_offset[0];
for (int i = begin_[0]; i < end_[0];
i += strides_[0], in_n_offset += input_element_num_[0], out_n_offset += out_step_size[0]) {
if (can_copy_memory[0]) {
CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0]);
continue;
}
auto in_c_offset = 0;
auto out_c_offset = out_start_offset[1];
for (int j = begin_[1]; j < end_[1];
j += strides_[1], in_c_offset += input_element_num_[1], out_c_offset += out_step_size[1]) {
if (can_copy_memory[1]) {
CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset,
input_element_num_[1]);
continue;
}
auto in_h_offset = 0;
auto out_h_offset = out_start_offset[2];
for (int k = begin_[2]; k < end_[2];
k += strides_[2], in_h_offset += input_element_num_[2], out_h_offset += out_step_size[2]) {
if (can_copy_memory[2]) {
CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs,
out_n_offset + out_c_offset + out_h_offset, input_element_num_[2]);
continue;
}
for (int m = begin_[3]; m < end_[3]; m += strides_[3]) { for (int m = begin_[3]; m < end_[3]; m += strides_[3]) {
auto offset = CPUKernelUtils::CalcOffset(output_dx_shape_, i, j, k, m); output_addr[out_n_offset + out_c_offset + out_h_offset + m] = *input_addr++;
output_dx_addr[offset] = *input_dy_addr++;
} }
} }
} }
...@@ -99,7 +129,38 @@ bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, ...@@ -99,7 +129,38 @@ bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
return true; return true;
} }
void SliceGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { bool SliceGradCPUKernel::CanCopyMemoryOnAxis(size_t dim) const {
for (size_t i = dim + 1; i < 4; ++i) {
if (begin_[i] != 0 || end_[i] != SizeToInt(output_shape_[i]) || strides_[i] != 1) {
return false;
}
}
return true;
}
void SliceGradCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset,
size_t copy_num) const {
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto in_buff_size = inputs[0]->size;
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto out_buff_size = outputs[0]->size;
if ((in_offset + copy_num) * sizeof(float) > in_buff_size) {
MS_LOG(EXCEPTION) << "input memory out of bounds.";
}
if ((out_offset + copy_num) * sizeof(float) > out_buff_size) {
MS_LOG(EXCEPTION) << "output memory out of bounds.";
}
auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset,
copy_num * sizeof(float));
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret;
}
}
void SliceGradCPUKernel::CheckParam(const CNodePtr &kernel_node) const {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) { if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceGradGpuKernel needs 1 output."; MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceGradGpuKernel needs 1 output.";
......
...@@ -33,12 +33,17 @@ class SliceGradCPUKernel : public CPUKernel { ...@@ -33,12 +33,17 @@ class SliceGradCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
void CheckParam(const CNodePtr &kernel_node); bool CanCopyMemoryOnAxis(size_t dim) const;
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num) const;
void CheckParam(const CNodePtr &kernel_node) const;
std::vector<int> begin_; std::vector<int> begin_;
std::vector<int> end_; std::vector<int> end_;
std::vector<int> strides_; std::vector<int> strides_;
std::vector<size_t> input_dy_shape_; std::vector<size_t> input_shape_;
std::vector<size_t> output_dx_shape_; std::vector<size_t> input_element_num_;
std::vector<size_t> output_shape_;
std::vector<size_t> output_element_num_;
}; };
MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL(
......
...@@ -40,7 +40,7 @@ class SliceGrad(nn.Cell): ...@@ -40,7 +40,7 @@ class SliceGrad(nn.Cell):
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_slice(): def test_slice_grad():
x = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]]), mstype.float32) x = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]]), mstype.float32)
dy = Tensor(np.array([[[3., 1., 2.]], [[4., 1., 4.]]]), mstype.float32) dy = Tensor(np.array([[[3., 1., 2.]], [[4., 1., 4.]]]), mstype.float32)
slicegrad = SliceGrad() slicegrad = SliceGrad()
...@@ -54,6 +54,27 @@ def test_slice(): ...@@ -54,6 +54,27 @@ def test_slice():
print("output:\n", output) print("output:\n", output)
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
class SliceGrad2(nn.Cell):
def __init__(self):
super(SliceGrad2, self).__init__()
self.slicegrad = G.SliceGrad()
def construct(self, dy, x):
return self.slicegrad(dy, x, (0, 1, 0), (2, 2, 2))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_slice_grad2():
dy = Tensor(np.array([[[2., 3.], [4., 5.]], [[8., 9.], [10., 11.]]]), mstype.float32)
x = Tensor(np.arange(2 * 3 * 2).reshape(2, 3, 2), mstype.float32)
grad = SliceGrad2()
output = grad(dy, x)
print("output:\n", output)
expect = [[[0., 0.], [2., 3.], [4., 5.]],
[[0., 0.], [8., 9.], [10., 11.]]]
assert (output.asnumpy() == expect).all()
if __name__ == '__main__': if __name__ == '__main__':
test_slice() test_slice_grad()
test_slice_grad2()
...@@ -21,6 +21,7 @@ import mindspore.nn as nn ...@@ -21,6 +21,7 @@ import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target='CPU') context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
...@@ -46,6 +47,27 @@ def test_slice(): ...@@ -46,6 +47,27 @@ def test_slice():
print("output:\n", output) print("output:\n", output)
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
class Slice2(nn.Cell):
def __init__(self):
super(Slice2, self).__init__()
self.slice = P.Slice()
def construct(self, x):
return self.slice(x, (1, 0, 0), (1, 2, 3))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_slice2():
x = Tensor(np.arange(3 * 2 * 3).reshape(3, 2, 3), mstype.float32)
expect = [[[6., 7., 8.],
[9., 10., 11.]]]
slice_op = Slice2()
output = slice_op(x)
print("output:\n", output)
assert (output.asnumpy() == expect).all()
if __name__ == '__main__': if __name__ == '__main__':
test_slice() test_slice()
test_slice2()
...@@ -43,3 +43,6 @@ def test_slice(): ...@@ -43,3 +43,6 @@ def test_slice():
expect = [[[5., 5., 5.], expect = [[[5., 5., 5.],
[6., 7., 8.]]] [6., 7., 8.]]]
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
if __name__ == '__main__':
test_slice()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册