提交 6414e23f 编写于 作者: V VectorSL

gpu codex fix

上级 553432c9
......@@ -24,27 +24,27 @@ MS_REG_GPU_KERNEL_ONE(MaxPoolGrad,
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
PoolingGradGpuFwdKernel, float)
PoolingGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(MaxPoolGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
PoolingGradGpuFwdKernel, half)
PoolingGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
PoolingGradGpuFwdKernel, float)
PoolingGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
PoolingGradGpuFwdKernel, half)
PoolingGradGpuKernel, half)
} // namespace kernel
} // namespace mindspore
......@@ -28,9 +28,9 @@
namespace mindspore {
namespace kernel {
template <typename T>
class PoolingGradGpuFwdKernel : public GpuKernel {
class PoolingGradGpuKernel : public GpuKernel {
public:
PoolingGradGpuFwdKernel()
PoolingGradGpuKernel()
: cudnn_handle_(nullptr),
pooling_descriptor_(nullptr),
y_descriptor_(nullptr),
......@@ -55,7 +55,7 @@ class PoolingGradGpuFwdKernel : public GpuKernel {
padded_size_(0),
workspace_size_(0),
use_pad_(true) {}
~PoolingGradGpuFwdKernel() override { DestroyResource(); }
~PoolingGradGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
......@@ -108,7 +108,7 @@ class PoolingGradGpuFwdKernel : public GpuKernel {
auto input_mask = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask);
if (is_null_input_) {
MS_LOG(WARNING) << "PoolingGradGpuFwdKernel input is null.";
MS_LOG(WARNING) << "PoolingGradGpuKernel input is null.";
InitSizeLists();
return true;
}
......@@ -196,7 +196,7 @@ class PoolingGradGpuFwdKernel : public GpuKernel {
bool CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but PoolingGradGpuFwdKernel needs 3 inputs.";
MS_LOG(ERROR) << "Input number is " << input_num << ", but PoolingGradGpuKernel needs 3 inputs.";
return false;
}
return true;
......
......@@ -21,10 +21,10 @@ namespace kernel {
MS_REG_GPU_KERNEL_ONE(
ReluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReluGradGpuFwdKernel, float)
ReluGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
ReluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ReluGradGpuFwdKernel, half)
ReluGradGpuKernel, half)
} // namespace kernel
} // namespace mindspore
......@@ -25,9 +25,9 @@
namespace mindspore {
namespace kernel {
template <typename T>
class ReluGradGpuFwdKernel : public GpuKernel {
class ReluGradGpuKernel : public GpuKernel {
public:
ReluGradGpuFwdKernel()
ReluGradGpuKernel()
: cudnn_handle_(nullptr),
activation_desc_(nullptr),
mode_(CUDNN_ACTIVATION_RELU),
......@@ -35,7 +35,7 @@ class ReluGradGpuFwdKernel : public GpuKernel {
is_null_input_(false),
cudnn_data_type_(CUDNN_DATA_FLOAT),
input_size_(0) {}
~ReluGradGpuFwdKernel() override { DestroyResource(); }
~ReluGradGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
......@@ -63,14 +63,14 @@ class ReluGradGpuFwdKernel : public GpuKernel {
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReluGradGpuFwdKernel needs 2.";
MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReluGradGpuKernel needs 2.";
return false;
}
auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
mode_ = CUDNN_ACTIVATION_RELU;
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "ReluGradGpuFwdKernel input is null.";
MS_LOG(WARNING) << "ReluGradGpuKernel input is null.";
InitSizeLists();
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册