From fcd6f64b98aafdb13d29395eaa3573f69632382a Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Mon, 7 Aug 2017 17:28:38 +0800 Subject: [PATCH] "redefine random op" --- paddle/operators/gaussian_random_op.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index f5fd902c5f7..d7ced6b526b 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -23,7 +23,7 @@ template class GaussianRandomOpKernel : public framework::OpKernel { public: - void Compute(const framework::KernelContext& context) const override { + void Compute(const framework::ExecutionContext& context) const override { auto mean = context.op_.GetAttr("mean"); auto std = context.op_.GetAttr("std"); auto* output = context.Output(0)->GetMutable(); @@ -41,15 +41,14 @@ class GaussianRandomOpKernel class GaussianRandomOp : public framework::OperatorWithKernel { protected: - void InferShape( - const std::vector& inputs, - const std::vector& outputs) const override { + void InferShape(const framework::InferShapeContext& ctx) const override { PADDLE_ENFORCE(inputs.size() == 0, "Input size of RandomOp must be zero."); PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one."); PADDLE_ENFORCE(outputs[0] != nullptr, "Outputs of RandomOp must all be set."); - outputs[0]->Resize( - framework::make_ddim(this->GetAttr>("shape"))); + auto* tensor = ctx.Output(0); + auto dims = GetAttr(std::vector("shape")); + tensor->Resize(framework::make_ddim(dims)); } }; -- GitLab