提交 f95fe92a 编写于 作者: V VectorSL

slice support nhwc

上级 4daabe1d
......@@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H
#include <vector>
#include <utility>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh"
......@@ -27,8 +28,7 @@ namespace kernel {
template <typename T>
class SliceGpuFwdKernel : public GpuKernel {
public:
SliceGpuFwdKernel()
: is_strided_slice_(false), is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0) {}
SliceGpuFwdKernel() : is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0) {}
~SliceGpuFwdKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
......@@ -50,51 +50,31 @@ class SliceGpuFwdKernel : public GpuKernel {
if (!CheckParam(kernel_node)) {
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto data_format = AnfAlgo::GetInputFormat(kernel_node, 0);
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
ShapeNdTo4d(input_shape, &input_shape_);
auto strides = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides");
if (strides) {
strides_ = GetAttr<std::vector<int>>(kernel_node, "strides");
for (auto i = strides_.size(); i < 4; i++) {
(void)strides_.insert(strides_.begin(), 1);
}
size_ = GetAttr<std::vector<int>>(kernel_node, "end");
is_strided_slice_ = true;
} else {
size_ = GetAttr<std::vector<int>>(kernel_node, "size");
}
for (auto i = begin_.size(); i < 4; i++) {
(void)begin_.insert(begin_.begin(), 0);
}
for (size_t i = size_.size(); i < 4; i++) {
(void)size_.insert(size_.begin(), 1);
}
for (size_t i = 0; i < begin_.size(); i++) {
if (begin_[i] < 0) {
begin_[i] = begin_[i] + input_shape_[i];
}
}
for (size_t i = 0; i < size_.size(); i++) {
if (size_[i] < 0) {
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0;
}
if (begin_[i] == size_[i] && is_strided_slice_) {
MS_LOG(WARNING) << "Output is null.";
is_null_input_ = true;
}
if (size_[i] == 0 && strides_[i] > 0) {
size_[i] = begin_[i] + 1;
}
}
input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T);
auto out_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
auto out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
output_size_ = sizeof(T);
for (size_t x : out_shape) {
output_size_ = output_size_ * x;
}
// transpose begin and size for NHWC data
if (data_format == "NHWC") {
std::swap(begin_[1], begin_[3]);
std::swap(begin_[1], begin_[2]);
std::swap(size_[1], size_[3]);
std::swap(size_[1], size_[2]);
}
InitSizeLists();
return true;
}
......@@ -126,26 +106,24 @@ class SliceGpuFwdKernel : public GpuKernel {
MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", scalar is not supported.";
return false;
}
size_ = GetAttr<std::vector<int>>(kernel_node, "size");
begin_ = GetAttr<std::vector<int>>(kernel_node, "begin");
for (size_t i = 0; i < input_shape.size(); i++) {
if ((begin_[i] > 0 && (begin_[i] > SizeToInt(input_shape[i]))) ||
(begin_[i] < 0 && (std::abs(begin_[i]) > SizeToInt(input_shape[i])))) {
MS_LOG(INFO) << "Input out of bounds " << input_shape[i] << " in axis " << i << ".";
begin_[i] = 0;
if (input_shape[i] <= 0 || size_[i] <= 0) {
MS_LOG(WARNING) << "Slice output is null.";
is_null_input_ = true;
}
}
return true;
}
std::vector<int> begin_;
std::vector<int> size_;
std::vector<int> strides_;
std::vector<int> input_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
bool is_strided_slice_;
bool is_null_input_;
size_t input_size_;
size_t output_size_;
......
......@@ -38,6 +38,7 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
{prim::kPrimReluGrad->name(), {{0, 1}, {0}}},
{prim::kPrimMaxPool->name(), {{0}, {0}}},
{prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}},
{kSliceOpName, {{0}, {0}}},
{kAvgPoolOpName, {{0}, {0}}},
{kAvgPoolGradGpuOpName, {{0, 1, 2}, {0}}},
{kTensorAddOpName, {{0, 1}, {0}}},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册