未验证 提交 b695fd95 编写于 作者: A Aganlengzi 提交者: GitHub

[phi]migrate increment addmm multinomial cholesky kernels to phi (#39858)

* migrate increment addmm multinomial cholesky kernels to phi

* test pr39869

* test pr39869

* fix style and ci
上级 127440c3
......@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/addmm_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
......@@ -24,6 +24,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
constexpr int kMULMKLDNNINT8 = 1;
using framework::OpKernelType;
using framework::Tensor;
......@@ -227,11 +229,3 @@ REGISTER_OPERATOR(addmm, ops::AddMMOp, ops::AddMMOpMaker,
ops::AddMMOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(addmm_grad, ops::AddMMGradOp);
REGISTER_OP_CPU_KERNEL(
addmm, ops::AddMMKernel<paddle::platform::CPUDeviceContext, float>,
ops::AddMMKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
addmm_grad, ops::AddMMGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::AddMMGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <boost/preprocessor/repetition/repeat.hpp>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Array1 = Eigen::DSizes<Eigen::DenseIndex, 1>;
using Array2 = Eigen::DSizes<Eigen::DenseIndex, 2>;
using Tensor = framework::Tensor;
constexpr int kMULMKLDNNINT8 = 1;
template <typename DeviceContext, typename T>
class AddMMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* x = context.Input<Tensor>("X");
const Tensor* y = context.Input<Tensor>("Y");
auto input_dims = input->dims();
auto x_dims = x->dims();
auto y_dims = y->dims();
// broadcast mode check
if (x_dims[0] != input_dims[0]) {
PADDLE_ENFORCE_EQ(input_dims[0], 1,
platform::errors::InvalidArgument(
"When x_dims[0] is not equal with input_dims[0], "
"input_dims[0] must be 1 but got %s",
input_dims[0]));
PADDLE_ENFORCE_EQ(
y_dims[1] == input_dims[1] || input_dims[1] == 1, true,
platform::errors::InvalidArgument(
"The input tensor shape mismatch, input shape=[%s], "
"x shape=[%s], y shape=[%s]",
input_dims, x_dims, y_dims));
}
// broadcast mode check
if (y_dims[1] != input_dims[1]) {
PADDLE_ENFORCE_EQ(input_dims[1], 1,
platform::errors::InvalidArgument(
"When y_dims[1] is not equal with input_dims[0], "
"input_dims[0] must be 1 but got %s",
input_dims[1]));
PADDLE_ENFORCE_EQ(
x_dims[0] == input_dims[0] || input_dims[0] == 1, true,
platform::errors::InvalidArgument(
"The input tensor shape mismatch, input shape=[%s], "
"x shape=[%s], y shape=[%s]",
input_dims, x_dims, y_dims));
}
// broadcast mode check
PADDLE_ENFORCE_EQ(
x_dims[1], y_dims[0],
platform::errors::InvalidArgument(
"The input tensor X's width must be equal with matrix Y' height. "
"But received X's shape = [%s], Y's shape = [%s].",
x_dims[1], y_dims[0]));
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>({x_dims[0], y_dims[1]}, context.GetPlace());
float alpha = context.template Attr<float>("Alpha");
float beta = context.template Attr<float>("Beta");
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
// calc broadcast dim
Array2 bcast_dims;
bcast_dims[0] = x_dims[0] / input_dims[0];
bcast_dims[1] = y_dims[1] / input_dims[1];
VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << "]";
// broadcast using eigen
auto eigen_input = EigenTensor<T, 2>::From(*input);
auto eigen_out = EigenTensor<T, 2>::From(*out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval(
place, eigen_out, eigen_input, bcast_dims);
blas.GEMM(false, false, x_dims[0], y_dims[1], x_dims[1], alpha,
x->data<T>(), x_dims[1], y->data<T>(), y_dims[1], beta,
out->data<T>(), y_dims[1]);
}
};
template <typename DeviceContext, typename T>
class AddMMGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto in_dims = ctx.Input<framework::LoDTensor>("Input")->dims();
auto* dinput =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Input"));
auto* dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
float alpha = ctx.Attr<float>("Alpha");
float beta = ctx.Attr<float>("Beta");
int total_elems = 0;
VLOG(3) << "alpha: " << alpha << " beta: " << beta;
if (dinput != nullptr) {
dinput->set_lod(dout->lod());
}
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
if (dinput) {
dinput->mutable_data<T>(ctx.GetPlace());
total_elems = in_dims[0] * in_dims[1];
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto eigen_dout = EigenTensor<T, 2>::From(*dout);
auto eigen_dinput = EigenTensor<T, 2>::From(*dinput);
bool row_compress = in_dims[0] != dout->dims()[0];
bool col_compress = in_dims[1] != dout->dims()[1];
auto eigen_dinput_shape = Array2(dinput->dims()[0], dinput->dims()[1]);
if (row_compress && col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum().eval().reshape(eigen_dinput_shape);
} else if (row_compress) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
} else if (col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
} else {
blas.VCOPY(total_elems, dout->data<T>(), dinput->data<T>());
}
blas.SCAL(total_elems, beta, dinput->data<T>());
}
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
total_elems = x->dims()[0] * x->dims()[1];
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
blas.MatMul(*dout, false, *y, true, dx);
blas.SCAL(total_elems, alpha, dx->data<T>());
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
total_elems = x->dims()[1] * y->dims()[1];
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
blas.MatMul(*x, true, *dout, false, dy);
blas.SCAL(total_elems, alpha, dy->data<T>());
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cholesky_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -111,11 +111,3 @@ REGISTER_OPERATOR(cholesky, ops::CholeskyOp, ops::CholeskyOpMaker,
ops::CholeskyGradOpMaker<paddle::framework::OpDesc>,
ops::CholeskyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cholesky_grad, ops::CholeskyGradOp);
REGISTER_OP_CPU_KERNEL(cholesky, ops::CholeskyCPUKernel<float>,
ops::CholeskyCPUKernel<double>);
REGISTER_OP_CPU_KERNEL(
cholesky_grad,
ops::CholeskyGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::CholeskyGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include <thrust/device_vector.h>
#include <algorithm>
#include <vector>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/cholesky_op.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
namespace paddle {
namespace operators {
template <typename T>
class CholeskyGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
bool upper = context.Attr<bool>("upper");
auto& dims = x->dims();
int batch_count = 1;
for (int i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}
int m = dims[dims.size() - 1];
int tensor_size = batch_count * m * m;
const auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
// matrices are assumed to be stored in column-major order in cusolver
cublasFillMode_t uplo =
upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
// portf is inplace, thus copy the triangular part of the input matrices to
// the output and set the other triangular part to 0 firstly
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
tensor_size);
if (upper) {
MatrixBandPartFunctor<T> matrix_band_part_functor(
m, m, /* num_lower_diags */ 0, /* num_upper_diags */ m, x_data,
out_data);
for_range(matrix_band_part_functor);
} else {
MatrixBandPartFunctor<T> matrix_band_part_functor(
m, m, /* num_lower_diags */ m, /* num_upper_diags */ 0, x_data,
out_data);
for_range(matrix_band_part_functor);
}
auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_count);
auto* info_ptr = reinterpret_cast<int*>(info->ptr());
#if CUDA_VERSION >= 9020 && !defined(_WIN32)
if (batch_count > 1) {
std::vector<T*> output_ptrs;
for (int i = 0; i < batch_count; i++) {
output_ptrs.emplace_back(out_data + i * m * m);
}
thrust::device_vector<T*> dev_output_ptrs(output_ptrs.begin(),
output_ptrs.end());
PotrfBatched(dev_ctx, uplo, m,
thrust::raw_pointer_cast(dev_output_ptrs.data()), m,
info_ptr, batch_count);
// TODO(guosheng): There seems to a bug in cusolver potrfBatched and need
// to clear the upper triangle of the output. Remove this workaround once
// the bug is fixed.
if (!upper) {
MatrixBandPartFunctor<T> matrix_band_part_functor(
m, m, /* num_lower_diags */ m, /* num_upper_diags */ 0, out_data,
out_data);
for_range(matrix_band_part_functor);
}
} else {
#endif
for (int i = 0; i < batch_count; i++) {
Potrf(dev_ctx, uplo, m, out_data + i * m * m, m, info_ptr + i);
}
#if CUDA_VERSION >= 9020 && !defined(_WIN32)
}
#endif
// check the info
std::vector<int> error_info; // only for checking positive matrix
error_info.resize(batch_count);
memory::Copy(platform::CPUPlace(), error_info.data(), dev_ctx.GetPlace(),
info_ptr, sizeof(int) * batch_count, dev_ctx.stream());
for (int i = 0; i < batch_count; ++i) {
PADDLE_ENFORCE_EQ(error_info[i], 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: U(%d, %d) is zero, singular U.", i,
error_info[i], error_info[i]));
}
}
void Potrf(const platform::CUDADeviceContext& dev_ctx, cublasFillMode_t uplo,
int n, T* A, int lda, int* info) const;
void PotrfBatched(const platform::CUDADeviceContext& dev_ctx,
cublasFillMode_t uplo, int n, T* Aarray[], int lda,
int* info_array, int batch_size) const;
};
#define FUNC_WITH_TYPES(m) m(float, S) m(double, D)
#define POTRF_INSTANCE(T, C) \
template <> \
void CholeskyGPUKernel<T>::Potrf(const platform::CUDADeviceContext& dev_ctx, \
cublasFillMode_t uplo, int n, T* A, \
int lda, int* info) const { \
auto handle = dev_ctx.cusolver_dn_handle(); \
int workspace_size = 0; \
PADDLE_ENFORCE_GPU_SUCCESS( \
platform::dynload::cusolverDn##C##potrf_bufferSize( \
handle, uplo, n, A, lda, &workspace_size)); \
auto workspace = memory::Alloc(dev_ctx, workspace_size); \
T* workspace_ptr = reinterpret_cast<T*>(workspace->ptr()); \
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDn##C##potrf( \
handle, uplo, n, A, lda, workspace_ptr, workspace_size, info)); \
}
FUNC_WITH_TYPES(POTRF_INSTANCE);
#if CUDA_VERSION >= 9020 && !defined(_WIN32)
#define POTRF_BATCH_INSTANCE(T, C) \
template <> \
void CholeskyGPUKernel<T>::PotrfBatched( \
const platform::CUDADeviceContext& dev_ctx, cublasFillMode_t uplo, \
int n, T* Aarray[], int lda, int* info_array, int batch_size) const { \
auto handle = dev_ctx.cusolver_dn_handle(); \
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDn##C##potrfBatched( \
handle, uplo, n, Aarray, lda, info_array, batch_size)); \
}
FUNC_WITH_TYPES(POTRF_BATCH_INSTANCE);
#endif
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(cholesky, ops::CholeskyGPUKernel<float>,
ops::CholeskyGPUKernel<double>);
REGISTER_OP_CUDA_KERNEL(
cholesky_grad,
ops::CholeskyGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CholeskyGradKernel<paddle::platform::CUDADeviceContext, double>);
#endif // not PADDLE_WITH_HIP
......@@ -12,9 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/increment_op.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
......@@ -101,14 +99,3 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker,
ops::IncrementGradOpMaker<paddle::framework::OpDesc>,
ops::IncrementGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
increment, ops::IncrementKernel<paddle::platform::CPUDeviceContext, float>,
ops::IncrementKernel<paddle::platform::CPUDeviceContext, double>,
ops::IncrementKernel<paddle::platform::CPUDeviceContext, int>,
ops::IncrementKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
increment, ops::IncrementKernel<paddle::platform::CUDADeviceContext, float>,
ops::IncrementKernel<paddle::platform::CUDADeviceContext, double>,
ops::IncrementKernel<paddle::platform::CUDADeviceContext, int>,
ops::IncrementKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/increment_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/multinomial_op.h"
#include <algorithm>
#include <string>
......@@ -80,29 +79,6 @@ class MultinomialOp : public framework::OperatorWithKernel {
}
};
template <typename T>
class MultinomialOpKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto x = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
const int64_t num_samples = ctx.Attr<int>("num_samples");
const bool replacement = ctx.Attr<bool>("replacement");
auto *in_data = x->data<T>();
int64_t *out_data = out->mutable_data<int64_t>(ctx.GetPlace());
auto in_dims = x->dims();
int64_t in_rank = in_dims.size();
const int64_t num_categories = in_dims[in_rank - 1];
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;
MultinomialFunctor<T>(out_data, in_data, num_samples, replacement,
num_categories, num_distributions);
}
};
} // namespace operators
} // namespace paddle
......@@ -112,7 +88,3 @@ REGISTER_OPERATOR(
multinomial, ops::MultinomialOp, ops::MultinomialOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
multinomial, ops::MultinomialOpKernel<plat::CPUDeviceContext, float>,
ops::MultinomialOpKernel<plat::CPUDeviceContext, double>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void AddmmGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
float alpha,
float beta,
DenseTensor* input_grad,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,13 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/addmm_op.h"
#pragma once
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#include "paddle/phi/core/dense_tensor.h"
REGISTER_OP_CUDA_KERNEL(addmm, ops::AddMMKernel<plat::CUDADeviceContext, float>,
ops::AddMMKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(addmm_grad,
ops::AddMMGradKernel<plat::CUDADeviceContext, float>,
ops::AddMMGradKernel<plat::CUDADeviceContext, double>);
namespace phi {
template <typename T, typename Context>
void AddmmKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& y,
float alpha,
float beta,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CholeskyGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
bool upper,
DenseTensor* x_grad);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CholeskyKernel(const Context& dev_ctx,
const DenseTensor& x,
bool upper,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/addmm_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
addmm_grad, CPU, ALL_LAYOUT, phi::AddmmGradKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/addmm_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_kernel_impl.h"
PD_REGISTER_KERNEL(addmm, CPU, ALL_LAYOUT, phi::AddmmKernel, float, double) {}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,30 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/phi/kernels/cholesky_grad_kernel.h"
namespace paddle {
namespace operators {
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/cholesky_grad_kernel_impl.h"
template <typename DeviceContext, typename T>
class IncrementKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x_tensor = context.Input<framework::Tensor>("X");
auto* out_tensor = context.Output<framework::Tensor>("Out");
float step = context.Attr<float>("step");
out_tensor->mutable_data<T>(context.GetPlace());
auto& dev =
*context.template device_context<DeviceContext>().eigen_device();
EigenAdd<std::decay_t<decltype(dev)>, T>::Eval(
dev, framework::EigenScalar<T>::From(*out_tensor),
framework::EigenScalar<T>::From(*x_tensor), static_cast<T>(step));
}
};
} // namespace operators
} // namespace paddle
PD_REGISTER_KERNEL(
cholesky_grad, CPU, ALL_LAYOUT, phi::CholeskyGradKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/cholesky_kernel.h"
#include "Eigen/Cholesky"
#include "Eigen/Core"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void CholeskyKernel(const Context& dev_ctx,
const DenseTensor& x,
bool upper,
DenseTensor* out) {
using EigenMatrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using InputMatrixMap = Eigen::Map<const EigenMatrix>;
using OutputMatrixMap = Eigen::Map<EigenMatrix>;
auto& dims = x.dims();
int batch_count = 1;
for (int i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}
auto m = dims[dims.size() - 1];
const auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
// Cholesky decomposition for each matrix, maybe can use multi threads
for (int i = 0; i < batch_count; i++) {
auto input = InputMatrixMap(x_data + i * m * m, m, m);
auto output = OutputMatrixMap(out_data + i * m * m, m, m);
if (upper) {
Eigen::LLT<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>,
Eigen::UpLoType::Upper>
llt_decomposition(input);
PADDLE_ENFORCE_EQ(llt_decomposition.info(),
Eigen::Success,
errors::InvalidArgument(
"Cholesky decomposition was not successful. The "
"%d-th input matrice "
"might not be not be positive definite.",
i));
output = llt_decomposition.matrixU();
} else {
Eigen::LLT<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>,
Eigen::UpLoType::Lower>
llt_decomposition(input);
PADDLE_ENFORCE_EQ(llt_decomposition.info(),
Eigen::Success,
errors::InvalidArgument(
"Cholesky decomposition was not successful. The "
"%d-th input matrice "
"might not be not be positive definite.",
i));
output = llt_decomposition.matrixL();
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
cholesky, CPU, ALL_LAYOUT, phi::CholeskyKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/increment_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/increment_kernel_impl.h"
PD_REGISTER_KERNEL(increment,
CPU,
ALL_LAYOUT,
phi::IncrementKernel,
float,
double,
int,
int64_t) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/multinomial_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void MultinomialKernel(const Context& dev_ctx,
const DenseTensor& x,
int num_samples,
bool replacement,
DenseTensor* out) {
auto* in_data = x.data<T>();
int64_t* out_data = dev_ctx.template Alloc<int64_t>(out);
auto in_dims = x.dims();
int64_t in_rank = in_dims.size();
const int64_t num_categories = in_dims[in_rank - 1];
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;
MultinomialFunctor<T>(out_data,
in_data,
num_samples,
replacement,
num_categories,
num_distributions);
}
} // namespace phi
PD_REGISTER_KERNEL(
multinomial, CPU, ALL_LAYOUT, phi::MultinomialKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/addmm_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
addmm_grad, GPU, ALL_LAYOUT, phi::AddmmGradKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/addmm_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_kernel_impl.h"
PD_REGISTER_KERNEL(addmm, GPU, ALL_LAYOUT, phi::AddmmKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cholesky_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/cholesky_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
cholesky_grad, GPU, ALL_LAYOUT, phi::CholeskyGradKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include "paddle/phi/kernels/cholesky_kernel.h"
#include <thrust/device_vector.h>
#include <algorithm>
#include <vector>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
struct MatrixBandPartFunctor {
/*! Set output as input value outside a central band and 0 inside that band.
* That is: output[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n]
* where: in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper
* < 0 || (n-m) <= num_upper)
*/
MatrixBandPartFunctor(const int m,
const int n,
const int num_lower_diags,
const int num_upper_diags,
const T* input,
T* output)
: m_(m),
n_(n),
num_lower_diags_(num_lower_diags),
num_upper_diags_(num_upper_diags),
input_(input),
output_(output) {}
HOSTDEVICE void operator()(size_t index) const {
const int col = index % n_;
const int row = (index / n_) % m_;
const int band_start = (num_lower_diags_ < 0 ? 0 : row - num_lower_diags_);
const int band_end =
(num_upper_diags_ < 0 ? n_ : row + num_upper_diags_ + 1);
if (col < band_start || col >= band_end) {
output_[index] = static_cast<T>(0);
} else {
output_[index] = input_[index];
}
}
const int m_, n_, num_lower_diags_, num_upper_diags_;
const T* input_;
T* output_;
};
#define FUNC_WITH_TYPES(m) m(float, S) m(double, D)
#define POTRF_INSTANCE(T, C) \
void Potrf(const GPUContext& dev_ctx, \
cublasFillMode_t uplo, \
int n, \
T* A, \
int lda, \
int* info) { \
auto handle = dev_ctx.cusolver_dn_handle(); \
int workspace_size = 0; \
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDn##C##potrf_bufferSize( \
handle, uplo, n, A, lda, &workspace_size)); \
auto workspace = paddle::memory::Alloc(dev_ctx, workspace_size); \
T* workspace_ptr = reinterpret_cast<T*>(workspace->ptr()); \
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDn##C##potrf( \
handle, uplo, n, A, lda, workspace_ptr, workspace_size, info)); \
}
FUNC_WITH_TYPES(POTRF_INSTANCE);
#if CUDA_VERSION >= 9020 && !defined(_WIN32)
#define POTRF_BATCH_INSTANCE(T, C) \
void PotrfBatched(const GPUContext& dev_ctx, \
cublasFillMode_t uplo, \
int n, \
T* Aarray[], \
int lda, \
int* info_array, \
int batch_size) { \
auto handle = dev_ctx.cusolver_dn_handle(); \
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDn##C##potrfBatched( \
handle, uplo, n, Aarray, lda, info_array, batch_size)); \
}
FUNC_WITH_TYPES(POTRF_BATCH_INSTANCE);
#endif
template <typename T, typename Context>
void CholeskyKernel(const Context& dev_ctx,
const DenseTensor& x,
bool upper,
DenseTensor* out) {
auto& dims = x.dims();
int batch_count = 1;
for (int i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}
int m = dims[dims.size() - 1];
int tensor_size = batch_count * m * m;
const auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
// matrices are assumed to be stored in column-major order in cusolver
cublasFillMode_t uplo =
upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
// portf is inplace, thus copy the triangular part of the input matrices to
// the output and set the other triangular part to 0 firstly
paddle::platform::ForRange<GPUContext> for_range(dev_ctx, tensor_size);
if (upper) {
MatrixBandPartFunctor<T> matrix_band_part_functor(m,
m,
/* num_lower_diags */ 0,
/* num_upper_diags */ m,
x_data,
out_data);
for_range(matrix_band_part_functor);
} else {
MatrixBandPartFunctor<T> matrix_band_part_functor(m,
m,
/* num_lower_diags */ m,
/* num_upper_diags */ 0,
x_data,
out_data);
for_range(matrix_band_part_functor);
}
auto info = paddle::memory::Alloc(dev_ctx, sizeof(int) * batch_count);
auto* info_ptr = reinterpret_cast<int*>(info->ptr());
#if CUDA_VERSION >= 9020 && !defined(_WIN32)
if (batch_count > 1) {
std::vector<T*> output_ptrs;
for (int i = 0; i < batch_count; i++) {
output_ptrs.emplace_back(out_data + i * m * m);
}
thrust::device_vector<T*> dev_output_ptrs(output_ptrs.begin(),
output_ptrs.end());
PotrfBatched(dev_ctx,
uplo,
m,
thrust::raw_pointer_cast(dev_output_ptrs.data()),
m,
info_ptr,
batch_count);
// TODO(guosheng): There seems to a bug in cusolver potrfBatched and need
// to clear the upper triangle of the output. Remove this workaround once
// the bug is fixed.
if (!upper) {
MatrixBandPartFunctor<T> matrix_band_part_functor(m,
m,
/* num_lower_diags */ m,
/* num_upper_diags */ 0,
out_data,
out_data);
for_range(matrix_band_part_functor);
}
} else {
#endif
for (int i = 0; i < batch_count; i++) {
Potrf(dev_ctx, uplo, m, out_data + i * m * m, m, info_ptr + i);
}
#if CUDA_VERSION >= 9020 && !defined(_WIN32)
}
#endif
// check the info
std::vector<int> error_info; // only for checking positive matrix
error_info.resize(batch_count);
paddle::memory::Copy(CPUPlace(),
error_info.data(),
dev_ctx.GetPlace(),
info_ptr,
sizeof(int) * batch_count,
dev_ctx.stream());
for (int i = 0; i < batch_count; ++i) {
PADDLE_ENFORCE_EQ(error_info[i],
0,
errors::PreconditionNotMet(
"For batch [%d]: U(%d, %d) is zero, singular U.",
i,
error_info[i],
error_info[i]));
}
}
} // namespace phi
PD_REGISTER_KERNEL(cholesky, // cuda_only
GPU,
ALL_LAYOUT,
phi::CholeskyKernel,
float,
double) {}
#endif // not PADDLE_WITH_HIP
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/increment_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/increment_kernel_impl.h"
PD_REGISTER_KERNEL(increment,
GPU,
ALL_LAYOUT,
phi::IncrementKernel,
float,
double,
int,
int64_t) {}
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -16,24 +16,25 @@ limitations under the License. */
// To-do(qili93): fix this after issue resolved
// https://github.com/ROCmSoftwarePlatform/rocPRIM/issues/202
#include "paddle/phi/kernels/multinomial_kernel.h"
#include <thrust/execution_policy.h>
#include <thrust/random.h>
#include <thrust/scan.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/multinomial_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace paddle {
namespace operators {
namespace phi {
template <typename T>
__global__ void NormalizeProbability(T* norm_probs, const T* in_data,
T* sum_rows, int64_t num_distributions,
__global__ void NormalizeProbability(T* norm_probs,
const T* in_data,
T* sum_rows,
int64_t num_distributions,
int64_t num_categories) {
int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
......@@ -57,7 +58,8 @@ __global__ void GetCumulativeProbs(T* norm_probs_data,
int64_t num_categories,
T* cumulative_probs) {
int id = blockIdx.x;
thrust::inclusive_scan(thrust::device, norm_probs_data + id * num_categories,
thrust::inclusive_scan(thrust::device,
norm_probs_data + id * num_categories,
norm_probs_data + (id + 1) * num_categories,
cumulative_probs + id * num_categories);
}
......@@ -77,8 +79,10 @@ struct RandomGeneratorCudaFunctor {
};
template <typename T>
__device__ int binarySearchFunctor(T* cumulative_probs, T* norm_probs_data,
int num_categories, T rng_number) {
__device__ int binarySearchFunctor(T* cumulative_probs,
T* norm_probs_data,
int num_categories,
T rng_number) {
int left = 0;
int right = num_categories;
......@@ -104,9 +108,13 @@ __device__ int binarySearchFunctor(T* cumulative_probs, T* norm_probs_data,
template <typename T>
__global__ void sampleMultinomialWithReplacement(
T* rng_data, const int64_t num_samples, int64_t* out_data,
const int64_t num_distributions, const int64_t num_categories,
T* cumulative_probs, T* norm_probs_data) {
T* rng_data,
const int64_t num_samples,
int64_t* out_data,
const int64_t num_distributions,
const int64_t num_categories,
T* cumulative_probs,
T* norm_probs_data) {
// use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
......@@ -118,153 +126,163 @@ __global__ void sampleMultinomialWithReplacement(
T rng_number = rng_data[sample + dist * num_samples];
// Find the bucket that a uniform random number lies in
int selected_category = binarySearchFunctor<T>(
cumulative_probs + dist * num_categories,
norm_probs_data + dist * num_categories, num_categories, rng_number);
int selected_category =
binarySearchFunctor<T>(cumulative_probs + dist * num_categories,
norm_probs_data + dist * num_categories,
num_categories,
rng_number);
out_data[sample + dist * num_samples] = selected_category;
}
}
template <typename T>
class MultinomialOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto x = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
const int64_t num_samples = ctx.Attr<int>("num_samples");
const bool replacement = ctx.Attr<bool>("replacement");
auto* in_data = x->data<T>();
int64_t* out_data = out->mutable_data<int64_t>(ctx.GetPlace());
auto in_dims = x->dims();
int64_t in_rank = in_dims.size();
const int64_t num_categories = in_dims[in_rank - 1];
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;
// If replacement is False, it's not a replaceable sample. Every category
// can
// be used only once. So after every sample, probability of the distribution
// will change. The implementation can't be parallelizable. Thus, call CPU
// implementation ``MultinomialFunctor`` to sample the distribution.
if (!replacement) {
int64_t in_data_numel = x->numel();
int64_t out_data_numel = out->numel();
T* cpu_in_data = new T[in_data_numel];
int64_t* cpu_out_data = new int64_t[out_data_numel];
template <typename T, typename Context>
void MultinomialKernel(const Context& dev_ctx,
const DenseTensor& x,
int num_samples,
bool replacement,
DenseTensor* out) {
auto* in_data = x.data<T>();
int64_t* out_data = dev_ctx.template Alloc<int64_t>(out);
auto in_dims = x.dims();
int64_t in_rank = in_dims.size();
const int64_t num_categories = in_dims[in_rank - 1];
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;
// If replacement is False, it's not a replaceable sample. Every category
// can
// be used only once. So after every sample, probability of the distribution
// will change. The implementation can't be parallelizable. Thus, call CPU
// implementation ``MultinomialFunctor`` to sample the distribution.
if (!replacement) {
int64_t in_data_numel = x.numel();
int64_t out_data_numel = out->numel();
T* cpu_in_data = new T[in_data_numel];
int64_t* cpu_out_data = new int64_t[out_data_numel];
#ifdef PADDLE_WITH_HIP
hipMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T),
hipMemcpyDeviceToHost);
hipMemcpy(
cpu_in_data, in_data, in_data_numel * sizeof(T), hipMemcpyDeviceToHost);
#else
cudaMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T),
cudaMemcpyDeviceToHost);
cudaMemcpy(cpu_in_data,
in_data,
in_data_numel * sizeof(T),
cudaMemcpyDeviceToHost);
#endif
MultinomialFunctor<T>(cpu_out_data, cpu_in_data, num_samples, replacement,
num_categories, num_distributions);
MultinomialFunctor<T>(cpu_out_data,
cpu_in_data,
num_samples,
replacement,
num_categories,
num_distributions);
#ifdef PADDLE_WITH_HIP
hipMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(int64_t),
hipMemcpyHostToDevice);
hipMemcpy(out_data,
cpu_out_data,
out_data_numel * sizeof(int64_t),
hipMemcpyHostToDevice);
#else
cudaMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(int64_t),
cudaMemcpyHostToDevice);
cudaMemcpy(out_data,
cpu_out_data,
out_data_numel * sizeof(int64_t),
cudaMemcpyHostToDevice);
#endif
delete[] cpu_in_data;
delete[] cpu_out_data;
return;
}
// Sum of input may not be 1. To get probability in range [0, 1], calculate
// sum of each row of input, and then use the sum to normalize the input.
// sum_row_data: sum of each row
framework::Tensor sum_rows_tensor;
auto* sum_rows_data =
sum_rows_tensor.mutable_data<T>({num_distributions}, ctx.GetPlace());
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
if (num_distributions == 1) {
auto eigen_input = framework::EigenVector<T>::Flatten(*x);
auto eigen_sum_rows = framework::EigenVector<T>::Flatten(sum_rows_tensor);
eigen_sum_rows.device(place) =
eigen_input.sum(Eigen::DSizes<int, 1>(1))
.eval()
.reshape(Eigen::DSizes<int, 1>(sum_rows_tensor.dims()[0]));
} else {
auto eigen_input = framework::EigenMatrix<T>::From(*x);
auto eigen_sum_rows = framework::EigenVector<T>::Flatten(sum_rows_tensor);
eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes<int, 1>(1));
}
delete[] cpu_in_data;
delete[] cpu_out_data;
return;
}
// Normalize row of each distribution to get the probability in range [0,
// 1].
// norm_probs_data: probability of the distribution
framework::Tensor norm_probs_tensor;
auto* norm_probs_data = norm_probs_tensor.mutable_data<T>(
{num_distributions, num_categories}, ctx.GetPlace());
// number of threads in a block is min(num_categories, 512)
dim3 block_norm(num_categories < 512 ? num_categories : 512);
dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1);
NormalizeProbability<
T><<<grid_norm, block_norm, 0, ctx.cuda_device_context().stream()>>>(
norm_probs_data, in_data, sum_rows_data, num_distributions,
num_categories);
// Get cumulative probability of each distribution. It's the same function
// of
// ``cumsum`` op.
framework::Tensor cumulative_probs_tensor;
auto* cumulative_probs = cumulative_probs_tensor.mutable_data<T>(
{num_distributions, num_categories}, ctx.GetPlace());
dim3 block_cumsum(1);
dim3 grid_cumsum(num_distributions);
GetCumulativeProbs<T><<<grid_cumsum, block_cumsum, 0,
ctx.cuda_device_context().stream()>>>(
norm_probs_data, num_distributions, num_categories, cumulative_probs);
// Generate random number for each sample.
std::random_device rd;
auto seed = rd();
framework::Tensor rng_data_tensor;
auto* rng_data = rng_data_tensor.mutable_data<T>(
{num_distributions, num_samples}, ctx.GetPlace());
thrust::counting_iterator<int64_t> index_sequence_begin(0);
platform::Transform<platform::CUDADeviceContext> trans;
auto* context =
static_cast<const platform::CUDADeviceContext*>(&ctx.device_context());
trans(*context, index_sequence_begin,
index_sequence_begin + num_distributions * num_samples, rng_data,
RandomGeneratorCudaFunctor<T>(seed));
// Sample the multinomial distributions.
dim3 block_sample(128);
dim3 grid_sample((num_samples - 1) / block_sample.x + 1, num_distributions);
sampleMultinomialWithReplacement<T><<<grid_sample, block_sample, 0,
ctx.cuda_device_context().stream()>>>(
rng_data, num_samples, out_data, num_distributions, num_categories,
cumulative_probs, norm_probs_data);
// Sum of input may not be 1. To get probability in range [0, 1], calculate
// sum of each row of input, and then use the sum to normalize the input.
// sum_row_data: sum of each row
DenseTensor sum_rows_tensor;
sum_rows_tensor.Resize({num_distributions});
auto* sum_rows_data = dev_ctx.template Alloc<T>(&sum_rows_tensor);
auto& place = *dev_ctx.eigen_device();
if (num_distributions == 1) {
auto eigen_input = EigenVector<T>::Flatten(x);
auto eigen_sum_rows = EigenVector<T>::Flatten(sum_rows_tensor);
eigen_sum_rows.device(place) =
eigen_input.sum(Eigen::DSizes<int, 1>(1))
.eval()
.reshape(Eigen::DSizes<int, 1>(sum_rows_tensor.dims()[0]));
} else {
auto eigen_input = EigenMatrix<T>::From(x);
auto eigen_sum_rows = EigenVector<T>::Flatten(sum_rows_tensor);
eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes<int, 1>(1));
}
};
} // namespace operators
} // namespace paddle
// Normalize row of each distribution to get the probability in range [0,
// 1].
// norm_probs_data: probability of the distribution
DenseTensor norm_probs_tensor;
norm_probs_tensor.Resize({num_distributions, num_categories});
auto* norm_probs_data = dev_ctx.template Alloc<T>(&norm_probs_tensor);
// number of threads in a block is min(num_categories, 512)
dim3 block_norm(num_categories < 512 ? num_categories : 512);
dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1);
NormalizeProbability<T><<<grid_norm, block_norm, 0, dev_ctx.stream()>>>(
norm_probs_data,
in_data,
sum_rows_data,
num_distributions,
num_categories);
// Get cumulative probability of each distribution. It's the same function
// of
// ``cumsum`` op.
DenseTensor cumulative_probs_tensor;
cumulative_probs_tensor.Resize({num_distributions, num_categories});
auto* cumulative_probs = dev_ctx.template Alloc<T>(&cumulative_probs_tensor);
dim3 block_cumsum(1);
dim3 grid_cumsum(num_distributions);
GetCumulativeProbs<T><<<grid_cumsum, block_cumsum, 0, dev_ctx.stream()>>>(
norm_probs_data, num_distributions, num_categories, cumulative_probs);
// Generate random number for each sample.
std::random_device rd;
auto seed = rd();
DenseTensor rng_data_tensor;
rng_data_tensor.Resize({num_distributions, num_samples});
auto* rng_data = dev_ctx.template Alloc<T>(&rng_data_tensor);
thrust::counting_iterator<int64_t> index_sequence_begin(0);
paddle::platform::Transform<GPUContext> trans;
trans(dev_ctx,
index_sequence_begin,
index_sequence_begin + num_distributions * num_samples,
rng_data,
RandomGeneratorCudaFunctor<T>(seed));
// Sample the multinomial distributions.
dim3 block_sample(128);
dim3 grid_sample((num_samples - 1) / block_sample.x + 1, num_distributions);
sampleMultinomialWithReplacement<
T><<<grid_sample, block_sample, 0, dev_ctx.stream()>>>(rng_data,
num_samples,
out_data,
num_distributions,
num_categories,
cumulative_probs,
norm_probs_data);
}
namespace ops = paddle::operators;
namespace plat = paddle::platform;
} // namespace phi
REGISTER_OP_CUDA_KERNEL(
multinomial, ops::MultinomialOpKernel<plat::CUDADeviceContext, double>,
ops::MultinomialOpKernel<plat::CUDADeviceContext, float>);
PD_REGISTER_KERNEL(multinomial, // cuda_only
GPU,
ALL_LAYOUT,
phi::MultinomialKernel,
float,
double) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/addmm_grad_kernel.h"
#include <type_traits>
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using PhiEigenTensor = EigenTensor<T, D, MajorType, IndexType>;
using Array1 = Eigen::DSizes<Eigen::DenseIndex, 1>;
using Array2 = Eigen::DSizes<Eigen::DenseIndex, 2>;
template <typename T, typename Context>
void AddmmGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
float alpha,
float beta,
DenseTensor* input_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto in_dims = input.dims();
int total_elems = 0;
VLOG(3) << "alpha: " << alpha << " beta: " << beta;
if (input_grad != nullptr) {
input_grad->set_lod(out_grad.lod());
}
if (x_grad != nullptr) {
x_grad->set_lod(x.lod());
}
if (y_grad != nullptr) {
y_grad->set_lod(y.lod());
}
auto blas = funcs::GetBlas<Context, T>(dev_ctx);
if (input_grad) {
dev_ctx.template Alloc<T>(input_grad);
total_elems = in_dims[0] * in_dims[1];
auto& place = *dev_ctx.eigen_device();
auto eigen_dout = PhiEigenTensor<T, 2>::From(out_grad);
auto eigen_dinput = PhiEigenTensor<T, 2>::From(*input_grad);
bool row_compress = in_dims[0] != out_grad.dims()[0];
bool col_compress = in_dims[1] != out_grad.dims()[1];
auto eigen_dinput_shape =
Array2(input_grad->dims()[0], input_grad->dims()[1]);
if (row_compress && col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum().eval().reshape(eigen_dinput_shape);
} else if (row_compress) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
} else if (col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
} else {
blas.VCOPY(total_elems, out_grad.data<T>(), input_grad->data<T>());
}
blas.SCAL(total_elems, beta, input_grad->data<T>());
}
if (x_grad) {
dev_ctx.template Alloc<T>(x_grad);
total_elems = x.dims()[0] * x.dims()[1];
// x_grad = out_grad * y'. x_grad: M x K, out_grad : M x N, y : K x N
blas.MatMul(out_grad, false, y, true, x_grad);
blas.SCAL(total_elems, alpha, x_grad->data<T>());
}
if (y_grad) {
dev_ctx.template Alloc<T>(y_grad);
total_elems = x.dims()[1] * y.dims()[1];
// y_grad = x' * out_grad. y_grad K x N, out_grad : M x N, x : M x K
blas.MatMul(x, true, out_grad, false, y_grad);
blas.SCAL(total_elems, alpha, y_grad->data<T>());
}
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/addmm_kernel.h"
#include <type_traits>
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using PhiEigenTensor = EigenTensor<T, D, MajorType, IndexType>;
using Array1 = Eigen::DSizes<Eigen::DenseIndex, 1>;
using Array2 = Eigen::DSizes<Eigen::DenseIndex, 2>;
template <typename T, typename Context>
void AddmmKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& y,
float alpha,
float beta,
DenseTensor* out) {
auto input_dims = input.dims();
auto x_dims = x.dims();
auto y_dims = y.dims();
// broadcast mode check
if (x_dims[0] != input_dims[0]) {
PADDLE_ENFORCE_EQ(input_dims[0],
1,
errors::InvalidArgument(
"When x_dims[0] is not equal with input_dims[0], "
"input_dims[0] must be 1 but got %s",
input_dims[0]));
PADDLE_ENFORCE_EQ(y_dims[1] == input_dims[1] || input_dims[1] == 1,
true,
errors::InvalidArgument(
"The input tensor shape mismatch, input shape=[%s], "
"x shape=[%s], y shape=[%s]",
input_dims,
x_dims,
y_dims));
}
// broadcast mode check
if (y_dims[1] != input_dims[1]) {
PADDLE_ENFORCE_EQ(input_dims[1],
1,
errors::InvalidArgument(
"When y_dims[1] is not equal with input_dims[0], "
"input_dims[0] must be 1 but got %s",
input_dims[1]));
PADDLE_ENFORCE_EQ(x_dims[0] == input_dims[0] || input_dims[0] == 1,
true,
errors::InvalidArgument(
"The input tensor shape mismatch, input shape=[%s], "
"x shape=[%s], y shape=[%s]",
input_dims,
x_dims,
y_dims));
}
// broadcast mode check
PADDLE_ENFORCE_EQ(
x_dims[1],
y_dims[0],
errors::InvalidArgument(
"The input tensor X's width must be equal with matrix Y' height. "
"But received X's shape = [%s], Y's shape = [%s].",
x_dims[1],
y_dims[0]));
dev_ctx.template Alloc<T>(out);
auto blas = funcs::GetBlas<Context, T>(dev_ctx);
// calc broadcast dim
Array2 bcast_dims;
bcast_dims[0] = x_dims[0] / input_dims[0];
bcast_dims[1] = y_dims[1] / input_dims[1];
VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << "]";
// broadcast using eigen
auto eigen_input = PhiEigenTensor<T, 2>::From(input);
auto eigen_out = PhiEigenTensor<T, 2>::From(*out);
auto& place = *dev_ctx.eigen_device();
funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval(
place, eigen_out, eigen_input, bcast_dims);
blas.GEMM(false,
false,
x_dims[0],
y_dims[1],
x_dims[1],
alpha,
x.data<T>(),
x_dims[1],
y.data<T>(),
y_dims[1],
beta,
out->data<T>(),
y_dims[1]);
}
} // namespace phi
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -14,74 +14,50 @@ limitations under the License. */
#pragma once
#include <numeric>
#include <vector>
#include "Eigen/Cholesky"
#include "Eigen/Core"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/phi/kernels/cholesky_grad_kernel.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class CholeskyCPUKernel : public framework::OpKernel<T> {
public:
// different with EigenMatrix in framework/eigen.h
using EigenMatrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using InputMatrixMap = Eigen::Map<const EigenMatrix>;
using OutputMatrixMap = Eigen::Map<EigenMatrix>;
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
bool upper = context.Attr<bool>("upper");
auto& dims = x->dims();
int batch_count = 1;
for (int i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}
auto m = dims[dims.size() - 1];
const auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
// Cholesky decomposition for each matrix, maybe can use multi threads
for (int i = 0; i < batch_count; i++) {
auto input = InputMatrixMap(x_data + i * m * m, m, m);
auto output = OutputMatrixMap(out_data + i * m * m, m, m);
if (upper) {
Eigen::LLT<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>,
Eigen::UpLoType::Upper>
llt_decomposition(input);
PADDLE_ENFORCE_EQ(llt_decomposition.info(), Eigen::Success,
platform::errors::InvalidArgument(
"Cholesky decomposition was not successful. The "
"%d-th input matrice "
"might not be not be positive definite.",
i));
output = llt_decomposition.matrixU();
} else {
Eigen::LLT<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>,
Eigen::UpLoType::Lower>
llt_decomposition(input);
PADDLE_ENFORCE_EQ(llt_decomposition.info(), Eigen::Success,
platform::errors::InvalidArgument(
"Cholesky decomposition was not successful. The "
"%d-th input matrice "
"might not be not be positive definite.",
i));
output = llt_decomposition.matrixL();
}
}
namespace phi {
template <typename Context, typename T>
inline void TransCompute(const int dim,
const Context& dev_ctx,
const DenseTensor& in,
DenseTensor* out,
const std::vector<int>& axis) {
switch (dim) {
case 1:
funcs::Transpose<Context, T, 1> trans1;
trans1(dev_ctx, in, out, axis);
break;
case 2:
funcs::Transpose<Context, T, 2> trans2;
trans2(dev_ctx, in, out, axis);
break;
case 3:
funcs::Transpose<Context, T, 3> trans3;
trans3(dev_ctx, in, out, axis);
break;
case 4:
funcs::Transpose<Context, T, 4> trans4;
trans4(dev_ctx, in, out, axis);
break;
case 5:
funcs::Transpose<Context, T, 5> trans5;
trans5(dev_ctx, in, out, axis);
break;
case 6:
funcs::Transpose<Context, T, 6> trans6;
trans6(dev_ctx, in, out, axis);
break;
default:
// for dim >= 7 situation
funcs::TransposeNormal<Context, T> trans_normal;
trans_normal(dev_ctx, in, out, axis);
}
};
}
/*! Use these functors to implement tril, triu, diagonal and other operators */
template <typename T>
......@@ -101,40 +77,6 @@ struct EyeFunctor {
T* output_;
};
template <typename T>
struct MatrixBandPartFunctor {
/*! Set output as input value outside a central band and 0 inside that band.
* That is: output[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n]
* where: in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper
* < 0 || (n-m) <= num_upper)
*/
MatrixBandPartFunctor(const int m, const int n, const int num_lower_diags,
const int num_upper_diags, const T* input, T* output)
: m_(m),
n_(n),
num_lower_diags_(num_lower_diags),
num_upper_diags_(num_upper_diags),
input_(input),
output_(output) {}
HOSTDEVICE void operator()(size_t index) const {
const int col = index % n_;
const int row = (index / n_) % m_;
const int band_start = (num_lower_diags_ < 0 ? 0 : row - num_lower_diags_);
const int band_end =
(num_upper_diags_ < 0 ? n_ : row + num_upper_diags_ + 1);
if (col < band_start || col >= band_end) {
output_[index] = static_cast<T>(0);
} else {
output_[index] = input_[index];
}
}
const int m_, n_, num_lower_diags_, num_upper_diags_;
const T* input_;
T* output_;
};
template <typename T>
struct MatrixSetDiagFunctor {
/*! Overwrite specified diagonals of output by the values in diagonal.
......@@ -145,9 +87,13 @@ struct MatrixSetDiagFunctor {
* and the num_diags diagonals has a up to down layout. Otherwise it has a
* shape [i, j, ..., max_diag_len].
*/
MatrixSetDiagFunctor(const int m, const int n, const int num_diags,
const int max_diag_len, const int upper_diag_index,
const T* diag, T* output)
MatrixSetDiagFunctor(const int m,
const int n,
const int num_diags,
const int max_diag_len,
const int upper_diag_index,
const T* diag,
T* output)
: m_(m),
n_(n),
num_diags_(num_diags),
......@@ -189,9 +135,14 @@ struct MatrixDiagPartFunctor {
/*! Similar to MatrixSetDiagFunctor but return the diagonals. diag_index=0
* refers to the main diagonal, positive value means superdiagonal and
* negative value means subdiagonal */
MatrixDiagPartFunctor(const int m, const int n, const int num_diags,
const int max_diag_len, const int upper_diag_index,
const T padding, const T* input, T* output)
MatrixDiagPartFunctor(const int m,
const int n,
const int num_diags,
const int max_diag_len,
const int upper_diag_index,
const T padding,
const T* input,
T* output)
: m_(m),
n_(n),
num_diags_(num_diags),
......@@ -237,10 +188,13 @@ struct MatrixBandPartScaleEndFunctor {
* 2. middle = matrix_set_diag(middle, diag * scalar)
* 3. middle = matrix_band_part(middle, -1, 0)
*/
MatrixBandPartScaleEndFunctor(const int m, const int n,
MatrixBandPartScaleEndFunctor(const int m,
const int n,
const int num_lower_diags,
const int num_upper_diags, const T scale,
const T* input, T* output)
const int num_upper_diags,
const T scale,
const T* input,
T* output)
: m_(m),
n_(n),
num_lower_diags_(num_lower_diags),
......@@ -283,92 +237,100 @@ struct AddtoScaleFunctor {
T* output_;
};
template <typename DeviceContext, typename T>
class CholeskyGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Input<Tensor>("Out");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
auto* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
bool upper = context.Attr<bool>("upper");
auto& dims = out->dims();
int batch_count = 1;
for (int i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}
auto m = dims[dims.size() - 1];
int tensor_size = batch_count * m * m;
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis(dims.size() - 2);
std::iota(axis.begin(), axis.end(), 0);
axis.insert(axis.end(), {dims.size() - 1, dims.size() - 2});
Tensor l, l_grad;
if (upper) {
l.mutable_data<T>(dims, context.GetPlace());
l_grad.mutable_data<T>(dims, context.GetPlace());
TransCompute<DeviceContext, T>(dims.size(), dev_ctx, *out, &l, axis);
TransCompute<DeviceContext, T>(dims.size(), dev_ctx, *out_grad, &l_grad,
axis);
} else {
l = *out;
l_grad = *out_grad;
}
auto* l_data = l.data<T>();
/*! refer to Iain Murray (2016); arXiv 1602.07527 */
/*! phi = matmul(L.transpose(-1, -2), grad) */
Tensor middle;
auto* middle_data = middle.mutable_data<T>(dims, context.GetPlace());
auto trans_desc = phi::funcs::CreateMatrixDescriptor(dims, 0, true);
auto no_trans_desc = phi::funcs::CreateMatrixDescriptor(dims, 0, false);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
blas.MatMul(l, trans_desc, l_grad, no_trans_desc, T(1), &middle, T(0));
/*! phi.tril_().diagonal(0, -2, -1).mul_(0.5) */
platform::ForRange<DeviceContext> for_range(dev_ctx, tensor_size);
MatrixBandPartScaleEndFunctor<T> matrix_band_part_scale_end_functor(
m, m, /* num_lower_diags */ m, /* num_upper_diags */ 0,
/* scale */ 0.5, middle_data, middle_data);
for_range(matrix_band_part_scale_end_functor);
// Compute inverse by solving the triangular linear system AX = B, where B
// is the identity matrix. The matrix X would be overwritten on B
Tensor identity;
auto* identity_data = identity.mutable_data<T>(dims, context.GetPlace());
EyeFunctor<T> eye_functor(m, m, identity_data);
for_range(eye_functor);
// TODO(guosheng): use trsmBatched for GPU
for (int i = 0; i < batch_count; i++) {
blas.TRSM(/*side*/ CblasLeft, /*uplo*/ CblasLower,
/*trans*/ CblasNoTrans, /*diag*/ CblasNonUnit, /*m*/ m, /*n*/ m,
/*alpha*/ T(1), l_data + i * m * m, /*lda*/ m,
identity_data + i * m * m, /*ldb*/ m);
}
Tensor& l_inverse = identity;
/*! x_grad = matmul(matmul(L_inverse.transpose(-1, -2), phi), L_inverse) */
Tensor middle1;
middle1.mutable_data<T>(dims, context.GetPlace());
blas.MatMul(l_inverse, trans_desc, middle, no_trans_desc, T(1), &middle1,
T(0));
blas.MatMul(middle1, no_trans_desc, l_inverse, no_trans_desc, T(1), x_grad,
T(0));
/*! x_grad.add(x_grad.transpose(-1, -2)).mul_(0.5) */
Tensor x_grad_trans;
auto* x_grad_trans_data =
x_grad_trans.mutable_data<T>(dims, context.GetPlace());
TransCompute<DeviceContext, T>(dims.size(), dev_ctx, *x_grad, &x_grad_trans,
axis);
AddtoScaleFunctor<T> addto_scale_functor(0.5, x_grad_trans_data,
x_grad_data);
for_range(addto_scale_functor);
template <typename T, typename Context>
void CholeskyGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
bool upper,
DenseTensor* x_grad) {
auto* x_grad_data = dev_ctx.template Alloc<T>(x_grad);
auto& dims = out.dims();
int batch_count = 1;
for (int i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}
};
} // namespace operators
} // namespace paddle
auto m = dims[dims.size() - 1];
int tensor_size = batch_count * m * m;
std::vector<int> axis(dims.size() - 2);
std::iota(axis.begin(), axis.end(), 0);
axis.insert(axis.end(), {dims.size() - 1, dims.size() - 2});
DenseTensor l, l_grad;
if (upper) {
l.Resize(dims);
dev_ctx.template Alloc<T>(&l);
l_grad.Resize(dims);
dev_ctx.template Alloc<T>(&l_grad);
TransCompute<Context, T>(dims.size(), dev_ctx, out, &l, axis);
TransCompute<Context, T>(dims.size(), dev_ctx, out_grad, &l_grad, axis);
} else {
l = out;
l_grad = out_grad;
}
auto* l_data = l.data<T>();
/*! refer to Iain Murray (2016); arXiv 1602.07527 */
/*! phi = matmul(L.transpose(-1, -2), grad) */
DenseTensor middle;
middle.Resize(dims);
auto* middle_data = dev_ctx.template Alloc<T>(&middle);
auto trans_desc = funcs::CreateMatrixDescriptor(dims, 0, true);
auto no_trans_desc = funcs::CreateMatrixDescriptor(dims, 0, false);
auto blas = funcs::GetBlas<Context, T>(dev_ctx);
blas.MatMul(l, trans_desc, l_grad, no_trans_desc, T(1), &middle, T(0));
/*! phi.tril_().diagonal(0, -2, -1).mul_(0.5) */
paddle::platform::ForRange<Context> for_range(dev_ctx, tensor_size);
MatrixBandPartScaleEndFunctor<T> matrix_band_part_scale_end_functor(
m,
m,
/* num_lower_diags */ m,
/* num_upper_diags */ 0,
/* scale */ 0.5,
middle_data,
middle_data);
for_range(matrix_band_part_scale_end_functor);
// Compute inverse by solving the triangular linear system AX = B, where B
// is the identity matrix. The matrix X would be overwritten on B
DenseTensor identity;
identity.Resize(dims);
auto* identity_data = dev_ctx.template Alloc<T>(&identity);
EyeFunctor<T> eye_functor(m, m, identity_data);
for_range(eye_functor);
// TODO(guosheng): use trsmBatched for GPU
for (int i = 0; i < batch_count; i++) {
blas.TRSM(/*side*/ CblasLeft,
/*uplo*/ CblasLower,
/*trans*/ CblasNoTrans,
/*diag*/ CblasNonUnit,
/*m*/ m,
/*n*/ m,
/*alpha*/ T(1),
l_data + i * m * m,
/*lda*/ m,
identity_data + i * m * m,
/*ldb*/ m);
}
DenseTensor& l_inverse = identity;
/*! x_grad = matmul(matmul(L_inverse.transpose(-1, -2), phi), L_inverse) */
DenseTensor middle1;
middle1.Resize(dims);
dev_ctx.template Alloc<T>(&middle1);
blas.MatMul(
l_inverse, trans_desc, middle, no_trans_desc, T(1), &middle1, T(0));
blas.MatMul(
middle1, no_trans_desc, l_inverse, no_trans_desc, T(1), x_grad, T(0));
/*! x_grad.add(x_grad.transpose(-1, -2)).mul_(0.5) */
DenseTensor x_grad_trans;
x_grad_trans.Resize(dims);
auto* x_grad_trans_data = dev_ctx.template Alloc<T>(&x_grad_trans);
TransCompute<Context, T>(dims.size(), dev_ctx, *x_grad, &x_grad_trans, axis);
AddtoScaleFunctor<T> addto_scale_functor(0.5, x_grad_trans_data, x_grad_data);
for_range(addto_scale_functor);
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/increment_kernel.h"
namespace phi {
template <typename T, typename Context>
void IncrementKernel(const Context& dev_ctx,
const DenseTensor& x,
float value,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
auto& dev = *dev_ctx.eigen_device();
funcs::EigenAdd<std::decay_t<decltype(dev)>, T>::Eval(
dev,
EigenScalar<T>::From(*out),
EigenScalar<T>::From(x),
static_cast<T>(value));
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void IncrementKernel(const Context& ctx,
const DenseTensor& x,
float value,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/hostdevice.h"
namespace paddle {
namespace operators {
namespace phi {
/**
* Samples a multinomial distribution given a probability input
*/
template <typename T, typename Context>
void MultinomialKernel(const Context& dev_ctx,
const DenseTensor& x,
int num_samples,
bool replacement,
DenseTensor* out);
template <typename T>
void MultinomialFunctor(int64_t* out_data, const T* in_data,
......@@ -35,7 +34,7 @@ void MultinomialFunctor(int64_t* out_data, const T* in_data,
std::vector<T> cumulative_probs(num_categories);
std::uniform_real_distribution<T> dist(0, 1);
auto gen_ptr = framework::DefaultCPUGenerator();
auto gen_ptr = paddle::framework::DefaultCPUGenerator();
auto engine = gen_ptr->GetCPUEngine();
for (int64_t i = 0; i < num_distributions; i++) {
......@@ -45,7 +44,7 @@ void MultinomialFunctor(int64_t* out_data, const T* in_data,
for (int64_t j = 0; j < num_categories; j++) {
prob_value = in_data[i * num_categories + j];
PADDLE_ENFORCE_GE(prob_value, 0.0,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The input of multinomial distribution "
"should be >= 0, but got %f.",
prob_value));
......@@ -57,13 +56,13 @@ void MultinomialFunctor(int64_t* out_data, const T* in_data,
cumulative_probs[j] = probs_sum;
}
PADDLE_ENFORCE_GT(probs_sum, 0.0,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The sum of one multinomial distribution "
"probability should be > 0, but got %f.",
probs_sum));
PADDLE_ENFORCE_EQ(
(replacement || (num_categories - num_zeros >= num_samples)), true,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"When replacement is False, number of "
"samples should be less than non-zero "
"categories."));
......@@ -121,8 +120,4 @@ void MultinomialFunctor(int64_t* out_data, const T* in_data,
}
}
template <typename DeviceContext, typename T>
class MultinomialOpKernel;
} // namespace operators
} // namespace paddle
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature AddmmOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"addmm", {"Input", "X", "Y"}, {"Alpha", "Beta"}, {"Out"});
}
KernelSignature AddmmGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"addmm_grad",
{"Input", "X", "Y", GradVarName("Out")},
{"Alpha", "Beta"},
{GradVarName("Input"), GradVarName("X"), GradVarName("Y")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(addmm, phi::AddmmOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(addmm_grad, phi::AddmmGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature CholeskyOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("cholesky", {"X"}, {"upper"}, {"Out"});
}
KernelSignature CholeskyGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cholesky_grad",
{"Out", GradVarName("Out")},
{"upper"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(cholesky, phi::CholeskyOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(cholesky_grad, phi::CholeskyGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册