未验证 提交 e8620638 编写于 作者: S SunAhong1993 提交者: GitHub

Add docs and fix the assign op (#623)

* fix the code

* fix the visit_tuple

* Update stargan.md

* Update ultra_light_fast_generic_face_detector.md

* fix the docs

* remove static

* fix

* fix

* fix

* fix the docs

* add docs

* fix the doc
Co-authored-by: Nchanningss <chen_lingchi@163.com>
上级 597dc18a
......@@ -42,8 +42,8 @@
| ResNet50 | [code](https://github.com/soeaver/caffe-model/blob/master/cls/resnet/deploy_resnet50.prototxt) |
| Unet | [code](https://github.com/jolibrain/deepdetect/blob/master/templates/caffe/unet/deploy.prototxt) |
| VGGNet | [code](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-vgg_ilsvrc_16_layers_deploy-prototxt) |
| FaceDetection | [code](https://github.com/ShiqiYu/libfacedetection/blob/master/models/caffe/yufacedetectnet-open-v1.prototxt) |
| FaceDetection | - |
【备注】-代表源模型已无法获取。
......
......@@ -143,6 +143,8 @@
| 138 | [torch.inverse](https://pytorch.org/docs/stable/generated/torch.inverse.html?highlight=inverse#torch.inverse) | [paddle.inverse](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/inverse_cn.html) | 功能一致,参数名不一致,PaddlePaddle未定义`out`参数代表输出Tensor |
| 139 | [torch.trace](https://pytorch.org/docs/stable/generated/torch.trace.html?highlight=trace#torch.trace) | [paddle.trace](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/trace_cn.html) | [差异对比](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.trace.md) |
| 140 | [torch.addmv](https://pytorch.org/docs/stable/generated/torch.addmv.html?highlight=addmv#torch.addmv) | 无对应实现 | [组合实现](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.addmv.md) |
| 141 | [torch.addr](https://pytorch.org/docs/stable/generated/torch.addr.html?highlight=addr#torch.addr) | 无对应实现 | [组合实现](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.addr.md) |
| 142 | [torch.baddbmm](https://pytorch.org/docs/stable/generated/torch.baddbmm.html?highlight=baddbmm#torch.baddbmm) | 无对应实现 | [组合实现](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.baddbmm.md) |
| 143 | [torch.bmm](https://pytorch.org/docs/stable/generated/torch.bmm.html?highlight=bmm#torch.bmm) | [paddle.bmm](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/bmm_cn.html) | 功能一致,参数名不一致,PaddlePaddle未定义`out`参数代表输出Tensor |
***持续更新...***
......@@ -5,7 +5,9 @@ torch.addmv(input, mat, vec, beta=1, alpha=1, out=None)
```
### 功能介绍
用于实现矩阵(`mat`)与向量(`vec`)相乘,再加上输入(`input`),PaddlePaddle目前无对应API,可使用如下代码组合实现该API。
用于实现矩阵(`mat`)与向量(`vec`)相乘,再加上输入(`input`),公式为:
$ out = β * input + α * (mat @ vec) $
PaddlePaddle目前无对应API,可使用如下代码组合实现该API。
```python
import paddle
......
## torch.addr
### [torch.addr](https://pytorch.org/docs/stable/generated/torch.addr.html?highlight=addr#torch.addr)
```python
torch.addr(input, vec1, vec2, beta=1, alpha=1, out=None)
```
### 功能介绍
用于实现矩阵(`vec`)与向量(`vec`)相乘,再加上输入(`input`),公式为:
$out = β * input + α * (vec1 ⊗ vec2)$
PaddlePaddle目前无对应API,可使用如下代码组合实现该API。
```python
import paddle
def addr(input, vec1, vec2, beta=1, alpha=1, out=None):
row = vec1.shape[0]
column = vec2.shape[0]
vec1 = paddle.unsqueeze(vec1, 0)
vec1 = paddle.transpose(vec1, [1, 0])
vec1 = paddle.expand(vec1, [row, column])
new_vec2 = paddle.zeros([column, column], dtype=vec2.dtype)
new_vec2[0, :] = vec2
out = alpha * paddle.matmul(vec1, new_vec2)
out = beta * input + out
return out
```
\ No newline at end of file
## torch.baddbmm
### [torch.baddbmm](https://pytorch.org/docs/stable/generated/torch.baddbmm.html?highlight=baddbmm#torch.baddbmm)
```python
torch.baddbmm(input, batch1, batch2, beta=1, alpha=1, out=None)
```
### 功能介绍
用于实现Tensor(大小为$b×n×m$)与用于实现Tensor(大小为$b×m×p$)相乘,再加上输入(`input`),公式为:
$out_i = β * input_i + α * (batch1_i @ batch2_i)$
PaddlePaddle目前无对应API,可使用如下代码组合实现该API。
```python
def addr(input, vec1, vec2, beta=1, alpha=1, out=None):
row = vec1.shape[0]
column = vec2.shape[0]
vec1 = paddle.unsqueeze(vec1, 0)
vec1 = paddle.transpose(vec1, [1, 0])
vec1 = paddle.expand(vec1, [row, column])
new_vec2 = paddle.zeros([column, column], dtype=vec2.dtype)
new_vec2[0, :] = vec2
out = alpha * paddle.matmul(vec1, new_vec2)
out = beta * input + out
return out
```
\ No newline at end of file
## [CRAFT-pytorch](https://github.com/clovaai/CRAFT-pytorch)
### 准备工作
``` shell
# 下载项目
git clone https://github.com/clovaai/CRAFT-pytorch.git
cd CRAFT-pytorch
git checkout e332dd8b718e291f51b66ff8f9ef2c98ee4474c8
```
模型与数据可根据[原repo](https://github.com/clovaai/CRAFT-pytorch#test-instruction-using-pretrained-model)相关信息进行下载,可将模型存放于新建文件夹`./weight`,将数据存放于新建文件夹`./data`
### 转换
``` shell
cd ../
x2paddle --convert_torch_project --project_dir=CRAFT-pytorch --save_dir=paddle_project --pretrain_model=CRAFT-pytorch/weights/
```
### 运行训练代码
``` shell
cd paddle_project
python test.py --trained_model=weights/craft_mlt_25k.pdiparams --test_folder=data
```
***转换后的代码可在[这里](https://github.com/SunAhong1993/CRAFT-pytorch/tree/paddle)进行查看。***
......@@ -149,7 +149,7 @@ def kaiming_normal_(param, a=0, mode='fan_in', nonlinearity='leaky_relu'):
dtype=param.dtype,
default_initializer=KaimingNormal(
a=a, mode=mode, nonlinearity=nonlinearity))
paddle.assign(param, replaced_param)
paddle.assign(replaced_param, param)
class XavierNormal(XavierInitializer):
......@@ -213,7 +213,7 @@ def xavier_normal_(param, gain=1.0):
shape=param.shape,
dtype=param.dtype,
default_initializer=XavierNormal(gain=gain))
paddle.assign(param, replaced_param)
paddle.assign(replaced_param, param)
class XavierUniform(XavierInitializer):
......@@ -278,7 +278,7 @@ def xavier_uniform_(param, gain=1.0):
shape=param.shape,
dtype=param.dtype,
default_initializer=XavierUniform(gain=gain))
paddle.assign(param, replaced_param)
paddle.assign(replaced_param, param)
def constant_init_(param, val):
......@@ -287,7 +287,7 @@ def constant_init_(param, val):
dtype=param.dtype,
default_initializer=paddle.nn.initializer.Assign(
paddle.full(param.shape, val, param.dtype)))
paddle.assign(param, replaced_param)
paddle.assign(replaced_param, param)
def normal_init_(param, mean=0.0, std=1.0):
......@@ -297,7 +297,7 @@ def normal_init_(param, mean=0.0, std=1.0):
default_initializer=paddle.nn.initializer.Assign(
paddle.normal(
mean=mean, std=std, shape=param.shape)))
paddle.assign(param, replaced_param)
paddle.assign(replaced_param, param)
def ones_init_(param):
......@@ -306,7 +306,7 @@ def ones_init_(param):
dtype=param.dtype,
default_initializer=paddle.nn.initializer.Assign(
paddle.ones(param.shape, param.dtype)))
paddle.assign(param, replaced_param)
paddle.assign(replaced_param, param)
def zeros_init_(param):
......@@ -315,4 +315,4 @@ def zeros_init_(param):
dtype=param.dtype,
default_initializer=paddle.nn.initializer.Assign(
paddle.zeros(param.shape, param.dtype)))
paddle.assign(param, replaced_param)
paddle.assign(replaced_param, param)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册