未验证 提交 8a859554 编写于 作者: W WangZhen 提交者: GitHub

Fix default GetExpectedKernelType for ops supported tensor attrs (#49414)

* Fix default GetExpectedKernelType for ops supported tensor attrs
上级 3ffcd693
...@@ -2715,7 +2715,23 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -2715,7 +2715,23 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
proto::VarType::Type dafault_data_type = proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1); static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type; proto::VarType::Type data_type = dafault_data_type;
for (auto* name : ctx.InNameList()) {
auto in_name_list = ctx.InNameList();
if (Info().HasOpProtoAndChecker()) {
for (auto& attr : Info().Proto().attrs()) {
auto it =
std::find_if(in_name_list.begin(),
in_name_list.end(),
[&attr](const std::string* name) {
return attr.support_tensor() && *name == attr.name();
});
if (it != in_name_list.end()) {
in_name_list.erase(it);
}
}
}
for (auto* name : in_name_list) {
if (ctx.InputSize(*name) == 1UL) { if (ctx.InputSize(*name) == 1UL) {
ParseInputDataType(ctx.InputVar(*name), *name, &data_type); ParseInputDataType(ctx.InputVar(*name), *name, &data_type);
} else { } else {
......
...@@ -178,6 +178,22 @@ class TestPaddingValueTensor2(TestPaddingValueTensor): ...@@ -178,6 +178,22 @@ class TestPaddingValueTensor2(TestPaddingValueTensor):
return out return out
class TestPaddingValueTensor3(unittest.TestCase):
def test_static(self):
np_x = np.random.random((16, 16)).astype('float32')
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
x = paddle.assign(np_x).astype('float32')
pad_value = paddle.assign([0.0]).astype('float64')
y = paddle.nn.functional.pad(x, [0, 1, 2, 3], value=pad_value)
exe = paddle.static.Executor(paddle.CPUPlace())
[pd_out] = exe.run(main_prog, fetch_list=[y])
np_out = np.pad(np_x, [(0, 1), (2, 3)], constant_values=0.0)
np.testing.assert_allclose(pd_out, np_out)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册