diff --git a/paddle/fluid/operators/number_count_op.cc b/paddle/fluid/operators/number_count_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f7a3b82acf19fa79cbf5c632977e6ae533ae12b --- /dev/null +++ b/paddle/fluid/operators/number_count_op.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/number_count_op.h" + +namespace paddle { +namespace operators { + +class NumberCountOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("gate_idx"), "Input", "gate_idx", + "NumberCount"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "number_count", + "NumberCount"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + // the dtype of the gate_idx should be same as int64 + auto gate_idx_dtype = + OperatorWithKernel::IndicateVarDataType(ctx, "gate_idx"); + + PADDLE_ENFORCE_EQ(gate_idx_dtype, framework::proto::VarType::INT64, + platform::errors::InvalidArgument( + "The dtype of the gate_idx_dtype should be int64")); + return framework::OpKernelType(gate_idx_dtype, ctx.GetPlace()); + } +}; + +class NumberCountOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("gate_idx", "(Tensor) The input gate index tensor."); + AddOutput("Out", "(Tensor) The output expert count tensor."); + AddAttr("upper_range", "(int), The number of experts."); + + AddComment(R"DOC(number_count Operator.count gate indices.)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CPU_KERNEL(number_count, ops::NumberCountOpCPUKernel, + ops::NumberCountOpCPUKernel); + +REGISTER_OP_WITHOUT_GRADIENT(number_count, ops::NumberCountOp, + ops::NumberCountOpMaker); diff --git a/paddle/fluid/operators/number_count_op.cu b/paddle/fluid/operators/number_count_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..97e4b4f2845ae132c28d3bb71dcc8e73f02e193a --- /dev/null +++ b/paddle/fluid/operators/number_count_op.cu @@ -0,0 +1,108 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/number_count_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1) +#define PERTHREAD_EXPERTS 256 +#define WARP_SIZE 32 + +const int CUDA_NUM_THREADS = 512; +static inline int GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +template +__global__ void initialize_zero_kernel(T* data, const int length) { + CUDA_KERNEL_LOOP(idx, length) { data[idx] = static_cast(0); } +} + +template +__global__ void NumberCount(const T* gate_idx, T* number_count, + int64_t batch_size, int upper_range) { + int res_tmp[PERTHREAD_EXPERTS] = {0}; + int expert_min = blockIdx.x * PERTHREAD_EXPERTS; + int expert_max = expert_min + PERTHREAD_EXPERTS; + if (expert_max > upper_range) { + expert_max = upper_range; + } + for (int i = threadIdx.x; i < batch_size; i += blockDim.x) { + T idx = gate_idx[i]; + if (idx == -1) { + continue; + } + if (idx < expert_min || idx >= expert_max) { + continue; + } + res_tmp[idx - expert_min] += 1; + } + for (int i = expert_min; i < expert_max; ++i) { + int x = res_tmp[i - expert_min]; +#pragma unroll + for (int j = 1; j < WARP_SIZE; j <<= 1) { +#ifdef __HIPCC__ + x = x + __shfl_down(x, j); +#else + x = x + __shfl_down_sync(-1u, x, j); +#endif + } + if (threadIdx.x % WARP_SIZE == 0) { + platform::CudaAtomicAdd(number_count + i, x); + } + } +} + +template +class NumberCountOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto gate_idx = context.Input("gate_idx"); + auto upper_range = context.Attr("upper_range"); + auto number_count = context.Output("Out"); + + int64_t batch_size = gate_idx->numel(); + auto place = context.GetPlace(); + const auto& dev_ctx = + context.template device_context(); + + framework::DDim out_dims = phi::make_ddim({upper_range}); + auto out_data = number_count->mutable_data(out_dims, place); + const T* gate_data = gate_idx->data(); + + initialize_zero_kernel< + T><<>>( + out_data, upper_range); + + NumberCount< + T><<>>( + gate_data, out_data, batch_size, upper_range); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(number_count, ops::NumberCountOpCUDAKernel); diff --git a/paddle/fluid/operators/number_count_op.h b/paddle/fluid/operators/number_count_op.h new file mode 100644 index 0000000000000000000000000000000000000000..95e64946fb8a2156fdb4cbae880ccf2c143447ed --- /dev/null +++ b/paddle/fluid/operators/number_count_op.h @@ -0,0 +1,37 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +namespace paddle { +namespace operators { + +template +class NumberCountOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support expert count op for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/distributed/models/__init__.py b/python/paddle/distributed/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1663029ef1f844676ce9484f724dc253d625386 --- /dev/null +++ b/python/paddle/distributed/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/distributed/models/moe/__init__.py b/python/paddle/distributed/models/moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1663029ef1f844676ce9484f724dc253d625386 --- /dev/null +++ b/python/paddle/distributed/models/moe/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/distributed/models/moe/utils.py b/python/paddle/distributed/models/moe/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fd98c64318c60e2e67af320c51b24e39a3132c43 --- /dev/null +++ b/python/paddle/distributed/models/moe/utils.py @@ -0,0 +1,55 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid import core +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import in_dygraph_mode + + +def _number_count(gate_idx, upper_range): + """ + calculate the expert count according to the gate index. + Args: + gate_idx (Tensor): Tensor. The input gate index whose data type should be int32 or int64. + upper_range (int): The number of the experts. + Returns: + out (Tensor): The output expert count. + Examples: + .. code-block:: python + # required: distributed + import paddle + + gate_idx = [ + [0, 2], + [0, 2] + ] + upper_range = 6 + gate_idx = paddle.to_tensor(gate_idx, dtype="int32") + number_count = paddle.distributed.utils.number_count(gate_idx, upper_range) + print(number_count) # the result: [2, 0, 2, 0, 0, 0] + """ + if in_dygraph_mode(): + return core.ops.number_count(gate_idx, 'upper_range', upper_range) + else: + op_type = 'number_count' + + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=gate_idx.dtype) + + helper.append_op( + type=op_type, + inputs={'gate_idx': gate_idx}, + outputs={'Out': out}, + attrs={'upper_range': upper_range}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_number_count_op.py b/python/paddle/fluid/tests/unittests/test_number_count_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0df9d2a3a41b44c18b7e008a271c10544ec4dfa0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_number_count_op.py @@ -0,0 +1,80 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import op_test +import numpy as np +import unittest +import paddle +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.backward import append_backward +from paddle.distributed.models.moe import utils + + +def count(x, upper_range): + res = np.zeros((upper_range, )).astype(int) + for i in x.reshape(-1): + if i >= 0 and i < len(res): + res[i] += 1 + return res + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestExpertCountOpInt64(op_test.OpTest): + def setUp(self): + expert_num = 16 + self.op_type = "number_count" + x = np.random.randint(-1, expert_num, size=(1000, 2)).astype('int64') + self.inputs = {'gate_idx': x} + self.outputs = {'Out': count(x, expert_num)} + self.attrs = {"upper_range": expert_num} + + def test_forward(self): + self.check_output_with_place(paddle.CUDAPlace(0)) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestExpertCountAPI(unittest.TestCase): + def setUp(self): + self.upper_range = 320 + self.x = np.random.randint( + -1, self.upper_range, size=(6000, 200)).astype('int64') + self.out = count(self.x, self.upper_range) + self.place = paddle.CUDAPlace(0) + + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('x', self.x.shape, dtype="int64") + out = utils._number_count(x, self.upper_range) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x}, fetch_list=[out]) + assert np.allclose(res, self.out) + + def test_api_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + out = utils._number_count(x, self.upper_range) + assert np.allclose(out.numpy(), self.out) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()