From e86206388ead0657a21a044a2758f86b0d42b646 Mon Sep 17 00:00:00 2001 From: SunAhong1993 <48579383+SunAhong1993@users.noreply.github.com> Date: Thu, 17 Jun 2021 19:58:31 +0800 Subject: [PATCH] Add docs and fix the assign op (#623) * fix the code * fix the visit_tuple * Update stargan.md * Update ultra_light_fast_generic_face_detector.md * fix the docs * remove static * fix * fix * fix * fix the docs * add docs * fix the doc Co-authored-by: channingss --- docs/introduction/x2paddle_model_zoo.md | 4 +-- .../API_docs/ops/README.md | 4 ++- .../API_docs/ops/torch.addmv.md | 4 ++- .../API_docs/ops/torch.addr.md | 26 +++++++++++++++++++ .../API_docs/ops/torch.baddbmm.md | 23 ++++++++++++++++ docs/pytorch_project_convertor/demo/craft.md | 23 ++++++++++++++++ .../pytorch/torch2paddle/nn_init.py | 14 +++++----- 7 files changed, 87 insertions(+), 11 deletions(-) create mode 100644 docs/pytorch_project_convertor/API_docs/ops/torch.addr.md create mode 100644 docs/pytorch_project_convertor/API_docs/ops/torch.baddbmm.md create mode 100644 docs/pytorch_project_convertor/demo/craft.md diff --git a/docs/introduction/x2paddle_model_zoo.md b/docs/introduction/x2paddle_model_zoo.md index e7ebc8c..4af0523 100644 --- a/docs/introduction/x2paddle_model_zoo.md +++ b/docs/introduction/x2paddle_model_zoo.md @@ -42,8 +42,8 @@ | ResNet50 | [code](https://github.com/soeaver/caffe-model/blob/master/cls/resnet/deploy_resnet50.prototxt) | | Unet | [code](https://github.com/jolibrain/deepdetect/blob/master/templates/caffe/unet/deploy.prototxt) | | VGGNet | [code](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-vgg_ilsvrc_16_layers_deploy-prototxt) | -| FaceDetection | [code](https://github.com/ShiqiYu/libfacedetection/blob/master/models/caffe/yufacedetectnet-open-v1.prototxt) | - +| FaceDetection | - | +【备注】-代表源模型已无法获取。 diff --git a/docs/pytorch_project_convertor/API_docs/ops/README.md b/docs/pytorch_project_convertor/API_docs/ops/README.md index 2a1142f..6366598 100644 --- a/docs/pytorch_project_convertor/API_docs/ops/README.md +++ b/docs/pytorch_project_convertor/API_docs/ops/README.md @@ -143,6 +143,8 @@ | 138 | [torch.inverse](https://pytorch.org/docs/stable/generated/torch.inverse.html?highlight=inverse#torch.inverse) | [paddle.inverse](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/inverse_cn.html) | 功能一致,参数名不一致,PaddlePaddle未定义`out`参数代表输出Tensor | | 139 | [torch.trace](https://pytorch.org/docs/stable/generated/torch.trace.html?highlight=trace#torch.trace) | [paddle.trace](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/trace_cn.html) | [差异对比](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.trace.md) | | 140 | [torch.addmv](https://pytorch.org/docs/stable/generated/torch.addmv.html?highlight=addmv#torch.addmv) | 无对应实现 | [组合实现](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.addmv.md) | - +| 141 | [torch.addr](https://pytorch.org/docs/stable/generated/torch.addr.html?highlight=addr#torch.addr) | 无对应实现 | [组合实现](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.addr.md) | +| 142 | [torch.baddbmm](https://pytorch.org/docs/stable/generated/torch.baddbmm.html?highlight=baddbmm#torch.baddbmm) | 无对应实现 | [组合实现](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.baddbmm.md) | +| 143 | [torch.bmm](https://pytorch.org/docs/stable/generated/torch.bmm.html?highlight=bmm#torch.bmm) | [paddle.bmm](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/bmm_cn.html) | 功能一致,参数名不一致,PaddlePaddle未定义`out`参数代表输出Tensor | ***持续更新...*** diff --git a/docs/pytorch_project_convertor/API_docs/ops/torch.addmv.md b/docs/pytorch_project_convertor/API_docs/ops/torch.addmv.md index 3eb3259..edfd02a 100644 --- a/docs/pytorch_project_convertor/API_docs/ops/torch.addmv.md +++ b/docs/pytorch_project_convertor/API_docs/ops/torch.addmv.md @@ -5,7 +5,9 @@ torch.addmv(input, mat, vec, beta=1, alpha=1, out=None) ``` ### 功能介绍 -用于实现矩阵(`mat`)与向量(`vec`)相乘,再加上输入(`input`),PaddlePaddle目前无对应API,可使用如下代码组合实现该API。 +用于实现矩阵(`mat`)与向量(`vec`)相乘,再加上输入(`input`),公式为: +$ out = β * input + α * (mat @ vec) $ +PaddlePaddle目前无对应API,可使用如下代码组合实现该API。 ```python import paddle diff --git a/docs/pytorch_project_convertor/API_docs/ops/torch.addr.md b/docs/pytorch_project_convertor/API_docs/ops/torch.addr.md new file mode 100644 index 0000000..9b24485 --- /dev/null +++ b/docs/pytorch_project_convertor/API_docs/ops/torch.addr.md @@ -0,0 +1,26 @@ +## torch.addr +### [torch.addr](https://pytorch.org/docs/stable/generated/torch.addr.html?highlight=addr#torch.addr) +```python +torch.addr(input, vec1, vec2, beta=1, alpha=1, out=None) +``` + +### 功能介绍 +用于实现矩阵(`vec`)与向量(`vec`)相乘,再加上输入(`input`),公式为: +$out = β * input + α * (vec1 ⊗ vec2)$ +PaddlePaddle目前无对应API,可使用如下代码组合实现该API。 + +```python +import paddle + +def addr(input, vec1, vec2, beta=1, alpha=1, out=None): + row = vec1.shape[0] + column = vec2.shape[0] + vec1 = paddle.unsqueeze(vec1, 0) + vec1 = paddle.transpose(vec1, [1, 0]) + vec1 = paddle.expand(vec1, [row, column]) + new_vec2 = paddle.zeros([column, column], dtype=vec2.dtype) + new_vec2[0, :] = vec2 + out = alpha * paddle.matmul(vec1, new_vec2) + out = beta * input + out + return out +``` \ No newline at end of file diff --git a/docs/pytorch_project_convertor/API_docs/ops/torch.baddbmm.md b/docs/pytorch_project_convertor/API_docs/ops/torch.baddbmm.md new file mode 100644 index 0000000..186789c --- /dev/null +++ b/docs/pytorch_project_convertor/API_docs/ops/torch.baddbmm.md @@ -0,0 +1,23 @@ +## torch.baddbmm +### [torch.baddbmm](https://pytorch.org/docs/stable/generated/torch.baddbmm.html?highlight=baddbmm#torch.baddbmm) +```python +torch.baddbmm(input, batch1, batch2, beta=1, alpha=1, out=None) +``` +### 功能介绍 +用于实现Tensor(大小为$b×n×m$)与用于实现Tensor(大小为$b×m×p$)相乘,再加上输入(`input`),公式为: +$out_i = β * input_i + α * (batch1_i @ batch2_i)$ +PaddlePaddle目前无对应API,可使用如下代码组合实现该API。 + +```python +def addr(input, vec1, vec2, beta=1, alpha=1, out=None): + row = vec1.shape[0] + column = vec2.shape[0] + vec1 = paddle.unsqueeze(vec1, 0) + vec1 = paddle.transpose(vec1, [1, 0]) + vec1 = paddle.expand(vec1, [row, column]) + new_vec2 = paddle.zeros([column, column], dtype=vec2.dtype) + new_vec2[0, :] = vec2 + out = alpha * paddle.matmul(vec1, new_vec2) + out = beta * input + out + return out +``` \ No newline at end of file diff --git a/docs/pytorch_project_convertor/demo/craft.md b/docs/pytorch_project_convertor/demo/craft.md new file mode 100644 index 0000000..0b6f61d --- /dev/null +++ b/docs/pytorch_project_convertor/demo/craft.md @@ -0,0 +1,23 @@ +## [CRAFT-pytorch](https://github.com/clovaai/CRAFT-pytorch) +### 准备工作 +``` shell +# 下载项目 +git clone https://github.com/clovaai/CRAFT-pytorch.git +cd CRAFT-pytorch +git checkout e332dd8b718e291f51b66ff8f9ef2c98ee4474c8 +``` +模型与数据可根据[原repo](https://github.com/clovaai/CRAFT-pytorch#test-instruction-using-pretrained-model)相关信息进行下载,可将模型存放于新建文件夹`./weight`,将数据存放于新建文件夹`./data`。 + +### 转换 +``` shell +cd ../ +x2paddle --convert_torch_project --project_dir=CRAFT-pytorch --save_dir=paddle_project --pretrain_model=CRAFT-pytorch/weights/ +``` + +### 运行训练代码 +``` shell +cd paddle_project +python test.py --trained_model=weights/craft_mlt_25k.pdiparams --test_folder=data +``` + +***转换后的代码可在[这里](https://github.com/SunAhong1993/CRAFT-pytorch/tree/paddle)进行查看。*** diff --git a/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py b/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py index 0d5e1af..289ce19 100644 --- a/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py +++ b/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py @@ -149,7 +149,7 @@ def kaiming_normal_(param, a=0, mode='fan_in', nonlinearity='leaky_relu'): dtype=param.dtype, default_initializer=KaimingNormal( a=a, mode=mode, nonlinearity=nonlinearity)) - paddle.assign(param, replaced_param) + paddle.assign(replaced_param, param) class XavierNormal(XavierInitializer): @@ -213,7 +213,7 @@ def xavier_normal_(param, gain=1.0): shape=param.shape, dtype=param.dtype, default_initializer=XavierNormal(gain=gain)) - paddle.assign(param, replaced_param) + paddle.assign(replaced_param, param) class XavierUniform(XavierInitializer): @@ -278,7 +278,7 @@ def xavier_uniform_(param, gain=1.0): shape=param.shape, dtype=param.dtype, default_initializer=XavierUniform(gain=gain)) - paddle.assign(param, replaced_param) + paddle.assign(replaced_param, param) def constant_init_(param, val): @@ -287,7 +287,7 @@ def constant_init_(param, val): dtype=param.dtype, default_initializer=paddle.nn.initializer.Assign( paddle.full(param.shape, val, param.dtype))) - paddle.assign(param, replaced_param) + paddle.assign(replaced_param, param) def normal_init_(param, mean=0.0, std=1.0): @@ -297,7 +297,7 @@ def normal_init_(param, mean=0.0, std=1.0): default_initializer=paddle.nn.initializer.Assign( paddle.normal( mean=mean, std=std, shape=param.shape))) - paddle.assign(param, replaced_param) + paddle.assign(replaced_param, param) def ones_init_(param): @@ -306,7 +306,7 @@ def ones_init_(param): dtype=param.dtype, default_initializer=paddle.nn.initializer.Assign( paddle.ones(param.shape, param.dtype))) - paddle.assign(param, replaced_param) + paddle.assign(replaced_param, param) def zeros_init_(param): @@ -315,4 +315,4 @@ def zeros_init_(param): dtype=param.dtype, default_initializer=paddle.nn.initializer.Assign( paddle.zeros(param.shape, param.dtype))) - paddle.assign(param, replaced_param) + paddle.assign(replaced_param, param) -- GitLab