未验证 提交 af4f018a 编写于 作者: J Jiabin Yang 提交者: GitHub

【Bug fix】Fix dygraph double grad dtype error (#36125)

* fix dygraph double grad dtype error when calling for high differential senario

* reinvoke ci

* add test for partial_engine.cc
上级 0e07f20e
......@@ -1589,14 +1589,15 @@ void OperatorWithKernel::ParseInputDataType(
"not initialized.",
Type(), name, ctx.InputNames(name).at(i)));
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(
tmp == *data_type || *data_type == default_data_type,
platform::errors::InvalidArgument(
"The DataType of %s Op's duplicable Variable %s must be "
"consistent. The current variable type is (%s), but the "
"previous variable type is (%s).",
Type(), name, DataTypeToString(tmp),
DataTypeToString(*data_type)));
PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type,
platform::errors::InvalidArgument(
"The DataType of %s Op's duplicable or different "
"slot Variable %s must be "
"consistent or reigster GetExpectedKernelType. The "
"current variable type is (%s), but the "
"previous variable type is (%s).",
Type(), name, DataTypeToString(tmp),
DataTypeToString(*data_type)));
*data_type = tmp;
}
}
......
......@@ -307,7 +307,15 @@ static void FillConstantLike(const VariableWrapper &ref_var,
auto *dst_tensor = dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
dst_tensor->Resize(ref_tensor.dims());
dst_tensor->mutable_data(place, ref_var.DataType());
// TOOD(jiabin): Ugly fix here we have fwd_data_type_ and data_type, since in
// grad mission
// we can't get data_type_ directly. We need to check if we can only use
// default data_type for now.
if (ref_var.ForwardDataType() != -1) {
dst_tensor->mutable_data(place, ref_var.ForwardDataType());
} else {
dst_tensor->mutable_data(place, ref_var.DataType());
}
operators::math::set_constant(*dev_ctx, dst_tensor, value);
}
......
......@@ -162,6 +162,7 @@ class VariableWrapper {
return tensor->type();
} else {
VLOG(6) << "The tensor of variable " << name_ << " is not initialized";
return data_type_;
}
}
......
......@@ -215,10 +215,6 @@ class TestJacobianFloat64(TestJacobian):
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
# NOTE(levi): skip this test case temporaryly.
def test_create_graph_true(self):
pass
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册