未验证 提交 81d4142b 编写于 作者: F furnace 提交者: GitHub

[Phi] move InferShape for truncated_gaussian_random and gaussian_random (#40191)

* [Phi] move InferShape for truncated_gaussian_random and gaussian_random

* [Phi] delete useless codes
上级 0c33c47e
......@@ -15,12 +15,14 @@ limitations under the License. */
#include <random>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/phi/infermeta/nullary.h"
namespace paddle {
namespace operators {
......@@ -54,38 +56,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GaussianRandom");
auto shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
std::vector<int64_t> temp;
temp.reserve(shape.size());
for (auto dim : shape) {
temp.push_back(static_cast<int64_t>(dim));
}
if (shape.empty() && ctx->HasInput("ShapeTensor")) {
auto shape_dims = ctx->GetInputDim("ShapeTensor");
int num_ele = 1;
for (int i = 0; i < shape_dims.size(); ++i) {
num_ele *= shape_dims[i];
}
auto vec_dims = std::vector<int>(num_ele, -1);
ctx->SetOutputDim("Out", phi::make_ddim(vec_dims));
return;
}
if (!ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList")) {
PADDLE_ENFORCE_GT(
shape.size(), 0UL,
platform::errors::InvalidArgument(
"Attribute(shape) of GaussianRandomOp must be set "
"and shape.size() > 0, but reveived shape.size() is %d",
shape.size()));
}
ctx->SetOutputDim("Out", phi::make_ddim(temp));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -171,11 +141,20 @@ Used to initialize tensors with gaussian random generator.
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
ops::GaussianRandomOpMaker);
DECLARE_INFER_SHAPE_FUNCTOR(gaussian_random, GaussianRandomInferShapeFunctor,
PD_INFER_META(phi::GaussianRandomInferMeta));
REGISTER_OPERATOR(
gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
GaussianRandomInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(gaussian_random_batch_size_like,
ops::CPUGaussianRandomBatchSizeLikeKernel<float>,
ops::CPUGaussianRandomBatchSizeLikeKernel<double>);
REGISTER_OP_VERSION(gaussian_random)
.AddCheckpoint(
R"ROC(
......
......@@ -17,8 +17,10 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h"
#include "paddle/phi/infermeta/nullary.h"
namespace paddle {
namespace operators {
......@@ -27,26 +29,6 @@ class TruncatedGaussianRandomOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of TruncatedGaussianRandomOp should not be null."));
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
std::vector<int64_t> out_dim;
out_dim.reserve(shape.size());
for (auto dim : shape) {
out_dim.push_back(static_cast<int64_t>(dim));
}
PADDLE_ENFORCE_GT(
shape.size(), 0UL,
platform::errors::InvalidArgument(
"the input shape of TruncatedGaussianRandomOp must be set, "
"But the rank of shape we received is %d",
shape.size()));
ctx->SetOutputDim("Out", phi::make_ddim(out_dim));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -99,6 +81,14 @@ Used to initialize tensors with truncated gaussian random generator.
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(truncated_gaussian_random,
ops::TruncatedGaussianRandomOp,
ops::TruncatedGaussianRandomOpMaker);
DECLARE_INFER_SHAPE_FUNCTOR(
truncated_gaussian_random, TruncatedGaussianRandomInferShapeFunctor,
PD_INFER_META(phi::TruncatedGaussianRandomInferMeta));
REGISTER_OPERATOR(
truncated_gaussian_random, ops::TruncatedGaussianRandomOp,
ops::TruncatedGaussianRandomOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
TruncatedGaussianRandomInferShapeFunctor);
......@@ -40,4 +40,29 @@ void EyeInferMeta(int64_t num_rows,
out->set_dims({num_rows, num_columns});
out->set_dtype(dtype);
}
void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
float mean,
float std,
int seed,
DataType dtype,
MetaTensor* out) {
auto out_dims = phi::make_ddim(shape);
out->set_dims(out_dims);
out->set_dtype(dtype);
out->set_layout(DataLayout::NCHW);
}
void GaussianRandomInferMeta(const ScalarArray& shape,
float mean,
float std,
int seed,
DataType dtype,
MetaTensor* out) {
auto out_dims = phi::make_ddim(shape.GetData());
out->set_dims(out_dims);
out->set_dtype(dtype);
out->set_layout(DataLayout::NCHW);
}
} // namespace phi
......@@ -40,4 +40,18 @@ void EyeInferMeta(int64_t num_rows,
DataType dtype,
MetaTensor* out);
void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
float mean,
float std,
int seed,
DataType dtype,
MetaTensor* out);
void GaussianRandomInferMeta(const ScalarArray& shape,
float mean,
float std,
int seed,
DataType dtype,
MetaTensor* out);
} // namespace phi
......@@ -27,7 +27,7 @@ namespace phi {
template <typename T, typename Context>
void TruncatedGaussianRandomKernel(const Context& dev_ctx,
const ScalarArray& shape,
const std::vector<int>& shape,
float mean,
float std,
int seed,
......
......@@ -25,7 +25,6 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/generator.h"
// #include "paddle/phi/core/generator.h"
namespace phi {
......@@ -87,7 +86,7 @@ struct TruncatedNormalOffset {
template <typename T, typename Context>
void TruncatedGaussianRandomKernel(const Context& dev_ctx,
const ScalarArray& shape,
const std::vector<int>& shape,
float mean,
float std,
int seed,
......
......@@ -20,6 +20,7 @@
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/infermeta/nullary.h"
namespace phi {
......@@ -157,8 +158,8 @@ struct TruncatedNormal {
};
template <typename T, typename Context>
void TruncatedGaussianRandomKernel(const Context& ctx,
const ScalarArray& shape,
void TruncatedGaussianRandomKernel(const Context& dev_ctx,
const std::vector<int>& shape,
float mean,
float std,
int seed,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册