未验证 提交 a6fbba65 编写于 作者: L Leo Chen 提交者: GitHub

rename inplace/no_need_buffer inferer, part3, test=develop (#24734)

上级 4aa90990
......@@ -82,8 +82,7 @@ class ReduceMeanDoubleGradOpBaseMaker : public imperative::GradOpBaseMakerBase {
}
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceMeanGradNoNeedBufferVarInference,
"X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceMeanGradNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -99,7 +98,7 @@ REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__,
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp,
ops::ReduceMeanDoubleGradDescMaker,
ops::ReduceMeanDoubleGradOpBaseMaker,
ops::ReduceMeanGradNoNeedBufferVarInference);
ops::ReduceMeanGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanFunctor>,
......
......@@ -51,7 +51,7 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceSumGradNoNeedBufferVarInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceSumGradNoNeedBufferVarInferer, "X");
class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
public:
void operator()(paddle::framework::InferVarTypeContext* ctx) const override {
......@@ -77,7 +77,7 @@ REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker,
ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
ops::ReduceSumGradNoNeedBufferVarInference);
ops::ReduceSumGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
......
......@@ -123,8 +123,7 @@ class SeqConcatGradOp : public framework::OperatorWithKernel {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SeqConcatGradNoNeedBufferVarsInference,
"X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SeqConcatGradNoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -140,7 +139,7 @@ REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>,
Kernel<int>, Kernel<int64_t>);
REGISTER_OPERATOR(sequence_concat_grad, op::SeqConcatGradOp,
op::SeqConcatGradNoNeedBufferVarsInference);
op::SeqConcatGradNoNeedBufferVarsInferer);
template <typename T>
using GradKernel =
op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>;
......
......@@ -181,10 +181,10 @@ class SequenceExpandAsOpGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceExpandAsOpNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceExpandAsOpNoNeedBufferVarsInferer,
"Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
SequenceExpandAsGradOpNoNeedBufferVarsInference, "X", "Y");
SequenceExpandAsGradOpNoNeedBufferVarsInferer, "X", "Y");
} // namespace operators
} // namespace paddle
......@@ -194,9 +194,9 @@ REGISTER_OPERATOR(
sequence_expand_as, ops::SequenceExpandAsOp, ops::SequenceExpandAsOpMaker,
ops::SequenceExpandAsOpGradOpMaker<paddle::framework::OpDesc>,
ops::SequenceExpandAsOpGradOpMaker<paddle::imperative::OpBase>,
ops::SequenceExpandAsOpNoNeedBufferVarsInference);
ops::SequenceExpandAsOpNoNeedBufferVarsInferer);
REGISTER_OPERATOR(sequence_expand_as_grad, ops::SequenceExpandAsOpGrad,
ops::SequenceExpandAsGradOpNoNeedBufferVarsInference);
ops::SequenceExpandAsGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
sequence_expand_as,
ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -247,10 +247,10 @@ class SequenceExpandOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceExpandOpNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceExpandOpNoNeedBufferVarsInferer,
"Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
SequenceExpandGradOpNoNeedBufferVarsInference, "X", "Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceExpandGradOpNoNeedBufferVarsInferer,
"X", "Y");
} // namespace operators
} // namespace paddle
......@@ -260,9 +260,9 @@ REGISTER_OPERATOR(sequence_expand, ops::SequenceExpandOp,
ops::SequenceExpandOpMaker,
ops::SequenceExpandOpGradMaker<paddle::framework::OpDesc>,
ops::SequenceExpandOpGradMaker<paddle::imperative::OpBase>,
ops::SequenceExpandOpNoNeedBufferVarsInference);
ops::SequenceExpandOpNoNeedBufferVarsInferer);
REGISTER_OPERATOR(sequence_expand_grad, ops::SequenceExpandOpGrad,
ops::SequenceExpandGradOpNoNeedBufferVarsInference);
ops::SequenceExpandGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
sequence_expand,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -251,7 +251,7 @@ class SequencePadGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequencePadGradOpNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequencePadGradOpNoNeedBufferVarsInferer,
"X");
} // namespace operators
......@@ -262,7 +262,7 @@ REGISTER_OPERATOR(sequence_pad, ops::SequencePadOp, ops::SequencePadOpMaker,
ops::SequencePadGradOpMaker<paddle::framework::OpDesc>,
ops::SequencePadGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp,
ops::SequencePadGradOpNoNeedBufferVarsInference);
ops::SequencePadGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
sequence_pad,
ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -166,7 +166,7 @@ class SequencePoolGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequencePoolGradOpNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequencePoolGradOpNoNeedBufferVarsInferer,
"X");
} // namespace operators
......@@ -177,7 +177,7 @@ REGISTER_OPERATOR(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker,
ops::SequencePoolGradOpMaker<paddle::framework::OpDesc>,
ops::SequencePoolGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp,
ops::SequencePoolGradOpNoNeedBufferVarsInference);
ops::SequencePoolGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
sequence_pool,
ops::SequencePoolKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -168,8 +168,8 @@ class SequenceScatterGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
SequenceScatterGradNoNeedBufferVarsInference, "Updates");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceScatterGradNoNeedBufferVarsInferer,
"Updates");
} // namespace operators
} // namespace paddle
......@@ -180,7 +180,7 @@ REGISTER_OPERATOR(sequence_scatter, ops::SequenceScatterOp,
ops::SequenceScatterGradMaker<paddle::framework::OpDesc>,
ops::SequenceScatterGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sequence_scatter_grad, ops::SequenceScatterGradOp,
ops::SequenceScatterGradNoNeedBufferVarsInference);
ops::SequenceScatterGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(sequence_scatter, ops::SequenceScatterOpKernel<float>,
ops::SequenceScatterOpKernel<double>,
ops::SequenceScatterOpKernel<int>,
......
......@@ -137,7 +137,7 @@ class SequenceSliceGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceSliceGradNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceSliceGradNoNeedBufferVarsInferer,
"X");
} // namespace operators
......@@ -149,7 +149,7 @@ REGISTER_OPERATOR(sequence_slice, ops::SequenceSliceOp,
ops::SequenceSliceGradOpMaker<paddle::framework::OpDesc>,
ops::SequenceSliceGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sequence_slice_grad, ops::SequenceSliceGradOp,
ops::SequenceSliceGradNoNeedBufferVarsInference);
ops::SequenceSliceGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
sequence_slice,
ops::SequenceSliceOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -169,8 +169,8 @@ class SequenceUnpadGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
SequenceUnpadGradOpNoNeedBufferVarsInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceUnpadGradOpNoNeedBufferVarsInferer,
"X");
} // namespace operators
} // namespace paddle
......@@ -181,7 +181,7 @@ REGISTER_OPERATOR(sequence_unpad, ops::SequenceUnpadOp,
ops::SequenceUnpadGradOpMaker<paddle::framework::OpDesc>,
ops::SequenceUnpadGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp,
ops::SequenceUnpadGradOpNoNeedBufferVarsInference);
ops::SequenceUnpadGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
sequence_unpad,
ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -131,7 +131,7 @@ class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SpaceToDepthGradOpNoBuffer, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SpaceToDepthGradOpNoBufferInferer, "X");
template <typename T>
class SpaceToDepthGradOpMaker : public framework::SingleGradOpMaker<T> {
......@@ -179,7 +179,7 @@ REGISTER_OPERATOR(space_to_depth, ops::SpaceToDepthOp, ops::SpaceToDepthOpMaker,
ops::SpaceToDepthGradOpMaker<paddle::framework::OpDesc>,
ops::SpaceToDepthGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(space_to_depth_grad, ops::SpaceToDepthGradOp,
ops::SpaceToDepthGradOpNoBuffer);
ops::SpaceToDepthGradOpNoBufferInferer);
REGISTER_OP_CPU_KERNEL(
space_to_depth,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -88,7 +88,8 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SquaredL2DistanceGradOpNoBuffer, "X", "Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SquaredL2DistanceGradOpNoBufferInferer, "X",
"Y");
template <typename T>
class SquaredL2DistanceGradOpMaker : public framework::SingleGradOpMaker<T> {
......@@ -192,7 +193,7 @@ REGISTER_OPERATOR(
ops::SquaredL2DistanceGradOpMaker<paddle::framework::OpDesc>,
ops::SquaredL2DistanceGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(squared_l2_distance_grad, ops::SquaredL2DistanceGradOp,
ops::SquaredL2DistanceGradOpNoBuffer);
ops::SquaredL2DistanceGradOpNoBufferInferer);
REGISTER_OP_CPU_KERNEL(
squared_l2_distance,
ops::SquaredL2DistanceKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -275,7 +275,7 @@ DECLARE_INPLACE_OP_INFERER(SequeezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(SequeezeGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SqueezeGradNoNeedBufferVarsInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SqueezeGradNoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -284,7 +284,7 @@ REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
ops::SqueezeGradOpMaker<paddle::framework::OpDesc>,
ops::SqueezeGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp,
ops::SqueezeGradNoNeedBufferVarsInference);
ops::SqueezeGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(squeeze2, ops::Squeeze2Op, ops::Squeeze2OpMaker,
ops::Squeeze2GradOpMaker<paddle::framework::OpDesc>,
......
......@@ -304,7 +304,7 @@ class StridedSliceOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
"Input");
} // namespace operators
......@@ -315,7 +315,7 @@ REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker,
ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad,
ops::StridedSliceOpGradNoNeedBufferVarsInference);
ops::StridedSliceOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
strided_slice,
......
......@@ -147,8 +147,7 @@ class TraceGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(TraceGradNoNeedBufferVarsInference,
"Input");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(TraceGradNoNeedBufferVarsInferer, "Input");
} // namespace operators
} // namespace paddle
......@@ -159,7 +158,7 @@ REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker,
ops::TraceGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad,
ops::TraceGradNoNeedBufferVarsInference);
ops::TraceGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
trace, ops::TraceKernel<paddle::platform::CPUDeviceContext, int>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -174,7 +174,7 @@ class UnfoldGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnfoldGradOpNoNeedBufferVarsInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnfoldGradOpNoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -184,7 +184,7 @@ REGISTER_OPERATOR(unfold, ops::UnfoldOp, ops::UnfoldOpMaker,
ops::UnfoldGradMaker<paddle::framework::OpDesc>,
ops::UnfoldGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(unfold_grad, ops::UnfoldGradOp,
ops::UnfoldGradOpNoNeedBufferVarsInference);
ops::UnfoldGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
unfold, ops::UnfoldOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -306,8 +306,7 @@ DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInference,
"X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -316,7 +315,7 @@ REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
ops::UnsqueezeGradOpMaker<paddle::framework::OpDesc>,
ops::UnsqueezeGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
ops::UnsqueezeGradOpNoNeedBufferVarInference);
ops::UnsqueezeGradOpNoNeedBufferVarInferer);
REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker,
ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>,
......
......@@ -184,7 +184,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(WarpCTCGradOpNoNeedBufferVarInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(WarpCTCGradOpNoNeedBufferVarInferer,
"Logits");
} // namespace operators
......@@ -195,7 +195,7 @@ REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker,
ops::WarpCTCGradOpMaker<paddle::framework::OpDesc>,
ops::WarpCTCGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp,
ops::WarpCTCGradOpNoNeedBufferVarInference);
ops::WarpCTCGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
......
......@@ -135,8 +135,7 @@ class WhereOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(WhereGradNoNeedBufferVarsInference, "X",
"Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(WhereGradNoNeedBufferVarsInferer, "X", "Y");
} // namespace operators
} // namespace paddle
......@@ -146,7 +145,7 @@ REGISTER_OPERATOR(where, ops::WhereOp, ops::WhereOpMaker,
ops::WhereOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(where_grad, ops::WhereGradOp,
ops::WhereGradNoNeedBufferVarsInference);
ops::WhereGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
where, ops::WhereKernel<paddle::platform::CPUDeviceContext, float>,
ops::WhereKernel<paddle::platform::CPUDeviceContext, double>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册