未验证 提交 4ac2729c 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix the get of attr pad_value under dtype float16 in pad2d op (#22909)

test=release/1.7
上级 c3a87e3d
...@@ -345,7 +345,7 @@ class Pad2dCPUKernel : public framework::OpKernel<T> { ...@@ -345,7 +345,7 @@ class Pad2dCPUKernel : public framework::OpKernel<T> {
GetPaddings(pads, context); GetPaddings(pads, context);
auto mode = context.Attr<std::string>("mode"); auto mode = context.Attr<std::string>("mode");
auto data_format = context.Attr<std::string>("data_format"); auto data_format = context.Attr<std::string>("data_format");
T value = context.Attr<T>("pad_value"); T value = static_cast<T>(context.Attr<float>("pad_value"));
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto in_dims = x->dims(); auto in_dims = x->dims();
......
...@@ -314,7 +314,7 @@ class Pad2dCUDAKernel : public framework::OpKernel<T> { ...@@ -314,7 +314,7 @@ class Pad2dCUDAKernel : public framework::OpKernel<T> {
GetPaddings(pads, context); GetPaddings(pads, context);
auto mode = context.Attr<std::string>("mode"); auto mode = context.Attr<std::string>("mode");
auto data_format = context.Attr<std::string>("data_format"); auto data_format = context.Attr<std::string>("data_format");
T value = context.Attr<T>("pad_value"); T value = static_cast<T>(context.Attr<float>("pad_value"));
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto in_dims = x->dims(); auto in_dims = x->dims();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册