未验证 提交 5c8f210c 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine inplace inference registry, test=develop (#19032)

上级 b6d1d890
...@@ -53,5 +53,15 @@ class SingleOpInplaceInToOut : public InplaceOpInference { ...@@ -53,5 +53,15 @@ class SingleOpInplaceInToOut : public InplaceOpInference {
} }
}; };
#define DECLARE_INPLACE_OP_INFERER(class_name, ...) \
class class_name final : public ::paddle::framework::InplaceOpInference { \
public: \
std::unordered_map<std::string, std::string> operator()( \
const ::paddle::framework::OpDesc& op_desc, \
bool use_cuda) const final { \
return {__VA_ARGS__}; \
} \
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -774,13 +774,9 @@ class SquareDoubleGradMaker ...@@ -774,13 +774,9 @@ class SquareDoubleGradMaker
} }
}; };
class ActivationGradOpInplaceInference : public framework::InplaceOpInference { DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
public: {framework::GradVarName("Out"),
std::unordered_map<std::string, std::string> operator()( framework::GradVarName("X")});
const framework::OpDesc& op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -598,36 +598,13 @@ std::unique_ptr<framework::OpDesc> BatchNormGradMaker::Apply() const { ...@@ -598,36 +598,13 @@ std::unique_ptr<framework::OpDesc> BatchNormGradMaker::Apply() const {
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<framework::OpDesc>(op);
} }
class BatchNormInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{"Mean", "MeanOut"}, {"Variance", "VarianceOut"}, {"X", "Y"}};
}
};
class BatchNormGradInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
// Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C]
return {
{framework::GradVarName("Y"), framework::GradVarName("X")},
{"SavedMean", framework::GradVarName("Scale")},
{"SavedVariance", framework::GradVarName("Bias")},
};
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker) ops::BatchNormOpInferVarType, ops::BatchNormGradMaker);
// ops::BatchNormInplaceInToOut); REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp)
// ops::BatchNormGradInplaceInToOut);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>, batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -317,21 +317,10 @@ class ElemwiseGradKernel : public framework::OpKernel<T> { ...@@ -317,21 +317,10 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
} }
}; };
class ElementwiseOpInplace : public framework::InplaceOpInference { DECLARE_INPLACE_OP_INFERER(ElementwiseOpInplace, {"X", "Out"});
public: DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplace,
std::unordered_map<std::string, std::string> operator()( {framework::GradVarName("Out"),
const framework::OpDesc &op_desc, bool use_cuda) const override { framework::GradVarName("X")});
return {{"X", "Out"}};
}
};
class ElementwiseGradOpInplace : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "Y"); DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "Y");
......
...@@ -268,21 +268,10 @@ class Flatten2GradOp : public framework::OperatorBase { ...@@ -268,21 +268,10 @@ class Flatten2GradOp : public framework::OperatorBase {
} }
}; };
class FlattenOpInplaceInToOut : public framework::InplaceOpInference { DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInToOut, {"X", "Out"});
public: DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceinToOut,
std::unordered_map<std::string, std::string> operator()( {framework::GradVarName("Out"),
const framework::OpDesc &op_desc, bool use_cuda) const override { framework::GradVarName("X")});
return {{"X", "Out"}};
}
};
class FlattenGradInplaceinToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -170,21 +170,10 @@ class GroupNormGradMaker : public framework::SingleGradOpDescMaker { ...@@ -170,21 +170,10 @@ class GroupNormGradMaker : public framework::SingleGradOpDescMaker {
} }
}; };
class GroupNormInplaceInToOut : public framework::InplaceOpInference { DECLARE_INPLACE_OP_INFERER(GroupNormInplaceInToOut, {"X", "Y"});
public: DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInToOut,
std::unordered_map<std::string, std::string> operator()( {framework::GradVarName("Y"),
const framework::OpDesc &op_desc, bool use_cuda) const override { framework::GradVarName("X")});
return {{"X", "Y"}};
}
};
class GroupNormGradInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Y"), framework::GradVarName("X")}};
}
};
class GroupNormOpInferVarType class GroupNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput { : public framework::PassInDtypeAndVarTypeToOutput {
......
...@@ -393,21 +393,10 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -393,21 +393,10 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
} }
}; };
class ReshapeOpInplaceInToOut : public framework::InplaceOpInference { DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInToOut, {"X", "Out"});
public: DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut,
std::unordered_map<std::string, std::string> operator()( {framework::GradVarName("Out"),
const framework::OpDesc &op_desc, bool use_cuda) const override { framework::GradVarName("X")});
return {{"X", "Out"}};
}
};
class ReshapeGradInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -255,23 +255,11 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { ...@@ -255,23 +255,11 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
} }
}; };
class SoftmaxWithCrossEntropyInplaceInference DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInference,
: public framework::InplaceOpInference { {"Logits", "Softmax"});
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const {
return {{"Logits", "Softmax"}};
}
};
class SoftmaxWithCrossEntropyGradInplaceInference DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInference,
: public framework::InplaceOpInference { {"Softmax", framework::GradVarName("Logits")});
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const {
return {{"Softmax", framework::GradVarName("Logits")}};
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -238,13 +238,7 @@ class SumGradMaker : public framework::GradOpDescMakerBase { ...@@ -238,13 +238,7 @@ class SumGradMaker : public framework::GradOpDescMakerBase {
} }
}; };
class SumInplace : public framework::InplaceOpInference { DECLARE_INPLACE_OP_INFERER(SumInplace, {"X", "Out"});
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const override {
return {{"X", "Out"}};
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册