未验证 提交 a3f28a31 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

【Phi】Migrate triangular_solve op into phi (#40093)

* Migrate triangular_solve op into phi

* fix CI

* move MatrixReduceSum to phi funcs

* move MatrixReduceSum to phi funcs

* fix comment

* fic CI
上级 e7afa391
...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/triangular_solve_op.h" #include "paddle/fluid/operators/triangular_solve_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/solve_op.h" #include "paddle/fluid/operators/solve_op.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,58 +25,6 @@ class TriangularSolveOp : public framework::OperatorWithKernel { ...@@ -22,58 +25,6 @@ class TriangularSolveOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "TriangularSolve");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "TriangularSolve");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "TriangularSolve");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto x_dims_n = x_dims.size();
auto y_dims_n = y_dims.size();
PADDLE_ENFORCE_GE(
x_dims_n, 2, platform::errors::InvalidArgument(
"The input tensor X's dimensions of TriangularSolveOp "
"should be >= 2. But received X's "
"dimensions = %d, X's shape = [%s]",
x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(
y_dims_n, 2, platform::errors::InvalidArgument(
"The input tensor Y's dimensions of TriangularSolveOp "
"should be >=2. But received Y's "
"dimensions = %d, Y's shape = [%s]",
y_dims.size(), y_dims));
PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2], x_dims[x_dims_n - 1],
platform::errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) all should "
"be square matrices "
"But received X's shape[-2] = %d and shape[-1] = %d.",
x_dims[x_dims_n - 2], x_dims[x_dims_n - 1]));
std::vector<int64_t> x_dims_vec = phi::vectorize(x_dims);
std::vector<int64_t> y_dims_vec = phi::vectorize(y_dims);
std::vector<int64_t> x_dims_vec_cut(x_dims_vec.begin(),
x_dims_vec.end() - 2);
std::vector<int64_t> y_dims_vec_cut(y_dims_vec.begin(),
y_dims_vec.end() - 2);
std::vector<int64_t> expand_batch_portion =
get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut);
std::vector<int64_t> y_broadcast_dims({expand_batch_portion});
y_broadcast_dims.insert(y_broadcast_dims.end(), {y_dims_vec[y_dims_n - 2],
y_dims_vec[y_dims_n - 1]});
// dim of 'Out' is the same with 'Y' after broadcast
ctx->SetOutputDim("Out", phi::make_ddim(y_broadcast_dims));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(
...@@ -168,20 +119,15 @@ class TriangularSolveOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -168,20 +119,15 @@ class TriangularSolveOpGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(triangular_solve, TriangularSolveInferShapeFunctor,
PT_INFER_META(phi::TriangularSolveInferMeta));
REGISTER_OPERATOR(triangular_solve, ops::TriangularSolveOp, REGISTER_OPERATOR(triangular_solve, ops::TriangularSolveOp,
ops::TriangularSolveOpMaker, ops::TriangularSolveOpMaker,
ops::TriangularSolveOpInferVarType, ops::TriangularSolveOpInferVarType,
ops::TriangularSolveOpGradMaker<paddle::framework::OpDesc>, ops::TriangularSolveOpGradMaker<paddle::framework::OpDesc>,
ops::TriangularSolveOpGradMaker<paddle::imperative::OpBase>); ops::TriangularSolveOpGradMaker<paddle::imperative::OpBase>,
TriangularSolveInferShapeFunctor);
REGISTER_OPERATOR(triangular_solve_grad, ops::TriangularSolveGradOp); REGISTER_OPERATOR(triangular_solve_grad, ops::TriangularSolveGradOp);
REGISTER_OP_CPU_KERNEL(
triangular_solve,
ops::TriangularSolveKernel<paddle::platform::CPUDeviceContext, float>,
ops::TriangularSolveKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
triangular_solve_grad,
ops::TriangularSolveGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TriangularSolveGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
namespace paddle {
namespace operators {
template <typename T>
class MatrixReduceSumFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const Tensor& in, Tensor* out,
const framework::ExecutionContext& ctx) {
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2]
const std::vector<std::int64_t> in_dims = phi::vectorize(in.dims());
auto in_size = in_dims.size();
const std::vector<std::int64_t> out_dims = phi::vectorize(out->dims());
auto out_size = out_dims.size();
std::vector<std::int64_t> out_bst_dims(in_size);
std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1);
std::copy(out_dims.data(), out_dims.data() + out_size,
out_bst_dims.data() + in_size - out_size);
std::vector<int> out_reduce_dims;
for (size_t idx = 0; idx <= in_size - 3; idx++) {
if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) {
out_reduce_dims.push_back(idx);
}
}
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx.cuda_device_context(), in, out, kps::IdentityFunctor<T>(),
out_reduce_dims, stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
triangular_solve,
ops::TriangularSolveKernel<paddle::platform::CUDADeviceContext, float>,
ops::TriangularSolveKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
triangular_solve_grad,
ops::TriangularSolveGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TriangularSolveGradKernel<paddle::platform::CUDADeviceContext,
double>);
...@@ -21,7 +21,6 @@ limitations under the License. */ ...@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/operators/solve_op.h" #include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/operators/tril_triu_op.h" #include "paddle/fluid/operators/tril_triu_op.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
namespace paddle { namespace paddle {
...@@ -30,10 +29,10 @@ namespace operators { ...@@ -30,10 +29,10 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
static void triangular_solve(const DeviceContext& context, const Tensor& x, static void triangular_solve(const DeviceContext &context, const Tensor &x,
const Tensor& y, Tensor* out, bool upper, const Tensor &y, Tensor *out, bool upper,
bool transpose, bool unitriangular) { bool transpose, bool unitriangular) {
// Tensor broadcast use eigen // Tensor broadcast use eigen library
std::vector<int64_t> x_bst_dims_vec; std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec; std::vector<int64_t> y_bst_dims_vec;
std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(x, y); std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(x, y);
...@@ -64,15 +63,15 @@ static void triangular_solve(const DeviceContext& context, const Tensor& x, ...@@ -64,15 +63,15 @@ static void triangular_solve(const DeviceContext& context, const Tensor& x,
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatrixReduceSumFunctor { class MatrixReduceSumFunctor {
public: public:
void operator()(const Tensor& input, Tensor* output, void operator()(const Tensor &input, Tensor *output,
const framework::ExecutionContext& ctx); const framework::ExecutionContext &ctx);
}; };
template <typename T> template <typename T>
class MatrixReduceSumFunctor<platform::CPUDeviceContext, T> { class MatrixReduceSumFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const Tensor& in, Tensor* out, void operator()(const Tensor &in, Tensor *out,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext &ctx) {
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3] // For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2] // out_reduce_dim should be [0, 2]
const std::vector<std::int64_t> in_dims = phi::vectorize(in.dims()); const std::vector<std::int64_t> in_dims = phi::vectorize(in.dims());
...@@ -101,129 +100,5 @@ class MatrixReduceSumFunctor<platform::CPUDeviceContext, T> { ...@@ -101,129 +100,5 @@ class MatrixReduceSumFunctor<platform::CPUDeviceContext, T> {
} }
}; };
template <typename DeviceContext, typename T>
class TriangularSolveKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* x = ctx.Input<framework::Tensor>("X");
const auto* y = ctx.Input<framework::Tensor>("Y");
auto* out = ctx.Output<framework::Tensor>("Out");
bool upper = ctx.template Attr<bool>("upper");
bool transpose = ctx.template Attr<bool>("transpose");
bool unitriangular = ctx.template Attr<bool>("unitriangular");
const auto& dev_ctx = ctx.template device_context<DeviceContext>();
triangular_solve<DeviceContext, T>(dev_ctx, *x, *y, out, upper, transpose,
unitriangular);
}
};
template <typename DeviceContext, typename T>
class TriangularSolveGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* x = ctx.Input<framework::Tensor>("X");
const auto* y = ctx.Input<framework::Tensor>("Y");
const auto* out = ctx.Input<framework::Tensor>("Out");
const auto* dout =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
bool upper = ctx.template Attr<bool>("upper");
bool transpose = ctx.template Attr<bool>("transpose");
bool unitriangular = ctx.template Attr<bool>("unitriangular");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec;
std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(*x, *y);
Tensor dy_bst(y->type());
if (dy) {
dy->mutable_data<T>(y->dims(), dev_ctx.GetPlace());
dy_bst.Resize(phi::make_ddim(y_bst_dims_vec));
dy_bst.mutable_data<T>(dev_ctx.GetPlace());
// calculate x's conjugate for complex
Tensor x_conj(x->type());
platform::ForRange<DeviceContext> x_for_range(dev_ctx, x->numel());
phi::funcs::ConjFunctor<T> x_functor(
x->data<T>(), x->numel(),
x_conj.mutable_data<T>(x->dims(), dev_ctx.GetPlace()));
x_for_range(x_functor);
// reuse forward to get dy_bst, and the result has been broadcated.
triangular_solve<DeviceContext, T>(dev_ctx, x_conj, *dout, &dy_bst, upper,
!transpose, unitriangular);
if (dy_bst.dims() == dy->dims()) {
framework::TensorCopy(dy_bst, dev_ctx.GetPlace(), dev_ctx, dy);
} else {
MatrixReduceSumFunctor<DeviceContext, T> functor;
functor(dy_bst, dy, ctx);
dy->Resize(y->dims());
}
}
Tensor dx_bst(x->type());
if (dx) {
dx->mutable_data<T>(x->dims(), dev_ctx.GetPlace());
dx_bst.Resize(phi::make_ddim(x_bst_dims_vec));
dx_bst.mutable_data<T>(dev_ctx.GetPlace());
// calculate out's conjugate for complex
Tensor out_conj(out->type());
platform::ForRange<DeviceContext> out_for_range(dev_ctx, out->numel());
phi::funcs::ConjFunctor<T> out_functor(
out->data<T>(), out->numel(),
out_conj.mutable_data<T>(out->dims(), dev_ctx.GetPlace()));
out_for_range(out_functor);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
if (transpose) {
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(out_conj.dims(), 0, false);
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(dy_bst.dims(), 0, true);
blas.MatMul(out_conj, mat_dim_a, dy_bst, mat_dim_b, static_cast<T>(-1),
&dx_bst, static_cast<T>(0));
} else {
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(dy_bst.dims(), 0, false);
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(out_conj.dims(), 0, true);
blas.MatMul(dy_bst, mat_dim_a, out_conj, mat_dim_b, static_cast<T>(-1),
&dx_bst, static_cast<T>(0));
}
Tensor dx_bst_upper(x->type());
// get upper or lower triangular
dx_bst_upper.Resize(dx_bst.dims());
dx_bst_upper.mutable_data<T>(dev_ctx.GetPlace());
const auto& dims = dx_bst.dims();
const auto H = dims[dims.size() - 2];
const auto W = dims[dims.size() - 1];
platform::ForRange<DeviceContext> x_for_range(dev_ctx, dx_bst.numel());
TrilTriuCompute<T> tril_triu_computer(dx_bst.data<T>(), unitriangular,
!upper, H, W,
dx_bst_upper.data<T>());
x_for_range(tril_triu_computer);
if (dx_bst_upper.dims() == dx->dims()) {
framework::TensorCopy(dx_bst_upper, dev_ctx.GetPlace(), dev_ctx, dx);
} else {
MatrixReduceSumFunctor<DeviceContext, T> functor;
functor(dx_bst_upper, dx, ctx);
dx->Resize(x->dims());
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -274,6 +274,65 @@ void HuberLossInferMeta(const MetaTensor& input, ...@@ -274,6 +274,65 @@ void HuberLossInferMeta(const MetaTensor& input,
out->share_lod(input); out->share_lod(input);
} }
void TriangularSolveInferMeta(const MetaTensor& x,
const MetaTensor& y,
bool upper,
bool transpose,
bool unitriangular,
MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
auto x_dims_n = x_dims.size();
auto y_dims_n = y_dims.size();
PADDLE_ENFORCE_GE(x_dims_n,
2,
phi::errors::InvalidArgument(
"The input tensor X's dimensions of TriangularSolveOp "
"should be >= 2. But received X's "
"dimensions = %d, X's shape = [%s]",
x_dims.size(),
x_dims));
PADDLE_ENFORCE_GE(y_dims_n,
2,
phi::errors::InvalidArgument(
"The input tensor Y's dimensions of TriangularSolveOp "
"should be >=2. But received Y's "
"dimensions = %d, Y's shape = [%s]",
y_dims.size(),
y_dims));
PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2],
x_dims[x_dims_n - 1],
phi::errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) all should "
"be square matrices "
"But received X's shape[-2] = %d and shape[-1] = %d.",
x_dims[x_dims_n - 2],
x_dims[x_dims_n - 1]));
std::vector<int64_t> x_dims_vec = phi::vectorize(x_dims);
std::vector<int64_t> y_dims_vec = phi::vectorize(y_dims);
std::vector<int64_t> x_dims_vec_cut(x_dims_vec.begin(), x_dims_vec.end() - 2);
std::vector<int64_t> y_dims_vec_cut(y_dims_vec.begin(), y_dims_vec.end() - 2);
std::vector<int64_t> expand_batch_portion =
funcs::MatrixGetBroadcastBatchPortion(x_dims_vec_cut, y_dims_vec_cut);
std::vector<int64_t> y_broadcast_dims({expand_batch_portion});
y_broadcast_dims.insert(y_broadcast_dims.end(),
{y_dims_vec[y_dims_n - 2], y_dims_vec[y_dims_n - 1]});
// dim of 'out' is the same with 'Y' after broadcast
out->set_dims(phi::make_ddim(y_broadcast_dims));
out->set_dtype(y.dtype());
out->set_layout(y.layout());
out->share_lod(y);
}
void IndexSampleInferMeta(const MetaTensor& x, void IndexSampleInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
MetaTensor* out, MetaTensor* out,
......
...@@ -62,6 +62,13 @@ void HuberLossInferMeta(const MetaTensor& input_meta, ...@@ -62,6 +62,13 @@ void HuberLossInferMeta(const MetaTensor& input_meta,
MetaTensor* residual, MetaTensor* residual,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void TriangularSolveInferMeta(const MetaTensor& x,
const MetaTensor& y,
bool upper,
bool transpose,
bool unitriangular,
MetaTensor* out);
void IndexSampleInferMeta(const MetaTensor& x, void IndexSampleInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
MetaTensor* out, MetaTensor* out,
......
...@@ -18,10 +18,11 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) ...@@ -18,10 +18,11 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
# NOTE: Some kernels depend on some targets that are not commonly used. # NOTE: Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies. # These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here. # In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel) set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel)
kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel) kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel)
kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce)
# auto parse and build kernel targets by cmake # auto parse and build kernel targets by cmake
register_kernels(EXCLUDES ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS}) register_kernels(EXCLUDES ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS})
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h"
PD_REGISTER_KERNEL(triangular_solve_grad,
CPU,
ALL_LAYOUT,
phi::TriangularSolveGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/triangular_solve_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
namespace phi {
template <typename T, typename Context>
void TriangularSolveKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool upper,
bool transpose,
bool unitriangular,
DenseTensor* out) {
// get broadcast dim
std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec;
std::tie(x_bst_dims_vec, y_bst_dims_vec) =
funcs::MatrixGetBroadcastDims(x, y);
int x_bst_ndim = x_bst_dims_vec.size();
int y_bst_ndim = y_bst_dims_vec.size();
// Tensor broadcast to 'out' and temp 'x_bst'
ScalarArray x_bst_dims(x_bst_dims_vec);
DenseTensor x_bst = phi::Empty<T, Context>(dev_ctx, x_bst_dims);
const T* x_bst_data = x_bst.data<T>();
ExpandKernel<T, Context>(dev_ctx, x, x_bst_dims, &x_bst);
out->Resize(phi::make_ddim(y_bst_dims_vec));
T* out_data = dev_ctx.template Alloc<T>(out);
ScalarArray y_bst_dims(y_bst_dims_vec);
ExpandKernel<T, Context>(dev_ctx, y, y_bst_dims, out);
// Calculate use blas library
int M = static_cast<int>(y_bst_dims_vec[y_bst_ndim - 2]);
int N = static_cast<int>(y_bst_dims_vec[y_bst_ndim - 1]);
int batch_size = 1;
for (int i = 0; i < x_bst_ndim - 2; i++) {
batch_size *= x_bst_dims_vec[i];
}
auto blas = phi::funcs::GetBlas<CPUContext, T>(dev_ctx);
for (int i = 0; i < batch_size; i++) {
blas.TRSM(CblasLeft,
upper ? CblasUpper : CblasLower,
transpose ? CblasTrans : CblasNoTrans,
unitriangular ? CblasUnit : CblasNonUnit,
M,
N,
T(1),
x_bst_data + i * M * M,
std::max(1, M),
out_data + i * N * M,
std::max(1, N));
}
}
} // namespace phi
PD_REGISTER_KERNEL(triangular_solve,
CPU,
ALL_LAYOUT,
phi::TriangularSolveKernel,
float,
double) {}
...@@ -8,3 +8,4 @@ math_library(sequence2batch) ...@@ -8,3 +8,4 @@ math_library(sequence2batch)
math_library(gru_compute DEPS activation_functions math_function) math_library(gru_compute DEPS activation_functions math_function)
math_library(lstm_compute DEPS activation_functions) math_library(lstm_compute DEPS activation_functions)
math_library(concat_and_split_functor DEPS dense_tensor) math_library(concat_and_split_functor DEPS dense_tensor)
math_library(matrix_reduce DEPS dense_tensor)
...@@ -140,6 +140,72 @@ inline bool CheckDims(const DDim &dims_x, const DDim &dims_y) { ...@@ -140,6 +140,72 @@ inline bool CheckDims(const DDim &dims_x, const DDim &dims_y) {
return true; return true;
} }
// Just For Matrix OP, for example:
// x's dim = [5, 3, 2, M, M] ; y's dim = [3, 1, M, N]
// out [5, 3, 2], which is batch_size of matrix
static inline std::vector<int64_t> MatrixGetBroadcastBatchPortion(
std::vector<int64_t> x, std::vector<int64_t> y) {
size_t size_x = x.size();
size_t size_y = y.size();
size_t size = std::max(size_x, size_y);
std::vector<int64_t> batchPortion(size);
ptrdiff_t i = (ptrdiff_t)size - 1;
for (; i >= 0; --i) {
ptrdiff_t offset = size - i - 1;
ptrdiff_t dim_x = size_x - offset - 1;
ptrdiff_t dim_y = size_y - offset - 1;
int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1;
int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1;
PADDLE_ENFORCE_EQ(
(x_size == y_size || x_size == 1 || y_size == 1),
true,
phi::errors::PreconditionNotMet(
"The size of tensor x (%d) must match the size of tensor y "
"(%d) at non-singleton dimension %d.",
x_size,
y_size,
i));
batchPortion[i] = x_size != 1 ? x_size : y_size;
}
return batchPortion;
}
// Just For Matrix OP, for example:
// x's dim = [5, 3, 2, M, M] ; y's dim = [3, 1, M, N]
// out shoule be [5, 3, 2, M, M] + [5, 3, 2, M, N], and [5, 3, 2] is
// batch_size of matrix
static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>>
MatrixGetBroadcastDims(const DenseTensor &x, const DenseTensor &y) {
std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
std::vector<int64_t> y_dims_vec = phi::vectorize(y.dims());
std::vector<int64_t>::const_iterator f1 = x_dims_vec.begin();
std::vector<int64_t>::const_iterator l1 = x_dims_vec.end() - 2;
std::vector<int64_t> x_dims_vec_cut(f1, l1);
std::vector<int64_t>::const_iterator f2 = y_dims_vec.begin();
std::vector<int64_t>::const_iterator l2 = y_dims_vec.end() - 2;
std::vector<int64_t> y_dims_vec_cut(f2, l2);
std::vector<int64_t> expand_batch_portion =
MatrixGetBroadcastBatchPortion(x_dims_vec_cut, y_dims_vec_cut);
std::vector<int64_t> x_expand_size({expand_batch_portion});
x_expand_size.insert(x_expand_size.end(),
{x_dims_vec[static_cast<int>(x_dims_vec.size()) - 2],
x_dims_vec[static_cast<int>(x_dims_vec.size()) - 1]});
std::vector<int64_t> y_expand_size({expand_batch_portion});
y_expand_size.insert(y_expand_size.end(),
{y_dims_vec[static_cast<int>(y_dims_vec.size()) - 2],
y_dims_vec[static_cast<int>(y_dims_vec.size()) - 1]});
return std::make_tuple(x_expand_size, y_expand_size);
}
inline DDim GetOutputDims(const DDim &s_dims, const DDim &l_dims) { inline DDim GetOutputDims(const DDim &s_dims, const DDim &l_dims) {
if (s_dims.size() > l_dims.size()) { if (s_dims.size() > l_dims.size()) {
return GetOutputDims(l_dims, s_dims); return GetOutputDims(l_dims, s_dims);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/funcs/matrix_reduce.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
namespace phi {
namespace funcs {
template <typename T>
class MatrixReduceSumFunctor<T, CPUContext> {
public:
void operator()(const CPUContext& dev_ctx,
const DenseTensor& in,
DenseTensor* out) {
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2]
const std::vector<int64_t> in_dims = phi::vectorize<int64_t>(in.dims());
auto in_size = in_dims.size();
const std::vector<int64_t> out_dims = phi::vectorize<int64_t>(out->dims());
auto out_size = out_dims.size();
std::vector<int64_t> out_bst_dims(in_size);
std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1);
std::copy(out_dims.data(),
out_dims.data() + out_size,
out_bst_dims.data() + in_size - out_size);
out->Resize(phi::make_ddim(out_bst_dims));
std::vector<int64_t> out_reduce_dims;
for (size_t idx = 0; idx <= in_size - 3; idx++) {
if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) {
out_reduce_dims.push_back(idx);
}
}
phi::ReduceKernelImpl<CPUContext, T, T, phi::funcs::SumFunctor>(
dev_ctx, in, out, out_reduce_dims, true, false);
}
};
template class MatrixReduceSumFunctor<float, CPUContext>;
template class MatrixReduceSumFunctor<double, CPUContext>;
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/funcs/matrix_reduce.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
namespace phi {
namespace funcs {
template <typename T>
class MatrixReduceSumFunctor<T, GPUContext> {
public:
void operator()(const GPUContext& dev_ctx,
const DenseTensor& in,
DenseTensor* out) {
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2]
const std::vector<int> in_dims = phi::vectorize<int>(in.dims());
auto in_size = in_dims.size();
const std::vector<int> out_dims = phi::vectorize<int>(out->dims());
auto out_size = out_dims.size();
std::vector<int> out_bst_dims(in_size);
std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1);
std::copy(out_dims.data(),
out_dims.data() + out_size,
out_bst_dims.data() + in_size - out_size);
out->Resize(phi::make_ddim(out_bst_dims));
std::vector<int> out_reduce_dims;
for (size_t idx = 0; idx <= in_size - 3; idx++) {
if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) {
out_reduce_dims.push_back(idx);
}
}
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx,
in,
out,
kps::IdentityFunctor<T>(),
out_reduce_dims,
dev_ctx.stream());
}
};
template class MatrixReduceSumFunctor<float, GPUContext>;
template class MatrixReduceSumFunctor<double, GPUContext>;
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
namespace funcs {
// Use For Matrix OP, reduce_sum 'in' according to out's dim
// for example: in's dim = [5, 3, 2, M, N] ; out's dim = [3, 1, M, N]
// axis [0, 2] of DenseTensor 'in' will be reduced
template <typename T, typename Context>
class MatrixReduceSumFunctor {
public:
void operator()(const Context& dev_ctx,
const DenseTensor& in,
DenseTensor* out);
};
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h"
PD_REGISTER_KERNEL(triangular_solve_grad,
GPU,
ALL_LAYOUT,
phi::TriangularSolveGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/triangular_solve_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memory.h"
namespace phi {
template <typename T, typename Context>
void TriangularSolveKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool upper,
bool transpose,
bool unitriangular,
DenseTensor* out) {
// get broadcast dim
std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec;
std::tie(x_bst_dims_vec, y_bst_dims_vec) =
funcs::MatrixGetBroadcastDims(x, y);
int x_bst_ndim = x_bst_dims_vec.size();
int y_bst_ndim = y_bst_dims_vec.size();
// Tensor broadcast to 'out' and temp 'x_bst'
ScalarArray x_bst_dims(x_bst_dims_vec);
DenseTensor x_bst = phi::Empty<T, Context>(dev_ctx, x_bst_dims);
const T* x_bst_data = x_bst.data<T>();
ExpandKernel<T, Context>(dev_ctx, x, x_bst_dims, &x_bst);
out->Resize(phi::make_ddim(y_bst_dims_vec));
T* out_data = dev_ctx.template Alloc<T>(out);
ScalarArray y_bst_dims(y_bst_dims_vec);
ExpandKernel<T, Context>(dev_ctx, y, y_bst_dims, out);
// calculate use cublas library
CBLAS_UPLO uplo = upper ? CblasUpper : CblasLower;
CBLAS_TRANSPOSE transA = transpose ? CblasTrans : CblasNoTrans;
CBLAS_DIAG diag = unitriangular ? CblasUnit : CblasNonUnit;
int M = static_cast<int>(y_bst_dims_vec[y_bst_ndim - 2]);
int N = static_cast<int>(y_bst_dims_vec[y_bst_ndim - 1]);
auto lda = std::max(1, M);
auto ldb = std::max(1, N);
int batch_size = 1;
for (int i = 0; i < x_bst_ndim - 2; i++) {
batch_size *= x_bst_dims_vec[i];
}
auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
if (batch_size <= 8 && M >= 64) {
for (auto i = 0; i < batch_size; i++) {
blas.TRSM(CblasLeft,
uplo,
transA,
diag,
M,
N,
T(1),
x_bst_data + i * M * M,
lda,
out_data + i * N * M,
ldb);
}
} else {
std::vector<const T*> cpu_ptrs(batch_size * 2);
for (int i = 0; i < batch_size; ++i) {
cpu_ptrs[i] = x_bst_data + i * M * M;
cpu_ptrs[i + batch_size] = out_data + i * M * N;
}
// Copy the addresses of A and tmp_b from host to device.
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(dev_ctx, cpu_ptrs.size() * sizeof(T*));
paddle::memory::Copy(dev_ctx.GetPlace(),
tmp_gpu_ptrs_data->ptr(),
paddle::platform::CPUPlace(),
static_cast<void*>(cpu_ptrs.data()),
cpu_ptrs.size() * sizeof(T*),
dev_ctx.stream());
const T** gpu_a_ptrs =
reinterpret_cast<const T**>(tmp_gpu_ptrs_data->ptr());
T** gpu_b_ptrs =
reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()) + batch_size;
blas.BatchedTRSM(CblasLeft,
uplo,
transA,
diag,
M,
N,
static_cast<T>(1.0),
gpu_a_ptrs,
lda,
gpu_b_ptrs,
ldb,
batch_size);
}
}
} // namespace phi
PD_REGISTER_KERNEL(triangular_solve,
GPU,
ALL_LAYOUT,
phi::TriangularSolveKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/triangular_solve_grad_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/matrix_reduce.h"
#include "paddle/phi/kernels/triangular_solve_kernel.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/tril_triu_op.h"
namespace phi {
template <typename T, typename Context>
void TriangularSolveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
bool upper,
bool transpose,
bool unitriangular,
DenseTensor* dx,
DenseTensor* dy) {
std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec;
std::tie(x_bst_dims_vec, y_bst_dims_vec) =
funcs::MatrixGetBroadcastDims(x, y);
ScalarArray y_bst_dims_array(y_bst_dims_vec);
DenseTensor dy_bst = phi::Empty<T, Context>(dev_ctx, y_bst_dims_array);
if (dy) {
// calculate x's conjugate for complex
DenseTensor x_conj = phi::Empty<T, Context>(dev_ctx);
x_conj.Resize(x.dims());
phi::funcs::ForRange<Context> x_for_range(dev_ctx, x.numel());
phi::funcs::ConjFunctor<T> x_functor(
x.data<T>(), x.numel(), dev_ctx.template Alloc<T>(&x_conj));
x_for_range(x_functor);
// reuse forward to get dy_bst, and the result has been broadcated already.
TriangularSolveKernel<T, Context>(
dev_ctx, x_conj, dout, upper, !transpose, unitriangular, &dy_bst);
dy->Resize(y.dims());
dev_ctx.template Alloc<T>(dy);
if (dy_bst.dims() == y.dims()) {
Copy<Context>(dev_ctx, dy_bst, dev_ctx.GetPlace(), false, dy);
} else {
funcs::MatrixReduceSumFunctor<T, Context> functor;
functor(dev_ctx, dy_bst, dy);
dy->Resize(y.dims());
}
}
ScalarArray x_bst_dims_array(x_bst_dims_vec);
DenseTensor dx_bst = phi::Empty<T, Context>(dev_ctx, x_bst_dims_array);
if (dx) {
// calculate x's conjugate for complex
DenseTensor out_conj = phi::Empty<T, Context>(dev_ctx);
out_conj.Resize(out.dims());
phi::funcs::ForRange<Context> out_for_range(dev_ctx, out.numel());
phi::funcs::ConjFunctor<T> out_functor(
out.data<T>(), out.numel(), dev_ctx.template Alloc<T>(&out_conj));
out_for_range(out_functor);
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
if (transpose) {
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(out_conj.dims(), 0, false);
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(dy_bst.dims(), 0, true);
blas.MatMul(out_conj,
mat_dim_a,
dy_bst,
mat_dim_b,
static_cast<T>(-1),
&dx_bst,
static_cast<T>(0));
} else {
auto mat_dim_a =
phi::funcs::CreateMatrixDescriptor(dy_bst.dims(), 0, false);
auto mat_dim_b =
phi::funcs::CreateMatrixDescriptor(out_conj.dims(), 0, true);
blas.MatMul(dy_bst,
mat_dim_a,
out_conj,
mat_dim_b,
static_cast<T>(-1),
&dx_bst,
static_cast<T>(0));
}
// get upper or lower triangular
DenseTensor dx_bst_upper =
phi::Empty<T, Context>(dev_ctx, x_bst_dims_array);
const auto& dims = dx_bst.dims();
const auto H = dims[dims.size() - 2];
const auto W = dims[dims.size() - 1];
phi::funcs::ForRange<Context> x_for_range(dev_ctx, dx_bst.numel());
paddle::operators::TrilTriuCompute<T> tril_triu_functor(
dx_bst.data<T>(), unitriangular, !upper, H, W, dx_bst_upper.data<T>());
x_for_range(tril_triu_functor);
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx);
if (dx_bst.dims() == x.dims()) {
Copy<Context>(dev_ctx, dx_bst_upper, dev_ctx.GetPlace(), false, dx);
} else {
funcs::MatrixReduceSumFunctor<T, Context> functor;
functor(dev_ctx, dx_bst_upper, dx);
dx->Resize(x.dims());
}
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TriangularSolveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
bool upper,
bool transpose,
bool unitriangular,
DenseTensor* dx,
DenseTensor* dy);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TriangularSolveKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool upper,
bool transpose,
bool unitriangular,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TriangularSolveGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("triangular_solve_grad",
{"X", "Y", "Out", GradVarName("Out")},
{"upper", "transpose", "unitriangular"},
{GradVarName("X"), GradVarName("Y")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(triangular_solve_grad,
phi::TriangularSolveGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册