From 6e0661cf9cc2a198f34273caab02355380c771c3 Mon Sep 17 00:00:00 2001 From: liaogang Date: Tue, 1 Aug 2017 23:20:56 +0800 Subject: [PATCH] Change ContextKernel to ExecutionKernel --- paddle/operators/mean_op.cc | 11 +++++------ paddle/operators/mean_op.h | 8 ++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index fc486a74355..fe34d6ad401 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -19,13 +19,12 @@ namespace operators { class MeanOp : public OperatorWithKernel { protected: - void InferShape(const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 1, "Input size of AddOp must be one"); - PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one"); - PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one"); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one"); + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.OutputVar(0) != nullptr, "Input/Output of MeanOp must be initialized."); - outputs[0]->Resize(framework::make_ddim({1})); + ctx.Output(0)->Resize(framework::make_ddim({1})); } }; diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index 483b3eb6015..5f7d443751d 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -21,14 +21,14 @@ namespace operators { template class MeanKernel : public OpKernel { public: - void Compute(const KernelContext& context) const override { - auto input = context.Input(0)->Get(); - auto output = context.Output(0)->GetMutable(); + void Compute(const ExecutionContext& context) const override { + auto input = context.Input(0); + auto output = context.Output(0); output->mutable_data(context.GetPlace()); EigenScalar::From(*output).device(*(context.GetEigenDevice())) = - EigenVector::Flatten(input).mean(); + EigenVector::Flatten(*input).mean(); } }; -- GitLab