47.md 20.8 KB
Newer Older
W
wizardforcel 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
# 使用自定义 C++ 类扩展 TorchScript

> 原文:<https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html>

本教程是[自定义运算符](torch_script_custom_ops.html)教程的后续教程,并介绍了我们为将 C++ 类同时绑定到 TorchScript 和 Python 而构建的 API。 该 API 与[`pybind11`](https://github.com/pybind/pybind11)非常相似,如果您熟悉该系统,则大多数概念都将转移过来。

## 用 C++ 实现和绑定类

在本教程中,我们将定义一个简单的 C++ 类,该类在成员变量中保持持久状态。

```py
// This header is all you need to do the C++ portions of this
// tutorial
#include <torch/script.h>
// This header is what defines the custom class registration
// behavior specifically. script.h already includes this, but
// we include it here so you know it exists in case you want
// to look at the API or implementation.
#include <torch/custom_class.h>

#include <string>
#include <vector>

template <class T>
struct MyStackClass : torch::CustomClassHolder {
  std::vector<T> stack_;
  MyStackClass(std::vector<T> init) : stack_(init.begin(), init.end()) {}

  void push(T x) {
    stack_.push_back(x);
  }
  T pop() {
    auto val = stack_.back();
    stack_.pop_back();
    return val;
  }

  c10::intrusive_ptr<MyStackClass> clone() const {
    return c10::make_intrusive<MyStackClass>(stack_);
  }

  void merge(const c10::intrusive_ptr<MyStackClass>& c) {
    for (auto& elem : c->stack_) {
      push(elem);
    }
  }
};

```

有几件事要注意:

*   `torch/custom_class.h`是您需要使用自定义类扩展 TorchScript 的标头。
*   注意,无论何时使用自定义类的实例,我们都通过`c10::intrusive_ptr<>`的实例来实现。 可以将`intrusive_ptr`视为类似于`std::shared_ptr`的智能指针,但是引用计数直接存储在对象中,而不是单独的元数据块(如`std::shared_ptr`中所做的。`torch::Tensor`内部使用相同的指针类型 ;和自定义类也必须使用此指针类型,以便我们可以一致地管理不同的对象类型。
*   注意的第二件事是用户定义的类必须继承`torch::CustomClassHolder`。 这样可以确保自定义类具有存储引用计数的空间。

现在让我们看一下如何使该类对 TorchScript 可见,该过程称为*绑定*该类:

```py
// Notice a few things:
// - We pass the class to be registered as a template parameter to
//   `torch::class_`. In this instance, we've passed the
//   specialization of the MyStackClass class ``MyStackClass<std::string>``.
//   In general, you cannot register a non-specialized template
//   class. For non-templated classes, you can just pass the
//   class name directly as the template parameter.
// - The arguments passed to the constructor make up the "qualified name"
//   of the class. In this case, the registered class will appear in
//   Python and C++ as `torch.classes.my_classes.MyStackClass`. We call
//   the first argument the "namespace" and the second argument the
//   actual class name.
TORCH_LIBRARY(my_classes, m) {
  m.class_<MyStackClass<std::string>>("MyStackClass")
    // The following line registers the contructor of our MyStackClass
    // class that takes a single `std::vector<std::string>` argument,
    // i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
    // Currently, we do not support registering overloaded
    // constructors, so for now you can only `def()` one instance of
    // `torch::init`.
    .def(torch::init<std::vector<std::string>>())
    // The next line registers a stateless (i.e. no captures) C++ lambda
    // function as a method. Note that a lambda function must take a
    // `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
    // as the first argument. Other arguments can be whatever you want.
    .def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
      return self->stack_.back();
    })
    // The following four lines expose methods of the MyStackClass<std::string>
    // class as-is. `torch::class_` will automatically examine the
    // argument and return types of the passed-in method pointers and
    // expose these to Python and TorchScript accordingly. Finally, notice
    // that we must take the *address* of the fully-qualified method name,
    // i.e. use the unary `&` operator, due to C++ typing rules.
    .def("push", &MyStackClass<std::string>::push)
    .def("pop", &MyStackClass<std::string>::pop)
    .def("clone", &MyStackClass<std::string>::clone)
    .def("merge", &MyStackClass<std::string>::merge)
  ;
}

```

## 使用 CMake 将示例构建为 C++ 项目

现在,我们将使用 [CMake](https://cmake.org) 构建系统来构建上述 C++ 代码。 首先,将到目前为止介绍的所有 C++ 代码放入`class.cpp`文件中。 然后,编写一个简单的`CMakeLists.txt`文件并将其放在同一目录中。 `CMakeLists.txt`应该是这样的:

```py
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(custom_class)

find_package(Torch REQUIRED)

