未验证 提交 81d4142b 编写于 作者: F furnace 提交者: GitHub

[Phi] move InferShape for truncated_gaussian_random and gaussian_random (#40191)

* [Phi] move InferShape for truncated_gaussian_random and gaussian_random

* [Phi] delete useless codes
上级 0c33c47e
...@@ -15,12 +15,14 @@ limitations under the License. */ ...@@ -15,12 +15,14 @@ limitations under the License. */
#include <random> #include <random>
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/fill_constant_op.h" #include "paddle/fluid/operators/fill_constant_op.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#include "paddle/phi/infermeta/nullary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -54,38 +56,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -54,38 +56,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GaussianRandom");
auto shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
std::vector<int64_t> temp;
temp.reserve(shape.size());
for (auto dim : shape) {
temp.push_back(static_cast<int64_t>(dim));
}
if (shape.empty() && ctx->HasInput("ShapeTensor")) {
auto shape_dims = ctx->GetInputDim("ShapeTensor");
int num_ele = 1;
for (int i = 0; i < shape_dims.size(); ++i) {
num_ele *= shape_dims[i];
}
auto vec_dims = std::vector<int>(num_ele, -1);
ctx->SetOutputDim("Out", phi::make_ddim(vec_dims));
return;
}
if (!ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList")) {
PADDLE_ENFORCE_GT(
shape.size(), 0UL,
platform::errors::InvalidArgument(
"Attribute(shape) of GaussianRandomOp must be set "
"and shape.size() > 0, but reveived shape.size() is %d",
shape.size()));
}
ctx->SetOutputDim("Out", phi::make_ddim(temp));
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -171,11 +141,20 @@ Used to initialize tensors with gaussian random generator. ...@@ -171,11 +141,20 @@ Used to initialize tensors with gaussian random generator.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
ops::GaussianRandomOpMaker); DECLARE_INFER_SHAPE_FUNCTOR(gaussian_random, GaussianRandomInferShapeFunctor,
PD_INFER_META(phi::GaussianRandomInferMeta));
REGISTER_OPERATOR(
gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
GaussianRandomInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(gaussian_random_batch_size_like, REGISTER_OP_CPU_KERNEL(gaussian_random_batch_size_like,
ops::CPUGaussianRandomBatchSizeLikeKernel<float>, ops::CPUGaussianRandomBatchSizeLikeKernel<float>,
ops::CPUGaussianRandomBatchSizeLikeKernel<double>); ops::CPUGaussianRandomBatchSizeLikeKernel<double>);
REGISTER_OP_VERSION(gaussian_random) REGISTER_OP_VERSION(gaussian_random)
.AddCheckpoint( .AddCheckpoint(
R"ROC( R"ROC(
......
...@@ -17,8 +17,10 @@ limitations under the License. */ ...@@ -17,8 +17,10 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h" #include "paddle/fluid/operators/truncated_gaussian_random_op.h"
#include "paddle/phi/infermeta/nullary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,26 +29,6 @@ class TruncatedGaussianRandomOp : public framework::OperatorWithKernel { ...@@ -27,26 +29,6 @@ class TruncatedGaussianRandomOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of TruncatedGaussianRandomOp should not be null."));
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
std::vector<int64_t> out_dim;
out_dim.reserve(shape.size());
for (auto dim : shape) {
out_dim.push_back(static_cast<int64_t>(dim));
}
PADDLE_ENFORCE_GT(
shape.size(), 0UL,
platform::errors::InvalidArgument(
"the input shape of TruncatedGaussianRandomOp must be set, "
"But the rank of shape we received is %d",
shape.size()));
ctx->SetOutputDim("Out", phi::make_ddim(out_dim));
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -99,6 +81,14 @@ Used to initialize tensors with truncated gaussian random generator. ...@@ -99,6 +81,14 @@ Used to initialize tensors with truncated gaussian random generator.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(truncated_gaussian_random,
ops::TruncatedGaussianRandomOp, DECLARE_INFER_SHAPE_FUNCTOR(
ops::TruncatedGaussianRandomOpMaker); truncated_gaussian_random, TruncatedGaussianRandomInferShapeFunctor,
PD_INFER_META(phi::TruncatedGaussianRandomInferMeta));
REGISTER_OPERATOR(
truncated_gaussian_random, ops::TruncatedGaussianRandomOp,
ops::TruncatedGaussianRandomOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
TruncatedGaussianRandomInferShapeFunctor);
...@@ -40,4 +40,29 @@ void EyeInferMeta(int64_t num_rows, ...@@ -40,4 +40,29 @@ void EyeInferMeta(int64_t num_rows,
out->set_dims({num_rows, num_columns}); out->set_dims({num_rows, num_columns});
out->set_dtype(dtype); out->set_dtype(dtype);
} }
void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
float mean,
float std,
int seed,
DataType dtype,
MetaTensor* out) {
auto out_dims = phi::make_ddim(shape);
out->set_dims(out_dims);
out->set_dtype(dtype);
out->set_layout(DataLayout::NCHW);
}
void GaussianRandomInferMeta(const ScalarArray& shape,
float mean,
float std,
int seed,
DataType dtype,
MetaTensor* out) {
auto out_dims = phi::make_ddim(shape.GetData());
out->set_dims(out_dims);
out->set_dtype(dtype);
out->set_layout(DataLayout::NCHW);
}
} // namespace phi } // namespace phi
...@@ -40,4 +40,18 @@ void EyeInferMeta(int64_t num_rows, ...@@ -40,4 +40,18 @@ void EyeInferMeta(int64_t num_rows,
DataType dtype, DataType dtype,
MetaTensor* out); MetaTensor* out);
void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
float mean,
float std,
int seed,
DataType dtype,
MetaTensor* out);
void GaussianRandomInferMeta(const ScalarArray& shape,
float mean,
float std,
int seed,
DataType dtype,
MetaTensor* out);
} // namespace phi } // namespace phi
...@@ -27,7 +27,7 @@ namespace phi { ...@@ -27,7 +27,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void TruncatedGaussianRandomKernel(const Context& dev_ctx, void TruncatedGaussianRandomKernel(const Context& dev_ctx,
const ScalarArray& shape, const std::vector<int>& shape,
float mean, float mean,
float std, float std,
int seed, int seed,
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
// #include "paddle/phi/core/generator.h"
namespace phi { namespace phi {
...@@ -87,7 +86,7 @@ struct TruncatedNormalOffset { ...@@ -87,7 +86,7 @@ struct TruncatedNormalOffset {
template <typename T, typename Context> template <typename T, typename Context>
void TruncatedGaussianRandomKernel(const Context& dev_ctx, void TruncatedGaussianRandomKernel(const Context& dev_ctx,
const ScalarArray& shape, const std::vector<int>& shape,
float mean, float mean,
float std, float std,
int seed, int seed,
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/infermeta/nullary.h"
namespace phi { namespace phi {
...@@ -157,8 +158,8 @@ struct TruncatedNormal { ...@@ -157,8 +158,8 @@ struct TruncatedNormal {
}; };
template <typename T, typename Context> template <typename T, typename Context>
void TruncatedGaussianRandomKernel(const Context& ctx, void TruncatedGaussianRandomKernel(const Context& dev_ctx,
const ScalarArray& shape, const std::vector<int>& shape,
float mean, float mean,
float std, float std,
int seed, int seed,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册