未验证 提交 c0a7830f 编写于 作者: W Weilong Wu 提交者: GitHub

[Phi] Migrate solve kernel to phi (#44363)

* draft version

* draft version

* draft version

* migrate solve kernel to phi

* polish

* polish

* re useless header file, fix a bug in grad_kernel_impl

* add header file in need
上级 6f7550e4
......@@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/solve_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
......@@ -220,10 +220,3 @@ REGISTER_OPERATOR(solve,
ops::SolveOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(solve_grad, ops::SolveGradOp);
REGISTER_OP_CPU_KERNEL(solve,
ops::SolveKernel<phi::CPUContext, float>,
ops::SolveKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(solve_grad,
ops::SolveGradKernel<phi::CPUContext, float>,
ops::SolveGradKernel<phi::CPUContext, double>);
此差异已折叠。
......@@ -3211,15 +3211,18 @@ void UnsqueezeInferMeta(const MetaTensor& x,
}
out->set_dtype(x.dtype());
}
// set xshape dims.
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
if (xshape) {
// set xshape dims.
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_dtype(x.dtype());
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_dtype(x.dtype());
}
void UnStackInferMeta(const MetaTensor& x,
......
......@@ -62,6 +62,7 @@ set(COMMON_KERNEL_DEPS
pooling
maxouting
matrix_inverse
matrix_solve
phi_dynload_warpctc
sequence_padding
sequence_scale)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/solve_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/solve_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
solve_grad, CPU, ALL_LAYOUT, phi::SolveGradKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/solve_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/solve_kernel_impl.h"
PD_REGISTER_KERNEL(solve, CPU, ALL_LAYOUT, phi::SolveKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/solve_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/solve_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
solve_grad, GPU, ALL_LAYOUT, phi::SolveGradKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/solve_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/solve_kernel_impl.h"
PD_REGISTER_KERNEL(solve, GPU, ALL_LAYOUT, phi::SolveKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/expand_as_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_solve.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/impl/solve_kernel_impl.h"
#include "paddle/phi/kernels/squeeze_kernel.h"
#include "paddle/phi/kernels/unsqueeze_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/phi/kernels/gpu/reduce.h"
#endif
namespace phi {
template <typename Context, typename T>
struct ReduceSumForSolvelGrad {
void operator()(const Context& dev_ctx,
const DenseTensor& input,
DenseTensor* output,
const std::vector<int>& reduce_dims,
bool keep_dims);
};
template <typename T>
struct ReduceSumForSolvelGrad<CPUContext, T> {
void operator()(const CPUContext& dev_ctx,
const DenseTensor& input,
DenseTensor* output,
const std::vector<int>& reduce_dims,
bool keep_dims) {
std::vector<int64_t> reduce_dims_tmp(reduce_dims.begin(),
reduce_dims.end());
phi::ReduceKernelImpl<CPUContext, T, T, phi::funcs::SumFunctor>(
dev_ctx, input, output, reduce_dims_tmp, keep_dims, false);
}
};
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
struct ReduceSumForSolvelGrad<GPUContext, T> {
void operator()(const GPUContext& dev_ctx,
const DenseTensor& input,
DenseTensor* output,
const std::vector<int>& reduce_dims,
bool keep_dims) {
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, input, output, kps::IdentityFunctor<T>(), reduce_dims);
}
};
#endif
template <typename T, typename Context>
void SolveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& out,
DenseTensor* dx,
DenseTensor* dy) {
bool is_vector = false;
is_vector = is_vector_rhs(x, y);
DenseTensor tmp_y;
if (is_vector) {
dev_ctx.Alloc(&tmp_y, y.dtype());
phi::Unsqueeze<T, Context>(dev_ctx, y, {-1}, &tmp_y, nullptr);
} else {
tmp_y.Resize(y.dims());
dev_ctx.Alloc(&tmp_y, y.dtype());
phi::Copy(dev_ctx, y, dev_ctx.GetPlace(), false, &tmp_y);
}
DenseTensor tmp_x;
tmp_x.Resize(x.dims());
dev_ctx.Alloc(&tmp_x, x.dtype());
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, &tmp_x);
std::vector<int64_t> x_broadcast_dims;
std::vector<int64_t> y_broadcast_dims;
std::tie(x_broadcast_dims, y_broadcast_dims) =
get_broadcast_dims(tmp_x, tmp_y);
// tmp_dx
DenseTensor tmp_dx;
tmp_dx.Resize(phi::make_ddim(x_broadcast_dims));
dev_ctx.template Alloc<T>(&tmp_dx);
// tmp_dy
DenseTensor tmp_dy;
tmp_dy.Resize(phi::make_ddim(y_broadcast_dims));
dev_ctx.template Alloc<T>(&tmp_dy);
DenseTensor tmp_input(x.dtype());
const auto& new_dims_vec = phi::funcs::getNewDimsVec(x.dims());
tmp_input.Resize(phi::make_ddim(new_dims_vec));
dev_ctx.template Alloc<T>(&tmp_input);
phi::funcs::TransposeNormal<Context, T> trans;
std::vector<int> new_axis = phi::funcs::getNewAxis(x.dims().size());
trans(dev_ctx, x, &tmp_input, new_axis);
if (dy) {
dev_ctx.template Alloc<T>(dy);
linalg_solve<Context, T>(dev_ctx, tmp_input, dout, &tmp_dy);
}
if (dx) {
dev_ctx.template Alloc<T>(dx);
// to get dx
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
if (x.dims().size() == 2 && y.dims().size() == 2) {
auto mat_dim_a1 =
phi::funcs::CreateMatrixDescriptor(tmp_dy.dims(), 0, false);
auto mat_dim_b1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
blas.MatMul(tmp_dy, mat_dim_a1, out, mat_dim_b1, T(-1), &tmp_dx, T(0));
} else if (is_vector_rhs(x, y)) {
DenseTensor tmp_dy_;
dev_ctx.Alloc(&tmp_dy_, y.dtype());
phi::Unsqueeze<T, Context>(dev_ctx,
tmp_dy,
paddle::experimental::IntArray({-1}),
&tmp_dy_,
nullptr);
DenseTensor tmp_out_;
dev_ctx.Alloc(&tmp_out_, out.dtype());
phi::Unsqueeze<T, Context>(dev_ctx,
out,
paddle::experimental::IntArray({-1}),
&tmp_out_,
nullptr);
auto mat_dim_a1 =
phi::funcs::CreateMatrixDescriptor(tmp_dy_.dims(), 0, false);
auto mat_dim_b1 =
phi::funcs::CreateMatrixDescriptor(tmp_out_.dims(), 0, true);
blas.MatMul(
tmp_dy_, mat_dim_a1, tmp_out_, mat_dim_b1, T(-1), &tmp_dx, T(0));
} else {
auto mat_dim_a1 =
phi::funcs::CreateMatrixDescriptor(tmp_dy.dims(), 0, false);
auto mat_dim_b1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true);
blas.MatMul(tmp_dy, mat_dim_a1, out, mat_dim_b1, T(-1), &tmp_dx, T(0));
}
}
if (y.dims() != tmp_dy.dims()) {
DenseTensor dy_help;
dy_help.Resize(tmp_dy.dims());
dev_ctx.Alloc(&dy_help, tmp_dy.dtype());
phi::Copy(dev_ctx, tmp_dy, dev_ctx.GetPlace(), false, &dy_help);
// get dims
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> dout_dims = vectorize(dout.dims());
if (is_vector_rhs(x, y)) {
dout_dims.push_back(1);
}
int y_ndim = y_dims.size();
int ndim = dout_dims.size();
const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims());
std::vector<std::int64_t> dy_broadcast_dims(ndim);
std::fill(
dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1);
std::copy(y_dims.data(),
y_dims.data() + y_ndim,
dy_broadcast_dims.data() + ndim - y_ndim);
std::vector<int> dy_reduce_dims;
for (int idx = 0; idx <= ndim - 3; idx++) {
if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) {
dy_reduce_dims.push_back(idx);
}
}
// reduce sum to get grad by ReduceSum
if (dy) {
if (dy_reduce_dims.empty()) {
*dy = std::move(dy_help);
} else {
bool keep_dim = true;
if (dy_help.dims().size() != dy->dims().size()) {
keep_dim = false;
}
ReduceSumForSolvelGrad<Context, T>()(
dev_ctx, dy_help, dy, dy_reduce_dims, keep_dim);
}
dy->Resize(y.dims());
}
} else {
phi::Copy(dev_ctx, tmp_dy, dev_ctx.GetPlace(), false, dy);
}
if (x.dims() != tmp_dx.dims()) {
DenseTensor dx_help;
dx_help.Resize(tmp_dx.dims());
dev_ctx.Alloc(&dx_help, tmp_dx.dtype());
phi::Copy(dev_ctx, tmp_dx, dev_ctx.GetPlace(), false, &dx_help);
// get dims
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
int x_ndim = x_dims.size();
int ndim = x_broadcast_dims.size();
const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::fill(
dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1);
std::copy(x_dims.data(),
x_dims.data() + x_ndim,
dx_broadcast_dims.data() + ndim - x_ndim);
std::vector<int> dx_reduce_dims;
for (int idx = 0; idx <= ndim - 3; idx++) {
if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) {
dx_reduce_dims.push_back(idx);
}
}
// reduce sum to get grad by ReduceSum
if (dx) {
dev_ctx.template Alloc<T>(dx);
if (dx_reduce_dims.empty()) {
*dx = std::move(dx_help);
} else {
bool keep_dim = true;
if (dx_help.dims().size() != dx->dims().size()) {
keep_dim = false;
}
ReduceSumForSolvelGrad<Context, T>()(
dev_ctx, dx_help, dx, dx_reduce_dims, keep_dim);
}
dx->Resize(x.dims());
}
} else {
phi::Copy(dev_ctx, tmp_dx, dev_ctx.GetPlace(), false, dx);
}
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/expand_as_kernel.h"
#include "paddle/phi/kernels/funcs/matrix_solve.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/squeeze_kernel.h"
#include "paddle/phi/kernels/unsqueeze_kernel.h"
namespace phi {
using Tensor = DenseTensor;
// check the input other is vector_case or not
static inline bool is_vector_rhs(const DenseTensor& input,
const DenseTensor& other) {
auto x_dim = input.dims();
auto y_dim = other.dims();
auto x_dim_size = x_dim.size();
auto y_dim_size = y_dim.size();
std::vector<int64_t> x_dims_vec = phi::vectorize(x_dim);
std::vector<int64_t> y_dims_vec = phi::vectorize(y_dim);
std::vector<int64_t>::const_iterator f = x_dims_vec.begin();
std::vector<int64_t>::const_iterator l = x_dims_vec.end() - 1;
std::vector<int64_t> x_dims_vec_cut(f, l); // input.shape[:-1]
std::vector<int64_t> expected_batched_rhs_shape(x_dims_vec_cut);
bool vector_case =
y_dim_size == 1 || (x_dim_size - 1 == y_dim_size &&
y_dims_vec == (expected_batched_rhs_shape));
return vector_case;
}
// Prepared for the broadcast operation
static std::vector<int64_t> get_broadcast_batch_portion(
std::vector<int64_t> x, std::vector<int64_t> y) {
size_t size_x = x.size();
size_t size_y = y.size();
size_t size = std::max(size_x, size_y);
std::vector<int64_t> batchPortion(size);
ptrdiff_t i = (ptrdiff_t)size - 1;
for (; i >= 0; --i) {
ptrdiff_t offset = size - i - 1;
ptrdiff_t dim_x = size_x - offset - 1;
ptrdiff_t dim_y = size_y - offset - 1;
int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1;
int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1;
PADDLE_ENFORCE_EQ(
(x_size == y_size || x_size == 1 || y_size == 1),
true,
phi::errors::PreconditionNotMet(
"The size of tensor x (%d) must match the size of tensor y "
"(%d) at non-singleton dimension %d.",
x_size,
y_size,
i));
batchPortion[i] = x_size != 1 ? x_size : y_size;
}
return batchPortion;
}
static inline std::vector<int> convert_to_int_vec(std::vector<int64_t> a) {
std::vector<int> ret;
for (size_t i = 0; i < a.size(); i++) {
ret.emplace_back(int(a[i]));
}
return ret;
}
// broadcast the batch dimensions of tensor x and tensor y.
static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_broadcast_dims(const Tensor& x, const Tensor& y) {
std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
std::vector<int64_t> y_dims_vec = phi::vectorize(y.dims());
std::vector<int64_t>::const_iterator f1 = x_dims_vec.begin();
std::vector<int64_t>::const_iterator l1 = x_dims_vec.end() - 2;
std::vector<int64_t> x_dims_vec_cut(f1, l1);
std::vector<int64_t>::const_iterator f2 = y_dims_vec.begin();
std::vector<int64_t>::const_iterator l2 = y_dims_vec.end() - 2;
std::vector<int64_t> y_dims_vec_cut(f2, l2);
std::vector<int64_t> expand_batch_portion =
get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut);
std::vector<int64_t> x_expand_size({expand_batch_portion});
x_expand_size.insert(x_expand_size.end(),
{x_dims_vec[static_cast<int>(x_dims_vec.size()) - 2],
x_dims_vec[static_cast<int>(x_dims_vec.size()) - 1]});
std::vector<int64_t> y_expand_size({expand_batch_portion});
y_expand_size.insert(y_expand_size.end(),
{y_dims_vec[static_cast<int>(y_dims_vec.size()) - 2],
y_dims_vec[static_cast<int>(y_dims_vec.size()) - 1]});
return std::make_tuple(x_expand_size, y_expand_size);
}
template <typename Context, typename T>
static void linalg_solve(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
phi::funcs::MatrixSolveFunctor<Context, T> mat_solve;
// input y can be vector or matrix
// but need to be unsqueezed if y is a vector
bool is_vector = false;
is_vector = is_vector_rhs(x, y);
Tensor tmp_y;
if (is_vector) {
dev_ctx.Alloc(&tmp_y, y.dtype());
phi::Unsqueeze<T, Context>(dev_ctx, y, {-1}, &tmp_y, nullptr);
} else {
tmp_y.Resize(y.dims());
dev_ctx.Alloc(&tmp_y, y.dtype());
phi::Copy(dev_ctx, y, dev_ctx.GetPlace(), false, &tmp_y);
}
Tensor tmp_x;
tmp_x.Resize(x.dims());
dev_ctx.Alloc(&tmp_x, x.dtype());
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, &tmp_x);
std::vector<int64_t> x_broadcast_dims;
std::vector<int64_t> y_broadcast_dims;
std::tie(x_broadcast_dims, y_broadcast_dims) =
get_broadcast_dims(tmp_x, tmp_y);
Tensor tmp_x_bc;
phi::ExpandAsKernel<T, Context>(
dev_ctx, tmp_x, nullptr, convert_to_int_vec(x_broadcast_dims), &tmp_x_bc);
Tensor tmp_y_bc;
phi::ExpandAsKernel<T, Context>(
dev_ctx, tmp_y, nullptr, convert_to_int_vec(y_broadcast_dims), &tmp_y_bc);
auto x_dim = x.dims();
auto y_dim = y.dims();
auto x_dim_size = x_dim.size();
auto y_dim_size = y_dim.size();
if (is_vector) { // vector case
out->Resize(tmp_y_bc.dims()); // out.unsqueeze(-1)
mat_solve(dev_ctx, tmp_x_bc, tmp_y_bc, out);
Tensor out_tmp;
out_tmp.Resize(out->dims());
out_tmp = *out;
phi::SqueezeKernel<T, Context>(dev_ctx, out_tmp, {-1}, out, nullptr);
} else {
PADDLE_ENFORCE_EQ(
x_dim[x_dim_size - 1],
y_dim[y_dim_size - 2],
phi::errors::InvalidArgument(
"Matrix X1 with dimension greater than 2 and any matrix Y1,"
"the matrix X1's width must be equal with matrix Y1's "
"height. But received X's shape = [%s], X1's shape = [%s], X1's "
"width = %s; Y's shape = [%s], Y1's shape = [%s], Y1's height = "
"%s.",
x_dim,
x_dim,
x_dim[x_dim_size - 1],
y_dim,
y_dim,
y_dim[y_dim_size - 2]));
mat_solve(dev_ctx, tmp_x_bc, tmp_y_bc, out);
}
}
template <typename T, typename Context>
void SolveKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
linalg_solve<Context, T>(dev_ctx, x, y, out);
}
} // namespace phi
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,14 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/solve_op.h"
#pragma once
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(solve,
ops::SolveKernel<plat::CUDADeviceContext, float>,
ops::SolveKernel<plat::CUDADeviceContext, double>);
#include "paddle/phi/core/dense_tensor.h"
REGISTER_OP_CUDA_KERNEL(solve_grad,
ops::SolveGradKernel<plat::CUDADeviceContext, float>,
ops::SolveGradKernel<plat::CUDADeviceContext, double>);
namespace phi {
template <typename T, typename Context>
void SolveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& out,
DenseTensor* dx,
DenseTensor* dy);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SolveKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
} // namespace phi
......@@ -17,6 +17,7 @@
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
......@@ -26,4 +27,16 @@ void UnsqueezeKernel(const Context& dev_ctx,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
void Unsqueeze(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
MetaTensor meta_out(out);
UnsqueezeInferMeta(x, axes, &meta_out, nullptr, MetaConfig());
UnsqueezeKernel<T, Context>(dev_ctx, x, axes, out, nullptr);
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SolveGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"solve_grad", {"X", "Y", "Out@GRAD", "Out"}, {}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(solve_grad, phi::SolveGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册