diff --git a/paddle/fluid/operators/shard_index_op_npu.cc b/paddle/fluid/operators/shard_index_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..83b5d12330d673ba6271d29442adae4bfd930cb4 --- /dev/null +++ b/paddle/fluid/operators/shard_index_op_npu.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/shard_index_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; +template +class ShardIndexNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + VLOG(4) << "start kernel"; + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int index_num = context.Attr("index_num"); + int nshards = context.Attr("nshards"); + int shard_id = context.Attr("shard_id"); + int ignore_value = context.Attr("ignore_value"); + + PADDLE_ENFORCE_GT( + index_num, 0, + platform::errors::InvalidArgument( + "The value 'index_num' for Op(shard_index) must be greater than 0, " + "but the value given is %d.", + index_num)); + PADDLE_ENFORCE_GT(nshards, 0, + platform::errors::InvalidArgument( + "The value 'nshard' for Op(shard_index) must be " + "greater than 0, but the value given is %d.", + nshards)); + PADDLE_ENFORCE_GE( + shard_id, 0, + platform::errors::InvalidArgument( + "The value 'shard_id' for Op(shard_index) must be greater or " + "equal to 0, but the value given is %d.", + shard_id)); + PADDLE_ENFORCE_LT( + shard_id, nshards, + platform::errors::InvalidArgument( + "The value 'shard_id' for Op(shard_index) must be less than " + "nshards (%d), but the value given is %d.", + nshards, shard_id)); + + int shard_size = (index_num + nshards - 1) / nshards; + + auto place = context.GetPlace(); + out->Resize(in->dims()); + out->set_lod(in->lod()); + out->mutable_data(place); + + Tensor tmp(in->type()); + tmp.mutable_data(framework::DDim({1}), place); + FillNpuTensorWithConstant(&tmp, shard_size); + + Tensor condition(framework::proto::VarType::BOOL); + condition.mutable_data(in->dims(), place); + + Tensor tmp2(in->type()); + tmp2.mutable_data(in->dims(), place); + + Tensor tmp3(in->type()); + tmp3.mutable_data(in->dims(), place); + + auto stream = + context.template device_context() + .stream(); + + NpuOpRunner runner; + runner.AddInputs({*in, tmp}); + runner.AddOutputs({tmp2}); + runner.SetType("Mod"); + runner.Run(stream); + + NpuOpRunner runner1; + runner1.AddInputs({*in, tmp}); + runner1.AddOutputs({tmp3}); + runner1.SetType("FloorDiv"); + runner1.Run(stream); + + FillNpuTensorWithConstant(&tmp, shard_id); + NpuOpRunner runner2; + runner2.AddInputs({tmp3, tmp}); + runner2.AddOutputs({condition}); + runner2.SetType("Equal"); + runner2.Run(stream); + + Tensor tmp4(in->type()); + tmp4.mutable_data(in->dims(), place); + FillNpuTensorWithConstant(&tmp4, ignore_value); + tmp4.Resize(in->dims()); + + NpuOpRunner runner3; + runner3.AddInputs({condition, tmp2, tmp4}); + runner3.AddOutputs({*out}); + runner3.SetType("Select"); + runner3.Run(stream); + } +}; +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL(shard_index, ops::ShardIndexNPUKernel, + ops::ShardIndexNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_shard_index_op.py b/python/paddle/fluid/tests/unittests/npu/test_shard_index_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7e962624a4659ce25b0fea78a6015bbafcd052 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_shard_index_op.py @@ -0,0 +1,84 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +import sys +sys.path.append("..") +from op_test import OpTest +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.framework as framework +from paddle.fluid.framework import Program, program_guard +import paddle +paddle.enable_static() +SEED = 2021 + + +def common_setup(self, index_num, nshards, shard_id, ignore_value): + self.__class__.use_npu = True + self.__class__.op_type = "shard_index" + + self.op_type = 'shard_index' + x_lod = [[i for i in range(10)]] + N = sum(x_lod[0]) + x = [np.random.randint(0, index_num - 1) for i in range(N)] + x = np.array(x).astype('int32').reshape([N, 1]) + + shard_size = (index_num + nshards - 1) // nshards + out = np.zeros(shape=x.shape).astype('int32') + for i in range(N): + if x[i] // shard_size == shard_id: + out[i] = x[i] % shard_size + else: + out[i] = ignore_value + + self.inputs = {'X': (x, x_lod)} + self.attrs = { + 'index_num': index_num, + 'nshards': nshards, + 'shard_id': shard_id, + 'ignore_value': ignore_value + } + self.outputs = {'Out': (out, x_lod)} + + +class TestShardIndexShardId0Op(OpTest): + def setUp(self): + common_setup(self, 20, 2, 0, -1) + + def test_check_output(self): + return self.check_output_with_place(place=paddle.NPUPlace(0)) + + +class TestShardIndexShardId1Op(TestShardIndexShardId0Op): + def setUp(self): + common_setup(self, 20, 2, 1, -1) + + +class TestShardIndexIgnoreValueOp(TestShardIndexShardId0Op): + def setUp(self): + common_setup(self, 20, 2, 0, -2) + + +class TestShardIndexNotEvenlyDividedOp(TestShardIndexShardId0Op): + def setUp(self): + common_setup(self, 15, 2, 1, -1) + + +if __name__ == '__main__': + unittest.main()