diff --git a/paddle/fluid/eager/nan_inf_utils.cc b/paddle/fluid/eager/nan_inf_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..d1c5983a3702f39a679983638a32e9588e16ff4a --- /dev/null +++ b/paddle/fluid/eager/nan_inf_utils.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/nan_inf_utils.h" + +#include "paddle/fluid/framework/details/nan_inf_utils_detail.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace egr { + +void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor) { + if (tensor.initialized()) { + auto& tensor_name = tensor.name(); + const phi::DenseTensor* dense_tensor{nullptr}; + if (tensor.is_dense_tensor()) { + dense_tensor = static_cast(tensor.impl().get()); + } else if (tensor.is_selected_rows()) { + dense_tensor = &( + static_cast(tensor.impl().get())->value()); + } else { + VLOG(10) << "Only DenseTensor or SelectedRows need to check, " + << tensor_name << " is no need."; + return; + } + + auto& place = dense_tensor->place(); + if (paddle::platform::is_gpu_place(place)) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + paddle::framework::details::tensor_check< + paddle::platform::CUDADeviceContext>(api_name, tensor_name, + *dense_tensor, place); +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "Tensor[%s] use gpu place. PaddlePaddle must compile with GPU.", + tensor_name)); +#endif + return; + } + paddle::framework::details::tensor_check< + paddle::platform::CPUDeviceContext>(api_name, tensor_name, + *dense_tensor, place); + } +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfTwoTensors& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfThreeTensors& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<2>(tensors)); +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfFourTensors& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<2>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<3>(tensors)); +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfFiveTensors& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<2>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<3>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<4>(tensors)); +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfSixTensors& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<2>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<3>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<4>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<5>(tensors)); +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const std::vector& tensors) { + for (auto& tensor : tensors) { + CheckTensorHasNanOrInf(api_name, tensor); + } +} + +void CheckTensorHasNanOrInf( + const std::string& api_name, + const paddle::small_vector, + egr::kSlotSmallVectorSize>& tensors) { + for (auto& tensor_vector : tensors) { + CheckTensorHasNanOrInf(api_name, tensor_vector); + } +} + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfTensorAndVector& tensors) { + CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); +} + +} // namespace egr diff --git a/paddle/fluid/eager/nan_inf_utils.h b/paddle/fluid/eager/nan_inf_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a411504fa4900d0a0f047e3d2c13a047fdd03888 --- /dev/null +++ b/paddle/fluid/eager/nan_inf_utils.h @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/eager/type_defs.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/utils/small_vector.h" + +namespace egr { + +using paddle::experimental::Tensor; +using TupleOfTwoTensors = std::tuple; +using TupleOfThreeTensors = std::tuple; +using TupleOfFourTensors = std::tuple; +using TupleOfFiveTensors = std::tuple; +using TupleOfSixTensors = + std::tuple; +using TupleOfTensorAndVector = std::tuple>; + +void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfTwoTensors& tensors); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfThreeTensors& tensors); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfFourTensors& tensors); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfFiveTensors& tensors); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfSixTensors& tensors); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const std::vector& tensors); + +void CheckTensorHasNanOrInf(const std::string& api_name, + const TupleOfTensorAndVector& tensors); + +void CheckTensorHasNanOrInf( + const std::string& api_name, + const paddle::small_vector, + egr::kSlotSmallVectorSize>& tensors); + +} // namespace egr diff --git a/paddle/fluid/operators/einsum_op.cc b/paddle/fluid/operators/einsum_op.cc index 8fdde1ccdc058be3ada3736a15f7ec249e8b868b..6da0045443cccdbad2965e9c9b320a41c2015d4d 100644 --- a/paddle/fluid/operators/einsum_op.cc +++ b/paddle/fluid/operators/einsum_op.cc @@ -33,6 +33,13 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Operands", "(TensorList), The input tensor of einsum op.") .AsDuplicable(); AddOutput("Out", "(Tensor), The output tensor of einsum op."); + AddOutput( + "InnerCache", + "(Tensor), The cache of the forward transpose tensors: tA and tB.") + .AsDuplicable() + .AsExtra() + .AsIntermediate(); + AddAttr("equation", "(string) A einsum equation. such as `ij,jk->ik`" "There must have `->` and the number of operands in " @@ -72,6 +79,7 @@ class EinsumGradMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr retv) const override { retv->SetType("einsum_grad"); retv->SetInput("Operands", this->Input("Operands")); + retv->SetInput("InnerCache", this->Output("InnerCache")); retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); retv->SetAttrMap(this->Attrs()); retv->SetOutput(framework::GradVarName("Operands"), diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index c70452c5016e88e0b41299cbc662ee3cc49d04b3..aa9b68289a555134a72fac94601cee4a2cd84d2c 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -808,3 +808,16 @@ PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); * Example: */ PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune."); + +/** + * Preformance related FLAG + * Name: einsum_opt + * Since Version: 2.3.0 + * Value Range: bool, default=false + * Example: + * Note: If True, EinsumOp will be optimimzed by innercache reuse, which + * uses more gpu memory. + */ +PADDLE_DEFINE_EXPORTED_bool( + einsum_opt, false, + "EinsumOp backward will be speedup at the expense of more gpu memory."); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index b7959d2809e4f64ef06f719818222c049daf5b41..980b4219c51dd01dd3af3a48287d65b90083c728 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -401,7 +401,8 @@ void EighInferMeta(const MetaTensor& x, void EinsumInferMeta(const std::vector& inputs, const std::string& equation, - MetaTensor* out) { + MetaTensor* out, + std::vector inner_cache) { // collect the following informations to prepare einsum. LabelMap labelshape(0); LabelMap labeltype(LabelType::Reduction); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 647b9c578c8e30a1a8259004c5b5ec40ff7c8362..e141acb2ea293c9ad11826ca9865f5af89c1c614 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -82,7 +82,8 @@ void EighInferMeta(const MetaTensor& x, void EinsumInferMeta(const std::vector& inputs, const std::string& equation, - MetaTensor* out); + MetaTensor* out, + std::vector inner_cache); void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, diff --git a/paddle/phi/kernels/cpu/einsum_kernel.cc b/paddle/phi/kernels/cpu/einsum_kernel.cc index 3e25a65526d8965c0954c2ba13f6a2c28380d3f2..8968542b3e0b898f0a058c4588932c8bac2c97bf 100644 --- a/paddle/phi/kernels/cpu/einsum_kernel.cc +++ b/paddle/phi/kernels/cpu/einsum_kernel.cc @@ -17,4 +17,5 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_impl.h" -PD_REGISTER_KERNEL(einsum, CPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {} +PD_REGISTER_KERNEL( + einsum, CPU, ALL_LAYOUT, phi::EinsumKernelRaw, float, double) {} diff --git a/paddle/phi/kernels/einsum_grad_kernel.h b/paddle/phi/kernels/einsum_grad_kernel.h index 5c1970e775825492412239f76a7bf00e22a12d5c..06785c8532e70d560a61fbf9e50b8f891fcdd317 100644 --- a/paddle/phi/kernels/einsum_grad_kernel.h +++ b/paddle/phi/kernels/einsum_grad_kernel.h @@ -21,6 +21,7 @@ namespace phi { template void EinsumGradKernel(const Context& dev_ctx, const std::vector& x, + const std::vector& inner_cache, const DenseTensor& out_grad, const std::string& equation, std::vector x_grad); diff --git a/paddle/phi/kernels/einsum_kernel.h b/paddle/phi/kernels/einsum_kernel.h index 3d9e8feda748de24822ee95304016a5cd8a5676e..87df2b1c64a4a993b054b42926fa0508ed5e8a96 100644 --- a/paddle/phi/kernels/einsum_kernel.h +++ b/paddle/phi/kernels/einsum_kernel.h @@ -24,4 +24,11 @@ void EinsumKernel(const Context& dev_ctx, const std::string& equation, DenseTensor* out); +template +void EinsumKernelRaw(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out, + std::vector cache); + } // namespace phi diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cc b/paddle/phi/kernels/funcs/eigen/broadcast.cc index 3459d7acd6baf0c4192719159ff02db721886332..008c51249f2497241af3bd9765aeadc4e4e425ba 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cc +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cc @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" @@ -73,6 +74,7 @@ struct EigenBroadcastGrad { template struct FUNCTOR INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, dtype::float16); +INSTANTIATION(EigenBroadcast, dtype::bfloat16); INSTANTIATION(EigenBroadcast, float); INSTANTIATION(EigenBroadcast, double); INSTANTIATION(EigenBroadcast, int); diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cu b/paddle/phi/kernels/funcs/eigen/broadcast.cu index d9de69ec55e8b5852766cbfc48244aa6347438fc..742081a30c1a0aa249012be3003573e0330e1266 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cu +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cu @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" @@ -73,6 +74,7 @@ struct EigenBroadcastGrad { template struct FUNCTOR INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, dtype::float16); +INSTANTIATION(EigenBroadcast, dtype::bfloat16); INSTANTIATION(EigenBroadcast, float); INSTANTIATION(EigenBroadcast, double); INSTANTIATION(EigenBroadcast, int); diff --git a/paddle/phi/kernels/gpu/einsum_grad_kernel.cu b/paddle/phi/kernels/gpu/einsum_grad_kernel.cu index c8a8745f34522ddc902dd8af9bf098714c398a20..950f811475c99f508654d36b3fbc5c3131f22e41 100644 --- a/paddle/phi/kernels/gpu/einsum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/einsum_grad_kernel.cu @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_grad_impl.h" -PD_REGISTER_KERNEL( - einsum_grad, GPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {} +PD_REGISTER_KERNEL(einsum_grad, + GPU, + ALL_LAYOUT, + phi::EinsumGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/einsum_kernel.cu b/paddle/phi/kernels/gpu/einsum_kernel.cu index d73e154eb40f7934f99e2b3a718e0dbb64c8d6c5..d1f4c6590387a81464a3fdceec0442934e8b2940 100644 --- a/paddle/phi/kernels/gpu/einsum_kernel.cu +++ b/paddle/phi/kernels/gpu/einsum_kernel.cu @@ -18,4 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_impl.h" -PD_REGISTER_KERNEL(einsum, GPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {} +PD_REGISTER_KERNEL(einsum, + GPU, + ALL_LAYOUT, + phi::EinsumKernelRaw, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/tile_kernel.cu b/paddle/phi/kernels/gpu/tile_kernel.cu index 0c3c29e82c42aefae33a4a9be9e9a7d9ec0c1e99..990877a8445cbf98d5b487a969a24059dbc24c84 100644 --- a/paddle/phi/kernels/gpu/tile_kernel.cu +++ b/paddle/phi/kernels/gpu/tile_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/einsum_grad_impl.h b/paddle/phi/kernels/impl/einsum_grad_impl.h index bd0143379ce1510e9417cba0209ecb3d251b5f71..a72db326807f8eea865b197e6723924413e29a9b 100644 --- a/paddle/phi/kernels/impl/einsum_grad_impl.h +++ b/paddle/phi/kernels/impl/einsum_grad_impl.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "paddle/fluid/platform/profiler.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/tile_kernel.h" @@ -55,7 +56,13 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, } t.Resize(make_ddim(resize_dims)); DenseTensor after_tile; - TileKernel(dev_ctx, t, repeat_times, &after_tile); + if (std::all_of(repeat_times.begin(), repeat_times.end(), [](int x) { + return x == 1; + })) { + after_tile = t; + } else { + TileKernel(dev_ctx, t, repeat_times, &after_tile); + } size_t n_ellipsis_idx = op_label.find(".", 0); if (n_ellipsis_idx != std::string::npos) { // may be we need reduce. broadcast_dims is not equal to ellipsis dims. @@ -91,10 +98,11 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, template void EinsumGradKernel(const Context& dev_ctx, const std::vector& x, + const std::vector& inner_cache, const DenseTensor& out_grad, const std::string& equation, std::vector x_grad) { - VLOG(5) << "Start EisumGradKernel:"; + VLOG(5) << "Start EinsumGradKernel:"; LabelMap labelshape(0); LabelMap labeltype(LabelType::Reduction); std::vector label2perms(x.size(), LabelMap(-1)); @@ -148,34 +156,65 @@ void EinsumGradKernel(const Context& dev_ctx, right = splits[1].substr(1); auto equation_for_A = - right + "," + ops[1] + "->" + gather_labels_except_reduction(ops[0]); + ops[1] + "," + right + "->" + gather_labels_except_reduction(ops[0]); auto equation_for_B = right + "," + ops[0] + "->" + gather_labels_except_reduction(ops[1]); auto operands_for_A = std::vector(); auto operands_for_B = std::vector(); DenseTensor dA, dB; - operands_for_A.push_back(&out_grad); + // dA = einsum(B, dC) operands_for_A.push_back(x[1]); + operands_for_A.push_back(&out_grad); + // dB = einsum(dC, A) operands_for_B.push_back(&out_grad); operands_for_B.push_back(x[0]); DenseTensor before_tile; - EinsumKernel(dev_ctx, operands_for_A, equation_for_A, &dA); - EinsumKernel(dev_ctx, operands_for_B, equation_for_B, &dB); - *(x_grad[0]) = PerformTileAndReduction(dev_ctx, - labeltype, - labelshape, - broadcast_dims, - ellipsis_dims[0], - ops[0], - dA); - *(x_grad[1]) = PerformTileAndReduction(dev_ctx, - labeltype, - labelshape, - broadcast_dims, - ellipsis_dims[1], - ops[1], - dB); + + std::vector cache(3); // set empty; TA, TB, TdC + if (inner_cache.size() > + 0) { // for compatibility, we can load and run v2.3 EinsumOp. + cache[0].ShareBufferWith(*(inner_cache[0])); + cache[1].ShareBufferWith(*(inner_cache[1])); + } + + EinsumKernelImpl(dev_ctx, + all_labels, + operands_for_A, + equation_for_A, + &dA, + {&cache[1], &cache[2]}, + false); + + EinsumKernelImpl(dev_ctx, + all_labels, + operands_for_B, + equation_for_B, + &dB, + {&cache[2], &cache[0]}, + false); + + // release the cache tensor dTC to save memory right now. they are useless + // now. + cache.clear(); + if (x_grad[0]) { + *(x_grad[0]) = PerformTileAndReduction(dev_ctx, + labeltype, + labelshape, + broadcast_dims, + ellipsis_dims[0], + ops[0], + dA); + } + if (x_grad[1]) { + *(x_grad[1]) = PerformTileAndReduction(dev_ctx, + labeltype, + labelshape, + broadcast_dims, + ellipsis_dims[1], + ops[1], + dB); + } } } } // namespace phi diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h index 73940a45cbde2b5b5f301b6a1f7d7c328c1b53c1..bfbd6e0c51cfc7b6592b78335810c75e449fb042 100644 --- a/paddle/phi/kernels/impl/einsum_impl.h +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -13,12 +13,15 @@ // limitations under the License. #pragma once +#include #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/utils/string/string_helper.h" +DECLARE_bool(einsum_opt); + namespace phi { // check the validation of the Einsum equation. @@ -55,7 +58,8 @@ inline static void ValidationCheck(const std::string& equation) { enum LabelType { ALL_TYPE = 0, Batch = 1, // ABO - Free, // AO, BO + AO, // AO -- free label + BO, // BO -- free label Contraction, // AB Reduction, // A, B }; @@ -125,18 +129,31 @@ inline std::vector union_labels(const std::vector& a, return res; } +// Apply transforms to all_labels and get another all_labels +inline std::vector TransformLabelsOrder( + const std::vector& all_labels, + const LabelMap& type, + std::vector new_order) { + std::vector ret; + for (auto cnt_type : new_order) { + std::vector tmp; + for (int c : all_labels) { + if (type[c] == cnt_type) tmp.push_back(c); + } + ret.insert(ret.end(), tmp.begin(), tmp.end()); + } + return ret; +} + inline static void GlobalInfo(const std::vector& op_labels, const std::string& right, LabelMap* label2type, std::vector* sorted_labels) { - // sorted_labels: ['.', , ] - VLOG(5) << "GlobalInfo: " - << paddle::string::join_strings(*sorted_labels, ","); std::vector all; LabelMap counter(0); for (auto& ch : right) { // char int c = ch; - (*label2type)[c] = LabelType::Free; + (*label2type)[c] = LabelType::BO; } for (auto& op : op_labels) { @@ -146,39 +163,45 @@ inline static void GlobalInfo(const std::vector& op_labels, all.push_back(ch); } counter[c] += 1; - if ((*label2type)[c] != LabelType::Free && counter[c] == 2) + if ((*label2type)[c] != LabelType::BO && counter[c] == 2) (*label2type)[c] = LabelType::Contraction; else if (counter[c] == 2) (*label2type)[c] = LabelType::Batch; } } + + // BO is represent Free, so we need find the AO. + for (int c : op_labels[0]) { + if ((*label2type)[c] == LabelType::BO) (*label2type)[c] = LabelType::AO; + } + (*label2type)['.'] = LabelType::Batch; - std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) { - if ((*label2type)[c] == LabelType::Batch) - sorted_labels->push_back(static_cast(c)); - }); - std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) { - if ((*label2type)[c] == LabelType::Free) - sorted_labels->push_back(static_cast(c)); - }); - std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) { - if ((*label2type)[c] == LabelType::Contraction) - sorted_labels->push_back(static_cast(c)); - }); - std::for_each(all.begin(), all.end(), [&sorted_labels, label2type](int c) { - if ((*label2type)[c] == LabelType::Reduction) - sorted_labels->push_back(static_cast(c)); - }); - VLOG(5) << "GlobalInfo: sorted_labels before: " - << paddle::string::join_strings(*sorted_labels, ","); + + if (sorted_labels->size()) { + std::set exist(all.begin(), all.end()); + all.clear(); + std::for_each( + sorted_labels->begin(), sorted_labels->end(), [&exist, &all](char c) { + if (exist.count(c)) all.push_back(c); + }); + } + + *sorted_labels = TransformLabelsOrder(all, + *label2type, + {LabelType::Batch, + LabelType::AO, + LabelType::BO, + LabelType::Contraction, + LabelType::Reduction}); + if (counter[static_cast('.')] > 0) { std::vector tmp; tmp.push_back('.'); // push '.' in the front *sorted_labels = union_labels(tmp, *sorted_labels); - VLOG(5) << "GlobalInfo: sorted_labels after: " - << paddle::string::join_strings(*sorted_labels, ","); } + VLOG(5) << "GlobalInfo: sorted_labels after: " + << paddle::string::join_strings(*sorted_labels, ","); } inline static void InferLabelShape(const std::vector& op_labels, @@ -289,17 +312,20 @@ inline static void ParseEinsumEquation( *right = results[1].substr(1); ReplaceEllipsis(*right); auto op_labels = paddle::string::split_string(left, ","); + // split_string("i,") -> ["i"], we expect 2 op_labels. + if (left[left.size() - 1] == ',') op_labels.push_back(""); std::for_each(op_labels.begin(), op_labels.end(), ReplaceEllipsis); GlobalInfo(op_labels, *right, labeltype, all_labels); InferLabelShape(op_labels, inputs, labelshape, ellipsis_dims, broadcast_dims); - VLOG(5) << "Einsum Infershape: right:" << right; - VLOG(5) << "Einsum Infershape: op_labels:" - << paddle::string::join_strings(op_labels, "\n"); + VLOG(5) << "Einsum Infershape: right:" << *right; + VLOG(5) << "Einsum Infershape: left :" + << paddle::string::join_strings(op_labels, '\n'); InferOutputDims(*right, *broadcast_dims, *labelshape, output_dims); for (size_t i = 0; i < inputs.size(); ++i) { InferLabelPerm( op_labels[i], ellipsis_dims->at(i).size(), &((*label2perms)[i])); } + VLOG(5) << "Einsum Infershape: end"; } template @@ -327,10 +353,12 @@ std::vector GetShapeByType(const std::vector& all_labels, const LabelMap& perm, const LabelMap& label2shape, const std::vector& ellipsis, - LabelType filter) { + std::set filter) { std::vector res; for (T c : all_labels) { - if ((filter == LabelType::ALL_TYPE || type[c] == filter) && perm[c] != -1) { + if ((filter.count(LabelType::ALL_TYPE) || + filter.count(LabelType(type[c]))) && + perm[c] != -1) { if (c == '.') res.insert(res.end(), ellipsis.begin(), ellipsis.end()); else @@ -390,7 +418,9 @@ DenseTensor PerformContraction( const LabelMap& label2type, const LabelMap& label2shape, const std::vector>& ellipsis_dims, - const std::vector& broadcast_dims) { + const std::vector& broadcast_dims, + std::vector cache, + bool use_cache) { // Get All the Batches, so perm is auto all_valid = LabelMap(1); auto recover_dim = GetShapeByType(all_labels, @@ -398,36 +428,77 @@ DenseTensor PerformContraction( all_valid, label2shape, broadcast_dims, - LabelType::Batch); + {LabelType::Batch}); auto preprocess = [&](const DenseTensor& t, const LabelMap& perm, - const std::vector& ellipsis) -> DenseTensor { - auto frees = GetShapeByType( - all_labels, label2type, perm, label2shape, ellipsis, LabelType::Free); + const std::vector& ellipsis, + int operand_idx) -> DenseTensor { + // reshape + auto frees = GetShapeByType(all_labels, + label2type, + perm, + label2shape, + ellipsis, + {LabelType::AO, LabelType::BO}); auto conts = GetShapeByType(all_labels, label2type, perm, label2shape, ellipsis, - LabelType::Contraction); - auto trans_t = PerformTranspose( - dev_ctx, t, perm, all_labels, ellipsis, label2type); - auto mul_dims = GetShapeByType( - all_labels, label2type, perm, label2shape, ellipsis, LabelType::Batch); + {LabelType::Contraction}); + std::vector reordered_all_labels = all_labels; + if (operand_idx == 1) { + reordered_all_labels = TransformLabelsOrder(all_labels, + label2type, + {LabelType::Batch, + LabelType::Contraction, + LabelType::AO, + LabelType::BO, + LabelType::Reduction}); + } + // reduction + DenseTensor trans_t; + if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr && + cache[operand_idx]->IsInitialized()) { + trans_t.ShareBufferWith(*(cache[operand_idx])); + VLOG(5) << "Cache Used!"; + } else { + auto reduct_t = PerformReduction( + dev_ctx, t, perm, all_labels, ellipsis, label2type); + trans_t = PerformTranspose( + dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type); + if (FLAGS_einsum_opt && cache[operand_idx] != nullptr) + cache[operand_idx]->ShareBufferWith(trans_t); + } + auto mul_dims = GetShapeByType(all_labels, + label2type, + perm, + label2shape, + ellipsis, + {LabelType::Batch}); recover_dim.insert(recover_dim.end(), frees.begin(), frees.end()); - mul_dims.push_back( - std::accumulate(frees.begin(), frees.end(), 1, std::multiplies())); - mul_dims.push_back( - std::accumulate(conts.begin(), conts.end(), 1, std::multiplies())); + if (operand_idx == 0) { + mul_dims.push_back(std::accumulate( + frees.begin(), frees.end(), 1, std::multiplies())); + mul_dims.push_back(std::accumulate( + conts.begin(), conts.end(), 1, std::multiplies())); + } else { + mul_dims.push_back(std::accumulate( + conts.begin(), conts.end(), 1, std::multiplies())); + mul_dims.push_back(std::accumulate( + frees.begin(), frees.end(), 1, std::multiplies())); + } VLOG(5) << "PerformContraction: mul_dims: " << paddle::string::join_strings(mul_dims, ","); trans_t.Resize(make_ddim(mul_dims)); return trans_t; }; - auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0]); - auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1]); + + // Reduction, Reshape and Matmul + auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0], 0); + auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1], 1); auto after_contraction = - Matmul(dev_ctx, trans_a, trans_b, false, true); + Matmul(dev_ctx, trans_a, trans_b, false, false); VLOG(5) << "PerformContraction: recover_dim: " << paddle::string::join_strings(recover_dim, ","); after_contraction.Resize(make_ddim(recover_dim)); @@ -458,17 +529,23 @@ void TransposeToOutput(const Context& dev_ctx, axis.push_back(it - all_labels.begin() + offset); } } - if (is_no_need_transpose(axis)) return output->ShareBufferWith(to_trans); + if (is_no_need_transpose(axis)) { + output->ShareBufferWith(to_trans); + return; + } VLOG(5) << "call TransposeToOutput: with axis: " << paddle::string::join_strings(axis, ","); - return TransposeKernel(dev_ctx, to_trans, axis, output); + TransposeKernel(dev_ctx, to_trans, axis, output); } template -void EinsumKernel(const Context& dev_ctx, - const std::vector& inputs, - const std::string& equation, - DenseTensor* out) { +void EinsumKernelImpl(const Context& dev_ctx, + const std::vector& forward_all_labels, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out, + std::vector cache, + bool is_forward = true) { ValidationCheck(equation); // collect the following informations to prepare einsum. LabelMap labelshape(0); @@ -484,6 +561,9 @@ void EinsumKernel(const Context& dev_ctx, input_dims.push_back(i->dims()); } std::string right; + if (!is_forward) { + all_labels = forward_all_labels; + } ParseEinsumEquation(equation, input_dims, &labelshape, @@ -498,22 +578,18 @@ void EinsumKernel(const Context& dev_ctx, if (inputs.size() == 2) { auto& A = inputs[0]; auto& B = inputs[1]; - // Reduce Procedure - auto reduce_A = PerformReduction( - dev_ctx, *A, label2perms[0], all_labels, ellipsis_dims[0], labeltype); - auto reduce_B = PerformReduction( - dev_ctx, *B, label2perms[1], all_labels, ellipsis_dims[1], labeltype); - // Contract Procedure - dev_ctx.template Alloc(out); + // Reduction and Contract Procedure auto after_contraction = PerformContraction(dev_ctx, - reduce_A, - reduce_B, + *A, + *B, label2perms, all_labels, labeltype, labelshape, ellipsis_dims, - broadcast_dims); + broadcast_dims, + cache, + !is_forward); TransposeToOutput(dev_ctx, after_contraction, right, @@ -545,4 +621,37 @@ void EinsumKernel(const Context& dev_ctx, } } +template +void EinsumKernelRaw(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out, + std::vector cache) { + std::vector tmp; + // for the sake of compatibility, we may load and run v2.3 EinsumOp. Output + // may have nullptr and the cache.size() is not equal to inputs.size(). refer + // to BuildPhiKernelContext for details. + int diff = inputs.size() - cache.size(); + for (int i = 0; i < diff; ++i) { + cache.push_back(nullptr); + } + EinsumKernelImpl( + dev_ctx, tmp, inputs, equation, out, cache, /*forward=*/true); +} + +template +void EinsumKernel(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out) { + std::vector place_holder; + std::vector cache_tensor( + inputs.size()); // set empty; TA, TB, TdC + for (size_t i = 0; i < inputs.size(); ++i) { + cache_tensor[i] = nullptr; + } + EinsumKernelImpl( + dev_ctx, place_holder, inputs, equation, out, cache_tensor, true); +} + } // namespace phi diff --git a/paddle/phi/ops/compat/einsum_sig.cc b/paddle/phi/ops/compat/einsum_sig.cc index 0b3cc3425df453799f18c13595eff188d69480da..5e45bcf97ce0e5d79e0fe17be9c8122daa5b60bd 100644 --- a/paddle/phi/ops/compat/einsum_sig.cc +++ b/paddle/phi/ops/compat/einsum_sig.cc @@ -17,14 +17,15 @@ limitations under the License. */ namespace phi { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"}); + return KernelSignature( + "einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"}); } KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("einsum_grad", - {"Operands", {"Out@GRAD"}}, + {"Operands", "InnerCache", "Out@GRAD"}, {"equation"}, - {{"Operands@GRAD"}}); + {"Operands@GRAD"}); } } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_einsum.py b/python/paddle/fluid/tests/unittests/test_einsum.py index 43b5ce96a390150db7e29588e4107271b240b23f..26aaf0f44f1d2ad6d1239bb6b827feb94b8864d3 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum.py +++ b/python/paddle/fluid/tests/unittests/test_einsum.py @@ -18,6 +18,9 @@ import unittest import paddle from paddle.fluid import core +import os +os.environ['FLAGS_new_einsum'] = "0" + class TestErrors(unittest.TestCase): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/test_einsum_op.py b/python/paddle/fluid/tests/unittests/test_einsum_op.py index 565e43214ea3237a592ba4afa9b793091344fade..1a4ae54afefe242c28c4f8198605ed097aeab7e4 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_op.py @@ -34,7 +34,11 @@ class TestEinsumBinary(OpTest): self.operands.append(("x" + str(idx), inp)) self.inputs = {"Operands": self.operands} self.attrs = {"equation": self.equation} - self.outputs = {'Out': out} + self.outputs = { + 'Out': out, + "InnerCache": [('cache_' + str(i), np.array([1.0])) + for i in range(len(self.operands))] + } def init_input(self): self.inputs = [] @@ -49,7 +53,7 @@ class TestEinsumBinary(OpTest): def test_check_output(self): if not self.disable: - self.check_output() + self.check_output(no_check_set=["InnerCache"]) def test_grad(self): if not self.disable: diff --git a/python/paddle/fluid/tests/unittests/test_einsum_v2.py b/python/paddle/fluid/tests/unittests/test_einsum_v2.py index 63acaf63969139324f2f7a60784707285b9cb3a3..b33a943c9f27e20047703fa56d1b6d9a0cea73f7 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_v2.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_v2.py @@ -464,5 +464,37 @@ class TestNumpyTests(unittest.TestCase): self.check_output_equal(a, e) +class TestStaticGraphShape(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def test_shape(self): + A = paddle.static.data(name='x', shape=[-1]) + B = paddle.static.data(name='y', shape=[384]) + C = paddle.einsum('i,d->id', A, B) + self.assertEqual(C.shape, (-1, 384)) + + +class TestBF16(unittest.TestCase): + """ + EinsumOp support bfloat16 type, add unittest here for the correctness. + """ + + def test_shape(self): + cuda_major = paddle.version.cuda().split('.')[0].strip() + if paddle.is_compiled_with_cuda() and int(cuda_major) >= 11: + """ MatmulKernel support bfloat16 only if cuda_major > 11.0. + """ + A = paddle.to_tensor(np.array([1.0, 2.0])).astype(paddle.bfloat16) + A = A.cuda() + B = paddle.to_tensor(np.array([2.0, 3.0])).astype(paddle.bfloat16) + B = B.cuda() + C = paddle.einsum('i,i->', A, B) + self.assertEqual(C.item(), 8.0) + + if __name__ == "__main__": - u + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index 23bbc377cae2742eb4c57c77ad01afbc91679607..ea3264ba0dbb7ac63d2b06912ef9ea0baa499595 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -35,4 +35,5 @@ no_check_set_white_list = [ 'eigh', 'eigvalsh', 'class_center_sample', + 'einsum', ] diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index 713a611f9f39a11146d04f1e8385991221c0e51d..49cc426a00fd998c2ed24f94fb0002e4466065af 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -798,11 +798,12 @@ def gen_einsum_op(equation, *operands): """ assert len(operands) <= 2, "Only support two operands in EinsumOp." if in_dygraph_mode(): - return _C_ops.final_state_einsum(operands, equation) + return _C_ops.final_state_einsum(operands, equation)[0] if _in_legacy_dygraph(): # dygraph - return _C_ops.einsum(operands, 'equation', equation) + return _C_ops.einsum(operands, len(operands), 'equation', equation)[0] + # static graph for inp in operands: check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum') @@ -811,11 +812,16 @@ def gen_einsum_op(equation, *operands): out = helper.create_variable_for_type_inference(dtype=operands[0].dtype) attrs = dict() attrs['equation'] = equation + caches = [ + helper.create_variable_for_type_inference(dtype=operands[0].dtype) + for i in range(len(operands)) + ] helper.append_op( type='einsum', inputs={'Operands': operands}, - outputs={'Out': out}, - attrs=attrs, ) + outputs={'Out': out, + "InnerCache": caches}, + attrs=attrs) return out @@ -977,7 +983,7 @@ def einsum(equation, *operands): # [0.51476848, 0.23367381, 0.39229113]]]) """ import os - if int(os.environ.get('FLAGS_new_einsum', "0")): + if int(os.environ.get('FLAGS_new_einsum', "1")): return einsum_v2(equation, *operands) nop = len(operands) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index c0e2c39f8189be04dfd2e33b2a2ca598cc663119..845f6b6ba2fe0fec13a0b94faf53fead2a40a270 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -547,7 +547,7 @@ - api : einsum args : (Tensor[] x, str equation) - output : Tensor + output : Tensor, Tensor[]{x.size()} infer_meta : func : EinsumInferMeta param : [x, equation] diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index de4e72bbe410b5a97ac41fad64aa0ede79383085..b618e6994db0619c0d20cde8c65c79a457a0b2cf 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -205,17 +205,19 @@ class BaseAPI(object): if len(temp_list) == 1: out_type, out_name, size_expr = parse_output_item(temp_list[0]) - return [out_type], [out_name], size_expr, self.get_return_type( + return [out_type], [out_name], [size_expr], self.get_return_type( [out_type]) else: out_type_list = [] out_name_list = [] + out_size_expr_list = [] for output_item in temp_list: out_type, out_name, size_expr = parse_output_item(output_item) out_type_list.append(out_type) out_name_list.append(out_name) + out_size_expr_list.append(size_expr) - return out_type_list, out_name_list, size_expr, self.get_return_type( + return out_type_list, out_name_list, out_size_expr_list, self.get_return_type( out_type_list) def parse_infer_meta(self, infer_meta_config): diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index 538958c2361bc74b466af6c96b4bddcdcf6e9001..08c9edb57647ef16ac85287b91e893338dc6b521 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -94,10 +94,10 @@ class ForwardAPI(BaseAPI): {code_indent} {self.outputs['return_type']} api_output{inplace_assign};""" if self.outputs['return_type'] == 'std::vector': - assert self.outputs['out_size_expr'] is not None, \ + assert self.outputs['out_size_expr'][0] is not None, \ f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." output_create = output_create + f""" -{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr']}, kernel_backend, &api_output);""" +{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, kernel_backend, &api_output);""" else: output_create = output_create + f""" diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index f6e5f61e00a63f6c2067f83dff46999795dc29be..e16c57d4d898cec376c4bf274f6d3c80271a006a 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -447,8 +447,8 @@ skip_transform : out_w, out_w_grad - backward_api : einsum_grad - forward : einsum (Tensor[] x, str equation) -> Tensor(out) - args : (Tensor[] x, Tensor out_grad, str equation) + forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache) + args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation) output : Tensor[](x_grad){x.size()} infer_meta : func : UnchangedMultiInferMeta diff --git a/python/paddle/utils/code_gen/backward_api_gen.py b/python/paddle/utils/code_gen/backward_api_gen.py index a88339c607c555d35e953b051f153e0796c2f913..d3ed928592d2c0bbd95375ef11f8876a518a3b96 100644 --- a/python/paddle/utils/code_gen/backward_api_gen.py +++ b/python/paddle/utils/code_gen/backward_api_gen.py @@ -113,10 +113,10 @@ class BackwardAPI(BaseAPI): {code_indent} {self.outputs['return_type']} api_output{inplace_assign};""" if output_type_list[0] == 'std::vector': - assert self.outputs['out_size_expr'] is not None, \ + assert self.outputs['out_size_expr'][0] is not None, \ f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." output_create = output_create + f""" -{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr']}, kernel_backend, &api_output);""" +{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, kernel_backend, &api_output);""" else: output_create = output_create + f"""