提交 905deeaa 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Automated rollback of commit 3abfe2cd

PiperOrigin-RevId: 216633097
上级 9153b897
......@@ -89,8 +89,6 @@ CompileOnlyService::CompileAheadOfTime(
const auto& program_shape = instance.computation.program_shape();
ExecutionOptions execution_options;
*execution_options.mutable_debug_options() = debug_options;
*execution_options.mutable_shape_with_output_layout() =
*instance.result_layout;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(program_shape, instance.argument_layouts,
......
......@@ -166,21 +166,10 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Compiling XLA executable";
return Compile(ctx, computation_proto, program);
}));
std::unique_ptr<XRTCompilationCacheEntryRef> entry;
OP_REQUIRES_OK(ctx, cache->Lookup(uid, &entry));
Tensor handle_output(DT_INT64, TensorShape({}));
handle_output.scalar<int64>()() = uid;
ctx->set_output(0, handle_output);
xla::LocalExecutable* executable = entry->get().get_executable();
xla::ProgramShape program_shape = executable->executable()
->module()
.entry_computation()
->ComputeProgramShape();
Tensor program_shape_output(DT_STRING, TensorShape({1}));
program_shape_output.vec<string>()(0) = program_shape.SerializeAsString();
ctx->set_output(1, program_shape_output);
Tensor output(DT_INT64, TensorShape({}));
output.scalar<int64>()() = uid;
ctx->set_output(0, output);
}
XRTCompileOp::~XRTCompileOp() = default;
......
......@@ -64,6 +64,14 @@ uint32 GetXLARandomSeed() {
return counter.fetch_add(2);
}
// Looks up the input `key` in the compilation cache.
Status GetComputationCacheEntry(
XRTCompilationCache* cache, int64 key,
std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
TF_RETURN_IF_ERROR(cache->Lookup(key, entry));
return Status::OK();
}
// Populates `inputs` with the input tensors to the computation.
Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm,
bool release_inputs,
......
......@@ -23,12 +23,7 @@ namespace tensorflow {
REGISTER_OP("XRTCompile")
.Input("computation: string")
.Output("handle: int64")
.Output("program_shape: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
c->set_output(1, c->UnknownShapeOfRank(1));
return Status::OK();
})
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
.Doc(
R"(
Reads a computation proto, compiles it, and places it in the global compilation
......
......@@ -29,11 +29,8 @@ cc_library(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xrt:xrt_proto",
"//tensorflow/compiler/xrt:xrt_server",
"//tensorflow/compiler/xrt/cc:xrt_ops",
......@@ -52,10 +49,7 @@ tf_cc_test(
name = "raw_api_test_cpu",
size = "medium",
srcs = [],
args = [
"--xla_test_device=XLA_CPU",
"--xla_platform=CPU",
],
args = ["--xla_test_device=XLA_CPU"],
deps = [
":raw_api_test_lib",
"//tensorflow/compiler/jit:xla_cpu_device",
......@@ -66,10 +60,7 @@ tf_cuda_cc_test(
name = "raw_api_test_gpu",
size = "medium",
srcs = [],
args = [
"--xla_test_device=XLA_GPU",
"--xla_platform=GPU",
],
args = ["--xla_test_device=XLA_GPU"],
tags = tf_cuda_tests_tags(),
deps = [
":raw_api_test_lib",
......
......@@ -22,13 +22,10 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
......@@ -46,7 +43,6 @@ namespace tensorflow {
namespace {
string* xla_test_device_ptr; // initial value set in main()
string* xla_platform_ptr; // initial value set in main()
string DeviceFromFlag() {
string xla_test_device = *xla_test_device_ptr;
......@@ -149,28 +145,6 @@ void StoreComputationSnapshot(const xla::XlaComputation& computation,
*dst = *snapshot;
}
xla::ProgramShape XlaCompiledProgramShape(
const xla::XlaComputation& computation,
const xla::ProgramShape& input_program_shape) {
se::Platform* platform =
xla::PlatformUtil::GetPlatform(*xla_platform_ptr).ValueOrDie();
xla::LocalClient* client =
xla::ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
xla::ExecutableBuildOptions exec_options;
exec_options.set_result_layout(input_program_shape.result());
std::vector<const xla::Shape*> parameters_shapes;
for (int64 i = 0; i < input_program_shape.parameters_size(); ++i) {
parameters_shapes.push_back(&input_program_shape.parameters(i));
}
auto local_executable =
client->Compile(computation, parameters_shapes, exec_options)
.ValueOrDie();
return local_executable->executable()
->module()
.entry_computation()
->ComputeProgramShape();
}
TEST(RawApiTest, ReadAndWriteState) {
xrt::XLAAllocation alloc;
alloc.set_device_ordinal(0);
......@@ -364,87 +338,20 @@ TEST(RawApiTest, CompileAndExecute) {
auto p1_value =
ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
auto p1_handle = ops::XRTAllocate(root, p1_value);
auto result = ops::XRTExecute(root, c_handle.handle, e_config,
auto result = ops::XRTExecute(root, c_handle, e_config,
{Output(p0_handle), Output(p1_handle)});
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
TF_ASSERT_OK(root.status());
ClientSession session(root);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
TF_EXPECT_OK(session.Run({read_back}, &outputs));
xla::LiteralProto response;
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
xla::ProgramShape program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
EXPECT_EQ(program_shape.parameters_size(), 2);
}
TEST(RawApiTest, CompileWithXlaReturnShapes) {
xla::XlaBuilder builder("XrtXlaShapes");
auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128});
auto kernel_shape = xla::ShapeUtil::MakeShape(xla::BF16, {3, 3, 5, 5});
// Clear layouts to signal XLA we are ready to get whatever are coming out of
// the compilation process.
xla::LayoutUtil::ClearLayout(&input_shape);
xla::LayoutUtil::ClearLayout(&kernel_shape);
auto param_shape =
xla::ShapeUtil::MakeTupleShape({input_shape, kernel_shape});
auto param = xla::Parameter(&builder, 0, param_shape, "param");
auto input = xla::GetTupleElement(param, 0);
auto kernel = xla::GetTupleElement(param, 1);
xla::Conv(input, kernel, {1, 1}, xla::Padding::kSame);
TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation xla_computation, builder.Build());
auto result_shape = xla_computation.GetProgramShape().ValueOrDie().result();
// Clear the result shape layout to tell XLA we are accepting whatever are
// coming out of the compilation process.
xla::LayoutUtil::ClearLayout(&result_shape);
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() = param_shape;
*shapes->mutable_result() = result_shape;
StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot());
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
auto computation =
ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
auto c_handle = ops::XRTCompile(root, computation);
auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle);
TF_ASSERT_OK(root.status());
ClientSession session(root);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(),
{c_handle.program_shape}, {release}, &outputs));
xla::ProgramShape program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec<string>()(0)));
EXPECT_EQ(program_shape.parameters_size(), 1);
VLOG(2) << "Param: "
<< xla::ShapeUtil::HumanStringWithLayout(program_shape.parameters(0));
VLOG(2) << "Result: "
<< xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
xla::ProgramShape xla_program_shape =
XlaCompiledProgramShape(xla_computation, *shapes);
EXPECT_TRUE(xla::LayoutUtil::Equal(
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
.layout()));
EXPECT_TRUE(xla::LayoutUtil::Equal(
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(),
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1})
.layout()));
EXPECT_TRUE(xla::LayoutUtil::Equal(program_shape.result().layout(),
xla_program_shape.result().layout()));
}
TEST(RawApiTest, CompileAndExecuteZeroArg) {
......@@ -464,7 +371,7 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) {
auto computation =
ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
auto c_handle = ops::XRTCompile(root, computation);
auto result = ops::XRTExecute(root, c_handle.handle, e_config,
auto result = ops::XRTExecute(root, c_handle, e_config,
std::initializer_list<Input>({}));
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
TF_ASSERT_OK(root.status());
......@@ -513,7 +420,7 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) {
auto p1_value =
ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
auto p1_handle = ops::XRTAllocate(root, p1_value);
auto result = ops::XRTExecute(root, c_handle.handle, e_config,
auto result = ops::XRTExecute(root, c_handle, e_config,
{Output(p0_handle), Output(p1_handle)});
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
TF_ASSERT_OK(root.status());
......@@ -548,7 +455,7 @@ TEST(RawApiTest, LeakCompilationReference) {
ClientSession session(root);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs));
TF_EXPECT_OK(session.Run({c_handle}, &outputs));
}
} // namespace
......@@ -557,12 +464,9 @@ TEST(RawApiTest, LeakCompilationReference) {
int main(int argc, char** argv) {
tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU");
tensorflow::xla_platform_ptr = new tensorflow::string("CPU");
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr,
"Tensorflow device type to use for test, e.g., XLA_CPU"),
tensorflow::Flag("xla_platform", tensorflow::xla_platform_ptr,
"The XLA platform to select for the device"),
};
tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册