diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index fc486a743554a11078a646d6af9c3ede642cae70..fe34d6ad4015620cac520146850e10563d4c50e0 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -19,13 +19,12 @@ namespace operators { class MeanOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 1, "Input size of AddOp must be one"); - PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one"); - PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one"); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one"); + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.OutputVar(0) != nullptr, "Input/Output of MeanOp must be initialized."); - outputs[0]->Resize(framework::make_ddim({1})); + ctx.Output(0)->Resize(framework::make_ddim({1})); } }; diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index 483b3eb60157eaff9b4c3001461a1f2e3e36f846..5f7d443751d1cdd7de3b67b0de2758ba1d566fb3 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -21,14 +21,14 @@ namespace operators { template class MeanKernel : public OpKernel { public: - void Compute(const KernelContext& context) const override { - auto input = context.Input(0)->Get(); - auto output = context.Output(0)->GetMutable(); + void Compute(const ExecutionContext& context) const override { + auto input = context.Input(0); + auto output = context.Output(0); output->mutable_data(context.GetPlace()); EigenScalar::From(*output).device(*(context.GetEigenDevice())) = - EigenVector::Flatten(input).mean(); + EigenVector::Flatten(*input).mean(); } };