未验证 提交 775ddb5a 编写于 作者: C Chen Weihang 提交者: GitHub

fix double grad var judging (#41072)

上级 83efeeae
...@@ -62,11 +62,6 @@ static T* DynLoad(void* handle, std::string name) { ...@@ -62,11 +62,6 @@ static T* DynLoad(void* handle, std::string name) {
return func; return func;
} }
inline static bool IsGradVar(const std::string& var_name) {
std::string suffix = kGradVarSuffix;
return var_name.rfind(suffix) != std::string::npos;
}
inline static bool IsDuplicableVar(const std::string& var_name) { inline static bool IsDuplicableVar(const std::string& var_name) {
std::string suffix = kTensorVectorSuffix; std::string suffix = kTensorVectorSuffix;
return var_name.rfind(suffix) != std::string::npos; return var_name.rfind(suffix) != std::string::npos;
...@@ -77,6 +72,17 @@ inline static std::string NoGrad(const std::string& var_name) { ...@@ -77,6 +72,17 @@ inline static std::string NoGrad(const std::string& var_name) {
return var_name.substr(0, var_name.size() - kGradVarSuffixSize); return var_name.substr(0, var_name.size() - kGradVarSuffixSize);
} }
inline static bool IsGradVar(const std::string& var_name, bool is_double_grad) {
std::string suffix = kGradVarSuffix;
if (!is_double_grad) {
return var_name.rfind(suffix) != std::string::npos;
} else {
// for double grad cases, the X@GRAD is not a grad var, X@GRAD@GRAD is a
// grad var, here we remove a @GRAD suffix
return NoGrad(var_name).rfind(suffix) != std::string::npos;
}
}
inline static bool IsMemberOf(const std::vector<std::string>& vec, inline static bool IsMemberOf(const std::vector<std::string>& vec,
const std::string& name) { const std::string& name) {
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend(); return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
...@@ -493,11 +499,12 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> { ...@@ -493,11 +499,12 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDesc*>& grad_block, const std::string& name, const std::vector<BlockDesc*>& grad_block, const std::string& name,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) const std::vector<std::string>& outputs, bool is_double_grad)
: SingleGradOpMaker<OpDesc>(fwd_op, no_grad_set, grad_to_var, grad_block), : SingleGradOpMaker<OpDesc>(fwd_op, no_grad_set, grad_to_var, grad_block),
name_(name), name_(name),
inputs_(inputs), inputs_(inputs),
outputs_(outputs) {} outputs_(outputs),
is_double_grad_(is_double_grad) {}
protected: protected:
void Apply(GradOpPtr<OpDesc> grad_op) const override { void Apply(GradOpPtr<OpDesc> grad_op) const override {
...@@ -508,7 +515,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> { ...@@ -508,7 +515,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
for (auto& in_name : inputs_) { for (auto& in_name : inputs_) {
VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name; VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) { if (!detail::IsGradVar(in_name, is_double_grad_)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) { if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name)); grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) { } else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
...@@ -540,6 +547,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> { ...@@ -540,6 +547,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
std::string name_; std::string name_;
std::vector<std::string> inputs_; std::vector<std::string> inputs_;
std::vector<std::string> outputs_; std::vector<std::string> outputs_;
bool is_double_grad_{false};
}; };
template <> template <>
...@@ -553,12 +561,13 @@ class CustomGradOpMaker<imperative::OpBase> ...@@ -553,12 +561,13 @@ class CustomGradOpMaker<imperative::OpBase>
const AttributeMap& attrs, const AttributeMap& attrs,
const std::map<std::string, std::string>& inplace_map, const std::map<std::string, std::string>& inplace_map,
const std::string& name, const std::vector<std::string>& inputs, const std::string& name, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) const std::vector<std::string>& outputs, bool is_double_grad)
: SingleGradOpMaker<imperative::OpBase>( : SingleGradOpMaker<imperative::OpBase>(
type, var_base_map_in, var_base_map_out, attrs, inplace_map), type, var_base_map_in, var_base_map_out, attrs, inplace_map),
name_(name), name_(name),
inputs_(inputs), inputs_(inputs),
outputs_(outputs) {} outputs_(outputs),
is_double_grad_(is_double_grad) {}
protected: protected:
// TODO(chenweihang): The code is duplicated with the previous one, because // TODO(chenweihang): The code is duplicated with the previous one, because
...@@ -574,7 +583,7 @@ class CustomGradOpMaker<imperative::OpBase> ...@@ -574,7 +583,7 @@ class CustomGradOpMaker<imperative::OpBase>
for (auto& in_name : inputs_) { for (auto& in_name : inputs_) {
VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name; VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) { if (!detail::IsGradVar(in_name, is_double_grad_)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) { if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name)); grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) { } else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
...@@ -600,6 +609,7 @@ class CustomGradOpMaker<imperative::OpBase> ...@@ -600,6 +609,7 @@ class CustomGradOpMaker<imperative::OpBase>
std::string name_; std::string name_;
std::vector<std::string> inputs_; std::vector<std::string> inputs_;
std::vector<std::string> outputs_; std::vector<std::string> outputs_;
bool is_double_grad_{false};
}; };
//////////// Operator and Kernel Register ////////////// //////////// Operator and Kernel Register //////////////
...@@ -832,21 +842,24 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -832,21 +842,24 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
VLOG(3) << "Custom Operator: backward, op outputs: " VLOG(3) << "Custom Operator: backward, op outputs: "
<< string::join_strings(grad_op_outputs, ','); << string::join_strings(grad_op_outputs, ',');
bool is_double_grad = (i == 2);
// GradOpDescMaker // GradOpDescMaker
info.grad_op_maker_ = [grad_op_name, grad_op_inputs, grad_op_outputs]( info.grad_op_maker_ = [grad_op_name, grad_op_inputs, grad_op_outputs,
is_double_grad](
const OpDesc& fwd_op, const OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set, const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDesc*>& grad_block) { const std::vector<BlockDesc*>& grad_block) {
CustomGradOpMaker<paddle::framework::OpDesc> maker( CustomGradOpMaker<paddle::framework::OpDesc> maker(
fwd_op, no_grad_set, grad_to_var, grad_block, grad_op_name, fwd_op, no_grad_set, grad_to_var, grad_block, grad_op_name,
grad_op_inputs, grad_op_outputs); grad_op_inputs, grad_op_outputs, is_double_grad);
return maker(); return maker();
}; };
// GradOpBaseMaker // GradOpBaseMaker
info.dygraph_grad_op_maker_ = [grad_op_name, grad_op_inputs, info.dygraph_grad_op_maker_ = [grad_op_name, grad_op_inputs,
grad_op_outputs]( grad_op_outputs, is_double_grad](
const std::string& type, const std::string& type,
const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out, const imperative::NameVarBaseMap& var_base_map_out,
...@@ -855,7 +868,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -855,7 +868,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
const std::map<std::string, std::string>& inplace_map) { const std::map<std::string, std::string>& inplace_map) {
CustomGradOpMaker<paddle::imperative::OpBase> maker( CustomGradOpMaker<paddle::imperative::OpBase> maker(
type, var_base_map_in, var_base_map_out, attrs, inplace_map, type, var_base_map_in, var_base_map_out, attrs, inplace_map,
grad_op_name, grad_op_inputs, grad_op_outputs); grad_op_name, grad_op_inputs, grad_op_outputs, is_double_grad);
maker.SetDygraphDefaultAttrsMap(default_attrs); maker.SetDygraphDefaultAttrsMap(default_attrs);
return maker(); return maker();
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册