未验证 提交 af4f018a 编写于 作者: J Jiabin Yang 提交者: GitHub

【Bug fix】Fix dygraph double grad dtype error (#36125)

* fix dygraph double grad dtype error when calling for high differential senario

* reinvoke ci

* add test for partial_engine.cc
上级 0e07f20e
...@@ -1589,14 +1589,15 @@ void OperatorWithKernel::ParseInputDataType( ...@@ -1589,14 +1589,15 @@ void OperatorWithKernel::ParseInputDataType(
"not initialized.", "not initialized.",
Type(), name, ctx.InputNames(name).at(i))); Type(), name, ctx.InputNames(name).at(i)));
proto::VarType::Type tmp = t->type(); proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE( PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type,
tmp == *data_type || *data_type == default_data_type, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "The DataType of %s Op's duplicable or different "
"The DataType of %s Op's duplicable Variable %s must be " "slot Variable %s must be "
"consistent. The current variable type is (%s), but the " "consistent or reigster GetExpectedKernelType. The "
"previous variable type is (%s).", "current variable type is (%s), but the "
Type(), name, DataTypeToString(tmp), "previous variable type is (%s).",
DataTypeToString(*data_type))); Type(), name, DataTypeToString(tmp),
DataTypeToString(*data_type)));
*data_type = tmp; *data_type = tmp;
} }
} }
......
...@@ -307,7 +307,15 @@ static void FillConstantLike(const VariableWrapper &ref_var, ...@@ -307,7 +307,15 @@ static void FillConstantLike(const VariableWrapper &ref_var,
auto *dst_tensor = dst_var->MutableVar()->GetMutable<framework::LoDTensor>(); auto *dst_tensor = dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
dst_tensor->Resize(ref_tensor.dims()); dst_tensor->Resize(ref_tensor.dims());
dst_tensor->mutable_data(place, ref_var.DataType()); // TOOD(jiabin): Ugly fix here we have fwd_data_type_ and data_type, since in
// grad mission
// we can't get data_type_ directly. We need to check if we can only use
// default data_type for now.
if (ref_var.ForwardDataType() != -1) {
dst_tensor->mutable_data(place, ref_var.ForwardDataType());
} else {
dst_tensor->mutable_data(place, ref_var.DataType());
}
operators::math::set_constant(*dev_ctx, dst_tensor, value); operators::math::set_constant(*dev_ctx, dst_tensor, value);
} }
......
...@@ -162,6 +162,7 @@ class VariableWrapper { ...@@ -162,6 +162,7 @@ class VariableWrapper {
return tensor->type(); return tensor->type();
} else { } else {
VLOG(6) << "The tensor of variable " << name_ << " is not initialized"; VLOG(6) << "The tensor of variable " << name_ << " is not initialized";
return data_type_; return data_type_;
} }
} }
......
...@@ -215,10 +215,6 @@ class TestJacobianFloat64(TestJacobian): ...@@ -215,10 +215,6 @@ class TestJacobianFloat64(TestJacobian):
self.x = paddle.rand(shape=self.shape, dtype=self.dtype) self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype) self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
# NOTE(levi): skip this test case temporaryly.
def test_create_graph_true(self):
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册