提交 60ac36f5 编写于 作者: H Hongmin Fan 提交者: TensorFlower Gardener

Fix a batch task creation bug in TFRT batch fallback kernel. Without the fix,

TFRT batch fallback kernel to crash in high QPS load. The bug makes it only
create an object of base class BatchTask even when splitting a task of the
derived class FallbackBatchTask (used only in TFRT), and put it into a batch
with other FallbackBatchTask objects. When this batch of mixed types of tasks is
processed, it crashes.

PiperOrigin-RevId: 339927837
Change-Id: Ie52bd11c61c9ddbe6ab803cd90208419d4b2dba6
上级 a8772528
......@@ -83,6 +83,26 @@ const string& GetModelName(OpKernelContext* ctx) {
} // namespace
std::unique_ptr<BatchResourceBase::BatchTask>
BatchResourceBase::BatchTask::CreateSplitTask(
int split_index, AsyncOpKernel::DoneCallback done_callback) {
std::unique_ptr<BatchTask> task = CreateDerivedTask();
task->guid = this->guid;
task->propagated_context = Context(ContextKind::kThread);
task->inputs.reserve(this->inputs.size());
task->captured_inputs = this->captured_inputs;
task->context = this->context;
task->done_callback = done_callback;
task->split_index = split_index;
task->output = this->output;
task->status = this->status;
task->is_partial = true;
task->start_time = this->start_time;
return task;
}
using ::tensorflow::concat_split_util::Concat;
using ::tensorflow::concat_split_util::Split;
using TensorMatrix = std::vector<std::vector<Tensor>>;
......@@ -317,20 +337,7 @@ Status BatchResourceBase::ConcatInputTensors(
output_tasks->reserve(output_task_num);
for (int i = 0; i < output_task_num; i++) {
auto task = absl::make_unique<BatchTask>();
task->guid = input_task.guid;
task->propagated_context = Context(ContextKind::kThread);
task->captured_inputs = input_task.captured_inputs;
task->context = input_task.context;
task->done_callback = barrier.Inc();
task->start_time = input_task.start_time;
task->split_index = i;
task->inputs.reserve(input_task.inputs.size());
task->is_partial = true;
task->status = input_task.status;
task->output = input_task.output;
output_tasks->push_back(std::move(task));
output_tasks->push_back(input_task.CreateSplitTask(i, barrier.Inc()));
}
const int num_input_tensors = input_task.inputs.size();
......
......@@ -87,9 +87,19 @@ class BatchResourceBase : public ResourceBase {
bool is_partial = false;
uint64 start_time;
size_t size() const override { return inputs[0].shape().dim_size(0); }
uint64 start_time;
// Create a split task from this one. The caller needs to setup the inputs
// of the new task
std::unique_ptr<BatchTask> CreateSplitTask(
int split_index, AsyncOpKernel::DoneCallback done_callback);
protected:
virtual std::unique_ptr<BatchTask> CreateDerivedTask() {
return std::make_unique<BatchTask>();
}
};
// Appending a T suffix to make the type alias different to those in
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册