提交 0cb2a73b 编写于 作者: Y Yunxing Dai 提交者: TensorFlower Gardener

[XLA] [DynamicPadder] Support sort op.

PiperOrigin-RevId: 262951106
上级 40897309
......@@ -2173,13 +2173,14 @@ cc_library(
hdrs = ["dynamic_dimension_inference.h"],
deps = [
":hlo",
":hlo_casting_utils",
":while_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/core:lib",
"//tensorflow/core/platform:macros",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:span",
],
......
......@@ -17,8 +17,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/window_util.h"
......@@ -53,6 +55,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
Status HandleReshape(HloInstruction* hlo) override;
Status HandleSort(HloInstruction* hlo) override;
Status HandlePad(HloInstruction* hlo) override;
Status HandleBroadcast(HloInstruction* hlo) override;
......@@ -161,6 +165,22 @@ Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
});
}
Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index,
int64 dynamic_dimension, int64 operand_index,
HloInstruction* dynamic_size, DimensionConstraint constraint) {
int64 sort_dimension = Cast<HloSortInstruction>(hlo)->sort_dimension();
if (sort_dimension == dynamic_dimension) {
return Unimplemented(
"Dynamic dimension on sorting dimension is not supported");
}
parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size,
constraint);
return Status::OK();
});
}
Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
......
......@@ -912,6 +912,38 @@ TEST_F(DynamicDimensionInferenceTest, DynamicSliceTest) {
EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 0), size_param);
}
TEST_F(DynamicDimensionInferenceTest, SortTest) {
auto builder = HloComputation::Builder(TestName());
auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
auto size_param = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
auto compare_builder = HloComputation::Builder("condition");
compare_builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "param1"));
compare_builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {}), "param2"));
compare_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* compare =
module_->AddEmbeddedComputation(compare_builder.Build());
auto* sort = builder.AddInstruction(HloInstruction::CreateSort(
ShapeUtil::MakeShape(F32, {5, 7}), 1, {data_param}, compare,
/*is_stable=*/false));
module_->AddEntryComputation(builder.Build());
// Set up dynamic parameter binding.
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{1, {}},
DynamicParameterBinding::DynamicDimension{0, {}, 0}));
TF_ASSERT_OK(RunInference());
EXPECT_EQ(inference_->GetDynamicSize(sort, {}, 0), size_param);
}
TEST_F(DynamicDimensionInferenceTest, DynamicSliceSingleElementTest) {
// Slicing out a single element from a dynamic dimension terminates the
// dynamic dimension.
......
......@@ -90,6 +90,7 @@ StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
case HloOpcode::kAllReduce:
case HloOpcode::kBroadcast:
case HloOpcode::kTranspose:
case HloOpcode::kSort:
case HloOpcode::kSlice:
return nullptr;
default:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册