From eaed16810343000f05131c4b4446237bb8463f28 Mon Sep 17 00:00:00 2001 From: gouzil <66515297+gouzil@users.noreply.github.com> Date: Tue, 9 May 2023 19:26:40 +0800 Subject: [PATCH] [static op generation] coalesce_tensor (#53570) * [phi][api] add autogen code coalesce_tensor * [phi][api]fix args * [phi][api] supplement attrs --- paddle/fluid/operators/coalesce_tensor_op.cc | 519 ------------------- paddle/phi/api/yaml/legacy_ops.yaml | 9 - paddle/phi/api/yaml/op_compat.yaml | 8 + paddle/phi/api/yaml/op_version.yaml | 18 + paddle/phi/api/yaml/ops.yaml | 9 + paddle/phi/kernels/coalesce_tensor_kernel.cc | 1 + paddle/phi/ops/compat/coalesce_tensor_sig.cc | 38 -- 7 files changed, 36 insertions(+), 566 deletions(-) delete mode 100644 paddle/fluid/operators/coalesce_tensor_op.cc delete mode 100644 paddle/phi/ops/compat/coalesce_tensor_sig.cc diff --git a/paddle/fluid/operators/coalesce_tensor_op.cc b/paddle/fluid/operators/coalesce_tensor_op.cc deleted file mode 100644 index 2739cdd76ed..00000000000 --- a/paddle/fluid/operators/coalesce_tensor_op.cc +++ /dev/null @@ -1,519 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/var_type.h" -#include "paddle/phi/backends/device_memory_aligment.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/phi/infermeta/multiary.h" - -namespace paddle { -namespace operators { - -template -struct FillConstantVisitor { - FillConstantVisitor(const DeviceContext &dev_ctx, - phi::DenseTensor *tensor, - const float value, - framework::proto::VarType::Type dtype, - const framework::ExecutionContext &context) - : dev_ctx_(dev_ctx), - tensor_(tensor), - value_(value), - dtype_(dtype), - context_(context) {} - - template - void apply(typename std::enable_if::value || - std::is_same::value>::type * = - nullptr) const { - PADDLE_THROW(platform::errors::InvalidArgument( - "Not support data type for set_constant attr")); - } - - template - void apply(typename std::enable_if::value || - std::is_same::value)>::type - * = nullptr) const { - phi::funcs::SetConstant set_constant; - set_constant(dev_ctx_, tensor_, static_cast(value_)); - } - - const DeviceContext &dev_ctx_; - phi::DenseTensor *tensor_; - float value_; - framework::proto::VarType::Type dtype_; - const framework::ExecutionContext &context_; -}; - -template -class CoalesceTensorOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto in_var_names = context.InputNames("Input"); - auto out_var_names = context.OutputNames("Output"); - const auto &in_tensors = context.MultiInput("Input"); - auto out_tensors = context.MultiOutput("Output"); - - PADDLE_ENFORCE_GT(in_var_names.size(), - static_cast(0), - platform::errors::InvalidArgument( - "The CoalesceTensor operator has no input.")); - PADDLE_ENFORCE_EQ(in_var_names.size(), - out_var_names.size(), - platform::errors::InvalidArgument( - "The number of CoalesceTensor operator's input and " - "output is not match, " - "input number is %u, output number is %u.", - in_var_names.size(), - out_var_names.size())); - - // Input & Output check: only support phi::DenseTensor - bool has_not_init_in_vars = false; - for (size_t i = 0; i < in_tensors.size(); ++i) { - PADDLE_ENFORCE_NOT_NULL( - in_tensors[i], - platform::errors::InvalidArgument( - "The %d-th input tensor cannot be nullptr.", i)); - PADDLE_ENFORCE_NOT_NULL( - out_tensors[i], - platform::errors::InvalidArgument( - "The %d-th output tensor cannot be nullptr.", i)); - if (!in_tensors[i]->IsInitialized()) { - has_not_init_in_vars = true; - } - } - - if (has_not_init_in_vars) { - const auto &concated_shapes = - context.Attr>("concated_shapes"); - const auto &concated_ranks = - context.Attr>("concated_ranks"); - PADDLE_ENFORCE_EQ(concated_ranks.size(), - out_tensors.size(), - platform::errors::InvalidArgument( - "The attribute(concated_ranks) length must be " - "equal to the output tensor number.")); - int64_t accumulated_ranks = 0; - for (size_t i = 0; i < in_tensors.size(); ++i) { - framework::DDim dims(concated_shapes.data() + accumulated_ranks, - concated_ranks[i]); - if (!in_tensors[i]->IsInitialized()) { - PADDLE_ENFORCE_EQ( - in_tensors[i], - out_tensors[i], - platform::errors::InvalidArgument( - "The %d-th output tensor and %d-th input tensor when the " - "%d-th input tensor is not initialized.", - i, - i, - i)); - out_tensors[i]->Resize(dims); - } else { - PADDLE_ENFORCE_EQ( - in_tensors[i]->dims(), - dims, - platform::errors::InvalidArgument( - "The %d-th input tensor shape does not match the " - "attribute(concated_shapes) and " - "attribute(concated_ranks).", - i)); - } - accumulated_ranks += concated_ranks[i]; - PADDLE_ENFORCE_LE(accumulated_ranks, - concated_shapes.size(), - platform::errors::InvalidArgument( - "The attribute(concated_shapes) and " - "attribute(concated_ranks) do not match.")); - } - PADDLE_ENFORCE_EQ(accumulated_ranks, - concated_shapes.size(), - platform::errors::InvalidArgument( - "The attribute(concated_shapes) and " - "attribute(concated_ranks) do not match.")); - } - - bool use_align = context.Attr("use_align"); - auto align_size = context.Attr("align_size"); - auto size_of_dtype = context.Attr("user_defined_size_of_dtype"); - - if (context.Attr("check_name")) { - for (size_t i = 0; i < in_var_names.size(); ++i) { - PADDLE_ENFORCE_EQ( - in_var_names[i], - out_var_names[i], - platform::errors::InvalidArgument( - "The input and output variable of CoalesceTensor operator is " - "different, %dth input is %s, %dth output is %s.", - i, - in_var_names[i], - i, - out_var_names[i])); - } - } else { - // Init the output as input - for (size_t i = 0; i < in_tensors.size(); ++i) { - out_tensors[i]->Resize(in_tensors[i]->dims()); - } - } - - auto &dev_ctx = context.template device_context(); - - // Get numel and dtype - size_t numel = 0; - auto dtype = static_cast( - context.Attr("dtype")); - if (size_of_dtype == -1) { - size_of_dtype = framework::SizeOfType(dtype); - } - GetMemSizeAndDtype(in_tensors, - in_var_names, - &numel, - size_of_dtype, - context.GetPlace(), - use_align, - align_size); - - // Alloc the continuous space - auto fused_tensor = context.Output("FusedOutput"); - void *fused_tensor_ptr = - fused_tensor->Resize(phi::make_ddim({static_cast(numel)})) - .mutable_data(context.GetPlace(), - framework::TransToPhiDataType(dtype)); - VLOG(10) << "Fused tensor addr " << fused_tensor_ptr; - - // Init the continuous space - size_t offset = 0; - if (context.Attr("copy_data")) { - for (size_t i = 0; i < in_var_names.size(); ++i) { - size_t len = static_cast(in_tensors[i]->numel()); - auto sub_tensor = fused_tensor->Slice( - static_cast(offset), static_cast(offset + len)); - framework::TensorCopy( - *in_tensors[i], context.GetPlace(), dev_ctx, &sub_tensor); - - offset += use_align ? phi::Alignment(len * size_of_dtype, - context.GetPlace(), - align_size) / - size_of_dtype - : len; - } - } else if (context.Attr("set_constant")) { - framework::VisitDataType( - dtype, - FillConstantVisitor(dev_ctx, - fused_tensor, - context.Attr("constant"), - dtype, - context)); - } else if (context.Attr("persist_output")) { - for (size_t i = 0; i < out_var_names.size(); ++i) { - size_t len = static_cast(out_tensors[i]->numel()); - auto sub_tensor = fused_tensor->Slice( - static_cast(offset), static_cast(offset + len)); - // some var may not persistable, or persistable var may not init - if (out_tensors[i]->IsInitialized()) { - framework::TensorCopy( - *out_tensors[i], context.GetPlace(), dev_ctx, &sub_tensor); - } - offset += use_align ? phi::Alignment(len * size_of_dtype, - context.GetPlace(), - align_size) / - size_of_dtype - : len; - } - } - - // Make the outputs point to the continuous space. - offset = 0; - std::stringstream ss; - ss << "alloc_space_for_vars: "; - - for (size_t i = 0; i < out_tensors.size(); ++i) { - size_t len = static_cast(out_tensors[i]->numel()); - auto dim = out_tensors[i]->dims(); - VLOG(4) << len << " " << dim << " " << offset; - out_tensors[i] - ->ShareDataWith(fused_tensor->Slice( - static_cast(offset), static_cast(offset + len))) - .Resize(dim); - len = use_align - ? phi::Alignment( - len * size_of_dtype, context.GetPlace(), align_size) / - size_of_dtype - : len; - ss << "output(" << out_var_names[i] << ") dim:(" << dim << ")" - << " address: " << out_tensors[i]->data() << " len: " << len << ", "; - offset += len; - } - PADDLE_ENFORCE_EQ( - (int64_t)offset, - fused_tensor->numel(), - platform::errors::InvalidArgument( - "The alloc_space_for_vars's offset: %s is unequal with " - "fused_tensor's numel: %s.", - offset, - fused_tensor->numel())); - VLOG(10) << ss.str(); - } - - private: - void GetMemSizeAndDtype( - const std::vector &lod_tensors, - const std::vector var_names, - size_t *numel, - const size_t &size_of_dtype, - const platform::Place &place, - const bool use_align = true, - const int align_size = -1) const { - PADDLE_ENFORCE_EQ( - lod_tensors.size(), - var_names.size(), - platform::errors::InvalidArgument( - "The number of input tensor and variable does not match, the " - "number of input tensor is %u, the number of input variable is %u.", - lod_tensors.size(), - var_names.size())); - *numel = 0; - std::stringstream ss; - ss << "alloc_space_for_vars: "; - for (size_t i = 0; i < var_names.size(); ++i) { - auto size = lod_tensors[i]->numel(); - PADDLE_ENFORCE_GT( - size, - 0, - platform::errors::InvalidArgument( - "The number of tensor `%s`'s elements is 0.", var_names[i])); - auto len = use_align - ? phi::Alignment(static_cast(size) * size_of_dtype, - place, - align_size) / - size_of_dtype - : static_cast(size); - const void *ptr = - lod_tensors[i]->IsInitialized() ? lod_tensors[i]->data() : nullptr; - VLOG(4) << size << " " << len; - ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims() - << ") " - << " addres:" << ptr << " len: " << len << ", "; - *numel += len; - } - VLOG(10) << ss.str(); - } -}; - -class CoalesceTensorOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - auto use_align = ctx->Attrs().Get("use_align"); - auto align_size = ctx->Attrs().Get("align_size"); - auto size_of_dtype = ctx->Attrs().Get("user_defined_size_of_dtype"); - - auto dtype = static_cast( - ctx->Attrs().Get("dtype")); - if (size_of_dtype == -1) { - size_of_dtype = framework::SizeOfType(dtype); - } - if (ctx->IsRuntime()) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - int64_t numel = 0; - auto dims = ctx->GetInputsDim("Input"); - for (const auto &dim : dims) { - auto size = phi::product(dim); - auto len = use_align ? phi::Alignment( - static_cast(size) * size_of_dtype, - phi::GPUPlace(), - align_size) / - size_of_dtype - : static_cast(size); - numel += len; - } - ctx->SetOutputDim("FusedOutput", phi::make_ddim({numel})); - VLOG(4) << "FusedOutput size:" << phi::make_ddim({numel}); -#else - return; -#endif - } else { - auto alignment = [](size_t size, size_t align_size) { - size_t remaining = size % align_size; - auto aligned_size = - remaining == 0 ? size : size + (align_size - remaining); - VLOG(4) << remaining << " " << size << " " << align_size << " " - << aligned_size; - return aligned_size; - }; - VLOG(4) << "align_size: " << align_size; - if (use_align && align_size > 0) { - int64_t numel = 0; - auto dims = ctx->GetInputsDim("Input"); - for (const auto &dim : dims) { - auto size = phi::product(dim); - auto len = use_align - ? alignment(static_cast(size) * size_of_dtype, - align_size) / - size_of_dtype - : static_cast(size); - numel += len; - } - ctx->SetOutputDim("FusedOutput", phi::make_ddim({numel})); - VLOG(4) << "FusedOutput size:" << phi::make_ddim({numel}); - } - } - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &context) const override { - auto dtype = static_cast( - context.Attr("dtype")); - return phi::KernelKey(dtype, context.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string &var_name, - const phi::DenseTensor &tensor, - const phi::KernelKey &expected_kernel_type) const override { - return phi::KernelKey(phi::Backend::ALL_BACKEND, - tensor.layout(), - expected_kernel_type.dtype()); - } -}; - -class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Input", - "(vector) The input tensors of" - " coalesce_tensor operator.") - .AsDuplicable(); - AddOutput("Output", - "(vector) The output " - "tensors of coalesce_tensor operator. And the address " - "of output tensors are continuous, they are sliced from the " - "tensor of FusedOutput.") - .AsDuplicable(); - AddOutput("FusedOutput", - "(phi::DenseTensor) The output tensor " - "of coalesce_tensor operator. And the tensors of" - " Output is sliced from the tensor of FusedOutput."); - AddAttr("dtype", "The output data type."); - AddAttr("copy_data", "Whether to copy the Input value to Output.") - .SetDefault(false); - AddAttr("set_constant", - "Whether to set the Output with a constant value.") - .SetDefault(false); - AddAttr("persist_output", - "Whether to persist the original Output value.") - .SetDefault(false); - AddAttr("constant", - "If set_constant is true, the constant value will be used " - "to set the Output.") - .SetDefault(0.0); - AddAttr("check_name", - "Whether to check the name of Input and Output to ensure " - "they are the same separately.") - .SetDefault(false); - AddAttr("use_align", - "Whether to consider memory chunk and take alignment into " - "account for inputs and outputs.") - .SetDefault(true); - AddAttr("align_size", "The alignment size when use_align is True") - .SetDefault(-1); - AddAttr("user_defined_size_of_dtype", - "The user defined size of dtype. This is used to coalesce " - "grad vars and merged_grad vars at the same time. For some " - "strategy, the dtype of fused_grad_vars and the dtype of " - "fused_grad_merged_vars are not identical, which will cause " - "the shape of these two coalesced vars are different. To " - "make sure the shape of these two vars are identical with " - "each other, this attr is added.") - .SetDefault(-1); - AddAttr>( - "concated_shapes", - "The concated shapes of each shape of the input tensors. " - "If any of the input tensors are not inited, this is used to " - "init the output tensor shape, together with " - "attribute(concated_ranks).") - .SetDefault({}); - AddAttr>( - "concated_ranks", - "The concated ranks of each rank of the input tensors. " - "If any of the input tensors are not inited, this is used to " - "init the output tensor shape, together with " - "attribute(concated_shapes).") - .SetDefault({}); - AddComment(R"DOC( -CoalesceTensor Operator. - -coalesce_tensor is used to make the address of Output -continuous according to the Input. This Op will alloc a big tensor -according to the tensors of Input, the dtype is the same with those input tensors, -the size is the sum of those input tensors' numel, and the dim of the big -tensor is {sum(numel)}. And the big tensor is stored in FusedOutput. -The tensors of Output are sliced from the tensor of FusedOutput. -Note that, the dtype of Input should be the same, and the dim of Input -and Output should equal. -The tensors of Input and Output could be the same or different. And -coalesce_tensor allows copying the value of Input to Output, or -setting the Output with a constant value, or persist the original Output -value. - -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -DECLARE_INFER_SHAPE_FUNCTOR(coalesce_tensor, - CoalesceTensorInferShapeFunctor, - PD_INFER_META(phi::CoalesceTensorInferMeta)); - -REGISTER_OPERATOR(coalesce_tensor, - paddle::operators::CoalesceTensorOp, - paddle::operators::CoalesceTensorOpMaker, - CoalesceTensorInferShapeFunctor); -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_VERSION(coalesce_tensor) - .AddCheckpoint( - R"ROC( - Upgrade coalesce_tensor: add a new attribute [use_align].)ROC", - paddle::framework::compatible::OpVersionDesc().NewAttr( - "use_align", - "In order to optionally take memory alignment into account when " - "coalescing tensors. The default value is true to be compatible " - "with before.", - true)) - .AddCheckpoint( - R"ROC( - Upgrade coalesce_tensor: add a new attribute [align_size].)ROC", - paddle::framework::compatible::OpVersionDesc().NewAttr( - "align_size", - "In order to optionally take memory alignment into account when " - "coalescing tensors. The default value is -1 and use the default " - "align_size " - "of each place to be compatible with before.", - -1)); diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 266e981254f..e7b0a5ca4f3 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -171,15 +171,6 @@ data_type : x inplace : (x -> out), (input_found_infinite -> output_found_infinite) -- op : coalesce_tensor - args : (Tensor[] input, DataType dtype, bool copy_data = false, bool set_constant = false, bool persist_output = false, float constant = 0.0, bool use_align = true, int align_size = -1, int size_of_dtype = -1, int64_t[] concated_shapes = {}, int64_t[] concated_ranks = {}) - output : Tensor[](output){input.size()}, Tensor(fused_output) - infer_meta : - func : CoalesceTensorInferMeta - kernel : - func : coalesce_tensor - data_type : dtype - - op : concat args : (Tensor[] x, Scalar(int64_t) axis) output : Tensor diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 97cca3ced80..928797a3c39 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -419,6 +419,14 @@ outputs : out : Out +- op : coalesce_tensor + inputs : + {input : Input} + outputs : + {output : Output, fused_output : FusedOutput} + attrs : + {size_of_dtype : user_defined_size_of_dtype} + - op : complex backward : complex_grad inputs : diff --git a/paddle/phi/api/yaml/op_version.yaml b/paddle/phi/api/yaml/op_version.yaml index 35ee674cd14..b13b432c460 100644 --- a/paddle/phi/api/yaml/op_version.yaml +++ b/paddle/phi/api/yaml/op_version.yaml @@ -72,6 +72,24 @@ - add_input : Max comment : Pass the mix, min value as input, not attribute. Max is dispensable. +- op : coalesce_tensor + version : + - checkpoint : "Upgrade coalesce_tensor: add a new attribute [use_align]." + action : + - add_attr : use_align + comment : In order to optionally take memory alignment into account when + coalescing tensors. The default value is true to be compatible + with before. + default : "true" + - checkpoint : "Upgrade coalesce_tensor: add a new attribute [align_size]." + action : + - add_attr : align_size + comment : In order to optionally take memory alignment into account when + coalescing tensors. The default value is -1 and use the default + align_size + of each place to be compatible with before. + default : -1 + - op : embedding version : - checkpoint : Upgrade flip, add new attr [axis] and delete attr [dims] diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 2f0960e9d30..eaaf9f61fa7 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -416,6 +416,15 @@ func : clip_by_norm {dense -> dense} clip_by_norm_sr {selected_rows -> selected_rows} +- op : coalesce_tensor + args : (Tensor[] input, DataType dtype, bool copy_data = false, bool set_constant = false, bool persist_output = false, float constant = 0.0, bool use_align = true, int align_size = -1, int size_of_dtype = -1, int64_t[] concated_shapes = {}, int64_t[] concated_ranks = {}) + output : Tensor[](output){input.size()}, Tensor(fused_output) + infer_meta : + func : CoalesceTensorInferMeta + kernel : + func : coalesce_tensor + data_type : dtype + - op : complex args : (Tensor real, Tensor imag) output : Tensor diff --git a/paddle/phi/kernels/coalesce_tensor_kernel.cc b/paddle/phi/kernels/coalesce_tensor_kernel.cc index 559c88eebbc..7b91b328acd 100644 --- a/paddle/phi/kernels/coalesce_tensor_kernel.cc +++ b/paddle/phi/kernels/coalesce_tensor_kernel.cc @@ -273,6 +273,7 @@ PD_REGISTER_KERNEL(coalesce_tensor, int, float, double) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); } diff --git a/paddle/phi/ops/compat/coalesce_tensor_sig.cc b/paddle/phi/ops/compat/coalesce_tensor_sig.cc deleted file mode 100644 index a2219850ea6..00000000000 --- a/paddle/phi/ops/compat/coalesce_tensor_sig.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { -KernelSignature CoalesceTensorOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("coalesce_tensor", - {"Input"}, - {"dtype", - "copy_data", - "set_constant", - "persist_output", - "constant", - "use_align", - "align_size", - "user_defined_size_of_dtype", - "concated_shapes", - "concated_ranks"}, - {"Output", "FusedOutput"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(coalesce_tensor, - phi::CoalesceTensorOpArgumentMapping); -- GitLab