未验证 提交 dffb0b22 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

fix set_grad_ivar bug of Tensor.backward (#34819)

上级 6326c3ef
......@@ -184,6 +184,12 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
auto data_type = src_tensor.type();
auto place = src_tensor.place();
PADDLE_ENFORCE_EQ(dst_tensor->type(), data_type,
platform::errors::PreconditionNotMet(
"The data type of source tensor and destination tensor "
"should be equal, Otherwise, the calculation results "
"will be incorrect."));
#define PADDLE_TENSOR_ADD(cpp_type) \
if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
TensorAddFunctor<cpp_type> func( \
......@@ -422,9 +428,9 @@ void GradientAccumulator::AccumulateGrad() {
auto* src = inner_var_->MutableVar();
auto* dst = var_->MutableVar();
if (!var_->IsEmpty()) {
VLOG(6) << "Leaf Gradient Var(" << var_->Name()
<< ") has been calculated by previous graph, will accumulate on "
"previous graph.";
VLOG(6) << "Leaf Var(" << var_->Name()
<< ")'s Gradient has been initizlized, will accumulate on "
"previous gradient.";
if (dst->IsType<framework::LoDTensor>()) {
if (src->IsType<framework::LoDTensor>()) {
TensorAdd(*src, dst);
......@@ -444,8 +450,9 @@ void GradientAccumulator::AccumulateGrad() {
"Only support LoDTensor and SelectedRows for gradient var"));
}
} else {
VLOG(6) << "Leaf Gradient Var(" << var_->Name()
<< ") has not been initialized, not accumulate. Just move";
VLOG(6)
<< "Leaf Var(" << var_->Name()
<< ")'s Gradient has not been initialized, not accumulate. Just move";
*(dst) = std::move(*src);
var_->SetType(inner_var_->Type());
var_->SetDataType(inner_var_->DataType());
......
......@@ -277,32 +277,73 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
}
void VarBase::CopyFrom(const VarBase& src, const bool blocking) {
if (SharedVar()->IsEmpty()) {
VLOG(3) << "deep copy Variable from " << src.Name() << " to " << Name();
SetPersistable(src.Persistable());
if (src.SharedVar()->IsEmpty()) {
return;
}
VLOG(3) << "Deep copy Tensor from " << src.Name() << " to " << Name();
if (Var().IsInitialized()) {
PADDLE_ENFORCE_EQ(DataType(), src.DataType(),
platform::errors::PreconditionNotMet(
"Tensor %s has different data type with Tensor %s, "
"Tensor Copy cannot be performed!",
Name(), src.Name()));
PADDLE_ENFORCE_EQ(Type(), src.Type(),
platform::errors::PreconditionNotMet(
"Tensor %s has different type with Tensor %s, Tensor "
"Copy cannot be performed!",
Name(), src.Name()));
} else {
SetDataType(src.DataType());
SetType(src.Type());
SetOverridedStopGradient(src.OverridedStopGradient());
if (!src.SharedVar()->IsEmpty()) {
const platform::Place& place = src.Place();
if (src.Var().IsType<framework::LoDTensor>()) {
auto& src_tensor = src.Var().Get<framework::LoDTensor>();
auto* dst_tensor = MutableVar()->GetMutable<framework::LoDTensor>();
dst_tensor->set_lod(src_tensor.lod());
framework::TensorCopy(src_tensor, place, dst_tensor);
} else if (src.Var().IsType<framework::SelectedRows>()) {
auto& src_selected_rows = src.Var().Get<framework::SelectedRows>();
auto* dst_selected_rows =
MutableVar()->GetMutable<framework::SelectedRows>();
dst_selected_rows->set_height(src_selected_rows.height());
dst_selected_rows->set_rows(src_selected_rows.rows());
framework::TensorCopy(src_selected_rows.value(), place,
dst_selected_rows->mutable_value());
}
if (blocking) {
platform::DeviceContextPool::Instance().Get(place)->Wait();
}
SetPersistable(src.Persistable());
InnerSetOverridedStopGradient(src.OverridedStopGradient());
}
platform::Place place = src.Place();
if (src.Var().IsType<framework::LoDTensor>()) {
auto& src_tensor = src.Var().Get<framework::LoDTensor>();
auto* dst_tensor = MutableVar()->GetMutable<framework::LoDTensor>();
if (dst_tensor && dst_tensor->IsInitialized()) {
PADDLE_ENFORCE_EQ(dst_tensor->dims(), src_tensor.dims(),
platform::errors::PreconditionNotMet(
"Tensor %s has different dims with Tensor %s, "
"Tensor Copy cannot be performed!",
Name(), src.Name()));
PADDLE_ENFORCE_EQ(dst_tensor->lod(), src_tensor.lod(),
platform::errors::PreconditionNotMet(
"Tensor %s has different dims with Tensor %s, "
"Tensor Copy cannot be performed!",
Name(), src.Name()));
place = Place();
} else {
dst_tensor->set_lod(src_tensor.lod());
dst_tensor->Resize(src_tensor.dims());
}
framework::TensorCopy(src_tensor, place, dst_tensor);
} else if (src.Var().IsType<framework::SelectedRows>()) {
auto& src_selected_rows = src.Var().Get<framework::SelectedRows>();
auto* dst_selected_rows =
MutableVar()->GetMutable<framework::SelectedRows>();
dst_selected_rows->set_height(src_selected_rows.height());
dst_selected_rows->set_rows(src_selected_rows.rows());
auto& src_tensor = src_selected_rows.value();
auto* dst_tensor = dst_selected_rows->mutable_value();
if (dst_tensor && dst_tensor->IsInitialized()) {
PADDLE_ENFORCE_EQ(dst_tensor->dims(), src_tensor.dims(),
platform::errors::PreconditionNotMet(
"Tensor %s has different dims with Tensor %s, "
"Tensor Copy cannot be performed!",
Name(), src.Name()));
place = Place();
} else {
dst_tensor->Resize(src_tensor.dims());
}
framework::TensorCopy(src_tensor, place, dst_tensor);
}
if (blocking) {
platform::DeviceContextPool::Instance().Get(place)->Wait();
}
}
......
......@@ -110,6 +110,7 @@ class VarBase {
void SetGradVarBase(const VarBase& grad_var) {
MutableGradVarBase()->CopyFrom(grad_var, true);
MutableGradVarBase()->SharedVar()->SetIsEmpty(false);
}
const std::shared_ptr<VarBase>& MutableGradVarBase() {
......@@ -142,6 +143,8 @@ class VarBase {
return grad_var_->MutableVar();
}
bool IsLeaf() const { return var_->IsLeaf(); }
void SetOverridedStopGradient(bool stop_gradient) {
var_->SetOverridedStopGradient(stop_gradient);
if (grad_var_) {
......@@ -151,10 +154,8 @@ class VarBase {
bool OverridedStopGradient() const { return var_->OverridedStopGradient(); }
bool IsLeaf() const { return var_->IsLeaf(); }
void InnerSetOverridedStopGradient(bool stop_gradient) {
if (var_->InnerOverridedStopGradient() == -1) {
if (InnerOverridedStopGradient() == -1) {
var_->InnerSetOverridedStopGradient(stop_gradient);
if (grad_var_) {
grad_var_->InnerSetOverridedStopGradient(stop_gradient);
......@@ -162,6 +163,10 @@ class VarBase {
}
}
int InnerOverridedStopGradient() const {
return var_->InnerOverridedStopGradient();
}
void SetPersistable(bool persistable) { var_->SetPersistable(persistable); }
bool Persistable() const { return var_->Persistable(); }
......
......@@ -41,7 +41,6 @@ class MyLayer(fluid.Layer):
class MLP(fluid.Layer):
def __init__(self, input_size):
super(MLP, self).__init__()
self._linear1 = None
self._linear1 = Linear(
input_size,
3,
......@@ -607,12 +606,21 @@ class TestImperative(unittest.TestCase):
mlp2.clear_gradients()
self.assertTrue(np.array_equal(clear_loss.grad.numpy(), [1]))
if ((batch_id + 1) % 10) == 0:
if ((batch_id + 1) % 10) % 2 == 0:
mlp1.clear_gradients()
expected_weight1_grad = 0.
expected_bias1_grad = 0.
expected_weight2_grad = 0.
expected_bias2_grad = 0.
elif ((batch_id + 1) % 10) % 2 == 1:
mlp1.clear_gradients()
mlp1._linear1.weight._set_grad_ivar(
paddle.ones([input_size, 3]))
mlp1._linear2.weight._set_grad_ivar(paddle.ones([3, 4]))
expected_weight1_grad = 1.
expected_bias1_grad = 0.
expected_weight2_grad = 1.
expected_bias2_grad = 0.
with fluid.dygraph.guard():
test_single_api(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册