未验证 提交 d53e567a 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of recompute in hybridparallel (#35588)

上级 652da1f4
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/flatten_op.h" #include "paddle/fluid/operators/flatten_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
flatten, ops::FlattenKernel<paddle::platform::CUDADeviceContext, float>, flatten, ops::FlattenKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -50,6 +51,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -50,6 +51,8 @@ REGISTER_OP_CUDA_KERNEL(
flatten_contiguous_range, flatten_contiguous_range,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext, ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
float>, float>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext, ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
double>, double>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext, ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
...@@ -63,6 +66,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -63,6 +66,8 @@ REGISTER_OP_CUDA_KERNEL(
flatten_contiguous_range_grad, flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext, ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
float>, float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext, ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
double>, double>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext, ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
......
...@@ -133,6 +133,7 @@ def _split_activation(tensor): ...@@ -133,6 +133,7 @@ def _split_activation(tensor):
# use inplace operation to save memory # use inplace operation to save memory
data = tensor.flatten_() data = tensor.flatten_()
part_size = tensor_numel // mp_degree part_size = tensor_numel // mp_degree
start = part_size * mp_rank start = part_size * mp_rank
end = start + part_size end = start + part_size
......
...@@ -94,6 +94,7 @@ black_list = { ...@@ -94,6 +94,7 @@ black_list = {
'softmax', 'softmax',
'softmax_with_cross_entropy', 'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits', 'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy', 'cross_entropy',
'cross_entropy2', 'cross_entropy2',
# fp16 is slower than fp32, though fp16 is supported. # fp16 is slower than fp32, though fp16 is supported.
......
...@@ -45,6 +45,7 @@ BLACK_LIST = { ...@@ -45,6 +45,7 @@ BLACK_LIST = {
'softmax', 'softmax',
'softmax_with_cross_entropy', 'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits', 'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy', 'cross_entropy',
'cross_entropy2', 'cross_entropy2',
# default fp32 can avoid return inf when the sum value large than 65504 # default fp32 can avoid return inf when the sum value large than 65504
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册