未验证 提交 5c8f210c 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine inplace inference registry, test=develop (#19032)

上级 b6d1d890
......@@ -53,5 +53,15 @@ class SingleOpInplaceInToOut : public InplaceOpInference {
}
};
#define DECLARE_INPLACE_OP_INFERER(class_name, ...) \
class class_name final : public ::paddle::framework::InplaceOpInference { \
public: \
std::unordered_map<std::string, std::string> operator()( \
const ::paddle::framework::OpDesc& op_desc, \
bool use_cuda) const final { \
return {__VA_ARGS__}; \
} \
}
} // namespace framework
} // namespace paddle
......@@ -774,13 +774,9 @@ class SquareDoubleGradMaker
}
};
class ActivationGradOpInplaceInference : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
......
......@@ -598,36 +598,13 @@ std::unique_ptr<framework::OpDesc> BatchNormGradMaker::Apply() const {
return std::unique_ptr<framework::OpDesc>(op);
}
class BatchNormInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{"Mean", "MeanOut"}, {"Variance", "VarianceOut"}, {"X", "Y"}};
}
};
class BatchNormGradInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
// Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C]
return {
{framework::GradVarName("Y"), framework::GradVarName("X")},
{"SavedMean", framework::GradVarName("Scale")},
{"SavedVariance", framework::GradVarName("Bias")},
};
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker)
// ops::BatchNormInplaceInToOut);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp)
// ops::BatchNormGradInplaceInToOut);
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OP_CPU_KERNEL(
batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -317,21 +317,10 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
}
};
class ElementwiseOpInplace : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{"X", "Out"}};
}
};
class ElementwiseGradOpInplace : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
DECLARE_INPLACE_OP_INFERER(ElementwiseOpInplace, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplace,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "Y");
......
......@@ -268,21 +268,10 @@ class Flatten2GradOp : public framework::OperatorBase {
}
};
class FlattenOpInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{"X", "Out"}};
}
};
class FlattenGradInplaceinToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInToOut, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceinToOut,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
......
......@@ -170,21 +170,10 @@ class GroupNormGradMaker : public framework::SingleGradOpDescMaker {
}
};
class GroupNormInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{"X", "Y"}};
}
};
class GroupNormGradInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Y"), framework::GradVarName("X")}};
}
};
DECLARE_INPLACE_OP_INFERER(GroupNormInplaceInToOut, {"X", "Y"});
DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInToOut,
{framework::GradVarName("Y"),
framework::GradVarName("X")});
class GroupNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
......
......@@ -393,21 +393,10 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
}
};
class ReshapeOpInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{"X", "Out"}};
}
};
class ReshapeGradInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInToOut, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
......
......@@ -255,23 +255,11 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
}
};
class SoftmaxWithCrossEntropyInplaceInference
: public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const {
return {{"Logits", "Softmax"}};
}
};
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInference,
{"Logits", "Softmax"});
class SoftmaxWithCrossEntropyGradInplaceInference
: public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const {
return {{"Softmax", framework::GradVarName("Logits")}};
}
};
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInference,
{"Softmax", framework::GradVarName("Logits")});
} // namespace operators
} // namespace paddle
......
......@@ -238,13 +238,7 @@ class SumGradMaker : public framework::GradOpDescMakerBase {
}
};
class SumInplace : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const override {
return {{"X", "Out"}};
}
};
DECLARE_INPLACE_OP_INFERER(SumInplace, {"X", "Out"});
} // namespace operators
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册