未验证 提交 b0e7439f 编写于 作者: L Leo Chen 提交者: GitHub

rename inplace/no_need_buffer inferer, part2, test=develop (#24733)

上级 a6fbba65
...@@ -822,10 +822,10 @@ class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> { ...@@ -822,10 +822,10 @@ class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference, DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
framework::GradVarName("X")}); framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInference, DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
{"DDX", "DDOut"}); {"DDX", "DDOut"});
template <typename T> template <typename T>
...@@ -913,7 +913,7 @@ namespace plat = paddle::platform; ...@@ -913,7 +913,7 @@ namespace plat = paddle::platform;
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \ std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
ops::ActFwdInplaceInferer, void>::type); \ ops::ActFwdInplaceInferer, void>::type); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad, \ REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
ops::ActivationGradOpInplaceInference); ops::ActivationGradOpInplaceInferer);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \
grad_functor) \ grad_functor) \
...@@ -941,13 +941,13 @@ REGISTER_OPERATOR( ...@@ -941,13 +941,13 @@ REGISTER_OPERATOR(
paddle::imperative::OpBase>, paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer); ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInferer,
ops::ReluDoubleGradMaker<paddle::framework::OpDesc>, ops::ReluDoubleGradMaker<paddle::framework::OpDesc>,
ops::ReluDoubleGradMaker<paddle::imperative::OpBase>); ops::ReluDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
relu_grad_grad, relu_grad_grad,
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
...@@ -971,13 +971,13 @@ REGISTER_OPERATOR( ...@@ -971,13 +971,13 @@ REGISTER_OPERATOR(
paddle::imperative::OpBase>, paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer); ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInferer,
ops::LeakyReluDoubleGradMaker<paddle::framework::OpDesc>, ops::LeakyReluDoubleGradMaker<paddle::framework::OpDesc>,
ops::LeakyReluDoubleGradMaker<paddle::imperative::OpBase>); ops::LeakyReluDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
leaky_relu_grad_grad, leaky_relu_grad_grad,
ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor, REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
LeakyReluGradFunctor); LeakyReluGradFunctor);
...@@ -1000,13 +1000,13 @@ REGISTER_OPERATOR( ...@@ -1000,13 +1000,13 @@ REGISTER_OPERATOR(
paddle::imperative::OpBase>, paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer); ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(elu_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(elu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInferer,
ops::ELUDoubleGradMaker<paddle::framework::OpDesc>, ops::ELUDoubleGradMaker<paddle::framework::OpDesc>,
ops::ELUDoubleGradMaker<paddle::imperative::OpBase>); ops::ELUDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
elu_grad_grad, elu_grad_grad,
ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor); REGISTER_ACTIVATION_CPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
...@@ -1028,13 +1028,13 @@ REGISTER_OPERATOR( ...@@ -1028,13 +1028,13 @@ REGISTER_OPERATOR(
paddle::imperative::OpBase>, paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer); ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInferer,
ops::SqrtDoubleGradMaker<paddle::framework::OpDesc>, ops::SqrtDoubleGradMaker<paddle::framework::OpDesc>,
ops::SqrtDoubleGradMaker<paddle::imperative::OpBase>); ops::SqrtDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
sqrt_grad_grad, sqrt_grad_grad,
ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor); REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
...@@ -1056,13 +1056,13 @@ REGISTER_OPERATOR( ...@@ -1056,13 +1056,13 @@ REGISTER_OPERATOR(
paddle::imperative::OpBase>, paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer); ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInferer,
ops::SquareDoubleGradMaker<paddle::framework::OpDesc>, ops::SquareDoubleGradMaker<paddle::framework::OpDesc>,
ops::SquareDoubleGradMaker<paddle::imperative::OpBase>); ops::SquareDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
square_grad_grad, square_grad_grad,
ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(square, REGISTER_OP_CPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::CPUDeviceContext, ops::ActivationKernel<paddle::platform::CPUDeviceContext,
...@@ -1106,7 +1106,7 @@ REGISTER_OPERATOR( ...@@ -1106,7 +1106,7 @@ REGISTER_OPERATOR(
std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(), std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
ops::ActFwdInplaceInferer, void>::type); ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad, REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
ops::ActivationGradOpInplaceInference); ops::ActivationGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
pow, ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<float>>, pow, ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<float>>,
...@@ -1131,7 +1131,7 @@ REGISTER_OPERATOR( ...@@ -1131,7 +1131,7 @@ REGISTER_OPERATOR(
std::conditional<ops::CanInplaceAct<ops::ExpGradFunctor<float>>(), std::conditional<ops::CanInplaceAct<ops::ExpGradFunctor<float>>(),
ops::ActFwdInplaceInferer, void>::type); ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(exp_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(exp_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference); ops::ActivationGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(exp, REGISTER_OP_CPU_KERNEL(exp,
ops::ActivationKernel<paddle::platform::CPUDeviceContext, ops::ActivationKernel<paddle::platform::CPUDeviceContext,
...@@ -1163,7 +1163,7 @@ REGISTER_OPERATOR( ...@@ -1163,7 +1163,7 @@ REGISTER_OPERATOR(
std::conditional<ops::CanInplaceAct<ops::AbsGradFunctor<float>>(), std::conditional<ops::CanInplaceAct<ops::AbsGradFunctor<float>>(),
ops::ActFwdInplaceInferer, void>::type); ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(abs_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(abs_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference); ops::ActivationGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(abs, REGISTER_OP_CPU_KERNEL(abs,
ops::ActivationKernel<paddle::platform::CPUDeviceContext, ops::ActivationKernel<paddle::platform::CPUDeviceContext,
......
...@@ -116,7 +116,7 @@ class ArgsortGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -116,7 +116,7 @@ class ArgsortGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ArgsortGradNoNeedBufferVarInference, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(ArgsortGradNoNeedBufferVarsInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -126,7 +126,7 @@ REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker, ...@@ -126,7 +126,7 @@ REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker,
ops::ArgsortGradOpMaker<paddle::framework::OpDesc>, ops::ArgsortGradOpMaker<paddle::framework::OpDesc>,
ops::ArgsortGradOpMaker<paddle::imperative::OpBase>); ops::ArgsortGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(argsort_grad, ops::ArgsortGradOp, REGISTER_OPERATOR(argsort_grad, ops::ArgsortGradOp,
ops::ArgsortGradNoNeedBufferVarInference); ops::ArgsortGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(argsort, REGISTER_OP_CPU_KERNEL(argsort,
ops::ArgsortKernel<paddle::platform::CPUPlace, float>, ops::ArgsortKernel<paddle::platform::CPUPlace, float>,
ops::ArgsortKernel<paddle::platform::CPUPlace, double>, ops::ArgsortKernel<paddle::platform::CPUPlace, double>,
......
...@@ -136,7 +136,7 @@ class BatchFCGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -136,7 +136,7 @@ class BatchFCGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchFCGradOpNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchFCGradOpNoNeedBufferVarsInferer,
"Bias"); "Bias");
} // namespace operators } // namespace operators
...@@ -148,7 +148,7 @@ REGISTER_OPERATOR(batch_fc, ops::BatchFCOp, ops::BatchFCOpMaker, ...@@ -148,7 +148,7 @@ REGISTER_OPERATOR(batch_fc, ops::BatchFCOp, ops::BatchFCOpMaker,
ops::BatchFCGradOpMaker<paddle::imperative::OpBase>); ops::BatchFCGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(batch_fc_grad, ops::BatchFCGradOp, REGISTER_OPERATOR(batch_fc_grad, ops::BatchFCGradOp,
ops::BatchFCGradOpNoNeedBufferVarsInference); ops::BatchFCGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
batch_fc, ops::BatchFCKernel<paddle::platform::CPUDeviceContext, float>, batch_fc, ops::BatchFCKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -74,7 +74,7 @@ class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -74,7 +74,7 @@ class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
virtual void Apply() = 0; virtual void Apply() = 0;
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchSizeLikeNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchSizeLikeNoNeedBufferVarsInferer,
"Input"); "Input");
} // namespace operators } // namespace operators
......
...@@ -97,15 +97,15 @@ REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_add, Add); ...@@ -97,15 +97,15 @@ REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_add, Add);
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
elementwise_add_grad, ops::ElementwiseOpGrad, ops::ElementwiseGradOpInplace, elementwise_add_grad, ops::ElementwiseOpGrad,
ops::ElementwiseGradNoBufVarsInference, ops::ElementwiseGradOpInplaceInferer, ops::ElementwiseGradNoBufVarsInferer,
ops::ElementwiseAddDoubleGradMaker<paddle::framework::OpDesc>, ops::ElementwiseAddDoubleGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseAddDoubleGradMaker<paddle::imperative::OpBase>); ops::ElementwiseAddDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_add_grad_grad, REGISTER_OPERATOR(elementwise_add_grad_grad,
ops::ElementwiseOpDoubleGradWithoutDXDY, ops::ElementwiseOpDoubleGradWithoutDXDY,
ops::ElementwiseDoubleGradOpInplace, ops::ElementwiseDoubleGradOpInplaceInferer,
ops::ElementwiseDoubleGradNoBufVarsInference); ops::ElementwiseDoubleGradNoBufVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add, elementwise_add,
......
...@@ -123,7 +123,7 @@ REGISTER_OPERATOR( ...@@ -123,7 +123,7 @@ REGISTER_OPERATOR(
ops::ElementwiseDivDoubleGradMaker<paddle::imperative::OpBase>); ops::ElementwiseDivDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad, REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad,
ops::ElementwiseDoubleGradOpInplace); ops::ElementwiseDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div, elementwise_div,
......
...@@ -123,7 +123,7 @@ REGISTER_OPERATOR( ...@@ -123,7 +123,7 @@ REGISTER_OPERATOR(
ops::ElementwiseMulDoubleGradMaker<paddle::imperative::OpBase>); ops::ElementwiseMulDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad, REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad,
ops::ElementwiseDoubleGradOpInplace); ops::ElementwiseDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul, elementwise_mul,
......
...@@ -348,16 +348,16 @@ class ElemwiseGradKernel : public framework::OpKernel<T> { ...@@ -348,16 +348,16 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
} }
}; };
DECLARE_INPLACE_OP_INFERER(ElementwiseOpInplace, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ElementwiseOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplace, DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
framework::GradVarName("X")}); framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplace, {"DDX", "DDOut"}); DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplaceInferer,
{"DDX", "DDOut"});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseGradNoBufVarsInference, "X", DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseGradNoBufVarsInferer, "X", "Y");
"Y"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseDoubleGradNoBufVarsInferer, "Y",
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseDoubleGradNoBufVarsInference, "DOut");
"Y", "DOut");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -389,4 +389,4 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseDoubleGradNoBufVarsInference, ...@@ -389,4 +389,4 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseDoubleGradNoBufVarsInference,
::paddle::operators::ElementwiseOpInferVarType, \ ::paddle::operators::ElementwiseOpInferVarType, \
op_type##GradMaker<::paddle::framework::OpDesc>, \ op_type##GradMaker<::paddle::framework::OpDesc>, \
op_type##GradMaker<::paddle::imperative::OpBase>, \ op_type##GradMaker<::paddle::imperative::OpBase>, \
::paddle::operators::ElementwiseOpInplace); ::paddle::operators::ElementwiseOpInplaceInferer);
...@@ -97,14 +97,14 @@ REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_sub, Sub); ...@@ -97,14 +97,14 @@ REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_sub, Sub);
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
elementwise_sub_grad, ops::ElementwiseOpGrad, ops::ElementwiseGradOpInplace, elementwise_sub_grad, ops::ElementwiseOpGrad,
ops::ElementwiseGradNoBufVarsInference, ops::ElementwiseGradOpInplaceInferer, ops::ElementwiseGradNoBufVarsInferer,
ops::ElementwiseSubDoubleGradMaker<paddle::framework::OpDesc>, ops::ElementwiseSubDoubleGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseSubDoubleGradMaker<paddle::imperative::OpBase>); ops::ElementwiseSubDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_sub_grad_grad, REGISTER_OPERATOR(elementwise_sub_grad_grad,
ops::ElementwiseOpDoubleGradWithoutDXDY, ops::ElementwiseOpDoubleGradWithoutDXDY,
ops::ElementwiseDoubleGradOpInplace, ops::ElementwiseDoubleGradOpInplaceInferer,
ops::ElementwiseDoubleGradNoBufVarsInference); ops::ElementwiseDoubleGradNoBufVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_sub, elementwise_sub,
......
...@@ -63,7 +63,7 @@ REGISTER_OPERATOR( ...@@ -63,7 +63,7 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::FillConstantBatchSizeLikeOpMaker, ops::FillConstantBatchSizeLikeOpMaker,
ops::BatchSizeLikeNoNeedBufferVarsInference); ops::BatchSizeLikeNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_constant_batch_size_like, fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext, ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
......
...@@ -71,7 +71,7 @@ class FillZerosLikeOp2Maker : public FillZerosLikeOpMaker { ...@@ -71,7 +71,7 @@ class FillZerosLikeOp2Maker : public FillZerosLikeOpMaker {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(FillZerosLikeOp2NoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(FillZerosLikeOp2NoNeedBufferVarsInferer,
"X"); "X");
} // namespace operators } // namespace operators
...@@ -83,7 +83,7 @@ REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp, ...@@ -83,7 +83,7 @@ REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp,
REGISTER_OPERATOR( REGISTER_OPERATOR(
fill_zeros_like2, ops::FillZerosLikeOp2, ops::FillZerosLikeOp2Maker, fill_zeros_like2, ops::FillZerosLikeOp2, ops::FillZerosLikeOp2Maker,
ops::FillZerosLikeOp2NoNeedBufferVarsInference, ops::FillZerosLikeOp2NoNeedBufferVarsInferer,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
...@@ -241,11 +241,11 @@ class Flatten2GradOp : public framework::OperatorWithKernel { ...@@ -241,11 +241,11 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
} }
}; };
DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInToOut, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceinToOut, DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
framework::GradVarName("X")}); framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(FlattenGradNoNeedBufferVarsInference, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(FlattenGradNoNeedBufferVarsInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -254,17 +254,17 @@ namespace ops = paddle::operators; ...@@ -254,17 +254,17 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker, REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker,
ops::FlattenGradOpMaker<paddle::framework::OpDesc>, ops::FlattenGradOpMaker<paddle::framework::OpDesc>,
ops::FlattenGradOpMaker<paddle::imperative::OpBase>, ops::FlattenGradOpMaker<paddle::imperative::OpBase>,
ops::FlattenOpInplaceInToOut); ops::FlattenOpInplaceInferer);
REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp,
ops::FlattenGradInplaceinToOut, ops::FlattenGradInplaceInferer,
ops::FlattenGradNoNeedBufferVarsInference); ops::FlattenGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker, REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker,
ops::Flatten2GradOpMaker<paddle::framework::OpDesc>, ops::Flatten2GradOpMaker<paddle::framework::OpDesc>,
ops::Flatten2GradOpMaker<paddle::imperative::OpBase>, ops::Flatten2GradOpMaker<paddle::imperative::OpBase>,
ops::FlattenOpInplaceInToOut); ops::FlattenOpInplaceInferer);
REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp, REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp,
ops::FlattenGradInplaceinToOut); ops::FlattenGradInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
flatten, ops::FlattenKernel<paddle::platform::CPUDeviceContext, float>, flatten, ops::FlattenKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -166,7 +166,7 @@ class GatherNdGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -166,7 +166,7 @@ class GatherNdGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherNdGradNoNeedBufferVarInference, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherNdGradNoNeedBufferVarInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -178,7 +178,7 @@ REGISTER_OPERATOR(gather_nd, ops::GatherNdOp, ops::GatherNdOpMaker, ...@@ -178,7 +178,7 @@ REGISTER_OPERATOR(gather_nd, ops::GatherNdOp, ops::GatherNdOpMaker,
ops::GatherNdGradOpMaker<paddle::imperative::OpBase>); ops::GatherNdGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp, REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp,
ops::GatherNdGradNoNeedBufferVarInference); ops::GatherNdGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel<float>, REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel<float>,
ops::GatherNdOpKernel<double>, ops::GatherNdOpKernel<double>,
......
...@@ -127,7 +127,7 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -127,7 +127,7 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInference, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -137,7 +137,7 @@ REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker, ...@@ -137,7 +137,7 @@ REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
ops::GatherGradOpMaker<paddle::framework::OpDesc>, ops::GatherGradOpMaker<paddle::framework::OpDesc>,
ops::GatherGradOpMaker<paddle::imperative::OpBase>); ops::GatherGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp, REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
ops::GatherGradNoNeedBufferVarInference); ops::GatherGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>, REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>, ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
ops::GatherOpKernel<uint8_t>, ops::GatherOpKernel<uint8_t>,
......
...@@ -74,6 +74,6 @@ REGISTER_OPERATOR( ...@@ -74,6 +74,6 @@ REGISTER_OPERATOR(
paddle::operators::GaussianRandomBatchSizeLikeOpMaker, paddle::operators::GaussianRandomBatchSizeLikeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::BatchSizeLikeNoNeedBufferVarsInference); paddle::operators::BatchSizeLikeNoNeedBufferVarsInferer);
// Kernels are registered in gaussian_random_op.cc and gaussian_random_op.cu // Kernels are registered in gaussian_random_op.cc and gaussian_random_op.cu
...@@ -216,8 +216,8 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -216,8 +216,8 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_INPLACE_OP_INFERER(GroupNormInplaceInToOut, {"X", "Y"}); DECLARE_INPLACE_OP_INFERER(GroupNormInplaceInferer, {"X", "Y"});
DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInToOut, DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInferer,
{framework::GradVarName("Y"), {framework::GradVarName("Y"),
framework::GradVarName("X")}); framework::GradVarName("X")});
...@@ -239,9 +239,9 @@ REGISTER_OPERATOR(group_norm, ops::GroupNormOp, ops::GroupNormOpMaker, ...@@ -239,9 +239,9 @@ REGISTER_OPERATOR(group_norm, ops::GroupNormOp, ops::GroupNormOpMaker,
ops::GroupNormOpInferVarType, ops::GroupNormOpInferVarType,
ops::GroupNormGradMaker<paddle::framework::OpDesc>, ops::GroupNormGradMaker<paddle::framework::OpDesc>,
ops::GroupNormGradMaker<paddle::imperative::OpBase>, ops::GroupNormGradMaker<paddle::imperative::OpBase>,
ops::GroupNormInplaceInToOut); ops::GroupNormInplaceInferer);
REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp, REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp,
ops::GroupNormGradInplaceInToOut); ops::GroupNormGradInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
group_norm, ops::GroupNormKernel<paddle::platform::CPUDeviceContext, float>, group_norm, ops::GroupNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::GroupNormKernel<paddle::platform::CPUDeviceContext, double>); ops::GroupNormKernel<paddle::platform::CPUDeviceContext, double>);
......
...@@ -456,7 +456,7 @@ class GRUGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -456,7 +456,7 @@ class GRUGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUGradOpNoNeedBufferVarInference, "Input", DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUGradOpNoNeedBufferVarInferer, "Input",
"Bias"); "Bias");
} // namespace operators } // namespace operators
...@@ -467,7 +467,7 @@ REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker, ...@@ -467,7 +467,7 @@ REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker,
ops::GRUGradOpMaker<paddle::framework::OpDesc>, ops::GRUGradOpMaker<paddle::framework::OpDesc>,
ops::GRUGradOpMaker<paddle::imperative::OpBase>); ops::GRUGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gru_grad, ops::GRUGradOp, REGISTER_OPERATOR(gru_grad, ops::GRUGradOp,
ops::GRUGradOpNoNeedBufferVarInference); ops::GRUGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel<float>, REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel<float>,
ops::GRUCPUKernel<double>); ops::GRUCPUKernel<double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -234,7 +234,7 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -234,7 +234,7 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUUnitGradOpNoNeedBufferVarInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUUnitGradOpNoNeedBufferVarInferer,
"Bias"); "Bias");
} // namespace operators } // namespace operators
...@@ -246,7 +246,7 @@ REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, ...@@ -246,7 +246,7 @@ REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker,
ops::GRUUnitGradOpMaker<paddle::framework::OpDesc>, ops::GRUUnitGradOpMaker<paddle::framework::OpDesc>,
ops::GRUUnitGradOpMaker<paddle::imperative::OpBase>); ops::GRUUnitGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp, REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp,
ops::GRUUnitGradOpNoNeedBufferVarInference); ops::GRUUnitGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
gru_unit, ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, float>, gru_unit, ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -78,5 +78,5 @@ REGISTER_OPERATOR( ...@@ -78,5 +78,5 @@ REGISTER_OPERATOR(
paddle::operators::UniformRandomBatchSizeLikeOpMaker, paddle::operators::UniformRandomBatchSizeLikeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::BatchSizeLikeNoNeedBufferVarsInference); paddle::operators::BatchSizeLikeNoNeedBufferVarsInferer);
// Kernels are registered in uniform_random_op.cc and uniform_random_op.cu // Kernels are registered in uniform_random_op.cc and uniform_random_op.cu
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册