未验证 提交 798a4eac 编写于 作者: X Xiaoxu Chen 提交者: GitHub

migrate dirichlet kernel to phi (#44434)

* migrate dirichlet op kernel to phi

* fix dirichlet sample memory leak
上级 2781740b
......@@ -11,83 +11,14 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/dirichlet_op.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
template <typename T, typename UniformSamplerT, typename NormalSamplerT>
struct GammaCPUFunctor {
GammaCPUFunctor(const T* alpha,
T* gamma,
BaseSampler<T, UniformSamplerT> uniform,
BaseSampler<T, NormalSamplerT> normal)
: alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {}
HOST void operator()(int64_t index) {
auto sample = sample_gamma<T, T, UniformSamplerT, NormalSamplerT>(
alpha_[index], uniform_, normal_);
gamma_[index] = std::max(std::numeric_limits<T>::min(), sample);
}
const T* alpha_;
T* gamma_;
BaseSampler<T, UniformSamplerT> uniform_;
BaseSampler<T, NormalSamplerT> normal_;
};
template <typename T>
struct DirichletSampler<phi::CPUContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const Tensor* alpha,
Tensor* out) {
auto& dev_ctx = ctx.device_context<phi::CPUContext>();
auto p_gen = framework::DefaultCPUGenerator();
auto generator = p_gen->GetCPUEngine();
auto uniform = [&generator]() -> T {
std::uniform_real_distribution<T> u(0.0, 1.0);
return u(*generator);
};
BaseSampler<T, decltype(uniform)> standard_uniform(uniform);
auto normal = [&generator]() {
std::normal_distribution<T> n(0.0, 1.0);
return n(*generator);
};
BaseSampler<T, decltype(normal)> standard_normal(normal);
// sample from K gamma distributions, where K=alpha.numel()
framework::Tensor gamma_samples;
gamma_samples.mutable_data<T>(alpha->dims(), dev_ctx.GetPlace());
GammaCPUFunctor<T, decltype(uniform), decltype(normal)> gamma_functor(
alpha->data<T>(),
gamma_samples.data<T>(),
standard_uniform,
standard_normal);
platform::ForRange<phi::CPUContext> for_range(dev_ctx, alpha->numel());
for_range(gamma_functor);
// normalize them into a simplex, along the last axis
framework::Tensor gamma_sum;
auto new_shape = gamma_samples.dims();
new_shape[new_shape.size() - 1] = 1;
gamma_sum.mutable_data<T>(new_shape, dev_ctx.GetPlace());
ReduceKernelFunctor<phi::CPUContext, T, SumFunctor>(
&gamma_samples, &gamma_sum, {new_shape.size() - 1}, true, false, ctx)
.template apply<T>();
ElementwiseComputeEx<DivFunctor<T>, phi::CPUContext, T, T>(
ctx, &gamma_samples, &gamma_sum, -1, DivFunctor<T>(), out);
}
};
class DirichletOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -100,29 +31,16 @@ class DirichletOpMaker : public framework::OpProtoAndCheckerMaker {
class DirichletOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Alpha"), "Input", "Alpha", "dirichlet");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "dirichlet");
const auto alpha_dim = ctx->GetInputDim("Alpha");
PADDLE_ENFORCE_GE(alpha_dim.size(),
1,
platform::errors::InvalidArgument(
"ShapeError: The number of dimensions of 'Alpha' "
"must be greater than or euqal to 1. "
"But received Alpha's dimensions = %d,",
alpha_dim.size()));
ctx->ShareDim("Alpha", /*->*/ "Out");
}
};
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(dirichlet,
DirichletInferShapeFunctor,
PD_INFER_META(phi::DirichletInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(dirichlet,
paddle::operators::DirichletOp,
paddle::operators::DirichletOpMaker);
REGISTER_OP_CPU_KERNEL(
dirichlet,
paddle::operators::DirichletKernel<phi::CPUContext, float>,
paddle::operators::DirichletKernel<phi::CPUContext, double>);
paddle::operators::DirichletOpMaker,
DirichletInferShapeFunctor);
......@@ -2531,6 +2531,15 @@
kernel:
func: broadcast_tensors
backward: broadcast_tensors_grad
# dirichlet
- api: dirichlet
args: (Tensor alpha)
output: Tensor
infer_meta:
func: DirichletInferMeta
kernel:
func: dirichlet
# eig
- api: eig
......
......@@ -518,6 +518,19 @@ void DiagonalInferMeta(const MetaTensor& input,
out->set_dims(phi::make_ddim(out_dims));
}
void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) {
const auto alpha_dim = alpha.dims();
PADDLE_ENFORCE_GE(alpha_dim.size(),
1,
phi::errors::InvalidArgument(
"ShapeError: The number of dimensions of 'Alpha' "
"must be greater than or euqal to 1. "
"But received Alpha's dimensions = %d,",
alpha_dim.size()));
out->set_dims(alpha_dim);
out->set_dtype(alpha.dtype());
}
void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) {
auto x_dims = x.dims();
int rank = x_dims.size();
......
......@@ -90,6 +90,8 @@ void DiagInferMeta(const MetaTensor& x,
void DiagonalInferMeta(
const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out);
void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out);
void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v);
void EighInferMeta(const MetaTensor& x,
......@@ -534,5 +536,4 @@ void ChannelShuffleInferMeta(const MetaTensor& x,
MetaTensor* out);
void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/impl/dirichlet_kernel_impl.h"
namespace phi {
template <typename T, typename UniformSamplerT, typename NormalSamplerT>
struct GammaCPUFunctor {
GammaCPUFunctor(const T* alpha,
T* gamma,
BaseSampler<T, UniformSamplerT> uniform,
BaseSampler<T, NormalSamplerT> normal)
: alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {}
HOST void operator()(int64_t index) {
auto sample = sample_gamma<T, T, UniformSamplerT, NormalSamplerT>(
alpha_[index], uniform_, normal_);
gamma_[index] = std::max(std::numeric_limits<T>::min(), sample);
}
const T* alpha_;
T* gamma_;
BaseSampler<T, UniformSamplerT> uniform_;
BaseSampler<T, NormalSamplerT> normal_;
};
template <typename T>
struct DirichletSampler<CPUContext, T> {
void operator()(const CPUContext& dev_ctx,
const DenseTensor& alpha,
DenseTensor* out) {
auto generator = dev_ctx.GetGenerator()->GetCPUEngine();
auto uniform = [&generator]() -> T {
std::uniform_real_distribution<T> u(0.0, 1.0);
return u(*generator);
};
BaseSampler<T, decltype(uniform)> standard_uniform(uniform);
auto normal = [&generator]() {
std::normal_distribution<T> n(0.0, 1.0);
return n(*generator);
};
BaseSampler<T, decltype(normal)> standard_normal(normal);
// sample from K gamma distributions, where K=alpha.numel()
DenseTensor gamma_samples;
gamma_samples.Resize(alpha.dims());
dev_ctx.template Alloc<T>(&gamma_samples);
GammaCPUFunctor<T, decltype(uniform), decltype(normal)> gamma_functor(
alpha.data<T>(),
gamma_samples.data<T>(),
standard_uniform,
standard_normal);
funcs::ForRange<CPUContext> for_range(dev_ctx, alpha.numel());
for_range(gamma_functor);
// normalize them into a simplex, along the last axis
DenseTensor gamma_sum;
auto new_shape = gamma_samples.dims();
new_shape[new_shape.size() - 1] = 1;
gamma_sum.Resize(new_shape);
dev_ctx.template Alloc<T>(&gamma_sum);
ReduceKernelImpl<CPUContext, T, T, funcs::SumFunctor>(
dev_ctx,
gamma_samples,
&gamma_sum,
{new_shape.size() - 1},
true,
false);
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T, T>(
dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor<T>(), out);
}
};
} // namespace phi
PD_REGISTER_KERNEL(
dirichlet, CPU, ALL_LAYOUT, phi::Dirichletkernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void Dirichletkernel(const Context& dev_ctx,
const DenseTensor& alpha,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -12,12 +14,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/dirichlet_op.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/impl/dirichlet_kernel_impl.h"
#ifdef PADDLE_WITH_CUDA
#include <curand_kernel.h>
......@@ -38,8 +42,7 @@ using COMPAT_RANDSTATEPHILOX4_32_10_T = hiprandStatePhilox4_32_10_t;
#define COMPAT_RAND_NORMAL hiprand_normal
#endif
namespace paddle {
namespace operators {
namespace phi {
template <typename T>
struct GammaCUDAFunctor {
GammaCUDAFunctor(const T* alpha, T* gamma, uint64_t seed, uint64_t offset)
......@@ -70,47 +73,44 @@ struct GammaCUDAFunctor {
};
template <typename T>
struct DirichletSampler<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* alpha,
framework::Tensor* out) {
auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
// init state, seed & offset for all threads
int device_id = ctx.GetPlace().GetDeviceId();
auto p_gen = framework::DefaultCUDAGenerator(device_id);
struct DirichletSampler<GPUContext, T> {
void operator()(const GPUContext& dev_ctx,
const DenseTensor& alpha,
DenseTensor* out) {
auto p_gen = dev_ctx.GetGenerator();
auto seed_and_offset = p_gen->IncrementOffset(10); // hard-coded offset
auto seed = seed_and_offset.first;
auto offset = seed_and_offset.second;
// sample from K gamma distributions, where K=alpha.numel()
framework::Tensor gamma_samples;
gamma_samples.mutable_data<T>(alpha->dims(), dev_ctx.GetPlace());
DenseTensor gamma_samples;
gamma_samples.Resize(alpha.dims());
dev_ctx.template Alloc<T>(&gamma_samples);
GammaCUDAFunctor<T> gamma_functor(
alpha->data<T>(), gamma_samples.data<T>(), seed, offset);
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
out->numel());
alpha.data<T>(), gamma_samples.data<T>(), seed, offset);
funcs::ForRange<GPUContext> for_range(dev_ctx, out->numel());
for_range(gamma_functor);
// normalize them into a simplex, along the last axis
framework::Tensor gamma_sum;
DenseTensor gamma_sum;
auto new_shape = gamma_samples.dims();
new_shape[new_shape.size() - 1] = 1;
gamma_sum.mutable_data<T>(new_shape, dev_ctx.GetPlace());
gamma_sum.Resize(new_shape);
dev_ctx.template Alloc<T>(&gamma_sum);
ReduceKernelFunctor<platform::CUDADeviceContext, T, SumFunctor>(
&gamma_samples, &gamma_sum, {new_shape.size() - 1}, true, false, ctx)
.template apply<T>();
ElementwiseComputeEx<DivFunctor<T>, platform::CUDADeviceContext, T, T>(
ctx, &gamma_samples, &gamma_sum, -1, DivFunctor<T>(), out);
ReduceKernelImpl<GPUContext, T, T, funcs::SumFunctor>(
dev_ctx,
gamma_samples,
&gamma_sum,
{new_shape.size() - 1},
true,
false);
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T, T>(
dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor<T>(), out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
} // namespace phi
REGISTER_OP_CUDA_KERNEL(
dirichlet,
ops::DirichletKernel<paddle::platform::CUDADeviceContext, float>,
ops::DirichletKernel<paddle::platform::CUDADeviceContext, double>);
PD_REGISTER_KERNEL(
dirichlet, GPU, ALL_LAYOUT, phi::Dirichletkernel, float, double) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -13,11 +13,10 @@
// limitations under the License.
#pragma once
#include <cmath>
#include <random>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/dirichlet_kernel.h"
// ROCM hcc doesn't work well with using std:: in kernel functions
#if defined(PADDLE_WITH_CUDA)
......@@ -42,10 +41,7 @@
#define COMPAT_LOG1P std::log1p
#endif
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
struct DirichletSampler;
namespace phi {
template <typename ScalarT, typename SamplerT>
struct BaseSampler {
......@@ -117,17 +113,19 @@ sample_gamma(ScalarT alpha,
}
}
template <typename DeviceContext, typename T>
class DirichletKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* alpha = ctx.Input<framework::Tensor>("Alpha");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
DirichletSampler<DeviceContext, T> sampler;
sampler(ctx, alpha, out);
}
template <typename Context, typename T>
struct DirichletSampler {
void operator()(const Context& dev_ctx,
const DenseTensor& alpha,
DenseTensor* out);
};
} // namespace operators
} // namespace paddle
template <typename T, typename Context>
void Dirichletkernel(const Context& dev_ctx,
const DenseTensor& alpha,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
DirichletSampler<Context, T> sampler;
sampler(dev_ctx, alpha, out);
}
} // namespace phi
......@@ -15,7 +15,7 @@
import paddle
from paddle.distribution import exponential_family
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
from paddle.fluid.layer_helper import LayerHelper
......@@ -157,9 +157,10 @@ def _dirichlet(concentration, name=None):
check_variable_and_dtype(concentration, 'concentration',
['float32', 'float64'], op_type)
if _non_static_mode():
if in_dygraph_mode():
return paddle._C_ops.final_state_dirichlet(concentration)
elif _in_legacy_dygraph():
return paddle._C_ops.dirichlet(concentration)
else:
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册