未验证 提交 d131e679 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add c_identity op (#52982) (#53013)

* [CustomDevice] add c_identity op

* fix use calc stream
上级 585f9d65
...@@ -187,7 +187,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective( ...@@ -187,7 +187,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
std::vector<phi::DenseTensor>& inputs, std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, std::vector<phi::DenseTensor>& outputs,
Fn fn, Fn fn,
CommType op_type) { CommType op_type,
bool sync_op,
bool use_calc_stream) {
const auto places = GetPlaceList(inputs); const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places); const auto key = GetKeyFromPlaces(places);
...@@ -199,20 +201,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective( ...@@ -199,20 +201,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
} }
auto& ccl_comms = places_to_customcomm_[key]; auto& ccl_comms = places_to_customcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); if (!use_calc_stream) {
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
}
auto task = CreateTask(places, rank_, op_type, inputs); auto task = CreateTask(places, rank_, op_type, inputs);
task->SetOutputs(outputs); task->SetOutputs(outputs);
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]); phi::DeviceGuard guard(places[i]);
const auto& ccl_stream = places_to_ctx_[key][i]->stream(); const auto& ccl_stream =
use_calc_stream ? reinterpret_cast<phi::CustomContext*>(
phi::DeviceContextPool::Instance().Get(places[i]))
->stream()
: places_to_ctx_[key][i]->stream();
phi::stream::Stream stream(places[i], ccl_stream); phi::stream::Stream stream(places[i], ccl_stream);
fn(inputs[i], outputs[i], ccl_comms[i]->GetCustomCCLComm(), stream); fn(inputs[i], outputs[i], ccl_comms[i]->GetCustomCCLComm(), stream);
} }
for (size_t i = 0; i < inputs.size(); ++i) { if (!use_calc_stream) {
phi::DeviceGuard guard(places[i]); for (size_t i = 0; i < inputs.size(); ++i) {
task->control_events_[i].Record(*places_to_ctx_[key][i]); phi::DeviceGuard guard(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
}
} }
return task; return task;
} }
...@@ -280,7 +290,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather( ...@@ -280,7 +290,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
comm, comm,
stream); stream);
}, },
CommType::ALLGATHER); CommType::ALLGATHER,
sync_op,
use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
...@@ -322,7 +334,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather( ...@@ -322,7 +334,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
comm, comm,
stream); stream);
}, },
CommType::ALLGATHER); CommType::ALLGATHER,
false,
false);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
...@@ -333,7 +347,36 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce( ...@@ -333,7 +347,36 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
bool use_calc_stream) { bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor}; std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor}; std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllReduce(in_wrapper, out_wrapper, opts); PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllReduce(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
ToCustomCCLRedType(opts.reduce_op),
comm,
stream);
},
CommType::ALLREDUCE,
sync_op,
use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
...@@ -342,9 +385,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce( ...@@ -342,9 +385,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
const AllreduceOptions& opts, const AllreduceOptions& opts,
bool sync_op // for compatibility, no use now bool sync_op // for compatibility, no use now
) { ) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor}; return AllReduce(out_tensor, in_tensor, opts, sync_op, false);
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllReduce(in_wrapper, out_wrapper, opts);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
...@@ -378,7 +419,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce( ...@@ -378,7 +419,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
comm, comm,
stream); stream);
}, },
CommType::ALLREDUCE); CommType::ALLREDUCE,
false,
false);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
...@@ -389,7 +432,47 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast( ...@@ -389,7 +432,47 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
bool use_calc_stream) { bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor}; std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor}; std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts); PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
int root = opts.source_rank * in_wrapper.size() + opts.source_root;
if (rank_ == root) {
return phi::DeviceManager::CCLBroadcast(
device_type_,
input.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
root,
comm,
stream);
} else {
return phi::DeviceManager::CCLBroadcast(
device_type_,
output.data(),
output.numel(),
phi::ccl::ToCCLDataType(output.dtype()),
root,
comm,
stream);
}
},
CommType::BROADCAST,
sync_op,
use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
...@@ -397,9 +480,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast( ...@@ -397,9 +480,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts, const BroadcastOptions& opts,
bool sync_op) { bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor}; return Broadcast(out_tensor, in_tensor, opts, sync_op, false);
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
...@@ -489,7 +570,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast( ...@@ -489,7 +570,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
stream); stream);
} }
}, },
CommType::BROADCAST); CommType::BROADCAST,
false,
false);
} }
std::shared_ptr<ProcessGroupCustom> std::shared_ptr<ProcessGroupCustom>
......
...@@ -176,7 +176,9 @@ class ProcessGroupCustom : public ProcessGroupWithStream { ...@@ -176,7 +176,9 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn, Fn fn,
CommType op_type); CommType op_type,
bool sync_op,
bool use_calc_stream);
void CreateCustomManagerCache(const std::string& places_key, void CreateCustomManagerCache(const std::string& places_key,
const std::vector<Place>& places); const std::vector<Place>& places);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/custom_device_common_op_registry.h" #include "paddle/fluid/operators/custom_device_common_op_registry.h"
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/operators/collective/c_concat_op.h" #include "paddle/fluid/operators/collective/c_concat_op.h"
#include "paddle/fluid/operators/collective/c_identity_op.h"
#include "paddle/fluid/operators/load_combine_op.h" #include "paddle/fluid/operators/load_combine_op.h"
#include "paddle/fluid/operators/run_program_op.h" #include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/operators/save_combine_op.h" #include "paddle/fluid/operators/save_combine_op.h"
...@@ -589,6 +590,21 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { ...@@ -589,6 +590,21 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
paddle::platform::CustomDeviceContext, paddle::platform::CustomDeviceContext,
paddle::platform::float16>) {} paddle::platform::float16>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
c_identity,
device_type,
paddle::operators::
CIdentityOpKernel<float, paddle::platform::CustomDeviceContext>,
paddle::operators::
CIdentityOpKernel<double, paddle::platform::CustomDeviceContext>,
paddle::operators::
CIdentityOpKernel<int, paddle::platform::CustomDeviceContext>,
paddle::operators::
CIdentityOpKernel<int64_t, paddle::platform::CustomDeviceContext>,
paddle::operators::CIdentityOpKernel<
paddle::platform::float16,
paddle::platform::CustomDeviceContext>) {}
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册