未验证 提交 5453a912 编写于 作者: M mapingshuo 提交者: GitHub

add fp64 support in sequence_pool, test=develop (#25662)

add fp64 support in sequence_pool, test=develop
上级 8dea7bed
......@@ -180,7 +180,10 @@ REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp,
ops::SequencePoolGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
sequence_pool,
ops::SequencePoolKernel<paddle::platform::CPUDeviceContext, float>);
ops::SequencePoolKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequencePoolKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
sequence_pool_grad,
ops::SequencePoolGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::SequencePoolGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequencePoolGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -346,7 +346,8 @@ def sequence_pool(input, pool_type, is_test=False, pad_value=0.0):
"""
assert not in_dygraph_mode(), (
"sequence layer is not supported in dygraph mode yet.")
check_variable_and_dtype(input, 'input', ['float32'], 'sequence_pool')
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'sequence_pool')
helper = LayerHelper('sequence_pool', **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册