提交 4c620cc8 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

[XLA] Add LiteralUtil::Replicate.

Change: 144352405
上级 a2f9e996
......@@ -136,6 +136,12 @@ class LiteralUtil {
const Literal& literal, tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices);
// Creates a literal with a prepended dimension with bound "times"; e.g. a
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input
// literal replicated four times.
template <typename NativeT>
static std::unique_ptr<Literal> Replicate(const Literal& input, int64 times);
// Create a literal by converting each element in an original literal to a new
// type.
template <typename NativeSrcT, typename NativeDestT>
......@@ -999,6 +1005,30 @@ LiteralUtil::CreateFullWithMonotonicDim0MajorLayout(
return literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::Replicate(
const Literal& input, int64 times) {
std::vector<int64> bounds = {times};
bounds.insert(bounds.end(), input.shape().dimensions().begin(),
input.shape().dimensions().end());
auto literal = MakeUnique<Literal>();
*literal->mutable_shape() =
ShapeUtil::MakeShape(input.shape().element_type(), bounds);
Reserve(ShapeUtil::ElementsIn(literal->shape()), literal.get());
for (int64 index = 0; index < ShapeUtil::ElementsIn(input.shape()); ++index) {
const std::vector<int64> element_indices =
IndexUtil::LinearIndexToMultidimensionalIndex(input.shape(), index);
const auto element = Get<NativeT>(input, element_indices);
for (int64 sample = 0; sample < times; ++sample) {
std::vector<int64> output_indices = {sample};
output_indices.insert(output_indices.end(), element_indices.begin(),
element_indices.end());
Set<NativeT>(literal.get(), output_indices, element);
}
}
return literal;
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册