diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 068de4f0435bbec8fb83aa9ee8b0cefdd71be06b..9014871229b39bd33771c1a31823923f7dfe49c6 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -211,70 +211,6 @@ void SplitTensorsWithType( } #endif -// NOTE(liubo48): Only implement operators::math::SplitFunctor for npu now. -// If later the operators::StridedMemcpyWithAxis0 is supported, -// then this specific SplitTensorsForAllReduce can be removed. -#ifdef PADDLE_WITH_ASCEND_CL -template <> -void SplitTensorsForAllReduce( - const platform::NPUDeviceContext &context, - framework::Variable *p_dense_contents, - std::vector *p_dense_tensors) { - auto *in = p_dense_contents->GetMutable(); - std::vector outs; - std::vector shape_refer; - - outs.reserve(p_dense_tensors->size()); - shape_refer.reserve(p_dense_tensors->size()); - - for (auto &tensor : *p_dense_tensors) { - outs.emplace_back(&tensor); - shape_refer.emplace_back(&tensor); - } - operators::math::SplitFunctor - split_functor_; - split_functor_(context, *in, shape_refer, 0, &outs); -} - -template <> -void ConcatTensorsWithType( - const platform::NPUDeviceContext &context, - const std::vector &dense_tensors_, - framework::Variable *p_dense_contents, - framework::proto::VarType::Type type) { - switch (type) { - case framework::proto::VarType::FP32: - ConcatTensorsForAllReduce( - context, dense_tensors_, p_dense_contents); - break; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "Data type (%s) is not supported when it concats tensors for " - "allreduce.", - framework::DataTypeToString(type))); - } -} - -template <> -void SplitTensorsWithType( - const platform::NPUDeviceContext &context, - framework::Variable *p_dense_contents, - std::vector *p_dense_tensors, - framework::proto::VarType::Type type) { - switch (type) { - case framework::proto::VarType::FP32: - SplitTensorsForAllReduce( - context, p_dense_contents, p_dense_tensors); - break; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "Data type (%s) is not supported when it splits tensors for " - "allreduce.", - framework::DataTypeToString(type))); - } -} -#endif - void Group::ConcatTensors(const platform::DeviceContext &context) { auto place = context.GetPlace(); if (platform::is_gpu_place(place)) { diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 996cec29e185bd55db3d9869ca16819b729f99c9..f308af04e5e580579dc381cbed55b798ca88da0a 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -238,7 +238,7 @@ def monkey_patch_varbase(): "Tensor shape not match, Tensor of grad_tensor [ {} ] with shape {} mismatch Tensor [ {} ] with shape {}".format( grad_tensor.name, grad_tensor.shape, self.name, self.shape) - if paddle.is_compiled_with_xpu(): + if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_npu(): # TODO(liuyuhui): Currently only for xpu. Will be removed in the future. scaled_loss = scale_loss(self) core.dygraph_run_backward([scaled_loss], [grad_tensor],