提交 d22a5976 编写于 作者: V VectorSL

gpu fix addn bug and supported list bug

上级 180b3029
...@@ -88,10 +88,11 @@ std::string SupportedTypeList(const CNodePtr &kernel_node) { ...@@ -88,10 +88,11 @@ std::string SupportedTypeList(const CNodePtr &kernel_node) {
supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type); supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type);
} }
supported_type_lists = supported_type_lists + supported_akg_type_list + "], out["; supported_type_lists = supported_type_lists + supported_akg_type_list + "], out[";
supported_akg_type_list.clear();
for (auto type : supported_akg_type_out) { for (auto type : supported_akg_type_out) {
supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type); supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type);
} }
supported_type_lists += "]; "; supported_type_lists = supported_type_lists + supported_akg_type_list + "]; ";
} }
return supported_type_lists; return supported_type_lists;
} }
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include <vector> #include <vector>
#include "kernel/gpu/gpu_kernel.h" #include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h" #include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/math/broadcast_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/slice_impl.cuh"
#include "kernel/gpu/kernel_constants.h" #include "kernel/gpu/kernel_constants.h"
namespace mindspore { namespace mindspore {
...@@ -43,18 +45,26 @@ class AddNGpuFwdKernel : public GpuKernel { ...@@ -43,18 +45,26 @@ class AddNGpuFwdKernel : public GpuKernel {
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *) override { const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) { if (is_null_input_) {
return true; return true;
} }
T *output_addr = GetDeviceAddress<T>(outputs, 0); T *output_addr = GetDeviceAddress<T>(outputs, 0);
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
FillDeviceArray(outputs[0]->size / sizeof(T), output_addr, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr));
}
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
for (size_t i = 0; i < IntToSize(num_input_); i++) { for (size_t i = 0; i < IntToSize(num_input_); i++) {
T *input_addr = GetDeviceAddress<T>(inputs, i); T *input_addr = GetDeviceAddress<T>(inputs, i);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr, if (cudnn_data_type_ == CUDNN_DATA_INT32) {
&(i > 0 ? alpha : beta), input_descriptor_, output_addr), NoBroadcast(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr,
"cudnnAddTensor failed"); reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr,
&(i > 0 ? alpha : beta), input_descriptor_, output_addr),
"cudnnAddTensor failed");
}
} }
return true; return true;
} }
...@@ -100,9 +110,8 @@ class AddNGpuFwdKernel : public GpuKernel { ...@@ -100,9 +110,8 @@ class AddNGpuFwdKernel : public GpuKernel {
} }
void InitSizeLists() override { void InitSizeLists() override {
if (!is_null_input_) { if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_),
cudnnGetTensorSizeInBytes(input_descriptor_, reinterpret_cast<size_t *>(&input_size_)), "cudnnGetTensorSizeInBytes failed");
"cudnnGetTensorSizeInBytes failed");
} }
for (int i = 0; i < num_input_; i++) { for (int i = 0; i < num_input_; i++) {
input_size_list_.push_back(input_size_); input_size_list_.push_back(input_size_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册