add_op.h 729 字节
Newer Older
Y
Yu Yang 已提交
1
#pragma once
Q
qijun 已提交
2 3 4
#include "glog/logging.h"
#include "paddle/framework/operator.h"
//#include "paddle/operators/add_op_functor.h"
Y
Yu Yang 已提交
5 6

namespace paddle {
Y
Yu Yang 已提交
7
namespace operators {
Y
Yu Yang 已提交
8

Q
qijun 已提交
9
template <typename Place, typename T>
Y
Yu Yang 已提交
10 11
class AddKernel : public framework::OpKernel {
public:
Q
qijun 已提交
12
  void Compute(const KernelContext& context) const override {
Q
qijun 已提交
13 14 15
    auto input0 = context.Input(0)->Get<framework::Tensor>();
    auto input1 = context.Input(1)->Get<framework::Tensor>();
    auto* output = context.Output(0)->GetMutable<framework::Tensor>();
Q
qijun 已提交
16

Q
qijun 已提交
17
    output->mutable_data<T>(Place());
Q
qijun 已提交
18 19

    output->flat<T>().device(*(context.get_eigen_device<Place>())) =
Q
qijun 已提交
20
        input0.flat<T>() + input1.flat<T>();
Y
Yu Yang 已提交
21 22 23
  }
};

Q
qijun 已提交
24
}  // namespace operators
Y
Yu Yang 已提交
25
}  // namespace paddle