# Define our library target
add_library(custom_class SHARED class.cpp)
set(CMAKE_CXX_STANDARD 14)
# Link against LibTorch
target_link_libraries(custom_class "${TORCH_LIBRARIES}")

```

另外,创建一个`build`目录。 您的文件树应如下所示:

```py
custom_class_project/
  class.cpp
  CMakeLists.txt
  build/

```

我们假设您已经按照[上一教程](torch_script_custom_ops.html)中所述的相同方式设置了环境。 继续并调用`cmake`,然后进行构建项目:

```py
$ cd build
$ cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..
  -- The C compiler identification is GNU 7.3.1
  -- The CXX compiler identification is GNU 7.3.1
  -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc
  -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- works
  -- Detecting C compiler ABI info
  -- Detecting C compiler ABI info - done
  -- Detecting C compile features
  -- Detecting C compile features - done
  -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++
  -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- works
  -- Detecting CXX compiler ABI info
  -- Detecting CXX compiler ABI info - done
  -- Detecting CXX compile features
  -- Detecting CXX compile features - done
  -- Looking for pthread.h
  -- Looking for pthread.h - found
  -- Looking for pthread_create
  -- Looking for pthread_create - not found
  -- Looking for pthread_create in pthreads
  -- Looking for pthread_create in pthreads - not found
  -- Looking for pthread_create in pthread
  -- Looking for pthread_create in pthread - found
  -- Found Threads: TRUE
  -- Found torch: /torchbind_tutorial/libtorch/lib/libtorch.so
  -- Configuring done
  -- Generating done
  -- Build files have been written to: /torchbind_tutorial/build
$ make -j
  Scanning dependencies of target custom_class
  [ 50%] Building CXX object CMakeFiles/custom_class.dir/class.cpp.o
  [100%] Linking CXX shared library libcustom_class.so
  [100%] Built target custom_class

```

您会发现,构建目录中现在有一个动态库文件。 在 Linux 上,它可能名为`libcustom_class.so`。 因此,文件树应如下所示:

```py
custom_class_project/
  class.cpp
  CMakeLists.txt
  build/
    libcustom_class.so

```

## 从 Python 和 TorchScript 使用 C++ 类

现在我们已经将我们的类及其注册编译为`.so`文件,我们可以将`.so`加载到 Python 中并进行尝试。 这是一个演示脚本的脚本:

```py
import torch

# `torch.classes.load_library()` allows you to pass the path to your .so file
# to load it in and make the custom C++ classes available to both Python and
# TorchScript
torch.classes.load_library("build/libcustom_class.so")
# You can query the loaded libraries like this:
print(torch.classes.loaded_libraries)
# prints {'/custom_class_project/build/libcustom_class.so'}

# We can find and instantiate our custom C++ class in python by using the
# `torch.classes` namespace:
#
# This instantiation will invoke the MyStackClass(std::vector<T> init)
# constructor we registered earlier
s = torch.classes.my_classes.MyStackClass(["foo", "bar"])

# We can call methods in Python
s.push("pushed")
assert s.pop() == "pushed"

# Returning and passing instances of custom classes works as you'd expect
s2 = s.clone()
s.merge(s2)
for expected in ["bar", "foo", "bar", "foo"]:
    assert s.pop() == expected

# We can also use the class in TorchScript
# For now, we need to assign the class's type to a local in order to
# annotate the type on the TorchScript function. This may change
# in the future.
MyStackClass = torch.classes.my_classes.MyStackClass

@torch.jit.script
def do_stacks(s: MyStackClass):  # We can pass a custom class instance
    # We can instantiate the class
    s2 = torch.classes.my_classes.MyStackClass(["hi", "mom"])
    s2.merge(s)  # We can call a method on the class
    # We can also return instances of the class
    # from TorchScript function/methods
    return s2.clone(), s2.top()

stack, top = do_stacks(torch.classes.my_classes.MyStackClass(["wow"]))
assert top == "wow"
for expected in ["wow", "mom", "hi"]:
    assert stack.pop() == expected

```

## 使用自定义类保存,加载和运行 TorchScript 代码

我们还可以在使用 libtorch 的 C++ 进程中使用自定义注册的 C++ 类。 举例来说,让我们定义一个简单的`nn.Module`,它实例化并调用`MyStackClass`类上的方法:

```py
import torch

torch.classes.load_library('build/libcustom_class.so')

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, s: str) -> str:
        stack = torch.classes.my_classes.MyStackClass(["hi", "mom"])
        return stack.pop() + s

scripted_foo = torch.jit.script(Foo())
print(scripted_foo.graph)

scripted_foo.save('foo.pt')

