提交 b9c86372 编写于 作者: Y Yu Yang

Fix compile

上级 fb6a48c6
......@@ -44,7 +44,7 @@ class ConcatKernel : public framework::OpKernel<T> {
};
template <typename Place, typename T>
class ConcatGradKernel : public framework::OpKernel {
class ConcatGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
......
......@@ -87,7 +87,7 @@ struct MaxOrMinGradFunctor {
};
template <typename Place, typename T, typename Functor>
class ReduceKernel : public framework::OpKernel {
class ReduceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size();
......@@ -141,7 +141,7 @@ class ReduceKernel : public framework::OpKernel {
};
template <typename Place, typename T, typename Functor>
class ReduceGradKernel : public framework::OpKernel {
class ReduceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册