未验证 提交 dce87e3d 编写于 作者: C caozhou 提交者: GitHub

[Phi] Migrate multiplex, qr, tril_triu op kernel to phi (#40007)

* migrate multiplex op kernel

* migrate qr cpu kernel

* migrate tril_triu op kernel

* fix multiplex kernel

* add kernel sig

* fix dependence and bug

* fix multiplex error

* fix npu include error

* fix conflict

* fix conflict and delete tril_triu

* fix date and multiplex input

* adapt header file order

* fix header file include

* fix conflict

* delete cholesky_solve_op.h

* delete triangular_solve_op.h
上级 517b1a7c
......@@ -18,8 +18,8 @@ limitations under the License. */
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/operators/set_value_op.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/tril_triu_op.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/tril_triu_compute.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/triangular_solve_kernel.h"
......@@ -404,11 +404,12 @@ void LU_Unpack(const DeviceContext& dev_ctx, const framework::Tensor* LU,
const auto W = udims[udims.size() - 1];
auto L_dataptr = L->mutable_data<T>(dev_ctx.GetPlace());
platform::ForRange<DeviceContext> x_for_range(dev_ctx, LU->numel());
TrilTriuCompute<T> tril_computer(LU->data<T>(), -1, true, H, W, L_dataptr);
phi::funcs::TrilTriuCompute<T> tril_computer(LU->data<T>(), -1, true, H, W,
L_dataptr);
x_for_range(tril_computer);
TrilTriuCompute<T> triu_computer(LU->data<T>(), 0, false, H, W,
U->mutable_data<T>(dev_ctx.GetPlace()));
phi::funcs::TrilTriuCompute<T> triu_computer(
LU->data<T>(), 0, false, H, W, U->mutable_data<T>(dev_ctx.GetPlace()));
x_for_range(triu_computer);
// set L's diagonal 1
......@@ -532,15 +533,15 @@ class LUGradKernel : public framework::OpKernel<T> {
auto phil_rank = LmHdims.size();
auto phiu_rank = UmHdims.size();
platform::ForRange<DeviceContext> l_for_range(dev_ctx, phi_L.numel());
TrilTriuCompute<T> tril_computer(phi_L.data<T>(), -1, true,
LmHdims[phil_rank - 2],
LmHdims[phil_rank - 1], phi_L.data<T>());
phi::funcs::TrilTriuCompute<T> tril_computer(
phi_L.data<T>(), -1, true, LmHdims[phil_rank - 2],
LmHdims[phil_rank - 1], phi_L.data<T>());
l_for_range(tril_computer);
platform::ForRange<DeviceContext> u_for_range(dev_ctx, phi_U.numel());
TrilTriuCompute<T> triu_computer(phi_U.data<T>(), 0, false,
UmHdims[phiu_rank - 2],
UmHdims[phiu_rank - 1], phi_U.data<T>());
phi::funcs::TrilTriuCompute<T> triu_computer(
phi_U.data<T>(), 0, false, UmHdims[phiu_rank - 2],
UmHdims[phiu_rank - 1], phi_U.data<T>());
u_for_range(triu_computer);
Tensor_Add<DeviceContext, T>(dev_ctx, phi_L, phi_U, &phi);
......@@ -591,8 +592,9 @@ class LUGradKernel : public framework::OpKernel<T> {
const auto W = phidims[phidims.size() - 1];
platform::ForRange<DeviceContext> x_for_range(dev_ctx,
phi_complement.numel());
TrilTriuCompute<T> tril_computer(phi_complement.data<T>(), -1, true, H,
W, phi_complement_l.data<T>());
phi::funcs::TrilTriuCompute<T> tril_computer(
phi_complement.data<T>(), -1, true, H, W,
phi_complement_l.data<T>());
x_for_range(tril_computer);
Tensor_Sub<DeviceContext, T>(dev_ctx, phi, phi_complement_l, &phi);
......@@ -664,8 +666,8 @@ class LUGradKernel : public framework::OpKernel<T> {
const auto W = phidims[phidims.size() - 1];
platform::ForRange<DeviceContext> x_for_range(dev_ctx,
phi_complement.numel());
TrilTriuCompute<T> triu_computer(phi_complement.data<T>(), 0, false, H, W,
phi_complement_u.data<T>());
phi::funcs::TrilTriuCompute<T> triu_computer(
phi_complement.data<T>(), 0, false, H, W, phi_complement_u.data<T>());
x_for_range(triu_computer);
Tensor_Sub<DeviceContext, T>(dev_ctx, phi, phi_complement_u, &phi);
......
......@@ -16,7 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lu_op.h"
#include "paddle/fluid/operators/tril_triu_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/tril_triu_compute.h"
namespace paddle {
namespace operators {
......@@ -87,7 +88,8 @@ class LU_UnpackGradKernel : public framework::OpKernel<T> {
auto W = ldims[ldims.size() - 1];
auto L_dataptr = dl_tril.mutable_data<T>(dev_ctx.GetPlace());
platform::ForRange<DeviceContext> l_for_range(dev_ctx, dl->numel());
TrilTriuCompute<T> tril_computer(dl->data<T>(), -1, true, H, W, L_dataptr);
phi::funcs::TrilTriuCompute<T> tril_computer(dl->data<T>(), -1, true, H, W,
L_dataptr);
l_for_range(tril_computer);
const auto udims = du->dims();
......@@ -96,7 +98,8 @@ class LU_UnpackGradKernel : public framework::OpKernel<T> {
W = udims[udims.size() - 1];
auto U_dataptr = du_triu.mutable_data<T>(dev_ctx.GetPlace());
platform::ForRange<DeviceContext> u_for_range(dev_ctx, du->numel());
TrilTriuCompute<T> triu_computer(du->data<T>(), 0, false, H, W, U_dataptr);
phi::funcs::TrilTriuCompute<T> triu_computer(du->data<T>(), 0, false, H, W,
U_dataptr);
u_for_range(triu_computer);
auto xdims = dx->dims();
......
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/multiplex_op.h"
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -169,15 +169,3 @@ REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker,
ops::MultiplexGradMaker<paddle::framework::OpDesc>,
ops::MultiplexGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp);
REGISTER_OP_CPU_KERNEL(
multiplex,
ops::MultiplexCPUKernel<paddle::platform::CPUDeviceContext, float>,
ops::MultiplexCPUKernel<paddle::platform::CPUDeviceContext, double>,
ops::MultiplexCPUKernel<paddle::platform::CPUDeviceContext, int>,
ops::MultiplexCPUKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
multiplex_grad,
ops::MultiplexGradCPUKernel<paddle::platform::CPUDeviceContext, float>,
ops::MultiplexGradCPUKernel<paddle::platform::CPUDeviceContext, double>,
ops::MultiplexGradCPUKernel<paddle::platform::CPUDeviceContext, int>,
ops::MultiplexGradCPUKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/multiplex_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class MultiplexGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<Tensor>("X");
auto* ids = ctx.Input<Tensor>("Ids");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE_GT(
ins[i]->numel(), 0,
platform::errors::OutOfRange(
"indexing will be out of bounds with size 0 for the %d-th input.",
i));
}
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
paddle::framework::TensorCopySync(*ids, platform::CPUPlace(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.cuda_device_context().stream();
platform::CUDAPlace place = ctx.GetPlace();
for (auto i = 0; i < rows; i++) {
int32_t k = index[i];
PADDLE_ENFORCE_GE(k, 0, platform::errors::PreconditionNotMet(
"index must be nonnegative."));
PADDLE_ENFORCE_LT(static_cast<size_t>(k), ins.size(),
platform::errors::PreconditionNotMet(
"index exceeds the number of candidate tensors."));
memory::Copy(place, out->data<T>() + i * cols, place,
ins[k]->data<T>() + i * cols, cols * sizeof(T), stream);
}
}
};
template <typename Place, typename T>
class MultiplexGradGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* ids = ctx.Input<Tensor>("Ids");
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
size_t idx = -1UL;
for (size_t i = 0; i < d_ins.size(); i++) {
if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
t.device(*ctx.template device_context<Place>().eigen_device()) =
t.constant(static_cast<T>(0));
idx = i;
}
}
if (idx == -1UL) return;
auto rows = d_ins[idx]->dims()[0];
auto cols = d_ins[idx]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
paddle::framework::TensorCopySync(*ids, platform::CPUPlace(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.cuda_device_context().stream();
platform::CUDAPlace place = ctx.GetPlace();
for (auto i = 0; i < rows; i++) {
size_t k = static_cast<size_t>(index[i]);
if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T), stream);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
multiplex,
ops::MultiplexGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::MultiplexGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::MultiplexGPUKernel<paddle::platform::CUDADeviceContext, int>,
ops::MultiplexGPUKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
multiplex_grad,
ops::MultiplexGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::MultiplexGradGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::MultiplexGradGPUKernel<paddle::platform::CUDADeviceContext, int>,
ops::MultiplexGradGPUKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class MultiplexCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto ids = ctx.Input<framework::Tensor>("Ids");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE_GT(
ins[i]->numel(), 0,
platform::errors::OutOfRange(
"indexing will be out of bounds with size 0 for the %d-th input.",
i));
}
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->numel() / rows;
auto index = ids->data<int32_t>();
platform::CPUPlace place = ctx.GetPlace();
for (auto i = 0; i < rows; i++) {
int32_t k = index[i];
PADDLE_ENFORCE_GE(k, 0, platform::errors::PreconditionNotMet(
"index must be nonnegative."));
PADDLE_ENFORCE_LT(static_cast<size_t>(k), ins.size(),
platform::errors::PreconditionNotMet(
"index exceeds the number of candidate tensors."));
memory::Copy(place, out->data<T>() + i * cols, place,
ins[k]->data<T>() + i * cols, cols * sizeof(T));
}
}
};
template <typename DeviceContext, typename T>
class MultiplexGradCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* ids = ctx.Input<framework::Tensor>("Ids");
auto d_ins =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
size_t idx = -1UL;
for (size_t i = 0; i < d_ins.size(); i++) {
if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
t.device(*ctx.template device_context<DeviceContext>().eigen_device()) =
t.constant(static_cast<T>(0));
idx = i;
}
}
if (idx == -1UL) return;
auto rows = d_ins[idx]->dims()[0];
auto cols = d_ins[idx]->numel() / rows;
auto* index = ids->data<int32_t>();
platform::CPUPlace place = ctx.GetPlace();
for (auto i = 0; i < rows; i++) {
size_t k = static_cast<size_t>(index[i]);
if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T));
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -145,8 +145,6 @@ REGISTER_OPERATOR(qr, ops::QrOp, ops::QrOpMaker,
REGISTER_OPERATOR(qr_grad, ops::QrGradOp);
REGISTER_OP_CPU_KERNEL(qr, ops::QrCPUKernel<float>, ops::QrCPUKernel<double>);
REGISTER_OP_CPU_KERNEL(
qr_grad, ops::QrGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::QrGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -48,85 +48,6 @@ static inline std::tuple<bool, bool> _parse_qr_mode(std::string mode) {
return std::make_tuple(compute_q, reduced);
}
template <typename T>
class QrCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool compute_q;
bool reduced_mode;
const Tensor& x = *context.Input<Tensor>("X");
Tensor& q = *context.Output<Tensor>("Q");
Tensor& r = *context.Output<Tensor>("R");
std::string mode = context.Attr<std::string>("mode");
std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode);
auto numel = x.numel();
PADDLE_ENFORCE_GT(numel, 0, platform::errors::PreconditionNotMet(
"The input of QR is empty."));
auto x_dims = x.dims();
int x_rank = x_dims.size();
int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int min_mn = std::min(m, n);
int k = reduced_mode ? min_mn : m;
int batch_size = numel / (m * n);
int x_stride = m * n;
int q_stride = m * k;
int r_stride = k * n;
auto* x_data = x.data<phi::dtype::Real<T>>();
T* q_data = nullptr;
if (compute_q) {
q_data = q.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(),
size_t(batch_size * m * k * sizeof(phi::dtype::Real<T>)));
memset(q_data, 0,
size_t(batch_size * m * k * sizeof(phi::dtype::Real<T>)));
}
auto* r_data = r.mutable_data<phi::dtype::Real<T>>(
context.GetPlace(),
size_t(batch_size * k * n * sizeof(phi::dtype::Real<T>)));
memset(r_data, 0, size_t(batch_size * k * n * sizeof(phi::dtype::Real<T>)));
// Implement QR by calling Eigen
for (int i = 0; i < batch_size; ++i) {
const T* x_matrix_ptr = x_data + i * x_stride;
T* r_matrix_ptr = r_data + i * r_stride;
using EigenDynamicMatrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
auto x_matrix = Eigen::Map<const EigenDynamicMatrix>(x_matrix_ptr, m, n);
Eigen::HouseholderQR<EigenDynamicMatrix> qr(x_matrix);
if (reduced_mode) {
auto qr_top_matrix = qr.matrixQR().block(0, 0, min_mn, n);
auto r_matrix_view =
qr_top_matrix.template triangularView<Eigen::Upper>();
auto r_matrix = EigenDynamicMatrix(r_matrix_view);
memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T));
} else {
auto r_matrix_view =
qr.matrixQR().template triangularView<Eigen::Upper>();
auto r_matrix = EigenDynamicMatrix(r_matrix_view);
memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T));
}
if (compute_q) {
T* q_matrix_ptr = q_data + i * q_stride;
if (reduced_mode) {
auto q_matrix =
qr.householderQ() * EigenDynamicMatrix::Identity(m, min_mn);
q_matrix.transposeInPlace();
memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T));
} else {
auto q_matrix =
qr.householderQ() * EigenDynamicMatrix::Identity(m, m);
q_matrix.transposeInPlace();
memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T));
}
}
}
}
};
template <typename DeviceContext, typename T>
class QrGradKernel : public framework::OpKernel<T> {
public:
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/tril_triu_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -104,19 +104,3 @@ REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
ops::TrilTriuGradOpMaker<paddle::framework::OpDesc>,
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp);
REGISTER_OP_CPU_KERNEL(
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, plat::float16>);
REGISTER_OP_CPU_KERNEL(
tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext,
plat::float16>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
class TrilTriuCompute {
public:
HOSTDEVICE TrilTriuCompute(const T* in, const int diagonal, const bool lower,
const int64_t H, const int64_t W, T* out)
: in_(in), diagonal_(diagonal), lower_(lower), H_(H), W_(W), out_(out) {}
HOSTDEVICE void operator()(int64_t idx) {
const int64_t row = (idx / W_) % H_;
const int64_t col = idx % W_;
const bool mask =
lower_ ? (col - row > diagonal_) : (col - row < diagonal_);
out_[idx] = mask ? static_cast<T>(0) : in_[idx];
}
private:
const T* in_;
const int diagonal_;
const bool lower_;
const int64_t H_;
const int64_t W_;
T* out_;
};
template <typename DeviceContext, typename T>
class TrilTriuOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* x = context.Input<framework::Tensor>("X");
const auto* x_data = x->data<T>();
auto* out = context.Output<framework::Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
const int diagonal = context.Attr<int>("diagonal");
const bool lower = context.Attr<bool>("lower");
const auto& dims = x->dims();
const auto H = dims[dims.size() - 2];
const auto W = dims[dims.size() - 1];
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(),
static_cast<size_t>(x->numel()));
paddle::operators::TrilTriuCompute<T> tril_triu_computer(
x_data, diagonal, lower, H, W, out_data);
for_range(tril_triu_computer);
}
};
template <typename DeviceContext, typename T>
class TrilTriuGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
const auto* dout_data = d_out->data<T>();
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dx_data = d_x->mutable_data<T>(context.GetPlace());
const int diagonal = context.Attr<int>("diagonal");
const bool lower = context.Attr<bool>("lower");
const auto& dims = d_out->dims();
const auto H = dims[dims.size() - 2];
const auto W = dims[dims.size() - 1];
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(),
static_cast<size_t>(d_out->numel()));
paddle::operators::TrilTriuCompute<T> tril_triu_grad_computer(
dout_data, diagonal, lower, H, W, dx_data);
for_range(tril_triu_grad_computer);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/tril_triu_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -11,7 +11,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/tril_triu_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/multiplex_grad_kernel.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void MultiplexGradKernel(const Context& ctx,
const DenseTensor& ids,
const DenseTensor& out_grad,
std::vector<DenseTensor*> ins_grad) {
size_t idx = -1UL;
for (size_t i = 0; i < ins_grad.size(); i++) {
if (ins_grad[i]) {
ctx.template Alloc<T>(ins_grad[i]);
auto t = phi::EigenVector<T>::Flatten(*ins_grad[i]);
t.device(*ctx.eigen_device()) = t.constant(static_cast<T>(0));
idx = i;
}
}
if (idx == -1UL) return;
auto rows = ins_grad[idx]->dims()[0];
auto cols = ins_grad[idx]->numel() / rows;
auto* index = ids.data<int32_t>();
for (auto i = 0; i < rows; i++) {
size_t k = static_cast<size_t>(index[i]);
if (ins_grad[k]) {
paddle::memory::Copy(ctx.GetPlace(),
ins_grad[k]->data<T>() + i * cols,
ctx.GetPlace(),
out_grad.data<T>() + i * cols,
cols * sizeof(T));
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(multiplex_grad,
CPU,
ALL_LAYOUT,
phi::MultiplexGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/multiplex_kernel.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void MultiplexKernel(const Context& ctx,
const std::vector<const DenseTensor*>& ins,
const DenseTensor& ids,
DenseTensor* out) {
ctx.template Alloc<T>(out);
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE_GT(
ins[i]->numel(),
0,
errors::OutOfRange(
"indexing will be out of bounds with size 0 for the %d-th input.",
i));
}
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->numel() / rows;
auto index = ids.data<int32_t>();
for (auto i = 0; i < rows; i++) {
int32_t k = index[i];
PADDLE_ENFORCE_GE(
k, 0, errors::PreconditionNotMet("index must be nonnegative."));
PADDLE_ENFORCE_LT(static_cast<size_t>(k),
ins.size(),
errors::PreconditionNotMet(
"index exceeds the number of candidate tensors."));
paddle::memory::Copy(ctx.GetPlace(),
out->data<T>() + i * cols,
ctx.GetPlace(),
ins[k]->data<T>() + i * cols,
cols * sizeof(T));
}
}
} // namespace phi
PD_REGISTER_KERNEL(multiplex,
CPU,
ALL_LAYOUT,
phi::MultiplexKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Dense>
#include "paddle/phi/kernels/qr_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace phi {
static inline std::tuple<bool, bool> ParseQrMode(const std::string& mode) {
bool compute_q;
bool reduced;
if (mode == "reduced") {
compute_q = true;
reduced = true;
} else if (mode == "complete") {
compute_q = true;
reduced = false;
} else if (mode == "r") {
compute_q = false;
reduced = true;
} else {
PADDLE_THROW(errors::InvalidArgument(
"QR received unrecognized mode '%s'"
" but expected one of 'reduced' (default), 'r', or 'complete'",
mode));
}
return std::make_tuple(compute_q, reduced);
}
template <typename T, typename Context>
void QrKernel(const Context& ctx,
const DenseTensor& x,
const std::string& mode,
DenseTensor* q,
DenseTensor* r) {
bool compute_q;
bool reduced_mode;
std::tie(compute_q, reduced_mode) = ParseQrMode(mode);
auto numel = x.numel();
PADDLE_ENFORCE_GT(
numel, 0, errors::PreconditionNotMet("The input of QR is empty."));
auto x_dims = x.dims();
int x_rank = x_dims.size();
int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int min_mn = std::min(m, n);
int k = reduced_mode ? min_mn : m;
int batch_size = numel / (m * n);
int x_stride = m * n;
int q_stride = m * k;
int r_stride = k * n;
auto* x_data = x.data<phi::dtype::Real<T>>();
T* q_data = nullptr;
if (compute_q) {
q_data = ctx.template Alloc<phi::dtype::Real<T>>(
q, batch_size * m * k * sizeof(phi::dtype::Real<T>));
}
auto* r_data = ctx.template Alloc<phi::dtype::Real<T>>(
r, batch_size * k * n * sizeof(phi::dtype::Real<T>));
// Implement QR by calling Eigen
for (int i = 0; i < batch_size; ++i) {
const T* x_matrix_ptr = x_data + i * x_stride;
T* r_matrix_ptr = r_data + i * r_stride;
using EigenDynamicMatrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
auto x_matrix = Eigen::Map<const EigenDynamicMatrix>(x_matrix_ptr, m, n);
Eigen::HouseholderQR<EigenDynamicMatrix> qr(x_matrix);
if (reduced_mode) {
auto qr_top_matrix = qr.matrixQR().block(0, 0, min_mn, n);
auto r_matrix_view =
qr_top_matrix.template triangularView<Eigen::Upper>();
auto r_matrix = EigenDynamicMatrix(r_matrix_view);
memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T));
} else {
auto r_matrix_view =
qr.matrixQR().template triangularView<Eigen::Upper>();
auto r_matrix = EigenDynamicMatrix(r_matrix_view);
memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T));
}
if (compute_q) {
T* q_matrix_ptr = q_data + i * q_stride;
if (reduced_mode) {
auto q_matrix =
qr.householderQ() * EigenDynamicMatrix::Identity(m, min_mn);
q_matrix.transposeInPlace();
memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T));
} else {
auto q_matrix = qr.householderQ() * EigenDynamicMatrix::Identity(m, m);
q_matrix.transposeInPlace();
memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T));
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(qr, CPU, ALL_LAYOUT, phi::QrKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/tril_triu_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(tril_triu_grad,
CPU,
ALL_LAYOUT,
phi::TrilTriuGradKernel,
bool,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(tril_triu,
CPU,
ALL_LAYOUT,
phi::TrilTriuKernel,
bool,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
namespace funcs {
template <typename T>
class TrilTriuCompute {
public:
HOSTDEVICE TrilTriuCompute(const T* in,
const int diagonal,
const bool lower,
const int64_t H,
const int64_t W,
T* out)
: in_(in), diagonal_(diagonal), lower_(lower), H_(H), W_(W), out_(out) {}
HOSTDEVICE void operator()(int64_t idx) {
const int64_t row = (idx / W_) % H_;
const int64_t col = idx % W_;
const bool mask =
lower_ ? (col - row > diagonal_) : (col - row < diagonal_);
out_[idx] = mask ? static_cast<T>(0) : in_[idx];
}
private:
const T* in_;
const int diagonal_;
const bool lower_;
const int64_t H_;
const int64_t W_;
T* out_;
};
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/multiplex_grad_kernel.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void MultiplexGradKernel(const Context& ctx,
const DenseTensor& ids,
const DenseTensor& out_grad,
std::vector<DenseTensor*> ins_grad) {
size_t idx = -1UL;
for (size_t i = 0; i < ins_grad.size(); i++) {
if (ins_grad[i]) {
ctx.template Alloc<T>(ins_grad[i]);
auto t = phi::EigenVector<T>::Flatten(*ins_grad[i]);
t.device(*ctx.eigen_device()) = t.constant(static_cast<T>(0));
idx = i;
}
}
if (idx == -1UL) return;
auto rows = ins_grad[idx]->dims()[0];
auto cols = ins_grad[idx]->numel() / rows;
DenseTensor index_t_cpu;
paddle::framework::TensorCopySync(ids, phi::CPUPlace(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.stream();
for (auto i = 0; i < rows; i++) {
size_t k = static_cast<size_t>(index[i]);
if (ins_grad[k]) {
paddle::memory::Copy(ctx.GetPlace(),
ins_grad[k]->data<T>() + i * cols,
ctx.GetPlace(),
out_grad.data<T>() + i * cols,
cols * sizeof(T),
stream);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(multiplex_grad,
GPU,
ALL_LAYOUT,
phi::MultiplexGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/multiplex_kernel.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void MultiplexKernel(const Context& ctx,
const std::vector<const DenseTensor*>& ins,
const DenseTensor& ids,
DenseTensor* out) {
ctx.template Alloc<T>(out);
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE_GT(
ins[i]->numel(),
0,
errors::OutOfRange(
"indexing will be out of bounds with size 0 for the %d-th input.",
i));
}
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->numel() / rows;
DenseTensor index_t_cpu;
paddle::framework::TensorCopySync(ids, phi::CPUPlace(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.stream();
for (auto i = 0; i < rows; i++) {
int32_t k = index[i];
PADDLE_ENFORCE_GE(
k, 0, errors::PreconditionNotMet("index must be nonnegative."));
PADDLE_ENFORCE_LT(static_cast<size_t>(k),
ins.size(),
errors::PreconditionNotMet(
"index exceeds the number of candidate tensors."));
paddle::memory::Copy(ctx.GetPlace(),
out->data<T>() + i * cols,
ctx.GetPlace(),
ins[k]->data<T>() + i * cols,
cols * sizeof(T),
stream);
}
}
} // namespace phi
PD_REGISTER_KERNEL(multiplex,
GPU,
ALL_LAYOUT,
phi::MultiplexKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/tril_triu_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(tril_triu_grad,
GPU,
ALL_LAYOUT,
phi::TrilTriuGradKernel,
bool,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(tril_triu,
GPU,
ALL_LAYOUT,
phi::TrilTriuKernel,
bool,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
......@@ -24,13 +24,12 @@
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/matrix_reduce.h"
#include "paddle/phi/kernels/funcs/tril_triu_compute.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/tril_triu_op.h"
namespace phi {
template <typename T, typename Context>
......@@ -115,7 +114,7 @@ void CholeskySolveGradKernel(const Context& dev_ctx,
const auto H = y_bst_dims_vec[y_bst_ndim - 2];
const auto W = y_bst_dims_vec[y_bst_ndim - 1];
phi::funcs::ForRange<Context> y_for_range(dev_ctx, dy_bst.numel());
paddle::operators::TrilTriuCompute<T> tril_triu_functor(
phi::funcs::TrilTriuCompute<T> tril_triu_functor(
dy_bst.data<T>(), 0, !upper, H, W, dy_bst_upper.data<T>());
y_for_range(tril_triu_functor);
......
......@@ -21,12 +21,11 @@
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/matrix_reduce.h"
#include "paddle/phi/kernels/funcs/tril_triu_compute.h"
#include "paddle/phi/kernels/triangular_solve_kernel.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/tril_triu_op.h"
namespace phi {
template <typename T, typename Context>
......@@ -119,7 +118,7 @@ void TriangularSolveGradKernel(const Context& dev_ctx,
const auto H = dims[dims.size() - 2];
const auto W = dims[dims.size() - 1];
phi::funcs::ForRange<Context> x_for_range(dev_ctx, dx_bst.numel());
paddle::operators::TrilTriuCompute<T> tril_triu_functor(
phi::funcs::TrilTriuCompute<T> tril_triu_functor(
dx_bst.data<T>(), unitriangular, !upper, H, W, dx_bst_upper.data<T>());
x_for_range(tril_triu_functor);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/tril_triu_grad_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/tril_triu_compute.h"
namespace phi {
template <typename T, typename Context>
void TrilTriuGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int diagonal,
bool lower,
DenseTensor* x_grad) {
const auto* dout_data = out_grad.data<T>();
auto* dx_data = ctx.template Alloc<T>(x_grad);
const auto& dims = out_grad.dims();
const auto H = dims[dims.size() - 2];
const auto W = dims[dims.size() - 1];
phi::funcs::ForRange<Context> for_range(
ctx, static_cast<size_t>(out_grad.numel()));
phi::funcs::TrilTriuCompute<T> tril_triu_grad_computer(
dout_data, diagonal, lower, H, W, dx_data);
for_range(tril_triu_grad_computer);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/tril_triu_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/tril_triu_compute.h"
namespace phi {
template <typename T, typename Context>
void TrilTriuKernel(const Context& ctx,
const DenseTensor& x,
int diagonal,
bool lower,
DenseTensor* out) {
const auto* x_data = x.data<T>();
auto* out_data = ctx.template Alloc<T>(out);
const auto& dims = x.dims();
const auto H = dims[dims.size() - 2];
const auto W = dims[dims.size() - 1];
phi::funcs::ForRange<Context> for_range(ctx, static_cast<size_t>(x.numel()));
phi::funcs::TrilTriuCompute<T> tril_triu_computer(
x_data, diagonal, lower, H, W, out_data);
for_range(tril_triu_computer);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MultiplexGradKernel(const Context& ctx,
const DenseTensor& ids,
const DenseTensor& out_grad,
std::vector<DenseTensor*> ins_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MultiplexKernel(const Context& ctx,
const std::vector<const DenseTensor*>& ins,
const DenseTensor& ids,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void QrKernel(const Context& ctx,
const DenseTensor& x,
const std::string& mode,
DenseTensor* q,
DenseTensor* r);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TrilTriuGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int diagonal,
bool lower,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TrilTriuKernel(const Context& ctx,
const DenseTensor& x,
int diagonal,
bool lower,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,24 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/tril_triu_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MultiplexOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("multiplex", {"X", "Ids"}, {}, {"Out"});
}
KernelSignature MultiplexGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"multiplex_grad", {"Ids", GradVarName("Out")}, {}, {GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(multiplex, phi::MultiplexOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(multiplex_grad, phi::MultiplexGradOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature QrOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("qr", {"X"}, {"mode"}, {"Q", "R"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(qr, phi::QrOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TrilTriuOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("tril_triu", {"X"}, {"diagonal", "lower"}, {"Out"});
}
KernelSignature TrilTriuGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("tril_triu_grad",
{GradVarName("Out")},
{"diagonal", "lower"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(tril_triu, phi::TrilTriuOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tril_triu_grad, phi::TrilTriuGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册