未验证 提交 775ddb5a 编写于 作者: C Chen Weihang 提交者: GitHub

fix double grad var judging (#41072)

上级 83efeeae
......@@ -62,11 +62,6 @@ static T* DynLoad(void* handle, std::string name) {
return func;
}
inline static bool IsGradVar(const std::string& var_name) {
std::string suffix = kGradVarSuffix;
return var_name.rfind(suffix) != std::string::npos;
}
inline static bool IsDuplicableVar(const std::string& var_name) {
std::string suffix = kTensorVectorSuffix;
return var_name.rfind(suffix) != std::string::npos;
......@@ -77,6 +72,17 @@ inline static std::string NoGrad(const std::string& var_name) {
return var_name.substr(0, var_name.size() - kGradVarSuffixSize);
}
inline static bool IsGradVar(const std::string& var_name, bool is_double_grad) {
std::string suffix = kGradVarSuffix;
if (!is_double_grad) {
return var_name.rfind(suffix) != std::string::npos;
} else {
// for double grad cases, the X@GRAD is not a grad var, X@GRAD@GRAD is a
// grad var, here we remove a @GRAD suffix
return NoGrad(var_name).rfind(suffix) != std::string::npos;
}
}
inline static bool IsMemberOf(const std::vector<std::string>& vec,
const std::string& name) {
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
......@@ -493,11 +499,12 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDesc*>& grad_block, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs)
const std::vector<std::string>& outputs, bool is_double_grad)
: SingleGradOpMaker<OpDesc>(fwd_op, no_grad_set, grad_to_var, grad_block),
name_(name),
inputs_(inputs),
outputs_(outputs) {}
outputs_(outputs),
is_double_grad_(is_double_grad) {}
protected:
void Apply(GradOpPtr<OpDesc> grad_op) const override {
......@@ -508,7 +515,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
for (auto& in_name : inputs_) {
VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) {
if (!detail::IsGradVar(in_name, is_double_grad_)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
......@@ -540,6 +547,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
bool is_double_grad_{false};
};
template <>
......@@ -553,12 +561,13 @@ class CustomGradOpMaker<imperative::OpBase>
const AttributeMap& attrs,
const std::map<std::string, std::string>& inplace_map,
const std::string& name, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs)
const std::vector<std::string>& outputs, bool is_double_grad)
: SingleGradOpMaker<imperative::OpBase>(
type, var_base_map_in, var_base_map_out, attrs, inplace_map),
name_(name),
inputs_(inputs),
outputs_(outputs) {}
outputs_(outputs),
is_double_grad_(is_double_grad) {}
protected:
// TODO(chenweihang): The code is duplicated with the previous one, because
......@@ -574,7 +583,7 @@ class CustomGradOpMaker<imperative::OpBase>
for (auto& in_name : inputs_) {
VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) {
if (!detail::IsGradVar(in_name, is_double_grad_)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
......@@ -600,6 +609,7 @@ class CustomGradOpMaker<imperative::OpBase>
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
bool is_double_grad_{false};
};
//////////// Operator and Kernel Register //////////////
......@@ -832,21 +842,24 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
VLOG(3) << "Custom Operator: backward, op outputs: "
<< string::join_strings(grad_op_outputs, ',');
bool is_double_grad = (i == 2);
// GradOpDescMaker
info.grad_op_maker_ = [grad_op_name, grad_op_inputs, grad_op_outputs](
info.grad_op_maker_ = [grad_op_name, grad_op_inputs, grad_op_outputs,
is_double_grad](
const OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDesc*>& grad_block) {
CustomGradOpMaker<paddle::framework::OpDesc> maker(
fwd_op, no_grad_set, grad_to_var, grad_block, grad_op_name,
grad_op_inputs, grad_op_outputs);
grad_op_inputs, grad_op_outputs, is_double_grad);
return maker();
};
// GradOpBaseMaker
info.dygraph_grad_op_maker_ = [grad_op_name, grad_op_inputs,
grad_op_outputs](
grad_op_outputs, is_double_grad](
const std::string& type,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
......@@ -855,7 +868,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
const std::map<std::string, std::string>& inplace_map) {
CustomGradOpMaker<paddle::imperative::OpBase> maker(
type, var_base_map_in, var_base_map_out, attrs, inplace_map,
grad_op_name, grad_op_inputs, grad_op_outputs);
grad_op_name, grad_op_inputs, grad_op_outputs, is_double_grad);
maker.SetDygraphDefaultAttrsMap(default_attrs);
return maker();
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册