提交 0318f47e 编写于 作者: Y Yu Yang 提交者: QI JUN

Enhance in backward (#5262)

Set gradient's data type based on its forward variable
上级 1363ddb6
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <deque> #include <deque>
#include <list> #include <list>
#include <memory> #include <memory>
#include <unordered_set>
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
...@@ -285,6 +286,15 @@ static bool AllGradInSet(const std::vector<std::string>& names, ...@@ -285,6 +286,15 @@ static bool AllGradInSet(const std::vector<std::string>& names,
return true; return true;
} }
static std::string FwdName(const std::string& grad_name) {
auto pos = grad_name.find("@GRAD");
if (pos == std::string::npos) {
return "";
} else {
return grad_name.substr(0, pos);
}
}
static void CreateGradVarInBlock( static void CreateGradVarInBlock(
size_t grad_op_start_index, size_t grad_op_start_index,
const std::unordered_map<std::string, std::string>& param_name_map, const std::unordered_map<std::string, std::string>& param_name_map,
...@@ -294,6 +304,7 @@ static void CreateGradVarInBlock( ...@@ -294,6 +304,7 @@ static void CreateGradVarInBlock(
for (size_t op_index = grad_op_start_index; op_index < ops.size(); for (size_t op_index = grad_op_start_index; op_index < ops.size();
++op_index) { ++op_index) {
bool need_infer_shape = false; bool need_infer_shape = false;
std::unordered_set<std::string> new_vars;
ForEachVarName(ops[op_index]->Outputs(), ForEachVarName(ops[op_index]->Outputs(),
[&](const std::string& grad_var_name) { [&](const std::string& grad_var_name) {
if (block_desc->HasVar(grad_var_name)) { if (block_desc->HasVar(grad_var_name)) {
...@@ -301,8 +312,7 @@ static void CreateGradVarInBlock( ...@@ -301,8 +312,7 @@ static void CreateGradVarInBlock(
} }
need_infer_shape = true; need_infer_shape = true;
auto var = block_desc->Var(grad_var_name); auto var = block_desc->Var(grad_var_name);
// FIXME(qiao) infer the datatype new_vars.insert(var->Name());
var->SetDataType(framework::DataType::FP32);
auto it = param_name_map.find(grad_var_name); auto it = param_name_map.find(grad_var_name);
if (it == param_name_map.end()) { if (it == param_name_map.end()) {
return false; return false;
...@@ -316,6 +326,21 @@ static void CreateGradVarInBlock( ...@@ -316,6 +326,21 @@ static void CreateGradVarInBlock(
}); });
if (need_infer_shape) { if (need_infer_shape) {
ops[op_index]->InferVarType(block_desc); ops[op_index]->InferVarType(block_desc);
for (auto& arg : ops[op_index]->OutputArgumentNames()) {
if (new_vars.find(arg) == new_vars.end()) {
continue;
}
auto pname = FwdName(arg);
auto* param = block_desc->FindVar(pname);
auto* grad = block_desc->FindVar(arg);
if (param == nullptr) {
LOG(WARNING) << "Cannot find forward variable of " << arg
<< ". Set its gradient to FP32";
grad->SetDataType(DataType::FP32);
} else {
grad->SetDataType(param->GetDataType());
}
}
ops[op_index]->InferShape(*block_desc); ops[op_index]->InferShape(*block_desc);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册