提交 d0ce6a90 编写于 作者: Y Yan Chunwei 提交者: root

fix anakin converter registry (#15993)

上级 a5124ee0
......@@ -64,7 +64,9 @@ option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option(WITH_PSLIB "Compile with pslib support" OFF)
option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF)
# TODO(Superjomn) Remove WITH_ANAKIN option if not needed latter.
option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(WITH_ANAKIN_SUBGRAPH "Compile with Anakin subgraph library" OFF)
option(ANAKIN_BUILD_FAT_BIN "Build anakin cuda fat-bin lib for all device plantform, ignored when WITH_ANAKIN=OFF" OFF)
option(ANAKIN_BUILD_CROSS_PLANTFORM "Build anakin lib for any nvidia device plantform. ignored when WITH_ANAKIN=OFF" ON)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
......
......@@ -16,7 +16,10 @@ add_subdirectory(utils)
if (TENSORRT_FOUND)
add_subdirectory(tensorrt)
endif()
# add_subdirectory(anakin)
if (WITH_ANAKIN_SUBGRAPH)
add_subdirectory(anakin)
endif()
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES)
......
......@@ -71,3 +71,5 @@ void FcOpConverter::operator()(const framework::proto::OpDesc &op,
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(fc, FcOpConverter);
......@@ -28,11 +28,8 @@ class FcOpConverter : public AnakinOpConverter {
const framework::Scope &scope,
bool test_mode) override;
virtual ~FcOpConverter() {}
private:
};
static Registrar<FcOpConverter> register_fc_op_converter("fc");
} // namespace anakin
} // namespace inference
} // namespace paddle
......@@ -46,19 +46,18 @@ class AnakinOpConverter {
bool test_mode = false) {
framework::OpDesc op_desc(op, nullptr);
std::string op_type = op_desc.Type();
std::shared_ptr<AnakinOpConverter> it{nullptr};
AnakinOpConverter *it = nullptr;
if (op_type == "mul") {
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL);
std::string Y = op_desc.Input("Y")[0];
std::cout << Y << parameters.count(Y) << std::endl;
if (parameters.count(Y)) {
it = OpRegister::instance()->Get("fc");
it = Registry<AnakinOpConverter>::Global().Lookup("fc");
}
}
if (!it) {
it = OpRegister::instance()->Get(op_type);
it = Registry<AnakinOpConverter>::Global().Lookup(op_type);
}
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", op_type);
it->SetEngine(engine);
......@@ -91,22 +90,23 @@ class AnakinOpConverter {
} // namespace inference
} // namespace paddle
#define REGISTER_ANAKIN_OP_CONVERTER(op_type__, Converter__) \
struct anakin_##op_type__##_converter \
: public ::paddle::framework::Registrar { \
anakin_##op_type__##_converter() { \
::paddle::inference:: \
Registry<paddle::inference::anakin::AnakinOpConverter>::Register< \
::paddle::inference::anakin::Converter__>(#op_type__); \
} \
}; \
anakin_##op_type__##_converter anakin_##op_type__##_converter__; \
int TouchConverterRegister_anakin_##op_type__() { \
anakin_##op_type__##_converter__.Touch(); \
return 0; \
#define REGISTER_ANAKIN_OP_CONVERTER(op_type__, Converter__) \
struct anakin_##op_type__##_converter \
: public ::paddle::framework::Registrar { \
anakin_##op_type__##_converter() { \
LOG(INFO) << "register convert " << #op_type__; \
::paddle::inference::Registry< \
::paddle::inference::anakin::AnakinOpConverter>::Global() \
.Register<::paddle::inference::anakin::Converter__>(#op_type__); \
} \
}; \
anakin_##op_type__##_converter anakin_##op_type__##_converter__; \
int TouchConverterRegister_anakin_##op_type__() { \
anakin_##op_type__##_converter__.Touch(); \
return 0; \
}
#define USE_ANAKIN_CONVERTER(op_type__) \
extern int TouchConverterRegister_anakin_##op_type__(); \
static int use_op_converter_anakin_##op_type__ __attribute__((unused)) = \
#define USE_ANAKIN_CONVERTER(op_type__) \
extern int TouchConverterRegister_anakin_##op_type__(); \
int use_op_converter_anakin_##op_type__ __attribute__((unused)) = \
TouchConverterRegister_anakin_##op_type__();
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/anakin/convert/fc.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/convert/ut_helper.h"
......@@ -22,10 +21,8 @@ namespace inference {
namespace anakin {
TEST(fc_op, test) {
auto fc_converter = OpRegister::instance()->Get("fc");
ASSERT_TRUE(fc_converter != nullptr);
// Registrar<FcOpConverter> register_fc("fc");
// auto fc = std::make_shared<FcOpConverter>();
auto* fc_converter = Registry<AnakinOpConverter>::Global().Lookup("fc");
ASSERT_TRUE(fc_converter);
std::unordered_set<std::string> parameters({"mul_y"});
framework::Scope scope;
......@@ -52,3 +49,4 @@ TEST(fc_op, test) {
} // namespace paddle
USE_OP(mul);
USE_ANAKIN_CONVERTER(fc);
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/engine.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/utils/singleton.h"
......@@ -82,7 +83,7 @@ class AnakinConvertValidation {
AnakinConvertValidation() = delete;
AnakinConvertValidation(const std::unordered_set<std::string>& parameters,
const framework::Scope& scope)
framework::Scope& scope)
: parameters_(parameters), scope_(scope), place_(0) {
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
engine_.reset(new AnakinEngine<NV, Precision::FP32>(true));
......
......@@ -198,9 +198,9 @@ class OpConverter {
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
struct trt_##op_type__##_converter : public ::paddle::framework::Registrar { \
trt_##op_type__##_converter() { \
::paddle::inference:: \
Registry<paddle::inference::tensorrt::OpConverter>::Register< \
::paddle::inference::tensorrt::Converter__>(#op_type__); \
::paddle::inference::Registry< \
paddle::inference::tensorrt::OpConverter>::Global() \
.Register<::paddle::inference::tensorrt::Converter__>(#op_type__); \
} \
}; \
trt_##op_type__##_converter trt_##op_type__##_converter__; \
......
......@@ -45,13 +45,13 @@ struct Registry {
}
template <typename ItemChild>
static void Register(const std::string& name) {
void Register(const std::string& name) {
PADDLE_ENFORCE_EQ(items_.count(name), 0);
items_[name] = new ItemChild;
}
static ItemParent* Lookup(const std::string& name,
const std::string& default_name = "") {
ItemParent* Lookup(const std::string& name,
const std::string& default_name = "") {
auto it = items_.find(name);
if (it == items_.end()) {
if (default_name == "")
......@@ -70,11 +70,8 @@ struct Registry {
private:
Registry() = default;
static std::unordered_map<std::string, ItemParent*> items_;
std::unordered_map<std::string, ItemParent*> items_;
};
template <typename ItemParent>
std::unordered_map<std::string, ItemParent*> Registry<ItemParent>::items_;
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册