diff --git a/paddle/fluid/operators/assign_pos_op.cc b/paddle/fluid/operators/assign_pos_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..69c0283c7bffb33cd49c1d0374f647828364dc67 --- /dev/null +++ b/paddle/fluid/operators/assign_pos_op.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/assign_pos_op.h" + +namespace paddle { +namespace operators { + +class AssignPosOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("cum_count"), "Input", "cum_count", + "AssignPos"); + OP_INOUT_CHECK(ctx->HasInput("eff_num_len"), "Input", "eff_num_len", + "AssignPos"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AssignPos"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "AssignPos"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto cum_count_dtype = + OperatorWithKernel::IndicateVarDataType(ctx, "cum_count"); + auto X_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + + PADDLE_ENFORCE_EQ(cum_count_dtype, X_dtype, + platform::errors::InvalidArgument( + "The dtype of the cum_count and X should be same")); + PADDLE_ENFORCE_EQ(cum_count_dtype, framework::proto::VarType::INT64, + platform::errors::InvalidArgument( + "The dtype of the cum_count_dtype, eff_num_len and " + "X should be same as int64")); + return framework::OpKernelType(cum_count_dtype, ctx.device_context()); + } +}; + +class AssignPosOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "numbers to scatter."); + AddInput("cum_count", "The cumulative sum count of numbers."); + AddInput("eff_num_len", + "The effective numbers of numbers should be scattered."); + AddOutput("Out", "Assemble numbers in the order of counters."); + + AddComment(R"DOC( +assign_pos_op Operator. + +Assign pos decides which tokens should be fetched belong to +specially counter orderingly. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(assign_pos, ops::AssignPosOp, + ops::AssignPosOpMaker); + +REGISTER_OP_CPU_KERNEL(assign_pos, ops::AssignPosOpCPUKernel, + ops::AssignPosOpCPUKernel); diff --git a/paddle/fluid/operators/assign_pos_op.cu b/paddle/fluid/operators/assign_pos_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5fa159b94f9834e43db1cb0a419eefd2f60181b0 --- /dev/null +++ b/paddle/fluid/operators/assign_pos_op.cu @@ -0,0 +1,94 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/assign_pos_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/float16.h" + +DECLARE_bool(avoid_op_randomness); + +namespace paddle { +namespace operators { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void AssignPos(T* cum_count, const T* numbers, T* out, + int64_t limit) { + CUDA_KERNEL_LOOP(i, limit) { + int number_idx = numbers[i]; + if (number_idx > -1) { + int p = platform::CudaAtomicAdd(cum_count + number_idx, -1); + out[p - 1] = i; + } + } +} + +template +class AssignPosCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + // assign pos decides which tokens should be fetched belong to specially + // counter orderingly. + auto cum_count = context.Input( + "cum_count"); // (counter number) int32 | int64 + auto numbers = + context.Input("X"); // (batch_size * seq_len, topk) int32 + auto eff_num_len = + context.Input("eff_num_len"); // (sum(cum_count)) + auto out = context.Output("Out"); // (cum_count) value ranges + // from 0 to batch_size * + // seq_len * topk + auto place = context.GetPlace(); + auto numel = numbers->numel(); + T* cum_data = const_cast(cum_count->data()); + auto cum_size = cum_count->numel(); + + framework::Tensor cpu_eff_num_len; + int64_t cpu_eff_num_len_data = 0; + if (platform::is_cpu_place(eff_num_len->place())) { + cpu_eff_num_len_data = eff_num_len->data()[0]; + } else { + framework::TensorCopySync(*eff_num_len, platform::CPUPlace(), + &cpu_eff_num_len); + cpu_eff_num_len_data = cpu_eff_num_len.data()[0]; + } + const auto& dev_ctx = + context.template device_context(); + framework::DDim out_dims = phi::make_ddim({cpu_eff_num_len_data}); + auto out_data = out->mutable_data(out_dims, place); + + const T* num_data = numbers->data(); + + int blocks = NumBlocks(numel); + int threads = kNumCUDAThreads; + + AssignPos<<>>(cum_data, num_data, + out_data, numel); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(assign_pos, ops::AssignPosCUDAKernel); diff --git a/paddle/fluid/operators/assign_pos_op.h b/paddle/fluid/operators/assign_pos_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1a017415778dd058378536284d1a264944c60927 --- /dev/null +++ b/paddle/fluid/operators/assign_pos_op.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; + +template +class AssignPosOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support assign pos op for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/number_count_op.cc b/paddle/fluid/operators/number_count_op.cc index 8f7a3b82acf19fa79cbf5c632977e6ae533ae12b..3b7406c997aba2885564f82cae4a21fcc59dcbdc 100644 --- a/paddle/fluid/operators/number_count_op.cc +++ b/paddle/fluid/operators/number_count_op.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,8 +22,7 @@ class NumberCountOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("gate_idx"), "Input", "gate_idx", - "NumberCount"); + OP_INOUT_CHECK(ctx->HasInput("numbers"), "Input", "numbers", "NumberCount"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "number_count", "NumberCount"); } @@ -31,25 +30,24 @@ class NumberCountOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - // the dtype of the gate_idx should be same as int64 - auto gate_idx_dtype = - OperatorWithKernel::IndicateVarDataType(ctx, "gate_idx"); + // the dtype of the numbers should be same as int64 + auto number_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "numbers"); - PADDLE_ENFORCE_EQ(gate_idx_dtype, framework::proto::VarType::INT64, + PADDLE_ENFORCE_EQ(number_dtype, framework::proto::VarType::INT64, platform::errors::InvalidArgument( - "The dtype of the gate_idx_dtype should be int64")); - return framework::OpKernelType(gate_idx_dtype, ctx.GetPlace()); + "The dtype of the number_dtype should be int64")); + return framework::OpKernelType(number_dtype, ctx.GetPlace()); } }; class NumberCountOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("gate_idx", "(Tensor) The input gate index tensor."); - AddOutput("Out", "(Tensor) The output expert count tensor."); - AddAttr("upper_range", "(int), The number of experts."); + AddInput("numbers", "(Tensor) The input gate index tensor."); + AddOutput("Out", "(Tensor) The output number count tensor."); + AddAttr("upper_range", "(int), The number of different numbers."); - AddComment(R"DOC(number_count Operator.count gate indices.)DOC"); + AddComment(R"DOC(number_count Operator.count numbers.)DOC"); } }; diff --git a/paddle/fluid/operators/number_count_op.cu b/paddle/fluid/operators/number_count_op.cu index 97e4b4f2845ae132c28d3bb71dcc8e73f02e193a..0106c70d8eb53888801e942fc6c7c9ca57644062 100644 --- a/paddle/fluid/operators/number_count_op.cu +++ b/paddle/fluid/operators/number_count_op.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ __global__ void initialize_zero_kernel(T* data, const int length) { } template -__global__ void NumberCount(const T* gate_idx, T* number_count, +__global__ void NumberCount(const T* numbers, T* number_count, int64_t batch_size, int upper_range) { int res_tmp[PERTHREAD_EXPERTS] = {0}; int expert_min = blockIdx.x * PERTHREAD_EXPERTS; @@ -47,7 +47,7 @@ __global__ void NumberCount(const T* gate_idx, T* number_count, expert_max = upper_range; } for (int i = threadIdx.x; i < batch_size; i += blockDim.x) { - T idx = gate_idx[i]; + T idx = numbers[i]; if (idx == -1) { continue; } @@ -76,18 +76,18 @@ template class NumberCountOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto gate_idx = context.Input("gate_idx"); + auto numbers = context.Input("numbers"); auto upper_range = context.Attr("upper_range"); auto number_count = context.Output("Out"); - int64_t batch_size = gate_idx->numel(); + int64_t batch_size = numbers->numel(); auto place = context.GetPlace(); const auto& dev_ctx = context.template device_context(); framework::DDim out_dims = phi::make_ddim({upper_range}); auto out_data = number_count->mutable_data(out_dims, place); - const T* gate_data = gate_idx->data(); + const T* gate_data = numbers->data(); initialize_zero_kernel< T><<>>( diff --git a/paddle/fluid/operators/number_count_op.h b/paddle/fluid/operators/number_count_op.h index 95e64946fb8a2156fdb4cbae880ccf2c143447ed..ded7ea6eec54f7ce08ae610274febdbb4f82d292 100644 --- a/paddle/fluid/operators/number_count_op.h +++ b/paddle/fluid/operators/number_count_op.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/python/paddle/distributed/models/moe/utils.py b/python/paddle/distributed/models/moe/utils.py index fd98c64318c60e2e67af320c51b24e39a3132c43..28cbfb4f4c74a2080fd2700533cf26a988d3fda7 100644 --- a/python/paddle/distributed/models/moe/utils.py +++ b/python/paddle/distributed/models/moe/utils.py @@ -17,11 +17,11 @@ from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.framework import in_dygraph_mode -def _number_count(gate_idx, upper_range): +def _number_count(numbers, upper_range): """ calculate the expert count according to the gate index. Args: - gate_idx (Tensor): Tensor. The input gate index whose data type should be int32 or int64. + numbers (Tensor): Tensor. The input gate index whose data type should be int32 or int64. upper_range (int): The number of the experts. Returns: out (Tensor): The output expert count. @@ -30,26 +30,75 @@ def _number_count(gate_idx, upper_range): # required: distributed import paddle - gate_idx = [ + numbers = [ [0, 2], [0, 2] ] upper_range = 6 - gate_idx = paddle.to_tensor(gate_idx, dtype="int32") - number_count = paddle.distributed.utils.number_count(gate_idx, upper_range) + numbers = paddle.to_tensor(numbers, dtype="int32") + number_count = paddle.distributed.utils.number_count(numbers, upper_range) print(number_count) # the result: [2, 0, 2, 0, 0, 0] """ if in_dygraph_mode(): - return core.ops.number_count(gate_idx, 'upper_range', upper_range) + return core.ops.number_count(numbers, 'upper_range', upper_range) else: op_type = 'number_count' helper = LayerHelper(op_type, **locals()) - out = helper.create_variable_for_type_inference(dtype=gate_idx.dtype) + out = helper.create_variable_for_type_inference(dtype=numbers.dtype) helper.append_op( type=op_type, - inputs={'gate_idx': gate_idx}, + inputs={'numbers': numbers}, outputs={'Out': out}, attrs={'upper_range': upper_range}) return out + + +def _assign_pos(x, cum_count): + """ + Assign pos decides which tokens should be fetched belong to + specially expert orderingly. + + Args: + x (Tensor): Tensor. Every element in the list must be a Tensor whose data type + should be float16, float32, float64, int32 or int64. + cum_count (Tensor): The cumulative sum tokens of counters. Every element in the list must be a Tensor whose + data type should be int64. + + Returns: + out (Tensor): Assemble numbers in the order of counters. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + number_count = [2, 0, 2, 0] + numbers = [ + [0, 2], + [0, 2] + ] + number_count = paddle.to_tensor(number_count) + numbers = paddle.to_tensor(numbers, dtype="int32") + num_cum = paddle.cumsum(number_count) + pos = paddle.distributed.utils.assign_pos(x=numbers, cum_count=num_cum) + print(pos) # the result: (2, 0, 3, 1) + """ + if in_dygraph_mode(): + return core.ops.assign_pos(x, cum_count, cum_count[-1]) + else: + op_type = 'assign_pos' + + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=cum_count.dtype) + + helper.append_op( + type=op_type, + inputs={ + 'X': [x], + 'cum_count': [cum_count], + "eff_num_len": [cum_count[-1]] + }, + outputs={'Out': [out]}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_assign_pos_op.py b/python/paddle/fluid/tests/unittests/test_assign_pos_op.py new file mode 100644 index 0000000000000000000000000000000000000000..72924f242d211d063b1d547050de79f87f2d8dac --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_assign_pos_op.py @@ -0,0 +1,131 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import op_test +import numpy as np +import unittest +import paddle +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.backward import append_backward +from paddle.distributed.models.moe import utils + + +def assign_pos(x, _cum_count): + cum_count = np.copy(_cum_count) + x = x.reshape(-1) + res = np.zeros((cum_count[-1], ), dtype=np.int64) + for i, idx in enumerate(x): + p = cum_count[idx] + cum_count[idx] -= 1 + if p >= 1: + res[p - 1] = i + return res + + +def count(x, upper_num): + res = np.zeros((upper_num, )).astype(int) + for i in x.reshape(-1): + if i >= 0 and i < len(res): + res[i] += 1 + return res + + +# why defining the assert function specially? +# Becasue assign_pos_op is multithread-op, which can make the order of numbers +# in each counter(bin) is random. But the numbers set is certain in each counter(bin). +np_allclose = np.allclose + + +def assert_allclose(res, out, cum_count): + c0 = 0 + for c in cum_count: + if c == c0: + continue + data1 = np.copy(res[c0:c]) + data2 = np.copy(out[c0:c]) + data1.sort() + data2.sort() + assert np_allclose(data2, data1) + c0 = c + return True + + +def get_redefined_allclose(cum_count): + def redefined_allclose(x, y, *args, **kwargs): + return assert_allclose(x, y, cum_count) + + return redefined_allclose + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestAssignPosOpInt64(op_test.OpTest): + def setUp(self): + x = np.random.randint(0, 16, size=(100, 2)).astype("int64") + y = count(x, 16) + cum_count = np.cumsum(y).astype(x.dtype) + self.op_type = "assign_pos" + self.inputs = { + 'X': x, + "cum_count": cum_count, + "eff_num_len": np.array([cum_count[-1]]) + } + self.outputs = {'Out': assign_pos(x, cum_count)} + self.cum_count = cum_count + + def test_forward(self): + np.allclose = get_redefined_allclose(self.cum_count) + self.check_output_with_place(paddle.CUDAPlace(0)) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestAssignPosAPI(unittest.TestCase): + def setUp(self): + self.x = np.random.randint(0, 16, size=(100, 2)).astype("int64") + y = count(self.x, 16) + self.cum_count = np.cumsum(y).astype(self.x.dtype) + self.out = assign_pos(self.x, self.cum_count) + self.place = paddle.CUDAPlace(0) + + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('x', self.x.shape, dtype="int64") + cum_count = paddle.fluid.data( + 'cum_count', self.cum_count.shape, dtype="int64") + out = utils._assign_pos(x, cum_count) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x, + "cum_count": self.cum_count}, + fetch_list=[out]) + assert_allclose(res[0], self.out, self.cum_count) + + def test_api_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + cum_count = paddle.to_tensor(self.cum_count).astype(x.dtype) + + out = utils._assign_pos(x, cum_count) + assert_allclose(out.numpy(), self.out, self.cum_count) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_number_count_op.py b/python/paddle/fluid/tests/unittests/test_number_count_op.py index 0df9d2a3a41b44c18b7e008a271c10544ec4dfa0..9eb89dfeb0e8d9e4538f3a7004da777eafbb2f34 100644 --- a/python/paddle/fluid/tests/unittests/test_number_count_op.py +++ b/python/paddle/fluid/tests/unittests/test_number_count_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,8 +26,8 @@ from paddle.fluid.backward import append_backward from paddle.distributed.models.moe import utils -def count(x, upper_range): - res = np.zeros((upper_range, )).astype(int) +def count(x, upper_num): + res = np.zeros((upper_num, )).astype(int) for i in x.reshape(-1): if i >= 0 and i < len(res): res[i] += 1 @@ -36,14 +36,14 @@ def count(x, upper_range): @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestExpertCountOpInt64(op_test.OpTest): +class TestNumberCountOpInt64(op_test.OpTest): def setUp(self): - expert_num = 16 + upper_num = 16 self.op_type = "number_count" - x = np.random.randint(-1, expert_num, size=(1000, 2)).astype('int64') - self.inputs = {'gate_idx': x} - self.outputs = {'Out': count(x, expert_num)} - self.attrs = {"upper_range": expert_num} + x = np.random.randint(-1, upper_num, size=(1000, 2)).astype('int64') + self.inputs = {'numbers': x} + self.outputs = {'Out': count(x, upper_num)} + self.attrs = {"upper_range": upper_num} def test_forward(self): self.check_output_with_place(paddle.CUDAPlace(0)) @@ -51,19 +51,19 @@ class TestExpertCountOpInt64(op_test.OpTest): @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestExpertCountAPI(unittest.TestCase): +class TestNumberCountAPI(unittest.TestCase): def setUp(self): - self.upper_range = 320 + self.upper_num = 320 self.x = np.random.randint( - -1, self.upper_range, size=(6000, 200)).astype('int64') - self.out = count(self.x, self.upper_range) + -1, self.upper_num, size=(6000, 200)).astype('int64') + self.out = count(self.x, self.upper_num) self.place = paddle.CUDAPlace(0) def test_api_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.fluid.data('x', self.x.shape, dtype="int64") - out = utils._number_count(x, self.upper_range) + out = utils._number_count(x, self.upper_num) exe = paddle.static.Executor(self.place) res = exe.run(feed={'x': self.x}, fetch_list=[out]) assert np.allclose(res, self.out) @@ -71,7 +71,7 @@ class TestExpertCountAPI(unittest.TestCase): def test_api_dygraph(self): paddle.disable_static() x = paddle.to_tensor(self.x) - out = utils._number_count(x, self.upper_range) + out = utils._number_count(x, self.upper_num) assert np.allclose(out.numpy(), self.out)