提交 9331ee54 编写于 作者: H Hye Soo Yang 提交者: TensorFlower Gardener

Adding sparse_core_ops_stats_handler for recording metrics

PiperOrigin-RevId: 565158613
上级 b090b29c
......@@ -139,6 +139,7 @@ cc_library(
srcs = ["sparse_core_preprocess_ops.cc"],
hdrs = ["sparse_core_preprocess_ops.h"],
deps = [
":sparse_core_ops_stats_handler",
":sparse_core_ops_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib_proto_parsing",
......@@ -156,6 +157,11 @@ cc_library(
],
)
cc_library(
name = "sparse_core_ops_stats_handler",
hdrs = ["sparse_core_ops_stats_handler.h"],
)
tf_kernel_library(
name = "tpu_configuration_ops",
srcs = ["tpu_configuration_ops.cc"],
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_STATS_HANDLER_H_
#define TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_STATS_HANDLER_H_
#include <cstdint>
#include <string>
enum class StatsType {
NUM_MINIBATCHES_PER_SC,
};
class SparseCoreOpsStatsHandler {
public:
virtual ~SparseCoreOpsStatsHandler() = default;
virtual void Record(
StatsType type, int64_t value, std::string device_name,
std::string table_name) { /* Default implementation does nothing */
}
};
#endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_STATS_HANDLER_H_
......@@ -18,6 +18,7 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <vector>
......@@ -33,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/tpu/kernels/sparse_core_ops_stats_handler.h"
#include "tensorflow/core/tpu/kernels/sparse_core_ops_utils.h"
namespace tensorflow {
......@@ -58,6 +60,10 @@ GetMinibatchesInCsrWithPhysicalReplicaOp::
"sample_count ", sample_count_,
" is not divisible by the number of sparsecores per chip ",
num_sc_per_chip_)));
// Create default instance of stats handler. May get overwritten by subclass.
sprase_core_ops_stats_handler_ =
std::make_unique<SparseCoreOpsStatsHandler>();
}
void GetMinibatchesInCsrWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) {
......@@ -119,6 +125,9 @@ void GetMinibatchesInCsrWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) {
ConvertBinarySplitsToBucketSplits(binary_splits, max_division_level);
const int32 num_minibatch_per_sc = bucket_splits.size() + 1;
sprase_core_ops_stats_handler_->Record(StatsType::NUM_MINIBATCHES_PER_SC,
num_minibatch_per_sc, device_name_,
table_name_);
OP_REQUIRES(
ctx, num_minibatch_per_sc <= max_minibatches_per_sc_,
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/tpu/kernels/sparse_core_ops_stats_handler.h"
namespace tensorflow {
......@@ -65,6 +66,7 @@ class GetMinibatchesInCsrWithPhysicalReplicaOp : public OpKernel {
int feature_width_ = 1;
int64_t num_sc_per_chip_;
std::string table_name_;
std::unique_ptr<SparseCoreOpsStatsHandler> sprase_core_ops_stats_handler_;
private:
int num_replica_ = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册