From 798a4eacdd9267da1c76536e990539601b69a5d4 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Thu, 28 Jul 2022 15:29:29 +0800 Subject: [PATCH] migrate dirichlet kernel to phi (#44434) * migrate dirichlet op kernel to phi * fix dirichlet sample memory leak --- paddle/fluid/operators/dirichlet_op.cc | 104 ++---------------- paddle/phi/api/yaml/legacy_api.yaml | 9 ++ paddle/phi/infermeta/unary.cc | 13 +++ paddle/phi/infermeta/unary.h | 3 +- paddle/phi/kernels/cpu/dirichlet_kernel.cc | 102 +++++++++++++++++ paddle/phi/kernels/dirichlet_kernel.h | 25 +++++ .../kernels/gpu/dirichlet_kernel.cu} | 74 ++++++------- .../kernels/impl/dirichlet_kernel_impl.h} | 40 ++++--- python/paddle/distribution/dirichlet.py | 7 +- 9 files changed, 222 insertions(+), 155 deletions(-) create mode 100644 paddle/phi/kernels/cpu/dirichlet_kernel.cc create mode 100644 paddle/phi/kernels/dirichlet_kernel.h rename paddle/{fluid/operators/dirichlet_op.cu => phi/kernels/gpu/dirichlet_kernel.cu} (59%) rename paddle/{fluid/operators/dirichlet_op.h => phi/kernels/impl/dirichlet_kernel_impl.h} (84%) diff --git a/paddle/fluid/operators/dirichlet_op.cc b/paddle/fluid/operators/dirichlet_op.cc index ccbe3b62b73..d9f5d367c88 100644 --- a/paddle/fluid/operators/dirichlet_op.cc +++ b/paddle/fluid/operators/dirichlet_op.cc @@ -11,83 +11,14 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -#include "paddle/fluid/operators/dirichlet_op.h" - -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" -#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { -template -struct GammaCPUFunctor { - GammaCPUFunctor(const T* alpha, - T* gamma, - BaseSampler uniform, - BaseSampler normal) - : alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {} - - HOST void operator()(int64_t index) { - auto sample = sample_gamma( - alpha_[index], uniform_, normal_); - gamma_[index] = std::max(std::numeric_limits::min(), sample); - } - - const T* alpha_; - T* gamma_; - BaseSampler uniform_; - BaseSampler normal_; -}; - -template -struct DirichletSampler { - void operator()(const framework::ExecutionContext& ctx, - const Tensor* alpha, - Tensor* out) { - auto& dev_ctx = ctx.device_context(); - - auto p_gen = framework::DefaultCPUGenerator(); - auto generator = p_gen->GetCPUEngine(); - - auto uniform = [&generator]() -> T { - std::uniform_real_distribution u(0.0, 1.0); - return u(*generator); - }; - BaseSampler standard_uniform(uniform); - - auto normal = [&generator]() { - std::normal_distribution n(0.0, 1.0); - return n(*generator); - }; - BaseSampler standard_normal(normal); - - // sample from K gamma distributions, where K=alpha.numel() - framework::Tensor gamma_samples; - gamma_samples.mutable_data(alpha->dims(), dev_ctx.GetPlace()); - GammaCPUFunctor gamma_functor( - alpha->data(), - gamma_samples.data(), - standard_uniform, - standard_normal); - platform::ForRange for_range(dev_ctx, alpha->numel()); - for_range(gamma_functor); - - // normalize them into a simplex, along the last axis - framework::Tensor gamma_sum; - auto new_shape = gamma_samples.dims(); - new_shape[new_shape.size() - 1] = 1; - gamma_sum.mutable_data(new_shape, dev_ctx.GetPlace()); - - ReduceKernelFunctor( - &gamma_samples, &gamma_sum, {new_shape.size() - 1}, true, false, ctx) - .template apply(); - ElementwiseComputeEx, phi::CPUContext, T, T>( - ctx, &gamma_samples, &gamma_sum, -1, DivFunctor(), out); - } -}; - class DirichletOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -100,29 +31,16 @@ class DirichletOpMaker : public framework::OpProtoAndCheckerMaker { class DirichletOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Alpha"), "Input", "Alpha", "dirichlet"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "dirichlet"); - const auto alpha_dim = ctx->GetInputDim("Alpha"); - PADDLE_ENFORCE_GE(alpha_dim.size(), - 1, - platform::errors::InvalidArgument( - "ShapeError: The number of dimensions of 'Alpha' " - "must be greater than or euqal to 1. " - "But received Alpha's dimensions = %d,", - alpha_dim.size())); - ctx->ShareDim("Alpha", /*->*/ "Out"); - } }; } // namespace operators } // namespace paddle +DECLARE_INFER_SHAPE_FUNCTOR(dirichlet, + DirichletInferShapeFunctor, + PD_INFER_META(phi::DirichletInferMeta)); + REGISTER_OP_WITHOUT_GRADIENT(dirichlet, paddle::operators::DirichletOp, - paddle::operators::DirichletOpMaker); -REGISTER_OP_CPU_KERNEL( - dirichlet, - paddle::operators::DirichletKernel, - paddle::operators::DirichletKernel); + paddle::operators::DirichletOpMaker, + DirichletInferShapeFunctor); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 6d5e87bd793..dd7ec0af6f1 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2531,6 +2531,15 @@ kernel: func: broadcast_tensors backward: broadcast_tensors_grad + +# dirichlet +- api: dirichlet + args: (Tensor alpha) + output: Tensor + infer_meta: + func: DirichletInferMeta + kernel: + func: dirichlet # eig - api: eig diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index dbeba144b51..109459d0e11 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -518,6 +518,19 @@ void DiagonalInferMeta(const MetaTensor& input, out->set_dims(phi::make_ddim(out_dims)); } +void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) { + const auto alpha_dim = alpha.dims(); + PADDLE_ENFORCE_GE(alpha_dim.size(), + 1, + phi::errors::InvalidArgument( + "ShapeError: The number of dimensions of 'Alpha' " + "must be greater than or euqal to 1. " + "But received Alpha's dimensions = %d,", + alpha_dim.size())); + out->set_dims(alpha_dim); + out->set_dtype(alpha.dtype()); +} + void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) { auto x_dims = x.dims(); int rank = x_dims.size(); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index e7d04cb998b..77d155e16eb 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -90,6 +90,8 @@ void DiagInferMeta(const MetaTensor& x, void DiagonalInferMeta( const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out); +void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out); + void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v); void EighInferMeta(const MetaTensor& x, @@ -534,5 +536,4 @@ void ChannelShuffleInferMeta(const MetaTensor& x, MetaTensor* out); void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out); - } // namespace phi diff --git a/paddle/phi/kernels/cpu/dirichlet_kernel.cc b/paddle/phi/kernels/cpu/dirichlet_kernel.cc new file mode 100644 index 00000000000..76ef2313441 --- /dev/null +++ b/paddle/phi/kernels/cpu/dirichlet_kernel.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/elementwise.h" +#include "paddle/phi/kernels/cpu/reduce.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" +#include "paddle/phi/kernels/impl/dirichlet_kernel_impl.h" + +namespace phi { + +template +struct GammaCPUFunctor { + GammaCPUFunctor(const T* alpha, + T* gamma, + BaseSampler uniform, + BaseSampler normal) + : alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {} + + HOST void operator()(int64_t index) { + auto sample = sample_gamma( + alpha_[index], uniform_, normal_); + gamma_[index] = std::max(std::numeric_limits::min(), sample); + } + + const T* alpha_; + T* gamma_; + BaseSampler uniform_; + BaseSampler normal_; +}; + +template +struct DirichletSampler { + void operator()(const CPUContext& dev_ctx, + const DenseTensor& alpha, + DenseTensor* out) { + auto generator = dev_ctx.GetGenerator()->GetCPUEngine(); + + auto uniform = [&generator]() -> T { + std::uniform_real_distribution u(0.0, 1.0); + return u(*generator); + }; + BaseSampler standard_uniform(uniform); + + auto normal = [&generator]() { + std::normal_distribution n(0.0, 1.0); + return n(*generator); + }; + BaseSampler standard_normal(normal); + + // sample from K gamma distributions, where K=alpha.numel() + DenseTensor gamma_samples; + gamma_samples.Resize(alpha.dims()); + dev_ctx.template Alloc(&gamma_samples); + + GammaCPUFunctor gamma_functor( + alpha.data(), + gamma_samples.data(), + standard_uniform, + standard_normal); + funcs::ForRange for_range(dev_ctx, alpha.numel()); + for_range(gamma_functor); + + // normalize them into a simplex, along the last axis + DenseTensor gamma_sum; + auto new_shape = gamma_samples.dims(); + new_shape[new_shape.size() - 1] = 1; + gamma_sum.Resize(new_shape); + dev_ctx.template Alloc(&gamma_sum); + + ReduceKernelImpl( + dev_ctx, + gamma_samples, + &gamma_sum, + {new_shape.size() - 1}, + true, + false); + + funcs::ElementwiseCompute, T, T>( + dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor(), out); + } +}; + +} // namespace phi + +PD_REGISTER_KERNEL( + dirichlet, CPU, ALL_LAYOUT, phi::Dirichletkernel, float, double) {} diff --git a/paddle/phi/kernels/dirichlet_kernel.h b/paddle/phi/kernels/dirichlet_kernel.h new file mode 100644 index 00000000000..a758eb8db02 --- /dev/null +++ b/paddle/phi/kernels/dirichlet_kernel.h @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void Dirichletkernel(const Context& dev_ctx, + const DenseTensor& alpha, + DenseTensor* out); +} // namespace phi diff --git a/paddle/fluid/operators/dirichlet_op.cu b/paddle/phi/kernels/gpu/dirichlet_kernel.cu similarity index 59% rename from paddle/fluid/operators/dirichlet_op.cu rename to paddle/phi/kernels/gpu/dirichlet_kernel.cu index aa83f8de87b..eb34df90f08 100644 --- a/paddle/fluid/operators/dirichlet_op.cu +++ b/paddle/phi/kernels/gpu/dirichlet_kernel.cu @@ -1,3 +1,5 @@ + + // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,12 +14,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/dirichlet_op.h" -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" -#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" -#include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/reduce.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" +#include "paddle/phi/kernels/impl/dirichlet_kernel_impl.h" #ifdef PADDLE_WITH_CUDA #include @@ -38,8 +42,7 @@ using COMPAT_RANDSTATEPHILOX4_32_10_T = hiprandStatePhilox4_32_10_t; #define COMPAT_RAND_NORMAL hiprand_normal #endif -namespace paddle { -namespace operators { +namespace phi { template struct GammaCUDAFunctor { GammaCUDAFunctor(const T* alpha, T* gamma, uint64_t seed, uint64_t offset) @@ -70,47 +73,44 @@ struct GammaCUDAFunctor { }; template -struct DirichletSampler { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* alpha, - framework::Tensor* out) { - auto& dev_ctx = ctx.device_context(); - - // init state, seed & offset for all threads - int device_id = ctx.GetPlace().GetDeviceId(); - auto p_gen = framework::DefaultCUDAGenerator(device_id); +struct DirichletSampler { + void operator()(const GPUContext& dev_ctx, + const DenseTensor& alpha, + DenseTensor* out) { + auto p_gen = dev_ctx.GetGenerator(); auto seed_and_offset = p_gen->IncrementOffset(10); // hard-coded offset auto seed = seed_and_offset.first; auto offset = seed_and_offset.second; // sample from K gamma distributions, where K=alpha.numel() - framework::Tensor gamma_samples; - gamma_samples.mutable_data(alpha->dims(), dev_ctx.GetPlace()); + DenseTensor gamma_samples; + gamma_samples.Resize(alpha.dims()); + dev_ctx.template Alloc(&gamma_samples); + GammaCUDAFunctor gamma_functor( - alpha->data(), gamma_samples.data(), seed, offset); - platform::ForRange for_range(dev_ctx, - out->numel()); + alpha.data(), gamma_samples.data(), seed, offset); + funcs::ForRange for_range(dev_ctx, out->numel()); for_range(gamma_functor); // normalize them into a simplex, along the last axis - framework::Tensor gamma_sum; + DenseTensor gamma_sum; auto new_shape = gamma_samples.dims(); new_shape[new_shape.size() - 1] = 1; - gamma_sum.mutable_data(new_shape, dev_ctx.GetPlace()); + gamma_sum.Resize(new_shape); + dev_ctx.template Alloc(&gamma_sum); - ReduceKernelFunctor( - &gamma_samples, &gamma_sum, {new_shape.size() - 1}, true, false, ctx) - .template apply(); - ElementwiseComputeEx, platform::CUDADeviceContext, T, T>( - ctx, &gamma_samples, &gamma_sum, -1, DivFunctor(), out); + ReduceKernelImpl( + dev_ctx, + gamma_samples, + &gamma_sum, + {new_shape.size() - 1}, + true, + false); + funcs::ElementwiseCompute, T, T>( + dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor(), out); } }; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; +} // namespace phi -REGISTER_OP_CUDA_KERNEL( - dirichlet, - ops::DirichletKernel, - ops::DirichletKernel); +PD_REGISTER_KERNEL( + dirichlet, GPU, ALL_LAYOUT, phi::Dirichletkernel, float, double) {} diff --git a/paddle/fluid/operators/dirichlet_op.h b/paddle/phi/kernels/impl/dirichlet_kernel_impl.h similarity index 84% rename from paddle/fluid/operators/dirichlet_op.h rename to paddle/phi/kernels/impl/dirichlet_kernel_impl.h index 75ee3580a0f..01858debc38 100644 --- a/paddle/fluid/operators/dirichlet_op.h +++ b/paddle/phi/kernels/impl/dirichlet_kernel_impl.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,11 +13,10 @@ // limitations under the License. #pragma once + #include #include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/kernels/dirichlet_kernel.h" // ROCM hcc doesn't work well with using std:: in kernel functions #if defined(PADDLE_WITH_CUDA) @@ -42,10 +41,7 @@ #define COMPAT_LOG1P std::log1p #endif -namespace paddle { -namespace operators { -template -struct DirichletSampler; +namespace phi { template struct BaseSampler { @@ -117,17 +113,19 @@ sample_gamma(ScalarT alpha, } } -template -class DirichletKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto* alpha = ctx.Input("Alpha"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - DirichletSampler sampler; - sampler(ctx, alpha, out); - } +template +struct DirichletSampler { + void operator()(const Context& dev_ctx, + const DenseTensor& alpha, + DenseTensor* out); }; -} // namespace operators -} // namespace paddle + +template +void Dirichletkernel(const Context& dev_ctx, + const DenseTensor& alpha, + DenseTensor* out) { + dev_ctx.template Alloc(out); + DirichletSampler sampler; + sampler(dev_ctx, alpha, out); +} +} // namespace phi diff --git a/python/paddle/distribution/dirichlet.py b/python/paddle/distribution/dirichlet.py index 63466bda7c0..050af6069c5 100644 --- a/python/paddle/distribution/dirichlet.py +++ b/python/paddle/distribution/dirichlet.py @@ -15,7 +15,7 @@ import paddle from paddle.distribution import exponential_family from paddle.fluid.data_feeder import check_variable_and_dtype -from paddle.fluid.framework import _non_static_mode, in_dygraph_mode +from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph from paddle.fluid.layer_helper import LayerHelper @@ -157,9 +157,10 @@ def _dirichlet(concentration, name=None): check_variable_and_dtype(concentration, 'concentration', ['float32', 'float64'], op_type) - if _non_static_mode(): + if in_dygraph_mode(): + return paddle._C_ops.final_state_dirichlet(concentration) + elif _in_legacy_dygraph(): return paddle._C_ops.dirichlet(concentration) - else: helper = LayerHelper(op_type, **locals()) out = helper.create_variable_for_type_inference( -- GitLab