提交 6e0661cf 编写于 作者: L liaogang

Change ContextKernel to ExecutionKernel

上级 426f0999
......@@ -19,13 +19,12 @@ namespace operators {
class MeanOp : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, "Input size of AddOp must be one");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr,
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.OutputVar(0) != nullptr,
"Input/Output of MeanOp must be initialized.");
outputs[0]->Resize(framework::make_ddim({1}));
ctx.Output<Tensor>(0)->Resize(framework::make_ddim({1}));
}
};
......
......@@ -21,14 +21,14 @@ namespace operators {
template <typename Place, typename T>
class MeanKernel : public OpKernel {
public:
void Compute(const KernelContext& context) const override {
auto input = context.Input(0)->Get<Tensor>();
auto output = context.Output(0)->GetMutable<Tensor>();
void Compute(const ExecutionContext& context) const override {
auto input = context.Input<Tensor>(0);
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(input).mean();
EigenVector<T>::Flatten(*input).mean();
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册