```

我们文件系统中的`foo.pt`现在包含我们刚刚定义的序列化 TorchScript 程序。

现在,我们将定义一个新的 CMake 项目,以展示如何加载此模型及其所需的`.so`文件。 有关如何执行此操作的完整说明,请查看[在 C++ 中加载 TorchScript 模型](https://pytorch.org/tutorials/advanced/cpp_export.html)的教程。

与之前类似,让我们创建一个包含以下内容的文件结构:

```py
cpp_inference_example/
  infer.cpp
  CMakeLists.txt
  foo.pt
  build/
  custom_class_project/
    class.cpp
    CMakeLists.txt
    build/

```

请注意,我们已经复制了序列化的`foo.pt`文件以及上面`custom_class_project`的源代码树。 我们将把`custom_class_project`作为依赖项添加到此 C++ 项目中,以便可以将自定义类构建到二进制文件中。

让我们用以下内容填充`infer.cpp`

```py
#include <torch/script.h>

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
  torch::jit::Module module;
  try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    module = torch::jit::load("foo.pt");
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return -1;
  }

  std::vector<c10::IValue> inputs = {"foobarbaz"};
  auto output = module.forward(inputs).toString();
  std::cout << output->string() << std::endl;
}

```

同样,让我们​​定义`CMakeLists.txt`文件:

```py
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(infer)

find_package(Torch REQUIRED)

add_subdirectory(custom_class_project)

# Define our library target
add_executable(infer infer.cpp)
set(CMAKE_CXX_STANDARD 14)
# Link against LibTorch
target_link_libraries(infer "${TORCH_LIBRARIES}")
# This is where we link in our libcustom_class code, making our
# custom class available in our binary.
target_link_libraries(infer -Wl,--no-as-needed custom_class)

```

您知道练习:`cd build``cmake``make`

```py
$ cd build
$ cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..
  -- The C compiler identification is GNU 7.3.1
  -- The CXX compiler identification is GNU 7.3.1
  -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc
  -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- works
  -- Detecting C compiler ABI info
  -- Detecting C compiler ABI info - done
  -- Detecting C compile features
  -- Detecting C compile features - done
  -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++
  -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- works
  -- Detecting CXX compiler ABI info
  -- Detecting CXX compiler ABI info - done
  -- Detecting CXX compile features
  -- Detecting CXX compile features - done
  -- Looking for pthread.h
  -- Looking for pthread.h - found
  -- Looking for pthread_create
  -- Looking for pthread_create - not found
  -- Looking for pthread_create in pthreads
  -- Looking for pthread_create in pthreads - not found
  -- Looking for pthread_create in pthread
  -- Looking for pthread_create in pthread - found
  -- Found Threads: TRUE
  -- Found torch: /local/miniconda3/lib/python3.7/site-packages/torch/lib/libtorch.so
  -- Configuring done
  -- Generating done
  -- Build files have been written to: /cpp_inference_example/build
$ make -j
  Scanning dependencies of target custom_class
  [ 25%] Building CXX object custom_class_project/CMakeFiles/custom_class.dir/class.cpp.o
  [ 50%] Linking CXX shared library libcustom_class.so
  [ 50%] Built target custom_class
  Scanning dependencies of target infer
  [ 75%] Building CXX object CMakeFiles/infer.dir/infer.cpp.o
  [100%] Linking CXX executable infer
  [100%] Built target infer

```

现在我们可以运行令人兴奋的 C++ 二进制文件:

```py
$ ./infer
  momfoobarbaz

```

难以置信!

## 将自定义类移入或移出`IValue`

也可能需要将自定义类从自定义 C++ 类实例移入或移出`IValue`, such as when you take or return IValues from TorchScript methods or you want to instantiate a custom class attribute in C++. For creating an IValue:

*   `torch::make_custom_class<T>()`提供类似于`c10::intrusive_ptr<T>`的 API,因为它将采用您提供给它的任何参数集,调用与该参数集匹配的`T`的构造器,并包装该实例,然后返回。 但是,它不仅返回指向自定义类对象的指针,还返回包装对象的`IValue`。 然后,您可以将此`IValue`直接传递给 TorchScript。
*   如果您已经有一个指向类的`intrusive_ptr`,则可以使用构造器`IValue(intrusive_ptr<T>)`直接从其构造`IValue`

要将`IValue`转换回自定义类:

*   `IValue::toCustomClass<T>()`将返回一个`intrusive_ptr<T>`,指向`IValue`包含的自定义类。 在内部,此函数正在检查`T`是否已注册为自定义类,并且`IValue`实际上确实包含一个自定义类。 您可以通过调用`isCustomClass()`来手动检查`IValue`是否包含自定义类。

## 为自定义 C++ 类定义序列化/反序列化方法

如果您尝试将具有自定义绑定 C++ 类的`ScriptModule`保存为属性,则会出现以下错误:

```py
# export_attr.py
import torch

torch.classes.load_library('build/libcustom_class.so')

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.stack = torch.classes.my_classes.MyStackClass(["just", "testing"])

    def forward(self, s: str) -> str:
        return self.stack.pop() + s

