提交 112e2b3f 编写于 作者: T T.J. Alumbaugh 提交者: TensorFlower Gardener

Add `cacheable` flag to Ruy Matrix so that caller "opts in" to cache behavior on a per-call basis

PiperOrigin-RevId: 286241942
Change-Id: Ie1320c17f6a50468a03dad2664a1c8645e09f3ce
上级 ddb75d9d
......@@ -382,13 +382,15 @@ struct CompileTimeEnabledReferenceMul</*ReferenceMulIsEnabled=*/false> {
}
};
inline void HandlePrepackedCaching(TrMulParams* params, Context* context) {
inline void HandlePrepackedCaching(TrMulParams* params,
const SidePair<bool>& cacheable,
Context* context) {
if (context->cache_policy == CachePolicy::kNoCache) {
return;
}
if (context->cache_policy == CachePolicy::kCacheLHSOnGemV) {
if (params->dst.layout.cols != 1) {
if (!cacheable[Side::kLhs] || params->dst.layout.cols != 1) {
return;
}
PrepackedCache* prepacked_cache = context->GetPrepackedCache();
......@@ -465,7 +467,8 @@ void DispatchMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
TrMulParams params;
CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
the_path, &params);
HandlePrepackedCaching(&params, context);
SidePair<bool> cacheable(lhs.cacheable, rhs.cacheable);
HandlePrepackedCaching(&params, cacheable, context);
TrMul(&params, context);
}
......
......@@ -108,6 +108,7 @@ template <typename Scalar>
struct Matrix final {
Matrix& operator=(const Matrix& other) {
data = other.data;
cacheable = other.cacheable;
layout = other.layout;
zero_point = other.zero_point;
return *this;
......@@ -120,6 +121,10 @@ struct Matrix final {
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
// When Scalar is floating-point, this must be 0.
Scalar zero_point = 0;
// Clients of Ruy must set this flag to enable any caching behavior. Doesn't
// impact numerical results, but caching can impact observable metrics like
// latency, memory usage, power, etc.
bool cacheable = false;
};
inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) {
......
......@@ -25,7 +25,6 @@ namespace ruy {
namespace {
TEST(PrepackedCacheTest, TestCacheEjection) {
ruy::Context* context = new ruy::Context();
// Create the cache.
PrepackedCache prepacked_cache(32);
// Allocate the prepacked matrix.
......@@ -54,11 +53,9 @@ TEST(PrepackedCacheTest, TestCacheEjection) {
// The cache size was exceeded by inserting mat2. Ensure that mat1 was
// ejected.
EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
delete context;
}
TEST(PrepackedCacheTest, TestCacheBasic) {
ruy::Context* context = new ruy::Context();
// Create the cache.
PrepackedCache prepacked_cache(48);
// Allocate the prepacked matrix.
......@@ -83,11 +80,9 @@ TEST(PrepackedCacheTest, TestCacheBasic) {
// The cache size was not exceeded by inserting mat2. Ensure that mat1 was not
// ejected.
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
delete context;
}
TEST(PrepackedCacheTest, TestCacheEjection2) {
ruy::Context* context = new ruy::Context();
// Create the cache.
PrepackedCache prepacked_cache(73);
// Allocate the prepacked matrix 1.
......@@ -137,7 +132,39 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend());
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend());
delete context;
}
TEST(PrepackedCacheTest, TestCacheOnCacheable) {
// Create context and set the cache policy
ruy::Context context;
context.cache_policy = ruy::kCacheLHSOnGemV;
PrepackedCache* cache = context.GetPrepackedCache();
EXPECT_EQ(cache->TotalSize(), 0);
const float lhs_data[] = {1, 2, 3, 4};
const float rhs_data[] = {1, 2};
float dst_data[4];
ruy::Matrix<float> lhs;
ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout);
lhs.data = lhs_data;
ruy::Matrix<float> rhs;
ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout);
rhs.data = rhs_data;
ruy::Matrix<float> dst;
ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout);
dst.data = dst_data;
ruy::BasicSpec<float, float> spec;
// Perform the multiplication and confirm no caching occured.
ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, &context, &dst);
EXPECT_EQ(cache->TotalSize(), 0);
// Set cacheable for the LHS, repeat the multiplication, and see
// that caching did occur.
lhs.cacheable = true;
ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, &context, &dst);
EXPECT_NE(cache->TotalSize(), 0);
}
} // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册