diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 6b559885c569d001233525c3d964fff2175950e3..66eecc13d04d1aa7d4532b69f7a2fbe8c62b8e6f 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -15,12 +15,14 @@ limitations under the License. */ #include #include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/fill_constant_op.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#include "paddle/phi/infermeta/nullary.h" namespace paddle { namespace operators { @@ -54,38 +56,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GaussianRandom"); - - auto shape = ctx->Attrs().Get>("shape"); - std::vector temp; - temp.reserve(shape.size()); - for (auto dim : shape) { - temp.push_back(static_cast(dim)); - } - if (shape.empty() && ctx->HasInput("ShapeTensor")) { - auto shape_dims = ctx->GetInputDim("ShapeTensor"); - int num_ele = 1; - for (int i = 0; i < shape_dims.size(); ++i) { - num_ele *= shape_dims[i]; - } - auto vec_dims = std::vector(num_ele, -1); - ctx->SetOutputDim("Out", phi::make_ddim(vec_dims)); - - return; - } - if (!ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList")) { - PADDLE_ENFORCE_GT( - shape.size(), 0UL, - platform::errors::InvalidArgument( - "Attribute(shape) of GaussianRandomOp must be set " - "and shape.size() > 0, but reveived shape.size() is %d", - shape.size())); - } - - ctx->SetOutputDim("Out", phi::make_ddim(temp)); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -171,11 +141,20 @@ Used to initialize tensors with gaussian random generator. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, - ops::GaussianRandomOpMaker); + +DECLARE_INFER_SHAPE_FUNCTOR(gaussian_random, GaussianRandomInferShapeFunctor, + PD_INFER_META(phi::GaussianRandomInferMeta)); + +REGISTER_OPERATOR( + gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + GaussianRandomInferShapeFunctor); + REGISTER_OP_CPU_KERNEL(gaussian_random_batch_size_like, ops::CPUGaussianRandomBatchSizeLikeKernel, ops::CPUGaussianRandomBatchSizeLikeKernel); + REGISTER_OP_VERSION(gaussian_random) .AddCheckpoint( R"ROC( diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cc b/paddle/fluid/operators/truncated_gaussian_random_op.cc index 6eb7f922dfdbec41aa1c47d11e1decc259d08689..dc5a66dce16d698f9cfac01e3bdc776d08c2af19 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op.cc +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cc @@ -17,8 +17,10 @@ limitations under the License. */ #include #include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/truncated_gaussian_random_op.h" +#include "paddle/phi/infermeta/nullary.h" namespace paddle { namespace operators { @@ -27,26 +29,6 @@ class TruncatedGaussianRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::NotFound( - "Output(Out) of TruncatedGaussianRandomOp should not be null.")); - auto shape = ctx->Attrs().Get>("shape"); - std::vector out_dim; - out_dim.reserve(shape.size()); - for (auto dim : shape) { - out_dim.push_back(static_cast(dim)); - } - PADDLE_ENFORCE_GT( - shape.size(), 0UL, - platform::errors::InvalidArgument( - "the input shape of TruncatedGaussianRandomOp must be set, " - "But the rank of shape we received is %d", - shape.size())); - ctx->SetOutputDim("Out", phi::make_ddim(out_dim)); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -99,6 +81,14 @@ Used to initialize tensors with truncated gaussian random generator. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(truncated_gaussian_random, - ops::TruncatedGaussianRandomOp, - ops::TruncatedGaussianRandomOpMaker); + +DECLARE_INFER_SHAPE_FUNCTOR( + truncated_gaussian_random, TruncatedGaussianRandomInferShapeFunctor, + PD_INFER_META(phi::TruncatedGaussianRandomInferMeta)); + +REGISTER_OPERATOR( + truncated_gaussian_random, ops::TruncatedGaussianRandomOp, + ops::TruncatedGaussianRandomOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + TruncatedGaussianRandomInferShapeFunctor); diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 0c48c9d0c7eae5de0fdb3d2c4c7bb0a9765e7b9f..506d3fd14ea3fd568ce2f77d7ce30408062279e9 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -40,4 +40,29 @@ void EyeInferMeta(int64_t num_rows, out->set_dims({num_rows, num_columns}); out->set_dtype(dtype); } + +void TruncatedGaussianRandomInferMeta(const std::vector& shape, + float mean, + float std, + int seed, + DataType dtype, + MetaTensor* out) { + auto out_dims = phi::make_ddim(shape); + out->set_dims(out_dims); + out->set_dtype(dtype); + out->set_layout(DataLayout::NCHW); +} + +void GaussianRandomInferMeta(const ScalarArray& shape, + float mean, + float std, + int seed, + DataType dtype, + MetaTensor* out) { + auto out_dims = phi::make_ddim(shape.GetData()); + out->set_dims(out_dims); + out->set_dtype(dtype); + out->set_layout(DataLayout::NCHW); +} + } // namespace phi diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index 40d6ea595c0c95f5d01cf8fd31fa1fdce89d5037..bd0567486e4d62a9f6fe9adfa02727bfe79937e1 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -40,4 +40,18 @@ void EyeInferMeta(int64_t num_rows, DataType dtype, MetaTensor* out); +void TruncatedGaussianRandomInferMeta(const std::vector& shape, + float mean, + float std, + int seed, + DataType dtype, + MetaTensor* out); + +void GaussianRandomInferMeta(const ScalarArray& shape, + float mean, + float std, + int seed, + DataType dtype, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc b/paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc index ebc032ef54538188d8e287673c0d31fae9ad197b..4247e597acef4aac14f93066a3ea6232734e0c8c 100644 --- a/paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc +++ b/paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc @@ -27,7 +27,7 @@ namespace phi { template void TruncatedGaussianRandomKernel(const Context& dev_ctx, - const ScalarArray& shape, + const std::vector& shape, float mean, float std, int seed, diff --git a/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu b/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu index 12c1bf791e1691bb6eee81750b337adea713b794..f27b32ca7b8319440b62f0d03d21129133c8470c 100644 --- a/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu +++ b/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu @@ -25,7 +25,6 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/fluid/framework/generator.h" -// #include "paddle/phi/core/generator.h" namespace phi { @@ -87,7 +86,7 @@ struct TruncatedNormalOffset { template void TruncatedGaussianRandomKernel(const Context& dev_ctx, - const ScalarArray& shape, + const std::vector& shape, float mean, float std, int seed, diff --git a/paddle/phi/kernels/truncated_gaussian_random_kernel.h b/paddle/phi/kernels/truncated_gaussian_random_kernel.h index 0370cc431fef9cab69861b7f707f65c897e20fa2..f8547ced41934a9810dc6874c090ab5aefd43497 100644 --- a/paddle/phi/kernels/truncated_gaussian_random_kernel.h +++ b/paddle/phi/kernels/truncated_gaussian_random_kernel.h @@ -20,6 +20,7 @@ #include "paddle/phi/common/scalar_array.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/infermeta/nullary.h" namespace phi { @@ -157,8 +158,8 @@ struct TruncatedNormal { }; template -void TruncatedGaussianRandomKernel(const Context& ctx, - const ScalarArray& shape, +void TruncatedGaussianRandomKernel(const Context& dev_ctx, + const std::vector& shape, float mean, float std, int seed,