You need to sign in or sign up before continuing.
提交 6e0661cf 编写于 作者: L liaogang

Change ContextKernel to ExecutionKernel

上级 426f0999
...@@ -19,13 +19,12 @@ namespace operators { ...@@ -19,13 +19,12 @@ namespace operators {
class MeanOp : public OperatorWithKernel { class MeanOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one");
PADDLE_ENFORCE(inputs.size() == 1, "Input size of AddOp must be one"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one"); PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.OutputVar(0) != nullptr,
PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr,
"Input/Output of MeanOp must be initialized."); "Input/Output of MeanOp must be initialized.");
outputs[0]->Resize(framework::make_ddim({1})); ctx.Output<Tensor>(0)->Resize(framework::make_ddim({1}));
} }
}; };
......
...@@ -21,14 +21,14 @@ namespace operators { ...@@ -21,14 +21,14 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class MeanKernel : public OpKernel { class MeanKernel : public OpKernel {
public: public:
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto input = context.Input(0)->Get<Tensor>(); auto input = context.Input<Tensor>(0);
auto output = context.Output(0)->GetMutable<Tensor>(); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) = EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(input).mean(); EigenVector<T>::Flatten(*input).mean();
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册