未验证 提交 de8b4f42 编写于 作者: L Leo Chen 提交者: GitHub

rename inplace/no_need_buffer inferer, part 1, test=develop (#24711)

上级 33fc690f
......@@ -257,7 +257,7 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
HierarchicalSigmoidGradOpNoNeedBufferVarInference, "Bias");
HierarchicalSigmoidGradOpNoNeedBufferVarInferer, "Bias");
} // namespace operators
} // namespace paddle
......@@ -270,7 +270,7 @@ REGISTER_OPERATOR(
ops::HierarchicalSigmoidGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
ops::HierarchicalSigmoidGradOpGradVarTypeInference,
ops::HierarchicalSigmoidGradOpNoNeedBufferVarInference);
ops::HierarchicalSigmoidGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -138,7 +138,7 @@ class IndexSelectGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInferer,
"X");
} // namespace operators
} // namespace paddle
......@@ -148,7 +148,7 @@ REGISTER_OPERATOR(index_select, ops::IndexSelectOp, ops::IndexSelectOpMaker,
ops::IndexSelectGradMaker<paddle::framework::OpDesc>,
ops::IndexSelectGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(index_select_grad, ops::IndexSelectGradOp,
ops::IndexSelectGradNoNeedBufferVarsInference);
ops::IndexSelectGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
index_select,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -603,7 +603,7 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
}
};
DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInference,
DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInferer,
{"DY", "DDY"});
} // namespace operators
......@@ -618,7 +618,7 @@ REGISTER_OPERATOR(instance_norm_grad, ops::InstanceNormGradOp,
ops::InstanceNormDoubleGradMaker<paddle::framework::OpDesc>,
ops::InstanceNormDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(instance_norm_grad_grad, ops::InstanceNormDoubleGradOp,
ops::InstanceNormDoubleGradOpInplaceInference);
ops::InstanceNormDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
instance_norm,
......
......@@ -585,7 +585,7 @@ class InterpolateGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateGradNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateGradNoNeedBufferVarsInferer,
"X");
} // namespace operators
......@@ -596,22 +596,22 @@ REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(bicubic_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bicubic_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
......@@ -631,7 +631,7 @@ REGISTER_OPERATOR(linear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(linear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(linear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
......
......@@ -166,7 +166,7 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(KLDivLossGradNoNeedBufferVarInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(KLDivLossGradNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -176,7 +176,7 @@ REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker,
ops::KLDivLossOpGradMaker<paddle::framework::OpDesc>,
ops::KLDivLossOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad,
ops::KLDivLossGradNoNeedBufferVarInference);
ops::KLDivLossGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
kldiv_loss, ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, float>,
ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -220,7 +220,7 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LayerNormGradNoNeedBufferVarInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LayerNormGradNoNeedBufferVarInferer,
"Bias");
} // namespace operators
......@@ -231,7 +231,7 @@ REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
ops::LayerNormGradOpMaker<paddle::framework::OpDesc>,
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp,
ops::LayerNormGradNoNeedBufferVarInference);
ops::LayerNormGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -345,7 +345,7 @@ class LinearChainCRFGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LinearChainCRFGradNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LinearChainCRFGradNoNeedBufferVarsInferer,
"Transition", "Emission");
} // namespace operators
......@@ -357,7 +357,7 @@ REGISTER_OPERATOR(linear_chain_crf, ops::LinearChainCRFOp,
ops::LinearChainCRFGradMaker<paddle::framework::OpDesc>,
ops::LinearChainCRFGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(linear_chain_crf_grad, ops::LinearChainCRFGradOp,
ops::LinearChainCRFGradNoNeedBufferVarsInference);
ops::LinearChainCRFGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
linear_chain_crf,
ops::LinearChainCRFOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -223,7 +223,7 @@ DECLARE_INPLACE_OP_INFERER(LoDResetGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LoDResetGradNoNeedBufferVarInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LoDResetGradNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -234,7 +234,7 @@ REGISTER_OPERATOR(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker,
ops::LoDResetGradMaker<paddle::imperative::OpBase>,
ops::LoDResetOpVarTypeInference, ops::LoDResetInplaceInferer);
REGISTER_OPERATOR(lod_reset_grad, ops::LoDResetGradOp,
ops::LoDResetGradNoNeedBufferVarInference,
ops::LoDResetGradNoNeedBufferVarInferer,
ops::LoDResetGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
......
......@@ -130,7 +130,7 @@ or not. And the output only shares the LoD information with input Ids.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LookupTableGradOpNoBuffer, "W");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LookupTableGradOpNoBufferVarsInferer, "W");
template <typename T>
class LookupTableGradOpMaker : public framework::SingleGradOpMaker<T> {
......@@ -198,7 +198,7 @@ REGISTER_OPERATOR(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker,
ops::LookupTableGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,
ops::LookupTableGradOpNoBuffer,
ops::LookupTableGradOpNoBufferVarsInferer,
ops::LookupTableOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
......
......@@ -118,7 +118,8 @@ or not. And the output only shares the LoD information with input Ids.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LookupTableV2GradOpNoBuffer, "W");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LookupTableV2GradOpNoBufferVarsInferer,
"W");
template <typename T>
class LookupTableV2GradOpMaker : public framework::SingleGradOpMaker<T> {
......@@ -187,7 +188,7 @@ REGISTER_OPERATOR(lookup_table_v2, ops::LookupTableV2Op,
ops::LookupTableV2GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(lookup_table_v2_grad, ops::LookupTableV2OpGrad,
ops::LookupTableV2GradOpNoBuffer,
ops::LookupTableV2GradOpNoBufferVarsInferer,
ops::LookupTableV2OpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel<float>,
......
......@@ -83,7 +83,7 @@ class MeanGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(MeanGradNoNeedBufferVarsInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(MeanGradNoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -93,7 +93,7 @@ REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
ops::MeanGradMaker<paddle::framework::OpDesc>,
ops::MeanGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp,
ops::MeanGradNoNeedBufferVarsInference);
ops::MeanGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -307,7 +307,7 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(NCEGradOpNoNeedBufferVarInference, "Bias");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(NCEGradOpNoNeedBufferVarInferer, "Bias");
} // namespace operators
} // namespace paddle
......@@ -317,7 +317,7 @@ REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker,
ops::NCEGradOpMaker<paddle::framework::OpDesc>,
ops::NCEGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference,
ops::NCEGradOpNoNeedBufferVarInference);
ops::NCEGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>,
ops::NCEKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(nce_grad,
......
......@@ -656,7 +656,7 @@ class Pad2dOpGradMaker : public framework::SingleGradOpMaker<T> {
};
// TODO(zjl): Paddings can also be skipped!
DECLARE_NO_NEED_BUFFER_VARS_INFERER(Pad2dOpGradNoNeedBufferVarsInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(Pad2dOpGradNoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -667,7 +667,7 @@ REGISTER_OPERATOR(pad2d, ops::Pad2dOp, ops::Pad2dOpMaker,
ops::Pad2dOpGradMaker<paddle::framework::OpDesc>,
ops::Pad2dOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad,
ops::Pad2dOpGradNoNeedBufferVarsInference);
ops::Pad2dOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel<float>,
ops::Pad2dCPUKernel<double>, ops::Pad2dCPUKernel<int>,
ops::Pad2dCPUKernel<int64_t>);
......
......@@ -316,7 +316,7 @@ class MaxPoolWithIndexGradOpMaker : public framework::SingleGradOpMaker<T> {
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
MaxPoolWithIndexOpGradNoNeedBufferVarsInference, "X");
MaxPoolWithIndexOpGradNoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -328,7 +328,7 @@ REGISTER_OPERATOR(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad,
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInference);
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index,
......@@ -347,7 +347,7 @@ REGISTER_OPERATOR(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad,
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInference);
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index,
......
......@@ -56,7 +56,7 @@ The input gradients is all dense gradient tensors in a table.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(PushDenseNoNeedBufferVarsInference, "Ids");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(PushDenseNoNeedBufferVarsInferer, "Ids");
} // namespace operators
} // namespace paddle
......@@ -66,5 +66,5 @@ REGISTER_OPERATOR(
push_dense, ops::PushDenseOp, ops::PushDenseOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::PushDenseNoNeedBufferVarsInference);
ops::PushDenseNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(push_dense, ops::PushDenseCPUKernel<float>)
......@@ -545,12 +545,12 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
}
};
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInToOut, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut,
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInToOut, {"DDX", "DDOut"});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReshapeDoubleGradOpNoNeedBufferVarInference,
DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInferer, {"DDX", "DDOut"});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReshapeDoubleGradOpNoNeedBufferVarInferer,
"DOut");
} // namespace operators
......@@ -562,9 +562,9 @@ REGISTER_OPERATOR(
reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>,
ops::ReshapeOpInplaceInToOut);
ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp,
ops::ReshapeGradInplaceInToOut);
ops::ReshapeGradInplaceInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
......@@ -576,14 +576,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker<paddle::framework::OpDesc>,
ops::Reshape2GradMaker<paddle::imperative::OpBase>,
ops::ReshapeOpInplaceInToOut);
ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
ops::ReshapeGradInplaceInToOut);
ops::ReshapeGradInplaceInferer);
REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops::ReshapeDoubleGradInplaceInToOut,
ops::ReshapeDoubleGradOpNoNeedBufferVarInference);
ops::ReshapeDoubleGradInplaceInferer,
ops::ReshapeDoubleGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int8_t, ops::ReshapeKernel,
......
......@@ -104,7 +104,7 @@ class ScaleGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_INPLACE_OP_INFERER(ScaleOpInplace, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ScaleOpInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
......@@ -113,7 +113,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker,
ops::ScaleGradMaker<paddle::framework::OpDesc>,
ops::ScaleGradMaker<paddle::imperative::OpBase>,
ops::ScaleOpVarTypeInference, ops::ScaleOpInplace);
ops::ScaleOpVarTypeInference, ops::ScaleOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
......
......@@ -287,10 +287,10 @@ class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInference,
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInferer,
{"Logits", "Softmax"});
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInference,
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInferer,
{"Softmax", framework::GradVarName("Logits")});
} // namespace operators
......@@ -302,10 +302,10 @@ REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker,
ops::SoftmaxGradMaker<paddle::framework::OpDesc>,
ops::SoftmaxGradMaker<paddle::imperative::OpBase>,
ops::SoftmaxWithCrossEntropyInplaceInference);
ops::SoftmaxWithCrossEntropyInplaceInferer);
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad,
ops::SoftmaxWithCrossEntropyGradInplaceInference);
ops::SoftmaxWithCrossEntropyGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<float>,
ops::SoftmaxWithCrossEntropyKernel<double>);
......
......@@ -299,7 +299,7 @@ class SumGradOpBaseMaker : public imperative::GradOpBaseMakerBase {
}
};
DECLARE_INPLACE_OP_INFERER(SumInplace, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(SumInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
......@@ -308,7 +308,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradDescMaker,
ops::SumGradOpBaseMaker, ops::SumOpVarTypeInference,
ops::SumInplace);
ops::SumInplaceInferer);
REGISTER_OP_CPU_KERNEL(
sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册