add_op.h 699 字节
Newer Older
Y
Yu Yang 已提交
1
#pragma once
Q
qijun 已提交
2 3 4
#include "glog/logging.h"
#include "paddle/framework/operator.h"
//#include "paddle/operators/add_op_functor.h"
Y
Yu Yang 已提交
5 6

namespace paddle {
Y
Yu Yang 已提交
7
namespace operators {
Y
Yu Yang 已提交
8

Q
qijun 已提交
9 10
// Place can be CPUPlace or GPUPlace
template <typename Place, typename DataType>
Y
Yu Yang 已提交
11 12
class AddKernel : public framework::OpKernel {
public:
Q
qijun 已提交
13 14 15 16 17 18 19 20 21
  void Compute(const KernelContext& context) const override {
    auto* input0 = context.Input(0);
    auto* input1 = context.Input(1);

    auto* output = context.Output(0);
    output->mutable_data<DataType>(Place());

    output->flat<T>().device(*(context.get_eigen_device<Place>())) =
        input0->flat<T>() + input1->flat<T>();
Y
Yu Yang 已提交
22 23 24
  }
};

Q
qijun 已提交
25
}  // namespace operators
Y
Yu Yang 已提交
26
}  // namespace paddle