From bc106fad5ae433c0a5e878bdee729be13da2707f Mon Sep 17 00:00:00 2001 From: wuyefeilin <30919197+wuyefeilin@users.noreply.github.com> Date: Wed, 3 Aug 2022 11:01:07 +0800 Subject: [PATCH] [PHI] Move uniform random inplace op to PHI. (#44700) --- .../operators/uniform_random_inplace_op.cc | 103 +++--------------- .../operators/uniform_random_inplace_op.cu | 57 ---------- paddle/phi/api/yaml/legacy_api.yaml | 11 ++ paddle/phi/api/yaml/legacy_backward.yaml | 10 ++ paddle/phi/infermeta/backward.cc | 18 +++ paddle/phi/infermeta/backward.h | 9 ++ paddle/phi/infermeta/unary.cc | 37 +++++++ paddle/phi/infermeta/unary.h | 9 ++ .../cpu/uniform_random_inplace_grad_kernel.cc | 44 ++++++++ .../cpu/uniform_random_inplace_kernel.cc | 54 +++++++++ .../gpu/uniform_random_inplace_grad_kernel.cu | 44 ++++++++ .../gpu/uniform_random_inplace_kernel.cu | 88 +++++++++++++++ .../uniform_random_inplace_grad_kernel.h | 32 ++++++ .../kernels/uniform_random_inplace_kernel.h | 32 ++++++ .../ops/compat/uniform_random_inplace_sig.cc | 42 +++++++ .../test_uniform_random_inplace_op.py | 30 +++++ python/paddle/tensor/random.py | 8 +- 17 files changed, 484 insertions(+), 144 deletions(-) delete mode 100644 paddle/fluid/operators/uniform_random_inplace_op.cu create mode 100644 paddle/phi/kernels/cpu/uniform_random_inplace_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/uniform_random_inplace_kernel.cc create mode 100644 paddle/phi/kernels/gpu/uniform_random_inplace_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/uniform_random_inplace_kernel.cu create mode 100644 paddle/phi/kernels/uniform_random_inplace_grad_kernel.h create mode 100644 paddle/phi/kernels/uniform_random_inplace_kernel.h create mode 100644 paddle/phi/ops/compat/uniform_random_inplace_sig.cc diff --git a/paddle/fluid/operators/uniform_random_inplace_op.cc b/paddle/fluid/operators/uniform_random_inplace_op.cc index e4283bae06e..09870c8401e 100644 --- a/paddle/fluid/operators/uniform_random_inplace_op.cc +++ b/paddle/fluid/operators/uniform_random_inplace_op.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -54,34 +56,6 @@ class UniformRandomInplaceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "UniformRandomInplaceOp"); - OP_INOUT_CHECK( - ctx->HasOutput("Out"), "Output", "Out", "UniformRandomInplaceOp"); - PADDLE_ENFORCE_LT( - ctx->Attrs().Get("min"), - ctx->Attrs().Get("max"), - platform::errors::InvalidArgument( - "The uniform_random's min must less then max. But received min = " - "%f great than or equal max = %f.", - ctx->Attrs().Get("min"), - ctx->Attrs().Get("max"))); - PADDLE_ENFORCE_GE(ctx->Attrs().Get("diag_num"), - 0, - platform::errors::InvalidArgument( - "The uniform_random's diag_num must greater than or " - "equal 0. But recevied diag_num (%d) < 0.", - ctx->Attrs().Get("diag_num"))); - PADDLE_ENFORCE_GE(ctx->Attrs().Get("diag_step"), - 0, - platform::errors::InvalidArgument( - "The uniform_random's diag_step must greater than or " - "equal 0. But recevied diag_step (%d) < 0.", - ctx->Attrs().Get("diag_step"))); - auto xdim = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", xdim); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -90,23 +64,9 @@ class UniformRandomInplaceOp : public framework::OperatorWithKernel { } }; -template -class CPUUniformRandomInplaceKernel : public framework::OpKernel { +class UniformRandomInplaceGradOp : public framework::OperatorWithKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto out_var = ctx.OutputVar("Out"); - auto *tensor = out_var->GetMutable(); - T *data = tensor->mutable_data(ctx.GetPlace()); - int64_t size = tensor->numel(); - std::uniform_real_distribution dist( - static_cast(ctx.Attr("min")), - static_cast(ctx.Attr("max"))); - auto engine = paddle::framework::GetCPURandomEngine( - static_cast(ctx.Attr("seed"))); - for (int64_t i = 0; i < size; ++i) { - data[i] = dist(*engine); - } - } + using framework::OperatorWithKernel::OperatorWithKernel; }; class UniformRandomInplaceOpVarTypeInference @@ -115,23 +75,6 @@ class UniformRandomInplaceOpVarTypeInference void operator()(framework::InferVarTypeContext *ctx) const override {} }; -class UniformRandomInplaceGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out_Grad", - "UniformRandomInplaceGradOp"); - auto x_dims = ctx->GetInputDim(framework::GradVarName("Out")); - auto x_grad_name = framework::GradVarName("X"); - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - } -}; - template class UniformRandomInplaceGradOpMaker : public framework::SingleGradOpMaker { public: @@ -146,18 +89,6 @@ class UniformRandomInplaceGradOpMaker : public framework::SingleGradOpMaker { } }; -template -class CPUUniformRandomInplaceGradKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext &ctx) const override { - auto *dx = ctx.Output(framework::GradVarName("X")); - if (dx) { - auto *data = dx->mutable_data(ctx.GetPlace()); - std::fill(data, data + dx->numel(), T(0)); - } - } -}; - } // namespace operators } // namespace paddle DECLARE_INPLACE_OP_INFERER(UniformRandomInplaceInferer, {"X", "Out"}); @@ -165,6 +96,14 @@ DECLARE_INPLACE_OP_INFERER(UniformRandomInplaceGradInplaceInferer, {paddle::framework::GradVarName("Out"), paddle::framework::GradVarName("X")}); +DECLARE_INFER_SHAPE_FUNCTOR(uniform_random_inplace, + UniformRandomInplaceInferShapeFunctor, + PD_INFER_META(phi::UniformRandomInplaceInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR( + uniform_random_inplace_grad, + UniformRandomInplaceGradInferShapeFunctor, + PD_INFER_META(phi::UniformRandomInplaceGradInferMeta)); + REGISTER_OPERATOR(uniform_random_inplace, paddle::operators::UniformRandomInplaceOp, paddle::operators::UniformRandomInplaceOpMaker, @@ -173,15 +112,9 @@ REGISTER_OPERATOR(uniform_random_inplace, paddle::operators::UniformRandomInplaceGradOpMaker< paddle::imperative::OpBase>, paddle::operators::UniformRandomInplaceOpVarTypeInference, - UniformRandomInplaceInferer); + UniformRandomInplaceInferer, + UniformRandomInplaceInferShapeFunctor); REGISTER_OPERATOR(uniform_random_inplace_grad, paddle::operators::UniformRandomInplaceGradOp, - UniformRandomInplaceGradInplaceInferer); -REGISTER_OP_CPU_KERNEL( - uniform_random_inplace, - paddle::operators::CPUUniformRandomInplaceKernel, - paddle::operators::CPUUniformRandomInplaceKernel); -REGISTER_OP_CPU_KERNEL( - uniform_random_inplace_grad, - paddle::operators::CPUUniformRandomInplaceGradKernel, - paddle::operators::CPUUniformRandomInplaceGradKernel); + UniformRandomInplaceGradInplaceInferer, + UniformRandomInplaceGradInferShapeFunctor); diff --git a/paddle/fluid/operators/uniform_random_inplace_op.cu b/paddle/fluid/operators/uniform_random_inplace_op.cu deleted file mode 100644 index a3490937410..00000000000 --- a/paddle/fluid/operators/uniform_random_inplace_op.cu +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/uniform_random_op.h" -#include "paddle/phi/kernels/full_kernel.h" - -namespace paddle { -namespace operators { -template -class GPUUniformRandomInplaceKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output("Out"); - UniformRandom(context, tensor); - } -}; - -template -class GPUUniformRandomInplaceGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* dx = ctx.Output(framework::GradVarName("X")); - auto dims = vectorize(dx->dims()); - const auto& dev_cxt = ctx.template device_context(); - float value = static_cast(0.0f); - phi::FullKernel( - static_cast::TYPE&>(dev_cxt), - dims, - value, - phi::DataType::UNDEFINED, - dx); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OP_CUDA_KERNEL( - uniform_random_inplace, - paddle::operators::GPUUniformRandomInplaceKernel, - paddle::operators::GPUUniformRandomInplaceKernel); -REGISTER_OP_CUDA_KERNEL( - uniform_random_inplace_grad, - paddle::operators::GPUUniformRandomInplaceGradKernel, - paddle::operators::GPUUniformRandomInplaceGradKernel); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index fa212ea8f12..fdd86857ed7 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2726,3 +2726,14 @@ kernel: func: overlap_add backward: overlap_add_grad + +- api: uniform_random_inplace + args: (Tensor x, float min, float max, int seed, int diag_num, int diag_step, float diag_val) + output: Tensor(out) + infer_meta: + func: UniformRandomInplaceInferMeta + kernel: + func: uniform_random_inplace + data_type: x + inplace: (x -> out) + backward: uniform_random_inplace_grad diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index c7fa3c13e60..64b68ba6b38 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -2512,6 +2512,16 @@ func : unfold_grad no_need_buffer : x +- backward_api : uniform_random_inplace_grad + forward : uniform_random_inplace(Tensor x, float min, float max, int seed, int diag_num, int diag_step, float diag_val) -> Tensor(out) + args : (Tensor out_grad, float min, float max, int seed, int diag_num, int diag_step, float diag_val) + output : Tensor(x_grad) + infer_meta : + func : UniformRandomInplaceGradInferMeta + kernel : + func : uniform_random_inplace_grad + inplace : (out_grad -> x_grad) + - backward_api : unsqueeze_double_grad forward : unsqueeze_grad(Tensor xshape, Tensor grad_out, IntArray axes) -> Tensor(grad_x) args : (Tensor grad_x_grad, IntArray axes) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a8555827c05..e375999bfba 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -798,6 +798,24 @@ void StackGradInferMeta(const MetaTensor& out_grad, } } +void UniformRandomInplaceGradInferMeta(const MetaTensor& out_grad, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + MetaTensor* x_grad) { + PADDLE_ENFORCE_NE( + x_grad, + nullptr, + phi::errors::InvalidArgument( + "The X@GRAD in UniformRandomInplaceGradInferMeta can't be nullptr.")); + auto dims = out_grad.dims(); + x_grad->set_dims(dims); + x_grad->set_dtype(out_grad.dtype()); +} + void UnStackGradInferMeta(const std::vector& out_grad, int axis, MetaTensor* x_grad) { diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index d9208b7c524..2d31860c17b 100755 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -330,6 +330,15 @@ void StackGradInferMeta(const MetaTensor& out_grad, int axis, std::vector x_grad); +void UniformRandomInplaceGradInferMeta(const MetaTensor& out_grad, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + MetaTensor* x_grad); + void UnStackGradInferMeta(const std::vector& out_grad, int axis, MetaTensor* x_grad); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 0aa2035257a..8389615f386 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3623,6 +3623,43 @@ void UnfoldInferMeta(const MetaTensor& x, out->set_dims(phi::make_ddim(out_dims)); } +void UniformRandomInplaceInferMeta(const MetaTensor& x, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + MetaTensor* out) { + PADDLE_ENFORCE_LT( + min, + max, + errors::InvalidArgument( + "The uniform_random's min must less then max. But received min = " + "%f great than or equal max = %f.", + min, + max)); + PADDLE_ENFORCE_GE(diag_num, + 0, + errors::InvalidArgument( + "The uniform_random's diag_num must greater than or " + "equal 0. But recevied diag_num (%d) < 0.", + diag_num)); + PADDLE_ENFORCE_GE(diag_step, + 0, + errors::InvalidArgument( + "The uniform_random's diag_step must greater than or " + "equal 0. But recevied diag_step (%d) < 0.", + diag_step)); + PADDLE_ENFORCE_NE(out, + nullptr, + phi::errors::InvalidArgument( + "uniform_random should have output tensor out.")); + auto xdim = x.dims(); + out->set_dims(xdim); + out->set_dtype(x.dtype()); +} + void UniqueConsecutiveInferMeta(const MetaTensor& x, bool return_inverse, bool return_counts, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index a37492cf7ec..72e6be818a2 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -492,6 +492,15 @@ void UnfoldInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void UniformRandomInplaceInferMeta(const MetaTensor& x, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + MetaTensor* out); + void UniqueConsecutiveInferMeta(const MetaTensor& x, bool return_inverse, bool return_counts, diff --git a/paddle/phi/kernels/cpu/uniform_random_inplace_grad_kernel.cc b/paddle/phi/kernels/cpu/uniform_random_inplace_grad_kernel.cc new file mode 100644 index 00000000000..d448312949e --- /dev/null +++ b/paddle/phi/kernels/cpu/uniform_random_inplace_grad_kernel.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/uniform_random_inplace_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void UniformRandomInplaceGradKernel(const Context& ctx, + const DenseTensor& out_grad, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor* x_grad) { + if (x_grad) { + auto* data = ctx.template Alloc(x_grad); + std::fill(data, data + x_grad->numel(), T(0)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(uniform_random_inplace_grad, + CPU, + ALL_LAYOUT, + phi::UniformRandomInplaceGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/uniform_random_inplace_kernel.cc b/paddle/phi/kernels/cpu/uniform_random_inplace_kernel.cc new file mode 100644 index 00000000000..6e687fbf543 --- /dev/null +++ b/paddle/phi/kernels/cpu/uniform_random_inplace_kernel.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/uniform_random_inplace_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void UniformRandomInplaceKernel(const Context& ctx, + const DenseTensor& x, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor* out) { + T* data = ctx.template Alloc(out); + int64_t size = out->numel(); + std::uniform_real_distribution dist(static_cast(min), + static_cast(max)); + std::shared_ptr engine; + if (seed) { + engine = std::make_shared(); + engine->seed(seed); + } else { + engine = ctx.GetGenerator()->GetCPUEngine(); + } + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(*engine); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(uniform_random_inplace, + CPU, + ALL_LAYOUT, + phi::UniformRandomInplaceKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/uniform_random_inplace_grad_kernel.cu b/paddle/phi/kernels/gpu/uniform_random_inplace_grad_kernel.cu new file mode 100644 index 00000000000..6c6f525a8d9 --- /dev/null +++ b/paddle/phi/kernels/gpu/uniform_random_inplace_grad_kernel.cu @@ -0,0 +1,44 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/uniform_random_inplace_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" + +namespace phi { + +template +void UniformRandomInplaceGradKernel(const Context& ctx, + const DenseTensor& out_grad, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor* x_grad) { + auto dims = vectorize(x_grad->dims()); + float value = static_cast(0.0f); + phi::FullKernel(ctx, dims, value, phi::DataType::UNDEFINED, x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(uniform_random_inplace_grad, + GPU, + ALL_LAYOUT, + phi::UniformRandomInplaceGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/uniform_random_inplace_kernel.cu b/paddle/phi/kernels/gpu/uniform_random_inplace_kernel.cu new file mode 100644 index 00000000000..d96f582b191 --- /dev/null +++ b/paddle/phi/kernels/gpu/uniform_random_inplace_kernel.cu @@ -0,0 +1,88 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/uniform_random_inplace_kernel.h" + +#include + +#include "gflags/gflags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" +#include "paddle/phi/kernels/funcs/index_impl.cu.h" + +namespace phi { + +template +struct UniformGenerator { + T min_, max_; + unsigned int seed_; + T diag_val_; + unsigned int diag_num_; + unsigned int diag_step_; + __host__ __device__ UniformGenerator( + T min, T max, int seed, int diag_num, int diag_step, T diag_val) + : min_(min), + max_(max), + seed_(seed), + diag_num_(diag_num), + diag_step_(diag_step), + diag_val_(diag_val) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n); + T out = dist(rng); + unsigned int remainder = n % (diag_step_ + 1); + if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { + out = diag_val_; + } + return out; + } +}; + +template +void UniformRandomInplaceKernel(const Context& ctx, + const DenseTensor& x, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor* out) { + ctx.template Alloc(out); + if (seed == 0) { + // Use global Generator seed + using MT = typename kps::details::MPTypeTrait::Type; + funcs::uniform_distribution dist; + funcs::uniform_real_transform trans(min, max); + funcs::distribution_and_transform(ctx, out, dist, trans); + } else { + // Use OP seed + auto func = + UniformGenerator(min, max, seed, diag_num, diag_step, diag_val); + IndexKernel>(ctx, out, func); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(uniform_random_inplace, + GPU, + ALL_LAYOUT, + phi::UniformRandomInplaceKernel, + float, + double) {} diff --git a/paddle/phi/kernels/uniform_random_inplace_grad_kernel.h b/paddle/phi/kernels/uniform_random_inplace_grad_kernel.h new file mode 100644 index 00000000000..ae74fbe2fd7 --- /dev/null +++ b/paddle/phi/kernels/uniform_random_inplace_grad_kernel.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void UniformRandomInplaceGradKernel(const Context& ctx, + const DenseTensor& out_grad, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/uniform_random_inplace_kernel.h b/paddle/phi/kernels/uniform_random_inplace_kernel.h new file mode 100644 index 00000000000..97a79375aff --- /dev/null +++ b/paddle/phi/kernels/uniform_random_inplace_kernel.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void UniformRandomInplaceKernel(const Context& ctx, + const DenseTensor& x, + float min, + float max, + int seed, + int diag_num, + int diag_step, + float diag_val, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/uniform_random_inplace_sig.cc b/paddle/phi/ops/compat/uniform_random_inplace_sig.cc new file mode 100644 index 00000000000..afdc0d5f3b3 --- /dev/null +++ b/paddle/phi/ops/compat/uniform_random_inplace_sig.cc @@ -0,0 +1,42 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { +KernelSignature UniformRandomInplaceOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "uniform_random_inplace", + {"X"}, + {"min", "max", "seed", "diag_num", "diag_step", "diag_val"}, + {"Out"}); +} + +KernelSignature UniformRandomInplaceGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "uniform_random_inplace_grad", + {"Out@GRAD"}, + {"min", "max", "seed", "diag_num", "diag_step", "diag_val"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(uniform_random_inplace, + phi::UniformRandomInplaceOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(uniform_random_inplace_grad, + phi::UniformRandomInplaceGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_inplace_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_inplace_op.py index 2e0196d4b16..ae772818cbc 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_inplace_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_inplace_op.py @@ -16,6 +16,7 @@ import unittest import paddle import paddle.fluid as fluid import numpy as np +from paddle.fluid.framework import _enable_legacy_dygraph, _disable_legacy_dygraph class TestUniformRandomInplaceOpDtype(unittest.TestCase): @@ -191,5 +192,34 @@ class TestUniformRandomInplaceGrad(unittest.TestCase): fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) +class TestUniformRandomInplaceGradOldDygraph(unittest.TestCase): + + def setUp(self): + self.shape = (1000, 784) + + def test_uniform_random_inplace_grad(self): + _enable_legacy_dygraph() + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + + def test_grad(): + tensor_a = paddle.ones(self.shape) + tensor_a.stop_gradient = False + tensor_b = tensor_a * 0.5 + tensor_b.uniform_(min=-2, max=2) + loss = tensor_b.sum() + loss.backward() + uniform_grad = tensor_b.grad.numpy() + self.assertTrue((uniform_grad == 0).all()) + + places = ['cpu'] + if fluid.core.is_compiled_with_cuda(): + places.append('gpu') + for place in places: + paddle.set_device(place) + test_grad() + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) + _disable_legacy_dygraph() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 59f40f7f39c..663c2ccb918 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -620,8 +620,12 @@ def uniform_(x, min=-1.0, max=1.0, seed=0, name=None): # [-0.34646994, -0.45116323, -0.09902662, -0.11397249], # random # [ 0.433519, 0.39483607, -0.8660099, 0.83664286]] # random """ - return _C_ops.uniform_random_inplace_(x, 'min', min, 'max', max, 'seed', - seed) + if in_dygraph_mode(): + return _C_ops.final_state_uniform_random_inplace_( + x, min, max, seed, 0, 0, 1.0) + else: + return _C_ops.uniform_random_inplace_(x, 'min', min, 'max', max, 'seed', + seed) def randint(low=0, high=None, shape=[1], dtype=None, name=None): -- GitLab