未验证 提交 84b72671 编写于 作者: L Leo Chen 提交者: GitHub

dygraph_grad_maker supports varbase without grad_var (#21524)

* dygraph_grad_maker supports varbase without grad_var, test=develop

* fix compile, test=develop

* fix test_tracer, test=develop

* follow comments, test=develop
上级 967eddb5
......@@ -13,7 +13,7 @@ function(op_library TARGET)
set(CUDNN_FILE)
set(mkldnn_cc_srcs)
set(MKLDNN_FILE)
set(op_common_deps operator op_registry math_function)
set(op_common_deps operator op_registry math_function layer)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
......
......@@ -127,9 +127,11 @@ class GradOpBaseMakerBase {
for (auto& var_base_temp : iterator->second) {
if (is_grad) {
PADDLE_ENFORCE_NOT_NULL(var_base_temp->GradVarBase(),
"VarBase grad of OP [%s] should not be null",
fw_op_base_->Type());
if (!var_base_temp->HasGradVar()) {
VLOG(6) << "GradVarBase of var " << var_base_temp->Name()
<< " in OP " << fw_op_base_->Type() << " is null";
var_base_temp->MutableGradVarBase();
}
auto grad_var_base_tmp = var_base_temp->GradVarBase();
if (!is_input) {
auto* tensor = grad_var_base_tmp->MutableVar()
......
......@@ -51,6 +51,12 @@ class Engine {
grad_ops_[op] = std::move(op_shared);
}
const std::unordered_set<VarBase*>& GradVars() const { return grad_vars_; }
const std::unordered_map<OpBase*, std::shared_ptr<OpBase>>& GradOps() const {
return grad_ops_;
}
void InsertGradVar(VarBase* grad) { grad_vars_.emplace(grad); }
bool IsGrad(VarBase* var) { return grad_vars_.count(var) > 0; }
......
......@@ -87,6 +87,18 @@ class VarBase {
const std::shared_ptr<VarBase>& GradVarBase() const { return grad_var_; }
void ClearGradVarBase() { grad_var_ = nullptr; }
const std::shared_ptr<VarBase>& MutableGradVarBase() {
if (grad_var_ == nullptr) {
grad_var_ = std::make_shared<VarBase>(false, GradVarName());
// NOTE(zhiqiu): we should keep grad_var_'s stop_gradient property same as
// fwd varbase
grad_var_->SetOverridedStopGradient(overrided_stop_gradient_);
}
return grad_var_;
}
const framework::Variable& GradVar() const {
PADDLE_ENFORCE_NOT_NULL(grad_var_, "Gradient of %s does not exist", name_);
return grad_var_->var_;
......@@ -151,6 +163,7 @@ class VarBase {
}
return rlt;
}
void ClearGradOps() { grad_ops_.clear(); }
const std::string& Name() const { return name_; }
......
......@@ -119,7 +119,7 @@ TEST(test_tracer, test_track_backward_output) {
std::shared_ptr<imperative::VarBase> x_in(
new imperative::VarBase(true, "x_in"));
std::shared_ptr<imperative::VarBase> y_in(
new imperative::VarBase(false, "y_in"));
new imperative::VarBase(true, "y_in"));
x_in->SetOverridedStopGradient(false);
std::shared_ptr<imperative::VarBase> vout(
new imperative::VarBase(true, "vout"));
......@@ -146,7 +146,10 @@ TEST(test_tracer, test_track_backward_output) {
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
ASSERT_ANY_THROW(tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true));
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
auto* engine = tracer.GetDefaultEngine();
ASSERT_NE(engine->GradVars().size(), 0UL);
ASSERT_NE(engine->GradOps().size(), 0UL); // trace_backward already ran.
}
TEST(test_tracer, test_track_backward_input) {
......@@ -157,7 +160,7 @@ TEST(test_tracer, test_track_backward_input) {
std::shared_ptr<imperative::VarBase> y_in(
new imperative::VarBase(true, "y_in"));
std::shared_ptr<imperative::VarBase> vout(
new imperative::VarBase(false, "vout"));
new imperative::VarBase(true, "vout"));
platform::CPUPlace place;
x_in->SetOverridedStopGradient(false);
std::vector<float> src_data(10, 2.0);
......@@ -182,7 +185,10 @@ TEST(test_tracer, test_track_backward_input) {
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
ASSERT_ANY_THROW(tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true));
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
auto* engine = tracer.GetDefaultEngine();
ASSERT_NE(engine->GradVars().size(), 0UL);
ASSERT_NE(engine->GradOps().size(), 0UL); // trace_backward already ran.
}
#if defined(PADDLE_WITH_CUDA)
TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
......@@ -296,6 +302,73 @@ TEST(test_tracer, test_expected_place) {
ASSERT_EQ(platform::is_gpu_place(tracer.ExpectedPlace()), true);
}
TEST(test_tracer, test_var_without_grad_var) {
// Doing an mul
imperative::Tracer tracer;
std::shared_ptr<imperative::VarBase> x_in(
new imperative::VarBase(true, "x_in"));
x_in->ClearGradVarBase();
std::shared_ptr<imperative::VarBase> y_in(
new imperative::VarBase(true, "y_in"));
std::shared_ptr<imperative::VarBase> vout(
new imperative::VarBase(true, "vout"));
x_in->SetOverridedStopGradient(false);
y_in->SetOverridedStopGradient(false);
platform::CPUPlace place;
std::vector<float> src_data(10, 2.0);
std::vector<int64_t> dims1 = {2, 5};
std::vector<int64_t> dims2 = {5, 2};
auto* x_in_tensor = x_in->MutableVar()->GetMutable<framework::LoDTensor>();
auto* y_in_tensor = y_in->MutableVar()->GetMutable<framework::LoDTensor>();
x_in_tensor->Resize(framework::make_ddim(dims1));
auto* mutable_x = x_in_tensor->mutable_data<float>(place);
paddle::memory::Copy(place, mutable_x, place, src_data.data(),
sizeof(float) * src_data.size());
y_in_tensor->Resize(framework::make_ddim(dims2));
auto* mutable_y = y_in_tensor->mutable_data<float>(place);
paddle::memory::Copy(place, mutable_y, place, src_data.data(),
sizeof(float) * src_data.size());
var_pair x_pair = var_pair("X", vb_vector(1, x_in));
var_pair y_pair = var_pair("Y", vb_vector(1, y_in));
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {x_pair, y_pair};
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
const auto& out_tensor = vout->Var().Get<framework::LoDTensor>();
for (int i = 0; i < vout->Var().Get<framework::LoDTensor>().numel(); i++) {
ASSERT_EQ(out_tensor.data<float>()[i], 20.0);
}
detail::BackwardStrategy back_st;
imperative::Engine* engine = tracer.GetDefaultEngine();
ASSERT_NE(engine->GradVars().size(), 0UL);
ASSERT_NE(engine->GradOps().size(), 0UL); // trace_backward already ran.
engine->Init(vout.get(), back_st);
engine->Execute();
// check the grad
framework::LoDTensor x_grad;
framework::TensorCopySync(x_in->GradVar().Get<framework::LoDTensor>(), place,
&x_grad);
for (int i = 0; i < x_grad.numel(); ++i) {
ASSERT_EQ(x_grad.data<float>()[i], 4.0);
}
framework::LoDTensor y_grad;
framework::TensorCopySync(y_in->GradVar().Get<framework::LoDTensor>(), place,
&y_grad);
for (int i = 0; i < y_grad.numel(); ++i) {
ASSERT_EQ(y_grad.data<float>()[i], 4.0);
}
}
} // namespace imperative
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册