未验证 提交 de8b4f42 编写于 作者: L Leo Chen 提交者: GitHub

rename inplace/no_need_buffer inferer, part 1, test=develop (#24711)

上级 33fc690f
...@@ -257,7 +257,7 @@ class HierarchicalSigmoidGradOpGradVarTypeInference ...@@ -257,7 +257,7 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER( DECLARE_NO_NEED_BUFFER_VARS_INFERER(
HierarchicalSigmoidGradOpNoNeedBufferVarInference, "Bias"); HierarchicalSigmoidGradOpNoNeedBufferVarInferer, "Bias");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -270,7 +270,7 @@ REGISTER_OPERATOR( ...@@ -270,7 +270,7 @@ REGISTER_OPERATOR(
ops::HierarchicalSigmoidGradMaker<paddle::imperative::OpBase>); ops::HierarchicalSigmoidGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp, REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
ops::HierarchicalSigmoidGradOpGradVarTypeInference, ops::HierarchicalSigmoidGradOpGradVarTypeInference,
ops::HierarchicalSigmoidGradOpNoNeedBufferVarInference); ops::HierarchicalSigmoidGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid, hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>, ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -138,7 +138,7 @@ class IndexSelectGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -138,7 +138,7 @@ class IndexSelectGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInferer,
"X"); "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -148,7 +148,7 @@ REGISTER_OPERATOR(index_select, ops::IndexSelectOp, ops::IndexSelectOpMaker, ...@@ -148,7 +148,7 @@ REGISTER_OPERATOR(index_select, ops::IndexSelectOp, ops::IndexSelectOpMaker,
ops::IndexSelectGradMaker<paddle::framework::OpDesc>, ops::IndexSelectGradMaker<paddle::framework::OpDesc>,
ops::IndexSelectGradMaker<paddle::imperative::OpBase>); ops::IndexSelectGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(index_select_grad, ops::IndexSelectGradOp, REGISTER_OPERATOR(index_select_grad, ops::IndexSelectGradOp,
ops::IndexSelectGradNoNeedBufferVarsInference); ops::IndexSelectGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
index_select, index_select,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, float>, ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -603,7 +603,7 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T> ...@@ -603,7 +603,7 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
} }
}; };
DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInference, DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInferer,
{"DY", "DDY"}); {"DY", "DDY"});
} // namespace operators } // namespace operators
...@@ -618,7 +618,7 @@ REGISTER_OPERATOR(instance_norm_grad, ops::InstanceNormGradOp, ...@@ -618,7 +618,7 @@ REGISTER_OPERATOR(instance_norm_grad, ops::InstanceNormGradOp,
ops::InstanceNormDoubleGradMaker<paddle::framework::OpDesc>, ops::InstanceNormDoubleGradMaker<paddle::framework::OpDesc>,
ops::InstanceNormDoubleGradMaker<paddle::imperative::OpBase>); ops::InstanceNormDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(instance_norm_grad_grad, ops::InstanceNormDoubleGradOp, REGISTER_OPERATOR(instance_norm_grad_grad, ops::InstanceNormDoubleGradOp,
ops::InstanceNormDoubleGradOpInplaceInference); ops::InstanceNormDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
instance_norm, instance_norm,
......
...@@ -585,7 +585,7 @@ class InterpolateGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -585,7 +585,7 @@ class InterpolateGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateGradNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateGradNoNeedBufferVarsInferer,
"X"); "X");
} // namespace operators } // namespace operators
...@@ -596,22 +596,22 @@ REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker, ...@@ -596,22 +596,22 @@ REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>, ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>); ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad, REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference); ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker, REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>, ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>); ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad, REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference); ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker, REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>, ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>); ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad, REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference); ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(bicubic_interp, ops::InterpolateOp, ops::InterpolateOpMaker, REGISTER_OPERATOR(bicubic_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>, ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>); ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bicubic_interp_grad, ops::InterpolateOpGrad, REGISTER_OPERATOR(bicubic_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference); ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>, REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>, ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>); ops::InterpolateKernel<uint8_t>);
...@@ -631,7 +631,7 @@ REGISTER_OPERATOR(linear_interp, ops::InterpolateOp, ops::InterpolateOpMaker, ...@@ -631,7 +631,7 @@ REGISTER_OPERATOR(linear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>, ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>); ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(linear_interp_grad, ops::InterpolateOpGrad, REGISTER_OPERATOR(linear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference); ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(linear_interp, ops::InterpolateKernel<float>, REGISTER_OP_CPU_KERNEL(linear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>, ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>); ops::InterpolateKernel<uint8_t>);
......
...@@ -166,7 +166,7 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -166,7 +166,7 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(KLDivLossGradNoNeedBufferVarInference, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(KLDivLossGradNoNeedBufferVarInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -176,7 +176,7 @@ REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker, ...@@ -176,7 +176,7 @@ REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker,
ops::KLDivLossOpGradMaker<paddle::framework::OpDesc>, ops::KLDivLossOpGradMaker<paddle::framework::OpDesc>,
ops::KLDivLossOpGradMaker<paddle::imperative::OpBase>); ops::KLDivLossOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad, REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad,
ops::KLDivLossGradNoNeedBufferVarInference); ops::KLDivLossGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
kldiv_loss, ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, float>, kldiv_loss, ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, float>,
ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, double>); ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, double>);
......
...@@ -220,7 +220,7 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -220,7 +220,7 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LayerNormGradNoNeedBufferVarInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(LayerNormGradNoNeedBufferVarInferer,
"Bias"); "Bias");
} // namespace operators } // namespace operators
...@@ -231,7 +231,7 @@ REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker, ...@@ -231,7 +231,7 @@ REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
ops::LayerNormGradOpMaker<paddle::framework::OpDesc>, ops::LayerNormGradOpMaker<paddle::framework::OpDesc>,
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>); ops::LayerNormGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp, REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp,
ops::LayerNormGradNoNeedBufferVarInference); ops::LayerNormGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>, layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CPUDeviceContext, double>); ops::LayerNormKernel<paddle::platform::CPUDeviceContext, double>);
......
...@@ -345,7 +345,7 @@ class LinearChainCRFGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -345,7 +345,7 @@ class LinearChainCRFGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LinearChainCRFGradNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(LinearChainCRFGradNoNeedBufferVarsInferer,
"Transition", "Emission"); "Transition", "Emission");
} // namespace operators } // namespace operators
...@@ -357,7 +357,7 @@ REGISTER_OPERATOR(linear_chain_crf, ops::LinearChainCRFOp, ...@@ -357,7 +357,7 @@ REGISTER_OPERATOR(linear_chain_crf, ops::LinearChainCRFOp,
ops::LinearChainCRFGradMaker<paddle::framework::OpDesc>, ops::LinearChainCRFGradMaker<paddle::framework::OpDesc>,
ops::LinearChainCRFGradMaker<paddle::imperative::OpBase>); ops::LinearChainCRFGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(linear_chain_crf_grad, ops::LinearChainCRFGradOp, REGISTER_OPERATOR(linear_chain_crf_grad, ops::LinearChainCRFGradOp,
ops::LinearChainCRFGradNoNeedBufferVarsInference); ops::LinearChainCRFGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
linear_chain_crf, linear_chain_crf,
ops::LinearChainCRFOpKernel<paddle::platform::CPUDeviceContext, float>, ops::LinearChainCRFOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -223,7 +223,7 @@ DECLARE_INPLACE_OP_INFERER(LoDResetGradInplaceInferer, ...@@ -223,7 +223,7 @@ DECLARE_INPLACE_OP_INFERER(LoDResetGradInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
framework::GradVarName("X")}); framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LoDResetGradNoNeedBufferVarInference, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(LoDResetGradNoNeedBufferVarInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -234,7 +234,7 @@ REGISTER_OPERATOR(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, ...@@ -234,7 +234,7 @@ REGISTER_OPERATOR(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker,
ops::LoDResetGradMaker<paddle::imperative::OpBase>, ops::LoDResetGradMaker<paddle::imperative::OpBase>,
ops::LoDResetOpVarTypeInference, ops::LoDResetInplaceInferer); ops::LoDResetOpVarTypeInference, ops::LoDResetInplaceInferer);
REGISTER_OPERATOR(lod_reset_grad, ops::LoDResetGradOp, REGISTER_OPERATOR(lod_reset_grad, ops::LoDResetGradOp,
ops::LoDResetGradNoNeedBufferVarInference, ops::LoDResetGradNoNeedBufferVarInferer,
ops::LoDResetGradInplaceInferer); ops::LoDResetGradInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -130,7 +130,7 @@ or not. And the output only shares the LoD information with input Ids. ...@@ -130,7 +130,7 @@ or not. And the output only shares the LoD information with input Ids.
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LookupTableGradOpNoBuffer, "W"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(LookupTableGradOpNoBufferVarsInferer, "W");
template <typename T> template <typename T>
class LookupTableGradOpMaker : public framework::SingleGradOpMaker<T> { class LookupTableGradOpMaker : public framework::SingleGradOpMaker<T> {
...@@ -198,7 +198,7 @@ REGISTER_OPERATOR(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker, ...@@ -198,7 +198,7 @@ REGISTER_OPERATOR(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker,
ops::LookupTableGradOpMaker<paddle::imperative::OpBase>); ops::LookupTableGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad, REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,
ops::LookupTableGradOpNoBuffer, ops::LookupTableGradOpNoBufferVarsInferer,
ops::LookupTableOpGradVarTypeInference); ops::LookupTableOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>, REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
......
...@@ -118,7 +118,8 @@ or not. And the output only shares the LoD information with input Ids. ...@@ -118,7 +118,8 @@ or not. And the output only shares the LoD information with input Ids.
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LookupTableV2GradOpNoBuffer, "W"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(LookupTableV2GradOpNoBufferVarsInferer,
"W");
template <typename T> template <typename T>
class LookupTableV2GradOpMaker : public framework::SingleGradOpMaker<T> { class LookupTableV2GradOpMaker : public framework::SingleGradOpMaker<T> {
...@@ -187,7 +188,7 @@ REGISTER_OPERATOR(lookup_table_v2, ops::LookupTableV2Op, ...@@ -187,7 +188,7 @@ REGISTER_OPERATOR(lookup_table_v2, ops::LookupTableV2Op,
ops::LookupTableV2GradOpMaker<paddle::imperative::OpBase>); ops::LookupTableV2GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(lookup_table_v2_grad, ops::LookupTableV2OpGrad, REGISTER_OPERATOR(lookup_table_v2_grad, ops::LookupTableV2OpGrad,
ops::LookupTableV2GradOpNoBuffer, ops::LookupTableV2GradOpNoBufferVarsInferer,
ops::LookupTableV2OpGradVarTypeInference); ops::LookupTableV2OpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel<float>, REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel<float>,
......
...@@ -83,7 +83,7 @@ class MeanGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -83,7 +83,7 @@ class MeanGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(MeanGradNoNeedBufferVarsInference, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(MeanGradNoNeedBufferVarsInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -93,7 +93,7 @@ REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType, ...@@ -93,7 +93,7 @@ REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
ops::MeanGradMaker<paddle::framework::OpDesc>, ops::MeanGradMaker<paddle::framework::OpDesc>,
ops::MeanGradMaker<paddle::imperative::OpBase>); ops::MeanGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp, REGISTER_OPERATOR(mean_grad, ops::MeanGradOp,
ops::MeanGradNoNeedBufferVarsInference); ops::MeanGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>, mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanKernel<paddle::platform::CPUDeviceContext, double>); ops::MeanKernel<paddle::platform::CPUDeviceContext, double>);
......
...@@ -307,7 +307,7 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference { ...@@ -307,7 +307,7 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(NCEGradOpNoNeedBufferVarInference, "Bias"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(NCEGradOpNoNeedBufferVarInferer, "Bias");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -317,7 +317,7 @@ REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker, ...@@ -317,7 +317,7 @@ REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker,
ops::NCEGradOpMaker<paddle::framework::OpDesc>, ops::NCEGradOpMaker<paddle::framework::OpDesc>,
ops::NCEGradOpMaker<paddle::imperative::OpBase>); ops::NCEGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference, REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference,
ops::NCEGradOpNoNeedBufferVarInference); ops::NCEGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>, REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>,
ops::NCEKernel<paddle::platform::CPUPlace, double>); ops::NCEKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(nce_grad, REGISTER_OP_CPU_KERNEL(nce_grad,
......
...@@ -656,7 +656,7 @@ class Pad2dOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -656,7 +656,7 @@ class Pad2dOpGradMaker : public framework::SingleGradOpMaker<T> {
}; };
// TODO(zjl): Paddings can also be skipped! // TODO(zjl): Paddings can also be skipped!
DECLARE_NO_NEED_BUFFER_VARS_INFERER(Pad2dOpGradNoNeedBufferVarsInference, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(Pad2dOpGradNoNeedBufferVarsInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -667,7 +667,7 @@ REGISTER_OPERATOR(pad2d, ops::Pad2dOp, ops::Pad2dOpMaker, ...@@ -667,7 +667,7 @@ REGISTER_OPERATOR(pad2d, ops::Pad2dOp, ops::Pad2dOpMaker,
ops::Pad2dOpGradMaker<paddle::framework::OpDesc>, ops::Pad2dOpGradMaker<paddle::framework::OpDesc>,
ops::Pad2dOpGradMaker<paddle::imperative::OpBase>); ops::Pad2dOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad, REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad,
ops::Pad2dOpGradNoNeedBufferVarsInference); ops::Pad2dOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel<float>, REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel<float>,
ops::Pad2dCPUKernel<double>, ops::Pad2dCPUKernel<int>, ops::Pad2dCPUKernel<double>, ops::Pad2dCPUKernel<int>,
ops::Pad2dCPUKernel<int64_t>); ops::Pad2dCPUKernel<int64_t>);
......
...@@ -316,7 +316,7 @@ class MaxPoolWithIndexGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -316,7 +316,7 @@ class MaxPoolWithIndexGradOpMaker : public framework::SingleGradOpMaker<T> {
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER( DECLARE_NO_NEED_BUFFER_VARS_INFERER(
MaxPoolWithIndexOpGradNoNeedBufferVarsInference, "X"); MaxPoolWithIndexOpGradNoNeedBufferVarsInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -328,7 +328,7 @@ REGISTER_OPERATOR(max_pool2d_with_index, ops::MaxPoolWithIndexOp, ...@@ -328,7 +328,7 @@ REGISTER_OPERATOR(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>, ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>); ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad, REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad,
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInference); ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index, max_pool2d_with_index,
...@@ -347,7 +347,7 @@ REGISTER_OPERATOR(max_pool3d_with_index, ops::MaxPoolWithIndexOp, ...@@ -347,7 +347,7 @@ REGISTER_OPERATOR(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>, ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>); ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad, REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad,
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInference); ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index, max_pool3d_with_index,
......
...@@ -56,7 +56,7 @@ The input gradients is all dense gradient tensors in a table. ...@@ -56,7 +56,7 @@ The input gradients is all dense gradient tensors in a table.
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(PushDenseNoNeedBufferVarsInference, "Ids"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(PushDenseNoNeedBufferVarsInferer, "Ids");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -66,5 +66,5 @@ REGISTER_OPERATOR( ...@@ -66,5 +66,5 @@ REGISTER_OPERATOR(
push_dense, ops::PushDenseOp, ops::PushDenseOpMaker, push_dense, ops::PushDenseOp, ops::PushDenseOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::PushDenseNoNeedBufferVarsInference); ops::PushDenseNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(push_dense, ops::PushDenseCPUKernel<float>) REGISTER_OP_CPU_KERNEL(push_dense, ops::PushDenseCPUKernel<float>)
...@@ -545,12 +545,12 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { ...@@ -545,12 +545,12 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
} }
}; };
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInToOut, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut, DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
framework::GradVarName("X")}); framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInToOut, {"DDX", "DDOut"}); DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInferer, {"DDX", "DDOut"});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReshapeDoubleGradOpNoNeedBufferVarInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReshapeDoubleGradOpNoNeedBufferVarInferer,
"DOut"); "DOut");
} // namespace operators } // namespace operators
...@@ -562,9 +562,9 @@ REGISTER_OPERATOR( ...@@ -562,9 +562,9 @@ REGISTER_OPERATOR(
reshape, ops::ReshapeOp, ops::ReshapeOpMaker, reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>, paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>, paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>,
ops::ReshapeOpInplaceInToOut); ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp, REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp,
ops::ReshapeGradInplaceInToOut); ops::ReshapeGradInplaceInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int, ops::ReshapeKernel,
...@@ -576,14 +576,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ...@@ -576,14 +576,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker<paddle::framework::OpDesc>, ops::Reshape2GradMaker<paddle::framework::OpDesc>,
ops::Reshape2GradMaker<paddle::imperative::OpBase>, ops::Reshape2GradMaker<paddle::imperative::OpBase>,
ops::ReshapeOpInplaceInToOut); ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>, ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>, ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
ops::ReshapeGradInplaceInToOut); ops::ReshapeGradInplaceInferer);
REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp, REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops::ReshapeDoubleGradInplaceInToOut, ops::ReshapeDoubleGradInplaceInferer,
ops::ReshapeDoubleGradOpNoNeedBufferVarInference); ops::ReshapeDoubleGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int8_t, ops::ReshapeKernel, ops::ReshapeKernel, int8_t, ops::ReshapeKernel,
......
...@@ -104,7 +104,7 @@ class ScaleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -104,7 +104,7 @@ class ScaleGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_INPLACE_OP_INFERER(ScaleOpInplace, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ScaleOpInplaceInferer, {"X", "Out"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -113,7 +113,7 @@ namespace ops = paddle::operators; ...@@ -113,7 +113,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker,
ops::ScaleGradMaker<paddle::framework::OpDesc>, ops::ScaleGradMaker<paddle::framework::OpDesc>,
ops::ScaleGradMaker<paddle::imperative::OpBase>, ops::ScaleGradMaker<paddle::imperative::OpBase>,
ops::ScaleOpVarTypeInference, ops::ScaleOpInplace); ops::ScaleOpVarTypeInference, ops::ScaleOpInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>, scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>, ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
......
...@@ -287,10 +287,10 @@ class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -287,10 +287,10 @@ class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInference, DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInferer,
{"Logits", "Softmax"}); {"Logits", "Softmax"});
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInference, DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInferer,
{"Softmax", framework::GradVarName("Logits")}); {"Softmax", framework::GradVarName("Logits")});
} // namespace operators } // namespace operators
...@@ -302,10 +302,10 @@ REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, ...@@ -302,10 +302,10 @@ REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxWithCrossEntropyOpMaker,
ops::SoftmaxGradMaker<paddle::framework::OpDesc>, ops::SoftmaxGradMaker<paddle::framework::OpDesc>,
ops::SoftmaxGradMaker<paddle::imperative::OpBase>, ops::SoftmaxGradMaker<paddle::imperative::OpBase>,
ops::SoftmaxWithCrossEntropyInplaceInference); ops::SoftmaxWithCrossEntropyInplaceInferer);
REGISTER_OPERATOR(softmax_with_cross_entropy_grad, REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad, ops::SoftmaxWithCrossEntropyOpGrad,
ops::SoftmaxWithCrossEntropyGradInplaceInference); ops::SoftmaxWithCrossEntropyGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<float>, ops::SoftmaxWithCrossEntropyKernel<float>,
ops::SoftmaxWithCrossEntropyKernel<double>); ops::SoftmaxWithCrossEntropyKernel<double>);
......
...@@ -299,7 +299,7 @@ class SumGradOpBaseMaker : public imperative::GradOpBaseMakerBase { ...@@ -299,7 +299,7 @@ class SumGradOpBaseMaker : public imperative::GradOpBaseMakerBase {
} }
}; };
DECLARE_INPLACE_OP_INFERER(SumInplace, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(SumInplaceInferer, {"X", "Out"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -308,7 +308,7 @@ namespace ops = paddle::operators; ...@@ -308,7 +308,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradDescMaker, REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradDescMaker,
ops::SumGradOpBaseMaker, ops::SumOpVarTypeInference, ops::SumGradOpBaseMaker, ops::SumOpVarTypeInference,
ops::SumInplace); ops::SumInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>, sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册