未验证 提交 581b2c64 编写于 作者: F From00 提交者: GitHub

Move GumbelSoftmax OP to phi (#39873)

* Move GumbelSoftmax OP to phi

* platform::errors -> phi::errors; GumbelSoftmaxGradInferMeta -> backend.h/cc

* Use axis util in kernel impl

* Remove namespace platform::errors

* Use GetCPUEngine in Device Context
上级 caea126c
......@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/gumbel_softmax_op.h"
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -23,10 +24,6 @@ class GumbelSoftmaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
return UnaryOpUnchangedInferShapeCheckAxis(ctx);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -71,20 +68,6 @@ Samples from the Gumbel-Softmax distribution and optionally discretizes.
class GumbelSoftmaxGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "gumbel_softmax_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "gumbel_softmax_grad");
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Out")),
platform::errors::InvalidArgument("Input(Out) and its gradients "
"should have the same shape."));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
};
template <typename T>
......@@ -107,17 +90,16 @@ class GumbelSoftmaxGradOpMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(gumbel_softmax, GumbelSoftmaxInferShapeFunctor,
PT_INFER_META(phi::GumbelSoftmaxInferMeta));
DELCARE_INFER_SHAPE_FUNCTOR(gumbel_softmax_grad,
GumbelSoftmaxGradInferShapeFunctor,
PT_INFER_META(phi::GumbelSoftmaxGradInferMeta));
REGISTER_OPERATOR(gumbel_softmax, ops::GumbelSoftmaxOp,
ops::GumbelSoftmaxOpMaker,
ops::GumbelSoftmaxGradOpMaker<paddle::framework::OpDesc>,
ops::GumbelSoftmaxGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gumbel_softmax_grad, ops::GumbelSoftmaxGradOp);
REGISTER_OP_CPU_KERNEL(
gumbel_softmax,
ops::GumbelSoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::GumbelSoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
gumbel_softmax_grad,
ops::GumbelSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GumbelSoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);
ops::GumbelSoftmaxGradOpMaker<paddle::imperative::OpBase>,
GumbelSoftmaxInferShapeFunctor);
REGISTER_OPERATOR(gumbel_softmax_grad, ops::GumbelSoftmaxGradOp,
GumbelSoftmaxGradInferShapeFunctor);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D>;
static inline int CanonicalAxis(const int axis, const int rank) {
if (axis < 0) {
return axis + rank;
}
return axis;
}
static inline int SizeToAxis(const int axis, DDim dims) {
int size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline int SizeFromAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
static inline int SizeOutAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
template <typename DeviceContext, typename T, int64_t Rank>
struct ArgMaxFunctor {
void operator()(const DeviceContext& ctx, const Tensor& in,
Tensor* index_tensor, const int64_t& axis) {
auto in_eigen = EigenTensor<T, Rank>::From(in, in.dims());
auto index_eigen = EigenTensor<int, Rank - 1>::From(*index_tensor);
index_eigen = in_eigen.argmax(axis).template cast<int>();
}
};
template <typename DeviceContext, typename T>
struct GumbleNoiseGenerator;
template <typename DeviceContext, typename T>
struct OneHotGenerator;
template <typename T>
struct GumbleNoiseGenerator<platform::CPUDeviceContext, T> {
static void Transform(const platform::CPUDeviceContext& context,
const T* input_data, T* output_data, int size_to_axis,
int size_from_axis, const float temperature) {
// generate uniform random number
const int size = size_to_axis * size_from_axis;
std::uniform_real_distribution<T> dist(0.00001, 1);
auto engine = paddle::framework::GetCPURandomEngine(0);
Tensor random_tensor;
auto* random_data =
random_tensor.mutable_data<T>({size}, platform::CPUPlace());
for (int64_t i = 0; i < size; ++i) {
random_data[i] = dist(*engine);
}
// generate gumbel noise
framework::DDim dim_2d{size_to_axis, size_from_axis};
auto gumbel_noise_eigen = EigenMatrix<T>::From(random_tensor, dim_2d);
gumbel_noise_eigen = -(((-(gumbel_noise_eigen.log())).log()));
// add noise
for (int64_t i = 0; i < size_to_axis * size_from_axis; i++) {
output_data[i] = (input_data[i] + random_data[i]) / temperature;
}
}
};
template <typename T>
struct OneHotGenerator<platform::CPUDeviceContext, T> {
static void Transform(const platform::CPUDeviceContext& context,
const Tensor& X, Tensor* Out, int axis) {
Tensor index;
std::vector<int> index_dim;
const auto rank = X.dims().size();
const int size_to_axis = SizeToAxis(axis, X.dims());
const int size_from_axis = SizeFromAxis(axis, X.dims());
const int size_out_axis = SizeOutAxis(axis, X.dims());
for (int i = 0; i < X.dims().size(); i++) {
if (i != axis) index_dim.push_back(X.dims().Get()[i]);
}
DDim index_ddim(index_dim.data(), rank - 1);
index.Resize(index_ddim);
auto* index_data = index.mutable_data<int>(context.GetPlace());
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMaxFunctor<platform::CPUDeviceContext, T, rank> functor##rank; \
functor##rank(context, *Out, &index, axis);
switch (Out->dims().size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
PADDLE_ENFORCE_LE(Out->dims().size(), 6,
platform::errors::InvalidArgument(
"gumbel_softmax operator doesn't supports "
"tensors whose ranks are greater "
"than 6 in CPU mode."));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
phi::funcs::set_constant(context, Out, 0.0);
for (int i = 0; i < size_to_axis; i++) {
for (int j = 0; j < size_out_axis; j++) {
*(Out->data<T>() + i * size_from_axis + j +
index_data[i * size_out_axis + j] * size_out_axis) = 1.0;
}
}
}
};
template <typename DeviceContext, typename T>
class GumbelSoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
const int rank = X->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = X->dims()[axis];
const bool is_hard = context.Attr<bool>("hard");
const float temperature = context.Attr<float>("temperature");
PADDLE_ENFORCE_GT(temperature, 0,
platform::errors::InvalidArgument(
"The temperature must be greater than 0. But "
"received temperature = %f",
temperature));
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
if (Out->numel() == 0) {
return;
}
const int size_to_axis = SizeToAxis(axis, X->dims());
const int size_from_axis = SizeFromAxis(axis, X->dims());
Tensor X_noise_2d, Out_2d;
X_noise_2d.Resize({size_to_axis, size_from_axis});
Out_2d.ShareDataWith(*Out).Resize({size_to_axis, size_from_axis});
// generate gumbel noise and add it to X
auto* x_noise_data = X_noise_2d.mutable_data<T>(context.GetPlace());
GumbleNoiseGenerator<DeviceContext, T>::Transform(
context.template device_context<DeviceContext>(), X->data<T>(),
x_noise_data, size_to_axis, size_from_axis, temperature);
#ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()(
context.template device_context<DeviceContext>(), axis_dim, &X_noise_2d,
&Out_2d);
#else
math::SoftmaxFunctor<DeviceContext, T, false>()(
context.template device_context<DeviceContext>(), axis_dim, &X_noise_2d,
&Out_2d);
#endif
if (is_hard) {
OneHotGenerator<DeviceContext, T>::Transform(
context.template device_context<DeviceContext>(), *X, Out, axis);
}
}
};
template <typename DeviceContext, typename T>
class GumbelSoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
const int rank = dX->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = dX->dims()[axis];
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
if (dX->numel() == 0) {
return;
}
const int size_to_axis = SizeToAxis(axis, dX->dims());
const int size_from_axis = SizeFromAxis(axis, dX->dims());
Tensor dX_2d, Out_2d, dOut_2d;
dX_2d.ShareDataWith(*dX).Resize({size_to_axis, size_from_axis});
Out_2d.ShareDataWith(*Out).Resize({size_to_axis, size_from_axis});
dOut_2d.ShareDataWith(*dOut).Resize({size_to_axis, size_from_axis});
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), axis_dim, &Out_2d,
&dOut_2d, &dX_2d);
}
};
} // namespace operators
} // namespace paddle
......@@ -76,4 +76,16 @@ void GeneralBinaryGradInferMeta(const MetaTensor& x,
}
}
void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
const MetaTensor& dout,
int axis,
MetaTensor* dx) {
PADDLE_ENFORCE_EQ(
out.dims(),
dout.dims(),
errors::InvalidArgument(
"Input(Out) and its gradients should have the same shape."));
dx->share_meta(dout);
}
} // namespace phi
......@@ -34,4 +34,8 @@ void GeneralBinaryGradInferMeta(const MetaTensor& x,
MetaTensor* dx,
MetaTensor* dy);
void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
const MetaTensor& dout,
int axis,
MetaTensor* dx);
} // namespace phi
......@@ -27,6 +27,30 @@ void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out) {
out->share_meta(x);
}
// meta x -> out without change, check if axis in range [-Rank(x), Rank(x)-1]
void UnchangedInferMetaCheckAxis(const MetaTensor& x,
int axis,
MetaTensor* out) {
auto rank = x.dims().size();
PADDLE_ENFORCE_GE(
axis,
-rank,
errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X). But received axis: %d, R: %d.",
axis,
rank));
PADDLE_ENFORCE_LT(
axis,
rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X). But received axis: %d, R: %d.",
axis,
rank));
out->share_meta(x);
}
void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
......@@ -75,6 +99,14 @@ void FlattenInferMeta(const MetaTensor& x,
}
}
void GumbelSoftmaxInferMeta(const MetaTensor& x,
float temperature,
bool hard,
int axis,
MetaTensor* out) {
UnchangedInferMetaCheckAxis(x, axis, out);
}
void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(out_dtype);
......
......@@ -34,11 +34,22 @@ class MetaConfig;
void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out);
// meta x -> out without change, check if axis in range [-Rank(x), Rank(x)-1]
void UnchangedInferMetaCheckAxis(const MetaTensor& x,
int axis,
MetaTensor* out);
void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
MetaTensor* out);
void GumbelSoftmaxInferMeta(const MetaTensor& x,
float temperature,
bool hard,
int axis,
MetaTensor* out);
void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);
......
......@@ -10,7 +10,7 @@ add_subdirectory(funcs)
set_property(GLOBAL PROPERTY PHI_KERNELS "")
set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col concat_and_split_functor)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col concat_and_split_functor softmax)
# remove this dep after removing fluid deps on tensor creation
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gumbel_softmax_grad_kernel.h"
#include "paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(gumbel_softmax_grad,
CPU,
ALL_LAYOUT,
phi::GumbelSoftmaxGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gumbel_softmax_kernel.h"
#include "paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T>
struct GumbleNoiseGenerator<CPUContext, T> {
static void Transform(const CPUContext& ctx,
const T* input_data,
T* output_data,
int size_to_axis,
int size_from_axis,
const float temperature) {
// generate uniform random number
const int size = size_to_axis * size_from_axis;
std::uniform_real_distribution<T> dist(0.00001, 1);
auto engine = ctx.GetGenerator()->GetCPUEngine();
DenseTensor random_tensor;
random_tensor.Resize(make_ddim({size}));
auto* random_data = ctx.template Alloc<T>(&random_tensor);
for (int64_t i = 0; i < size; ++i) {
random_data[i] = dist(*engine);
}
// generate gumbel noise
DDim dim_2d{size_to_axis, size_from_axis};
auto gumbel_noise_eigen = EigenMatrix<T>::From(random_tensor, dim_2d);
gumbel_noise_eigen = -(((-(gumbel_noise_eigen.log())).log()));
// add noise
for (int64_t i = 0; i < size_to_axis * size_from_axis; i++) {
output_data[i] = (input_data[i] + random_data[i]) / temperature;
}
}
};
template <typename T>
struct OneHotGenerator<CPUContext, T> {
static void Transform(const CPUContext& ctx,
const DenseTensor& x,
DenseTensor* out,
int axis) {
DenseTensor index;
std::vector<int> index_dim;
const auto rank = x.dims().size();
const int size_to_axis = funcs::SizeToAxis(axis, x.dims());
const int size_from_axis = funcs::SizeFromAxis(axis, x.dims());
const int size_out_axis = funcs::SizeOutAxis(axis, x.dims());
for (int i = 0; i < x.dims().size(); i++) {
if (i != axis) index_dim.push_back(x.dims().Get()[i]);
}
DDim index_ddim(index_dim.data(), rank - 1);
index.Resize(index_ddim);
auto* index_data = ctx.template Alloc<int>(&index);
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMaxFunctor<CPUContext, T, rank> functor##rank; \
functor##rank(ctx, *out, &index, axis);
switch (out->dims().size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
PADDLE_ENFORCE_LE(
out->dims().size(),
6,
errors::InvalidArgument("gumbel_softmax operator doesn't supports "
"tensors whose ranks are greater "
"than 6 in CPU mode."));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
funcs::set_constant(ctx, out, 0.0);
for (int i = 0; i < size_to_axis; i++) {
for (int j = 0; j < size_out_axis; j++) {
*(out->data<T>() + i * size_from_axis + j +
index_data[i * size_out_axis + j] * size_out_axis) = 1.0;
}
}
}
};
} // namespace phi
PD_REGISTER_KERNEL(
gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gumbel_softmax_grad_kernel.h"
#include "paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(gumbel_softmax_grad,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxGradKernel,
float,
double) {}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/gumbel_softmax_op.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gumbel_softmax_kernel.h"
#include "paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#ifdef __NVCC__
......@@ -31,11 +32,10 @@ namespace cub = hipcub;
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/memcpy.h"
namespace paddle {
namespace operators {
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename K, typename V>
using KeyValuePair = cub::KeyValuePair<K, V>;
......@@ -46,7 +46,9 @@ struct UniformCUDAGenerator {
unsigned int offset_ = 0;
HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed)
: min_(min), max_(max), seed_(seed) {}
HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed,
HOSTDEVICE UniformCUDAGenerator(T min,
T max,
unsigned int seed,
unsigned int offset)
: min_(min), max_(max), seed_(seed), offset_(offset) {}
......@@ -60,9 +62,12 @@ struct UniformCUDAGenerator {
};
template <typename T, size_t BlockDim>
__global__ void OneHotCUDAKernel(const int64_t height, const int64_t width,
const int64_t size_out_axis, const T init,
const T* in, T* out) {
__global__ void OneHotCUDAKernel(const int64_t height,
const int64_t width,
const int64_t size_out_axis,
const T init,
const T* in,
T* out) {
typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
......@@ -85,32 +90,40 @@ __global__ void OneHotCUDAKernel(const int64_t height, const int64_t width,
}
template <typename T>
struct OneHotGenerator<platform::CUDADeviceContext, T> {
static void Transform(const platform::CUDADeviceContext& context,
const Tensor& X, Tensor* Out, int axis) {
const int size_to_axis = SizeToAxis(axis, X.dims());
const int size_from_axis = SizeFromAxis(axis, X.dims());
const int size_out_axis = SizeOutAxis(axis, X.dims());
struct OneHotGenerator<GPUContext, T> {
static void Transform(const GPUContext& ctx,
const DenseTensor& X,
DenseTensor* out,
int axis) {
const int size_to_axis = funcs::SizeToAxis(axis, X.dims());
const int size_from_axis = funcs::SizeFromAxis(axis, X.dims());
const int size_out_axis = funcs::SizeOutAxis(axis, X.dims());
constexpr int thread_size = 512;
int64_t max_grid_dimx = context.GetCUDAMaxGridDimSize()[0];
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
int64_t height = size_to_axis * size_out_axis;
int block_size = height < max_grid_dimx ? height : max_grid_dimx;
Tensor input_tensor;
input_tensor.mutable_data<T>(Out->dims(), platform::CUDAPlace());
paddle::framework::TensorCopy(*Out, context.GetPlace(), &input_tensor);
phi::funcs::set_constant(context, Out, 0.0);
OneHotCUDAKernel<
T, thread_size><<<block_size, thread_size, 0, context.stream()>>>(
height, size_from_axis / size_out_axis, size_out_axis,
std::numeric_limits<T>::lowest(), input_tensor.data<T>(),
Out->data<T>());
DenseTensor input_tensor;
input_tensor.Resize(out->dims());
ctx.template Alloc<T>(&input_tensor);
paddle::framework::TensorCopy(*out, ctx.GetPlace(), &input_tensor);
funcs::set_constant(ctx, out, 0.0);
OneHotCUDAKernel<T,
thread_size><<<block_size, thread_size, 0, ctx.stream()>>>(
height,
size_from_axis / size_out_axis,
size_out_axis,
std::numeric_limits<T>::lowest(),
input_tensor.data<T>(),
out->data<T>());
}
};
template <typename T>
__global__ void AddGumbelNoiseCUDAKernel(const T* input_data, T* output_data,
T* noise, const float temperature,
__global__ void AddGumbelNoiseCUDAKernel(const T* input_data,
T* output_data,
T* noise,
const float temperature,
int64_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int step = blockDim.x * gridDim.x;
......@@ -121,29 +134,34 @@ __global__ void AddGumbelNoiseCUDAKernel(const T* input_data, T* output_data,
}
template <typename T>
struct GumbleNoiseGenerator<platform::CUDADeviceContext, T> {
static void Transform(const platform::CUDADeviceContext& context,
const T* input_data, T* output_data, int size_to_axis,
int size_from_axis, const float temperature) {
Tensor random_tensor;
struct GumbleNoiseGenerator<GPUContext, T> {
static void Transform(const GPUContext& ctx,
const T* input_data,
T* output_data,
int size_to_axis,
int size_from_axis,
const float temperature) {
DenseTensor random_tensor;
int64_t size = size_to_axis * size_from_axis;
T* random_data =
random_tensor.mutable_data<T>({size}, platform::CUDAPlace());
random_tensor.Resize(make_ddim({size}));
auto* random_data = ctx.template Alloc<T>(&random_tensor);
thrust::counting_iterator<int64_t> index_sequence_begin(0);
// generate gumbel noise
int device_id = context.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
int device_id = ctx.GetPlace().GetDeviceId();
auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy()) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
index_sequence_begin,
index_sequence_begin + size,
thrust::device_ptr<T>(random_data),
UniformCUDAGenerator<T>(0.00001, 1, seed_offset.first, gen_offset));
} else {
const unsigned int seed = std::random_device()();
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::transform(index_sequence_begin,
index_sequence_begin + size,
thrust::device_ptr<T>(random_data),
UniformCUDAGenerator<T>(0.00001, 1, seed));
}
......@@ -151,22 +169,13 @@ struct GumbleNoiseGenerator<platform::CUDADeviceContext, T> {
// add gumbel noise to X
const int thread_size = 512;
int64_t block_size = (size + thread_size) / thread_size;
AddGumbelNoiseCUDAKernel<
T><<<block_size, thread_size, 0, context.stream()>>>(
AddGumbelNoiseCUDAKernel<T><<<block_size, thread_size, 0, ctx.stream()>>>(
input_data, output_data, random_data, temperature, size);
}
};
} // namespace phi
#endif
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
gumbel_softmax, ops::GumbelSoftmaxKernel<plat::CUDADeviceContext, float>,
ops::GumbelSoftmaxKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
gumbel_softmax_grad,
ops::GumbelSoftmaxGradKernel<plat::CUDADeviceContext, float>,
ops::GumbelSoftmaxGradKernel<plat::CUDADeviceContext, double>);
PD_REGISTER_KERNEL(
gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GumbelSoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
int axis,
DenseTensor* dx);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GumbelSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace phi {
template <typename T, typename Context>
void GumbelSoftmaxGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& dout,
int axis,
DenseTensor* dx) {
const int rank = dx->dims().size();
axis = funcs::CanonicalAxis(axis, rank);
int axis_dim = dx->dims()[axis];
// allocate memory on device.
ctx.template Alloc<T>(dx);
if (dx->numel() == 0) {
return;
}
const int size_to_axis = funcs::SizeToAxis(axis, dx->dims());
const int size_from_axis = funcs::SizeFromAxis(axis, dx->dims());
DenseTensor dx_2d(*dx), out_2d(out), dout_2d(dout);
dx_2d.Resize({size_to_axis, size_from_axis});
out_2d.Resize({size_to_axis, size_from_axis});
dout_2d.Resize({size_to_axis, size_from_axis});
paddle::operators::math::SoftmaxGradFunctor<Context, T>()(
ctx, axis_dim, &out_2d, &dout_2d, &dx_2d);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <random>
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename Context, typename T, int64_t Rank>
struct ArgMaxFunctor {
void operator()(const Context& ctx,
const DenseTensor& in,
DenseTensor* index_tensor,
const int64_t& axis) {
auto in_eigen = EigenTensor<T, Rank>::From(in, in.dims());
auto index_eigen = EigenTensor<int, Rank - 1>::From(*index_tensor);
index_eigen = in_eigen.argmax(axis).template cast<int>();
}
};
template <typename Context, typename T>
struct GumbleNoiseGenerator;
template <typename Context, typename T>
struct OneHotGenerator;
template <typename T, typename Context>
void GumbelSoftmaxKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
const int rank = x.dims().size();
axis = funcs::CanonicalAxis(axis, rank);
int axis_dim = x.dims()[axis];
PADDLE_ENFORCE_GT(temperature,
0,
phi::errors::InvalidArgument(
"The temperature must be greater than 0. But "
"received temperature = %f",
temperature));
// allocate memory on device.
ctx.template Alloc<T>(out);
if (out->numel() == 0) {
return;
}
const int size_to_axis = funcs::SizeToAxis(axis, x.dims());
const int size_from_axis = funcs::SizeFromAxis(axis, x.dims());
DenseTensor x_noise_2d, out_2d(*out);
x_noise_2d.Resize({size_to_axis, size_from_axis});
out_2d.Resize({size_to_axis, size_from_axis});
// generate gumbel noise and add it to X
auto* x_noise_data = ctx.template Alloc<T>(&x_noise_2d);
GumbleNoiseGenerator<Context, T>::Transform(ctx,
x.data<T>(),
x_noise_data,
size_to_axis,
size_from_axis,
temperature);
#ifdef PADDLE_ON_INFERENCE
paddle::operators::math::SoftmaxFunctor<Context, T, true>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
#else
paddle::operators::math::SoftmaxFunctor<Context, T, false>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
#endif
if (hard) {
OneHotGenerator<Context, T>::Transform(ctx, x, out, axis);
}
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GumbelSoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("gumbel_softmax_grad",
{"Out", GradVarName("Out")},
{"axis"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax_grad,
phi::GumbelSoftmaxGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册