提交 d0ce6a90 编写于 作者: Y Yan Chunwei 提交者: root

fix anakin converter registry (#15993)

上级 a5124ee0
...@@ -64,7 +64,9 @@ option(WITH_DISTRIBUTE "Compile with distributed support" OFF) ...@@ -64,7 +64,9 @@ option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option(WITH_PSLIB "Compile with pslib support" OFF) option(WITH_PSLIB "Compile with pslib support" OFF)
option(WITH_CONTRIB "Compile the third-party contributation" OFF) option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF) option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF)
# TODO(Superjomn) Remove WITH_ANAKIN option if not needed latter.
option(WITH_ANAKIN "Compile with Anakin library" OFF) option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(WITH_ANAKIN_SUBGRAPH "Compile with Anakin subgraph library" OFF)
option(ANAKIN_BUILD_FAT_BIN "Build anakin cuda fat-bin lib for all device plantform, ignored when WITH_ANAKIN=OFF" OFF) option(ANAKIN_BUILD_FAT_BIN "Build anakin cuda fat-bin lib for all device plantform, ignored when WITH_ANAKIN=OFF" OFF)
option(ANAKIN_BUILD_CROSS_PLANTFORM "Build anakin lib for any nvidia device plantform. ignored when WITH_ANAKIN=OFF" ON) option(ANAKIN_BUILD_CROSS_PLANTFORM "Build anakin lib for any nvidia device plantform. ignored when WITH_ANAKIN=OFF" ON)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
......
...@@ -16,7 +16,10 @@ add_subdirectory(utils) ...@@ -16,7 +16,10 @@ add_subdirectory(utils)
if (TENSORRT_FOUND) if (TENSORRT_FOUND)
add_subdirectory(tensorrt) add_subdirectory(tensorrt)
endif() endif()
# add_subdirectory(anakin)
if (WITH_ANAKIN_SUBGRAPH)
add_subdirectory(anakin)
endif()
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES) get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES)
......
...@@ -71,3 +71,5 @@ void FcOpConverter::operator()(const framework::proto::OpDesc &op, ...@@ -71,3 +71,5 @@ void FcOpConverter::operator()(const framework::proto::OpDesc &op,
} // namespace anakin } // namespace anakin
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(fc, FcOpConverter);
...@@ -28,11 +28,8 @@ class FcOpConverter : public AnakinOpConverter { ...@@ -28,11 +28,8 @@ class FcOpConverter : public AnakinOpConverter {
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~FcOpConverter() {} virtual ~FcOpConverter() {}
private:
}; };
static Registrar<FcOpConverter> register_fc_op_converter("fc");
} // namespace anakin } // namespace anakin
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -46,19 +46,18 @@ class AnakinOpConverter { ...@@ -46,19 +46,18 @@ class AnakinOpConverter {
bool test_mode = false) { bool test_mode = false) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
std::string op_type = op_desc.Type(); std::string op_type = op_desc.Type();
std::shared_ptr<AnakinOpConverter> it{nullptr}; AnakinOpConverter *it = nullptr;
if (op_type == "mul") { if (op_type == "mul") {
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL); PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL);
std::string Y = op_desc.Input("Y")[0]; std::string Y = op_desc.Input("Y")[0];
std::cout << Y << parameters.count(Y) << std::endl;
if (parameters.count(Y)) { if (parameters.count(Y)) {
it = OpRegister::instance()->Get("fc"); it = Registry<AnakinOpConverter>::Global().Lookup("fc");
} }
} }
if (!it) { if (!it) {
it = OpRegister::instance()->Get(op_type); it = Registry<AnakinOpConverter>::Global().Lookup(op_type);
} }
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", op_type); PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", op_type);
it->SetEngine(engine); it->SetEngine(engine);
...@@ -95,9 +94,10 @@ class AnakinOpConverter { ...@@ -95,9 +94,10 @@ class AnakinOpConverter {
struct anakin_##op_type__##_converter \ struct anakin_##op_type__##_converter \
: public ::paddle::framework::Registrar { \ : public ::paddle::framework::Registrar { \
anakin_##op_type__##_converter() { \ anakin_##op_type__##_converter() { \
::paddle::inference:: \ LOG(INFO) << "register convert " << #op_type__; \
Registry<paddle::inference::anakin::AnakinOpConverter>::Register< \ ::paddle::inference::Registry< \
::paddle::inference::anakin::Converter__>(#op_type__); \ ::paddle::inference::anakin::AnakinOpConverter>::Global() \
.Register<::paddle::inference::anakin::Converter__>(#op_type__); \
} \ } \
}; \ }; \
anakin_##op_type__##_converter anakin_##op_type__##_converter__; \ anakin_##op_type__##_converter anakin_##op_type__##_converter__; \
...@@ -108,5 +108,5 @@ class AnakinOpConverter { ...@@ -108,5 +108,5 @@ class AnakinOpConverter {
#define USE_ANAKIN_CONVERTER(op_type__) \ #define USE_ANAKIN_CONVERTER(op_type__) \
extern int TouchConverterRegister_anakin_##op_type__(); \ extern int TouchConverterRegister_anakin_##op_type__(); \
static int use_op_converter_anakin_##op_type__ __attribute__((unused)) = \ int use_op_converter_anakin_##op_type__ __attribute__((unused)) = \
TouchConverterRegister_anakin_##op_type__(); TouchConverterRegister_anakin_##op_type__();
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/inference/anakin/convert/fc.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h" #include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/convert/ut_helper.h" #include "paddle/fluid/inference/anakin/convert/ut_helper.h"
...@@ -22,10 +21,8 @@ namespace inference { ...@@ -22,10 +21,8 @@ namespace inference {
namespace anakin { namespace anakin {
TEST(fc_op, test) { TEST(fc_op, test) {
auto fc_converter = OpRegister::instance()->Get("fc"); auto* fc_converter = Registry<AnakinOpConverter>::Global().Lookup("fc");
ASSERT_TRUE(fc_converter != nullptr); ASSERT_TRUE(fc_converter);
// Registrar<FcOpConverter> register_fc("fc");
// auto fc = std::make_shared<FcOpConverter>();
std::unordered_set<std::string> parameters({"mul_y"}); std::unordered_set<std::string> parameters({"mul_y"});
framework::Scope scope; framework::Scope scope;
...@@ -52,3 +49,4 @@ TEST(fc_op, test) { ...@@ -52,3 +49,4 @@ TEST(fc_op, test) {
} // namespace paddle } // namespace paddle
USE_OP(mul); USE_OP(mul);
USE_ANAKIN_CONVERTER(fc);
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/engine.h" #include "paddle/fluid/inference/anakin/engine.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
...@@ -82,7 +83,7 @@ class AnakinConvertValidation { ...@@ -82,7 +83,7 @@ class AnakinConvertValidation {
AnakinConvertValidation() = delete; AnakinConvertValidation() = delete;
AnakinConvertValidation(const std::unordered_set<std::string>& parameters, AnakinConvertValidation(const std::unordered_set<std::string>& parameters,
const framework::Scope& scope) framework::Scope& scope)
: parameters_(parameters), scope_(scope), place_(0) { : parameters_(parameters), scope_(scope), place_(0) {
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0); PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
engine_.reset(new AnakinEngine<NV, Precision::FP32>(true)); engine_.reset(new AnakinEngine<NV, Precision::FP32>(true));
......
...@@ -198,9 +198,9 @@ class OpConverter { ...@@ -198,9 +198,9 @@ class OpConverter {
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \ #define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
struct trt_##op_type__##_converter : public ::paddle::framework::Registrar { \ struct trt_##op_type__##_converter : public ::paddle::framework::Registrar { \
trt_##op_type__##_converter() { \ trt_##op_type__##_converter() { \
::paddle::inference:: \ ::paddle::inference::Registry< \
Registry<paddle::inference::tensorrt::OpConverter>::Register< \ paddle::inference::tensorrt::OpConverter>::Global() \
::paddle::inference::tensorrt::Converter__>(#op_type__); \ .Register<::paddle::inference::tensorrt::Converter__>(#op_type__); \
} \ } \
}; \ }; \
trt_##op_type__##_converter trt_##op_type__##_converter__; \ trt_##op_type__##_converter trt_##op_type__##_converter__; \
......
...@@ -45,12 +45,12 @@ struct Registry { ...@@ -45,12 +45,12 @@ struct Registry {
} }
template <typename ItemChild> template <typename ItemChild>
static void Register(const std::string& name) { void Register(const std::string& name) {
PADDLE_ENFORCE_EQ(items_.count(name), 0); PADDLE_ENFORCE_EQ(items_.count(name), 0);
items_[name] = new ItemChild; items_[name] = new ItemChild;
} }
static ItemParent* Lookup(const std::string& name, ItemParent* Lookup(const std::string& name,
const std::string& default_name = "") { const std::string& default_name = "") {
auto it = items_.find(name); auto it = items_.find(name);
if (it == items_.end()) { if (it == items_.end()) {
...@@ -70,11 +70,8 @@ struct Registry { ...@@ -70,11 +70,8 @@ struct Registry {
private: private:
Registry() = default; Registry() = default;
static std::unordered_map<std::string, ItemParent*> items_; std::unordered_map<std::string, ItemParent*> items_;
}; };
template <typename ItemParent>
std::unordered_map<std::string, ItemParent*> Registry<ItemParent>::items_;
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册