未验证 提交 22e75d92 编写于 作者: X xiongkun 提交者: GitHub

[ CherryPick ] Cherry pick for einsum optimization. (#43468)

* [EinsumOp] Polish forward logic and backward logic for optimize (#42603)

* change logic for optimize

* modifty

* merge

* change einsum_v2 as default and add new flags: FLAG_einsum_opt=1|0 (#43010)

* [EinsumOp] Make EinsumOp support bfloat16. (#43085)

* change einsum_v2 as default and add new flags: FLAG_einsum_opt=1|0

* make EInsumOP support bf16

* add unittest for BF16

* add condition for test_BF16

* fix bugs

* fix

* change the backward api to fit einsum op
上级 afd0c1db
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace egr {
void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor) {
if (tensor.initialized()) {
auto& tensor_name = tensor.name();
const phi::DenseTensor* dense_tensor{nullptr};
if (tensor.is_dense_tensor()) {
dense_tensor = static_cast<const phi::DenseTensor*>(tensor.impl().get());
} else if (tensor.is_selected_rows()) {
dense_tensor = &(
static_cast<const phi::SelectedRows*>(tensor.impl().get())->value());
} else {
VLOG(10) << "Only DenseTensor or SelectedRows need to check, "
<< tensor_name << " is no need.";
return;
}
auto& place = dense_tensor->place();
if (paddle::platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle::framework::details::tensor_check<
paddle::platform::CUDADeviceContext>(api_name, tensor_name,
*dense_tensor, place);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"Tensor[%s] use gpu place. PaddlePaddle must compile with GPU.",
tensor_name));
#endif
return;
}
paddle::framework::details::tensor_check<
paddle::platform::CPUDeviceContext>(api_name, tensor_name,
*dense_tensor, place);
}
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTwoTensors& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfThreeTensors& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfFourTensors& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<3>(tensors));
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfFiveTensors& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<3>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<4>(tensors));
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfSixTensors& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<3>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<4>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<5>(tensors));
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const std::vector<Tensor>& tensors) {
for (auto& tensor : tensors) {
CheckTensorHasNanOrInf(api_name, tensor);
}
}
void CheckTensorHasNanOrInf(
const std::string& api_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>& tensors) {
for (auto& tensor_vector : tensors) {
CheckTensorHasNanOrInf(api_name, tensor_vector);
}
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
}
} // namespace egr
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <tuple>
#include <vector>
#include "paddle/fluid/eager/type_defs.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/utils/small_vector.h"
namespace egr {
using paddle::experimental::Tensor;
using TupleOfTwoTensors = std::tuple<Tensor, Tensor>;
using TupleOfThreeTensors = std::tuple<Tensor, Tensor, Tensor>;
using TupleOfFourTensors = std::tuple<Tensor, Tensor, Tensor, Tensor>;
using TupleOfFiveTensors = std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfSixTensors =
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfTensorAndVector = std::tuple<Tensor, std::vector<Tensor>>;
void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTwoTensors& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfThreeTensors& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfFourTensors& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfFiveTensors& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfSixTensors& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const std::vector<Tensor>& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors);
void CheckTensorHasNanOrInf(
const std::string& api_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>& tensors);
} // namespace egr
...@@ -33,6 +33,13 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -33,6 +33,13 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Operands", "(TensorList), The input tensor of einsum op.") AddInput("Operands", "(TensorList), The input tensor of einsum op.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "(Tensor), The output tensor of einsum op."); AddOutput("Out", "(Tensor), The output tensor of einsum op.");
AddOutput(
"InnerCache",
"(Tensor), The cache of the forward transpose tensors: tA and tB.")
.AsDuplicable()
.AsExtra()
.AsIntermediate();
AddAttr<std::string>("equation", AddAttr<std::string>("equation",
"(string) A einsum equation. such as `ij,jk->ik`" "(string) A einsum equation. such as `ij,jk->ik`"
"There must have `->` and the number of operands in " "There must have `->` and the number of operands in "
...@@ -72,6 +79,7 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -72,6 +79,7 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> retv) const override { void Apply(GradOpPtr<T> retv) const override {
retv->SetType("einsum_grad"); retv->SetType("einsum_grad");
retv->SetInput("Operands", this->Input("Operands")); retv->SetInput("Operands", this->Input("Operands"));
retv->SetInput("InnerCache", this->Output("InnerCache"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs()); retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("Operands"), retv->SetOutput(framework::GradVarName("Operands"),
......
...@@ -808,3 +808,16 @@ PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); ...@@ -808,3 +808,16 @@ PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait");
* Example: * Example:
*/ */
PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune."); PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune.");
/**
* Preformance related FLAG
* Name: einsum_opt
* Since Version: 2.3.0
* Value Range: bool, default=false
* Example:
* Note: If True, EinsumOp will be optimimzed by innercache reuse, which
* uses more gpu memory.
*/
PADDLE_DEFINE_EXPORTED_bool(
einsum_opt, false,
"EinsumOp backward will be speedup at the expense of more gpu memory.");
...@@ -401,7 +401,8 @@ void EighInferMeta(const MetaTensor& x, ...@@ -401,7 +401,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs, void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation, const std::string& equation,
MetaTensor* out) { MetaTensor* out,
std::vector<MetaTensor*> inner_cache) {
// collect the following informations to prepare einsum. // collect the following informations to prepare einsum.
LabelMap labelshape(0); LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction); LabelMap labeltype(LabelType::Reduction);
......
...@@ -82,7 +82,8 @@ void EighInferMeta(const MetaTensor& x, ...@@ -82,7 +82,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs, void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation, const std::string& equation,
MetaTensor* out); MetaTensor* out,
std::vector<MetaTensor*> inner_cache);
void ExpandInferMeta(const MetaTensor& x, void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape, const IntArray& shape,
......
...@@ -17,4 +17,5 @@ ...@@ -17,4 +17,5 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum, CPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {} PD_REGISTER_KERNEL(
einsum, CPU, ALL_LAYOUT, phi::EinsumKernelRaw, float, double) {}
...@@ -21,6 +21,7 @@ namespace phi { ...@@ -21,6 +21,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void EinsumGradKernel(const Context& dev_ctx, void EinsumGradKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x, const std::vector<const DenseTensor*>& x,
const std::vector<const DenseTensor*>& inner_cache,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const std::string& equation, const std::string& equation,
std::vector<DenseTensor*> x_grad); std::vector<DenseTensor*> x_grad);
......
...@@ -24,4 +24,11 @@ void EinsumKernel(const Context& dev_ctx, ...@@ -24,4 +24,11 @@ void EinsumKernel(const Context& dev_ctx,
const std::string& equation, const std::string& equation,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache);
} // namespace phi } // namespace phi
...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
...@@ -73,6 +74,7 @@ struct EigenBroadcastGrad<Eigen::DefaultDevice, T, Rank> { ...@@ -73,6 +74,7 @@ struct EigenBroadcastGrad<Eigen::DefaultDevice, T, Rank> {
template struct FUNCTOR<Eigen::DefaultDevice, T, 6> template struct FUNCTOR<Eigen::DefaultDevice, T, 6>
INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, bool);
INSTANTIATION(EigenBroadcast, dtype::float16); INSTANTIATION(EigenBroadcast, dtype::float16);
INSTANTIATION(EigenBroadcast, dtype::bfloat16);
INSTANTIATION(EigenBroadcast, float); INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double); INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, int); INSTANTIATION(EigenBroadcast, int);
......
...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
...@@ -73,6 +74,7 @@ struct EigenBroadcastGrad<Eigen::GpuDevice, T, Rank> { ...@@ -73,6 +74,7 @@ struct EigenBroadcastGrad<Eigen::GpuDevice, T, Rank> {
template struct FUNCTOR<Eigen::GpuDevice, T, 6> template struct FUNCTOR<Eigen::GpuDevice, T, 6>
INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, bool);
INSTANTIATION(EigenBroadcast, dtype::float16); INSTANTIATION(EigenBroadcast, dtype::float16);
INSTANTIATION(EigenBroadcast, dtype::bfloat16);
INSTANTIATION(EigenBroadcast, float); INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double); INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, int); INSTANTIATION(EigenBroadcast, int);
......
...@@ -18,5 +18,11 @@ ...@@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_grad_impl.h" #include "paddle/phi/kernels/impl/einsum_grad_impl.h"
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(einsum_grad,
einsum_grad, GPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::EinsumGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -18,4 +18,11 @@ ...@@ -18,4 +18,11 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum, GPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {} PD_REGISTER_KERNEL(einsum,
GPU,
ALL_LAYOUT,
phi::EinsumKernelRaw,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile, ...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/platform/profiler.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/impl/einsum_impl.h"
#include "paddle/phi/kernels/tile_kernel.h" #include "paddle/phi/kernels/tile_kernel.h"
...@@ -55,7 +56,13 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, ...@@ -55,7 +56,13 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
} }
t.Resize(make_ddim(resize_dims)); t.Resize(make_ddim(resize_dims));
DenseTensor after_tile; DenseTensor after_tile;
if (std::all_of(repeat_times.begin(), repeat_times.end(), [](int x) {
return x == 1;
})) {
after_tile = t;
} else {
TileKernel<T, Context>(dev_ctx, t, repeat_times, &after_tile); TileKernel<T, Context>(dev_ctx, t, repeat_times, &after_tile);
}
size_t n_ellipsis_idx = op_label.find(".", 0); size_t n_ellipsis_idx = op_label.find(".", 0);
if (n_ellipsis_idx != std::string::npos) { if (n_ellipsis_idx != std::string::npos) {
// may be we need reduce. broadcast_dims is not equal to ellipsis dims. // may be we need reduce. broadcast_dims is not equal to ellipsis dims.
...@@ -91,10 +98,11 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, ...@@ -91,10 +98,11 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void EinsumGradKernel(const Context& dev_ctx, void EinsumGradKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x, const std::vector<const DenseTensor*>& x,
const std::vector<const DenseTensor*>& inner_cache,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const std::string& equation, const std::string& equation,
std::vector<DenseTensor*> x_grad) { std::vector<DenseTensor*> x_grad) {
VLOG(5) << "Start EisumGradKernel:"; VLOG(5) << "Start EinsumGradKernel:";
LabelMap labelshape(0); LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction); LabelMap labeltype(LabelType::Reduction);
std::vector<LabelMap> label2perms(x.size(), LabelMap(-1)); std::vector<LabelMap> label2perms(x.size(), LabelMap(-1));
...@@ -148,20 +156,48 @@ void EinsumGradKernel(const Context& dev_ctx, ...@@ -148,20 +156,48 @@ void EinsumGradKernel(const Context& dev_ctx,
right = splits[1].substr(1); right = splits[1].substr(1);
auto equation_for_A = auto equation_for_A =
right + "," + ops[1] + "->" + gather_labels_except_reduction(ops[0]); ops[1] + "," + right + "->" + gather_labels_except_reduction(ops[0]);
auto equation_for_B = auto equation_for_B =
right + "," + ops[0] + "->" + gather_labels_except_reduction(ops[1]); right + "," + ops[0] + "->" + gather_labels_except_reduction(ops[1]);
auto operands_for_A = std::vector<const DenseTensor*>(); auto operands_for_A = std::vector<const DenseTensor*>();
auto operands_for_B = std::vector<const DenseTensor*>(); auto operands_for_B = std::vector<const DenseTensor*>();
DenseTensor dA, dB; DenseTensor dA, dB;
operands_for_A.push_back(&out_grad); // dA = einsum(B, dC)
operands_for_A.push_back(x[1]); operands_for_A.push_back(x[1]);
operands_for_A.push_back(&out_grad);
// dB = einsum(dC, A)
operands_for_B.push_back(&out_grad); operands_for_B.push_back(&out_grad);
operands_for_B.push_back(x[0]); operands_for_B.push_back(x[0]);
DenseTensor before_tile; DenseTensor before_tile;
EinsumKernel<T, Context>(dev_ctx, operands_for_A, equation_for_A, &dA);
EinsumKernel<T, Context>(dev_ctx, operands_for_B, equation_for_B, &dB); std::vector<DenseTensor> cache(3); // set empty; TA, TB, TdC
if (inner_cache.size() >
0) { // for compatibility, we can load and run v2.3 EinsumOp.
cache[0].ShareBufferWith(*(inner_cache[0]));
cache[1].ShareBufferWith(*(inner_cache[1]));
}
EinsumKernelImpl<T, Context>(dev_ctx,
all_labels,
operands_for_A,
equation_for_A,
&dA,
{&cache[1], &cache[2]},
false);
EinsumKernelImpl<T, Context>(dev_ctx,
all_labels,
operands_for_B,
equation_for_B,
&dB,
{&cache[2], &cache[0]},
false);
// release the cache tensor dTC to save memory right now. they are useless
// now.
cache.clear();
if (x_grad[0]) {
*(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx, *(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype, labeltype,
labelshape, labelshape,
...@@ -169,6 +205,8 @@ void EinsumGradKernel(const Context& dev_ctx, ...@@ -169,6 +205,8 @@ void EinsumGradKernel(const Context& dev_ctx,
ellipsis_dims[0], ellipsis_dims[0],
ops[0], ops[0],
dA); dA);
}
if (x_grad[1]) {
*(x_grad[1]) = PerformTileAndReduction<T, Context>(dev_ctx, *(x_grad[1]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype, labeltype,
labelshape, labelshape,
...@@ -177,5 +215,6 @@ void EinsumGradKernel(const Context& dev_ctx, ...@@ -177,5 +215,6 @@ void EinsumGradKernel(const Context& dev_ctx,
ops[1], ops[1],
dB); dB);
} }
}
} }
} // namespace phi } // namespace phi
...@@ -13,12 +13,15 @@ ...@@ -13,12 +13,15 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <set>
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/utils/string/string_helper.h" #include "paddle/utils/string/string_helper.h"
DECLARE_bool(einsum_opt);
namespace phi { namespace phi {
// check the validation of the Einsum equation. // check the validation of the Einsum equation.
...@@ -55,7 +58,8 @@ inline static void ValidationCheck(const std::string& equation) { ...@@ -55,7 +58,8 @@ inline static void ValidationCheck(const std::string& equation) {
enum LabelType { enum LabelType {
ALL_TYPE = 0, ALL_TYPE = 0,
Batch = 1, // ABO Batch = 1, // ABO
Free, // AO, BO AO, // AO -- free label
BO, // BO -- free label
Contraction, // AB Contraction, // AB
Reduction, // A, B Reduction, // A, B
}; };
...@@ -125,18 +129,31 @@ inline std::vector<char> union_labels(const std::vector<char>& a, ...@@ -125,18 +129,31 @@ inline std::vector<char> union_labels(const std::vector<char>& a,
return res; return res;
} }
// Apply transforms to all_labels and get another all_labels
inline std::vector<char> TransformLabelsOrder(
const std::vector<char>& all_labels,
const LabelMap& type,
std::vector<LabelType> new_order) {
std::vector<char> ret;
for (auto cnt_type : new_order) {
std::vector<char> tmp;
for (int c : all_labels) {
if (type[c] == cnt_type) tmp.push_back(c);
}
ret.insert(ret.end(), tmp.begin(), tmp.end());
}
return ret;
}
inline static void GlobalInfo(const std::vector<std::string>& op_labels, inline static void GlobalInfo(const std::vector<std::string>& op_labels,
const std::string& right, const std::string& right,
LabelMap* label2type, LabelMap* label2type,
std::vector<char>* sorted_labels) { std::vector<char>* sorted_labels) {
// sorted_labels: ['.', <right>, <left only label>]
VLOG(5) << "GlobalInfo: "
<< paddle::string::join_strings(*sorted_labels, ",");
std::vector<char> all; std::vector<char> all;
LabelMap counter(0); LabelMap counter(0);
for (auto& ch : right) { // char for (auto& ch : right) { // char
int c = ch; int c = ch;
(*label2type)[c] = LabelType::Free; (*label2type)[c] = LabelType::BO;
} }
for (auto& op : op_labels) { for (auto& op : op_labels) {
...@@ -146,39 +163,45 @@ inline static void GlobalInfo(const std::vector<std::string>& op_labels, ...@@ -146,39 +163,45 @@ inline static void GlobalInfo(const std::vector<std::string>& op_labels,
all.push_back(ch); all.push_back(ch);
} }
counter[c] += 1; counter[c] += 1;
if ((*label2type)[c] != LabelType::Free && counter[c] == 2) if ((*label2type)[c] != LabelType::BO && counter[c] == 2)
(*label2type)[c] = LabelType::Contraction; (*label2type)[c] = LabelType::Contraction;
else if (counter[c] == 2) else if (counter[c] == 2)
(*label2type)[c] = LabelType::Batch; (*label2type)[c] = LabelType::Batch;
} }
} }
// BO is represent Free, so we need find the AO.
for (int c : op_labels[0]) {
if ((*label2type)[c] == LabelType::BO) (*label2type)[c] = LabelType::AO;
}
(*label2type)['.'] = LabelType::Batch; (*label2type)['.'] = LabelType::Batch;
std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) {
if ((*label2type)[c] == LabelType::Batch) if (sorted_labels->size()) {
sorted_labels->push_back(static_cast<char>(c)); std::set<char> exist(all.begin(), all.end());
}); all.clear();
std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) { std::for_each(
if ((*label2type)[c] == LabelType::Free) sorted_labels->begin(), sorted_labels->end(), [&exist, &all](char c) {
sorted_labels->push_back(static_cast<char>(c)); if (exist.count(c)) all.push_back(c);
});
std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) {
if ((*label2type)[c] == LabelType::Contraction)
sorted_labels->push_back(static_cast<char>(c));
});
std::for_each(all.begin(), all.end(), [&sorted_labels, label2type](int c) {
if ((*label2type)[c] == LabelType::Reduction)
sorted_labels->push_back(static_cast<char>(c));
}); });
VLOG(5) << "GlobalInfo: sorted_labels before: " }
<< paddle::string::join_strings(*sorted_labels, ",");
*sorted_labels = TransformLabelsOrder(all,
*label2type,
{LabelType::Batch,
LabelType::AO,
LabelType::BO,
LabelType::Contraction,
LabelType::Reduction});
if (counter[static_cast<int>('.')] > 0) { if (counter[static_cast<int>('.')] > 0) {
std::vector<char> tmp; std::vector<char> tmp;
tmp.push_back('.'); tmp.push_back('.');
// push '.' in the front // push '.' in the front
*sorted_labels = union_labels(tmp, *sorted_labels); *sorted_labels = union_labels(tmp, *sorted_labels);
}
VLOG(5) << "GlobalInfo: sorted_labels after: " VLOG(5) << "GlobalInfo: sorted_labels after: "
<< paddle::string::join_strings(*sorted_labels, ","); << paddle::string::join_strings(*sorted_labels, ",");
}
} }
inline static void InferLabelShape(const std::vector<std::string>& op_labels, inline static void InferLabelShape(const std::vector<std::string>& op_labels,
...@@ -289,17 +312,20 @@ inline static void ParseEinsumEquation( ...@@ -289,17 +312,20 @@ inline static void ParseEinsumEquation(
*right = results[1].substr(1); *right = results[1].substr(1);
ReplaceEllipsis(*right); ReplaceEllipsis(*right);
auto op_labels = paddle::string::split_string(left, ","); auto op_labels = paddle::string::split_string(left, ",");
// split_string("i,") -> ["i"], we expect 2 op_labels.
if (left[left.size() - 1] == ',') op_labels.push_back("");
std::for_each(op_labels.begin(), op_labels.end(), ReplaceEllipsis); std::for_each(op_labels.begin(), op_labels.end(), ReplaceEllipsis);
GlobalInfo(op_labels, *right, labeltype, all_labels); GlobalInfo(op_labels, *right, labeltype, all_labels);
InferLabelShape(op_labels, inputs, labelshape, ellipsis_dims, broadcast_dims); InferLabelShape(op_labels, inputs, labelshape, ellipsis_dims, broadcast_dims);
VLOG(5) << "Einsum Infershape: right:" << right; VLOG(5) << "Einsum Infershape: right:" << *right;
VLOG(5) << "Einsum Infershape: op_labels:" VLOG(5) << "Einsum Infershape: left :"
<< paddle::string::join_strings(op_labels, "\n"); << paddle::string::join_strings(op_labels, '\n');
InferOutputDims(*right, *broadcast_dims, *labelshape, output_dims); InferOutputDims(*right, *broadcast_dims, *labelshape, output_dims);
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
InferLabelPerm( InferLabelPerm(
op_labels[i], ellipsis_dims->at(i).size(), &((*label2perms)[i])); op_labels[i], ellipsis_dims->at(i).size(), &((*label2perms)[i]));
} }
VLOG(5) << "Einsum Infershape: end";
} }
template <typename T> template <typename T>
...@@ -327,10 +353,12 @@ std::vector<T> GetShapeByType(const std::vector<char>& all_labels, ...@@ -327,10 +353,12 @@ std::vector<T> GetShapeByType(const std::vector<char>& all_labels,
const LabelMap& perm, const LabelMap& perm,
const LabelMap& label2shape, const LabelMap& label2shape,
const std::vector<int>& ellipsis, const std::vector<int>& ellipsis,
LabelType filter) { std::set<LabelType> filter) {
std::vector<T> res; std::vector<T> res;
for (T c : all_labels) { for (T c : all_labels) {
if ((filter == LabelType::ALL_TYPE || type[c] == filter) && perm[c] != -1) { if ((filter.count(LabelType::ALL_TYPE) ||
filter.count(LabelType(type[c]))) &&
perm[c] != -1) {
if (c == '.') if (c == '.')
res.insert(res.end(), ellipsis.begin(), ellipsis.end()); res.insert(res.end(), ellipsis.begin(), ellipsis.end());
else else
...@@ -390,7 +418,9 @@ DenseTensor PerformContraction( ...@@ -390,7 +418,9 @@ DenseTensor PerformContraction(
const LabelMap& label2type, const LabelMap& label2type,
const LabelMap& label2shape, const LabelMap& label2shape,
const std::vector<std::vector<int>>& ellipsis_dims, const std::vector<std::vector<int>>& ellipsis_dims,
const std::vector<int>& broadcast_dims) { const std::vector<int>& broadcast_dims,
std::vector<DenseTensor*> cache,
bool use_cache) {
// Get All the Batches, so perm is // Get All the Batches, so perm is
auto all_valid = LabelMap(1); auto all_valid = LabelMap(1);
auto recover_dim = GetShapeByType<int>(all_labels, auto recover_dim = GetShapeByType<int>(all_labels,
...@@ -398,36 +428,77 @@ DenseTensor PerformContraction( ...@@ -398,36 +428,77 @@ DenseTensor PerformContraction(
all_valid, all_valid,
label2shape, label2shape,
broadcast_dims, broadcast_dims,
LabelType::Batch); {LabelType::Batch});
auto preprocess = [&](const DenseTensor& t, auto preprocess = [&](const DenseTensor& t,
const LabelMap& perm, const LabelMap& perm,
const std::vector<int>& ellipsis) -> DenseTensor { const std::vector<int>& ellipsis,
auto frees = GetShapeByType<int>( int operand_idx) -> DenseTensor {
all_labels, label2type, perm, label2shape, ellipsis, LabelType::Free); // reshape
auto frees = GetShapeByType<int>(all_labels,
label2type,
perm,
label2shape,
ellipsis,
{LabelType::AO, LabelType::BO});
auto conts = GetShapeByType<int>(all_labels, auto conts = GetShapeByType<int>(all_labels,
label2type, label2type,
perm, perm,
label2shape, label2shape,
ellipsis, ellipsis,
LabelType::Contraction); {LabelType::Contraction});
auto trans_t = PerformTranspose<T, Context>( std::vector<char> reordered_all_labels = all_labels;
if (operand_idx == 1) {
reordered_all_labels = TransformLabelsOrder(all_labels,
label2type,
{LabelType::Batch,
LabelType::Contraction,
LabelType::AO,
LabelType::BO,
LabelType::Reduction});
}
// reduction
DenseTensor trans_t;
if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr &&
cache[operand_idx]->IsInitialized()) {
trans_t.ShareBufferWith(*(cache[operand_idx]));
VLOG(5) << "Cache Used!";
} else {
auto reduct_t = PerformReduction<T, Context>(
dev_ctx, t, perm, all_labels, ellipsis, label2type); dev_ctx, t, perm, all_labels, ellipsis, label2type);
auto mul_dims = GetShapeByType<int>( trans_t = PerformTranspose<T, Context>(
all_labels, label2type, perm, label2shape, ellipsis, LabelType::Batch); dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
if (FLAGS_einsum_opt && cache[operand_idx] != nullptr)
cache[operand_idx]->ShareBufferWith(trans_t);
}
auto mul_dims = GetShapeByType<int>(all_labels,
label2type,
perm,
label2shape,
ellipsis,
{LabelType::Batch});
recover_dim.insert(recover_dim.end(), frees.begin(), frees.end()); recover_dim.insert(recover_dim.end(), frees.begin(), frees.end());
mul_dims.push_back( if (operand_idx == 0) {
std::accumulate(frees.begin(), frees.end(), 1, std::multiplies<int>())); mul_dims.push_back(std::accumulate(
mul_dims.push_back( frees.begin(), frees.end(), 1, std::multiplies<int>()));
std::accumulate(conts.begin(), conts.end(), 1, std::multiplies<int>())); mul_dims.push_back(std::accumulate(
conts.begin(), conts.end(), 1, std::multiplies<int>()));
} else {
mul_dims.push_back(std::accumulate(
conts.begin(), conts.end(), 1, std::multiplies<int>()));
mul_dims.push_back(std::accumulate(
frees.begin(), frees.end(), 1, std::multiplies<int>()));
}
VLOG(5) << "PerformContraction: mul_dims: " VLOG(5) << "PerformContraction: mul_dims: "
<< paddle::string::join_strings(mul_dims, ","); << paddle::string::join_strings(mul_dims, ",");
trans_t.Resize(make_ddim(mul_dims)); trans_t.Resize(make_ddim(mul_dims));
return trans_t; return trans_t;
}; };
auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0]);
auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1]); // Reduction, Reshape and Matmul
auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0], 0);
auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1], 1);
auto after_contraction = auto after_contraction =
Matmul<T, Context>(dev_ctx, trans_a, trans_b, false, true); Matmul<T, Context>(dev_ctx, trans_a, trans_b, false, false);
VLOG(5) << "PerformContraction: recover_dim: " VLOG(5) << "PerformContraction: recover_dim: "
<< paddle::string::join_strings(recover_dim, ","); << paddle::string::join_strings(recover_dim, ",");
after_contraction.Resize(make_ddim(recover_dim)); after_contraction.Resize(make_ddim(recover_dim));
...@@ -458,17 +529,23 @@ void TransposeToOutput(const Context& dev_ctx, ...@@ -458,17 +529,23 @@ void TransposeToOutput(const Context& dev_ctx,
axis.push_back(it - all_labels.begin() + offset); axis.push_back(it - all_labels.begin() + offset);
} }
} }
if (is_no_need_transpose(axis)) return output->ShareBufferWith(to_trans); if (is_no_need_transpose(axis)) {
output->ShareBufferWith(to_trans);
return;
}
VLOG(5) << "call TransposeToOutput: with axis: " VLOG(5) << "call TransposeToOutput: with axis: "
<< paddle::string::join_strings(axis, ","); << paddle::string::join_strings(axis, ",");
return TransposeKernel<T, Context>(dev_ctx, to_trans, axis, output); TransposeKernel<T, Context>(dev_ctx, to_trans, axis, output);
} }
template <typename T, typename Context> template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx, void EinsumKernelImpl(const Context& dev_ctx,
const std::vector<char>& forward_all_labels,
const std::vector<const DenseTensor*>& inputs, const std::vector<const DenseTensor*>& inputs,
const std::string& equation, const std::string& equation,
DenseTensor* out) { DenseTensor* out,
std::vector<DenseTensor*> cache,
bool is_forward = true) {
ValidationCheck(equation); ValidationCheck(equation);
// collect the following informations to prepare einsum. // collect the following informations to prepare einsum.
LabelMap labelshape(0); LabelMap labelshape(0);
...@@ -484,6 +561,9 @@ void EinsumKernel(const Context& dev_ctx, ...@@ -484,6 +561,9 @@ void EinsumKernel(const Context& dev_ctx,
input_dims.push_back(i->dims()); input_dims.push_back(i->dims());
} }
std::string right; std::string right;
if (!is_forward) {
all_labels = forward_all_labels;
}
ParseEinsumEquation(equation, ParseEinsumEquation(equation,
input_dims, input_dims,
&labelshape, &labelshape,
...@@ -498,22 +578,18 @@ void EinsumKernel(const Context& dev_ctx, ...@@ -498,22 +578,18 @@ void EinsumKernel(const Context& dev_ctx,
if (inputs.size() == 2) { if (inputs.size() == 2) {
auto& A = inputs[0]; auto& A = inputs[0];
auto& B = inputs[1]; auto& B = inputs[1];
// Reduce Procedure // Reduction and Contract Procedure
auto reduce_A = PerformReduction<T, Context>(
dev_ctx, *A, label2perms[0], all_labels, ellipsis_dims[0], labeltype);
auto reduce_B = PerformReduction<T, Context>(
dev_ctx, *B, label2perms[1], all_labels, ellipsis_dims[1], labeltype);
// Contract Procedure
dev_ctx.template Alloc<T>(out);
auto after_contraction = PerformContraction<T, Context>(dev_ctx, auto after_contraction = PerformContraction<T, Context>(dev_ctx,
reduce_A, *A,
reduce_B, *B,
label2perms, label2perms,
all_labels, all_labels,
labeltype, labeltype,
labelshape, labelshape,
ellipsis_dims, ellipsis_dims,
broadcast_dims); broadcast_dims,
cache,
!is_forward);
TransposeToOutput<T, Context>(dev_ctx, TransposeToOutput<T, Context>(dev_ctx,
after_contraction, after_contraction,
right, right,
...@@ -545,4 +621,37 @@ void EinsumKernel(const Context& dev_ctx, ...@@ -545,4 +621,37 @@ void EinsumKernel(const Context& dev_ctx,
} }
} }
template <typename T, typename Context>
void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache) {
std::vector<char> tmp;
// for the sake of compatibility, we may load and run v2.3 EinsumOp. Output
// may have nullptr and the cache.size() is not equal to inputs.size(). refer
// to BuildPhiKernelContext for details.
int diff = inputs.size() - cache.size();
for (int i = 0; i < diff; ++i) {
cache.push_back(nullptr);
}
EinsumKernelImpl<T, Context>(
dev_ctx, tmp, inputs, equation, out, cache, /*forward=*/true);
}
template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out) {
std::vector<char> place_holder;
std::vector<DenseTensor*> cache_tensor(
inputs.size()); // set empty; TA, TB, TdC
for (size_t i = 0; i < inputs.size(); ++i) {
cache_tensor[i] = nullptr;
}
EinsumKernelImpl<T, Context>(
dev_ctx, place_holder, inputs, equation, out, cache_tensor, true);
}
} // namespace phi } // namespace phi
...@@ -17,14 +17,15 @@ limitations under the License. */ ...@@ -17,14 +17,15 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"}); return KernelSignature(
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"});
} }
KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("einsum_grad", return KernelSignature("einsum_grad",
{"Operands", {"Out@GRAD"}}, {"Operands", "InnerCache", "Out@GRAD"},
{"equation"}, {"equation"},
{{"Operands@GRAD"}}); {"Operands@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -18,6 +18,9 @@ import unittest ...@@ -18,6 +18,9 @@ import unittest
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
import os
os.environ['FLAGS_new_einsum'] = "0"
class TestErrors(unittest.TestCase): class TestErrors(unittest.TestCase):
def setUp(self): def setUp(self):
......
...@@ -34,7 +34,11 @@ class TestEinsumBinary(OpTest): ...@@ -34,7 +34,11 @@ class TestEinsumBinary(OpTest):
self.operands.append(("x" + str(idx), inp)) self.operands.append(("x" + str(idx), inp))
self.inputs = {"Operands": self.operands} self.inputs = {"Operands": self.operands}
self.attrs = {"equation": self.equation} self.attrs = {"equation": self.equation}
self.outputs = {'Out': out} self.outputs = {
'Out': out,
"InnerCache": [('cache_' + str(i), np.array([1.0]))
for i in range(len(self.operands))]
}
def init_input(self): def init_input(self):
self.inputs = [] self.inputs = []
...@@ -49,7 +53,7 @@ class TestEinsumBinary(OpTest): ...@@ -49,7 +53,7 @@ class TestEinsumBinary(OpTest):
def test_check_output(self): def test_check_output(self):
if not self.disable: if not self.disable:
self.check_output() self.check_output(no_check_set=["InnerCache"])
def test_grad(self): def test_grad(self):
if not self.disable: if not self.disable:
......
...@@ -464,5 +464,37 @@ class TestNumpyTests(unittest.TestCase): ...@@ -464,5 +464,37 @@ class TestNumpyTests(unittest.TestCase):
self.check_output_equal(a, e) self.check_output_equal(a, e)
class TestStaticGraphShape(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_shape(self):
A = paddle.static.data(name='x', shape=[-1])
B = paddle.static.data(name='y', shape=[384])
C = paddle.einsum('i,d->id', A, B)
self.assertEqual(C.shape, (-1, 384))
class TestBF16(unittest.TestCase):
"""
EinsumOp support bfloat16 type, add unittest here for the correctness.
"""
def test_shape(self):
cuda_major = paddle.version.cuda().split('.')[0].strip()
if paddle.is_compiled_with_cuda() and int(cuda_major) >= 11:
""" MatmulKernel support bfloat16 only if cuda_major > 11.0.
"""
A = paddle.to_tensor(np.array([1.0, 2.0])).astype(paddle.bfloat16)
A = A.cuda()
B = paddle.to_tensor(np.array([2.0, 3.0])).astype(paddle.bfloat16)
B = B.cuda()
C = paddle.einsum('i,i->', A, B)
self.assertEqual(C.item(), 8.0)
if __name__ == "__main__": if __name__ == "__main__":
u unittest.main()
...@@ -35,4 +35,5 @@ no_check_set_white_list = [ ...@@ -35,4 +35,5 @@ no_check_set_white_list = [
'eigh', 'eigh',
'eigvalsh', 'eigvalsh',
'class_center_sample', 'class_center_sample',
'einsum',
] ]
...@@ -798,11 +798,12 @@ def gen_einsum_op(equation, *operands): ...@@ -798,11 +798,12 @@ def gen_einsum_op(equation, *operands):
""" """
assert len(operands) <= 2, "Only support two operands in EinsumOp." assert len(operands) <= 2, "Only support two operands in EinsumOp."
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.final_state_einsum(operands, equation) return _C_ops.final_state_einsum(operands, equation)[0]
if _in_legacy_dygraph(): if _in_legacy_dygraph():
# dygraph # dygraph
return _C_ops.einsum(operands, 'equation', equation) return _C_ops.einsum(operands, len(operands), 'equation', equation)[0]
# static graph # static graph
for inp in operands: for inp in operands:
check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum') check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum')
...@@ -811,11 +812,16 @@ def gen_einsum_op(equation, *operands): ...@@ -811,11 +812,16 @@ def gen_einsum_op(equation, *operands):
out = helper.create_variable_for_type_inference(dtype=operands[0].dtype) out = helper.create_variable_for_type_inference(dtype=operands[0].dtype)
attrs = dict() attrs = dict()
attrs['equation'] = equation attrs['equation'] = equation
caches = [
helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands))
]
helper.append_op( helper.append_op(
type='einsum', type='einsum',
inputs={'Operands': operands}, inputs={'Operands': operands},
outputs={'Out': out}, outputs={'Out': out,
attrs=attrs, ) "InnerCache": caches},
attrs=attrs)
return out return out
...@@ -977,7 +983,7 @@ def einsum(equation, *operands): ...@@ -977,7 +983,7 @@ def einsum(equation, *operands):
# [0.51476848, 0.23367381, 0.39229113]]]) # [0.51476848, 0.23367381, 0.39229113]]])
""" """
import os import os
if int(os.environ.get('FLAGS_new_einsum', "0")): if int(os.environ.get('FLAGS_new_einsum', "1")):
return einsum_v2(equation, *operands) return einsum_v2(equation, *operands)
nop = len(operands) nop = len(operands)
......
...@@ -547,7 +547,7 @@ ...@@ -547,7 +547,7 @@
- api : einsum - api : einsum
args : (Tensor[] x, str equation) args : (Tensor[] x, str equation)
output : Tensor output : Tensor, Tensor[]{x.size()}
infer_meta : infer_meta :
func : EinsumInferMeta func : EinsumInferMeta
param : [x, equation] param : [x, equation]
......
...@@ -205,17 +205,19 @@ class BaseAPI(object): ...@@ -205,17 +205,19 @@ class BaseAPI(object):
if len(temp_list) == 1: if len(temp_list) == 1:
out_type, out_name, size_expr = parse_output_item(temp_list[0]) out_type, out_name, size_expr = parse_output_item(temp_list[0])
return [out_type], [out_name], size_expr, self.get_return_type( return [out_type], [out_name], [size_expr], self.get_return_type(
[out_type]) [out_type])
else: else:
out_type_list = [] out_type_list = []
out_name_list = [] out_name_list = []
out_size_expr_list = []
for output_item in temp_list: for output_item in temp_list:
out_type, out_name, size_expr = parse_output_item(output_item) out_type, out_name, size_expr = parse_output_item(output_item)
out_type_list.append(out_type) out_type_list.append(out_type)
out_name_list.append(out_name) out_name_list.append(out_name)
out_size_expr_list.append(size_expr)
return out_type_list, out_name_list, size_expr, self.get_return_type( return out_type_list, out_name_list, out_size_expr_list, self.get_return_type(
out_type_list) out_type_list)
def parse_infer_meta(self, infer_meta_config): def parse_infer_meta(self, infer_meta_config):
......
...@@ -94,10 +94,10 @@ class ForwardAPI(BaseAPI): ...@@ -94,10 +94,10 @@ class ForwardAPI(BaseAPI):
{code_indent} {self.outputs['return_type']} api_output{inplace_assign};""" {code_indent} {self.outputs['return_type']} api_output{inplace_assign};"""
if self.outputs['return_type'] == 'std::vector<Tensor>': if self.outputs['return_type'] == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'] is not None, \ assert self.outputs['out_size_expr'][0] is not None, \
f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr']}, kernel_backend, &api_output);""" {code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, kernel_backend, &api_output);"""
else: else:
output_create = output_create + f""" output_create = output_create + f"""
......
...@@ -447,8 +447,8 @@ ...@@ -447,8 +447,8 @@
skip_transform : out_w, out_w_grad skip_transform : out_w, out_w_grad
- backward_api : einsum_grad - backward_api : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out) forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache)
args : (Tensor[] x, Tensor out_grad, str equation) args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x.size()} output : Tensor[](x_grad){x.size()}
infer_meta : infer_meta :
func : UnchangedMultiInferMeta func : UnchangedMultiInferMeta
......
...@@ -113,10 +113,10 @@ class BackwardAPI(BaseAPI): ...@@ -113,10 +113,10 @@ class BackwardAPI(BaseAPI):
{code_indent} {self.outputs['return_type']} api_output{inplace_assign};""" {code_indent} {self.outputs['return_type']} api_output{inplace_assign};"""
if output_type_list[0] == 'std::vector<Tensor>': if output_type_list[0] == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'] is not None, \ assert self.outputs['out_size_expr'][0] is not None, \
f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr']}, kernel_backend, &api_output);""" {code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, kernel_backend, &api_output);"""
else: else:
output_create = output_create + f""" output_create = output_create + f"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册