未验证 提交 a6fbba65 编写于 作者: L Leo Chen 提交者: GitHub

rename inplace/no_need_buffer inferer, part3, test=develop (#24734)

上级 4aa90990
...@@ -82,8 +82,7 @@ class ReduceMeanDoubleGradOpBaseMaker : public imperative::GradOpBaseMakerBase { ...@@ -82,8 +82,7 @@ class ReduceMeanDoubleGradOpBaseMaker : public imperative::GradOpBaseMakerBase {
} }
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceMeanGradNoNeedBufferVarInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceMeanGradNoNeedBufferVarInferer, "X");
"X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -99,7 +98,7 @@ REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__, ...@@ -99,7 +98,7 @@ REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__,
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp, REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp,
ops::ReduceMeanDoubleGradDescMaker, ops::ReduceMeanDoubleGradDescMaker,
ops::ReduceMeanDoubleGradOpBaseMaker, ops::ReduceMeanDoubleGradOpBaseMaker,
ops::ReduceMeanGradNoNeedBufferVarInference); ops::ReduceMeanGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(reduce_mean, REGISTER_OP_CPU_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanFunctor>, float, ops::MeanFunctor>,
......
...@@ -51,7 +51,7 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -51,7 +51,7 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceSumGradNoNeedBufferVarInference, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceSumGradNoNeedBufferVarInferer, "X");
class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference { class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
public: public:
void operator()(paddle::framework::InferVarTypeContext* ctx) const override { void operator()(paddle::framework::InferVarTypeContext* ctx) const override {
...@@ -77,7 +77,7 @@ REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker, ...@@ -77,7 +77,7 @@ REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker,
ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>, ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>); ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp, REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
ops::ReduceSumGradNoNeedBufferVarInference); ops::ReduceSumGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float, reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
......
...@@ -123,8 +123,7 @@ class SeqConcatGradOp : public framework::OperatorWithKernel { ...@@ -123,8 +123,7 @@ class SeqConcatGradOp : public framework::OperatorWithKernel {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SeqConcatGradNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(SeqConcatGradNoNeedBufferVarsInferer, "X");
"X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -140,7 +139,7 @@ REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>, ...@@ -140,7 +139,7 @@ REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>,
Kernel<int>, Kernel<int64_t>); Kernel<int>, Kernel<int64_t>);
REGISTER_OPERATOR(sequence_concat_grad, op::SeqConcatGradOp, REGISTER_OPERATOR(sequence_concat_grad, op::SeqConcatGradOp,
op::SeqConcatGradNoNeedBufferVarsInference); op::SeqConcatGradNoNeedBufferVarsInferer);
template <typename T> template <typename T>
using GradKernel = using GradKernel =
op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>; op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>;
......
...@@ -181,10 +181,10 @@ class SequenceExpandAsOpGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -181,10 +181,10 @@ class SequenceExpandAsOpGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceExpandAsOpNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceExpandAsOpNoNeedBufferVarsInferer,
"Y"); "Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERER( DECLARE_NO_NEED_BUFFER_VARS_INFERER(
SequenceExpandAsGradOpNoNeedBufferVarsInference, "X", "Y"); SequenceExpandAsGradOpNoNeedBufferVarsInferer, "X", "Y");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -194,9 +194,9 @@ REGISTER_OPERATOR( ...@@ -194,9 +194,9 @@ REGISTER_OPERATOR(
sequence_expand_as, ops::SequenceExpandAsOp, ops::SequenceExpandAsOpMaker, sequence_expand_as, ops::SequenceExpandAsOp, ops::SequenceExpandAsOpMaker,
ops::SequenceExpandAsOpGradOpMaker<paddle::framework::OpDesc>, ops::SequenceExpandAsOpGradOpMaker<paddle::framework::OpDesc>,
ops::SequenceExpandAsOpGradOpMaker<paddle::imperative::OpBase>, ops::SequenceExpandAsOpGradOpMaker<paddle::imperative::OpBase>,
ops::SequenceExpandAsOpNoNeedBufferVarsInference); ops::SequenceExpandAsOpNoNeedBufferVarsInferer);
REGISTER_OPERATOR(sequence_expand_as_grad, ops::SequenceExpandAsOpGrad, REGISTER_OPERATOR(sequence_expand_as_grad, ops::SequenceExpandAsOpGrad,
ops::SequenceExpandAsGradOpNoNeedBufferVarsInference); ops::SequenceExpandAsGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_expand_as, sequence_expand_as,
ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, float>, ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -247,10 +247,10 @@ class SequenceExpandOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -247,10 +247,10 @@ class SequenceExpandOpGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceExpandOpNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceExpandOpNoNeedBufferVarsInferer,
"Y"); "Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERER( DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceExpandGradOpNoNeedBufferVarsInferer,
SequenceExpandGradOpNoNeedBufferVarsInference, "X", "Y"); "X", "Y");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -260,9 +260,9 @@ REGISTER_OPERATOR(sequence_expand, ops::SequenceExpandOp, ...@@ -260,9 +260,9 @@ REGISTER_OPERATOR(sequence_expand, ops::SequenceExpandOp,
ops::SequenceExpandOpMaker, ops::SequenceExpandOpMaker,
ops::SequenceExpandOpGradMaker<paddle::framework::OpDesc>, ops::SequenceExpandOpGradMaker<paddle::framework::OpDesc>,
ops::SequenceExpandOpGradMaker<paddle::imperative::OpBase>, ops::SequenceExpandOpGradMaker<paddle::imperative::OpBase>,
ops::SequenceExpandOpNoNeedBufferVarsInference); ops::SequenceExpandOpNoNeedBufferVarsInferer);
REGISTER_OPERATOR(sequence_expand_grad, ops::SequenceExpandOpGrad, REGISTER_OPERATOR(sequence_expand_grad, ops::SequenceExpandOpGrad,
ops::SequenceExpandGradOpNoNeedBufferVarsInference); ops::SequenceExpandGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_expand, sequence_expand,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>, ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -251,7 +251,7 @@ class SequencePadGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -251,7 +251,7 @@ class SequencePadGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequencePadGradOpNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequencePadGradOpNoNeedBufferVarsInferer,
"X"); "X");
} // namespace operators } // namespace operators
...@@ -262,7 +262,7 @@ REGISTER_OPERATOR(sequence_pad, ops::SequencePadOp, ops::SequencePadOpMaker, ...@@ -262,7 +262,7 @@ REGISTER_OPERATOR(sequence_pad, ops::SequencePadOp, ops::SequencePadOpMaker,
ops::SequencePadGradOpMaker<paddle::framework::OpDesc>, ops::SequencePadGradOpMaker<paddle::framework::OpDesc>,
ops::SequencePadGradOpMaker<paddle::imperative::OpBase>); ops::SequencePadGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp, REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp,
ops::SequencePadGradOpNoNeedBufferVarsInference); ops::SequencePadGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_pad, sequence_pad,
ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, float>, ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -166,7 +166,7 @@ class SequencePoolGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -166,7 +166,7 @@ class SequencePoolGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequencePoolGradOpNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequencePoolGradOpNoNeedBufferVarsInferer,
"X"); "X");
} // namespace operators } // namespace operators
...@@ -177,7 +177,7 @@ REGISTER_OPERATOR(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker, ...@@ -177,7 +177,7 @@ REGISTER_OPERATOR(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker,
ops::SequencePoolGradOpMaker<paddle::framework::OpDesc>, ops::SequencePoolGradOpMaker<paddle::framework::OpDesc>,
ops::SequencePoolGradOpMaker<paddle::imperative::OpBase>); ops::SequencePoolGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp, REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp,
ops::SequencePoolGradOpNoNeedBufferVarsInference); ops::SequencePoolGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_pool, sequence_pool,
ops::SequencePoolKernel<paddle::platform::CPUDeviceContext, float>); ops::SequencePoolKernel<paddle::platform::CPUDeviceContext, float>);
......
...@@ -168,8 +168,8 @@ class SequenceScatterGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -168,8 +168,8 @@ class SequenceScatterGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER( DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceScatterGradNoNeedBufferVarsInferer,
SequenceScatterGradNoNeedBufferVarsInference, "Updates"); "Updates");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -180,7 +180,7 @@ REGISTER_OPERATOR(sequence_scatter, ops::SequenceScatterOp, ...@@ -180,7 +180,7 @@ REGISTER_OPERATOR(sequence_scatter, ops::SequenceScatterOp,
ops::SequenceScatterGradMaker<paddle::framework::OpDesc>, ops::SequenceScatterGradMaker<paddle::framework::OpDesc>,
ops::SequenceScatterGradMaker<paddle::imperative::OpBase>); ops::SequenceScatterGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sequence_scatter_grad, ops::SequenceScatterGradOp, REGISTER_OPERATOR(sequence_scatter_grad, ops::SequenceScatterGradOp,
ops::SequenceScatterGradNoNeedBufferVarsInference); ops::SequenceScatterGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(sequence_scatter, ops::SequenceScatterOpKernel<float>, REGISTER_OP_CPU_KERNEL(sequence_scatter, ops::SequenceScatterOpKernel<float>,
ops::SequenceScatterOpKernel<double>, ops::SequenceScatterOpKernel<double>,
ops::SequenceScatterOpKernel<int>, ops::SequenceScatterOpKernel<int>,
......
...@@ -137,7 +137,7 @@ class SequenceSliceGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -137,7 +137,7 @@ class SequenceSliceGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceSliceGradNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceSliceGradNoNeedBufferVarsInferer,
"X"); "X");
} // namespace operators } // namespace operators
...@@ -149,7 +149,7 @@ REGISTER_OPERATOR(sequence_slice, ops::SequenceSliceOp, ...@@ -149,7 +149,7 @@ REGISTER_OPERATOR(sequence_slice, ops::SequenceSliceOp,
ops::SequenceSliceGradOpMaker<paddle::framework::OpDesc>, ops::SequenceSliceGradOpMaker<paddle::framework::OpDesc>,
ops::SequenceSliceGradOpMaker<paddle::imperative::OpBase>); ops::SequenceSliceGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sequence_slice_grad, ops::SequenceSliceGradOp, REGISTER_OPERATOR(sequence_slice_grad, ops::SequenceSliceGradOp,
ops::SequenceSliceGradNoNeedBufferVarsInference); ops::SequenceSliceGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_slice, sequence_slice,
ops::SequenceSliceOpKernel<paddle::platform::CPUDeviceContext, float>, ops::SequenceSliceOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -169,8 +169,8 @@ class SequenceUnpadGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -169,8 +169,8 @@ class SequenceUnpadGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER( DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceUnpadGradOpNoNeedBufferVarsInferer,
SequenceUnpadGradOpNoNeedBufferVarsInference, "X"); "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -181,7 +181,7 @@ REGISTER_OPERATOR(sequence_unpad, ops::SequenceUnpadOp, ...@@ -181,7 +181,7 @@ REGISTER_OPERATOR(sequence_unpad, ops::SequenceUnpadOp,
ops::SequenceUnpadGradOpMaker<paddle::framework::OpDesc>, ops::SequenceUnpadGradOpMaker<paddle::framework::OpDesc>,
ops::SequenceUnpadGradOpMaker<paddle::imperative::OpBase>); ops::SequenceUnpadGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp, REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp,
ops::SequenceUnpadGradOpNoNeedBufferVarsInference); ops::SequenceUnpadGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_unpad, sequence_unpad,
ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, float>, ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -131,7 +131,7 @@ class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -131,7 +131,7 @@ class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SpaceToDepthGradOpNoBuffer, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(SpaceToDepthGradOpNoBufferInferer, "X");
template <typename T> template <typename T>
class SpaceToDepthGradOpMaker : public framework::SingleGradOpMaker<T> { class SpaceToDepthGradOpMaker : public framework::SingleGradOpMaker<T> {
...@@ -179,7 +179,7 @@ REGISTER_OPERATOR(space_to_depth, ops::SpaceToDepthOp, ops::SpaceToDepthOpMaker, ...@@ -179,7 +179,7 @@ REGISTER_OPERATOR(space_to_depth, ops::SpaceToDepthOp, ops::SpaceToDepthOpMaker,
ops::SpaceToDepthGradOpMaker<paddle::framework::OpDesc>, ops::SpaceToDepthGradOpMaker<paddle::framework::OpDesc>,
ops::SpaceToDepthGradOpMaker<paddle::imperative::OpBase>); ops::SpaceToDepthGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(space_to_depth_grad, ops::SpaceToDepthGradOp, REGISTER_OPERATOR(space_to_depth_grad, ops::SpaceToDepthGradOp,
ops::SpaceToDepthGradOpNoBuffer); ops::SpaceToDepthGradOpNoBufferInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
space_to_depth, space_to_depth,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, float>, ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -88,7 +88,8 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { ...@@ -88,7 +88,8 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SquaredL2DistanceGradOpNoBuffer, "X", "Y"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(SquaredL2DistanceGradOpNoBufferInferer, "X",
"Y");
template <typename T> template <typename T>
class SquaredL2DistanceGradOpMaker : public framework::SingleGradOpMaker<T> { class SquaredL2DistanceGradOpMaker : public framework::SingleGradOpMaker<T> {
...@@ -192,7 +193,7 @@ REGISTER_OPERATOR( ...@@ -192,7 +193,7 @@ REGISTER_OPERATOR(
ops::SquaredL2DistanceGradOpMaker<paddle::framework::OpDesc>, ops::SquaredL2DistanceGradOpMaker<paddle::framework::OpDesc>,
ops::SquaredL2DistanceGradOpMaker<paddle::imperative::OpBase>); ops::SquaredL2DistanceGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(squared_l2_distance_grad, ops::SquaredL2DistanceGradOp, REGISTER_OPERATOR(squared_l2_distance_grad, ops::SquaredL2DistanceGradOp,
ops::SquaredL2DistanceGradOpNoBuffer); ops::SquaredL2DistanceGradOpNoBufferInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
squared_l2_distance, squared_l2_distance,
ops::SquaredL2DistanceKernel<paddle::platform::CPUDeviceContext, float>); ops::SquaredL2DistanceKernel<paddle::platform::CPUDeviceContext, float>);
......
...@@ -275,7 +275,7 @@ DECLARE_INPLACE_OP_INFERER(SequeezeInplaceInferer, {"X", "Out"}); ...@@ -275,7 +275,7 @@ DECLARE_INPLACE_OP_INFERER(SequeezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(SequeezeGradInplaceInferer, DECLARE_INPLACE_OP_INFERER(SequeezeGradInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
framework::GradVarName("X")}); framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SqueezeGradNoNeedBufferVarsInference, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(SqueezeGradNoNeedBufferVarsInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -284,7 +284,7 @@ REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker, ...@@ -284,7 +284,7 @@ REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
ops::SqueezeGradOpMaker<paddle::framework::OpDesc>, ops::SqueezeGradOpMaker<paddle::framework::OpDesc>,
ops::SqueezeGradOpMaker<paddle::imperative::OpBase>); ops::SqueezeGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp,
ops::SqueezeGradNoNeedBufferVarsInference); ops::SqueezeGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(squeeze2, ops::Squeeze2Op, ops::Squeeze2OpMaker, REGISTER_OPERATOR(squeeze2, ops::Squeeze2Op, ops::Squeeze2OpMaker,
ops::Squeeze2GradOpMaker<paddle::framework::OpDesc>, ops::Squeeze2GradOpMaker<paddle::framework::OpDesc>,
......
...@@ -304,7 +304,7 @@ class StridedSliceOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -304,7 +304,7 @@ class StridedSliceOpGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
"Input"); "Input");
} // namespace operators } // namespace operators
...@@ -315,7 +315,7 @@ REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker, ...@@ -315,7 +315,7 @@ REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker,
ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>, ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>); ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad, REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad,
ops::StridedSliceOpGradNoNeedBufferVarsInference); ops::StridedSliceOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
strided_slice, strided_slice,
......
...@@ -147,8 +147,7 @@ class TraceGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -147,8 +147,7 @@ class TraceGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(TraceGradNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(TraceGradNoNeedBufferVarsInferer, "Input");
"Input");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -159,7 +158,7 @@ REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker, ...@@ -159,7 +158,7 @@ REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker,
ops::TraceGradOpMaker<paddle::imperative::OpBase>); ops::TraceGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad, REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad,
ops::TraceGradNoNeedBufferVarsInference); ops::TraceGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
trace, ops::TraceKernel<paddle::platform::CPUDeviceContext, int>, trace, ops::TraceKernel<paddle::platform::CPUDeviceContext, int>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, float>, ops::TraceKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -174,7 +174,7 @@ class UnfoldGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -174,7 +174,7 @@ class UnfoldGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnfoldGradOpNoNeedBufferVarsInference, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnfoldGradOpNoNeedBufferVarsInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -184,7 +184,7 @@ REGISTER_OPERATOR(unfold, ops::UnfoldOp, ops::UnfoldOpMaker, ...@@ -184,7 +184,7 @@ REGISTER_OPERATOR(unfold, ops::UnfoldOp, ops::UnfoldOpMaker,
ops::UnfoldGradMaker<paddle::framework::OpDesc>, ops::UnfoldGradMaker<paddle::framework::OpDesc>,
ops::UnfoldGradMaker<paddle::imperative::OpBase>); ops::UnfoldGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(unfold_grad, ops::UnfoldGradOp, REGISTER_OPERATOR(unfold_grad, ops::UnfoldGradOp,
ops::UnfoldGradOpNoNeedBufferVarsInference); ops::UnfoldGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unfold, ops::UnfoldOpKernel<paddle::platform::CPUDeviceContext, float>, unfold, ops::UnfoldOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -306,8 +306,7 @@ DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"}); ...@@ -306,8 +306,7 @@ DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer, DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
framework::GradVarName("X")}); framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");
"X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -316,7 +315,7 @@ REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, ...@@ -316,7 +315,7 @@ REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
ops::UnsqueezeGradOpMaker<paddle::framework::OpDesc>, ops::UnsqueezeGradOpMaker<paddle::framework::OpDesc>,
ops::UnsqueezeGradOpMaker<paddle::imperative::OpBase>); ops::UnsqueezeGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp, REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
ops::UnsqueezeGradOpNoNeedBufferVarInference); ops::UnsqueezeGradOpNoNeedBufferVarInferer);
REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker, REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker,
ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>, ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>,
......
...@@ -184,7 +184,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { ...@@ -184,7 +184,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(WarpCTCGradOpNoNeedBufferVarInference, DECLARE_NO_NEED_BUFFER_VARS_INFERER(WarpCTCGradOpNoNeedBufferVarInferer,
"Logits"); "Logits");
} // namespace operators } // namespace operators
...@@ -195,7 +195,7 @@ REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker, ...@@ -195,7 +195,7 @@ REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker,
ops::WarpCTCGradOpMaker<paddle::framework::OpDesc>, ops::WarpCTCGradOpMaker<paddle::framework::OpDesc>,
ops::WarpCTCGradOpMaker<paddle::imperative::OpBase>); ops::WarpCTCGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp, REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp,
ops::WarpCTCGradOpNoNeedBufferVarInference); ops::WarpCTCGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>); warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -135,8 +135,7 @@ class WhereOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -135,8 +135,7 @@ class WhereOpGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(WhereGradNoNeedBufferVarsInference, "X", DECLARE_NO_NEED_BUFFER_VARS_INFERER(WhereGradNoNeedBufferVarsInferer, "X", "Y");
"Y");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -146,7 +145,7 @@ REGISTER_OPERATOR(where, ops::WhereOp, ops::WhereOpMaker, ...@@ -146,7 +145,7 @@ REGISTER_OPERATOR(where, ops::WhereOp, ops::WhereOpMaker,
ops::WhereOpGradMaker<paddle::imperative::OpBase>); ops::WhereOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(where_grad, ops::WhereGradOp, REGISTER_OPERATOR(where_grad, ops::WhereGradOp,
ops::WhereGradNoNeedBufferVarsInference); ops::WhereGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
where, ops::WhereKernel<paddle::platform::CPUDeviceContext, float>, where, ops::WhereKernel<paddle::platform::CPUDeviceContext, float>,
ops::WhereKernel<paddle::platform::CPUDeviceContext, double>, ops::WhereKernel<paddle::platform::CPUDeviceContext, double>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册