未验证 提交 dffb0b22 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

fix set_grad_ivar bug of Tensor.backward (#34819)

上级 6326c3ef
...@@ -184,6 +184,12 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { ...@@ -184,6 +184,12 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
auto data_type = src_tensor.type(); auto data_type = src_tensor.type();
auto place = src_tensor.place(); auto place = src_tensor.place();
PADDLE_ENFORCE_EQ(dst_tensor->type(), data_type,
platform::errors::PreconditionNotMet(
"The data type of source tensor and destination tensor "
"should be equal, Otherwise, the calculation results "
"will be incorrect."));
#define PADDLE_TENSOR_ADD(cpp_type) \ #define PADDLE_TENSOR_ADD(cpp_type) \
if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \ if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
TensorAddFunctor<cpp_type> func( \ TensorAddFunctor<cpp_type> func( \
...@@ -422,9 +428,9 @@ void GradientAccumulator::AccumulateGrad() { ...@@ -422,9 +428,9 @@ void GradientAccumulator::AccumulateGrad() {
auto* src = inner_var_->MutableVar(); auto* src = inner_var_->MutableVar();
auto* dst = var_->MutableVar(); auto* dst = var_->MutableVar();
if (!var_->IsEmpty()) { if (!var_->IsEmpty()) {
VLOG(6) << "Leaf Gradient Var(" << var_->Name() VLOG(6) << "Leaf Var(" << var_->Name()
<< ") has been calculated by previous graph, will accumulate on " << ")'s Gradient has been initizlized, will accumulate on "
"previous graph."; "previous gradient.";
if (dst->IsType<framework::LoDTensor>()) { if (dst->IsType<framework::LoDTensor>()) {
if (src->IsType<framework::LoDTensor>()) { if (src->IsType<framework::LoDTensor>()) {
TensorAdd(*src, dst); TensorAdd(*src, dst);
...@@ -444,8 +450,9 @@ void GradientAccumulator::AccumulateGrad() { ...@@ -444,8 +450,9 @@ void GradientAccumulator::AccumulateGrad() {
"Only support LoDTensor and SelectedRows for gradient var")); "Only support LoDTensor and SelectedRows for gradient var"));
} }
} else { } else {
VLOG(6) << "Leaf Gradient Var(" << var_->Name() VLOG(6)
<< ") has not been initialized, not accumulate. Just move"; << "Leaf Var(" << var_->Name()
<< ")'s Gradient has not been initialized, not accumulate. Just move";
*(dst) = std::move(*src); *(dst) = std::move(*src);
var_->SetType(inner_var_->Type()); var_->SetType(inner_var_->Type());
var_->SetDataType(inner_var_->DataType()); var_->SetDataType(inner_var_->DataType());
......
...@@ -277,32 +277,73 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -277,32 +277,73 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
} }
void VarBase::CopyFrom(const VarBase& src, const bool blocking) { void VarBase::CopyFrom(const VarBase& src, const bool blocking) {
if (SharedVar()->IsEmpty()) { if (src.SharedVar()->IsEmpty()) {
VLOG(3) << "deep copy Variable from " << src.Name() << " to " << Name(); return;
SetPersistable(src.Persistable()); }
VLOG(3) << "Deep copy Tensor from " << src.Name() << " to " << Name();
if (Var().IsInitialized()) {
PADDLE_ENFORCE_EQ(DataType(), src.DataType(),
platform::errors::PreconditionNotMet(
"Tensor %s has different data type with Tensor %s, "
"Tensor Copy cannot be performed!",
Name(), src.Name()));
PADDLE_ENFORCE_EQ(Type(), src.Type(),
platform::errors::PreconditionNotMet(
"Tensor %s has different type with Tensor %s, Tensor "
"Copy cannot be performed!",
Name(), src.Name()));
} else {
SetDataType(src.DataType()); SetDataType(src.DataType());
SetType(src.Type()); SetType(src.Type());
SetOverridedStopGradient(src.OverridedStopGradient()); SetPersistable(src.Persistable());
if (!src.SharedVar()->IsEmpty()) { InnerSetOverridedStopGradient(src.OverridedStopGradient());
const platform::Place& place = src.Place(); }
if (src.Var().IsType<framework::LoDTensor>()) {
auto& src_tensor = src.Var().Get<framework::LoDTensor>(); platform::Place place = src.Place();
auto* dst_tensor = MutableVar()->GetMutable<framework::LoDTensor>(); if (src.Var().IsType<framework::LoDTensor>()) {
dst_tensor->set_lod(src_tensor.lod()); auto& src_tensor = src.Var().Get<framework::LoDTensor>();
framework::TensorCopy(src_tensor, place, dst_tensor); auto* dst_tensor = MutableVar()->GetMutable<framework::LoDTensor>();
} else if (src.Var().IsType<framework::SelectedRows>()) { if (dst_tensor && dst_tensor->IsInitialized()) {
auto& src_selected_rows = src.Var().Get<framework::SelectedRows>(); PADDLE_ENFORCE_EQ(dst_tensor->dims(), src_tensor.dims(),
auto* dst_selected_rows = platform::errors::PreconditionNotMet(
MutableVar()->GetMutable<framework::SelectedRows>(); "Tensor %s has different dims with Tensor %s, "
dst_selected_rows->set_height(src_selected_rows.height()); "Tensor Copy cannot be performed!",
dst_selected_rows->set_rows(src_selected_rows.rows()); Name(), src.Name()));
framework::TensorCopy(src_selected_rows.value(), place, PADDLE_ENFORCE_EQ(dst_tensor->lod(), src_tensor.lod(),
dst_selected_rows->mutable_value()); platform::errors::PreconditionNotMet(
} "Tensor %s has different dims with Tensor %s, "
if (blocking) { "Tensor Copy cannot be performed!",
platform::DeviceContextPool::Instance().Get(place)->Wait(); Name(), src.Name()));
} place = Place();
} else {
dst_tensor->set_lod(src_tensor.lod());
dst_tensor->Resize(src_tensor.dims());
}
framework::TensorCopy(src_tensor, place, dst_tensor);
} else if (src.Var().IsType<framework::SelectedRows>()) {
auto& src_selected_rows = src.Var().Get<framework::SelectedRows>();
auto* dst_selected_rows =
MutableVar()->GetMutable<framework::SelectedRows>();
dst_selected_rows->set_height(src_selected_rows.height());
dst_selected_rows->set_rows(src_selected_rows.rows());
auto& src_tensor = src_selected_rows.value();
auto* dst_tensor = dst_selected_rows->mutable_value();
if (dst_tensor && dst_tensor->IsInitialized()) {
PADDLE_ENFORCE_EQ(dst_tensor->dims(), src_tensor.dims(),
platform::errors::PreconditionNotMet(
"Tensor %s has different dims with Tensor %s, "
"Tensor Copy cannot be performed!",
Name(), src.Name()));
place = Place();
} else {
dst_tensor->Resize(src_tensor.dims());
} }
framework::TensorCopy(src_tensor, place, dst_tensor);
}
if (blocking) {
platform::DeviceContextPool::Instance().Get(place)->Wait();
} }
} }
......
...@@ -110,6 +110,7 @@ class VarBase { ...@@ -110,6 +110,7 @@ class VarBase {
void SetGradVarBase(const VarBase& grad_var) { void SetGradVarBase(const VarBase& grad_var) {
MutableGradVarBase()->CopyFrom(grad_var, true); MutableGradVarBase()->CopyFrom(grad_var, true);
MutableGradVarBase()->SharedVar()->SetIsEmpty(false);
} }
const std::shared_ptr<VarBase>& MutableGradVarBase() { const std::shared_ptr<VarBase>& MutableGradVarBase() {
...@@ -142,6 +143,8 @@ class VarBase { ...@@ -142,6 +143,8 @@ class VarBase {
return grad_var_->MutableVar(); return grad_var_->MutableVar();
} }
bool IsLeaf() const { return var_->IsLeaf(); }
void SetOverridedStopGradient(bool stop_gradient) { void SetOverridedStopGradient(bool stop_gradient) {
var_->SetOverridedStopGradient(stop_gradient); var_->SetOverridedStopGradient(stop_gradient);
if (grad_var_) { if (grad_var_) {
...@@ -151,10 +154,8 @@ class VarBase { ...@@ -151,10 +154,8 @@ class VarBase {
bool OverridedStopGradient() const { return var_->OverridedStopGradient(); } bool OverridedStopGradient() const { return var_->OverridedStopGradient(); }
bool IsLeaf() const { return var_->IsLeaf(); }
void InnerSetOverridedStopGradient(bool stop_gradient) { void InnerSetOverridedStopGradient(bool stop_gradient) {
if (var_->InnerOverridedStopGradient() == -1) { if (InnerOverridedStopGradient() == -1) {
var_->InnerSetOverridedStopGradient(stop_gradient); var_->InnerSetOverridedStopGradient(stop_gradient);
if (grad_var_) { if (grad_var_) {
grad_var_->InnerSetOverridedStopGradient(stop_gradient); grad_var_->InnerSetOverridedStopGradient(stop_gradient);
...@@ -162,6 +163,10 @@ class VarBase { ...@@ -162,6 +163,10 @@ class VarBase {
} }
} }
int InnerOverridedStopGradient() const {
return var_->InnerOverridedStopGradient();
}
void SetPersistable(bool persistable) { var_->SetPersistable(persistable); } void SetPersistable(bool persistable) { var_->SetPersistable(persistable); }
bool Persistable() const { return var_->Persistable(); } bool Persistable() const { return var_->Persistable(); }
......
...@@ -41,7 +41,6 @@ class MyLayer(fluid.Layer): ...@@ -41,7 +41,6 @@ class MyLayer(fluid.Layer):
class MLP(fluid.Layer): class MLP(fluid.Layer):
def __init__(self, input_size): def __init__(self, input_size):
super(MLP, self).__init__() super(MLP, self).__init__()
self._linear1 = None
self._linear1 = Linear( self._linear1 = Linear(
input_size, input_size,
3, 3,
...@@ -607,12 +606,21 @@ class TestImperative(unittest.TestCase): ...@@ -607,12 +606,21 @@ class TestImperative(unittest.TestCase):
mlp2.clear_gradients() mlp2.clear_gradients()
self.assertTrue(np.array_equal(clear_loss.grad.numpy(), [1])) self.assertTrue(np.array_equal(clear_loss.grad.numpy(), [1]))
if ((batch_id + 1) % 10) == 0: if ((batch_id + 1) % 10) % 2 == 0:
mlp1.clear_gradients() mlp1.clear_gradients()
expected_weight1_grad = 0. expected_weight1_grad = 0.
expected_bias1_grad = 0. expected_bias1_grad = 0.
expected_weight2_grad = 0. expected_weight2_grad = 0.
expected_bias2_grad = 0. expected_bias2_grad = 0.
elif ((batch_id + 1) % 10) % 2 == 1:
mlp1.clear_gradients()
mlp1._linear1.weight._set_grad_ivar(
paddle.ones([input_size, 3]))
mlp1._linear2.weight._set_grad_ivar(paddle.ones([3, 4]))
expected_weight1_grad = 1.
expected_bias1_grad = 0.
expected_weight2_grad = 1.
expected_bias2_grad = 0.
with fluid.dygraph.guard(): with fluid.dygraph.guard():
test_single_api(False) test_single_api(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册