提交 a3d0ddb4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5779 tenoradd profiling

Merge pull request !5779 from chenweifeng/broadcast-refactor
......@@ -36,17 +36,21 @@ enum BroadcastOpType {
BROADCAST_TYPE_INVALID = 0xffffffff,
};
template <typename T, typename S>
void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const T *input0, const T *input1,
S *output, cudaStream_t stream);
template <typename T>
void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream);
template <typename T>
void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream);
template <typename T, typename S>
void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output,
cudaStream_t stream);
template <typename T>
void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims,
enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream);
template <typename T>
void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims,
enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream);
template <typename T>
void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1,
const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
......@@ -58,8 +58,8 @@ class AddNGpuFwdKernel : public GpuKernel {
for (size_t i = 0; i < IntToSize(num_input_); i++) {
T *input_addr = GetDeviceAddress<T>(inputs, i);
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
NoBroadcast(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr,
&(i > 0 ? alpha : beta), input_descriptor_, output_addr),
......
......@@ -19,119 +19,119 @@
namespace mindspore {
namespace kernel {
// fp32
MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_ONE(
Greater,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, float, bool)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, float, bool)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Minimum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
FloorDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
BroadcastOpGpuKernel, float)
// fp16
MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_ONE(
Greater,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, half, bool)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, half, bool)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Minimum,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
FloorDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
BroadcastOpGpuKernel, half)
// int32
MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int, bool)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
BroadcastOpGpuKernel, int)
} // namespace kernel
} // namespace mindspore
......@@ -28,11 +28,16 @@
namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 7;
template <typename T, typename S>
template <typename T>
class BroadcastOpGpuKernel : public GpuKernel {
public:
BroadcastOpGpuKernel()
: op_type_(BROADCAST_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {}
: op_type_(BROADCAST_TYPE_INVALID),
need_broadcast_(false),
is_comp_op_(false),
input1_num_(1),
input2_num_(1),
output_num_(1) {}
~BroadcastOpGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
......@@ -43,13 +48,23 @@ class BroadcastOpGpuKernel : public GpuKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *lhs = GetDeviceAddress<T>(inputs, 0);
T *rhs = GetDeviceAddress<T>(inputs, 1);
S *output = GetDeviceAddress<S>(outputs, 0);
if (need_broadcast_) {
Broadcast(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
if (is_comp_op_) {
bool *output = GetDeviceAddress<bool>(outputs, 0);
if (need_broadcast_) {
BroadcastCmp(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
ElewiseCmp(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
}
} else {
NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
T *output = GetDeviceAddress<T>(outputs, 0);
if (need_broadcast_) {
BroadcastArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
ElewiseArith(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
}
}
return true;
......@@ -91,26 +106,42 @@ class BroadcastOpGpuKernel : public GpuKernel {
void InitSizeLists() override {
input_size_list_.push_back(input1_num_ * sizeof(T));
input_size_list_.push_back(input2_num_ * sizeof(T));
output_size_list_.push_back(output_num_ * sizeof(S));
auto unit_size = is_comp_op_ ? sizeof(bool) : sizeof(T);
output_size_list_.push_back(output_num_ * unit_size);
}
private:
void GetOpType(const CNodePtr &kernel_node) {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = {
{"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM},
{"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV},
{"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD},
{"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD}, {"Div", BROADCAST_TYPE_DIV},
static std::map<std::string, BroadcastOpType> kBroadcastCmpTypeMap = {
{"Greater", BROADCAST_TYPE_GREATER},
{"Less", BROADCAST_TYPE_LESS},
};
auto iter = kBroadcastTypeMap.find(kernel_name);
if (iter == kBroadcastTypeMap.end()) {
MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported.";
} else {
auto iter = kBroadcastCmpTypeMap.find(kernel_name);
if (iter != kBroadcastCmpTypeMap.end()) {
op_type_ = iter->second;
is_comp_op_ = true;
return;
}
static std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = {
{"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER},
{"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
{"TensorAdd", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
{"Div", BROADCAST_TYPE_DIV},
};
iter = kBroadcastArithmetricTypeMap.find(kernel_name);
if (iter != kBroadcastArithmetricTypeMap.end()) {
op_type_ = iter->second;
is_comp_op_ = false;
return;
}
MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported.";
}
bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) {
......@@ -127,6 +158,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
BroadcastOpType op_type_;
bool need_broadcast_;
bool is_comp_op_;
int input1_num_;
int input2_num_;
int output_num_;
......@@ -137,7 +169,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
}; // namespace kernel
} // namespace kernel
} // namespace mindspore
......
......@@ -160,3 +160,45 @@ def test_broadcast_diff_dims():
output_ms = P.Sub()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np - x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_broadcast_fp16():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x1_np = np.random.rand(3, 1, 5, 1).astype(np.float16)
x2_np = np.random.rand(1, 4, 1, 6).astype(np.float16)
output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np))
output_np = np.minimum(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Maximum()(Tensor(x1_np), Tensor(x2_np))
output_np = np.maximum(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np > x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np < x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np))
output_np = np.power(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.RealDiv()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np / x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Mul()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np * x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Sub()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np - x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册