diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index f5fd902c5f72c18955e19539a9de00af11612829..d7ced6b526b51ea453dfe630894aa4675ef7087a 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -23,7 +23,7 @@ template class GaussianRandomOpKernel : public framework::OpKernel { public: - void Compute(const framework::KernelContext& context) const override { + void Compute(const framework::ExecutionContext& context) const override { auto mean = context.op_.GetAttr("mean"); auto std = context.op_.GetAttr("std"); auto* output = context.Output(0)->GetMutable(); @@ -41,15 +41,14 @@ class GaussianRandomOpKernel class GaussianRandomOp : public framework::OperatorWithKernel { protected: - void InferShape( - const std::vector& inputs, - const std::vector& outputs) const override { + void InferShape(const framework::InferShapeContext& ctx) const override { PADDLE_ENFORCE(inputs.size() == 0, "Input size of RandomOp must be zero."); PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one."); PADDLE_ENFORCE(outputs[0] != nullptr, "Outputs of RandomOp must all be set."); - outputs[0]->Resize( - framework::make_ddim(this->GetAttr>("shape"))); + auto* tensor = ctx.Output(0); + auto dims = GetAttr(std::vector("shape")); + tensor->Resize(framework::make_ddim(dims)); } };