提交 00d0347c 编写于 作者: B Brennan Saeta 提交者: TensorFlower Gardener

[TF:XLA] Add debug metadata to HLO ops.

In order to support end-to-end debugging and performance profiling tooling for
the TensorFlow::XLA toolchain, this change adds a DebugMetadata proto to the
HloInstruction class, and pipes it through the tf2xla stack.
Change: 149703349
上级 d3147337
......@@ -18,6 +18,7 @@ limitations under the License.
#include <functional>
#include <memory>
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/platform/mem.h"
......@@ -85,6 +86,20 @@ Allocator* XlaCompilationDevice::GetAllocator(AllocatorAttributes attr) {
return allocator_.get();
}
void XlaCompilationDevice::Compute(OpKernel* op_kernel,
OpKernelContext* context) {
VLOG(1) << "XlaCompilationDevice::Compute "
<< SummarizeNodeDef(op_kernel->def());
auto* b = XlaContext::Get(context).builder();
xla::OpMetadata metadata;
metadata.set_op_type(op_kernel->type_string());
metadata.set_op_name(op_kernel->name());
b->SetOpMetadata(metadata);
op_kernel->Compute(context);
b->ClearOpMetadata();
VLOG(2) << "Done";
}
Status XlaCompilationDevice::Sync() { return Status::OK(); }
Status XlaCompilationDevice::MakeTensorFromProto(
......
......@@ -52,6 +52,8 @@ class XlaCompilationDevice : public LocalDevice {
Allocator* GetAllocator(AllocatorAttributes attr) override;
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
Status Sync() override;
Status MakeTensorFromProto(const TensorProto& tensor_proto,
......
......@@ -171,6 +171,7 @@ ComputationDataHandle ComputationBuilder::ConstantOp(
OpRequest op_request;
*op_request.mutable_constant_request() = request;
*op_request.mutable_computation() = computation_.handle();
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making constant request";
......@@ -198,6 +199,7 @@ ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number,
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_parameter_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making parameter request";
......@@ -269,6 +271,7 @@ ComputationDataHandle ComputationBuilder::Slice(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_slice_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making slice request";
......@@ -293,6 +296,7 @@ ComputationDataHandle ComputationBuilder::DynamicSlice(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_dynamic_slice_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making dynamic slice request";
......@@ -314,6 +318,7 @@ ComputationDataHandle ComputationBuilder::DynamicUpdateSlice(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_dynamic_update_slice_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making dynamic update slice request";
......@@ -336,6 +341,7 @@ ComputationDataHandle ComputationBuilder::ConcatInDim(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_concatenate_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making concatenate request";
......@@ -358,6 +364,7 @@ ComputationDataHandle ComputationBuilder::Broadcast(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_broadcast_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making broadcast request";
......@@ -380,6 +387,7 @@ ComputationDataHandle ComputationBuilder::Pad(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_pad_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making pad request";
......@@ -406,6 +414,7 @@ ComputationDataHandle ComputationBuilder::Reshape(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_reshape_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making reshape request";
......@@ -482,6 +491,7 @@ void ComputationBuilder::Trace(const string& tag,
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_trace_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making trace request";
......@@ -513,6 +523,7 @@ ComputationDataHandle ComputationBuilder::Tuple(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_variadic_op_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making variadic op request";
......@@ -532,6 +543,7 @@ ComputationDataHandle ComputationBuilder::GetTupleElement(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_get_tuple_element_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making get tuple element op request";
......@@ -758,6 +770,7 @@ ComputationDataHandle ComputationBuilder::ConvGeneralDilated(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_convolve_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making convolve request";
......@@ -777,6 +790,7 @@ ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape,
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_infeed_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making infeed op request";
......@@ -799,6 +813,7 @@ void ComputationBuilder::Outfeed(const ComputationDataHandle& operand,
OpRequest op_request;
*op_request.mutable_outfeed_request() = request;
*op_request.mutable_computation() = computation_.handle();
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making outfeed op request";
......@@ -825,6 +840,7 @@ ComputationDataHandle ComputationBuilder::Call(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_call_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making call op request";
......@@ -850,6 +866,7 @@ ComputationDataHandle ComputationBuilder::CustomCall(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_custom_call_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making custom call op request";
......@@ -990,6 +1007,7 @@ ComputationDataHandle ComputationBuilder::Rev(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_reverse_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making reverse op request";
......@@ -1035,6 +1053,7 @@ ComputationDataHandle ComputationBuilder::ConvertElementType(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_convert_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making convert request";
......@@ -1078,6 +1097,7 @@ ComputationDataHandle ComputationBuilder::UnaryOp(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_unary_op_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making unop request";
......@@ -1104,6 +1124,7 @@ ComputationDataHandle ComputationBuilder::BinaryOp(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_binary_op_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making binop request";
......@@ -1129,6 +1150,7 @@ ComputationDataHandle ComputationBuilder::RngOp(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_rng_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making rngop request";
......@@ -1152,6 +1174,7 @@ ComputationDataHandle ComputationBuilder::TernaryOp(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_ternary_op_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making triop request";
......@@ -1253,6 +1276,7 @@ ComputationDataHandle ComputationBuilder::Map(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_map_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making Map request";
......@@ -1291,6 +1315,7 @@ ComputationDataHandle ComputationBuilder::While(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_while_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making while request";
......@@ -1316,6 +1341,7 @@ ComputationDataHandle ComputationBuilder::Reduce(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_reduce_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making reduce request";
......@@ -1368,6 +1394,7 @@ ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_reduce_window_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making reduce-window request";
......@@ -1386,6 +1413,7 @@ ComputationDataHandle ComputationBuilder::CrossReplicaSum(
OpRequest op_request;
*op_request.mutable_cross_replica_sum_request() = request;
*op_request.mutable_computation() = computation_.handle();
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making cross-replica-sum request";
......@@ -1442,6 +1470,7 @@ ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding(
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_select_and_scatter_request() = request;
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making select-and-scatter request";
......@@ -1461,6 +1490,7 @@ void ComputationBuilder::Send(const ComputationDataHandle& operand,
OpRequest op_request;
*op_request.mutable_send_request() = request;
*op_request.mutable_computation() = computation_.handle();
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making send request";
......@@ -1485,6 +1515,7 @@ ComputationDataHandle ComputationBuilder::Recv(const Shape& shape,
OpRequest op_request;
*op_request.mutable_recv_request() = request;
*op_request.mutable_computation() = computation_.handle();
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making recv request";
......@@ -1520,6 +1551,11 @@ StatusOr<Computation> ComputationBuilder::Build() {
return {std::move(computation_)};
}
void ComputationBuilder::AddOpMetadata(OpRequest* request) const {
tensorflow::mutex_lock lock(mutex_);
*request->mutable_metadata() = metadata_;
}
/* static */ ConvolutionDimensionNumbers
ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
ConvolutionDimensionNumbers dimension_numbers;
......
......@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stacktrace.h"
#include "tensorflow/core/platform/types.h"
......@@ -61,6 +62,23 @@ class ComputationBuilder {
// Returns the computation name.
const string& name() { return name_; }
// Sets OpMetadata that will be added to all instructions until cleared.
//
// OpMetadata is often applied to a series of XLA HLO instructions. As a
// result, OpMetadata is set on the Computation Builder. All subsequent
// instructions generated via this Computation Builder will have the same
// OpMetadata attached until a call to ClearOpMetdata.
void SetOpMetadata(const OpMetadata& metadata) {
tensorflow::mutex_lock lock(mutex_);
metadata_ = metadata;
}
// Clears the HloMetdata state.
void ClearOpMetadata() {
tensorflow::mutex_lock lock(mutex_);
metadata_.Clear();
}
// Sets the builder to a mode where it will die immediately when an error is
// encountered, rather than producing it in a deferred fashion when Build() is
// called (which is the default).
......@@ -717,6 +735,8 @@ class ComputationBuilder {
// * dying if die_immediately_on_error_ is true
void NoteError(const Status& error);
void AddOpMetadata(OpRequest* request) const;
string name_; // Name to use for the built computation.
// The first error encountered while building the computation.
......@@ -735,6 +755,14 @@ class ComputationBuilder {
// Mode bit that indicates whether to die when a first error is encountered.
bool die_immediately_on_error_{false};
// Mutex to guard against concurrent access to metadata_.
mutable tensorflow::mutex mutex_;
// The metadata to attach to each op. This is structured as a "modal"-like
// operation, in order to simplify client code (and not sprinkle this metadata
// throughout the TensorFlow op kernel implementations).
OpMetadata metadata_ GUARDED_BY(mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
};
......
......@@ -179,6 +179,18 @@ string InstructionSequenceGraph(
WindowToString(instruction->window());
}
name += "\\n" + instruction->name();
if (!instruction->metadata().op_type().empty()) {
StrAppend(&name, "\\n", instruction->metadata().op_type());
}
if (!instruction->metadata().op_name().empty()) {
StrAppend(&name, "\\n", instruction->metadata().op_name());
}
if (!instruction->metadata().source_file().empty() &&
instruction->metadata().source_line() != 0) {
StrAppend(&name, "\\n", instruction->metadata().source_file(), ":",
instruction->metadata().source_line());
}
std::vector<HloComputation*> called_computations;
// Pick different colors or shapes for instructions which are particularly
......
......@@ -1443,6 +1443,11 @@ string HloInstruction::ToString(bool compact_operands) const {
if (opcode() == HloOpcode::kGetTupleElement) {
tensorflow::strings::StrAppend(&extra, ", index=", tuple_index());
}
if (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty()) {
tensorflow::strings::StrAppend(
&extra, " # metadata=", metadata_.ShortDebugString());
}
return tensorflow::strings::Printf(
"%s = %s %s(%s)%s", name().c_str(),
ShapeUtil::HumanStringWithLayout(shape()).c_str(),
......
......@@ -683,6 +683,10 @@ class HloInstruction {
// Sets the string identifier for this instruction.
void set_name(const string& name) { name_ = name; }
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
const OpMetadata& metadata() const { return metadata_; }
// Set/get the computation containing this instruction. set_parent should only
// be called by HloComputation methods which add/remove instructions to
// computations.
......@@ -857,6 +861,9 @@ class HloInstruction {
// The computation in which this instruction is contained.
HloComputation* parent_ = nullptr;
// Metadata for debugging.
OpMetadata metadata_;
TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
};
......
......@@ -1246,57 +1246,63 @@ tensorflow::Status Service::AddInstruction(
tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
TF_ASSIGN_OR_RETURN(UserComputation * computation,
computation_tracker_.Resolve(arg->computation()));
StatusOr<ComputationDataHandle> handle;
StatusOr<ComputationDataHandle> handle_status;
switch (arg->op_case()) {
case OpRequest::kBinaryOpRequest:
handle = computation->AddBinaryInstruction(arg->binary_op_request());
handle_status =
computation->AddBinaryInstruction(arg->binary_op_request());
break;
case OpRequest::kBroadcastRequest:
handle = computation->AddBroadcastInstruction(arg->broadcast_request());
handle_status =
computation->AddBroadcastInstruction(arg->broadcast_request());
break;
case OpRequest::kCallRequest: {
TF_ASSIGN_OR_RETURN(
UserComputation * to_apply,
computation_tracker_.Resolve(arg->call_request().to_apply()));
handle = computation->AddCallInstruction(arg->call_request(), *to_apply);
handle_status =
computation->AddCallInstruction(arg->call_request(), *to_apply);
break;
}
case OpRequest::kConcatenateRequest:
handle =
handle_status =
computation->AddConcatenateInstruction(arg->concatenate_request());
break;
case OpRequest::kConstantRequest:
handle = computation->AddConstantInstruction(arg->constant_request());
handle_status =
computation->AddConstantInstruction(arg->constant_request());
break;
case OpRequest::kConvertRequest:
handle = computation->AddConvertInstruction(arg->convert_request());
handle_status =
computation->AddConvertInstruction(arg->convert_request());
break;
case OpRequest::kConvolveRequest:
handle = computation->AddConvolveInstruction(arg->convolve_request());
handle_status =
computation->AddConvolveInstruction(arg->convolve_request());
break;
case OpRequest::kCrossReplicaSumRequest:
handle = computation->AddCrossReplicaSumInstruction(
handle_status = computation->AddCrossReplicaSumInstruction(
arg->cross_replica_sum_request());
break;
case OpRequest::kCustomCallRequest:
handle =
handle_status =
computation->AddCustomCallInstruction(arg->custom_call_request());
break;
case OpRequest::kDynamicSliceRequest:
handle =
handle_status =
computation->AddDynamicSliceInstruction(arg->dynamic_slice_request());
break;
case OpRequest::kDynamicUpdateSliceRequest:
handle = computation->AddDynamicUpdateSliceInstruction(
handle_status = computation->AddDynamicUpdateSliceInstruction(
arg->dynamic_update_slice_request());
break;
case OpRequest::kGetTupleElementRequest:
handle = computation->AddGetTupleElementInstruction(
handle_status = computation->AddGetTupleElementInstruction(
arg->get_tuple_element_request());
break;
case OpRequest::kInfeedRequest:
handle = computation->AddInfeedInstruction(arg->infeed_request());
handle_status = computation->AddInfeedInstruction(arg->infeed_request());
break;
case OpRequest::kOutfeedRequest:
TF_RETURN_IF_ERROR(
......@@ -1306,20 +1312,22 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
TF_ASSIGN_OR_RETURN(
UserComputation * to_apply,
computation_tracker_.Resolve(arg->map_request().to_apply()));
handle = computation->AddMapInstruction(arg->map_request(), *to_apply);
handle_status =
computation->AddMapInstruction(arg->map_request(), *to_apply);
break;
}
case OpRequest::kPadRequest:
handle = computation->AddPadInstruction(arg->pad_request());
handle_status = computation->AddPadInstruction(arg->pad_request());
break;
case OpRequest::kParameterRequest:
handle = computation->AddParameterInstruction(arg->parameter_request());
handle_status =
computation->AddParameterInstruction(arg->parameter_request());
break;
case OpRequest::kReduceRequest: {
TF_ASSIGN_OR_RETURN(
UserComputation * to_apply,
computation_tracker_.Resolve(arg->reduce_request().to_apply()));
handle =
handle_status =
computation->AddReduceInstruction(arg->reduce_request(), *to_apply);
break;
}
......@@ -1327,18 +1335,20 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
TF_ASSIGN_OR_RETURN(UserComputation * to_apply,
computation_tracker_.Resolve(
arg->reduce_window_request().to_apply()));
handle = computation->AddReduceWindowInstruction(
handle_status = computation->AddReduceWindowInstruction(
arg->reduce_window_request(), *to_apply);
break;
}
case OpRequest::kReshapeRequest:
handle = computation->AddReshapeInstruction(arg->reshape_request());
handle_status =
computation->AddReshapeInstruction(arg->reshape_request());
break;
case OpRequest::kReverseRequest:
handle = computation->AddReverseInstruction(arg->reverse_request());
handle_status =
computation->AddReverseInstruction(arg->reverse_request());
break;
case OpRequest::kRngRequest:
handle = computation->AddRngInstruction(arg->rng_request());
handle_status = computation->AddRngInstruction(arg->rng_request());
break;
case OpRequest::kSelectAndScatterRequest: {
TF_ASSIGN_OR_RETURN(UserComputation * select,
......@@ -1347,23 +1357,25 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
TF_ASSIGN_OR_RETURN(UserComputation * scatter,
computation_tracker_.Resolve(
arg->select_and_scatter_request().scatter()));
handle = computation->AddSelectAndScatterInstruction(
handle_status = computation->AddSelectAndScatterInstruction(
arg->select_and_scatter_request(), *select, *scatter);
break;
}
case OpRequest::kSliceRequest:
handle = computation->AddSliceInstruction(arg->slice_request());
handle_status = computation->AddSliceInstruction(arg->slice_request());
break;
case OpRequest::kTernaryOpRequest:
handle = computation->AddTernaryInstruction(arg->ternary_op_request());
handle_status =
computation->AddTernaryInstruction(arg->ternary_op_request());
break;
case OpRequest::kTraceRequest:
return computation->AddTraceInstruction(arg->trace_request());
case OpRequest::kUnaryOpRequest:
handle = computation->AddUnaryInstruction(arg->unary_op_request());
handle_status = computation->AddUnaryInstruction(arg->unary_op_request());
break;
case OpRequest::kVariadicOpRequest:
handle = computation->AddVariadicInstruction(arg->variadic_op_request());
handle_status =
computation->AddVariadicInstruction(arg->variadic_op_request());
break;
case OpRequest::kWhileRequest: {
TF_ASSIGN_OR_RETURN(
......@@ -1372,8 +1384,8 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
TF_ASSIGN_OR_RETURN(
UserComputation * body,
computation_tracker_.Resolve(arg->while_request().body()));
handle = computation->AddWhileInstruction(arg->while_request(),
*condition, *body);
handle_status = computation->AddWhileInstruction(arg->while_request(),
*condition, *body);
break;
}
case OpRequest::kSendRequest: {
......@@ -1385,13 +1397,19 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
case OpRequest::kRecvRequest: {
TF_RETURN_IF_ERROR(
channel_tracker_.RegisterRecv(arg->recv_request().channel_handle()));
handle = computation->AddRecvInstruction(arg->recv_request());
handle_status = computation->AddRecvInstruction(arg->recv_request());
break;
}
default:
return InvalidArgument("Unsupported operation");
}
TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle);
TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status);
// We set the debug metadata here, because we slice off part of the OpRequest
// proto in the above switch statement.
TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status);
TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata()));
return tensorflow::Status::OK();
}
......
......@@ -1091,6 +1091,22 @@ StatusOr<Shape> UserComputation::GetShape(const ComputationDataHandle& handle) {
return operand->output_shape();
}
Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle,
const OpMetadata& metadata) {
tensorflow::mutex_lock lock(mutex_);
int64 handle_value = handle.handle();
if (session_computation_.requests().count(handle_value) == 0) {
return InvalidArgument("Invalid handle in SetDebugMetadata (%lld)",
handle_value);
}
*session_computation_.mutable_requests()
->at(handle_value)
.mutable_request()
->mutable_metadata() = metadata;
return Status::OK();
}
Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) {
tensorflow::mutex_lock lock(mutex_);
......@@ -2314,6 +2330,7 @@ HloInstruction* ComputationLowerer::Visit(
default:
LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
}
hlo_instruction->set_metadata(request.request().metadata());
(*visited)[handle.handle()] = hlo_instruction;
return hlo_instruction;
}
......
......@@ -236,6 +236,10 @@ class UserComputation {
// Returns the output shape of the operation indicated by the given handle.
StatusOr<Shape> GetShape(const ComputationDataHandle& handle);
// Sets metadata on the Hlo instruction referenced by the given handle.
Status SetOpMetadata(const ComputationDataHandle& handle,
const OpMetadata& metadata);
// Builds a HLO computation from the UserComputation. The parameter "resolver"
// is a function which returns a pointer to the HloComputation corresponding
// to the given ComputationHandle at the given version. The resolver is used
......
......@@ -1346,6 +1346,22 @@ cc_test(
],
)
cc_test(
name = "hlo_metadata_test",
srcs = [
"hlo_metadata_test.cc",
],
deps = [
":local_client_test_base",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/service:computation_tracker",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/core:test_main",
],
)
xla_test(
name = "round_trip_transfer_test",
srcs = ["round_trip_transfer_test.cc"],
......
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
namespace xla {
namespace {
class HloMetadataTest : public LocalClientTestBase {
protected:
HloMetadataTest() {
metadata_.set_op_type("add");
metadata_.set_op_name("my_sum_op");
}
void BuildAddComputation(ComputationBuilder* builder) {
auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder->Add(x, y);
}
OpMetadata metadata_;
};
TEST_F(HloMetadataTest, MetadataPropagation) {
ComputationBuilder builder(local_client_, "add");
builder.SetOpMetadata(metadata_);
BuildAddComputation(&builder);
builder.ClearOpMetadata();
Shape argument_layout = ShapeUtil::MakeShape(F32, {});
TF_ASSIGN_OR_ASSERT_OK(
std::unique_ptr<LocalExecutable> executable,
local_client_->Compile(builder.Build().ValueOrDie(),
{&argument_layout, &argument_layout},
ExecutableBuildOptions()));
auto instruction = executable->executable()
->module()
.entry_computation()
->root_instruction();
EXPECT_EQ("add", instruction->metadata().op_type());
EXPECT_EQ("my_sum_op", instruction->metadata().op_name());
}
TEST_F(HloMetadataTest, MetadataClearing) {
ComputationBuilder builder(local_client_, "add");
builder.SetOpMetadata(metadata_);
// Some other pretend computation here.
builder.ClearOpMetadata();
BuildAddComputation(&builder);
Shape argument_layout = ShapeUtil::MakeShape(F32, {});
auto executable_status = local_client_->Compile(
builder.Build().ValueOrDie(), {&argument_layout, &argument_layout},
ExecutableBuildOptions());
ASSERT_IS_OK(executable_status);
std::unique_ptr<LocalExecutable> executable =
executable_status.ConsumeValueOrDie();
auto instruction = executable->executable()
->module()
.entry_computation()
->root_instruction();
// We expect these to be empty (no metadata set).
EXPECT_EQ("", instruction->metadata().op_type());
EXPECT_EQ("", instruction->metadata().op_name());
}
} // namespace
} // namespace xla
......@@ -178,6 +178,31 @@ message ComputationStats {
double transcendental_count = 2;
}
// Symbolization metadata for HLO Instructions.
//
// This metadata is used for debugging XLA code generation, as well as
// performance profiling of XLA-generated executables.
message OpMetadata {
// The framework op name that generated this XLA op.
//
// Frameworks that build on top of XLA should mirror the names of their ops
// back to users by specifying the op_type. In this way, even if the
// framework's "ops" are implemented as multiple XLA HLO Ops, they can be
// grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
// multiple ops, then each op should have the op_type be "SoftMax".)
string op_type = 1;
// The user-specified name of the op.
//
// This name is often unique within a computation. Note: some frameworks
// add auto-generated names if the user does not provide one.
string op_name = 2;
// Indicate a file and line that this op is associated to in a user's program.
//
// e.g. it could be be the file and line of user code that generated the op.
string source_file = 3;
int32 source_line = 4;
}
// Profile data from the execution of a computation.
message ExecutionProfile {
// Whether the executable was read from the compilation cache.
......@@ -690,6 +715,7 @@ message RecvRequest {
message OpRequest {
ComputationHandle computation = 1;
OpMetadata metadata = 33;
oneof op {
BinaryOpRequest binary_op_request = 2;
......@@ -723,7 +749,7 @@ message OpRequest {
SendRequest send_request = 30;
RecvRequest recv_request = 31;
OutfeedRequest outfeed_request = 32;
// Next: 33
// Next: 34
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册