提交 b99dffb4 编写于 作者: H Hye Soo Yang 提交者: TensorFlower Gardener

Open source kernel for `GetMinibatchesInCsrWithPhysicalReplicaOp` for SparseCore.

Open source sparse_core_ops_utils*

PiperOrigin-RevId: 565156105
上级 67c0625e
......@@ -134,6 +134,28 @@ cc_library(
hdrs = ["tpu_compile_op_options.h"],
)
cc_library(
name = "sparse_core_preprocess_ops",
srcs = ["sparse_core_preprocess_ops.cc"],
hdrs = ["sparse_core_preprocess_ops.h"],
deps = [
":sparse_core_ops_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:portable_gif_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@local_xla//xla:util",
"@local_xla//xla/stream_executor/tpu:tpu_api",
"@local_xla//xla/stream_executor/tpu:tpu_ops_c_api_hdrs",
],
)
tf_kernel_library(
name = "tpu_configuration_ops",
srcs = ["tpu_configuration_ops.cc"],
......@@ -1314,3 +1336,26 @@ tf_cc_test(
"@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
],
)
cc_library(
name = "sparse_core_ops_utils",
srcs = ["sparse_core_ops_utils.cc"],
hdrs = ["sparse_core_ops_utils.h"],
deps = [
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/numeric:bits",
],
)
tf_cc_test(
name = "sparse_core_ops_utils_test",
srcs = ["sparse_core_ops_utils_test.cc"],
deps = [
":sparse_core_ops_utils",
"//tensorflow/core:portable_gif_internal",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/sparse_core_ops_utils.h"
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/base/attributes.h"
#include "absl/numeric/bits.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
std::vector<int> ConvertBinarySplitsToBucketSplits(int64 split,
int max_division_level) {
std::vector<int> bucket_splits;
uint32 current_index = 0;
while (split > 0) {
if (split % 2 == 1) {
int split_level = absl::bit_width(current_index + 1) - 1;
int split_offset = current_index - (1 << split_level) + 1;
int split_size = 1 << (max_division_level - 1 - split_level);
bucket_splits.push_back(split_size + split_offset * split_size * 2);
}
split >>= 1;
current_index += 1;
}
absl::c_sort(bucket_splits);
return bucket_splits;
}
int64 ConvertBucketSplitsToBinarySplits(std::vector<int> bucket_splits,
int max_division_level) {
int64 binary_splits = 0;
for (auto& bucket_split : bucket_splits) {
int split_level = max_division_level - 1;
while (bucket_split > 0 && bucket_split % 2 == 0) {
--split_level;
bucket_split = bucket_split >> 1;
}
binary_splits |= (1LL << ((1 << split_level) - 1 + bucket_split / 2));
}
return binary_splits;
}
ABSL_ATTRIBUTE_WEAK int GetMinibatchMaxDivisionLevel() {
return kMinibatchMaxDivisionLevel;
}
} // namespace tensorflow
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_UTILS_H_
#define TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_UTILS_H_
#include <cstdint>
#include <limits>
#include <vector>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
constexpr int kMinibatchMaxDivisionLevel = 6;
// Pad value used for SparseCore mini batching logic.
const int32_t kXlaPadValue = std::numeric_limits<int32_t>::max();
std::vector<int> ConvertBinarySplitsToBucketSplits(int64 split,
int max_division_level);
int64 ConvertBucketSplitsToBinarySplits(std::vector<int> bucket_splits,
int max_division_level);
int GetMinibatchMaxDivisionLevel();
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_UTILS_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/sparse_core_ops_utils.h"
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
TEST(ConvertSplitsAndBackTest, Split0) {
const int max_division_level = 6;
int64 original_split = 0;
std::vector<int> actual_buckets =
ConvertBinarySplitsToBucketSplits(original_split, max_division_level);
std::vector<int> expected_buckets = {};
int64 re_split =
ConvertBucketSplitsToBinarySplits(expected_buckets, max_division_level);
ASSERT_EQ(re_split, original_split);
}
TEST(ConvertSplitsAndBackTest, Split2) {
const int max_division_level = 6;
int64 original_split = 2;
std::vector<int> actual_buckets =
ConvertBinarySplitsToBucketSplits(original_split, max_division_level);
std::vector<int> expected_buckets = {16};
int64 re_split =
ConvertBucketSplitsToBinarySplits(expected_buckets, max_division_level);
ASSERT_EQ(re_split, original_split);
}
TEST(ConvertSplitsAndBackTest, Split3) {
const int max_division_level = 6;
int64 original_split = 3;
std::vector<int> actual_buckets =
ConvertBinarySplitsToBucketSplits(original_split, max_division_level);
std::vector<int> expected_buckets = {16, 32};
int64 re_split =
ConvertBucketSplitsToBinarySplits(expected_buckets, max_division_level);
ASSERT_EQ(re_split, original_split);
}
} // namespace
} // namespace tensorflow
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <string>
#include <vector>
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "xla/stream_executor/tpu/tpu_api.h"
#include "xla/stream_executor/tpu/tpu_ops_c_api.h"
#include "xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/tpu/kernels/sparse_core_ops_utils.h"
namespace tensorflow {
GetMinibatchesInCsrWithPhysicalReplicaOp::
GetMinibatchesInCsrWithPhysicalReplicaOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_name_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_replica", &num_replica_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("sample_count", &sample_count_));
OP_REQUIRES_OK(
ctx, ctx->GetAttr("max_minibatches_per_sc", &max_minibatches_per_sc_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_ids_per_chip_per_sample",
&max_ids_per_chip_per_sample_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("table_vocab_size", &table_vocab_size_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_width", &feature_width_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_sc_per_chip", &num_sc_per_chip_));
device_name_ = ctx->device()->name();
OP_REQUIRES(ctx, sample_count_ % num_sc_per_chip_ == 0,
absl::InvalidArgumentError(absl::StrCat(
"sample_count ", sample_count_,
" is not divisible by the number of sparsecores per chip ",
num_sc_per_chip_)));
}
void GetMinibatchesInCsrWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Compute GetMinibatchesInCsrWithPhysicalReplicaOp";
const Tensor* row_ids;
OP_REQUIRES_OK(ctx, ctx->input("row_ids", &row_ids));
const Tensor* col_ids;
OP_REQUIRES_OK(ctx, ctx->input("col_ids", &col_ids));
const Tensor* gains;
OP_REQUIRES_OK(ctx, ctx->input("gains", &gains));
const Tensor* splits;
OP_REQUIRES_OK(ctx, ctx->input("splits", &splits));
const Tensor* id_counts;
OP_REQUIRES_OK(ctx, ctx->input("id_counts", &id_counts));
// TODO(patn): Allow clients to provide the max_ids and max_uniques directly
// making program_key optional. This would be useful if there's a need to
// use this op without the bridge.
const Tensor* program_key_t;
OP_REQUIRES_OK(ctx, ctx->input("program_key", &program_key_t));
tstring program_key = program_key_t->vec<tstring>()(0);
int64_t per_sparse_core_batch_size = sample_count_ / num_sc_per_chip_;
int64_t max_ids_per_partition = -1;
int64_t max_unique_ids_per_partition = -1;
GetMaxIdsAndUniques(ctx, program_key, table_name_, per_sparse_core_batch_size,
feature_width_, &max_ids_per_partition,
&max_unique_ids_per_partition);
const int32* row_ids_tensor_ptr = row_ids->flat<int32>().data();
const int32* col_ids_tensor_ptr = col_ids->flat<int32>().data();
const float* gains_tensor_ptr = gains->flat<float>().data();
const int64* splits_tensor_ptr = splits->flat<int64>().data();
const int32* id_counts_tensor_ptr = id_counts->flat<int32>().data();
const int num_physical_replica = num_replica_ * num_sc_per_chip_;
size_t xla_pad_size = stream_executor::tpu::OpsApiFn()
->TpuUtil_GetXlaPadSizeFromTpuTopologyFn();
OP_REQUIRES(ctx, sample_count_ % num_sc_per_chip_ == 0,
absl::InvalidArgumentError(
absl::StrCat("Sample_count has to be multiply of "
"num_sc_per_replica which is 4, but got ",
sample_count_, " samples.")));
const int max_division_level = GetMinibatchMaxDivisionLevel();
const int32 kMaxDivisions = 1 << max_division_level;
int64 binary_splits = 0;
for (int i = 0; i < splits->NumElements(); ++i) {
binary_splits |= *(splits_tensor_ptr + i);
}
std::vector<int> bucket_splits =
ConvertBinarySplitsToBucketSplits(binary_splits, max_division_level);
const int32 num_minibatch_per_sc = bucket_splits.size() + 1;
OP_REQUIRES(
ctx, num_minibatch_per_sc <= max_minibatches_per_sc_,
absl::InvalidArgumentError(absl::StrCat(
"The number of minibatches per sparse core is ", num_minibatch_per_sc,
". But the max minibatches per sparse core is set to be ",
max_minibatches_per_sc_, " which is smaller.")));
VLOG(2) << "GetMinibatchesInCsrWithPhysicalReplicaOp: "
<< "program_key ='" << program_key << "'"
<< ", table_name = " << table_name_
<< ", max_ids = " << max_ids_per_partition
<< ", max_uniques = " << max_unique_ids_per_partition
<< ", num_minibatch_per_sc = " << num_minibatch_per_sc;
const int total_num_minibatch = num_minibatch_per_sc * num_sc_per_chip_;
bucket_splits.insert(bucket_splits.begin(), 0);
bucket_splits.push_back(kMaxDivisions);
const int32 max_ids_per_chip = max_ids_per_chip_per_sample_ * sample_count_;
const int32 padded_row_pointers_size_per_sc =
xla::RoundUpTo<int32>(num_physical_replica, xla_pad_size);
Tensor* row_pointers_tensor;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(
"row_pointers",
TensorShape({max_minibatches_per_sc_ * num_sc_per_chip_ *
padded_row_pointers_size_per_sc}),
&row_pointers_tensor));
Tensor* sorted_sample_ids_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output("sorted_sample_ids",
TensorShape({max_ids_per_chip}),
&sorted_sample_ids_tensor));
Tensor* sorted_token_ids_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output("sorted_token_ids",
TensorShape({max_ids_per_chip}),
&sorted_token_ids_tensor));
Tensor* sorted_gains_tensor;
OP_REQUIRES_OK(
ctx, ctx->allocate_output("sorted_gains", TensorShape({max_ids_per_chip}),
&sorted_gains_tensor));
int32* row_pointers_tensor_ptr = row_pointers_tensor->flat<int32>().data();
int32* sorted_sample_ids_tensor_ptr =
sorted_sample_ids_tensor->flat<int32>().data();
int32* sorted_token_ids_tensor_ptr =
sorted_token_ids_tensor->flat<int32>().data();
float* sorted_gains_tensor_ptr = sorted_gains_tensor->flat<float>().data();
int32 global_index = 0;
int32 row_pointers_index = 0;
for (int sc_id = 0; sc_id < num_sc_per_chip_; ++sc_id) {
for (int i = 1; i < bucket_splits.size(); ++i) {
for (int replica_id = 0; replica_id < num_physical_replica;
++replica_id) {
const int global_division_id =
sc_id * num_physical_replica + replica_id;
const int start_division_pos =
global_division_id * kMaxDivisions + bucket_splits[i - 1];
const int end_division_pos =
global_division_id * kMaxDivisions + bucket_splits[i];
const int token_id_count = *(id_counts_tensor_ptr + end_division_pos) -
*(id_counts_tensor_ptr + start_division_pos);
const int token_id_start_pos =
*(id_counts_tensor_ptr + start_division_pos);
std::copy_n(col_ids_tensor_ptr + token_id_start_pos, token_id_count,
sorted_token_ids_tensor_ptr + global_index);
std::copy_n(row_ids_tensor_ptr + token_id_start_pos, token_id_count,
sorted_sample_ids_tensor_ptr + global_index);
std::copy_n(gains_tensor_ptr + token_id_start_pos, token_id_count,
sorted_gains_tensor_ptr + global_index);
global_index += token_id_count;
*(row_pointers_tensor_ptr + row_pointers_index) = global_index;
int32 num_ids_to_pad_per_replica =
xla::RoundUpTo<int32>(global_index, xla_pad_size) - global_index;
std::fill_n(sorted_token_ids_tensor_ptr + global_index,
num_ids_to_pad_per_replica, kXlaPadValue);
std::fill_n(sorted_sample_ids_tensor_ptr + global_index,
num_ids_to_pad_per_replica, kXlaPadValue);
std::fill_n(sorted_gains_tensor_ptr + global_index,
num_ids_to_pad_per_replica, kXlaPadValue);
global_index += num_ids_to_pad_per_replica;
++row_pointers_index;
}
// Pad the row_pointers to be memory aligned.
int32 num_row_pointers_to_pad =
xla::RoundUpTo<int32>(row_pointers_index, xla_pad_size) -
row_pointers_index;
std::fill_n(row_pointers_tensor_ptr + row_pointers_index,
num_row_pointers_to_pad, global_index);
row_pointers_index += num_row_pointers_to_pad;
}
}
int32 ids_unpadded_size = global_index;
OP_REQUIRES(ctx, ids_unpadded_size <= max_ids_per_chip,
absl::InvalidArgumentError(absl::StrCat(
"Got ", ids_unpadded_size,
" ids after padding but the max_ids_per_chip is set to be ",
max_ids_per_chip, " which is smaller.")));
int32 row_pointers_unpadded_size =
total_num_minibatch * padded_row_pointers_size_per_sc;
Tensor* num_minibatches_per_physical_sparse_core_tensor;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(
"num_minibatches_per_physical_sparse_core", TensorShape({}),
&num_minibatches_per_physical_sparse_core_tensor));
Tensor* row_pointers_unpadded_size_tensor;
OP_REQUIRES_OK(
ctx, ctx->allocate_output("row_pointers_unpadded_size", TensorShape({}),
&row_pointers_unpadded_size_tensor));
Tensor* ids_unpadded_size_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output("ids_unpadded_size", TensorShape({}),
&ids_unpadded_size_tensor));
num_minibatches_per_physical_sparse_core_tensor->flat<int32>()(0) =
num_minibatch_per_sc;
row_pointers_unpadded_size_tensor->flat<int32>()(0) =
row_pointers_unpadded_size;
ids_unpadded_size_tensor->flat<int32>()(0) = ids_unpadded_size;
VLOG(1) << "Compute GetMinibatchesInCsrWithPhysicalReplicaOp done";
}
#ifdef LIBTPU_ON_GCE
REGISTER_KERNEL_BUILDER(
Name("GetMinibatchesInCsrWithPhysicalReplica").Device(DEVICE_CPU),
GetMinibatchesInCsrWithPhysicalReplicaOp)
#endif
} // namespace tensorflow
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_PREPROCESS_OPS_H_
#define TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_PREPROCESS_OPS_H_
#include <string>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
// Struct to describe an embedding lookup input data.
struct EmbeddingLookupInput {
// Which replica it belongs.
int32 replica_id;
// Token id.
int32 token_id;
// Sample id.
int32 sample_id;
// Gain.
float gain;
EmbeddingLookupInput(int32 replica_id, int32 token_id, int32 sample_id,
float gain)
: replica_id(replica_id),
token_id(token_id),
sample_id(sample_id),
gain(gain) {}
};
class GetMinibatchesInCsrWithPhysicalReplicaOp : public OpKernel {
public:
explicit GetMinibatchesInCsrWithPhysicalReplicaOp(OpKernelConstruction* ctx);
~GetMinibatchesInCsrWithPhysicalReplicaOp() override = default;
GetMinibatchesInCsrWithPhysicalReplicaOp(
const GetMinibatchesInCsrWithPhysicalReplicaOp&) = delete;
GetMinibatchesInCsrWithPhysicalReplicaOp& operator=(
const GetMinibatchesInCsrWithPhysicalReplicaOp&) = delete;
void Compute(OpKernelContext* ctx) override;
protected:
virtual void GetMaxIdsAndUniques(OpKernelContext* ctx,
const std::string& program_key,
const std::string& table_name,
int64_t num_samples_per_sparse_core,
int64_t feature_width,
int64_t* max_ids_per_partition,
int64_t* max_unique_ids_per_partition) {}
int sample_count_ = 1;
int feature_width_ = 1;
int64_t num_sc_per_chip_;
std::string table_name_;
private:
int num_replica_ = 1;
int max_minibatches_per_sc_ = 1;
int max_ids_per_chip_per_sample_ = 1;
int table_vocab_size_ = 1;
std::string device_name_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_PREPROCESS_OPS_H_
......@@ -84,6 +84,7 @@ tsl::Status SetTpuOpsStructFns(void* library_handle) { // TENSORFLOW_STATUS_OK
TFTPU_SET_FN(ops_api_fn, TpuCompile_DestroyCompilationCacheKey);
TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateGuaranteedConstFingerprint);
TFTPU_SET_FN(ops_api_fn, TpuUtil_GetTopologyPtr);
TFTPU_SET_FN(ops_api_fn, TpuUtil_GetXlaPadSizeFromTpuTopology);
TFTPU_SET_FN(ops_api_fn, TfTpu_InitializeTpuModelServer);
......
......@@ -483,6 +483,9 @@ TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint(
// Returns a pointer to the TPU topology struct.
TFTPU_CAPI_EXPORT SE_TpuTopology* TpuUtil_GetTopologyPtr();
// Returns XLA pad size from TPU topology.
TFTPU_CAPI_EXPORT size_t TpuUtil_GetXlaPadSizeFromTpuTopology();
XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
TF_Status* status);
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
......@@ -786,6 +789,7 @@ struct TfTpu_OpsApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint);
TFTPU_ADD_FN_IN_STRUCT(TpuUtil_GetTopologyPtr);
TFTPU_ADD_FN_IN_STRUCT(TpuUtil_GetXlaPadSizeFromTpuTopology);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册