提交 049d3796 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1091 gpu support Cast/RealDiv/Mul/Sub/Softmax kernels enforcement

Merge pull request !1091 from chenweifeng/cast
......@@ -69,9 +69,8 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_dim0_ * input_dim1_ * sizeof(T));
input_size_list_.push_back(output_dim0_ * sizeof(S));
input_size_list_.push_back(output_dim0_ * sizeof(int));
output_size_list_.push_back(output_dim0_ * output_dim1_ * sizeof(S));
input_size_list_.push_back(input_dim0_ * sizeof(S));
output_size_list_.push_back(output_dim0_ * output_dim1_ * sizeof(T));
}
private:
......
......@@ -49,6 +49,21 @@ struct PowerFunc<half, half> {
}
};
template <typename T, typename S>
struct RealDivFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
};
template <typename T, typename S>
struct MulFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs * rhs); }
};
template <typename T, typename S>
struct SubFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs - rhs); }
};
template <>
struct PowerFunc<half, bool> {
// invalid branch
......@@ -94,6 +109,15 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const
case BROADCAST_TYPE_POWER:
return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
case BROADCAST_TYPE_REALDIV:
return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
case BROADCAST_TYPE_MUL:
return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
case BROADCAST_TYPE_SUB:
return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
}
}
......@@ -127,6 +151,12 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const
return NoBroadcastOperator<T, S, MaximumFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_POWER:
return NoBroadcastOperator<T, S, PowerFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_REALDIV:
return NoBroadcastOperator<T, S, RealDivFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_MUL:
return NoBroadcastOperator<T, S, MulFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_SUB:
return NoBroadcastOperator<T, S, SubFunc<T, S>>(nums, input0, input1, output);
}
}
......
......@@ -25,6 +25,9 @@ enum BroadcastOpType {
BROADCAST_TYPE_MAXIMUM = 2,
BROADCAST_TYPE_MINIMUM = 3,
BROADCAST_TYPE_POWER = 4,
BROADCAST_TYPE_REALDIV = 5,
BROADCAST_TYPE_MUL = 6,
BROADCAST_TYPE_SUB = 7,
BROADCAST_TYPE_INVALID = 0xffffffff,
};
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/math/binary_op_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BinaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BinaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BinaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BinaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BinaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BinaryOpGpuKernel, half)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_
#include <cuda_runtime_api.h>
#include <vector>
#include <string>
#include <map>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/unary_op_impl.cuh"
#include "kernel/gpu/kernel_constants.h"
namespace mindspore {
namespace kernel {
enum BinaryOpType { BINARY_OP_ADD = 0, BINARY_OP_SUB, BINARY_OP_MUL, BINARY_OP_DIV, BINARY_OP_INVALID_TYPE = 255 };
static const std::map<std::string, BinaryOpType> kBinaryOpTypeMap = {
{"Sub", BINARY_OP_SUB}, {"Mul", BINARY_OP_MUL}, {"RealDiv", BINARY_OP_DIV}};
template <typename T>
class BinaryOpGpuKernel : public GpuKernel {
public:
BinaryOpGpuKernel()
: cudnn_handle_(nullptr),
binary_op_type_(BINARY_OP_INVALID_TYPE),
tensor_op_(CUDNN_OP_TENSOR_MUL),
inputA_descriptor_(nullptr),
inputB_descriptor_(nullptr),
opTensor_descriptor_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT),
is_null_input_(false),
input_size_(0),
output_size_(0),
workspace_size_(0) {}
~BinaryOpGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *input_addr2 = GetDeviceAddress<T>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
const float alpha = 1;
const float beta = 0;
T *inputB_addr = nullptr;
switch (binary_op_type_) {
case BINARY_OP_SUB: {
T *workspace_addr = GetDeviceAddress<T>(workspace, 0);
Negative(input_addr2, workspace_addr, inputs[1]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
inputB_addr = workspace_addr;
break;
}
case BINARY_OP_MUL: {
inputB_addr = input_addr2;
break;
}
case BINARY_OP_DIV: {
T *workspace_addr = GetDeviceAddress<T>(workspace, 0);
Reciprocal(input_addr2, workspace_addr, inputs[1]->size / sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));
inputB_addr = workspace_addr;
break;
}
default: {
MS_LOG(EXCEPTION) << "Binary operation " << binary_op_type_ << " is not supported.";
}
}
if (inputs[0]->size > inputs[1]->size) {
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnOpTensor(cudnn_handle_, opTensor_descriptor_, &alpha, inputA_descriptor_, input_addr, &alpha,
inputB_descriptor_, inputB_addr, &beta, inputA_descriptor_, output_addr),
"cudnnOpTensor failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnOpTensor(cudnn_handle_, opTensor_descriptor_, &alpha, inputB_descriptor_, inputB_addr, &alpha,
inputA_descriptor_, input_addr, &beta, inputB_descriptor_, output_addr),
"cudnnOpTensor failed");
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
InitResource();
cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))];
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but binary operation needs 2 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but binary operation needs 1 output.";
return false;
}
InferBinaryType(kernel_node);
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto input_shapeB = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (input_shape != output_shape && input_shapeB != output_shape) {
MS_LOG(ERROR) << "Double-sided broadcast was not supported in cudnn of cudnnOpTensor:\n"
"InputA must match the corresponding dimension of the destination tensor outC, and each "
"dimension of the inputB "
"must match the corresponding dimension of outC or must be equal to 1.";
return false;
}
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_shapeB);
if (is_null_input_) {
MS_LOG(WARNING) << "BinaryOpGpuKernel input is null";
InitSizeLists();
return true;
}
int shape_n = input_shape.size() < 4 ? 1 : SizeToInt(input_shape[input_shape.size() - 4]);
int shape_c = input_shape.size() < 3 ? 1 : SizeToInt(input_shape[input_shape.size() - 3]);
int shape_h = input_shape.size() < 2 ? 1 : SizeToInt(input_shape[input_shape.size() - 2]);
int shape_w = input_shape.size() == 0 ? 1 : SizeToInt(input_shape[input_shape.size() - 1]);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape_n, shape_c, shape_h, shape_w),
"cudnnSetTensor4dDescriptor failed");
int shapeB_n = input_shapeB.size() < 4 ? 1 : SizeToInt(input_shapeB[input_shapeB.size() - 4]);
int shapeB_c = input_shapeB.size() < 3 ? 1 : SizeToInt(input_shapeB[input_shapeB.size() - 3]);
int shapeB_h = input_shapeB.size() < 2 ? 1 : SizeToInt(input_shapeB[input_shapeB.size() - 2]);
int shapeB_w = input_shapeB.size() == 0 ? 1 : SizeToInt(input_shapeB[input_shapeB.size() - 1]);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputB_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shapeB_n, shapeB_c, shapeB_h, shapeB_w),
"cudnnSetTensor4dDescriptor failed");
InitSizeLists();
return true;
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&inputA_descriptor_),
"cudnnCreateTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&inputB_descriptor_),
"cudnnCreateTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateOpTensorDescriptor(&opTensor_descriptor_),
"cudnnCreateOpTensorDescriptor failed.");
}
void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(inputA_descriptor_, &input_size_),
"cudnnGetTensorSizeInBytes failed.");
input_size_list_.push_back(input_size_);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(inputB_descriptor_, &output_size_),
"cudnnGetTensorSizeInBytes failed.");
}
input_size_list_.push_back(output_size_);
if (binary_op_type_ == BINARY_OP_DIV || binary_op_type_ == BINARY_OP_SUB) {
workspace_size_ = output_size_;
}
workspace_size_list_.push_back(workspace_size_);
if (output_size_ > input_size_) {
output_size_list_.push_back(output_size_);
} else {
output_size_list_.push_back(input_size_);
}
return;
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(inputA_descriptor_),
"cudnnDestroyTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(inputB_descriptor_),
"cudnnDestroyTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyOpTensorDescriptor(opTensor_descriptor_),
"cudnnDestroyOpTensorDescriptor failed.");
}
void InferBinaryType(const CNodePtr &kernel_node) {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
auto iter = kBinaryOpTypeMap.find(kernel_name);
if (iter == kBinaryOpTypeMap.end()) {
MS_LOG(EXCEPTION) << "Binary operation " << kernel_name << " is not supported.";
} else {
binary_op_type_ = iter->second;
}
switch (binary_op_type_) {
case BINARY_OP_DIV:
case BINARY_OP_MUL: {
tensor_op_ = CUDNN_OP_TENSOR_MUL;
break;
}
case BINARY_OP_SUB: {
tensor_op_ = CUDNN_OP_TENSOR_ADD;
break;
}
default: {
MS_LOG(EXCEPTION) << "Binary operation " << binary_op_type_ << " is not supported.";
}
}
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetOpTensorDescriptor(opTensor_descriptor_, tensor_op_, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN),
"cudnnSetOpTensorDescriptor failed");
return;
}
cudnnHandle_t cudnn_handle_;
BinaryOpType binary_op_type_;
cudnnOpTensorOp_t tensor_op_;
cudnnTensorDescriptor_t inputA_descriptor_;
cudnnTensorDescriptor_t inputB_descriptor_;
cudnnOpTensorDescriptor_t opTensor_descriptor_;
cudnnDataType_t cudnn_data_type_;
bool is_null_input_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_
......@@ -37,6 +37,16 @@ MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_TWO(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
// fp16
MS_REG_GPU_KERNEL_TWO(
......@@ -57,5 +67,15 @@ MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_TWO(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
} // namespace kernel
} // namespace mindspore
......@@ -98,7 +98,8 @@ class BroadcastOpGpuKernel : public GpuKernel {
static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = {
{"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM},
{"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER},
{"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV},
{"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
};
auto iter = kBroadcastTypeMap.find(kernel_name);
......
......@@ -58,11 +58,6 @@ class SoftmaxGpuKernel : public GpuKernel {
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
T *transpose_input_addr = GetDeviceAddress<T>(workspace, 0);
T *transpose_output_addr = GetDeviceAddress<T>(workspace, 1);
int *input_shape = GetDeviceAddress<int>(workspace, 2);
int *transpose_shape = GetDeviceAddress<int>(workspace, 3);
int *transpose_axis = GetDeviceAddress<int>(workspace, 4);
const float alpha = 1;
const float beta = 0;
......@@ -71,6 +66,11 @@ class SoftmaxGpuKernel : public GpuKernel {
input_addr, &beta, output_descriptor_, output_addr),
"cudnnSoftmaxForward failed");
} else {
T *transpose_input_addr = GetDeviceAddress<T>(workspace, 0);
T *transpose_output_addr = GetDeviceAddress<T>(workspace, 1);
int *input_shape = GetDeviceAddress<int>(workspace, 2);
int *transpose_shape = GetDeviceAddress<int>(workspace, 3);
int *transpose_axis = GetDeviceAddress<int>(workspace, 4);
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_shape failed");
......@@ -114,9 +114,6 @@ class SoftmaxGpuKernel : public GpuKernel {
return true;
}
shape_size_ = SizeToInt(input_shape.size());
if (shape_size_ != 2) {
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but softmax only supports 2-D inputs.";
}
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == "LogSoftmax") {
algo_ = CUDNN_SOFTMAX_LOG;
......@@ -163,7 +160,15 @@ class SoftmaxGpuKernel : public GpuKernel {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "destroy input_descriptor failed");
}
void InitSizeByAxis(const std::vector<size_t> input_shape, const int axis) {
void InitSizeByAxis(const std::vector<size_t> &input_shape, const int &axis) {
if (input_shape.size() == 2) {
InitSizeByAxis2D(input_shape, axis);
} else {
InitSizeByAxisLastDim(input_shape, axis);
}
}
void InitSizeByAxis2D(const std::vector<size_t> &input_shape, const int &axis) {
axis_ = axis;
if (axis_ < 0) {
axis_ += shape_size_;
......@@ -191,6 +196,31 @@ class SoftmaxGpuKernel : public GpuKernel {
workspace_size_ = IntToSize(shape_size_) * sizeof(int);
}
void InitSizeByAxisLastDim(const std::vector<size_t> &input_shape, const int &axis) {
int axis_pos = axis;
if (axis_pos < 0) {
axis_pos += input_shape.size();
}
// axis should be -1 with ND
if (axis_pos != SizeToInt(input_shape.size() - 1)) {
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid.";
}
// squeeze to 2d, then invoke cudnn
size_t n = 1;
for (size_t i = 0; i < input_shape.size() - 1; i++) {
n *= input_shape[i];
}
axis_ = 1;
batch_size_ = n;
channel_size_ = input_shape[axis_pos];
height_ = 1;
width_ = 1;
input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_;
output_size_ = input_size_;
input_shape_.push_back(batch_size_);
input_shape_.push_back(channel_size_);
}
cudnnHandle_t cudnn_handle_;
cudnnTensorDescriptor_t input_descriptor_;
cudnnTensorDescriptor_t output_descriptor_;
......
......@@ -22,6 +22,8 @@ cast_op_info = AkgRegOp("Cast") \
.attr("dst_type", "required", "str") \
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
.get_op_info()
......
......@@ -50,6 +50,19 @@ def test_nobroadcast():
output_np = np.power(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.RealDiv()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np / x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Mul()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np * x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Sub()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np - x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
......@@ -80,6 +93,17 @@ def test_broadcast():
output_np = np.power(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.RealDiv()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np / x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Mul()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np * x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Sub()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np - x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
......@@ -109,3 +133,15 @@ def test_broadcast_diff_dims():
output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np))
output_np = np.power(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.RealDiv()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np / x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Mul()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np * x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Sub()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np - x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
\ No newline at end of file
......@@ -49,3 +49,21 @@ def test_cast():
assert (type0 == 'float16')
type1 = output[1].asnumpy().dtype
assert (type1 == 'float32')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast1():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int32))
t0 = mstype.float32
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool))
t1 = mstype.float32
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net()
output = net(x0, t0, x1, t1)
type0 = output[0].asnumpy().dtype
assert (type0 == 'float32')
type1 = output[1].asnumpy().dtype
assert (type1 == 'float32')
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.context as context
class NetSoftmax(nn.Cell):
def __init__(self):
super(NetSoftmax, self).__init__()
axis = -2
self.softmax1 = P.Softmax()
self.softmax2 = P.Softmax(axis)
def construct(self, x):
return self.softmax1(x), self.softmax2(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_softmax():
x = Tensor(np.array([[0.1, 0.3, 0.6, -0.3],
[0.2, -0.6, 0.8, 0.6],
[0.6, -1.2, 0.4, 0.6]]).astype(np.float32))
expect1 = np.ones(3)
expect2 = np.ones(4)
error1 = expect1 * 1.0e-6
error2 = expect2 * 1.0e-6
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
Softmax = NetSoftmax()
output = Softmax(x)
outputSum1 = output[0].asnumpy().sum(axis=1)
outputSum2 = output[1].asnumpy().sum(axis=0)
diff1 = np.abs(outputSum1 - expect1)
diff2 = np.abs(outputSum2 - expect2)
assert np.all(diff1 < error1)
assert np.all(diff2 < error2)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
Softmax = NetSoftmax()
output = Softmax(x)
outputSum1 = output[0].asnumpy().sum(axis=1)
outputSum2 = output[1].asnumpy().sum(axis=0)
diff1 = np.abs(outputSum1 - expect1)
diff2 = np.abs(outputSum2 - expect2)
assert np.all(diff1 < error1)
assert np.all(diff2 < error2)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
import mindspore.nn as nn
import mindspore.context as context
class NetSoftmax(nn.Cell):
def __init__(self):
super(NetSoftmax, self).__init__()
axis = -2
self.softmax1 = P.Softmax()
self.softmax2 = P.Softmax(axis)
def construct(self, x):
return self.softmax1(x), self.softmax2(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_softmax():
x = Tensor(np.array([[0.1, 0.3, 0.6, -0.3],
[0.2, -0.6, 0.8, 0.6],
[0.6, -1.2, 0.4, 0.6]]).astype(np.float32))
expect1 = np.ones(3)
expect2 = np.ones(4)
error1 = expect1 * 1.0e-6
error2 = expect2 * 1.0e-6
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
Softmax = NetSoftmax()
output = Softmax(x)
outputSum1 = output[0].asnumpy().sum(axis=1)
outputSum2 = output[1].asnumpy().sum(axis=0)
diff1 = np.abs(outputSum1 - expect1)
diff2 = np.abs(outputSum2 - expect2)
assert np.all(diff1 < error1)
assert np.all(diff2 < error2)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
Softmax = NetSoftmax()
output = Softmax(x)
outputSum1 = output[0].asnumpy().sum(axis=1)
outputSum2 = output[1].asnumpy().sum(axis=0)
diff1 = np.abs(outputSum1 - expect1)
diff2 = np.abs(outputSum2 - expect2)
assert np.all(diff1 < error1)
assert np.all(diff2 < error2)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.softmax1 = P.Softmax()
def construct(self, x):
return self.softmax1(x)
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True)
self.network = network
def construct(self, input_data, sens):
gout = self.grad(self.network)(input_data, sens)
return gout
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_softmax_4d():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = np.array([[[[ 2.7866030e-01, 8.5578346e-01, -2.7546784e-01, -8.5833269e-01, 1.5753637e-01],
[-4.5145524e-01, 1.5590921e-01, -6.1947298e-01, -6.3499230e-01, -1.0625143e+00],
[-6.8716180e-01, -3.5565588e-01, 9.9680430e-01, -3.5519487e-01, 5.2122700e-01],
[-9.8125875e-01, 9.0505141e-01, 6.5961617e-01, 6.5950197e-01, 1.0319239e+00]],
[[-7.6588345e-01, -1.6929083e-01, 9.4459933e-01, -8.3931917e-01, 1.4916732e+00],
[ 8.1874236e-02, -1.9288104e-02, 7.3255712e-01, -1.4598954e-01, 1.1225560e+00],
[ 2.7356184e-01, 1.2557162e-01, 1.3796539e+00, 1.0073920e-01, 7.9203087e-01],
[-3.6947381e-01, 4.7919992e-01, 2.2421131e+00, -8.3911163e-01, 1.0814662e+00]],
[[-2.5838584e-01, 2.0765430e-01, -1.9366746e-01, 6.7511219e-01, -3.7492469e-01],
[ 4.4170797e-01, -9.9537361e-01, -3.5100895e-01, -7.8317386e-01, 1.1672008e-02],
[ 1.6037937e+00, -1.7059358e+00, -9.3724984e-01, -1.5016698e+00, -2.7605603e-02],
[ 1.6392696e-01, 1.0074581e+00, -2.7704465e+00, 8.1361882e-02, 7.9730105e-01]]],
[[[ 2.9516423e-01, 4.6354745e-02, 1.7318316e-01, 1.5894413e+00, -1.2769363e+00],
[ 2.8939021e-01, -3.8801813e-01, -1.3376296e+00, -4.9808905e-01, -3.2318991e-02],
[-1.1740140e+00, -1.1140432e+00, -1.4198960e-01, 5.8953021e-02, -3.6763316e-01],
[ 1.8660797e+00, -5.8705074e-01, 6.8757606e-01, -4.0573463e-01, -7.1130061e-01]],
[[ 2.6170531e-01, 5.4814044e-02, 1.3891056e-01, 3.4492522e-02, -1.0920379e-01],
[ 1.1420644e-01, 1.6939731e-01, -1.0413316e+00, -1.4040415e-01, -3.3280477e-01],
[-3.0776244e-01, 1.0526397e+00, 2.9497927e-01, 1.1266683e+00, 8.4419928e-02],
[-2.1593940e+00, -1.0187222e+00, 1.7475771e+00, -3.5802367e-01, -1.2900480e+00]],
[[ 3.2892069e-01, -1.6604670e+00, -5.7856506e-01, 5.8143520e-01, 5.9596705e-01],
[-1.5992336e-01, -5.9647644e-01, 1.2957820e+00, -1.0650631e-01, 7.0879894e-01],
[ 4.1372257e-01, 3.6408889e-01, -6.3091749e-01, 1.0573713e+00, 1.0981073e+00],
[-1.9162457e-01, 3.6392561e-05, -1.8338780e-01, 1.7549801e+00, -9.3534666e-01]]]]).astype(np.float32)
dy = np.array([[[[ 2.98213929e-01, 3.10518718e+00, -1.64306939e-01, -7.33681679e-01, 5.23136854e-02],
[-3.47142726e-01, -1.52662742e+00, 5.26977003e-01, 5.29672280e-02, -4.34386432e-01],
[ 1.34674394e+00, 1.69386661e+00, 3.17139983e-01, 5.77129781e-01, 1.25290680e+00],
[-1.71099675e+00, -1.62872851e+00, -7.89083183e-01, 8.64615321e-01, -1.74364686e+00]],
[[ 1.11915946e+00, -7.06878662e-01, -6.71557069e-01, -4.50884640e-01, 2.95763493e-01],
[-7.64747679e-01, 1.62951392e-03, -2.84069944e-02, 7.55402744e-01, -1.02387452e+00],
[-5.92088878e-01, 4.47980821e-01, 4.50127304e-01, -3.99038166e-01, -5.24561822e-01],
[ 1.92535609e-01, 2.44671494e-01, -8.70469391e-01, -8.30129832e-02, -4.04477213e-03]],
[[-1.94159836e-01, -8.50215256e-01, -1.01224804e+00, 2.64235616e-01, 5.34391068e-02],
[-6.71353936e-01, 3.73690695e-01, 4.48037744e-01, -2.84973383e-01, -2.80129910e+00],
[ 6.69475198e-01, 2.08404279e+00, 4.49459851e-01, 2.50908136e+00, 9.80683088e-01],
[ 1.18290365e+00, -1.28790128e+00, -1.70202863e+00, -1.37078688e-01, 9.53227460e-01]]],
[[[-6.44128084e-01, 1.37707603e+00, -8.60912442e-01, -3.83467346e-01, 6.68365955e-01],
[-3.32795471e-01, 3.05202007e-01, 2.20850635e+00, 6.93960607e-01, -1.94968760e-01],
[-3.35764170e-01, 1.10562348e+00, -1.13264215e+00, -1.08296621e+00, -6.53923571e-01],
[-4.64974046e-01, 8.83257568e-01, -1.70353889e+00, -4.48120385e-01, -1.76938546e+00]],
[[-3.80976290e-01, -1.49393475e+00, -8.51393223e-01, -1.49780405e+00, -1.24160886e-01],
[-7.18508661e-02, 2.44543999e-01, 3.29225749e-01, 7.09274471e-01, -9.26648498e-01],
[ 6.67312503e-01, -1.08737612e+00, -9.63039994e-01, -3.22715081e-02, -4.03802067e-01],
[-5.97982287e-01, -1.40739769e-01, 2.80631828e+00, 5.72278857e-01, 2.05998325e+00]],
[[ 3.46207246e-02, 7.34213948e-01, 1.45563519e+00, 1.02045703e+00, 1.40984225e+00],
[ 4.14457440e-01, -8.74118507e-01, -4.21902031e-01, 7.87168801e-01, -1.48280108e+00],
[ 1.42688036e+00, -2.02695489e+00, 9.26816165e-01, 9.37691629e-01, 7.85577714e-01],
[-6.59893751e-01, 1.14681525e-02, -5.79456389e-01, -1.65206456e+00, 4.37116653e-01]]]]).astype(np.float32)
expect_x = np.array([[[[0.21919312, 0.3903627, 0.12594244, 0.07031325, 0.19418849],
[0.19778392, 0.36304963, 0.16719443, 0.1646197, 0.10735231],
[0.07986113, 0.11125171, 0.43020225, 0.11130301, 0.26738194],
[0.03936873, 0.25963634, 0.20313013, 0.20310691, 0.29475793]],
[[0.05308856, 0.09640461, 0.29366633, 0.04932966, 0.50751084],
[0.13426398, 0.12134594, 0.2573638, 0.10690536, 0.38012096],
[0.13503104, 0.11645612, 0.40813455, 0.11359984, 0.22677852],
[0.04576753, 0.10693795, 0.6233836, 0.02861518, 0.19529575]],
[[0.14096586, 0.2246532, 0.15039064, 0.35853124, 0.12545899],
[0.37957698, 0.09019516, 0.17180163, 0.11151683, 0.2469094 ],
[0.7375885, 0.0269412, 0.05811028, 0.03304673, 0.14431332],
[0.16174863, 0.37599453, 0.00859921, 0.1489303, 0.3047274 ]]],
[[[0.15335402, 0.11957449, 0.13574363, 0.55949026, 0.03183762],
[0.34669915, 0.17609946, 0.06813136, 0.15774474, 0.2513253 ],
[0.09487908, 0.10074313, 0.26630113, 0.32556766, 0.21250896],
[0.6357843, 0.05469263, 0.19565557, 0.0655652, 0.0483023 ]],
[[0.23898226, 0.19431841, 0.21136671, 0.19040942, 0.16492325],
[0.2641041, 0.27909, 0.08316323, 0.20473833, 0.16890427],
[0.08062991, 0.3142761, 0.14732064, 0.33842432, 0.11934903],
[0.01604616, 0.05020634, 0.79826504, 0.09720672, 0.03827571]],
[[0.24191543, 0.03308899, 0.09762195, 0.31140763, 0.31596598],
[0.10669514, 0.06895282, 0.45745608, 0.11254943, 0.25434658],
[0.16156755, 0.15374413, 0.05684244, 0.3075298, 0.32031605],
[0.09346025, 0.11320464, 0.09423324, 0.65467626, 0.04442552]]]]).astype(np.float32)
expect_dx = np.array([[[[-0.20103945, 0.737705 , -0.17376284, -0.1370458 , -0.22585672],
[ 0.04461281, -0.34632078, 0.18386088, 0.10299816, 0.01484894],
[ 0.04113413, 0.09592049, -0.22135337, -0.02833145, 0.11263024],
[-0.0284293 , -0.1661311 , 0.04058228, 0.37645525, -0.22247711]],
[[ 0.06355994, -0.06061868, -0.17428297, -0.01839012, 0.1897318 ],
[-0.04652473, 0.05094835, 0.10032654, 0.12546772, -0.23021786],
[-0.07882182, 0.05314343, 0.18712361, -0.04438123, -0.11706398],
[ 0.03219109, 0.08079126, -0.22419631, 0.01224192, 0.09897206]],
[[ 0.01057316, -0.1305348 , -0.11175273, 0.19124077, 0.04047358],
[ 0.07448982, 0.11195826, 0.2260284 , 0.06497248, -0.47744888],
[-0.09664576, 0.03458005, -0.02039931, 0.05646288, 0.02600216],
[ 0.1973966 , -0.47014874, -0.01431374, -0.01483214, 0.30189803]]],
[[[-0.06132338, 0.19386888, -0.08370841, -0.07789247, 0.02905542],
[-0.16714299, 0.0274538 , 0.14029635, 0.08591694, -0.08652411],
[ 0.03585254, 0.18327834, -0.11158065, -0.12024056, 0.01269035],
[ 0.14654502, 0.0863447 , -0.19723451, 0.01621746, -0.05187264]],
[[ 0.11614501, -0.12182987, 0.00329342, -0.12011584, 0.12250728],
[-0.03623635, 0.05001016, 0.02194443, 0.13183522, -0.16755345],
[ 0.09322704, -0.18807998, -0.06984743, 0.15454148, 0.01015892],
[-0.04743218, -0.12545264, 0.35787603, -0.1735842 , -0.01140684]],
[[-0.21854429, -0.00674347, 0.05053139, 0.02567403, 0.14908233],
[ 0.09731252, -0.02596174, 0.03463032, 0.14460044, -0.2505815 ],
[ 0.1478814 , -0.3902862 , 0.02360253, 0.13103928, 0.087763 ],
[ 0.04834083, 0.13455458, 0.05632052, -0.3109298 , 0.07171366]]]]).astype(np.float32)
y = Net()(Tensor(x))
assert np.allclose(y.asnumpy(), expect_x)
dx = Grad(Net())(Tensor(x), Tensor(dy))
assert np.allclose(dx[0].asnumpy(), expect_dx)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册