提交 2f006929 编写于 作者: A Andy Ly 提交者: TensorFlower Gardener

Unify keys in tf.XlaSendToHost, tf.XlaRecvFromHost, and tf._XlaHostComputeMlir legalizations.

This keeps the logic for suffixes appended to keys in a centralized location instead of having passes handle it when creating such ops.

PiperOrigin-RevId: 327867882
Change-Id: I1f6f30486fbf29d3c0028d5996d2009f69bae24a
上级 0d10d5d0
......@@ -169,7 +169,7 @@ func @send_to_host(%arg0: tensor<i32>) {
// CHECK: "mhlo.send"([[ARG0]], [[INIT_TOKEN]])
// CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64}
// CHECK-SAME: is_host_transfer = true
// CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "send_key"}
// CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "send_key_dtoh_0"}
// CHECK-SAME: (tensor<i32>, !mhlo.token) -> !mhlo.token
"tf.XlaSendToHost"(%arg0) {key = "send_key"} : (tensor<i32>) -> ()
return
......@@ -186,7 +186,7 @@ func @recv_from_host() -> tensor<i32> {
// CHECK: [[RECV_TUPLE:%.*]] = "mhlo.recv"([[INIT_TOKEN]])
// CHECK-SAME: channel_id = {handle = 1 : i64, type = 3 : i64}
// CHECK-SAME: is_host_transfer = true
// CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "recv_key"}
// CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "recv_key_htod_0"}
// CHECK-SAME: (!mhlo.token) -> tuple<tensor<i32>, !mhlo.token>
......
......@@ -215,11 +215,17 @@ void SetOpSharding(Operation* op, int64_t tpu_core) {
}
// Assigns frontend attributes holding information about data type and
// TensorFlow rendezvous channel name.
void SetFrontendAttributes(Operation* op, StringRef key, Type type) {
// TensorFlow rendezvous channel name. The TensorFlow rendezvous channel name is
// handled differently as individual names are used per data send and receive.
void SetFrontendAttributes(Operation* op, int32_t index, StringRef key,
Type type, bool device_to_host) {
MLIRContext* context = op->getContext();
auto rendezvous_name = StringAttr::get(key, context);
std::string formatted_key =
device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str()
: llvm::formatv("{0}_htod_{1}", key, index).str();
auto rendezvous_name = StringAttr::get(formatted_key, context);
auto rendezvous_name_attr = NamedAttribute(
Identifier::get(kXlaHostTransferRendezvousNameAttr, context),
rendezvous_name);
......@@ -239,24 +245,10 @@ void SetFrontendAttributes(Operation* op, StringRef key, Type type) {
op->setAttr(kFrontendAttributesAttr, frontend_attributes);
}
// Assigns frontend attributes holding information about data type and
// TensorFlow rendezvous channel name specific to `tf._XlaHostComputeMlir`.
// TensorFlow rendezvous channel name is handled differently as individual names
// are used per data send and receive.
void SetFrontendAttributes(Operation* op, int32_t index, StringRef key,
Type type, bool device_to_host) {
std::string formatted_key =
device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str()
: llvm::formatv("{0}_htod_{1}", key, index).str();
return SetFrontendAttributes(op, formatted_key, type);
}
// Creates a `mhlo.send` op for sending value `operand`. If `index` is set,
// `key` will be rewritten with a suffix and index. If `tpu_core` is set, op
// sharding for the respective device will be set.
// Creates a `mhlo.send` op for sending value `operand`. If `tpu_core` is set,
// op sharding for the respective device will be set.
Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc,
Value operand, StringRef key, const Optional<size_t>& index,
Value operand, StringRef key, size_t index,
const Optional<int64_t>& tpu_core, Value token) {
// type 2 == DEVICE_TO_HOST
auto channel_handle = ChannelHandle::get(
......@@ -266,23 +258,18 @@ Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc,
loc, token.getType(), operand, token, channel_handle,
/*is_host_transfer=*/builder.getBoolAttr(true));
if (index) {
SetFrontendAttributes(send, *index, key, operand.getType(),
/*device_to_host=*/true);
} else {
SetFrontendAttributes(send, key, operand.getType());
}
SetFrontendAttributes(send, index, key, operand.getType(),
/*device_to_host=*/true);
if (tpu_core) SetOpSharding(send, *tpu_core);
return send.getResult();
}
// Creates a `mhlo.recv` op for receiving a value. If `index` is set, `key` will
// be rewritten with a suffix and index. If `tpu_core` is set, op sharding for
// the respective device will be set.
// Creates a `mhlo.recv` op for receiving a value. If `tpu_core` is set, op
// sharding for the respective device will be set.
Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
Value result, StringRef key, const Optional<size_t>& index,
Value result, StringRef key, size_t index,
const Optional<int64_t>& tpu_core, Value token) {
// type 3 == HOST_TO_DEVICE
auto channel_handle = ChannelHandle::get(
......@@ -294,12 +281,10 @@ Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
auto recv =
builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
/*is_host_transfer=*/builder.getBoolAttr(true));
if (index) {
SetFrontendAttributes(recv, *index, key, result_type,
/*device_to_host=*/false);
} else {
SetFrontendAttributes(recv, key, result.getType());
}
SetFrontendAttributes(recv, index, key, result_type,
/*device_to_host=*/false);
if (tpu_core) SetOpSharding(recv, *tpu_core);
auto get_tuple_element =
......@@ -369,7 +354,7 @@ Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id,
builder.setInsertionPoint(send_to_host);
token = CreateSendOp(builder, channel_id, send_to_host.getLoc(),
send_to_host.input(), send_to_host.key(),
/*index=*/llvm::None, /*tpu_core=*/llvm::None, token);
/*index=*/0, /*tpu_core=*/llvm::None, token);
send_to_host.erase();
return token;
......@@ -381,7 +366,7 @@ Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id,
builder.setInsertionPoint(recv_from_host);
token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(),
recv_from_host.output(), recv_from_host.key(),
/*index=*/llvm::None, /*tpu_core=*/llvm::None, token);
/*index=*/0, /*tpu_core=*/llvm::None, token);
recv_from_host.erase();
return token;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册