diff --git a/paddle/fluid/operators/lookup_table_v2_op_mlu.cc b/paddle/fluid/operators/lookup_table_v2_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..c8ab269c023a5b87bf8314b2f66e3e00bb876f86 --- /dev/null +++ b/paddle/fluid/operators/lookup_table_v2_op_mlu.cc @@ -0,0 +1,129 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +constexpr int64_t kNoPadding = -1; + +template +class LookupTableV2MLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *ids_t = ctx.Input("Ids"); // int tensor + auto *output_t = ctx.Output("Out"); // float tensor + auto *table_t = ctx.Input("W"); + + auto *table_var = ctx.InputVar("W"); + PADDLE_ENFORCE_EQ( + table_var->IsType(), true, + platform::errors::InvalidArgument("mlu only accept LoDTensor")); + output_t->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc ids_desc(*ids_t); + MLUCnnlTensorDesc table_desc(*table_t); + MLUCnnlTensorDesc output_desc(*output_t); + + int64_t padding_idx = ctx.Attr("padding_idx"); + if (padding_idx == kNoPadding) { + MLUCnnl::GatherFunctor(ctx, /*axis=*/0, /*batch_dims=*/0, + table_desc.get(), GetBasePtr(table_t), + ids_desc.get(), GetBasePtr(ids_t), + output_desc.get(), GetBasePtr(output_t)); + } else { + Tensor tmp_table_t(table_t->type()); + tmp_table_t.mutable_data(table_t->dims(), ctx.GetPlace()); + + Tensor index; + index.mutable_data({1, 1}, ctx.GetPlace()); + auto idx_value = static_cast(padding_idx); + MLUCnnlTensorDesc index_desc(index); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &idx_value, index_desc.get(), + GetBasePtr(&index)); + + auto update_dim = phi::make_ddim({1, table_t->dims()[1]}); + Tensor update; + update.mutable_data(update_dim, ctx.GetPlace()); + + auto update_value = static_cast(0); + MLUCnnlTensorDesc update_desc(update); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &update_value, + update_desc.get(), GetBasePtr(&update)); + + MLUCnnlTensorDesc tmp_table_desc(tmp_table_t); + MLUCnnl::ScatterNd( + ctx, CNNL_SCATTERND_UPDATE, index_desc.get(), GetBasePtr(&index), + update_desc.get(), GetBasePtr(&update), table_desc.get(), + GetBasePtr(table_t), tmp_table_desc.get(), GetBasePtr(&tmp_table_t)); + + MLUCnnl::GatherFunctor(ctx, /*axis=*/0, /*batch_dims=*/0, + tmp_table_desc.get(), GetBasePtr(&tmp_table_t), + ids_desc.get(), GetBasePtr(ids_t), + output_desc.get(), GetBasePtr(output_t)); + } + } +}; + +template +class LookupTableV2GradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *ids_t = ctx.Input("Ids"); + auto *output_grad_t = + ctx.Input(framework::GradVarName("Out")); + auto *table_grad_t = + ctx.Output(framework::GradVarName("W")); + table_grad_t->mutable_data(ctx.GetPlace()); + + int padding_idx = static_cast(ctx.Attr("padding_idx")); + + Tensor ids_int32(ids_t->dtype()); + if (ids_t->dtype() != DataType::INT32) { + ids_int32.mutable_data(ids_t->dims(), ctx.GetPlace()); + MLUCnnlTensorDesc ids_desc(*ids_t); + MLUCnnlTensorDesc ids_int32_desc(ids_int32); + auto cast_type = GetCastDataType(ids_t->dtype(), DataType::INT32); + MLUCnnl::Cast(ctx, cast_type, ids_desc.get(), GetBasePtr(ids_t), + ids_int32_desc.get(), GetBasePtr(&ids_int32)); + } else { + ids_int32 = *ids_t; + } + + MLUCnnlTensorDesc ids_int32_desc(ids_int32); + MLUCnnlTensorDesc output_grad_desc(*output_grad_t); + MLUCnnlTensorDesc table_grad_desc(*table_grad_t); + + MLUCnnl::EmbeddingBackward(ctx, padding_idx, false, ids_int32_desc.get(), + GetBasePtr(&ids_int32), output_grad_desc.get(), + GetBasePtr(output_grad_t), table_grad_desc.get(), + GetBasePtr(table_grad_t)); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(lookup_table_v2, ops::LookupTableV2MLUKernel, + ops::LookupTableV2MLUKernel, + ops::LookupTableV2MLUKernel); + +REGISTER_OP_MLU_KERNEL(lookup_table_v2_grad, + ops::LookupTableV2GradMLUKernel, + ops::LookupTableV2GradMLUKernel, + ops::LookupTableV2GradMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 867c5f212ba6c17514b65f83d40fe356b5d04146..9d3b8e2407fbfb8aa4cd5eeb640ab06ae961eb1c 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -34,6 +34,12 @@ cnnlCastDataType_t GetCastDataType(const VT::Type& src_type, return cast_type; } +cnnlCastDataType_t GetCastDataType(const DataType& src_type, + const DataType& dst_type) { + return GetCastDataType(framework::TransToProtoVarType(src_type), + framework::TransToProtoVarType(dst_type)); +} + bool MLUSupportsCast(const VT::Type& src_type, const VT::Type& dst_type) { for (auto it = MLU_SUPPORTED_CAST_TYPE.begin(); it != MLU_SUPPORTED_CAST_TYPE.end(); ++it) { @@ -2713,17 +2719,16 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { output_desc, output)); } -/* static */ void MLUCnnl::ScatterNd(const ExecutionContext& ctx, - const cnnlTensorDescriptor_t indices_desc, - const void* indices, - const cnnlTensorDescriptor_t updates_desc, - const void* updates, - const cnnlTensorDescriptor_t output_desc, - void* output) { +/* static */ void MLUCnnl::ScatterNd( + const ExecutionContext& ctx, cnnlScatterNdMode_t mode, + const cnnlTensorDescriptor_t indices_desc, const void* indices, + const cnnlTensorDescriptor_t updates_desc, const void* updates, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t output_desc, void* output) { cnnlHandle_t handle = GetHandleFromCTX(ctx); - PADDLE_ENFORCE_MLU_SUCCESS(cnnlScatterNd(handle, indices_desc, indices, - updates_desc, updates, output_desc, - output)); + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlScatterNd_v2(handle, mode, indices_desc, indices, updates_desc, + updates, input_desc, input, output_desc, output)); } /* static */ void MLUCnnl::BitWise( @@ -2777,5 +2782,26 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { cnnlReciprocal(handle, input_desc, input, output_desc, output)); } +/* static */ void MLUCnnl::EmbeddingBackward( + const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq, + const cnnlTensorDescriptor_t indices_desc, const void* indices, + const cnnlTensorDescriptor_t diff_desc, const void* diff, + const cnnlTensorDescriptor_t output_desc, void* output) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetEmbeddingBackwardWorkspaceSize( + handle, diff_desc, output_desc, scale_grad_by_freq, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlEmbeddingBackward( + handle, padding_idx, scale_grad_by_freq, indices_desc, indices, diff_desc, + diff, workspace_ptr, workspace_size, output_desc, output)); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 24db6c760d78abb4c317b42715daf20575994aee..f048ac7c5c3be08e034c7b2a3b163888f9e9e982 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -175,6 +175,10 @@ const std::map, cnnlCastDataType_t> cnnlCastDataType_t GetCastDataType(const VT::Type& src_type, const VT::Type& dst_type); + +cnnlCastDataType_t GetCastDataType(const DataType& src_type, + const DataType& dst_type); + bool MLUSupportsCast(const VT::Type& src_type, const VT::Type& dst_type); cnnlDeviceType_t GetCnnlDev(int dev_ordinal); @@ -1202,11 +1206,13 @@ class MLUCnnl { const void* k, const int k_int, const cnnlTensorDescriptor_t output_desc, void* output); - static void ScatterNd(const ExecutionContext& ctx, + static void ScatterNd(const ExecutionContext& ctx, cnnlScatterNdMode_t mode, const cnnlTensorDescriptor_t indices_desc, const void* indices, const cnnlTensorDescriptor_t updates_desc, const void* updates, + const cnnlTensorDescriptor_t input_desc, + const void* input, const cnnlTensorDescriptor_t output_desc, void* output); static void BitWise(const ExecutionContext& ctx, @@ -1227,6 +1233,12 @@ class MLUCnnl { const void* input, const cnnlTensorDescriptor_t output_desc, void* output); + + static void EmbeddingBackward( + const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq, + const cnnlTensorDescriptor_t indices_desc, const void* indices, + const cnnlTensorDescriptor_t diff_desc, const void* diff, + const cnnlTensorDescriptor_t output_desc, void* output); }; template diff --git a/paddle/fluid/operators/unstack_op_mlu.cc b/paddle/fluid/operators/unstack_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..9c4dd256a94efea0949f1b81afe232227ee96be0 --- /dev/null +++ b/paddle/fluid/operators/unstack_op_mlu.cc @@ -0,0 +1,95 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class UnStackMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto out = ctx.MultiOutput("Y"); + int axis = ctx.Attr("axis"); + if (axis < 0) axis += x->dims().size(); + int num = x->dims()[axis]; + + std::vector out_descs; + std::vector out_raw_descs; + std::vector out_ptrs; + std::vector new_dims = phi::vectorize(x->dims()); + new_dims[axis] = 1; + for (int i = 0; i < num; i++) { + out[i]->mutable_data(ctx.GetPlace()); + out_descs.emplace_back(MLUCnnlTensorDesc(new_dims.size(), new_dims.data(), + ToCnnlDataType())); + out_raw_descs.push_back(out_descs.back().get()); + out_ptrs.push_back(GetBasePtr(out[i])); + } + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnl::Split(ctx, num, axis, x_desc.get(), GetBasePtr(x), + out_raw_descs.data(), out_ptrs.data()); + } +}; + +template +class UnStackGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto x = ctx.MultiInput(framework::GradVarName("Y")); + auto *y = ctx.Output(framework::GradVarName("X")); + int axis = ctx.Attr("axis"); + if (axis < 0) axis += (x[0]->dims().size() + 1); + int num = static_cast(x.size()); + + std::vector x_descs; + std::vector x_raw_descs; + std::vector x_ptrs; + for (int i = 0; i < num; i++) { + if (x[i]->dims().size() != 0) { + std::vector in_dims = phi::vectorize(x[i]->dims()); + in_dims.insert(in_dims.begin() + axis, 1); + x_descs.emplace_back(MLUCnnlTensorDesc(in_dims.size(), in_dims.data(), + ToCnnlDataType())); + } else { + int input_dims = 1; + x_descs.emplace_back( + MLUCnnlTensorDesc(1, &input_dims, ToCnnlDataType())); + } + x_raw_descs.push_back(x_descs.back().get()); + x_ptrs.push_back(GetBasePtr(x[i])); + } + y->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc y_desc(*y); + MLUCnnl::Concat(ctx, num, axis, x_raw_descs.data(), x_ptrs.data(), + y_desc.get(), GetBasePtr(y)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +REGISTER_OP_MLU_KERNEL(unstack, ops::UnStackMLUKernel, + ops::UnStackMLUKernel); + +REGISTER_OP_MLU_KERNEL(unstack_grad, ops::UnStackGradMLUKernel, + ops::UnStackGradMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_lookup_table_v2_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_lookup_table_v2_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a08ba4c9b146534fbd27361fdc2fc2a68f87d9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_lookup_table_v2_op_mlu.py @@ -0,0 +1,142 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2022 + + +class TestLookupTableV2(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "lookup_table_v2" + + self.init_dtype() + self.init_dims() + self.init_padding_idx() + np.random.seed(SEED) + w = np.random.random([self.vocab, self.dim]).astype(self.dtype) + x = np.random.randint( + 0, self.vocab, size=(self.bsz, self.seqlen)).astype(self.ids_dtype) + out = w[x] + if self.padding_idx != -1: + out[np.squeeze(x == self.padding_idx)] = np.zeros(self.dim) + + self.inputs = { + 'W': OpTest.np_dtype_to_fluid_dtype(w), + 'Ids': OpTest.np_dtype_to_fluid_dtype(x) + } + self.attrs = { + 'is_sparse': False, + 'is_distributed': False, + 'remote_prefetch': False, + 'padding_idx': self.padding_idx + } + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + self.ids_dtype = np.int32 + + def init_dims(self): + self.bsz = 6 + self.seqlen = 8 + self.vocab = 10 + # embedding_dim is not multiple of 32 + self.dim = 20 + + def init_padding_idx(self): + self.padding_idx = -1 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if self.dtype == np.float16: + self.check_grad_with_place( + self.place, ['W'], 'Out', max_relative_error=0.01) + else: + self.check_grad_with_place(self.place, ['W'], 'Out') + + +class TestLookupTableV2FP16(TestLookupTableV2): + no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.float16 + self.ids_dtype = np.int32 + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + self.__class__.no_need_check_grad = True + + +class TestLookupTableV2Dim32(TestLookupTableV2): + def init_dims(self): + self.bsz = 6 + self.seqlen = 8 + self.vocab = 10 + # embedding_dim is multiple of 32 + self.dim = 64 + + +class TestLookupTableV2Dim32FP16(TestLookupTableV2): + no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.float16 + self.ids_dtype = np.int64 + + def init_dims(self): + self.bsz = 6 + self.seqlen = 8 + self.vocab = 10 + self.dim = 64 + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + self.__class__.no_need_check_grad = True + + +class TestLookupTableV2WithPadding(TestLookupTableV2): + def init_padding_idx(self): + self.padding_idx = np.random.randint(0, self.vocab) + + +class TestLookupTableV2WithPadding1(TestLookupTableV2): + def init_padding_idx(self): + self.padding_idx = np.random.randint(0, self.vocab) + + def init_dtype(self): + self.dtype = np.float32 + self.ids_dtype = np.int64 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_unstack_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_unstack_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..a75a6aa1dfcb92b516e968d845b56cc90624cdd5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_unstack_op_mlu.py @@ -0,0 +1,97 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import unittest +import paddle + +paddle.enable_static() + + +class TestUnStackOpBase(OpTest): + def initDefaultParameters(self): + self.input_dim = (5, 6, 7) + self.axis = 0 + + def initParameters(self): + pass + + def get_y_names(self): + y_names = [] + for i in range(self.input_dim[self.axis]): + y_names.append('y{}'.format(i)) + return y_names + + def setUp(self): + self.initDefaultParameters() + self.initParameters() + self.op_type = 'unstack' + self.set_mlu() + self.init_dtype() + + self.x = np.random.random(size=self.input_dim).astype(self.dtype) + + outs = np.split(self.x, self.input_dim[self.axis], self.axis) + new_shape = list(self.input_dim) + del new_shape[self.axis] + y_names = self.get_y_names() + tmp = [] + for i in range(self.input_dim[self.axis]): + tmp.append((y_names[i], np.reshape(outs[i], new_shape))) + + self.inputs = {'X': self.x} + self.outputs = {'Y': tmp} + self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]} + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], self.get_y_names()) + + +class TestStackOp3(TestUnStackOpBase): + def initParameters(self): + self.axis = -1 + + +class TestStackOp4(TestUnStackOpBase): + def initParameters(self): + self.axis = -3 + + +class TestStackOp5(TestUnStackOpBase): + def initParameters(self): + self.axis = 1 + + +class TestStackOp6(TestUnStackOpBase): + def initParameters(self): + self.axis = 2 + + +if __name__ == '__main__': + unittest.main()