未验证 提交 b0e7439f 编写于 作者: L Leo Chen 提交者: GitHub

rename inplace/no_need_buffer inferer, part2, test=develop (#24733)

上级 a6fbba65
......@@ -822,10 +822,10 @@ class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
}
};
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInference,
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
{"DDX", "DDOut"});
template <typename T>
......@@ -913,7 +913,7 @@ namespace plat = paddle::platform;
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
ops::ActFwdInplaceInferer, void>::type); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
ops::ActivationGradOpInplaceInference);
ops::ActivationGradOpInplaceInferer);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \
grad_functor) \
......@@ -941,13 +941,13 @@ REGISTER_OPERATOR(
paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::ActivationGradOpInplaceInferer,
ops::ReluDoubleGradMaker<paddle::framework::OpDesc>,
ops::ReluDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
relu_grad_grad,
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference);
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
......@@ -971,13 +971,13 @@ REGISTER_OPERATOR(
paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::ActivationGradOpInplaceInferer,
ops::LeakyReluDoubleGradMaker<paddle::framework::OpDesc>,
ops::LeakyReluDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
leaky_relu_grad_grad,
ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference);
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
LeakyReluGradFunctor);
......@@ -1000,13 +1000,13 @@ REGISTER_OPERATOR(
paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(elu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::ActivationGradOpInplaceInferer,
ops::ELUDoubleGradMaker<paddle::framework::OpDesc>,
ops::ELUDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
elu_grad_grad,
ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference);
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor);
REGISTER_OP_CPU_KERNEL(
......@@ -1028,13 +1028,13 @@ REGISTER_OPERATOR(
paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::ActivationGradOpInplaceInferer,
ops::SqrtDoubleGradMaker<paddle::framework::OpDesc>,
ops::SqrtDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
sqrt_grad_grad,
ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference);
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
......@@ -1056,13 +1056,13 @@ REGISTER_OPERATOR(
paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::ActivationGradOpInplaceInferer,
ops::SquareDoubleGradMaker<paddle::framework::OpDesc>,
ops::SquareDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
square_grad_grad,
ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference);
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
......@@ -1106,7 +1106,7 @@ REGISTER_OPERATOR(
std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
ops::ActivationGradOpInplaceInference);
ops::ActivationGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
pow, ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<float>>,
......@@ -1131,7 +1131,7 @@ REGISTER_OPERATOR(
std::conditional<ops::CanInplaceAct<ops::ExpGradFunctor<float>>(),
ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(exp_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference);
ops::ActivationGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(exp,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
......@@ -1163,7 +1163,7 @@ REGISTER_OPERATOR(
std::conditional<ops::CanInplaceAct<ops::AbsGradFunctor<float>>(),
ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(abs_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference);
ops::ActivationGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(abs,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
......
......@@ -116,7 +116,7 @@ class ArgsortGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ArgsortGradNoNeedBufferVarInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ArgsortGradNoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -126,7 +126,7 @@ REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker,
ops::ArgsortGradOpMaker<paddle::framework::OpDesc>,
ops::ArgsortGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(argsort_grad, ops::ArgsortGradOp,
ops::ArgsortGradNoNeedBufferVarInference);
ops::ArgsortGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(argsort,
ops::ArgsortKernel<paddle::platform::CPUPlace, float>,
ops::ArgsortKernel<paddle::platform::CPUPlace, double>,
......
......@@ -136,7 +136,7 @@ class BatchFCGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchFCGradOpNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchFCGradOpNoNeedBufferVarsInferer,
"Bias");
} // namespace operators
......@@ -148,7 +148,7 @@ REGISTER_OPERATOR(batch_fc, ops::BatchFCOp, ops::BatchFCOpMaker,
ops::BatchFCGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(batch_fc_grad, ops::BatchFCGradOp,
ops::BatchFCGradOpNoNeedBufferVarsInference);
ops::BatchFCGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
batch_fc, ops::BatchFCKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -74,7 +74,7 @@ class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
virtual void Apply() = 0;
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchSizeLikeNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchSizeLikeNoNeedBufferVarsInferer,
"Input");
} // namespace operators
......
......@@ -97,15 +97,15 @@ REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_add, Add);
namespace ops = paddle::operators;
REGISTER_OPERATOR(
elementwise_add_grad, ops::ElementwiseOpGrad, ops::ElementwiseGradOpInplace,
ops::ElementwiseGradNoBufVarsInference,
elementwise_add_grad, ops::ElementwiseOpGrad,
ops::ElementwiseGradOpInplaceInferer, ops::ElementwiseGradNoBufVarsInferer,
ops::ElementwiseAddDoubleGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseAddDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_add_grad_grad,
ops::ElementwiseOpDoubleGradWithoutDXDY,
ops::ElementwiseDoubleGradOpInplace,
ops::ElementwiseDoubleGradNoBufVarsInference);
ops::ElementwiseDoubleGradOpInplaceInferer,
ops::ElementwiseDoubleGradNoBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
elementwise_add,
......
......@@ -123,7 +123,7 @@ REGISTER_OPERATOR(
ops::ElementwiseDivDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad,
ops::ElementwiseDoubleGradOpInplace);
ops::ElementwiseDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
elementwise_div,
......
......@@ -123,7 +123,7 @@ REGISTER_OPERATOR(
ops::ElementwiseMulDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad,
ops::ElementwiseDoubleGradOpInplace);
ops::ElementwiseDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
elementwise_mul,
......
......@@ -348,16 +348,16 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
}
};
DECLARE_INPLACE_OP_INFERER(ElementwiseOpInplace, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplace,
DECLARE_INPLACE_OP_INFERER(ElementwiseOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplace, {"DDX", "DDOut"});
DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplaceInferer,
{"DDX", "DDOut"});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseGradNoBufVarsInference, "X",
"Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseDoubleGradNoBufVarsInference,
"Y", "DOut");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseGradNoBufVarsInferer, "X", "Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseDoubleGradNoBufVarsInferer, "Y",
"DOut");
} // namespace operators
} // namespace paddle
......@@ -389,4 +389,4 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseDoubleGradNoBufVarsInference,
::paddle::operators::ElementwiseOpInferVarType, \
op_type##GradMaker<::paddle::framework::OpDesc>, \
op_type##GradMaker<::paddle::imperative::OpBase>, \
::paddle::operators::ElementwiseOpInplace);
::paddle::operators::ElementwiseOpInplaceInferer);
......@@ -97,14 +97,14 @@ REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_sub, Sub);
namespace ops = paddle::operators;
REGISTER_OPERATOR(
elementwise_sub_grad, ops::ElementwiseOpGrad, ops::ElementwiseGradOpInplace,
ops::ElementwiseGradNoBufVarsInference,
elementwise_sub_grad, ops::ElementwiseOpGrad,
ops::ElementwiseGradOpInplaceInferer, ops::ElementwiseGradNoBufVarsInferer,
ops::ElementwiseSubDoubleGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseSubDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_sub_grad_grad,
ops::ElementwiseOpDoubleGradWithoutDXDY,
ops::ElementwiseDoubleGradOpInplace,
ops::ElementwiseDoubleGradNoBufVarsInference);
ops::ElementwiseDoubleGradOpInplaceInferer,
ops::ElementwiseDoubleGradNoBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
elementwise_sub,
......
......@@ -63,7 +63,7 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::FillConstantBatchSizeLikeOpMaker,
ops::BatchSizeLikeNoNeedBufferVarsInference);
ops::BatchSizeLikeNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
......
......@@ -71,7 +71,7 @@ class FillZerosLikeOp2Maker : public FillZerosLikeOpMaker {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(FillZerosLikeOp2NoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(FillZerosLikeOp2NoNeedBufferVarsInferer,
"X");
} // namespace operators
......@@ -83,7 +83,7 @@ REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp,
REGISTER_OPERATOR(
fill_zeros_like2, ops::FillZerosLikeOp2, ops::FillZerosLikeOp2Maker,
ops::FillZerosLikeOp2NoNeedBufferVarsInference,
ops::FillZerosLikeOp2NoNeedBufferVarsInferer,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
......@@ -241,11 +241,11 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
}
};
DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInToOut, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceinToOut,
DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(FlattenGradNoNeedBufferVarsInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(FlattenGradNoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -254,17 +254,17 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker,
ops::FlattenGradOpMaker<paddle::framework::OpDesc>,
ops::FlattenGradOpMaker<paddle::imperative::OpBase>,
ops::FlattenOpInplaceInToOut);
ops::FlattenOpInplaceInferer);
REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp,
ops::FlattenGradInplaceinToOut,
ops::FlattenGradNoNeedBufferVarsInference);
ops::FlattenGradInplaceInferer,
ops::FlattenGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker,
ops::Flatten2GradOpMaker<paddle::framework::OpDesc>,
ops::Flatten2GradOpMaker<paddle::imperative::OpBase>,
ops::FlattenOpInplaceInToOut);
ops::FlattenOpInplaceInferer);
REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp,
ops::FlattenGradInplaceinToOut);
ops::FlattenGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
flatten, ops::FlattenKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -166,7 +166,7 @@ class GatherNdGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherNdGradNoNeedBufferVarInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherNdGradNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -178,7 +178,7 @@ REGISTER_OPERATOR(gather_nd, ops::GatherNdOp, ops::GatherNdOpMaker,
ops::GatherNdGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp,
ops::GatherNdGradNoNeedBufferVarInference);
ops::GatherNdGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel<float>,
ops::GatherNdOpKernel<double>,
......
......@@ -127,7 +127,7 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -137,7 +137,7 @@ REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
ops::GatherGradOpMaker<paddle::framework::OpDesc>,
ops::GatherGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
ops::GatherGradNoNeedBufferVarInference);
ops::GatherGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
ops::GatherOpKernel<uint8_t>,
......
......@@ -74,6 +74,6 @@ REGISTER_OPERATOR(
paddle::operators::GaussianRandomBatchSizeLikeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::BatchSizeLikeNoNeedBufferVarsInference);
paddle::operators::BatchSizeLikeNoNeedBufferVarsInferer);
// Kernels are registered in gaussian_random_op.cc and gaussian_random_op.cu
......@@ -216,8 +216,8 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_INPLACE_OP_INFERER(GroupNormInplaceInToOut, {"X", "Y"});
DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInToOut,
DECLARE_INPLACE_OP_INFERER(GroupNormInplaceInferer, {"X", "Y"});
DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInferer,
{framework::GradVarName("Y"),
framework::GradVarName("X")});
......@@ -239,9 +239,9 @@ REGISTER_OPERATOR(group_norm, ops::GroupNormOp, ops::GroupNormOpMaker,
ops::GroupNormOpInferVarType,
ops::GroupNormGradMaker<paddle::framework::OpDesc>,
ops::GroupNormGradMaker<paddle::imperative::OpBase>,
ops::GroupNormInplaceInToOut);
ops::GroupNormInplaceInferer);
REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp,
ops::GroupNormGradInplaceInToOut);
ops::GroupNormGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
group_norm, ops::GroupNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::GroupNormKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -456,7 +456,7 @@ class GRUGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUGradOpNoNeedBufferVarInference, "Input",
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUGradOpNoNeedBufferVarInferer, "Input",
"Bias");
} // namespace operators
......@@ -467,7 +467,7 @@ REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker,
ops::GRUGradOpMaker<paddle::framework::OpDesc>,
ops::GRUGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gru_grad, ops::GRUGradOp,
ops::GRUGradOpNoNeedBufferVarInference);
ops::GRUGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel<float>,
ops::GRUCPUKernel<double>);
REGISTER_OP_CPU_KERNEL(
......
......@@ -234,7 +234,7 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUUnitGradOpNoNeedBufferVarInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUUnitGradOpNoNeedBufferVarInferer,
"Bias");
} // namespace operators
......@@ -246,7 +246,7 @@ REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker,
ops::GRUUnitGradOpMaker<paddle::framework::OpDesc>,
ops::GRUUnitGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp,
ops::GRUUnitGradOpNoNeedBufferVarInference);
ops::GRUUnitGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
gru_unit, ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -78,5 +78,5 @@ REGISTER_OPERATOR(
paddle::operators::UniformRandomBatchSizeLikeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::BatchSizeLikeNoNeedBufferVarsInference);
paddle::operators::BatchSizeLikeNoNeedBufferVarsInferer);
// Kernels are registered in uniform_random_op.cc and uniform_random_op.cu
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册