未验证 提交 d131e679 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add c_identity op (#52982) (#53013)

* [CustomDevice] add c_identity op

* fix use calc stream
上级 585f9d65
......@@ -187,7 +187,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
Fn fn,
CommType op_type) {
CommType op_type,
bool sync_op,
bool use_calc_stream) {
const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places);
......@@ -199,20 +201,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
}
auto& ccl_comms = places_to_customcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
if (!use_calc_stream) {
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
}
auto task = CreateTask(places, rank_, op_type, inputs);
task->SetOutputs(outputs);
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
const auto& ccl_stream = places_to_ctx_[key][i]->stream();
const auto& ccl_stream =
use_calc_stream ? reinterpret_cast<phi::CustomContext*>(
phi::DeviceContextPool::Instance().Get(places[i]))
->stream()
: places_to_ctx_[key][i]->stream();
phi::stream::Stream stream(places[i], ccl_stream);
fn(inputs[i], outputs[i], ccl_comms[i]->GetCustomCCLComm(), stream);
}
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
if (!use_calc_stream) {
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
}
}
return task;
}
......@@ -280,7 +290,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
comm,
stream);
},
CommType::ALLGATHER);
CommType::ALLGATHER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
......@@ -322,7 +334,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
comm,
stream);
},
CommType::ALLGATHER);
CommType::ALLGATHER,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
......@@ -333,7 +347,36 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllReduce(in_wrapper, out_wrapper, opts);
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllReduce(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
ToCustomCCLRedType(opts.reduce_op),
comm,
stream);
},
CommType::ALLREDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
......@@ -342,9 +385,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
const AllreduceOptions& opts,
bool sync_op // for compatibility, no use now
) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllReduce(in_wrapper, out_wrapper, opts);
return AllReduce(out_tensor, in_tensor, opts, sync_op, false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
......@@ -378,7 +419,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
comm,
stream);
},
CommType::ALLREDUCE);
CommType::ALLREDUCE,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
......@@ -389,7 +432,47 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts);
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
int root = opts.source_rank * in_wrapper.size() + opts.source_root;
if (rank_ == root) {
return phi::DeviceManager::CCLBroadcast(
device_type_,
input.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
root,
comm,
stream);
} else {
return phi::DeviceManager::CCLBroadcast(
device_type_,
output.data(),
output.numel(),
phi::ccl::ToCCLDataType(output.dtype()),
root,
comm,
stream);
}
},
CommType::BROADCAST,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
......@@ -397,9 +480,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts);
return Broadcast(out_tensor, in_tensor, opts, sync_op, false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
......@@ -489,7 +570,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
stream);
}
},
CommType::BROADCAST);
CommType::BROADCAST,
false,
false);
}
std::shared_ptr<ProcessGroupCustom>
......
......@@ -176,7 +176,9 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn,
CommType op_type);
CommType op_type,
bool sync_op,
bool use_calc_stream);
void CreateCustomManagerCache(const std::string& places_key,
const std::vector<Place>& places);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/custom_device_common_op_registry.h"
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/operators/collective/c_concat_op.h"
#include "paddle/fluid/operators/collective/c_identity_op.h"
#include "paddle/fluid/operators/load_combine_op.h"
#include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/operators/save_combine_op.h"
......@@ -589,6 +590,21 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
paddle::platform::CustomDeviceContext,
paddle::platform::float16>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
c_identity,
device_type,
paddle::operators::
CIdentityOpKernel<float, paddle::platform::CustomDeviceContext>,
paddle::operators::
CIdentityOpKernel<double, paddle::platform::CustomDeviceContext>,
paddle::operators::
CIdentityOpKernel<int, paddle::platform::CustomDeviceContext>,
paddle::operators::
CIdentityOpKernel<int64_t, paddle::platform::CustomDeviceContext>,
paddle::operators::CIdentityOpKernel<
paddle::platform::float16,
paddle::platform::CustomDeviceContext>) {}
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册