diff --git a/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc b/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc index ff39196b92b3a046eb49507be94155cb0bdeca8c..135621ff5f66ef88f94269a493cf47504c32c5f2 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc @@ -154,7 +154,8 @@ std::shared_ptr ProcessGroupBKCL::Collective( const auto& place = in_tensor.place(); const auto& key = GetKeyFromPlace(place); - if (!calc_event_) { + if (!calc_event_ || + (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end())) { CreateBKCLEnvCache(place, key); } @@ -170,6 +171,8 @@ std::shared_ptr ProcessGroupBKCL::Collective( fn(out_tensor, in_tensor, comm_ctx->bkcl_context(), bkcl_stream); if (!use_calc_stream) { + PADDLE_ENFORCE_NOT_NULL( + comm_ctx.get(), platform::errors::Fatal("comm context is nullptr.")); task->comm_event_->Record(*comm_ctx.get()); } @@ -369,6 +372,10 @@ std::shared_ptr ProcessGroupBKCL::AllReduce( 1, platform::errors::InvalidArgument( "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInXPUPlace(in_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in XPUPlace.")); return Collective( &out_tensors[0], in_tensors[0], @@ -406,6 +413,10 @@ std::shared_ptr ProcessGroupBKCL::AllReduce( 1, platform::errors::InvalidArgument( "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInXPUPlace(in_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in XPUPlace.")); return Collective( &out_tensors[0], in_tensors[0], @@ -442,6 +453,10 @@ std::shared_ptr ProcessGroupBKCL::Broadcast( 1, platform::errors::InvalidArgument( "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInXPUPlace(in_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in XPUPlace.")); return Collective( &out_tensors[0], @@ -481,6 +496,10 @@ std::shared_ptr ProcessGroupBKCL::Broadcast( 1, platform::errors::InvalidArgument( "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInXPUPlace(in_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in XPUPlace.")); return Collective( &out_tensors[0], @@ -518,6 +537,10 @@ std::shared_ptr ProcessGroupBKCL::AllGather( 1, platform::errors::InvalidArgument( "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInXPUPlace(in_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in XPUPlace.")); PADDLE_ENFORCE_EQ( CheckTensorsInXPUPlace(out_tensors), true, diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index bd63aedd3a25a3d7f069b0d8f03df3947a9b596d..05b4991b858779f795492e2742665f1f204e54a3 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -25,6 +25,7 @@ XPUOpMap& get_kl2_ops() { {"abs", XPUKernelSet({phi::DataType::FLOAT32})}, {"abs_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"accuracy", XPUKernelSet({phi::DataType::FLOAT32})}, {"adadelta", XPUKernelSet({phi::DataType::FLOAT32})}, {"adamw", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"adam", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -402,6 +403,13 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32, phi::DataType::BOOL, phi::DataType::FLOAT32})}, + {"reshape_with_xshape", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::FLOAT32})}, {"resnet_unit", XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, {"resnet_unit_grad", @@ -485,6 +493,14 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32})}, + {"squeeze_with_xshape", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::FLOAT32})}, {"squeeze_grad", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, diff --git a/paddle/phi/kernels/elementwise_kernel.cc b/paddle/phi/kernels/elementwise_kernel.cc index c6031b34af249c6e054a27d22d0f726ea0ea91cb..50dc0b53443a650043241e50fb9b03b755064b3a 100644 --- a/paddle/phi/kernels/elementwise_kernel.cc +++ b/paddle/phi/kernels/elementwise_kernel.cc @@ -378,7 +378,9 @@ PD_REGISTER_KERNEL(multiply, ALL_LAYOUT, phi::MultiplyKernel, phi::dtype::float16, - float) {} + float, + int, + int64_t) {} PD_REGISTER_KERNEL(subtract, XPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/xpu/elementwise_multiply_kernel.cc b/paddle/phi/kernels/xpu/elementwise_multiply_kernel.cc index 811e09bf333349ae2cdc36c489496ac7c27c2f8d..e3b62d539486f8ed6bfdbc2822304ea1615f615e 100644 --- a/paddle/phi/kernels/xpu/elementwise_multiply_kernel.cc +++ b/paddle/phi/kernels/xpu/elementwise_multiply_kernel.cc @@ -42,4 +42,6 @@ PD_REGISTER_KERNEL(multiply_raw, ALL_LAYOUT, phi::MultiplyRawKernel, phi::dtype::float16, - float) {} + float, + int, + int64_t) {}