diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 57657941ef83f3a3ea0e9e716d49a8b38d22eef8..9f08d0b73fc0870bc5cb215ef0a8633dda5c78ab 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -184,6 +184,12 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { auto data_type = src_tensor.type(); auto place = src_tensor.place(); + PADDLE_ENFORCE_EQ(dst_tensor->type(), data_type, + platform::errors::PreconditionNotMet( + "The data type of source tensor and destination tensor " + "should be equal, Otherwise, the calculation results " + "will be incorrect.")); + #define PADDLE_TENSOR_ADD(cpp_type) \ if (data_type == framework::DataTypeTrait::DataType()) { \ TensorAddFunctor func( \ @@ -422,9 +428,9 @@ void GradientAccumulator::AccumulateGrad() { auto* src = inner_var_->MutableVar(); auto* dst = var_->MutableVar(); if (!var_->IsEmpty()) { - VLOG(6) << "Leaf Gradient Var(" << var_->Name() - << ") has been calculated by previous graph, will accumulate on " - "previous graph."; + VLOG(6) << "Leaf Var(" << var_->Name() + << ")'s Gradient has been initizlized, will accumulate on " + "previous gradient."; if (dst->IsType()) { if (src->IsType()) { TensorAdd(*src, dst); @@ -444,8 +450,9 @@ void GradientAccumulator::AccumulateGrad() { "Only support LoDTensor and SelectedRows for gradient var")); } } else { - VLOG(6) << "Leaf Gradient Var(" << var_->Name() - << ") has not been initialized, not accumulate. Just move"; + VLOG(6) + << "Leaf Var(" << var_->Name() + << ")'s Gradient has not been initialized, not accumulate. Just move"; *(dst) = std::move(*src); var_->SetType(inner_var_->Type()); var_->SetDataType(inner_var_->DataType()); diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 6e28ecd9971abcee51e4c3910896eadae7b01c0a..53ae5b8127fdba5dd68ddc6748dc35e9fe7ae8ec 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -277,32 +277,73 @@ std::shared_ptr VarBase::NewVarBase(const platform::Place& dst_place, } void VarBase::CopyFrom(const VarBase& src, const bool blocking) { - if (SharedVar()->IsEmpty()) { - VLOG(3) << "deep copy Variable from " << src.Name() << " to " << Name(); - SetPersistable(src.Persistable()); + if (src.SharedVar()->IsEmpty()) { + return; + } + + VLOG(3) << "Deep copy Tensor from " << src.Name() << " to " << Name(); + if (Var().IsInitialized()) { + PADDLE_ENFORCE_EQ(DataType(), src.DataType(), + platform::errors::PreconditionNotMet( + "Tensor %s has different data type with Tensor %s, " + "Tensor Copy cannot be performed!", + Name(), src.Name())); + PADDLE_ENFORCE_EQ(Type(), src.Type(), + platform::errors::PreconditionNotMet( + "Tensor %s has different type with Tensor %s, Tensor " + "Copy cannot be performed!", + Name(), src.Name())); + } else { SetDataType(src.DataType()); SetType(src.Type()); - SetOverridedStopGradient(src.OverridedStopGradient()); - if (!src.SharedVar()->IsEmpty()) { - const platform::Place& place = src.Place(); - if (src.Var().IsType()) { - auto& src_tensor = src.Var().Get(); - auto* dst_tensor = MutableVar()->GetMutable(); - dst_tensor->set_lod(src_tensor.lod()); - framework::TensorCopy(src_tensor, place, dst_tensor); - } else if (src.Var().IsType()) { - auto& src_selected_rows = src.Var().Get(); - auto* dst_selected_rows = - MutableVar()->GetMutable(); - dst_selected_rows->set_height(src_selected_rows.height()); - dst_selected_rows->set_rows(src_selected_rows.rows()); - framework::TensorCopy(src_selected_rows.value(), place, - dst_selected_rows->mutable_value()); - } - if (blocking) { - platform::DeviceContextPool::Instance().Get(place)->Wait(); - } + SetPersistable(src.Persistable()); + InnerSetOverridedStopGradient(src.OverridedStopGradient()); + } + + platform::Place place = src.Place(); + if (src.Var().IsType()) { + auto& src_tensor = src.Var().Get(); + auto* dst_tensor = MutableVar()->GetMutable(); + if (dst_tensor && dst_tensor->IsInitialized()) { + PADDLE_ENFORCE_EQ(dst_tensor->dims(), src_tensor.dims(), + platform::errors::PreconditionNotMet( + "Tensor %s has different dims with Tensor %s, " + "Tensor Copy cannot be performed!", + Name(), src.Name())); + PADDLE_ENFORCE_EQ(dst_tensor->lod(), src_tensor.lod(), + platform::errors::PreconditionNotMet( + "Tensor %s has different dims with Tensor %s, " + "Tensor Copy cannot be performed!", + Name(), src.Name())); + place = Place(); + } else { + dst_tensor->set_lod(src_tensor.lod()); + dst_tensor->Resize(src_tensor.dims()); + } + framework::TensorCopy(src_tensor, place, dst_tensor); + } else if (src.Var().IsType()) { + auto& src_selected_rows = src.Var().Get(); + auto* dst_selected_rows = + MutableVar()->GetMutable(); + dst_selected_rows->set_height(src_selected_rows.height()); + dst_selected_rows->set_rows(src_selected_rows.rows()); + + auto& src_tensor = src_selected_rows.value(); + auto* dst_tensor = dst_selected_rows->mutable_value(); + if (dst_tensor && dst_tensor->IsInitialized()) { + PADDLE_ENFORCE_EQ(dst_tensor->dims(), src_tensor.dims(), + platform::errors::PreconditionNotMet( + "Tensor %s has different dims with Tensor %s, " + "Tensor Copy cannot be performed!", + Name(), src.Name())); + place = Place(); + } else { + dst_tensor->Resize(src_tensor.dims()); } + framework::TensorCopy(src_tensor, place, dst_tensor); + } + if (blocking) { + platform::DeviceContextPool::Instance().Get(place)->Wait(); } } diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 56e16ba199707c37031b55b65057cd95ff5ed805..16580627ed1964c6cfc81a48b15f26d0b2459a78 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -110,6 +110,7 @@ class VarBase { void SetGradVarBase(const VarBase& grad_var) { MutableGradVarBase()->CopyFrom(grad_var, true); + MutableGradVarBase()->SharedVar()->SetIsEmpty(false); } const std::shared_ptr& MutableGradVarBase() { @@ -142,6 +143,8 @@ class VarBase { return grad_var_->MutableVar(); } + bool IsLeaf() const { return var_->IsLeaf(); } + void SetOverridedStopGradient(bool stop_gradient) { var_->SetOverridedStopGradient(stop_gradient); if (grad_var_) { @@ -151,10 +154,8 @@ class VarBase { bool OverridedStopGradient() const { return var_->OverridedStopGradient(); } - bool IsLeaf() const { return var_->IsLeaf(); } - void InnerSetOverridedStopGradient(bool stop_gradient) { - if (var_->InnerOverridedStopGradient() == -1) { + if (InnerOverridedStopGradient() == -1) { var_->InnerSetOverridedStopGradient(stop_gradient); if (grad_var_) { grad_var_->InnerSetOverridedStopGradient(stop_gradient); @@ -162,6 +163,10 @@ class VarBase { } } + int InnerOverridedStopGradient() const { + return var_->InnerOverridedStopGradient(); + } + void SetPersistable(bool persistable) { var_->SetPersistable(persistable); } bool Persistable() const { return var_->Persistable(); } diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 1cdb57c540ac4dec982689c21137b945906666fe..3aed1af59795a6f6bd99483682feefb2393e4fe9 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -41,7 +41,6 @@ class MyLayer(fluid.Layer): class MLP(fluid.Layer): def __init__(self, input_size): super(MLP, self).__init__() - self._linear1 = None self._linear1 = Linear( input_size, 3, @@ -607,12 +606,21 @@ class TestImperative(unittest.TestCase): mlp2.clear_gradients() self.assertTrue(np.array_equal(clear_loss.grad.numpy(), [1])) - if ((batch_id + 1) % 10) == 0: + if ((batch_id + 1) % 10) % 2 == 0: mlp1.clear_gradients() expected_weight1_grad = 0. expected_bias1_grad = 0. expected_weight2_grad = 0. expected_bias2_grad = 0. + elif ((batch_id + 1) % 10) % 2 == 1: + mlp1.clear_gradients() + mlp1._linear1.weight._set_grad_ivar( + paddle.ones([input_size, 3])) + mlp1._linear2.weight._set_grad_ivar(paddle.ones([3, 4])) + expected_weight1_grad = 1. + expected_bias1_grad = 0. + expected_weight2_grad = 1. + expected_bias2_grad = 0. with fluid.dygraph.guard(): test_single_api(False)