From a49f4a66b7b0dda654a60b962249e1ba41b101dd Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 22 Dec 2020 13:33:24 +0800 Subject: [PATCH] feat(dnn): add indexing_one_hot and indexing_set_one_hot opr GitOrigin-RevId: c5406c71ffa91864c1fb0828278e51ef3af45c97 --- dnn/{test => src}/common/opr_trait.h | 8 +++++--- dnn/test/common/opr_algo_proxy.h | 2 +- dnn/test/common/opr_proxy.h | 3 ++- dnn/test/common/powc.h | 5 ++--- 4 files changed, 10 insertions(+), 8 deletions(-) rename dnn/{test => src}/common/opr_trait.h (97%) diff --git a/dnn/test/common/opr_trait.h b/dnn/src/common/opr_trait.h similarity index 97% rename from dnn/test/common/opr_trait.h rename to dnn/src/common/opr_trait.h index 08929d3d..4088c051 100644 --- a/dnn/test/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -1,5 +1,5 @@ /** - * \file dnn/test/common/opr_trait.h + * \file dnn/src/common/opr_trait.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -14,7 +14,6 @@ #include namespace megdnn { -namespace test { template struct OprTrait {}; @@ -114,7 +113,10 @@ DEF(FakeQuantForward, 4, true, true); DEF(FakeQuantBackward, 5, true, false); DEF(TQTForward, 3, true, true); DEF(TQTBackward, 5, true, false); -} // namespace test +DEF(PowC, 2, false, true); +DEF(UniformRNG, 1, true, true); +DEF(GaussianRNG, 1, true, true); + } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/test/common/opr_algo_proxy.h b/dnn/test/common/opr_algo_proxy.h index 5ad9fbcd..79361f07 100644 --- a/dnn/test/common/opr_algo_proxy.h +++ b/dnn/test/common/opr_algo_proxy.h @@ -12,7 +12,7 @@ #pragma once #include "megdnn/basic_types.h" -#include "test/common/opr_trait.h" +#include "src/common/opr_trait.h" #include "test/common/utils.h" namespace megdnn { diff --git a/dnn/test/common/opr_proxy.h b/dnn/test/common/opr_proxy.h index 8ca2f68d..0f5a89a0 100644 --- a/dnn/test/common/opr_proxy.h +++ b/dnn/test/common/opr_proxy.h @@ -11,12 +11,13 @@ */ #pragma once +#include "src/common/opr_trait.h" + #include "test/common/deduce_layout_proxy.h" #include "test/common/exec_proxy.h" #include "test/common/fast_run_cache.h" #include "test/common/inspect_type.h" #include "test/common/opr_algo_proxy.h" -#include "test/common/opr_trait.h" #include "test/common/timer.h" #include "test/common/workspace_wrapper.h" diff --git a/dnn/test/common/powc.h b/dnn/test/common/powc.h index 241697f5..44cf091b 100644 --- a/dnn/test/common/powc.h +++ b/dnn/test/common/powc.h @@ -12,13 +12,12 @@ #include "megdnn/handle.h" #include "megdnn/oprs/general.h" -#include "test/common/opr_proxy.h" + +#include "src/common/opr_trait.h" namespace megdnn { namespace test { -DEF(PowC, 2, false, true); - void run_powc_test(Handle* handle, DType dtype); } // namespace test -- GitLab