From 22e75d92cb04ede350e84be9a774b1b528f650b1 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 14 Jun 2022 20:03:49 +0800 Subject: [PATCH] [ CherryPick ] Cherry pick for einsum optimization. (#43468) * [EinsumOp] Polish forward logic and backward logic for optimize (#42603) * change logic for optimize * modifty * merge * change einsum_v2 as default and add new flags: FLAG_einsum_opt=1|0 (#43010) * [EinsumOp] Make EinsumOp support bfloat16. (#43085) * change einsum_v2 as default and add new flags: FLAG_einsum_opt=1|0 * make EInsumOP support bf16 * add unittest for BF16 * add condition for test_BF16 * fix bugs * fix * change the backward api to fit einsum op --- paddle/fluid/eager/nan_inf_utils.cc | 119 +++++++++ paddle/fluid/eager/nan_inf_utils.h | 64 +++++ paddle/fluid/operators/einsum_op.cc | 8 + paddle/fluid/platform/flags.cc | 13 + paddle/phi/infermeta/unary.cc | 3 +- paddle/phi/infermeta/unary.h | 3 +- paddle/phi/kernels/cpu/einsum_kernel.cc | 3 +- paddle/phi/kernels/einsum_grad_kernel.h | 1 + paddle/phi/kernels/einsum_kernel.h | 7 + paddle/phi/kernels/funcs/eigen/broadcast.cc | 2 + paddle/phi/kernels/funcs/eigen/broadcast.cu | 2 + paddle/phi/kernels/gpu/einsum_grad_kernel.cu | 10 +- paddle/phi/kernels/gpu/einsum_kernel.cu | 9 +- paddle/phi/kernels/gpu/tile_kernel.cu | 3 +- paddle/phi/kernels/impl/einsum_grad_impl.h | 79 ++++-- paddle/phi/kernels/impl/einsum_impl.h | 237 +++++++++++++----- paddle/phi/ops/compat/einsum_sig.cc | 7 +- .../fluid/tests/unittests/test_einsum.py | 3 + .../fluid/tests/unittests/test_einsum_op.py | 8 +- .../fluid/tests/unittests/test_einsum_v2.py | 34 ++- .../white_list/no_check_set_white_list.py | 1 + python/paddle/tensor/einsum.py | 16 +- python/paddle/utils/code_gen/api.yaml | 2 +- python/paddle/utils/code_gen/api_base.py | 6 +- python/paddle/utils/code_gen/api_gen.py | 4 +- python/paddle/utils/code_gen/backward.yaml | 4 +- .../paddle/utils/code_gen/backward_api_gen.py | 4 +- 27 files changed, 541 insertions(+), 111 deletions(-) create mode 100644 paddle/fluid/eager/nan_inf_utils.cc create mode 100644 paddle/fluid/eager/nan_inf_utils.h diff --git a/paddle/fluid/eager/nan_inf_utils.cc b/paddle/fluid/eager/nan_inf_utils.cc new file mode 100644 index 00000000000..d1c5983a370 --- /dev/null +++ b/paddle/fluid/eager/nan_inf_utils.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/nan_inf_utils.h" + +#include "paddle/fluid/framework/details/nan_inf_utils_detail.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace egr { + +void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor) { + if (tensor.initialized()) { + auto& tensor_name = tensor.name(); + const phi::DenseTensor* dense_tensor{nullptr}; + if (tensor.is_dense_tensor()) { + dense_tensor = static_cast(tensor.impl().get()); + } else if (tensor.is_selected_rows()) { + dense_tensor = &( + static_cast(tensor.impl().get())->value()); + } else { + VLOG(10) << "Only DenseTensor or SelectedRows need to check, " + << tensor_name << " is no need."; + return; + } + + auto& place = dense_tensor->place(); + if (paddle::platform::is_gpu_place(place)) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + paddle::framework::details::tensor_check< + paddle::platform::CUDADeviceContext>(api_name, tensor_name, + *dense_tensor, place); +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "Tensor[%s] use gpu place. PaddlePaddle must compile with GPU.", + tensor_name)); +#endif + return; + } + paddle::framework::details::tensor_check< + paddle::platform::CPUDeviceContext>(api_name, tensor_name, + *dense_tensor, place); + } +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfTwoTensors& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfThreeTensors& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<2>(tensors)); +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfFourTensors& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<2>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<3>(tensors)); +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfFiveTensors& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<2>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<3>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<4>(tensors)); +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfSixTensors& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<2>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<3>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<4>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<5>(tensors)); +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const std::vector& tensors) { + for (auto& tensor : tensors) { + CheckTensorHasNanOrInf(api_name, tensor); + } +} + +void CheckTensorHasNanOrInf( + const std::string& api_name, + const paddle::small_vector, + egr::kSlotSmallVectorSize>& tensors) { + for (auto& tensor_vector : tensors) { + CheckTensorHasNanOrInf(api_name, tensor_vector); + } +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfTensorAndVector& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); +} + +} // namespace egr diff --git a/paddle/fluid/eager/nan_inf_utils.h b/paddle/fluid/eager/nan_inf_utils.h new file mode 100644 index 00000000000..a411504fa49 --- /dev/null +++ b/paddle/fluid/eager/nan_inf_utils.h @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/eager/type_defs.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/utils/small_vector.h" + +namespace egr { + +using paddle::experimental::Tensor; +using TupleOfTwoTensors = std::tuple; +using TupleOfThreeTensors = std::tuple; +using TupleOfFourTensors = std::tuple; +using TupleOfFiveTensors = std::tuple; +using TupleOfSixTensors = + std::tuple; +using TupleOfTensorAndVector = std::tuple>; + +void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfTwoTensors& tensors); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfThreeTensors& tensors); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfFourTensors& tensors); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfFiveTensors& tensors); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfSixTensors& tensors); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const std::vector& tensors); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfTensorAndVector& tensors); + +void CheckTensorHasNanOrInf( + const std::string& api_name, + const paddle::small_vector, + egr::kSlotSmallVectorSize>& tensors); + +} // namespace egr diff --git a/paddle/fluid/operators/einsum_op.cc b/paddle/fluid/operators/einsum_op.cc index 8fdde1ccdc0..6da0045443c 100644 --- a/paddle/fluid/operators/einsum_op.cc +++ b/paddle/fluid/operators/einsum_op.cc @@ -33,6 +33,13 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Operands", "(TensorList), The input tensor of einsum op.") .AsDuplicable(); AddOutput("Out", "(Tensor), The output tensor of einsum op."); + AddOutput( + "InnerCache", + "(Tensor), The cache of the forward transpose tensors: tA and tB.") + .AsDuplicable() + .AsExtra() + .AsIntermediate(); + AddAttr("equation", "(string) A einsum equation. such as `ij,jk->ik`" "There must have `->` and the number of operands in " @@ -72,6 +79,7 @@ class EinsumGradMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr retv) const override { retv->SetType("einsum_grad"); retv->SetInput("Operands", this->Input("Operands")); + retv->SetInput("InnerCache", this->Output("InnerCache")); retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); retv->SetAttrMap(this->Attrs()); retv->SetOutput(framework::GradVarName("Operands"), diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index c70452c5016..aa9b68289a5 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -808,3 +808,16 @@ PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); * Example: */ PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune."); + +/** + * Preformance related FLAG + * Name: einsum_opt + * Since Version: 2.3.0 + * Value Range: bool, default=false + * Example: + * Note: If True, EinsumOp will be optimimzed by innercache reuse, which + * uses more gpu memory. + */ +PADDLE_DEFINE_EXPORTED_bool( + einsum_opt, false, + "EinsumOp backward will be speedup at the expense of more gpu memory."); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index b7959d2809e..980b4219c51 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -401,7 +401,8 @@ void EighInferMeta(const MetaTensor& x, void EinsumInferMeta(const std::vector& inputs, const std::string& equation, - MetaTensor* out) { + MetaTensor* out, + std::vector inner_cache) { // collect the following informations to prepare einsum. LabelMap labelshape(0); LabelMap labeltype(LabelType::Reduction); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 647b9c578c8..e141acb2ea2 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -82,7 +82,8 @@ void EighInferMeta(const MetaTensor& x, void EinsumInferMeta(const std::vector& inputs, const std::string& equation, - MetaTensor* out); + MetaTensor* out, + std::vector inner_cache); void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, diff --git a/paddle/phi/kernels/cpu/einsum_kernel.cc b/paddle/phi/kernels/cpu/einsum_kernel.cc index 3e25a65526d..8968542b3e0 100644 --- a/paddle/phi/kernels/cpu/einsum_kernel.cc +++ b/paddle/phi/kernels/cpu/einsum_kernel.cc @@ -17,4 +17,5 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_impl.h" -PD_REGISTER_KERNEL(einsum, CPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {} +PD_REGISTER_KERNEL( + einsum, CPU, ALL_LAYOUT, phi::EinsumKernelRaw, float, double) {} diff --git a/paddle/phi/kernels/einsum_grad_kernel.h b/paddle/phi/kernels/einsum_grad_kernel.h index 5c1970e7758..06785c8532e 100644 --- a/paddle/phi/kernels/einsum_grad_kernel.h +++ b/paddle/phi/kernels/einsum_grad_kernel.h @@ -21,6 +21,7 @@ namespace phi { template void EinsumGradKernel(const Context& dev_ctx, const std::vector& x, + const std::vector& inner_cache, const DenseTensor& out_grad, const std::string& equation, std::vector x_grad); diff --git a/paddle/phi/kernels/einsum_kernel.h b/paddle/phi/kernels/einsum_kernel.h index 3d9e8feda74..87df2b1c64a 100644 --- a/paddle/phi/kernels/einsum_kernel.h +++ b/paddle/phi/kernels/einsum_kernel.h @@ -24,4 +24,11 @@ void EinsumKernel(const Context& dev_ctx, const std::string& equation, DenseTensor* out); +template +void EinsumKernelRaw(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out, + std::vector cache); + } // namespace phi diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cc b/paddle/phi/kernels/funcs/eigen/broadcast.cc index 3459d7acd6b..008c51249f2 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cc +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cc @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" @@ -73,6 +74,7 @@ struct EigenBroadcastGrad { template struct FUNCTOR INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, dtype::float16); +INSTANTIATION(EigenBroadcast, dtype::bfloat16); INSTANTIATION(EigenBroadcast, float); INSTANTIATION(EigenBroadcast, double); INSTANTIATION(EigenBroadcast, int); diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cu b/paddle/phi/kernels/funcs/eigen/broadcast.cu index d9de69ec55e..742081a30c1 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cu +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cu @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" @@ -73,6 +74,7 @@ struct EigenBroadcastGrad { template struct FUNCTOR INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, dtype::float16); +INSTANTIATION(EigenBroadcast, dtype::bfloat16); INSTANTIATION(EigenBroadcast, float); INSTANTIATION(EigenBroadcast, double); INSTANTIATION(EigenBroadcast, int); diff --git a/paddle/phi/kernels/gpu/einsum_grad_kernel.cu b/paddle/phi/kernels/gpu/einsum_grad_kernel.cu index c8a8745f345..950f811475c 100644 --- a/paddle/phi/kernels/gpu/einsum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/einsum_grad_kernel.cu @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_grad_impl.h" -PD_REGISTER_KERNEL( - einsum_grad, GPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {} +PD_REGISTER_KERNEL(einsum_grad, + GPU, + ALL_LAYOUT, + phi::EinsumGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/einsum_kernel.cu b/paddle/phi/kernels/gpu/einsum_kernel.cu index d73e154eb40..d1f4c659038 100644 --- a/paddle/phi/kernels/gpu/einsum_kernel.cu +++ b/paddle/phi/kernels/gpu/einsum_kernel.cu @@ -18,4 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_impl.h" -PD_REGISTER_KERNEL(einsum, GPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {} +PD_REGISTER_KERNEL(einsum, + GPU, + ALL_LAYOUT, + phi::EinsumKernelRaw, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/tile_kernel.cu b/paddle/phi/kernels/gpu/tile_kernel.cu index 0c3c29e82c4..990877a8445 100644 --- a/paddle/phi/kernels/gpu/tile_kernel.cu +++ b/paddle/phi/kernels/gpu/tile_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/einsum_grad_impl.h b/paddle/phi/kernels/impl/einsum_grad_impl.h index bd0143379ce..a72db326807 100644 --- a/paddle/phi/kernels/impl/einsum_grad_impl.h +++ b/paddle/phi/kernels/impl/einsum_grad_impl.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "paddle/fluid/platform/profiler.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/tile_kernel.h" @@ -55,7 +56,13 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, } t.Resize(make_ddim(resize_dims)); DenseTensor after_tile; - TileKernel(dev_ctx, t, repeat_times, &after_tile); + if (std::all_of(repeat_times.begin(), repeat_times.end(), [](int x) { + return x == 1; + })) { + after_tile = t; + } else { + TileKernel(dev_ctx, t, repeat_times, &after_tile); + } size_t n_ellipsis_idx = op_label.find(".", 0); if (n_ellipsis_idx != std::string::npos) { // may be we need reduce. broadcast_dims is not equal to ellipsis dims. @@ -91,10 +98,11 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, template void EinsumGradKernel(const Context& dev_ctx, const std::vector& x, + const std::vector& inner_cache, const DenseTensor& out_grad, const std::string& equation, std::vector x_grad) { - VLOG(5) << "Start EisumGradKernel:"; + VLOG(5) << "Start EinsumGradKernel:"; LabelMap labelshape(0); LabelMap labeltype(LabelType::Reduction); std::vector label2perms(x.size(), LabelMap(-1)); @@ -148,34 +156,65 @@ void EinsumGradKernel(const Context& dev_ctx, right = splits[1].substr(1); auto equation_for_A = - right + "," + ops[1] + "->" + gather_labels_except_reduction(ops[0]); + ops[1] + "," + right + "->" + gather_labels_except_reduction(ops[0]); auto equation_for_B = right + "," + ops[0] + "->" + gather_labels_except_reduction(ops[1]); auto operands_for_A = std::vector(); auto operands_for_B = std::vector(); DenseTensor dA, dB; - operands_for_A.push_back(&out_grad); + // dA = einsum(B, dC) operands_for_A.push_back(x[1]); + operands_for_A.push_back(&out_grad); + // dB = einsum(dC, A) operands_for_B.push_back(&out_grad); operands_for_B.push_back(x[0]); DenseTensor before_tile; - EinsumKernel(dev_ctx, operands_for_A, equation_for_A, &dA); - EinsumKernel(dev_ctx, operands_for_B, equation_for_B, &dB); - *(x_grad[0]) = PerformTileAndReduction(dev_ctx, - labeltype, - labelshape, - broadcast_dims, - ellipsis_dims[0], - ops[0], - dA); - *(x_grad[1]) = PerformTileAndReduction(dev_ctx, - labeltype, - labelshape, - broadcast_dims, - ellipsis_dims[1], - ops[1], - dB); + + std::vector cache(3); // set empty; TA, TB, TdC + if (inner_cache.size() > + 0) { // for compatibility, we can load and run v2.3 EinsumOp. + cache[0].ShareBufferWith(*(inner_cache[0])); + cache[1].ShareBufferWith(*(inner_cache[1])); + } + + EinsumKernelImpl(dev_ctx, + all_labels, + operands_for_A, + equation_for_A, + &dA, + {&cache[1], &cache[2]}, + false); + + EinsumKernelImpl(dev_ctx, + all_labels, + operands_for_B, + equation_for_B, + &dB, + {&cache[2], &cache[0]}, + false); + + // release the cache tensor dTC to save memory right now. they are useless + // now. + cache.clear(); + if (x_grad[0]) { + *(x_grad[0]) = PerformTileAndReduction(dev_ctx, + labeltype, + labelshape, + broadcast_dims, + ellipsis_dims[0], + ops[0], + dA); + } + if (x_grad[1]) { + *(x_grad[1]) = PerformTileAndReduction(dev_ctx, + labeltype, + labelshape, + broadcast_dims, + ellipsis_dims[1], + ops[1], + dB); + } } } } // namespace phi diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h index 73940a45cbd..bfbd6e0c51c 100644 --- a/paddle/phi/kernels/impl/einsum_impl.h +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -13,12 +13,15 @@ // limitations under the License. #pragma once +#include #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/utils/string/string_helper.h" +DECLARE_bool(einsum_opt); + namespace phi { // check the validation of the Einsum equation. @@ -55,7 +58,8 @@ inline static void ValidationCheck(const std::string& equation) { enum LabelType { ALL_TYPE = 0, Batch = 1, // ABO - Free, // AO, BO + AO, // AO -- free label + BO, // BO -- free label Contraction, // AB Reduction, // A, B }; @@ -125,18 +129,31 @@ inline std::vector union_labels(const std::vector& a, return res; } +// Apply transforms to all_labels and get another all_labels +inline std::vector TransformLabelsOrder( + const std::vector& all_labels, + const LabelMap& type, + std::vector new_order) { + std::vector ret; + for (auto cnt_type : new_order) { + std::vector tmp; + for (int c : all_labels) { + if (type[c] == cnt_type) tmp.push_back(c); + } + ret.insert(ret.end(), tmp.begin(), tmp.end()); + } + return ret; +} + inline static void GlobalInfo(const std::vector& op_labels, const std::string& right, LabelMap* label2type, std::vector* sorted_labels) { - // sorted_labels: ['.', , ] - VLOG(5) << "GlobalInfo: " - << paddle::string::join_strings(*sorted_labels, ","); std::vector all; LabelMap counter(0); for (auto& ch : right) { // char int c = ch; - (*label2type)[c] = LabelType::Free; + (*label2type)[c] = LabelType::BO; } for (auto& op : op_labels) { @@ -146,39 +163,45 @@ inline static void GlobalInfo(const std::vector& op_labels, all.push_back(ch); } counter[c] += 1; - if ((*label2type)[c] != LabelType::Free && counter[c] == 2) + if ((*label2type)[c] != LabelType::BO && counter[c] == 2) (*label2type)[c] = LabelType::Contraction; else if (counter[c] == 2) (*label2type)[c] = LabelType::Batch; } } + + // BO is represent Free, so we need find the AO. + for (int c : op_labels[0]) { + if ((*label2type)[c] == LabelType::BO) (*label2type)[c] = LabelType::AO; + } + (*label2type)['.'] = LabelType::Batch; - std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) { - if ((*label2type)[c] == LabelType::Batch) - sorted_labels->push_back(static_cast(c)); - }); - std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) { - if ((*label2type)[c] == LabelType::Free) - sorted_labels->push_back(static_cast(c)); - }); - std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) { - if ((*label2type)[c] == LabelType::Contraction) - sorted_labels->push_back(static_cast(c)); - }); - std::for_each(all.begin(), all.end(), [&sorted_labels, label2type](int c) { - if ((*label2type)[c] == LabelType::Reduction) - sorted_labels->push_back(static_cast(c)); - }); - VLOG(5) << "GlobalInfo: sorted_labels before: " - << paddle::string::join_strings(*sorted_labels, ","); + + if (sorted_labels->size()) { + std::set exist(all.begin(), all.end()); + all.clear(); + std::for_each( + sorted_labels->begin(), sorted_labels->end(), [&exist, &all](char c) { + if (exist.count(c)) all.push_back(c); + }); + } + + *sorted_labels = TransformLabelsOrder(all, + *label2type, + {LabelType::Batch, + LabelType::AO, + LabelType::BO, + LabelType::Contraction, + LabelType::Reduction}); + if (counter[static_cast('.')] > 0) { std::vector tmp; tmp.push_back('.'); // push '.' in the front *sorted_labels = union_labels(tmp, *sorted_labels); - VLOG(5) << "GlobalInfo: sorted_labels after: " - << paddle::string::join_strings(*sorted_labels, ","); } + VLOG(5) << "GlobalInfo: sorted_labels after: " + << paddle::string::join_strings(*sorted_labels, ","); } inline static void InferLabelShape(const std::vector& op_labels, @@ -289,17 +312,20 @@ inline static void ParseEinsumEquation( *right = results[1].substr(1); ReplaceEllipsis(*right); auto op_labels = paddle::string::split_string(left, ","); + // split_string("i,") -> ["i"], we expect 2 op_labels. + if (left[left.size() - 1] == ',') op_labels.push_back(""); std::for_each(op_labels.begin(), op_labels.end(), ReplaceEllipsis); GlobalInfo(op_labels, *right, labeltype, all_labels); InferLabelShape(op_labels, inputs, labelshape, ellipsis_dims, broadcast_dims); - VLOG(5) << "Einsum Infershape: right:" << right; - VLOG(5) << "Einsum Infershape: op_labels:" - << paddle::string::join_strings(op_labels, "\n"); + VLOG(5) << "Einsum Infershape: right:" << *right; + VLOG(5) << "Einsum Infershape: left :" + << paddle::string::join_strings(op_labels, '\n'); InferOutputDims(*right, *broadcast_dims, *labelshape, output_dims); for (size_t i = 0; i < inputs.size(); ++i) { InferLabelPerm( op_labels[i], ellipsis_dims->at(i).size(), &((*label2perms)[i])); } + VLOG(5) << "Einsum Infershape: end"; } template @@ -327,10 +353,12 @@ std::vector GetShapeByType(const std::vector& all_labels, const LabelMap& perm, const LabelMap& label2shape, const std::vector& ellipsis, - LabelType filter) { + std::set filter) { std::vector res; for (T c : all_labels) { - if ((filter == LabelType::ALL_TYPE || type[c] == filter) && perm[c] != -1) { + if ((filter.count(LabelType::ALL_TYPE) || + filter.count(LabelType(type[c]))) && + perm[c] != -1) { if (c == '.') res.insert(res.end(), ellipsis.begin(), ellipsis.end()); else @@ -390,7 +418,9 @@ DenseTensor PerformContraction( const LabelMap& label2type, const LabelMap& label2shape, const std::vector>& ellipsis_dims, - const std::vector& broadcast_dims) { + const std::vector& broadcast_dims, + std::vector cache, + bool use_cache) { // Get All the Batches, so perm is auto all_valid = LabelMap(1); auto recover_dim = GetShapeByType(all_labels, @@ -398,36 +428,77 @@ DenseTensor PerformContraction( all_valid, label2shape, broadcast_dims, - LabelType::Batch); + {LabelType::Batch}); auto preprocess = [&](const DenseTensor& t, const LabelMap& perm, - const std::vector& ellipsis) -> DenseTensor { - auto frees = GetShapeByType( - all_labels, label2type, perm, label2shape, ellipsis, LabelType::Free); + const std::vector& ellipsis, + int operand_idx) -> DenseTensor { + // reshape + auto frees = GetShapeByType(all_labels, + label2type, + perm, + label2shape, + ellipsis, + {LabelType::AO, LabelType::BO}); auto conts = GetShapeByType(all_labels, label2type, perm, label2shape, ellipsis, - LabelType::Contraction); - auto trans_t = PerformTranspose( - dev_ctx, t, perm, all_labels, ellipsis, label2type); - auto mul_dims = GetShapeByType( - all_labels, label2type, perm, label2shape, ellipsis, LabelType::Batch); + {LabelType::Contraction}); + std::vector reordered_all_labels = all_labels; + if (operand_idx == 1) { + reordered_all_labels = TransformLabelsOrder(all_labels, + label2type, + {LabelType::Batch, + LabelType::Contraction, + LabelType::AO, + LabelType::BO, + LabelType::Reduction}); + } + // reduction + DenseTensor trans_t; + if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr && + cache[operand_idx]->IsInitialized()) { + trans_t.ShareBufferWith(*(cache[operand_idx])); + VLOG(5) << "Cache Used!"; + } else { + auto reduct_t = PerformReduction( + dev_ctx, t, perm, all_labels, ellipsis, label2type); + trans_t = PerformTranspose( + dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type); + if (FLAGS_einsum_opt && cache[operand_idx] != nullptr) + cache[operand_idx]->ShareBufferWith(trans_t); + } + auto mul_dims = GetShapeByType(all_labels, + label2type, + perm, + label2shape, + ellipsis, + {LabelType::Batch}); recover_dim.insert(recover_dim.end(), frees.begin(), frees.end()); - mul_dims.push_back( - std::accumulate(frees.begin(), frees.end(), 1, std::multiplies())); - mul_dims.push_back( - std::accumulate(conts.begin(), conts.end(), 1, std::multiplies())); + if (operand_idx == 0) { + mul_dims.push_back(std::accumulate( + frees.begin(), frees.end(), 1, std::multiplies())); + mul_dims.push_back(std::accumulate( + conts.begin(), conts.end(), 1, std::multiplies())); + } else { + mul_dims.push_back(std::accumulate( + conts.begin(), conts.end(), 1, std::multiplies())); + mul_dims.push_back(std::accumulate( + frees.begin(), frees.end(), 1, std::multiplies())); + } VLOG(5) << "PerformContraction: mul_dims: " << paddle::string::join_strings(mul_dims, ","); trans_t.Resize(make_ddim(mul_dims)); return trans_t; }; - auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0]); - auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1]); + + // Reduction, Reshape and Matmul + auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0], 0); + auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1], 1); auto after_contraction = - Matmul(dev_ctx, trans_a, trans_b, false, true); + Matmul(dev_ctx, trans_a, trans_b, false, false); VLOG(5) << "PerformContraction: recover_dim: " << paddle::string::join_strings(recover_dim, ","); after_contraction.Resize(make_ddim(recover_dim)); @@ -458,17 +529,23 @@ void TransposeToOutput(const Context& dev_ctx, axis.push_back(it - all_labels.begin() + offset); } } - if (is_no_need_transpose(axis)) return output->ShareBufferWith(to_trans); + if (is_no_need_transpose(axis)) { + output->ShareBufferWith(to_trans); + return; + } VLOG(5) << "call TransposeToOutput: with axis: " << paddle::string::join_strings(axis, ","); - return TransposeKernel(dev_ctx, to_trans, axis, output); + TransposeKernel(dev_ctx, to_trans, axis, output); } template -void EinsumKernel(const Context& dev_ctx, - const std::vector& inputs, - const std::string& equation, - DenseTensor* out) { +void EinsumKernelImpl(const Context& dev_ctx, + const std::vector& forward_all_labels, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out, + std::vector cache, + bool is_forward = true) { ValidationCheck(equation); // collect the following informations to prepare einsum. LabelMap labelshape(0); @@ -484,6 +561,9 @@ void EinsumKernel(const Context& dev_ctx, input_dims.push_back(i->dims()); } std::string right; + if (!is_forward) { + all_labels = forward_all_labels; + } ParseEinsumEquation(equation, input_dims, &labelshape, @@ -498,22 +578,18 @@ void EinsumKernel(const Context& dev_ctx, if (inputs.size() == 2) { auto& A = inputs[0]; auto& B = inputs[1]; - // Reduce Procedure - auto reduce_A = PerformReduction( - dev_ctx, *A, label2perms[0], all_labels, ellipsis_dims[0], labeltype); - auto reduce_B = PerformReduction( - dev_ctx, *B, label2perms[1], all_labels, ellipsis_dims[1], labeltype); - // Contract Procedure - dev_ctx.template Alloc(out); + // Reduction and Contract Procedure auto after_contraction = PerformContraction(dev_ctx, - reduce_A, - reduce_B, + *A, + *B, label2perms, all_labels, labeltype, labelshape, ellipsis_dims, - broadcast_dims); + broadcast_dims, + cache, + !is_forward); TransposeToOutput(dev_ctx, after_contraction, right, @@ -545,4 +621,37 @@ void EinsumKernel(const Context& dev_ctx, } } +template +void EinsumKernelRaw(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out, + std::vector cache) { + std::vector tmp; + // for the sake of compatibility, we may load and run v2.3 EinsumOp. Output + // may have nullptr and the cache.size() is not equal to inputs.size(). refer + // to BuildPhiKernelContext for details. + int diff = inputs.size() - cache.size(); + for (int i = 0; i < diff; ++i) { + cache.push_back(nullptr); + } + EinsumKernelImpl( + dev_ctx, tmp, inputs, equation, out, cache, /*forward=*/true); +} + +template +void EinsumKernel(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out) { + std::vector place_holder; + std::vector cache_tensor( + inputs.size()); // set empty; TA, TB, TdC + for (size_t i = 0; i < inputs.size(); ++i) { + cache_tensor[i] = nullptr; + } + EinsumKernelImpl( + dev_ctx, place_holder, inputs, equation, out, cache_tensor, true); +} + } // namespace phi diff --git a/paddle/phi/ops/compat/einsum_sig.cc b/paddle/phi/ops/compat/einsum_sig.cc index 0b3cc3425df..5e45bcf97ce 100644 --- a/paddle/phi/ops/compat/einsum_sig.cc +++ b/paddle/phi/ops/compat/einsum_sig.cc @@ -17,14 +17,15 @@ limitations under the License. */ namespace phi { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"}); + return KernelSignature( + "einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"}); } KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("einsum_grad", - {"Operands", {"Out@GRAD"}}, + {"Operands", "InnerCache", "Out@GRAD"}, {"equation"}, - {{"Operands@GRAD"}}); + {"Operands@GRAD"}); } } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_einsum.py b/python/paddle/fluid/tests/unittests/test_einsum.py index 43b5ce96a39..26aaf0f44f1 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum.py +++ b/python/paddle/fluid/tests/unittests/test_einsum.py @@ -18,6 +18,9 @@ import unittest import paddle from paddle.fluid import core +import os +os.environ['FLAGS_new_einsum'] = "0" + class TestErrors(unittest.TestCase): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/test_einsum_op.py b/python/paddle/fluid/tests/unittests/test_einsum_op.py index 565e43214ea..1a4ae54afef 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_op.py @@ -34,7 +34,11 @@ class TestEinsumBinary(OpTest): self.operands.append(("x" + str(idx), inp)) self.inputs = {"Operands": self.operands} self.attrs = {"equation": self.equation} - self.outputs = {'Out': out} + self.outputs = { + 'Out': out, + "InnerCache": [('cache_' + str(i), np.array([1.0])) + for i in range(len(self.operands))] + } def init_input(self): self.inputs = [] @@ -49,7 +53,7 @@ class TestEinsumBinary(OpTest): def test_check_output(self): if not self.disable: - self.check_output() + self.check_output(no_check_set=["InnerCache"]) def test_grad(self): if not self.disable: diff --git a/python/paddle/fluid/tests/unittests/test_einsum_v2.py b/python/paddle/fluid/tests/unittests/test_einsum_v2.py index 63acaf63969..b33a943c9f2 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_v2.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_v2.py @@ -464,5 +464,37 @@ class TestNumpyTests(unittest.TestCase): self.check_output_equal(a, e) +class TestStaticGraphShape(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def test_shape(self): + A = paddle.static.data(name='x', shape=[-1]) + B = paddle.static.data(name='y', shape=[384]) + C = paddle.einsum('i,d->id', A, B) + self.assertEqual(C.shape, (-1, 384)) + + +class TestBF16(unittest.TestCase): + """ + EinsumOp support bfloat16 type, add unittest here for the correctness. + """ + + def test_shape(self): + cuda_major = paddle.version.cuda().split('.')[0].strip() + if paddle.is_compiled_with_cuda() and int(cuda_major) >= 11: + """ MatmulKernel support bfloat16 only if cuda_major > 11.0. + """ + A = paddle.to_tensor(np.array([1.0, 2.0])).astype(paddle.bfloat16) + A = A.cuda() + B = paddle.to_tensor(np.array([2.0, 3.0])).astype(paddle.bfloat16) + B = B.cuda() + C = paddle.einsum('i,i->', A, B) + self.assertEqual(C.item(), 8.0) + + if __name__ == "__main__": - u + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index 23bbc377cae..ea3264ba0db 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -35,4 +35,5 @@ no_check_set_white_list = [ 'eigh', 'eigvalsh', 'class_center_sample', + 'einsum', ] diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index 713a611f9f3..49cc426a00f 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -798,11 +798,12 @@ def gen_einsum_op(equation, *operands): """ assert len(operands) <= 2, "Only support two operands in EinsumOp." if in_dygraph_mode(): - return _C_ops.final_state_einsum(operands, equation) + return _C_ops.final_state_einsum(operands, equation)[0] if _in_legacy_dygraph(): # dygraph - return _C_ops.einsum(operands, 'equation', equation) + return _C_ops.einsum(operands, len(operands), 'equation', equation)[0] + # static graph for inp in operands: check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum') @@ -811,11 +812,16 @@ def gen_einsum_op(equation, *operands): out = helper.create_variable_for_type_inference(dtype=operands[0].dtype) attrs = dict() attrs['equation'] = equation + caches = [ + helper.create_variable_for_type_inference(dtype=operands[0].dtype) + for i in range(len(operands)) + ] helper.append_op( type='einsum', inputs={'Operands': operands}, - outputs={'Out': out}, - attrs=attrs, ) + outputs={'Out': out, + "InnerCache": caches}, + attrs=attrs) return out @@ -977,7 +983,7 @@ def einsum(equation, *operands): # [0.51476848, 0.23367381, 0.39229113]]]) """ import os - if int(os.environ.get('FLAGS_new_einsum', "0")): + if int(os.environ.get('FLAGS_new_einsum', "1")): return einsum_v2(equation, *operands) nop = len(operands) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index c0e2c39f818..845f6b6ba2f 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -547,7 +547,7 @@ - api : einsum args : (Tensor[] x, str equation) - output : Tensor + output : Tensor, Tensor[]{x.size()} infer_meta : func : EinsumInferMeta param : [x, equation] diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index de4e72bbe41..b618e6994db 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -205,17 +205,19 @@ class BaseAPI(object): if len(temp_list) == 1: out_type, out_name, size_expr = parse_output_item(temp_list[0]) - return [out_type], [out_name], size_expr, self.get_return_type( + return [out_type], [out_name], [size_expr], self.get_return_type( [out_type]) else: out_type_list = [] out_name_list = [] + out_size_expr_list = [] for output_item in temp_list: out_type, out_name, size_expr = parse_output_item(output_item) out_type_list.append(out_type) out_name_list.append(out_name) + out_size_expr_list.append(size_expr) - return out_type_list, out_name_list, size_expr, self.get_return_type( + return out_type_list, out_name_list, out_size_expr_list, self.get_return_type( out_type_list) def parse_infer_meta(self, infer_meta_config): diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index 538958c2361..08c9edb5764 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -94,10 +94,10 @@ class ForwardAPI(BaseAPI): {code_indent} {self.outputs['return_type']} api_output{inplace_assign};""" if self.outputs['return_type'] == 'std::vector': - assert self.outputs['out_size_expr'] is not None, \ + assert self.outputs['out_size_expr'][0] is not None, \ f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." output_create = output_create + f""" -{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr']}, kernel_backend, &api_output);""" +{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, kernel_backend, &api_output);""" else: output_create = output_create + f""" diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index f6e5f61e00a..e16c57d4d89 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -447,8 +447,8 @@ skip_transform : out_w, out_w_grad - backward_api : einsum_grad - forward : einsum (Tensor[] x, str equation) -> Tensor(out) - args : (Tensor[] x, Tensor out_grad, str equation) + forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache) + args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation) output : Tensor[](x_grad){x.size()} infer_meta : func : UnchangedMultiInferMeta diff --git a/python/paddle/utils/code_gen/backward_api_gen.py b/python/paddle/utils/code_gen/backward_api_gen.py index a88339c607c..d3ed928592d 100644 --- a/python/paddle/utils/code_gen/backward_api_gen.py +++ b/python/paddle/utils/code_gen/backward_api_gen.py @@ -113,10 +113,10 @@ class BackwardAPI(BaseAPI): {code_indent} {self.outputs['return_type']} api_output{inplace_assign};""" if output_type_list[0] == 'std::vector': - assert self.outputs['out_size_expr'] is not None, \ + assert self.outputs['out_size_expr'][0] is not None, \ f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." output_create = output_create + f""" -{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr']}, kernel_backend, &api_output);""" +{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, kernel_backend, &api_output);""" else: output_create = output_create + f""" -- GitLab