未验证 提交 9bdee437 编写于 作者: R Roc 提交者: GitHub

add number count op (#39224)

* add expert count op

add ut for expert_count

* update UT only for cuda

* fix for rocm

* update ut

* add moe module

* add expert count op

add ut for expert_count

* update UT only for cuda

* update ut

* add moe module

* make expert count private

* rename expert count op
Co-authored-by: Nhlygit66666 <2570058140@qq.com>
上级 4d886f75
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/number_count_op.h"
namespace paddle {
namespace operators {
class NumberCountOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("gate_idx"), "Input", "gate_idx",
"NumberCount");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "number_count",
"NumberCount");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// the dtype of the gate_idx should be same as int64
auto gate_idx_dtype =
OperatorWithKernel::IndicateVarDataType(ctx, "gate_idx");
PADDLE_ENFORCE_EQ(gate_idx_dtype, framework::proto::VarType::INT64,
platform::errors::InvalidArgument(
"The dtype of the gate_idx_dtype should be int64"));
return framework::OpKernelType(gate_idx_dtype, ctx.GetPlace());
}
};
class NumberCountOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("gate_idx", "(Tensor) The input gate index tensor.");
AddOutput("Out", "(Tensor) The output expert count tensor.");
AddAttr<int>("upper_range", "(int), The number of experts.");
AddComment(R"DOC(number_count Operator.count gate indices.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CPU_KERNEL(number_count, ops::NumberCountOpCPUKernel<int>,
ops::NumberCountOpCPUKernel<int64_t>);
REGISTER_OP_WITHOUT_GRADIENT(number_count, ops::NumberCountOp,
ops::NumberCountOpMaker);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/number_count_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1)
#define PERTHREAD_EXPERTS 256
#define WARP_SIZE 32
const int CUDA_NUM_THREADS = 512;
static inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename T>
__global__ void initialize_zero_kernel(T* data, const int length) {
CUDA_KERNEL_LOOP(idx, length) { data[idx] = static_cast<T>(0); }
}
template <typename T>
__global__ void NumberCount(const T* gate_idx, T* number_count,
int64_t batch_size, int upper_range) {
int res_tmp[PERTHREAD_EXPERTS] = {0};
int expert_min = blockIdx.x * PERTHREAD_EXPERTS;
int expert_max = expert_min + PERTHREAD_EXPERTS;
if (expert_max > upper_range) {
expert_max = upper_range;
}
for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
T idx = gate_idx[i];
if (idx == -1) {
continue;
}
if (idx < expert_min || idx >= expert_max) {
continue;
}
res_tmp[idx - expert_min] += 1;
}
for (int i = expert_min; i < expert_max; ++i) {
int x = res_tmp[i - expert_min];
#pragma unroll
for (int j = 1; j < WARP_SIZE; j <<= 1) {
#ifdef __HIPCC__
x = x + __shfl_down(x, j);
#else
x = x + __shfl_down_sync(-1u, x, j);
#endif
}
if (threadIdx.x % WARP_SIZE == 0) {
platform::CudaAtomicAdd(number_count + i, x);
}
}
}
template <typename T>
class NumberCountOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto gate_idx = context.Input<LoDTensor>("gate_idx");
auto upper_range = context.Attr<int>("upper_range");
auto number_count = context.Output<LoDTensor>("Out");
int64_t batch_size = gate_idx->numel();
auto place = context.GetPlace();
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
framework::DDim out_dims = phi::make_ddim({upper_range});
auto out_data = number_count->mutable_data<T>(out_dims, place);
const T* gate_data = gate_idx->data<T>();
initialize_zero_kernel<
T><<<GET_BLOCKS(upper_range), CUDA_NUM_THREADS, 0, dev_ctx.stream()>>>(
out_data, upper_range);
NumberCount<
T><<<CEIL(upper_range, PERTHREAD_EXPERTS), 256, 0, dev_ctx.stream()>>>(
gate_data, out_data, batch_size, upper_range);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(number_count, ops::NumberCountOpCUDAKernel<int64_t>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class NumberCountOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support expert count op for cpu kernel now."));
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import in_dygraph_mode
def _number_count(gate_idx, upper_range):
"""
calculate the expert count according to the gate index.
Args:
gate_idx (Tensor): Tensor. The input gate index whose data type should be int32 or int64.
upper_range (int): The number of the experts.
Returns:
out (Tensor): The output expert count.
Examples:
.. code-block:: python
# required: distributed
import paddle
gate_idx = [
[0, 2],
[0, 2]
]
upper_range = 6
gate_idx = paddle.to_tensor(gate_idx, dtype="int32")
number_count = paddle.distributed.utils.number_count(gate_idx, upper_range)
print(number_count) # the result: [2, 0, 2, 0, 0, 0]
"""
if in_dygraph_mode():
return core.ops.number_count(gate_idx, 'upper_range', upper_range)
else:
op_type = 'number_count'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=gate_idx.dtype)
helper.append_op(
type=op_type,
inputs={'gate_idx': gate_idx},
outputs={'Out': out},
attrs={'upper_range': upper_range})
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import op_test
import numpy as np
import unittest
import paddle
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.backward import append_backward
from paddle.distributed.models.moe import utils
def count(x, upper_range):
res = np.zeros((upper_range, )).astype(int)
for i in x.reshape(-1):
if i >= 0 and i < len(res):
res[i] += 1
return res
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestExpertCountOpInt64(op_test.OpTest):
def setUp(self):
expert_num = 16
self.op_type = "number_count"
x = np.random.randint(-1, expert_num, size=(1000, 2)).astype('int64')
self.inputs = {'gate_idx': x}
self.outputs = {'Out': count(x, expert_num)}
self.attrs = {"upper_range": expert_num}
def test_forward(self):
self.check_output_with_place(paddle.CUDAPlace(0))
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestExpertCountAPI(unittest.TestCase):
def setUp(self):
self.upper_range = 320
self.x = np.random.randint(
-1, self.upper_range, size=(6000, 200)).astype('int64')
self.out = count(self.x, self.upper_range)
self.place = paddle.CUDAPlace(0)
def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('x', self.x.shape, dtype="int64")
out = utils._number_count(x, self.upper_range)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x}, fetch_list=[out])
assert np.allclose(res, self.out)
def test_api_dygraph(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
out = utils._number_count(x, self.upper_range)
assert np.allclose(out.numpy(), self.out)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册