未验证 提交 edd66f2e 编写于 作者: S Sing_chan 提交者: GitHub

make full_like support double_max in dygraph (#45385)

* make full_like support double_max in dygraph

* fix bug
上级 5df464fe
...@@ -1289,7 +1289,7 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj, ...@@ -1289,7 +1289,7 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
int64_t value = CastPyArg2Long(obj, op_type, arg_pos); int64_t value = CastPyArg2Long(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value); return paddle::experimental::Scalar(value);
} else if (PyFloat_Check(obj)) { } else if (PyFloat_Check(obj)) {
float value = CastPyArg2Float(obj, op_type, arg_pos); double value = CastPyArg2Double(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value); return paddle::experimental::Scalar(value);
} else if (IsEagerTensor(obj)) { } else if (IsEagerTensor(obj)) {
paddle::experimental::Tensor& value = GetTensorFromPyObject( paddle::experimental::Tensor& value = GetTensorFromPyObject(
......
...@@ -44,7 +44,7 @@ void FullLikeKernel(const Context& dev_ctx, ...@@ -44,7 +44,7 @@ void FullLikeKernel(const Context& dev_ctx,
const Scalar& val, const Scalar& val,
DataType dtype, DataType dtype,
DenseTensor* out) { DenseTensor* out) {
auto value = val.to<float>(); auto value = val.to<double>();
using CommonType = typename std::common_type< using CommonType = typename std::common_type<
float, float,
typename std::conditional<std::is_same<T, phi::dtype::float16>::value, typename std::conditional<std::is_same<T, phi::dtype::float16>::value,
......
...@@ -60,7 +60,7 @@ void FullLikeKernel(const Context& dev_ctx, ...@@ -60,7 +60,7 @@ void FullLikeKernel(const Context& dev_ctx,
const Scalar& val, const Scalar& val,
DataType dtype, DataType dtype,
DenseTensor* out) { DenseTensor* out) {
auto value = val.to<float>(); auto value = val.to<double>();
using CommonType = typename std::common_type< using CommonType = typename std::common_type<
float, float,
typename std::conditional< typename std::conditional<
......
...@@ -70,7 +70,7 @@ void FullLikeKernel(const Context& dev_ctx, ...@@ -70,7 +70,7 @@ void FullLikeKernel(const Context& dev_ctx,
DataType dtype, DataType dtype,
DenseTensor* out) { DenseTensor* out) {
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
auto value = val.to<float>(); auto value = val.to<double>();
using XPUInTDType = typename XPUTypeTrait<T>::Type; using XPUInTDType = typename XPUTypeTrait<T>::Type;
using CommonType = typename std::common_type< using CommonType = typename std::common_type<
float, float,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册