# 使用自定义 C++ 类扩展 TorchScript > 原文: 本教程是[自定义运算符](torch_script_custom_ops.html)教程的后续教程,并介绍了我们为将 C++ 类同时绑定到 TorchScript 和 Python 而构建的 API。 该 API 与[`pybind11`](https://github.com/pybind/pybind11)非常相似,如果您熟悉该系统,则大多数概念都将转移过来。 ## 用 C++ 实现和绑定类 在本教程中,我们将定义一个简单的 C++ 类,该类在成员变量中保持持久状态。 ```py // This header is all you need to do the C++ portions of this // tutorial #include // This header is what defines the custom class registration // behavior specifically. script.h already includes this, but // we include it here so you know it exists in case you want // to look at the API or implementation. #include #include #include template struct MyStackClass : torch::CustomClassHolder { std::vector stack_; MyStackClass(std::vector init) : stack_(init.begin(), init.end()) {} void push(T x) { stack_.push_back(x); } T pop() { auto val = stack_.back(); stack_.pop_back(); return val; } c10::intrusive_ptr clone() const { return c10::make_intrusive(stack_); } void merge(const c10::intrusive_ptr& c) { for (auto& elem : c->stack_) { push(elem); } } }; ``` 有几件事要注意: * `torch/custom_class.h`是您需要使用自定义类扩展 TorchScript 的标头。 * 注意,无论何时使用自定义类的实例,我们都通过`c10::intrusive_ptr<>`的实例来实现。 可以将`intrusive_ptr`视为类似于`std::shared_ptr`的智能指针,但是引用计数直接存储在对象中,而不是单独的元数据块(如`std::shared_ptr`中所做的。`torch::Tensor`内部使用相同的指针类型 ;和自定义类也必须使用此指针类型,以便我们可以一致地管理不同的对象类型。 * 注意的第二件事是用户定义的类必须继承`torch::CustomClassHolder`。 这样可以确保自定义类具有存储引用计数的空间。 现在让我们看一下如何使该类对 TorchScript 可见,该过程称为*绑定*该类: ```py // Notice a few things: // - We pass the class to be registered as a template parameter to // `torch::class_`. In this instance, we've passed the // specialization of the MyStackClass class ``MyStackClass``. // In general, you cannot register a non-specialized template // class. For non-templated classes, you can just pass the // class name directly as the template parameter. // - The arguments passed to the constructor make up the "qualified name" // of the class. In this case, the registered class will appear in // Python and C++ as `torch.classes.my_classes.MyStackClass`. We call // the first argument the "namespace" and the second argument the // actual class name. TORCH_LIBRARY(my_classes, m) { m.class_>("MyStackClass") // The following line registers the contructor of our MyStackClass // class that takes a single `std::vector` argument, // i.e. it exposes the C++ method `MyStackClass(std::vector init)`. // Currently, we do not support registering overloaded // constructors, so for now you can only `def()` one instance of // `torch::init`. .def(torch::init>()) // The next line registers a stateless (i.e. no captures) C++ lambda // function as a method. Note that a lambda function must take a // `c10::intrusive_ptr` (or some const/ref version of that) // as the first argument. Other arguments can be whatever you want. .def("top", [](const c10::intrusive_ptr>& self) { return self->stack_.back(); }) // The following four lines expose methods of the MyStackClass // class as-is. `torch::class_` will automatically examine the // argument and return types of the passed-in method pointers and // expose these to Python and TorchScript accordingly. Finally, notice // that we must take the *address* of the fully-qualified method name, // i.e. use the unary `&` operator, due to C++ typing rules. .def("push", &MyStackClass::push) .def("pop", &MyStackClass::pop) .def("clone", &MyStackClass::clone) .def("merge", &MyStackClass::merge) ; } ``` ## 使用 CMake 将示例构建为 C++ 项目 现在,我们将使用 [CMake](https://cmake.org) 构建系统来构建上述 C++ 代码。 首先,将到目前为止介绍的所有 C++ 代码放入`class.cpp`文件中。 然后,编写一个简单的`CMakeLists.txt`文件并将其放在同一目录中。 `CMakeLists.txt`应该是这样的: ```py cmake_minimum_required(VERSION 3.1 FATAL_ERROR) project(custom_class) find_package(Torch REQUIRED) # Define our library target add_library(custom_class SHARED class.cpp) set(CMAKE_CXX_STANDARD 14) # Link against LibTorch target_link_libraries(custom_class "${TORCH_LIBRARIES}") ``` 另外,创建一个`build`目录。 您的文件树应如下所示: ```py custom_class_project/ class.cpp CMakeLists.txt build/ ``` 我们假设您已经按照[上一教程](torch_script_custom_ops.html)中所述的相同方式设置了环境。 继续并调用`cmake`,然后进行构建项目: ```py $ cd build $ cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" .. -- The C compiler identification is GNU 7.3.1 -- The CXX compiler identification is GNU 7.3.1 -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- works -- Detecting C compiler ABI info -- Detecting C compiler ABI info - done -- Detecting C compile features -- Detecting C compile features - done -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- works -- Detecting CXX compiler ABI info -- Detecting CXX compiler ABI info - done -- Detecting CXX compile features -- Detecting CXX compile features - done -- Looking for pthread.h -- Looking for pthread.h - found -- Looking for pthread_create -- Looking for pthread_create - not found -- Looking for pthread_create in pthreads -- Looking for pthread_create in pthreads - not found -- Looking for pthread_create in pthread -- Looking for pthread_create in pthread - found -- Found Threads: TRUE -- Found torch: /torchbind_tutorial/libtorch/lib/libtorch.so -- Configuring done -- Generating done -- Build files have been written to: /torchbind_tutorial/build $ make -j Scanning dependencies of target custom_class [ 50%] Building CXX object CMakeFiles/custom_class.dir/class.cpp.o [100%] Linking CXX shared library libcustom_class.so [100%] Built target custom_class ``` 您会发现,构建目录中现在有一个动态库文件。 在 Linux 上,它可能名为`libcustom_class.so`。 因此,文件树应如下所示: ```py custom_class_project/ class.cpp CMakeLists.txt build/ libcustom_class.so ``` ## 从 Python 和 TorchScript 使用 C++ 类 现在我们已经将我们的类及其注册编译为`.so`文件,我们可以将`.so`加载到 Python 中并进行尝试。 这是一个演示脚本的脚本: ```py import torch # `torch.classes.load_library()` allows you to pass the path to your .so file # to load it in and make the custom C++ classes available to both Python and # TorchScript torch.classes.load_library("build/libcustom_class.so") # You can query the loaded libraries like this: print(torch.classes.loaded_libraries) # prints {'/custom_class_project/build/libcustom_class.so'} # We can find and instantiate our custom C++ class in python by using the # `torch.classes` namespace: # # This instantiation will invoke the MyStackClass(std::vector init) # constructor we registered earlier s = torch.classes.my_classes.MyStackClass(["foo", "bar"]) # We can call methods in Python s.push("pushed") assert s.pop() == "pushed" # Returning and passing instances of custom classes works as you'd expect s2 = s.clone() s.merge(s2) for expected in ["bar", "foo", "bar", "foo"]: assert s.pop() == expected # We can also use the class in TorchScript # For now, we need to assign the class's type to a local in order to # annotate the type on the TorchScript function. This may change # in the future. MyStackClass = torch.classes.my_classes.MyStackClass @torch.jit.script def do_stacks(s: MyStackClass): # We can pass a custom class instance # We can instantiate the class s2 = torch.classes.my_classes.MyStackClass(["hi", "mom"]) s2.merge(s) # We can call a method on the class # We can also return instances of the class # from TorchScript function/methods return s2.clone(), s2.top() stack, top = do_stacks(torch.classes.my_classes.MyStackClass(["wow"])) assert top == "wow" for expected in ["wow", "mom", "hi"]: assert stack.pop() == expected ``` ## 使用自定义类保存,加载和运行 TorchScript 代码 我们还可以在使用 libtorch 的 C++ 进程中使用自定义注册的 C++ 类。 举例来说,让我们定义一个简单的`nn.Module`,它实例化并调用`MyStackClass`类上的方法: ```py import torch torch.classes.load_library('build/libcustom_class.so') class Foo(torch.nn.Module): def __init__(self): super().__init__() def forward(self, s: str) -> str: stack = torch.classes.my_classes.MyStackClass(["hi", "mom"]) return stack.pop() + s scripted_foo = torch.jit.script(Foo()) print(scripted_foo.graph) scripted_foo.save('foo.pt') ``` 我们文件系统中的`foo.pt`现在包含我们刚刚定义的序列化 TorchScript 程序。 现在,我们将定义一个新的 CMake 项目,以展示如何加载此模型及其所需的`.so`文件。 有关如何执行此操作的完整说明,请查看[在 C++ 中加载 TorchScript 模型](https://pytorch.org/tutorials/advanced/cpp_export.html)的教程。 与之前类似,让我们创建一个包含以下内容的文件结构: ```py cpp_inference_example/ infer.cpp CMakeLists.txt foo.pt build/ custom_class_project/ class.cpp CMakeLists.txt build/ ``` 请注意,我们已经复制了序列化的`foo.pt`文件以及上面`custom_class_project`的源代码树。 我们将把`custom_class_project`作为依赖项添加到此 C++ 项目中,以便可以将自定义类构建到二进制文件中。 让我们用以下内容填充`infer.cpp`: ```py #include #include #include int main(int argc, const char* argv[]) { torch::jit::Module module; try { // Deserialize the ScriptModule from a file using torch::jit::load(). module = torch::jit::load("foo.pt"); } catch (const c10::Error& e) { std::cerr << "error loading the model\n"; return -1; } std::vector inputs = {"foobarbaz"}; auto output = module.forward(inputs).toString(); std::cout << output->string() << std::endl; } ``` 同样,让我们​​定义`CMakeLists.txt`文件: ```py cmake_minimum_required(VERSION 3.1 FATAL_ERROR) project(infer) find_package(Torch REQUIRED) add_subdirectory(custom_class_project) # Define our library target add_executable(infer infer.cpp) set(CMAKE_CXX_STANDARD 14) # Link against LibTorch target_link_libraries(infer "${TORCH_LIBRARIES}") # This is where we link in our libcustom_class code, making our # custom class available in our binary. target_link_libraries(infer -Wl,--no-as-needed custom_class) ``` 您知道练习:`cd build`,`cmake`和`make`: ```py $ cd build $ cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" .. -- The C compiler identification is GNU 7.3.1 -- The CXX compiler identification is GNU 7.3.1 -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- works -- Detecting C compiler ABI info -- Detecting C compiler ABI info - done -- Detecting C compile features -- Detecting C compile features - done -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- works -- Detecting CXX compiler ABI info -- Detecting CXX compiler ABI info - done -- Detecting CXX compile features -- Detecting CXX compile features - done -- Looking for pthread.h -- Looking for pthread.h - found -- Looking for pthread_create -- Looking for pthread_create - not found -- Looking for pthread_create in pthreads -- Looking for pthread_create in pthreads - not found -- Looking for pthread_create in pthread -- Looking for pthread_create in pthread - found -- Found Threads: TRUE -- Found torch: /local/miniconda3/lib/python3.7/site-packages/torch/lib/libtorch.so -- Configuring done -- Generating done -- Build files have been written to: /cpp_inference_example/build $ make -j Scanning dependencies of target custom_class [ 25%] Building CXX object custom_class_project/CMakeFiles/custom_class.dir/class.cpp.o [ 50%] Linking CXX shared library libcustom_class.so [ 50%] Built target custom_class Scanning dependencies of target infer [ 75%] Building CXX object CMakeFiles/infer.dir/infer.cpp.o [100%] Linking CXX executable infer [100%] Built target infer ``` 现在我们可以运行令人兴奋的 C++ 二进制文件: ```py $ ./infer momfoobarbaz ``` 难以置信! ## 将自定义类移入或移出`IValue` 也可能需要将自定义类从自定义 C++ 类实例移入或移出`IValue`, such as when you take or return IValues from TorchScript methods or you want to instantiate a custom class attribute in C++. For creating an IValue: * `torch::make_custom_class()`提供类似于`c10::intrusive_ptr`的 API,因为它将采用您提供给它的任何参数集,调用与该参数集匹配的`T`的构造器,并包装该实例,然后返回。 但是,它不仅返回指向自定义类对象的指针,还返回包装对象的`IValue`。 然后,您可以将此`IValue`直接传递给 TorchScript。 * 如果您已经有一个指向类的`intrusive_ptr`,则可以使用构造器`IValue(intrusive_ptr)`直接从其构造`IValue`。 要将`IValue`转换回自定义类: * `IValue::toCustomClass()`将返回一个`intrusive_ptr`,指向`IValue`包含的自定义类。 在内部,此函数正在检查`T`是否已注册为自定义类,并且`IValue`实际上确实包含一个自定义类。 您可以通过调用`isCustomClass()`来手动检查`IValue`是否包含自定义类。 ## 为自定义 C++ 类定义序列化/反序列化方法 如果您尝试将具有自定义绑定 C++ 类的`ScriptModule`保存为属性,则会出现以下错误: ```py # export_attr.py import torch torch.classes.load_library('build/libcustom_class.so') class Foo(torch.nn.Module): def __init__(self): super().__init__() self.stack = torch.classes.my_classes.MyStackClass(["just", "testing"]) def forward(self, s: str) -> str: return self.stack.pop() + s scripted_foo = torch.jit.script(Foo()) scripted_foo.save('foo.pt') loaded = torch.jit.load('foo.pt') print(loaded.stack.pop()) ``` ```py $ python export_attr.py RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.my_classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128) ``` 这是因为 TorchScript 无法自动找出 C++ 类中保存的信息。 您必须手动指定。 这样做的方法是使用`class_`上的特殊`def_pickle`方法在类上定义`__getstate__`和`__setstate__`方法。 注意 TorchScript 中`__getstate__`和`__setstate__`的语义与 Python `pickle`模块的语义相同。 您可以[阅读更多](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md#getstate-and-setstate)有关如何使用这些方法的信息。 这是`def_pickle`调用的示例,我们可以将其添加到`MyStackClass`的注册中以包括序列化方法: ```py // class_<>::def_pickle allows you to define the serialization // and deserialization methods for your C++ class. // Currently, we only support passing stateless lambda functions // as arguments to def_pickle .def_pickle( // __getstate__ // This function defines what data structure should be produced // when we serialize an instance of this class. The function // must take a single `self` argument, which is an intrusive_ptr // to the instance of the object. The function can return // any type that is supported as a return value of the TorchScript // custom operator API. In this instance, we've chosen to return // a std::vector as the salient data to preserve // from the class. [](const c10::intrusive_ptr>& self) -> std::vector { return self->stack_; }, // __setstate__ // This function defines how to create a new instance of the C++ // class when we are deserializing. The function must take a // single argument of the same type as the return value of // `__getstate__`. The function must return an intrusive_ptr // to a new instance of the C++ class, initialized however // you would like given the serialized state. [](std::vector state) -> c10::intrusive_ptr> { // A convenient way to instantiate an object and get an // intrusive_ptr to it is via `make_intrusive`. We use // that here to allocate an instance of MyStackClass // and call the single-argument std::vector // constructor with the serialized state. return c10::make_intrusive>(std::move(state)); }); ``` 注意 我们在 Pickle API 中采用与`pybind11`不同的方法。`pybind11`作为传递给`class_::def()`的特殊函数`pybind11::pickle()`,为此我们有一个单独的方法`def_pickle`。 这是因为`torch::jit::pickle`这个名称已经被使用了,我们不想引起混淆。 以这种方式定义(反)序列化行为后,脚本现在可以成功运行: ```py $ python ../export_attr.py testing ``` ## 定义接受或返回绑定 C++ 类的自定义运算符 定义自定义 C++ 类后,您还可以将该类用作自变量或从自定义运算符返回(即自由函数)。 假设您具有以下自由函数: ```py c10::intrusive_ptr> manipulate_instance(const c10::intrusive_ptr>& instance) { instance->pop(); return instance; } ``` 您可以在`TORCH_LIBRARY`块中运行以下代码来注册它: ```py m.def( "foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y", manipulate_instance ); ``` 有关注册 API 的更多详细信息,请参见[自定义操作教程](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html)。 完成此操作后,您可以像以下示例一样使用操作: ```py class TryCustomOp(torch.nn.Module): def __init__(self): super(TryCustomOp, self).__init__() self.f = torch.classes.my_classes.MyStackClass(["foo", "bar"]) def forward(self): return torch.ops.foo.manipulate_instance(self.f) ``` 注意 注册使用 C++ 类作为参数的运算符时,要求已注册自定义类。 您可以通过确保自定义类注册和您的自由函数定义在同一`TORCH_LIBRARY`块中,并确保自定义类注册位于第一位来强制实现此操作。 将来,我们可能会放宽此要求,以便可以按任何顺序进行注册。 ## 总结 本教程向您介绍了如何向 TorchScript(以及扩展为 Python)公开 C++ 类,如何注册其方法,如何从 Python 和 TorchScript 使用该类以及如何使用该类保存和加载代码以及运行该代码。 在独立的 C++ 过程中。 现在,您可以使用与第三方 C++ 库连接的 C++ 类扩展 TorchScript 模型,或实现需要 Python,TorchScript 和 C++ 之间的界线平滑融合的任何其他用例。 与往常一样,如果您遇到任何问题或疑问,可以使用我们的[论坛](https://discuss.pytorch.org/)或 [GitHub ISSUE](https://github.com/pytorch/pytorch/issues) 进行联系。 另外,我们的[常见问题解答(FAQ)页面](https://pytorch.org/cppdocs/notes/faq.html)可能包含有用的信息。