未验证 提交 305f32d1 编写于 作者: R Roc 提交者: GitHub

[MoE]Assign pos op (#40580)

* # This is a combination of 10 commits.
# The first commit's message is:
add expert count op

add ut for expert_count

# This is the 2nd commit message:

update UT only for cuda

# This is the 3rd commit message:

fix for rocm

# This is the 4th commit message:

update ut

# This is the 5th commit message:

add moe module

# This is the 6th commit message:

add expert count op

add ut for expert_count

# This is the 7th commit message:

update UT only for cuda

# This is the 8th commit message:

update ut

# This is the 9th commit message:

add moe module

# This is the 10th commit message:

make expert count private

* add assign pos op

* fix upper num name

* add api _assign pos

* add ut for assign pos op

* update date

* fix for win

* update for test (timeout)

* fix ut

* update

* fix ut for number count
Co-authored-by: Nhlygit66666 <2570058140@qq.com>
上级 9d8cfc1b
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/assign_pos_op.h"
namespace paddle {
namespace operators {
class AssignPosOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("cum_count"), "Input", "cum_count",
"AssignPos");
OP_INOUT_CHECK(ctx->HasInput("eff_num_len"), "Input", "eff_num_len",
"AssignPos");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AssignPos");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "AssignPos");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto cum_count_dtype =
OperatorWithKernel::IndicateVarDataType(ctx, "cum_count");
auto X_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
PADDLE_ENFORCE_EQ(cum_count_dtype, X_dtype,
platform::errors::InvalidArgument(
"The dtype of the cum_count and X should be same"));
PADDLE_ENFORCE_EQ(cum_count_dtype, framework::proto::VarType::INT64,
platform::errors::InvalidArgument(
"The dtype of the cum_count_dtype, eff_num_len and "
"X should be same as int64"));
return framework::OpKernelType(cum_count_dtype, ctx.device_context());
}
};
class AssignPosOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "numbers to scatter.");
AddInput("cum_count", "The cumulative sum count of numbers.");
AddInput("eff_num_len",
"The effective numbers of numbers should be scattered.");
AddOutput("Out", "Assemble numbers in the order of counters.");
AddComment(R"DOC(
assign_pos_op Operator.
Assign pos decides which tokens should be fetched belong to
specially counter orderingly.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(assign_pos, ops::AssignPosOp,
ops::AssignPosOpMaker);
REGISTER_OP_CPU_KERNEL(assign_pos, ops::AssignPosOpCPUKernel<int>,
ops::AssignPosOpCPUKernel<int64_t>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/assign_pos_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_bool(avoid_op_randomness);
namespace paddle {
namespace operators {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void AssignPos(T* cum_count, const T* numbers, T* out,
int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
int number_idx = numbers[i];
if (number_idx > -1) {
int p = platform::CudaAtomicAdd(cum_count + number_idx, -1);
out[p - 1] = i;
}
}
}
template <typename T>
class AssignPosCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// assign pos decides which tokens should be fetched belong to specially
// counter orderingly.
auto cum_count = context.Input<LoDTensor>(
"cum_count"); // (counter number) int32 | int64
auto numbers =
context.Input<LoDTensor>("X"); // (batch_size * seq_len, topk) int32
auto eff_num_len =
context.Input<LoDTensor>("eff_num_len"); // (sum(cum_count))
auto out = context.Output<LoDTensor>("Out"); // (cum_count) value ranges
// from 0 to batch_size *
// seq_len * topk
auto place = context.GetPlace();
auto numel = numbers->numel();
T* cum_data = const_cast<T*>(cum_count->data<T>());
auto cum_size = cum_count->numel();
framework::Tensor cpu_eff_num_len;
int64_t cpu_eff_num_len_data = 0;
if (platform::is_cpu_place(eff_num_len->place())) {
cpu_eff_num_len_data = eff_num_len->data<T>()[0];
} else {
framework::TensorCopySync(*eff_num_len, platform::CPUPlace(),
&cpu_eff_num_len);
cpu_eff_num_len_data = cpu_eff_num_len.data<T>()[0];
}
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
framework::DDim out_dims = phi::make_ddim({cpu_eff_num_len_data});
auto out_data = out->mutable_data<T>(out_dims, place);
const T* num_data = numbers->data<T>();
int blocks = NumBlocks(numel);
int threads = kNumCUDAThreads;
AssignPos<T><<<blocks, threads, 0, dev_ctx.stream()>>>(cum_data, num_data,
out_data, numel);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(assign_pos, ops::AssignPosCUDAKernel<int64_t>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename T>
class AssignPosOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support assign pos op for cpu kernel now."));
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -22,8 +22,7 @@ class NumberCountOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("gate_idx"), "Input", "gate_idx",
"NumberCount");
OP_INOUT_CHECK(ctx->HasInput("numbers"), "Input", "numbers", "NumberCount");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "number_count",
"NumberCount");
}
......@@ -31,25 +30,24 @@ class NumberCountOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// the dtype of the gate_idx should be same as int64
auto gate_idx_dtype =
OperatorWithKernel::IndicateVarDataType(ctx, "gate_idx");
// the dtype of the numbers should be same as int64
auto number_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "numbers");
PADDLE_ENFORCE_EQ(gate_idx_dtype, framework::proto::VarType::INT64,
PADDLE_ENFORCE_EQ(number_dtype, framework::proto::VarType::INT64,
platform::errors::InvalidArgument(
"The dtype of the gate_idx_dtype should be int64"));
return framework::OpKernelType(gate_idx_dtype, ctx.GetPlace());
"The dtype of the number_dtype should be int64"));
return framework::OpKernelType(number_dtype, ctx.GetPlace());
}
};
class NumberCountOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("gate_idx", "(Tensor) The input gate index tensor.");
AddOutput("Out", "(Tensor) The output expert count tensor.");
AddAttr<int>("upper_range", "(int), The number of experts.");
AddInput("numbers", "(Tensor) The input gate index tensor.");
AddOutput("Out", "(Tensor) The output number count tensor.");
AddAttr<int>("upper_range", "(int), The number of different numbers.");
AddComment(R"DOC(number_count Operator.count gate indices.)DOC");
AddComment(R"DOC(number_count Operator.count numbers.)DOC");
}
};
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -38,7 +38,7 @@ __global__ void initialize_zero_kernel(T* data, const int length) {
}
template <typename T>
__global__ void NumberCount(const T* gate_idx, T* number_count,
__global__ void NumberCount(const T* numbers, T* number_count,
int64_t batch_size, int upper_range) {
int res_tmp[PERTHREAD_EXPERTS] = {0};
int expert_min = blockIdx.x * PERTHREAD_EXPERTS;
......@@ -47,7 +47,7 @@ __global__ void NumberCount(const T* gate_idx, T* number_count,
expert_max = upper_range;
}
for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
T idx = gate_idx[i];
T idx = numbers[i];
if (idx == -1) {
continue;
}
......@@ -76,18 +76,18 @@ template <typename T>
class NumberCountOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto gate_idx = context.Input<LoDTensor>("gate_idx");
auto numbers = context.Input<LoDTensor>("numbers");
auto upper_range = context.Attr<int>("upper_range");
auto number_count = context.Output<LoDTensor>("Out");
int64_t batch_size = gate_idx->numel();
int64_t batch_size = numbers->numel();
auto place = context.GetPlace();
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
framework::DDim out_dims = phi::make_ddim({upper_range});
auto out_data = number_count->mutable_data<T>(out_dims, place);
const T* gate_data = gate_idx->data<T>();
const T* gate_data = numbers->data<T>();
initialize_zero_kernel<
T><<<GET_BLOCKS(upper_range), CUDA_NUM_THREADS, 0, dev_ctx.stream()>>>(
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......
......@@ -17,11 +17,11 @@ from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import in_dygraph_mode
def _number_count(gate_idx, upper_range):
def _number_count(numbers, upper_range):
"""
calculate the expert count according to the gate index.
Args:
gate_idx (Tensor): Tensor. The input gate index whose data type should be int32 or int64.
numbers (Tensor): Tensor. The input gate index whose data type should be int32 or int64.
upper_range (int): The number of the experts.
Returns:
out (Tensor): The output expert count.
......@@ -30,26 +30,75 @@ def _number_count(gate_idx, upper_range):
# required: distributed
import paddle
gate_idx = [
numbers = [
[0, 2],
[0, 2]
]
upper_range = 6
gate_idx = paddle.to_tensor(gate_idx, dtype="int32")
number_count = paddle.distributed.utils.number_count(gate_idx, upper_range)
numbers = paddle.to_tensor(numbers, dtype="int32")
number_count = paddle.distributed.utils.number_count(numbers, upper_range)
print(number_count) # the result: [2, 0, 2, 0, 0, 0]
"""
if in_dygraph_mode():
return core.ops.number_count(gate_idx, 'upper_range', upper_range)
return core.ops.number_count(numbers, 'upper_range', upper_range)
else:
op_type = 'number_count'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=gate_idx.dtype)
out = helper.create_variable_for_type_inference(dtype=numbers.dtype)
helper.append_op(
type=op_type,
inputs={'gate_idx': gate_idx},
inputs={'numbers': numbers},
outputs={'Out': out},
attrs={'upper_range': upper_range})
return out
def _assign_pos(x, cum_count):
"""
Assign pos decides which tokens should be fetched belong to
specially expert orderingly.
Args:
x (Tensor): Tensor. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
cum_count (Tensor): The cumulative sum tokens of counters. Every element in the list must be a Tensor whose
data type should be int64.
Returns:
out (Tensor): Assemble numbers in the order of counters.
Examples:
.. code-block:: python
# required: distributed
import paddle
number_count = [2, 0, 2, 0]
numbers = [
[0, 2],
[0, 2]
]
number_count = paddle.to_tensor(number_count)
numbers = paddle.to_tensor(numbers, dtype="int32")
num_cum = paddle.cumsum(number_count)
pos = paddle.distributed.utils.assign_pos(x=numbers, cum_count=num_cum)
print(pos) # the result: (2, 0, 3, 1)
"""
if in_dygraph_mode():
return core.ops.assign_pos(x, cum_count, cum_count[-1])
else:
op_type = 'assign_pos'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=cum_count.dtype)
helper.append_op(
type=op_type,
inputs={
'X': [x],
'cum_count': [cum_count],
"eff_num_len": [cum_count[-1]]
},
outputs={'Out': [out]})
return out
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import op_test
import numpy as np
import unittest
import paddle
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.backward import append_backward
from paddle.distributed.models.moe import utils
def assign_pos(x, _cum_count):
cum_count = np.copy(_cum_count)
x = x.reshape(-1)
res = np.zeros((cum_count[-1], ), dtype=np.int64)
for i, idx in enumerate(x):
p = cum_count[idx]
cum_count[idx] -= 1
if p >= 1:
res[p - 1] = i
return res
def count(x, upper_num):
res = np.zeros((upper_num, )).astype(int)
for i in x.reshape(-1):
if i >= 0 and i < len(res):
res[i] += 1
return res
# why defining the assert function specially?
# Becasue assign_pos_op is multithread-op, which can make the order of numbers
# in each counter(bin) is random. But the numbers set is certain in each counter(bin).
np_allclose = np.allclose
def assert_allclose(res, out, cum_count):
c0 = 0
for c in cum_count:
if c == c0:
continue
data1 = np.copy(res[c0:c])
data2 = np.copy(out[c0:c])
data1.sort()
data2.sort()
assert np_allclose(data2, data1)
c0 = c
return True
def get_redefined_allclose(cum_count):
def redefined_allclose(x, y, *args, **kwargs):
return assert_allclose(x, y, cum_count)
return redefined_allclose
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestAssignPosOpInt64(op_test.OpTest):
def setUp(self):
x = np.random.randint(0, 16, size=(100, 2)).astype("int64")
y = count(x, 16)
cum_count = np.cumsum(y).astype(x.dtype)
self.op_type = "assign_pos"
self.inputs = {
'X': x,
"cum_count": cum_count,
"eff_num_len": np.array([cum_count[-1]])
}
self.outputs = {'Out': assign_pos(x, cum_count)}
self.cum_count = cum_count
def test_forward(self):
np.allclose = get_redefined_allclose(self.cum_count)
self.check_output_with_place(paddle.CUDAPlace(0))
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestAssignPosAPI(unittest.TestCase):
def setUp(self):
self.x = np.random.randint(0, 16, size=(100, 2)).astype("int64")
y = count(self.x, 16)
self.cum_count = np.cumsum(y).astype(self.x.dtype)
self.out = assign_pos(self.x, self.cum_count)
self.place = paddle.CUDAPlace(0)
def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('x', self.x.shape, dtype="int64")
cum_count = paddle.fluid.data(
'cum_count', self.cum_count.shape, dtype="int64")
out = utils._assign_pos(x, cum_count)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x,
"cum_count": self.cum_count},
fetch_list=[out])
assert_allclose(res[0], self.out, self.cum_count)
def test_api_dygraph(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
cum_count = paddle.to_tensor(self.cum_count).astype(x.dtype)
out = utils._assign_pos(x, cum_count)
assert_allclose(out.numpy(), self.out, self.cum_count)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -26,8 +26,8 @@ from paddle.fluid.backward import append_backward
from paddle.distributed.models.moe import utils
def count(x, upper_range):
res = np.zeros((upper_range, )).astype(int)
def count(x, upper_num):
res = np.zeros((upper_num, )).astype(int)
for i in x.reshape(-1):
if i >= 0 and i < len(res):
res[i] += 1
......@@ -36,14 +36,14 @@ def count(x, upper_range):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestExpertCountOpInt64(op_test.OpTest):
class TestNumberCountOpInt64(op_test.OpTest):
def setUp(self):
expert_num = 16
upper_num = 16
self.op_type = "number_count"
x = np.random.randint(-1, expert_num, size=(1000, 2)).astype('int64')
self.inputs = {'gate_idx': x}
self.outputs = {'Out': count(x, expert_num)}
self.attrs = {"upper_range": expert_num}
x = np.random.randint(-1, upper_num, size=(1000, 2)).astype('int64')
self.inputs = {'numbers': x}
self.outputs = {'Out': count(x, upper_num)}
self.attrs = {"upper_range": upper_num}
def test_forward(self):
self.check_output_with_place(paddle.CUDAPlace(0))
......@@ -51,19 +51,19 @@ class TestExpertCountOpInt64(op_test.OpTest):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestExpertCountAPI(unittest.TestCase):
class TestNumberCountAPI(unittest.TestCase):
def setUp(self):
self.upper_range = 320
self.upper_num = 320
self.x = np.random.randint(
-1, self.upper_range, size=(6000, 200)).astype('int64')
self.out = count(self.x, self.upper_range)
-1, self.upper_num, size=(6000, 200)).astype('int64')
self.out = count(self.x, self.upper_num)
self.place = paddle.CUDAPlace(0)
def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('x', self.x.shape, dtype="int64")
out = utils._number_count(x, self.upper_range)
out = utils._number_count(x, self.upper_num)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x}, fetch_list=[out])
assert np.allclose(res, self.out)
......@@ -71,7 +71,7 @@ class TestExpertCountAPI(unittest.TestCase):
def test_api_dygraph(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
out = utils._number_count(x, self.upper_range)
out = utils._number_count(x, self.upper_num)
assert np.allclose(out.numpy(), self.out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册