scripted_foo = torch.jit.script(Foo())

scripted_foo.save('foo.pt')
loaded = torch.jit.load('foo.pt')

print(loaded.stack.pop())

```

```py
$ python export_attr.py
RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.my_classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)

```

这是因为 TorchScript 无法自动找出 C++ 类中保存的信息。 您必须手动指定。 这样做的方法是使用`class_`上的特殊`def_pickle`方法在类上定义`__getstate__``__setstate__`方法。

注意

TorchScript 中`__getstate__``__setstate__`的语义与 Python `pickle`模块的语义相同。 您可以[阅读更多](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md#getstate-and-setstate)有关如何使用这些方法的信息。

这是`def_pickle`调用的示例,我们可以将其添加到`MyStackClass`的注册中以包括序列化方法:

```py
    // class_<>::def_pickle allows you to define the serialization
    // and deserialization methods for your C++ class.
    // Currently, we only support passing stateless lambda functions
    // as arguments to def_pickle
    .def_pickle(
          // __getstate__
          // This function defines what data structure should be produced
          // when we serialize an instance of this class. The function
          // must take a single `self` argument, which is an intrusive_ptr
          // to the instance of the object. The function can return
          // any type that is supported as a return value of the TorchScript
          // custom operator API. In this instance, we've chosen to return
          // a std::vector<std::string> as the salient data to preserve
          // from the class.
          [](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
              -> std::vector<std::string> {
            return self->stack_;
          },
          // __setstate__
          // This function defines how to create a new instance of the C++
          // class when we are deserializing. The function must take a
          // single argument of the same type as the return value of
          // `__getstate__`. The function must return an intrusive_ptr
          // to a new instance of the C++ class, initialized however
          // you would like given the serialized state.
          [](std::vector<std::string> state)
              -> c10::intrusive_ptr<MyStackClass<std::string>> {
            // A convenient way to instantiate an object and get an
            // intrusive_ptr to it is via `make_intrusive`. We use
            // that here to allocate an instance of MyStackClass<std::string>
            // and call the single-argument std::vector<std::string>
            // constructor with the serialized state.
            return c10::make_intrusive<MyStackClass<std::string>>(std::move(state));
          });

```

注意

我们在 Pickle API 中采用与`pybind11`不同的方法。`pybind11`作为传递给`class_::def()`的特殊函数`pybind11::pickle()`,为此我们有一个单独的方法`def_pickle`。 这是因为`torch::jit::pickle`这个名称已经被使用了,我们不想引起混淆。

以这种方式定义(反)序列化行为后,脚本现在可以成功运行:

```py
$ python ../export_attr.py
testing

```

## 定义接受或返回绑定 C++ 类的自定义运算符

定义自定义 C++ 类后,您还可以将该类用作自变量或从自定义运算符返回(即自由函数)。 假设您具有以下自由函数:

```py
c10::intrusive_ptr<MyStackClass<std::string>> manipulate_instance(const c10::intrusive_ptr<MyStackClass<std::string>>& instance) {
  instance->pop();
  return instance;
}

```

您可以在`TORCH_LIBRARY`块中运行以下代码来注册它:

```py
    m.def(
      "foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y",
      manipulate_instance
    );

```

有关注册 API 的更多详细信息,请参见[自定义操作教程](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html)

完成此操作后,您可以像以下示例一样使用操作:

```py
class TryCustomOp(torch.nn.Module):
    def __init__(self):
        super(TryCustomOp, self).__init__()
        self.f = torch.classes.my_classes.MyStackClass(["foo", "bar"])

    def forward(self):
        return torch.ops.foo.manipulate_instance(self.f)

```

注意

注册使用 C++ 类作为参数的运算符时,要求已注册自定义类。 您可以通过确保自定义类注册和您的自由函数定义在同一`TORCH_LIBRARY`块中,并确保自定义类注册位于第一位来强制实现此操作。 将来,我们可能会放宽此要求,以便可以按任何顺序进行注册。

## 总结

本教程向您介绍了如何向 TorchScript(以及扩展为 Python)公开 C++ 类,如何注册其方法,如何从 Python 和 TorchScript 使用该类以及如何使用该类保存和加载代码以及运行该代码。 在独立的 C++ 过程中。 现在,您可以使用与第三方 C++ 库连接的 C++ 类扩展 TorchScript 模型,或实现需要 Python,TorchScript 和 C++ 之间的界线平滑融合的任何其他用例。

与往常一样,如果您遇到任何问题或疑问,可以使用我们的[论坛](https://discuss.pytorch.org/)[GitHub ISSUE](https://github.com/pytorch/pytorch/issues) 进行联系。 另外,我们的[常见问题解答(FAQ)页面](https://pytorch.org/cppdocs/notes/faq.html)可能包含有用的信息。