diff --git a/paddle/fluid/operators/coalesce_tensor_op.cc b/paddle/fluid/operators/coalesce_tensor_op.cc index 4a11e6d5723bd9e0e20dfc59d67d12ab17930934..bdc0f4799bcd861f7a7b7acdb3d8ba234cd0c780 100644 --- a/paddle/fluid/operators/coalesce_tensor_op.cc +++ b/paddle/fluid/operators/coalesce_tensor_op.cc @@ -28,6 +28,8 @@ #ifdef PADDLE_WITH_MLU #include "paddle/fluid/operators/mlu/mlu_baseop.h" #endif +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -506,24 +508,16 @@ value. } // namespace operators } // namespace paddle +DECLARE_INFER_SHAPE_FUNCTOR(coalesce_tensor, + CoalesceTensorInferShapeFunctor, + PD_INFER_META(phi::CoalesceTensorInferMeta)); + REGISTER_OPERATOR(coalesce_tensor, paddle::operators::CoalesceTensorOp, - paddle::operators::CoalesceTensorOpMaker); + paddle::operators::CoalesceTensorOpMaker, + CoalesceTensorInferShapeFunctor); namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_CPU_KERNEL(coalesce_tensor, - ops::CoalesceTensorOpKernel, - ops::CoalesceTensorOpKernel, - ops::CoalesceTensorOpKernel); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -REGISTER_OP_CUDA_KERNEL( - coalesce_tensor, - ops::CoalesceTensorOpKernel, - ops::CoalesceTensorOpKernel, - ops::CoalesceTensorOpKernel, - ops::CoalesceTensorOpKernel); -#endif #if defined(PADDLE_WITH_ASCEND_CL) REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 44096f0ef1e110586f7a5ab6c1887cbf43cf90a1..6fc34172d0434ea5357da51b69d5b8619ec3dc8c 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -493,6 +493,15 @@ kernel : func : clip_by_norm +- api : coalesce_tensor + args : (Tensor[] input, DataType dtype, bool copy_data = false, bool set_constant = false, bool persist_output = false, float constant = 0.0, bool use_align = true, int align_size = -1, int size_of_dtype = -1, int64_t[] concated_shapes = {}, int64_t[] concated_ranks = {}) + output : Tensor[](output){input.size()}, Tensor(fused_output) + infer_meta : + func : CoalesceTensorInferMeta + kernel : + func : coalesce_tensor + data_type : dtype + - api : complex args : (Tensor x, Tensor y) output : Tensor diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 12c5ba109ab061648aa9a60cb4e1832c565f148e..f43660bfdbe8ce14da2785e38d9ece7255e9a8f5 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -785,6 +785,56 @@ void CheckFiniteAndUnscaleInferMeta(const std::vector& xs, found_infinite->set_dtype(DataType::BOOL); } +void CoalesceTensorInferMeta(const std::vector& input, + DataType dtype, + bool copy_data, + bool set_constant, + bool persist_output, + float constant, + bool use_align, + int align_size, + int size_of_dtype, + const std::vector& concated_shapes, + const std::vector& concated_ranks, + std::vector output, + MetaTensor* fused_output, + MetaConfig config) { + if (config.is_runtime) { + return; + } + if (size_of_dtype == -1) { + size_of_dtype = paddle::experimental::SizeOf(dtype); + } + + auto alignment = [](size_t size, size_t align_size) { + size_t remaining = size % align_size; + auto aligned_size = remaining == 0 ? size : size + (align_size - remaining); + VLOG(4) << remaining << " " << size << " " << align_size << " " + << aligned_size; + return aligned_size; + }; + VLOG(4) << "align_size: " << align_size; + if (use_align && align_size > 0) { + int64_t numel = 0; + + for (size_t i = 0; i < input.size(); ++i) { + const auto& dim = input[i]->dims(); + auto size = phi::product(dim); + auto len = use_align + ? alignment(static_cast(size) * size_of_dtype, + align_size) / + size_of_dtype + : static_cast(size); + numel += len; + } + if (fused_output) { + fused_output->set_dims(phi::make_ddim({numel})); + fused_output->set_dtype(dtype); + VLOG(4) << "fused_output size:" << phi::make_ddim({numel}); + } + } +} + void ConcatInferMeta(const std::vector& x, const Scalar& axis_scalar, MetaTensor* out, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index b01a23b9ee3827fa4d36233d96e366c4843894eb..0296509e43750b97fc467d7ed1e03e066e6b816e 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -201,6 +201,21 @@ void CheckFiniteAndUnscaleInferMeta(const std::vector& xs, std::vector outs, MetaTensor* found_infinite); +void CoalesceTensorInferMeta(const std::vector& input, + DataType dtype, + bool copy_data, + bool set_constant, + bool persist_output, + float constant, + bool use_align, + int align_size, + int size_of_dtype, + const std::vector& concated_shapes, + const std::vector& concated_ranks, + std::vector output, + MetaTensor* fused_output, + MetaConfig config = MetaConfig()); + void ConcatInferMeta(const std::vector& x, const Scalar& axis_scalar, MetaTensor* out, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 47ae390fb6f6c81fb204dec00b7ee4f29a4d96f0..10d3e730cc1f3e7c75e585eb425c6b18855b9841 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -83,7 +83,8 @@ set(COMMON_KERNEL_DEPS custom_kernel string_infermeta gpc - utf8proc) + utf8proc + device_memory_aligment) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup) if(WITH_NCCL OR WITH_RCCL) diff --git a/paddle/phi/kernels/coalesce_tensor_kernel.cc b/paddle/phi/kernels/coalesce_tensor_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..41548fa03a9bd8c1c60db6c7a73fbf5fa2177acc --- /dev/null +++ b/paddle/phi/kernels/coalesce_tensor_kernel.cc @@ -0,0 +1,282 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/coalesce_tensor_kernel.h" + +#include +#include + +#include "paddle/fluid/platform/device_memory_aligment.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +struct FillConstantVisitor { + FillConstantVisitor(const Context &dev_ctx, + DenseTensor *tensor, + const float value) + : dev_ctx_(dev_ctx), tensor_(tensor), value_(value) {} + + template + void apply(typename std::enable_if::value || + std::is_same::value>::type * = + nullptr) const { + PADDLE_THROW( + errors::InvalidArgument("Not support data type for set_constant attr")); + } + + template + void apply(typename std::enable_if::value || + std::is_same::value)>::type + * = nullptr) const { + phi::funcs::SetConstant set_constant; + set_constant(dev_ctx_, tensor_, static_cast(value_)); + } + + const Context &dev_ctx_; + DenseTensor *tensor_; + float value_; +}; + +void GetMemSizeAndDtype(const std::vector &lod_tensors, + size_t *numel, + const size_t &size_of_dtype, + const phi::Place &place, + const bool use_align = true, + const int align_size = -1) { + *numel = 0; + std::stringstream ss; + ss << "alloc_space_for_vars: "; + for (size_t i = 0; i < lod_tensors.size(); ++i) { + auto size = lod_tensors[i]->numel(); + PADDLE_ENFORCE_GT(size, + 0, + errors::InvalidArgument( + "The number of `%d`-th tensor's elements is 0.", i)); + auto len = use_align ? paddle::platform::Alignment( + static_cast(size) * size_of_dtype, + place, + align_size) / + size_of_dtype + : static_cast(size); + const void *ptr = + lod_tensors[i]->IsInitialized() ? lod_tensors[i]->data() : nullptr; + VLOG(4) << size << " " << len; + ss << "input(" << i << "-th tensor) dim:(" << lod_tensors[i]->dims() << ") " + << " addres:" << ptr << " len: " << len << ", "; + *numel += len; + } + VLOG(10) << ss.str(); +} + +template +void CoalesceTensorKernel(const Context &dev_ctx, + const std::vector &input, + DataType dtype, + bool copy_data, + bool set_constant, + bool persist_output, + float constant, + bool use_align, + int align_size, + int size_of_dtype, + const std::vector &concated_shapes, + const std::vector &concated_ranks, + std::vector output, + DenseTensor *fused_output) { + PADDLE_ENFORCE_GT( + input.size(), + static_cast(0), + errors::InvalidArgument("The CoalesceTensor operator has no input.")); + PADDLE_ENFORCE_EQ(input.size(), + output.size(), + errors::InvalidArgument( + "The number of CoalesceTensor operator's input and " + "output is not match, " + "input number is %u, output number is %u.", + input.size(), + output.size())); + + // Input & Output check: only support LoDTensor + bool has_not_init_in_vars = false; + for (size_t i = 0; i < input.size(); ++i) { + PADDLE_ENFORCE_NOT_NULL( + input[i], + errors::InvalidArgument("The %d-th input tensor cannot be nullptr.", + i)); + PADDLE_ENFORCE_NOT_NULL( + output[i], + errors::InvalidArgument("The %d-th output tensor cannot be nullptr.", + i)); + if (!input[i]->IsInitialized()) { + has_not_init_in_vars = true; + } + } + + if (has_not_init_in_vars) { + PADDLE_ENFORCE_EQ( + concated_ranks.size(), + output.size(), + errors::InvalidArgument("The attribute(concated_ranks) length must be " + "equal to the output tensor number.")); + int64_t accumulated_ranks = 0; + for (size_t i = 0; i < input.size(); ++i) { + phi::DDim dims(concated_shapes.data() + accumulated_ranks, + concated_ranks[i]); + if (!input[i]->IsInitialized()) { + PADDLE_ENFORCE_EQ( + input[i], + output[i], + errors::InvalidArgument( + "The %d-th output tensor and %d-th input tensor when the " + "%d-th input tensor is not initialized.", + i, + i, + i)); + output[i]->Resize(dims); + } else { + PADDLE_ENFORCE_EQ(input[i]->dims(), + dims, + errors::InvalidArgument( + "The %d-th input tensor shape does not match the " + "attribute(concated_shapes) and " + "attribute(concated_ranks).", + i)); + } + accumulated_ranks += concated_ranks[i]; + PADDLE_ENFORCE_LE( + accumulated_ranks, + concated_shapes.size(), + errors::InvalidArgument("The attribute(concated_shapes) and " + "attribute(concated_ranks) do not match.")); + } + PADDLE_ENFORCE_EQ( + accumulated_ranks, + concated_shapes.size(), + errors::InvalidArgument("The attribute(concated_shapes) and " + "attribute(concated_ranks) do not match.")); + } + + // Init the output as input + for (size_t i = 0; i < input.size(); ++i) { + output[i]->Resize(input[i]->dims()); + } + + // Get numel and dtype + size_t numel = 0; + + if (size_of_dtype == -1) { + size_of_dtype = paddle::experimental::SizeOf(dtype); + } + GetMemSizeAndDtype( + input, &numel, size_of_dtype, dev_ctx.GetPlace(), use_align, align_size); + + // Alloc the continuous space + void *fused_tensor_ptr = dev_ctx.Alloc( + &fused_output->Resize(phi::make_ddim({static_cast(numel)})), + dtype); + VLOG(10) << "Fused tensor addr " << fused_tensor_ptr; + + // Init the continuous space + size_t offset = 0; + if (copy_data) { + for (size_t i = 0; i < input.size(); ++i) { + size_t len = static_cast(input[i]->numel()); + auto sub_tensor = fused_output->Slice(static_cast(offset), + static_cast(offset + len)); + phi::Copy(dev_ctx, *input[i], dev_ctx.GetPlace(), false, &sub_tensor); + + offset += use_align + ? paddle::platform::Alignment( + len * size_of_dtype, dev_ctx.GetPlace(), align_size) / + size_of_dtype + : len; + } + } else if (set_constant) { + phi::VisitDataType( + dtype, FillConstantVisitor(dev_ctx, fused_output, constant)); + } else if (persist_output) { + for (size_t i = 0; i < output.size(); ++i) { + size_t len = static_cast(output[i]->numel()); + auto sub_tensor = fused_output->Slice(static_cast(offset), + static_cast(offset + len)); + // some var may not persistable, or persistable var may not init + if (output[i]->IsInitialized()) { + phi::Copy(dev_ctx, *output[i], dev_ctx.GetPlace(), false, &sub_tensor); + } + offset += use_align + ? paddle::platform::Alignment( + len * size_of_dtype, dev_ctx.GetPlace(), align_size) / + size_of_dtype + : len; + } + } + + // Make the outputs point to the continuous space. + offset = 0; + std::stringstream ss; + ss << "alloc_space_for_vars: "; + + for (size_t i = 0; i < output.size(); ++i) { + size_t len = static_cast(output[i]->numel()); + auto dim = output[i]->dims(); + VLOG(4) << len << " " << dim << " " << offset; + output[i] + ->ShareDataWith(fused_output->Slice(static_cast(offset), + static_cast(offset + len))) + .Resize(dim); + len = use_align ? paddle::platform::Alignment( + len * size_of_dtype, dev_ctx.GetPlace(), align_size) / + size_of_dtype + : len; + ss << "output(" << i << "-th tensor) dim:(" << dim << ")" + << " address: " << output[i]->data() << " len: " << len << ", "; + offset += len; + } + PADDLE_ENFORCE_EQ((int64_t)offset, + fused_output->numel(), + errors::InvalidArgument( + "The alloc_space_for_vars's offset: %s is unequal with " + "fused_output's numel: %s.", + offset, + fused_output->numel())); + VLOG(10) << ss.str(); +} + +} // namespace phi + +PD_REGISTER_KERNEL(coalesce_tensor, + CPU, + ALL_LAYOUT, + phi::CoalesceTensorKernel, + int, + float, + double) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(coalesce_tensor, + GPU, + ALL_LAYOUT, + phi::CoalesceTensorKernel, + phi::dtype::float16, + int, + float, + double) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} +#endif diff --git a/paddle/phi/kernels/coalesce_tensor_kernel.h b/paddle/phi/kernels/coalesce_tensor_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..cbfe2c18bb0238bf241938b98a876096c7d10b55 --- /dev/null +++ b/paddle/phi/kernels/coalesce_tensor_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CoalesceTensorKernel(const Context& dev_ctx, + const std::vector& input, + DataType dtype, + bool copy_data, + bool set_constant, + bool persist_output, + float constant, + bool use_align, + int align_size, + int size_of_dtype, + const std::vector& concated_shapes, + const std::vector& concated_ranks, + std::vector output, + DenseTensor* fused_output); + +} // namespace phi diff --git a/paddle/phi/ops/compat/coalesce_tensor_sig.cc b/paddle/phi/ops/compat/coalesce_tensor_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2219850ea6d54eea085a4a594ff2f931edb6037 --- /dev/null +++ b/paddle/phi/ops/compat/coalesce_tensor_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { +KernelSignature CoalesceTensorOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("coalesce_tensor", + {"Input"}, + {"dtype", + "copy_data", + "set_constant", + "persist_output", + "constant", + "use_align", + "align_size", + "user_defined_size_of_dtype", + "concated_shapes", + "concated_ranks"}, + {"Output", "FusedOutput"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(coalesce_tensor, + phi::CoalesceTensorOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_coalesce_tensor_op.py b/python/paddle/fluid/tests/unittests/test_coalesce_tensor_op.py index 495d405c46badc748b5e9a979f086279ed6fd9ae..e398e65b0999919deb8c7af50a0a04b870a4b0a7 100644 --- a/python/paddle/fluid/tests/unittests/test_coalesce_tensor_op.py +++ b/python/paddle/fluid/tests/unittests/test_coalesce_tensor_op.py @@ -18,8 +18,28 @@ import unittest import numpy as np from op_test import OpTest from paddle.fluid import core - -alignment = 256 +import paddle.fluid as fluid +import paddle + + +def coalesce_tensor_eager_api(Input, + datatype=core.VarDesc.VarType.FP32, + copy_data=False, + set_constant=False, + persist_output=False, + constant=0.0, + use_align=True, + align_size=-1, + user_defined_size_of_dtype=-1, + concated_shapes=[], + concated_ranks=[]): + if datatype == int(core.VarDesc.VarType.FP32): + datatype = core.VarDesc.VarType.FP32 + return paddle._C_ops.coalesce_tensor(Input, datatype, copy_data, + set_constant, persist_output, constant, + use_align, align_size, + user_defined_size_of_dtype, + concated_shapes, concated_ranks) @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -27,17 +47,14 @@ alignment = 256 class TestAllocContinuousSpace(OpTest): def setUp(self): + self.python_api = coalesce_tensor_eager_api self.op_type = "coalesce_tensor" self.dtype, self.fluid_dtype = self.init_dtype() - attrs = self.init_attr() - self.copy_data = attrs["copy_data"] - self.constant = attrs["constant"] - self.set_constant = attrs["set_constant"] + self.attrs = self.init_attr() self.Inputs = self.init_input() self.Outputs, self.FusedOutput = self.init_output( - self.Inputs, self.set_constant, self.constant) + self.Inputs, self.attrs["set_constant"], self.attrs["constant"]) self.inputs = {'Input': self.Inputs} - self.attrs = attrs self.outputs = {'Output': self.Outputs, 'FusedOutput': self.FusedOutput} def init_dtype(self): @@ -64,10 +81,14 @@ class TestAllocContinuousSpace(OpTest): def init_output(self, input_list, set_constant, constant): inputs = [] outputs = input_list + # GpuMinChunkSize=256 bytes, FP32=4 bytes + alignment = 256 / 4 + if 'user_defined_size_of_dtype' in self.attrs: + alignment = 256 / self.attrs['user_defined_size_of_dtype'] for input in input_list: length = len(input[1].flatten()) - aligned_len = (length + alignment) / alignment * alignment + aligned_len = (length + alignment) // alignment * alignment out = np.zeros(int(aligned_len)) out[0:length] = input[1].flatten() inputs.append(out) @@ -80,10 +101,45 @@ class TestAllocContinuousSpace(OpTest): for out in outputs] return outputs, coalesce_tensor_var + def verify_output(self, place): + with fluid.dygraph.base.guard(place=place): + tensor_input = [ + fluid.dygraph.base.to_variable(value=data[1]) + for data in self.inputs["Input"] + ] + eager_outputs, eager_fused_output = coalesce_tensor_eager_api( + tensor_input, + datatype=self.attrs["dtype"], + copy_data=self.attrs["copy_data"] + if "copy_data" in self.attrs else False, + set_constant=self.attrs["set_constant"] + if "set_constant" in self.attrs else False, + persist_output=False, + constant=self.attrs["constant"] + if "constant" in self.attrs else 0.0, + use_align=True, + align_size=-1, + user_defined_size_of_dtype=self. + attrs["user_defined_size_of_dtype"] + if "user_defined_size_of_dtype" in self.attrs else -1, + concated_shapes=[], + concated_ranks=[]) + for idx, (expected, eager_output) in enumerate( + zip(self.outputs['Output'], eager_outputs)): + np.testing.assert_allclose(expected[1], + eager_output, + atol=1e-5, + err_msg=f'not equal {idx}') + np.testing.assert_allclose(self.outputs['FusedOutput'], + eager_fused_output, + atol=1e-5, + err_msg=f'not equal fusedoutput') + def test_check_output(self): self.check_output_with_place(place=core.CUDAPlace(0), no_check_set=["FusedOutput"], atol=1e-5) + self.verify_output(core.CUDAPlace(0)) @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -103,6 +159,7 @@ class TestAllocContinuousSpace2(TestAllocContinuousSpace): self.check_output_with_place(place=core.CUDAPlace(0), no_check_set=["FusedOutput"], atol=1e-5) + self.verify_output(core.CUDAPlace(0)) if __name__ == '__main__': diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 72b7e374c517cf1f376a46283b1facfbe01bc437..fc85dfe1a1a67297fc5da465b9defd6ffdb780e6 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -34,6 +34,7 @@ from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as from paddle.fluid.data_feeder import convert_dtype from paddle import _C_ops, _legacy_C_ops from paddle import in_dynamic_mode +from paddle.fluid.framework import in_dygraph_mode from paddle.framework import core from paddle.static import default_startup_program from paddle.static import program_guard