提交 22f7b9ab 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into rnn

......@@ -20,26 +20,42 @@ def main():
adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=adam_optimizer)
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
if event.batch_id % 1000 == 0:
result = trainer.test(reader=paddle.reader.batched(
paddle.dataset.mnist.test(), batch_size=256))
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics,
result.metrics)
else:
pass
trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
trainer.train(
reader=paddle.reader.batched(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=32),
cost=cost,
event_handler=event_handler)
# output is a softmax layer. It returns probabilities.
# Shape should be (100, 10)
probs = paddle.infer(
output=inference,
parameters=parameters,
event_handler=event_handler,
reader_dict={images.name: 0,
label.name: 1})
reader=paddle.reader.batched(
paddle.reader.firstn(
paddle.reader.map_readers(lambda item: (item[0], ),
paddle.dataset.mnist.test()),
n=100),
batch_size=32))
print probs.shape
if __name__ == '__main__':
......
......@@ -10,6 +10,7 @@
usage/cmd_parameter/index_cn.rst
usage/concepts/use_concepts_cn.rst
usage/cluster/cluster_train_cn.md
usage/k8s/k8s_basis_cn.md
usage/k8s/k8s_cn.md
usage/k8s/k8s_distributed_cn.md
......
......@@ -6,7 +6,7 @@
在本文中,我们将阐释如何在集群上运行分布式 Paddle 训练作业。我们将以[推荐系统](https://github.com/baidu/Paddle/tree/develop/demo/recommendation)为例创建分布式的单进程训练。
在本文中使用的[脚本](https://github.com/baidu/Paddle/tree/develop/paddle/scripts/cluster_train)通过 SSH 运行分布式作业。 它们还可以供那些运行更复杂的集群管理系统(如 MPI 和 [Kubernetes](https://github.com/PaddlePaddle/Paddle/tree/develop/doc/howto/usage/cluster/k8s) )的用户参考。
在本文中使用的[脚本](https://github.com/baidu/Paddle/tree/develop/paddle/scripts/cluster_train)通过 SSH 运行分布式作业。 它们还可以供那些运行更复杂的集群管理系统(如 MPI 和 [Kubernetes](https://github.com/PaddlePaddle/Paddle/tree/develop/doc/howto/usage/k8s) )的用户参考。
## 前提条件
......
......@@ -2,7 +2,7 @@
In this article, we explain how to run distributed Paddle training jobs on clusters. We will create the distributed version of the single-process training example, [recommendation](https://github.com/baidu/Paddle/tree/develop/demo/recommendation).
[Scripts](https://github.com/baidu/Paddle/tree/develop/paddle/scripts/cluster_train) used in this article launch distributed jobs via SSH. They also work as a reference for users running more sophisticated cluster management systems like MPI and [Kubernetes](https://github.com/PaddlePaddle/Paddle/tree/develop/doc/howto/usage/cluster/k8s).
[Scripts](https://github.com/baidu/Paddle/tree/develop/paddle/scripts/cluster_train) used in this article launch distributed jobs via SSH. They also work as a reference for users running more sophisticated cluster management systems like MPI and [Kubernetes](https://github.com/PaddlePaddle/Paddle/tree/develop/doc/howto/usage/k8s).
## Prerequisite
......
# Kubernetes 简介
[*Kubernetes*](http://kubernetes.io/)是Google开源的容器集群管理系统,其提供应用部署、维护、扩展机制等功能,利用Kubernetes能方便地管理跨机器运行容器化的应用。Kubernetes可以在物理机或虚拟机上运行,且支持部署到[AWS](http://kubernetes.io/docs/getting-started-guides/aws)[Azure](http://kubernetes.io/docs/getting-started-guides/azure/)[GCE](http://kubernetes.io/docs/getting-started-guides/gce)等多种公有云环境。介绍分布式训练之前,需要对[Kubernetes](http://kubernetes.io/)有一个基本的认识,下面先简要介绍一下本文用到的几个Kubernetes概念。
- [*Node*](http://kubernetes.io/docs/admin/node/) 表示一个Kubernetes集群中的一个工作节点,这个节点可以是物理机或者虚拟机,Kubernetes集群就是由node节点与master节点组成的。
- [*Pod*](http://kubernetes.io/docs/user-guide/pods/) 是一组(一个或多个)容器,pod是Kubernetes的最小调度单元,一个pod中的所有容器会被调度到同一个node上。Pod中的容器共享NET,PID,IPC,UTS等Linux namespace。由于容器之间共享NET namespace,所以它们使用同一个IP地址,可以通过*localhost*互相通信。不同pod之间可以通过IP地址访问。
- [*Job*](http://kubernetes.io/docs/user-guide/jobs/) 描述Kubernetes上运行的作业,一次作业称为一个job,通常每个job包括一个或者多个pods,job启动后会创建这些pod并开始执行一个程序,等待这个程序执行成功并返回0则成功退出,如果执行失败,也可以配置不同的重试机制。
- [*Volume*](http://kubernetes.io/docs/user-guide/volumes/) 存储卷,是pod内的容器都可以访问的共享目录,也是容器与node之间共享文件的方式,因为容器内的文件都是暂时存在的,当容器因为各种原因被销毁时,其内部的文件也会随之消失。通过volume,就可以将这些文件持久化存储。Kubernetes支持多种volume,例如hostPath(宿主机目录),gcePersistentDisk,awsElasticBlockStore等。
- [*Namespaces*](https://kubernetes.io/docs/user-guide/namespaces/) 命名空间,在kubernetes中创建的所有资源对象(例如上文的pod,job)等都属于一个命名空间,在同一个命名空间中,资源对象的名字是唯一的,不同空间的资源名可以重复,命名空间主要为了对象进行逻辑上的分组便于管理。本文只使用了默认命名空间。
- [*PersistentVolume*](https://kubernetes.io/docs/user-guide/persistent-volumes/): 和[*PersistentVolumeClaim*](https://kubernetes.io/docs/user-guide/persistent-volumes/#persistentvolumeclaims)结合,将外部的存储服务在Kubernetes中描述成为统一的资源形式,便于存储资源管理和Pod引用。
# 部署Kubernetes集群
Kubernetes提供了多种集群部署的方案,本文档内不重复介绍。这里给出集中常见的部署方法:
- [*minikube*](https://kubernetes.io/docs/getting-started-guides/minikube/): 快速在本地启动一个单机的kubernetes服务器,便于本地验证和测试。
- [*kubeadm*](http://kubernetes.io/docs/getting-started-guides/kubeadm/): 在不同操作系统,不同主机(Bare-Metal, AWS, GCE)条件下,快速部署集群。
- [*AWS EC2*](https://kubernetes.io/docs/getting-started-guides/aws/): 在aws上快速部署集群。
- [*Bare-Metal*](https://kubernetes.io/docs/getting-started-guides/centos/centos_manual_config/): 在物理机上手动部署。
可以参考[这个表格](https://kubernetes.io/docs/getting-started-guides/#table-of-solutions)选择适合您的场景的合适方案。
# 选择存储方案
容器不会保留在运行时生成的数据,job或者应用程序在容器中运行时生成的数据会在容器销毁时消失。为了完成分布式机器学习训练任务,需要有一个外部的存储服务来保存训练所需数据和训练输出。
常见的可选存储服务包括:
- [*NFS*](https://github.com/kubernetes/kubernetes/tree/master/examples/volumes/nfs): 可以将磁盘上某个目录共享给网络中其他机器访问。部署和配置比较简单,可以用于小量数据的验证。不提供分布式存储,高可用,冗余等功能。NFS的部署方法可以参考[这里](http://www.tecmint.com/how-to-setup-nfs-server-in-linux/)
- [*GlusterFS*](http://gluster.readthedocs.io/en/latest/Quick-Start-Guide/Quickstart/): 网络分布式文件系统,可以在Kubernetes中按照[这个](https://github.com/kubernetes/kubernetes/tree/master/examples/volumes/glusterfs)例子使用。
- [*Ceph*](http://docs.ceph.com/docs/master/): 分布式文件系统,支持rbd,POSIX API接口(ceph fs)和对象存储API,参考[这里](https://kubernetes.io/docs/user-guide/volumes/#rbd)
- [*MooseFS*](https://moosefs.com/documentation.html): 一个分布式的存储系统。需要先挂载到服务器Node上再通过kubernetes hostPath Volume挂载到容器中。
# 配置kubectl
## 安装kubectl
```
# OS X
curl -LO https://storage.googleapis.com/kubernetes-release/release/$(curl -s https://storage.googleapis.com/kubernetes-release/release/stable.txt)/bin/darwin/amd64/kubectl
# Linux
curl -LO https://storage.googleapis.com/kubernetes-release/release/$(curl -s https://storage.googleapis.com/kubernetes-release/release/stable.txt)/bin/linux/amd64/kubectl
# Windows
curl -LO https://storage.googleapis.com/kubernetes-release/release/$(curl -s https://storage.googleapis.com/kubernetes-release/release/stable.txt)/bin/windows/amd64/kubectl.exe
```
## 配置kubectl访问你的kubernetes集群
编辑`~/.kube/config`这个配置文件,修改`Master-IP`的地址。如果使用SSL认证,则需要配置`certificate-authority``users`中的用户证书。如果是使用非SSL方式访问(比如通过8080端口),也可以去掉这些证书的配置。
```
apiVersion: v1
clusters:
- cluster:
certificate-authority: /path/to/ca.crt
server: https://[Master-IP]:443
name: minikube
contexts:
- context:
cluster: minikube
user: minikube
name: minikube
current-context: minikube
kind: Config
preferences: {}
users:
- name: minikube
user:
client-certificate: /path/to/apiserver.crt
client-key: /Users/wuyi/.minikube/apiserver.key
```
......@@ -2,168 +2,50 @@
前一篇文章介绍了如何在Kubernetes集群上启动一个单机PaddlePaddle训练作业 (Job)。在这篇文章里,我们介绍如何在Kubernetes集群上进行分布式PaddlePaddle训练作业。关于PaddlePaddle的分布式训练,文章 [Cluster Training](https://github.com/baidu/Paddle/blob/develop/doc/cluster/opensource/cluster_train.md)介绍了一种通过SSH远程分发任务,进行分布式训练的方法,与此不同的是,本文将介绍在Kubernetes容器管理平台上快速构建PaddlePaddle容器集群,进行分布式训练的方案。
## Kubernetes 基本概念
[*Kubernetes*](http://kubernetes.io/)是Google开源的容器集群管理系统,其提供应用部署、维护、 扩展机制等功能,利用Kubernetes能方便地管理跨机器运行容器化的应用。Kubernetes可以在物理机或虚拟机上运行,且支持部署到[AWS](http://kubernetes.io/docs/getting-started-guides/aws)[Azure](http://kubernetes.io/docs/getting-started-guides/azure/)[GCE](http://kubernetes.io/docs/getting-started-guides/gce)等多种公有云环境。介绍分布式训练之前,需要对[Kubernetes](http://kubernetes.io/)有一个基本的认识,下面先简要介绍一下本文用到的几个Kubernetes概念。
- [*Node*](http://kubernetes.io/docs/admin/node/) 表示一个Kubernetes集群中的一个工作节点,这个节点可以是物理机或者虚拟机,Kubernetes集群就是由node节点与master节点组成的。
- [*Pod*](http://kubernetes.io/docs/user-guide/pods/) 是一组(一个或多个)容器,pod是Kubernetes的最小调度单元,一个pod中的所有容器会被调度到同一个node上。Pod中的容器共享NET,PID,IPC,UTS等Linux namespace。由于容器之间共享NET namespace,所以它们使用同一个IP地址,可以通过*localhost*互相通信。不同pod之间可以通过IP地址访问。
- [*Job*](http://kubernetes.io/docs/user-guide/jobs/) 是Kubernetes上运行的作业,一次作业称为一个job,通常每个job包括一个或者多个pods。
- [*Volume*](http://kubernetes.io/docs/user-guide/volumes/) 存储卷,是pod内的容器都可以访问的共享目录,也是容器与node之间共享文件的方式,因为容器内的文件都是暂时存在的,当容器因为各种原因被销毁时,其内部的文件也会随之消失。通过volume,就可以将这些文件持久化存储。Kubernetes支持多种volume,例如hostPath(宿主机目录),gcePersistentDisk,awsElasticBlockStore等。
- [*Namespaces*](http://kubernetes.io/docs/user-guide/volumes/) 命名空间,在kubernetes中创建的所有资源对象(例如上文的pod,job)等都属于一个命名空间,在同一个命名空间中,资源对象的名字是唯一的,不同空间的资源名可以重复,命名空间主要为了对象进行逻辑上的分组便于管理。本文只使用了默认命名空间。
有关Kubernetes相关概念以及如何搭建和配置Kubernetes集群,可以参考[k8s_basis](./k8s_basis_cn.md)
## 整体方案
### 部署Kubernetes集群
首先,我们需要拥有一个Kubernetes集群,在这个集群中所有node与pod都可以互相通信。关于Kubernetes集群搭建,可以参考[官方文档](http://kubernetes.io/docs/getting-started-guides/kubeadm/),在以后的文章中我们也会介绍AWS上搭建的方案。本文假设大家能找到几台物理机,并且可以按照官方文档在上面部署Kubernetes。在本文的环境中,Kubernetes集群中所有node都挂载了一个[MFS](http://moosefs.org/)(Moose filesystem,一种分布式文件系统)共享目录,我们通过这个目录来存放训练文件与最终输出的模型。关于MFS的安装部署,可以参考[MooseFS documentation](https://moosefs.com/documentation.html)。在训练之前,用户将配置与训练数据切分好放在MFS目录中,训练时,程序从此目录拷贝文件到容器内进行训练,将结果保存到此目录里。整体的结构图如下:
在训练之前,用户将配置与训练数据切分好放在分布式文件系统预先分配好的目录中(不同的分布式文件系统,需要使用其制定的方式挂载后并导入数据),训练时,程序从此目录拷贝文件到容器内进行训练,将结果保存到此目录里。整体的结构图如下:
![paddle on kubernetes结构图](src/k8s-paddle-arch.png)
上图描述了一个3节点的分布式训练场景,Kubernetes集群的每个node上都挂载了一个MFS目录,这个目录可以通过volume的形式挂载到容器中。Kubernetes为这次训练创建了3个pod并且调度到了3个node上运行,每个pod包含一个PaddlePaddle容器。在容器创建后,会启动pserver与trainer进程,读取volume中的数据进行这次分布式训练。
### 使用 Job
我们使用Kubernetes中的job这个概念来代表一次分布式训练。Job表示一次性作业,在作业完成后,Kubernetes会销毁job产生的容器并且释放相关资源。
在Kubernetes中,可以通过编写一个YAML文件,来描述这个job,在这个文件中,主要包含了一些配置信息,例如PaddlePaddle的节点个数,`paddle pserver`开放的端口个数与端口号,使用的网卡设备等,这些信息通过环境变量的形式传递给容器内的程序使用。
上图描述了一个3节点的分布式训练场景,在每个Pod上都通过volume方式挂载分布式文件系统的一个目录用于保存训练数据和输出结果。Kubernetes为这次训练创建了3个pod并且调度到了3个node上运行,每个pod包含一个PaddlePaddle容器。在容器创建后,会启动pserver与trainer进程,读取volume中的数据进行这次分布式训练。
在一次分布式训练中,用户确定好本次训练需要的PaddlePaddle节点个数,将切分好的训练数据与配置文件上传到MFS共享目录中。然后编写这次训练的job YAML文件,提交给Kubernetes集群创建并开始作业。
根据前文的描述,要在已有的Kubernetes集群上进行PaddlePaddle的分布式训练,按照下面步骤即可:
### 创建PaddlePaddle节点
当Kubernetes master收到请求,解析完YAML文件后,会创建出多个pod(个数为PaddlePaddle节点数),Kubernetes会把这些pod调度到集群的node上运行。一个pod就代表一个PaddlePaddle节点,当pod被成功分配到一台物理/虚拟机上后,Kubernetes会启动pod内的容器,这个容器会根据YAML文件中的环境变量,启动`paddle pserver``paddle train`进程。
### 启动训练
在容器启动后,会通过脚本来启动这次分布式训练,我们知道`paddle train`进程启动时需要知道其他节点的IP地址以及本节点的trainer_id,由于PaddlePaddle本身不提供类似服务发现的功能,所以在本文的启动脚本中,每个节点会根据job name向Kubernetes apiserver查询这个job对应的所有pod信息(Kubernetes默认会在每个容器的环境变量中写入apiserver的地址)。
根据这些pod信息,就可以通过某种方式,为每个pod分配一个唯一的trainer_id。本文把所有pod的IP地址进行排序,将顺序作为每个PaddlePaddle节点的trainer_id。启动脚本的工作流程大致如下:
1. 查询Kubernetes apiserver获取pod信息,根据IP分配trainer_id
1. 从MFS共享目录中拷贝训练文件到容器内
1. 根据环境变量,解析出`paddle pserver``paddle train`的启动参数,启动进程
1. 训练时,PaddlePaddle会自动将结果保存在trainer_id为0的节点上,将输出路径设置为MFS目录,保存输出的文件
## 搭建过程
根据前文的描述,要在已有的Kubernetes集群上进行PaddlePaddle的分布式训练,主要分为以下几个步骤:
1. 制作PaddlePaddle镜像
1. 将训练文件与切分好的数据上传到共享存储
1. 编写本次训练的YAML文件,创建一个Kubernetes job
1. 训练结束后查看输出结果
1. [制作PaddlePaddle镜像](#制作镜像)
1. [将训练文件与切分好的数据上传到共享存储](#上传训练文件)
1. [编写本次训练的YAML文件,创建一个Kubernetes job](#创建Job)
1. [训练结束后查看输出结果](#查看输出)
下面就根据这几个步骤分别介绍。
### 制作镜像
PaddlePaddle镜像需要提供`paddle pserver``paddle train`进程的运行环境,用这个镜像创建的容器需要有以下两个功能:
- 拷贝训练文件到容器内
- 生成`paddle pserver``paddle train`进程的启动参数,并且启动训练
因为官方镜像 `paddledev/paddle:cpu-latest` 内已经包含PaddlePaddle的执行程序但是还没上述功能,所以我们可以在这个基础上,添加启动脚本,制作新镜像来完成以上的工作。镜像的*Dockerfile*如下:
```Dockerfile
FROM paddledev/paddle:cpu-latest
MAINTAINER zjsxzong89@gmail.com
COPY start.sh /root/
COPY start_paddle.py /root/
CMD ["bash"," -c","/root/start.sh"]
```
[start.sh](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/usage/cluster/k8s/start.sh)文件拷贝训练文件到容器内,然后执行[start_paddle.py](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/usage/cluster/k8s/start_paddle.py)脚本启动训练,前文提到的获取其他节点IP地址,分配`trainer_id`等都在`start_paddle.py`脚本中完成。
`start_paddle.py`脚本开始时,会先进行参数的初始化与解析。
```python
parser = argparse.ArgumentParser(prog="start_paddle.py",
description='simple tool for k8s')
args, train_args_list = parser.parse_known_args()
train_args = refine_unknown_args(train_args_list)
train_args_dict = dict(zip(train_args[:-1:2], train_args[1::2]))
podlist = getPodList()
```
然后通过函数`getPodList()`访问Kubernetes的接口来查询此job对应的所有pod信息。当所有pod都处于running状态(容器运行都运行)时,再通过函数`getIdMap(podlist)`获取trainer_id。
```python
podlist = getPodList()
# need to wait until all pods are running
while not isPodAllRunning(podlist):
time.sleep(10)
podlist = getPodList()
idMap = getIdMap(podlist)
```
在函数`getIdMap(podlist)`内部,我们通过读取`podlist`中每个pod的IP地址,将IP排序生成的序号作为trainer_id。
```python
def getIdMap(podlist):
'''
generate tainer_id by ip
'''
ips = []
for pod in podlist["items"]:
ips.append(pod["status"]["podIP"])
ips.sort()
idMap = {}
for i in range(len(ips)):
idMap[ips[i]] = i
return idMap
```
在得到`idMap`后,通过函数`startPaddle(idMap, train_args_dict)`构造`paddle pserver``paddle train`的启动参数并执行进程。
在函数`startPaddle`中,最主要的工作就是解析出`paddle pserver``paddle train`的启动参数。例如`paddle train`参数的解析,解析环境变量得到`PADDLE_NIC``PADDLE_PORT``PADDLE_PORTS_NUM`等参数,然后通过自身的IP地址在`idMap`中获取`trainerId`
```python
program = 'paddle train'
args = " --nics=" + PADDLE_NIC
args += " --port=" + str(PADDLE_PORT)
args += " --ports_num=" + str(PADDLE_PORTS_NUM)
args += " --comment=" + "paddle_process_by_paddle"
ip_string = ""
for ip in idMap.keys():
ip_string += (ip + ",")
ip_string = ip_string.rstrip(",")
args += " --pservers=" + ip_string
args_ext = ""
for key, value in train_args_dict.items():
args_ext += (' --' + key + '=' + value)
localIP = socket.gethostbyname(socket.gethostname())
trainerId = idMap[localIP]
args += " " + args_ext + " --trainer_id=" + \
str(trainerId) + " --save_dir=" + JOB_PATH_OUTPUT
```
使用 `docker build` 构建镜像:
因为官方镜像 `paddledev/paddle:cpu-latest` 内已经包含PaddlePaddle的执行程序但是还没上述功能,所以我们可以在这个基础上,添加启动脚本,制作新镜像来完成以上的工作。参考镜像的[*Dockerfile*](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/usage/cluster/k8s/src/k8s_train/Dockerfile)
```bash
docker build -t your_repo/paddle:mypaddle .
$ cd doc/howto/usage/k8s/src/k8s_train
$ docker build -t [YOUR_REPO]/paddle:mypaddle .
```
然后将构建成功的镜像上传到镜像仓库。
```bash
docker push your_repo/paddle:mypaddle
docker push [YOUR_REPO]/paddle:mypaddle
```
注意上述命令中`your_repo`表示读者所使用的Docker镜像仓库地址,读者需要替换成自己使用的仓库地址。下文使用`your_repo/paddle:mypaddle`这个地址来表示此步骤所构建出的镜像。
注意上述命令中`[YOUR_REPO]`表示读者所使用的Docker镜像仓库地址,读者需要替换成自己使用的仓库地址。下文使用`[YOUR_REPO]/paddle:mypaddle`这个地址来表示此步骤所构建出的镜像。
### 上传训练文件
本文使用PaddlePaddle官方的[recommendation demo](http://www.paddlepaddle.org/doc/demo/index.html#recommendation)作为这次训练的内容,我们将训练文件与数据放在一个job name命名的目录中,上传到MFS共享存储。完成后MFS上的文件内容大致如下:
本文使用PaddlePaddle官方的[recommendation demo](http://www.paddlepaddle.org/doc/demo/index.html#recommendation)作为这次训练的内容,我们将训练文件与数据放在一个job name命名的目录中,上传到volume所在的共享存储(使用不同分布式存储会有不同的挂载方式,需要要先挂载这个目录,然后拷贝数据)。完成后volume中的文件内容大致如下:
```bash
[root@paddle-kubernetes-node0 mfs]# tree -d
......@@ -205,7 +87,7 @@ spec:
path: /home/work/mfs
containers:
- name: trainer
image: your_repo/paddle:mypaddle
image: [YOUR_REPO]/paddle:mypaddle
command: ["bin/bash", "-c", "/root/start.sh"]
env:
- name: JOB_NAME
......@@ -289,8 +171,8 @@ I1116 09:10:17.123121 50 Util.cpp:155] commandline:
--ports_num=2 --comment=paddle_process_by_paddle
--pservers=192.168.129.66,192.168.223.143,192.168.129.71
--ports_num_for_sparse=2 --config=./trainer_config.py
--trainer_count=4 --num_passes=10 --use_gpu=0
--log_period=50 --dot_period=10 --saving_period=1
--trainer_count=4 --num_passes=10 --use_gpu=0
--log_period=50 --dot_period=10 --saving_period=1
--local=0 --trainer_id=0
--save_dir=/home/jobpath/paddle-cluster-job/output
I1116 09:10:17.123440 50 Util.cpp:130] Calling runInitFunctions
......@@ -310,3 +192,90 @@ I1116 09:10:18.019492 50 ParameterClient2.cpp:122] pserver 3 192.168.223.143:
I1116 09:10:18.019716 50 ParameterClient2.cpp:122] pserver 4 192.168.129.71:7164
I1116 09:10:18.019836 50 ParameterClient2.cpp:122] pserver 5 192.168.129.71:7165
```
## 一些细节的补充
### 使用环境变量
使用容器方式运行训练任务的Kubernetes Job,通常会使用环境变量配置Job的配置信息`start_paddle.py`提供了一个启动脚本,将环境变量转换成paddle的命令行参数:
```
API = "/api/v1/namespaces/"
JOBSELECTOR = "labelSelector=job-name="
JOB_PATH = os.getenv("JOB_PATH") + "/" + os.getenv("JOB_NAME")
JOB_PATH_OUTPUT = JOB_PATH + "/output"
JOBNAME = os.getenv("JOB_NAME")
NAMESPACE = os.getenv("JOB_NAMESPACE")
PADDLE_NIC = os.getenv("CONF_PADDLE_NIC")
PADDLE_PORT = os.getenv("CONF_PADDLE_PORT")
PADDLE_PORTS_NUM = os.getenv("CONF_PADDLE_PORTS_NUM")
PADDLE_PORTS_NUM_SPARSE = os.getenv("CONF_PADDLE_PORTS_NUM_SPARSE")
PADDLE_SERVER_NUM = os.getenv("CONF_PADDLE_GRADIENT_NUM")
```
### Pod间通信
`start_paddle.py`脚本开始时,会先进行参数的初始化与解析。
```python
parser = argparse.ArgumentParser(prog="start_paddle.py",
description='simple tool for k8s')
args, train_args_list = parser.parse_known_args()
train_args = refine_unknown_args(train_args_list)
train_args_dict = dict(zip(train_args[:-1:2], train_args[1::2]))
podlist = getPodList()
```
然后通过函数`getPodList()`访问Kubernetes的接口来查询此job对应的所有pod信息。当所有pod都处于running状态(容器运行都运行)时,再通过函数`getIdMap(podlist)`获取trainer_id。
```python
podlist = getPodList()
# need to wait until all pods are running
while not isPodAllRunning(podlist):
time.sleep(10)
podlist = getPodList()
idMap = getIdMap(podlist)
```
* *注意*: `getPodList()`会获取当前namespace下的所有pod,如果已经有pod运行,可能会导致出错。这种集群节点管理方式会在将来使用[statfulsets](https://kubernetes.io/docs/concepts/abstractions/controllers/statefulsets/)代替。
在函数`getIdMap(podlist)`内部,我们通过读取`podlist`中每个pod的IP地址,将IP排序生成的序号作为trainer_id。
```python
def getIdMap(podlist):
'''
generate tainer_id by ip
'''
ips = []
for pod in podlist["items"]:
ips.append(pod["status"]["podIP"])
ips.sort()
idMap = {}
for i in range(len(ips)):
idMap[ips[i]] = i
return idMap
```
在得到`idMap`后,通过函数`startPaddle(idMap, train_args_dict)`构造`paddle pserver``paddle train`的启动参数并执行进程。
### 启动任务
在函数`startPaddle`中,最主要的工作就是解析出`paddle pserver``paddle train`的启动参数。例如`paddle train`参数的解析,解析环境变量得到`PADDLE_NIC``PADDLE_PORT``PADDLE_PORTS_NUM`等参数,然后通过自身的IP地址在`idMap`中获取`trainerId`
```python
program = 'paddle train'
args = " --nics=" + PADDLE_NIC
args += " --port=" + str(PADDLE_PORT)
args += " --ports_num=" + str(PADDLE_PORTS_NUM)
args += " --comment=" + "paddle_process_by_paddle"
ip_string = ""
for ip in idMap.keys():
ip_string += (ip + ",")
ip_string = ip_string.rstrip(",")
args += " --pservers=" + ip_string
args_ext = ""
for key, value in train_args_dict.items():
args_ext += (' --' + key + '=' + value)
localIP = socket.gethostbyname(socket.gethostname())
trainerId = idMap[localIP]
args += " " + args_ext + " --trainer_id=" + \
str(trainerId) + " --save_dir=" + JOB_PATH_OUTPUT
```
......@@ -47,6 +47,9 @@ void setUseGpu(bool useGpu);
/// Return true if this py_paddle is compiled in GPU Version
bool isGpuVersion();
/// Return FLAGS_trainer_count
int getTrainerCount();
/// The Error of IO Operation. Such as file not found, etc.
class IOError {};
......
......@@ -54,5 +54,7 @@ bool isGpuVersion() {
#endif
}
int getTrainerCount() { return FLAGS_trainer_count; }
static_assert(NUM_PARAMETER_TYPES == paddle::NUM_PARAMETER_TYPES,
"The Parameter Type should be same in core/api and core/common");
......@@ -26,6 +26,15 @@ class IScanner(object):
if not isinstance(self.input_type, dp2.InputType):
raise ValueError("input type should be dataprovider2.InputType")
self.pos = pos
# data_in_gpu is used to indicate whether to create argument on GPU
# or not in GPU mode. Now if using one thread (trainer_count=1),
# trainer uses NeuralNetwork which needs to create argument on GPU
# before calling forward function. So, set data_in_gpu to True.
# Otherwise, trainer uses MultiGradientMachine which will transfer
# data from CPU to GPU in the forward function, set data_in_gpu to
# False in this case.
self.data_in_gpu = swig_paddle.isUsingGpu(
) and swig_paddle.getTrainerCount() == 1
def scan(self, dat):
pass
......@@ -53,7 +62,8 @@ class DenseScanner(IScanner):
assert isinstance(argument, swig_paddle.Arguments)
if self.__mat__.dtype != numpy.float32:
self.__mat__ = self.__mat__.astype(numpy.float32)
m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, False)
m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True,
self.data_in_gpu)
argument.setSlotValue(self.pos, m)
......@@ -75,10 +85,13 @@ class SparseBinaryScanner(IScanner):
def finish_scan(self, argument):
assert isinstance(argument, swig_paddle.Arguments)
m = swig_paddle.Matrix.createSparse(self.__height__,
self.input_type.dim,
len(self.__cols__),
len(self.__value__) == 0)
m = swig_paddle.Matrix.createSparse(
self.__height__,
self.input_type.dim,
len(self.__cols__),
len(self.__value__) == 0,
False, # trans
False) # TODO supoort GPU
assert isinstance(m, swig_paddle.Matrix)
m.sparseCopyFrom(self.__rows__, self.__cols__, self.__value__)
argument.setSlotValue(self.pos, m)
......@@ -102,7 +115,7 @@ class IndexScanner(IScanner):
self.__ids__.append(dat)
def finish_scan(self, argument):
ids = swig_paddle.IVector.create(self.__ids__)
ids = swig_paddle.IVector.create(self.__ids__, self.data_in_gpu)
assert isinstance(argument, swig_paddle.Arguments)
argument.setSlotIds(self.pos, ids)
......
......@@ -5,38 +5,50 @@ ARG DEBIAN_FRONTEND=noninteractive
ARG UBUNTU_MIRROR
RUN /bin/bash -c 'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i 's#http://archive.ubuntu.com#${UBUNTU_MIRROR}#g' /etc/apt/sources.list; fi'
# ENV variables
ARG BUILD_WOBOQ
ARG BUILD_AND_INSTALL
ARG WITH_AVX
ARG WITH_DOC
ARG WITH_STYLE_CHECK
ENV BUILD_WOBOQ=${BUILD_WOBOQ:-OFF}
ENV BUILD_AND_INSTALL=${BUILD_AND_INSTALL:-OFF}
ENV WITH_GPU=OFF
ENV WITH_AVX=${WITH_AVX:-ON}
ENV WITH_DOC=${WITH_DOC:-OFF}
ENV WITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
ENV HOME /root
# Add bash enhancements
COPY ./paddle/scripts/docker/root/ /root/
RUN apt-get update && \
apt-get install -y git python-pip python-dev openssh-server bison && \
apt-get install -y wget unzip tar xz-utils bzip2 gzip coreutils && \
apt-get install -y curl sed grep graphviz libjpeg-dev zlib1g-dev && \
apt-get install -y python-numpy python-matplotlib gcc g++ gfortran && \
apt-get install -y automake && \
apt-get install -y automake locales clang-format-3.8 && \
apt-get clean -y
# git credential to skip password typing
RUN git config --global credential.helper store
# Fix locales to en_US.UTF-8
RUN localedef -i en_US -f UTF-8 en_US.UTF-8
RUN pip install --upgrade pip && \
pip install -U "protobuf==3.1.0" && \
pip install -U 'protobuf==3.1.0' && \
pip install -U wheel pillow BeautifulSoup && \
pip install -U docopt PyYAML sphinx && \
pip install -U sphinx_rtd_theme recommonmark jupyter
pip install -U sphinx_rtd_theme recommonmark && \
pip install -U pre-commit 'requests==2.9.2' jupyter
RUN curl -sSL https://cmake.org/files/v3.4/cmake-3.4.1.tar.gz | tar -xz && \
cd cmake-3.4.1 && ./bootstrap && make -j `nproc` && make install && \
cd .. && rm -rf cmake-3.4.1
ARG BUILD_WOBOQ
ARG BUILD_AND_INSTALL
ARG WITH_AVX
ARG WITH_DOC
ARG WITH_STYLE_CHECK
ENV BUILD_WOBOQ=${BUILD_WOBOQ:-OFF}
ENV BUILD_AND_INSTALL=${BUILD_AND_INSTALL:-OFF}
ENV WITH_GPU=OFF
ENV WITH_AVX=${WITH_AVX:-ON}
ENV WITH_DOC=${WITH_DOC:-OFF}
ENV WITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
RUN mkdir /paddle
COPY . /paddle/
RUN /paddle/paddle/scripts/docker/build.sh
VOLUME ["/usr/share/nginx/html/data", "/usr/share/nginx/html/paddle"]
......@@ -53,7 +65,6 @@ RUN mkdir /notes/
WORKDIR "/notes"
EXPOSE 8888
RUN mkdir -p /opt/bin
COPY ./paddle/scripts/docker/entrypoint /opt/bin/
CMD ["/opt/bin/entrypoint"]
......@@ -5,38 +5,50 @@ ARG DEBIAN_FRONTEND=noninteractive
ARG UBUNTU_MIRROR
RUN /bin/bash -c 'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i 's#http://archive.ubuntu.com#${UBUNTU_MIRROR}#g' /etc/apt/sources.list; fi'
# ENV variables
ARG BUILD_WOBOQ
ARG BUILD_AND_INSTALL
ARG WITH_AVX
ARG WITH_DOC
ARG WITH_STYLE_CHECK
ENV BUILD_WOBOQ=${BUILD_WOBOQ:-OFF}
ENV BUILD_AND_INSTALL=${BUILD_AND_INSTALL:-OFF}
ENV WITH_GPU=ON
ENV WITH_AVX=${WITH_AVX:-ON}
ENV WITH_DOC=${WITH_DOC:-OFF}
ENV WITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
ENV HOME /root
# Add bash enhancements
COPY ./paddle/scripts/docker/root/ /root/
RUN apt-get update && \
apt-get install -y git python-pip python-dev openssh-server bison && \
apt-get install -y wget unzip tar xz-utils bzip2 gzip coreutils && \
apt-get install -y curl sed grep graphviz libjpeg-dev zlib1g-dev && \
apt-get install -y python-numpy python-matplotlib gcc g++ gfortran && \
apt-get install -y automake && \
apt-get install -y automake locales clang-format-3.8 && \
apt-get clean -y
# git credential to skip password typing
RUN git config --global credential.helper store
# Fix locales to en_US.UTF-8
RUN localedef -i en_US -f UTF-8 en_US.UTF-8
RUN pip install --upgrade pip && \
pip install -U "protobuf==3.1.0" && \
pip install -U 'protobuf==3.1.0' && \
pip install -U wheel pillow BeautifulSoup && \
pip install -U docopt PyYAML sphinx && \
pip install -U sphinx_rtd_theme recommonmark jupyter
pip install -U sphinx_rtd_theme recommonmark && \
pip install -U pre-commit 'requests==2.9.2' jupyter
RUN curl -sSL https://cmake.org/files/v3.4/cmake-3.4.1.tar.gz | tar -xz && \
cd cmake-3.4.1 && ./bootstrap && make -j `nproc` && make install && \
cd .. && rm -rf cmake-3.4.1
ARG BUILD_WOBOQ
ARG BUILD_AND_INSTALL
ARG WITH_AVX
ARG WITH_DOC
ARG WITH_STYLE_CHECK
ENV BUILD_WOBOQ=${BUILD_WOBOQ:-OFF}
ENV BUILD_AND_INSTALL=${BUILD_AND_INSTALL:-OFF}
ENV WITH_GPU=ON
ENV WITH_AVX=${WITH_AVX:-ON}
ENV WITH_DOC=${WITH_DOC:-OFF}
ENV WITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
RUN mkdir /paddle
COPY . /paddle/
RUN /paddle/paddle/scripts/docker/build.sh
VOLUME ["/usr/share/nginx/html/data", "/usr/share/nginx/html/paddle"]
......@@ -53,7 +65,6 @@ RUN mkdir /notes/
WORKDIR "/notes"
EXPOSE 8888
RUN mkdir -p /opt/bin
COPY ./paddle/scripts/docker/entrypoint /opt/bin/
CMD ["/opt/bin/entrypoint"]
# Locales
export LC_ALL=en_US.UTF-8
export LANG=en_US.UTF-8
export LANGUAGE=en_US.UTF-8
# Aliases
alias rm='rm -i'
alias cp='cp -i'
alias mv='mv -i'
alias ls='ls -hFG'
alias l='ls -lF'
alias ll='ls -alF'
alias lt='ls -ltrF'
alias ll='ls -alF'
alias lls='ls -alSrF'
alias llt='ls -altrF'
# Colorize directory listing
alias ls="ls -ph --color=auto"
# Colorize grep
if echo hello|grep --color=auto l >/dev/null 2>&1; then
export GREP_OPTIONS="--color=auto" GREP_COLOR="1;31"
fi
# Shell
export CLICOLOR="1"
YELLOW="\[\033[1;33m\]"
NO_COLOUR="\[\033[0m\]"
GREEN="\[\033[1;32m\]"
WHITE="\[\033[1;37m\]"
source ~/.scripts/git-prompt.sh
export PS1="\[\033[1;33m\]λ $WHITE\h $GREEN\w$YELLOW\$(__git_ps1 \" \[\033[35m\]{\[\033[36m\]%s\[\033[35m\]}\")$NO_COLOUR "
# Git
source ~/.scripts/git-completion.sh
[user]
name =
email =
[alias]
st = status --branch --short
ci = commit
br = branch
co = checkout
df = diff
l = log --pretty=format:\"%h %ad | %s%d [%an]\" --graph --date=short
ll = log --stat
[merge]
tool = vimdiff
[core]
excludesfile = ~/.gitignore
editor = vim
[color]
branch = auto
diff = auto
status = auto
[color "branch"]
current = yellow reverse
local = yellow
remote = green
[color "diff"]
meta = yellow bold
frag = magenta bold
old = red bold
new = green bold
[color "status"]
added = yellow
changed = green
untracked = cyan
[push]
default = matching
\ No newline at end of file
#!bash
#
# bash/zsh completion support for core Git.
#
# Copyright (C) 2006,2007 Shawn O. Pearce <spearce@spearce.org>
# Conceptually based on gitcompletion (http://gitweb.hawaga.org.uk/).
# Distributed under the GNU General Public License, version 2.0.
#
# The contained completion routines provide support for completing:
#
# *) local and remote branch names
# *) local and remote tag names
# *) .git/remotes file names
# *) git 'subcommands'
# *) tree paths within 'ref:path/to/file' expressions
# *) file paths within current working directory and index
# *) common --long-options
#
# To use these routines:
#
# 1) Copy this file to somewhere (e.g. ~/.git-completion.sh).
# 2) Add the following line to your .bashrc/.zshrc:
# source ~/.git-completion.sh
# 3) Consider changing your PS1 to also show the current branch,
# see git-prompt.sh for details.
case "$COMP_WORDBREAKS" in
*:*) : great ;;
*) COMP_WORDBREAKS="$COMP_WORDBREAKS:"
esac
# __gitdir accepts 0 or 1 arguments (i.e., location)
# returns location of .git repo
__gitdir ()
{
if [ -z "${1-}" ]; then
if [ -n "${__git_dir-}" ]; then
echo "$__git_dir"
elif [ -n "${GIT_DIR-}" ]; then
test -d "${GIT_DIR-}" || return 1
echo "$GIT_DIR"
elif [ -d .git ]; then
echo .git
else
git rev-parse --git-dir 2>/dev/null
fi
elif [ -d "$1/.git" ]; then
echo "$1/.git"
else
echo "$1"
fi
}
# The following function is based on code from:
#
# bash_completion - programmable completion functions for bash 3.2+
#
# Copyright © 2006-2008, Ian Macdonald <ian@caliban.org>
# © 2009-2010, Bash Completion Maintainers
# <bash-completion-devel@lists.alioth.debian.org>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#
# The latest version of this software can be obtained here:
#
# http://bash-completion.alioth.debian.org/
#
# RELEASE: 2.x
# This function can be used to access a tokenized list of words
# on the command line:
#
# __git_reassemble_comp_words_by_ref '=:'
# if test "${words_[cword_-1]}" = -w
# then
# ...
# fi
#
# The argument should be a collection of characters from the list of
# word completion separators (COMP_WORDBREAKS) to treat as ordinary
# characters.
#
# This is roughly equivalent to going back in time and setting
# COMP_WORDBREAKS to exclude those characters. The intent is to
# make option types like --date=<type> and <rev>:<path> easy to
# recognize by treating each shell word as a single token.
#
# It is best not to set COMP_WORDBREAKS directly because the value is
# shared with other completion scripts. By the time the completion
# function gets called, COMP_WORDS has already been populated so local
# changes to COMP_WORDBREAKS have no effect.
#
# Output: words_, cword_, cur_.
__git_reassemble_comp_words_by_ref()
{
local exclude i j first
# Which word separators to exclude?
exclude="${1//[^$COMP_WORDBREAKS]}"
cword_=$COMP_CWORD
if [ -z "$exclude" ]; then
words_=("${COMP_WORDS[@]}")
return
fi
# List of word completion separators has shrunk;
# re-assemble words to complete.
for ((i=0, j=0; i < ${#COMP_WORDS[@]}; i++, j++)); do
# Append each nonempty word consisting of just
# word separator characters to the current word.
first=t
while
[ $i -gt 0 ] &&
[ -n "${COMP_WORDS[$i]}" ] &&
# word consists of excluded word separators
[ "${COMP_WORDS[$i]//[^$exclude]}" = "${COMP_WORDS[$i]}" ]
do
# Attach to the previous token,
# unless the previous token is the command name.
if [ $j -ge 2 ] && [ -n "$first" ]; then
((j--))
fi
first=
words_[$j]=${words_[j]}${COMP_WORDS[i]}
if [ $i = $COMP_CWORD ]; then
cword_=$j
fi
if (($i < ${#COMP_WORDS[@]} - 1)); then
((i++))
else
# Done.
return
fi
done
words_[$j]=${words_[j]}${COMP_WORDS[i]}
if [ $i = $COMP_CWORD ]; then
cword_=$j
fi
done
}
if ! type _get_comp_words_by_ref >/dev/null 2>&1; then
_get_comp_words_by_ref ()
{
local exclude cur_ words_ cword_
if [ "$1" = "-n" ]; then
exclude=$2
shift 2
fi
__git_reassemble_comp_words_by_ref "$exclude"
cur_=${words_[cword_]}
while [ $# -gt 0 ]; do
case "$1" in
cur)
cur=$cur_
;;
prev)
prev=${words_[$cword_-1]}
;;
words)
words=("${words_[@]}")
;;
cword)
cword=$cword_
;;
esac
shift
done
}
fi
__gitcompadd ()
{
local i=0
for x in $1; do
if [[ "$x" == "$3"* ]]; then
COMPREPLY[i++]="$2$x$4"
fi
done
}
# Generates completion reply, appending a space to possible completion words,
# if necessary.
# It accepts 1 to 4 arguments:
# 1: List of possible completion words.
# 2: A prefix to be added to each possible completion word (optional).
# 3: Generate possible completion matches for this word (optional).
# 4: A suffix to be appended to each possible completion word (optional).
__gitcomp ()
{
local cur_="${3-$cur}"
case "$cur_" in
--*=)
;;
*)
local c i=0 IFS=$' \t\n'
for c in $1; do
c="$c${4-}"
if [[ $c == "$cur_"* ]]; then
case $c in
--*=*|*.) ;;
*) c="$c " ;;
esac
COMPREPLY[i++]="${2-}$c"
fi
done
;;
esac
}
# Generates completion reply from newline-separated possible completion words
# by appending a space to all of them.
# It accepts 1 to 4 arguments:
# 1: List of possible completion words, separated by a single newline.
# 2: A prefix to be added to each possible completion word (optional).
# 3: Generate possible completion matches for this word (optional).
# 4: A suffix to be appended to each possible completion word instead of
# the default space (optional). If specified but empty, nothing is
# appended.
__gitcomp_nl ()
{
local IFS=$'\n'
__gitcompadd "$1" "${2-}" "${3-$cur}" "${4- }"
}
# Generates completion reply with compgen from newline-separated possible
# completion filenames.
# It accepts 1 to 3 arguments:
# 1: List of possible completion filenames, separated by a single newline.
# 2: A directory prefix to be added to each possible completion filename
# (optional).
# 3: Generate possible completion matches for this word (optional).
__gitcomp_file ()
{
local IFS=$'\n'
# XXX does not work when the directory prefix contains a tilde,
# since tilde expansion is not applied.
# This means that COMPREPLY will be empty and Bash default
# completion will be used.
__gitcompadd "$1" "${2-}" "${3-$cur}" ""
# use a hack to enable file mode in bash < 4
compopt -o filenames +o nospace 2>/dev/null ||
compgen -f /non-existing-dir/ > /dev/null
}
# Execute 'git ls-files', unless the --committable option is specified, in
# which case it runs 'git diff-index' to find out the files that can be
# committed. It return paths relative to the directory specified in the first
# argument, and using the options specified in the second argument.
__git_ls_files_helper ()
{
(
test -n "${CDPATH+set}" && unset CDPATH
cd "$1"
if [ "$2" == "--committable" ]; then
git diff-index --name-only --relative HEAD
else
# NOTE: $2 is not quoted in order to support multiple options
git ls-files --exclude-standard $2
fi
) 2>/dev/null
}
# __git_index_files accepts 1 or 2 arguments:
# 1: Options to pass to ls-files (required).
# 2: A directory path (optional).
# If provided, only files within the specified directory are listed.
# Sub directories are never recursed. Path must have a trailing
# slash.
__git_index_files ()
{
local dir="$(__gitdir)" root="${2-.}" file
if [ -d "$dir" ]; then
__git_ls_files_helper "$root" "$1" |
while read -r file; do
case "$file" in
?*/*) echo "${file%%/*}" ;;
*) echo "$file" ;;
esac
done | sort | uniq
fi
}
__git_heads ()
{
local dir="$(__gitdir)"
if [ -d "$dir" ]; then
git --git-dir="$dir" for-each-ref --format='%(refname:short)' \
refs/heads
return
fi
}
__git_tags ()
{
local dir="$(__gitdir)"
if [ -d "$dir" ]; then
git --git-dir="$dir" for-each-ref --format='%(refname:short)' \
refs/tags
return
fi
}
# __git_refs accepts 0, 1 (to pass to __gitdir), or 2 arguments
# presence of 2nd argument means use the guess heuristic employed
# by checkout for tracking branches
__git_refs ()
{
local i hash dir="$(__gitdir "${1-}")" track="${2-}"
local format refs
if [ -d "$dir" ]; then
case "$cur" in
refs|refs/*)
format="refname"
refs="${cur%/*}"
track=""
;;
*)
for i in HEAD FETCH_HEAD ORIG_HEAD MERGE_HEAD; do
if [ -e "$dir/$i" ]; then echo $i; fi
done
format="refname:short"
refs="refs/tags refs/heads refs/remotes"
;;
esac
git --git-dir="$dir" for-each-ref --format="%($format)" \
$refs
if [ -n "$track" ]; then
# employ the heuristic used by git checkout
# Try to find a remote branch that matches the completion word
# but only output if the branch name is unique
local ref entry
git --git-dir="$dir" for-each-ref --shell --format="ref=%(refname:short)" \
"refs/remotes/" | \
while read -r entry; do
eval "$entry"
ref="${ref#*/}"
if [[ "$ref" == "$cur"* ]]; then
echo "$ref"
fi
done | sort | uniq -u
fi
return
fi
case "$cur" in
refs|refs/*)
git ls-remote "$dir" "$cur*" 2>/dev/null | \
while read -r hash i; do
case "$i" in
*^{}) ;;
*) echo "$i" ;;
esac
done
;;
*)
echo "HEAD"
git for-each-ref --format="%(refname:short)" -- "refs/remotes/$dir/" | sed -e "s#^$dir/##"
;;
esac
}
# __git_refs2 requires 1 argument (to pass to __git_refs)
__git_refs2 ()
{
local i
for i in $(__git_refs "$1"); do
echo "$i:$i"
done
}
# __git_refs_remotes requires 1 argument (to pass to ls-remote)
__git_refs_remotes ()
{
local i hash
git ls-remote "$1" 'refs/heads/*' 2>/dev/null | \
while read -r hash i; do
echo "$i:refs/remotes/$1/${i#refs/heads/}"
done
}
__git_remotes ()
{
local i IFS=$'\n' d="$(__gitdir)"
test -d "$d/remotes" && ls -1 "$d/remotes"
for i in $(git --git-dir="$d" config --get-regexp 'remote\..*\.url' 2>/dev/null); do
i="${i#remote.}"
echo "${i/.url*/}"
done
}
__git_list_merge_strategies ()
{
git merge -s help 2>&1 |
sed -n -e '/[Aa]vailable strategies are: /,/^$/{
s/\.$//
s/.*://
s/^[ ]*//
s/[ ]*$//
p
}'
}
__git_merge_strategies=
# 'git merge -s help' (and thus detection of the merge strategy
# list) fails, unfortunately, if run outside of any git working
# tree. __git_merge_strategies is set to the empty string in
# that case, and the detection will be repeated the next time it
# is needed.
__git_compute_merge_strategies ()
{
test -n "$__git_merge_strategies" ||
__git_merge_strategies=$(__git_list_merge_strategies)
}
__git_complete_revlist_file ()
{
local pfx ls ref cur_="$cur"
case "$cur_" in
*..?*:*)
return
;;
?*:*)
ref="${cur_%%:*}"
cur_="${cur_#*:}"
case "$cur_" in
?*/*)
pfx="${cur_%/*}"
cur_="${cur_##*/}"
ls="$ref:$pfx"
pfx="$pfx/"
;;
*)
ls="$ref"
;;
esac
case "$COMP_WORDBREAKS" in
*:*) : great ;;
*) pfx="$ref:$pfx" ;;
esac
__gitcomp_nl "$(git --git-dir="$(__gitdir)" ls-tree "$ls" 2>/dev/null \
| sed '/^100... blob /{
s,^.* ,,
s,$, ,
}
/^120000 blob /{
s,^.* ,,
s,$, ,
}
/^040000 tree /{
s,^.* ,,
s,$,/,
}
s/^.* //')" \
"$pfx" "$cur_" ""
;;
*...*)
pfx="${cur_%...*}..."
cur_="${cur_#*...}"
__gitcomp_nl "$(__git_refs)" "$pfx" "$cur_"
;;
*..*)
pfx="${cur_%..*}.."
cur_="${cur_#*..}"
__gitcomp_nl "$(__git_refs)" "$pfx" "$cur_"
;;
*)
__gitcomp_nl "$(__git_refs)"
;;
esac
}
# __git_complete_index_file requires 1 argument:
# 1: the options to pass to ls-file
#
# The exception is --committable, which finds the files appropriate commit.
__git_complete_index_file ()
{
local pfx="" cur_="$cur"
case "$cur_" in
?*/*)
pfx="${cur_%/*}"
cur_="${cur_##*/}"
pfx="${pfx}/"
;;
esac
__gitcomp_file "$(__git_index_files "$1" "$pfx")" "$pfx" "$cur_"
}
__git_complete_file ()
{
__git_complete_revlist_file
}
__git_complete_revlist ()
{
__git_complete_revlist_file
}
__git_complete_remote_or_refspec ()
{
local cur_="$cur" cmd="${words[1]}"
local i c=2 remote="" pfx="" lhs=1 no_complete_refspec=0
if [ "$cmd" = "remote" ]; then
((c++))
fi
while [ $c -lt $cword ]; do
i="${words[c]}"
case "$i" in
--mirror) [ "$cmd" = "push" ] && no_complete_refspec=1 ;;
--all)
case "$cmd" in
push) no_complete_refspec=1 ;;
fetch)
return
;;
*) ;;
esac
;;
-*) ;;
*) remote="$i"; break ;;
esac
((c++))
done
if [ -z "$remote" ]; then
__gitcomp_nl "$(__git_remotes)"
return
fi
if [ $no_complete_refspec = 1 ]; then
return
fi
[ "$remote" = "." ] && remote=
case "$cur_" in
*:*)
case "$COMP_WORDBREAKS" in
*:*) : great ;;
*) pfx="${cur_%%:*}:" ;;
esac
cur_="${cur_#*:}"
lhs=0
;;
+*)
pfx="+"
cur_="${cur_#+}"
;;
esac
case "$cmd" in
fetch)
if [ $lhs = 1 ]; then
__gitcomp_nl "$(__git_refs2 "$remote")" "$pfx" "$cur_"
else
__gitcomp_nl "$(__git_refs)" "$pfx" "$cur_"
fi
;;
pull|remote)
if [ $lhs = 1 ]; then
__gitcomp_nl "$(__git_refs "$remote")" "$pfx" "$cur_"
else
__gitcomp_nl "$(__git_refs)" "$pfx" "$cur_"
fi
;;
push)
if [ $lhs = 1 ]; then
__gitcomp_nl "$(__git_refs)" "$pfx" "$cur_"
else
__gitcomp_nl "$(__git_refs "$remote")" "$pfx" "$cur_"
fi
;;
esac
}
__git_complete_strategy ()
{
__git_compute_merge_strategies
case "$prev" in
-s|--strategy)
__gitcomp "$__git_merge_strategies"
return 0
esac
case "$cur" in
--strategy=*)
__gitcomp "$__git_merge_strategies" "" "${cur##--strategy=}"
return 0
;;
esac
return 1
}
__git_commands () {
if test -n "${GIT_TESTING_COMMAND_COMPLETION:-}"
then
printf "%s" "${GIT_TESTING_COMMAND_COMPLETION}"
else
git help -a|egrep '^ [a-zA-Z0-9]'
fi
}
__git_list_all_commands ()
{
local i IFS=" "$'\n'
for i in $(__git_commands)
do
case $i in
*--*) : helper pattern;;
*) echo $i;;
esac
done
}
__git_all_commands=
__git_compute_all_commands ()
{
test -n "$__git_all_commands" ||
__git_all_commands=$(__git_list_all_commands)
}
__git_list_porcelain_commands ()
{
local i IFS=" "$'\n'
__git_compute_all_commands
for i in $__git_all_commands
do
case $i in
*--*) : helper pattern;;
applymbox) : ask gittus;;
applypatch) : ask gittus;;
archimport) : import;;
cat-file) : plumbing;;
check-attr) : plumbing;;
check-ignore) : plumbing;;
check-mailmap) : plumbing;;
check-ref-format) : plumbing;;
checkout-index) : plumbing;;
commit-tree) : plumbing;;
count-objects) : infrequent;;
credential-cache) : credentials helper;;
credential-store) : credentials helper;;
cvsexportcommit) : export;;
cvsimport) : import;;
cvsserver) : daemon;;
daemon) : daemon;;
diff-files) : plumbing;;
diff-index) : plumbing;;
diff-tree) : plumbing;;
fast-import) : import;;
fast-export) : export;;
fsck-objects) : plumbing;;
fetch-pack) : plumbing;;
fmt-merge-msg) : plumbing;;
for-each-ref) : plumbing;;
hash-object) : plumbing;;
http-*) : transport;;
index-pack) : plumbing;;
init-db) : deprecated;;
local-fetch) : plumbing;;
lost-found) : infrequent;;
ls-files) : plumbing;;
ls-remote) : plumbing;;
ls-tree) : plumbing;;
mailinfo) : plumbing;;
mailsplit) : plumbing;;
merge-*) : plumbing;;
mktree) : plumbing;;
mktag) : plumbing;;
pack-objects) : plumbing;;
pack-redundant) : plumbing;;
pack-refs) : plumbing;;
parse-remote) : plumbing;;
patch-id) : plumbing;;
peek-remote) : plumbing;;
prune) : plumbing;;
prune-packed) : plumbing;;
quiltimport) : import;;
read-tree) : plumbing;;
receive-pack) : plumbing;;
remote-*) : transport;;
repo-config) : deprecated;;
rerere) : plumbing;;
rev-list) : plumbing;;
rev-parse) : plumbing;;
runstatus) : plumbing;;
sh-setup) : internal;;
shell) : daemon;;
show-ref) : plumbing;;
send-pack) : plumbing;;
show-index) : plumbing;;
ssh-*) : transport;;
stripspace) : plumbing;;
symbolic-ref) : plumbing;;
tar-tree) : deprecated;;
unpack-file) : plumbing;;
unpack-objects) : plumbing;;
update-index) : plumbing;;
update-ref) : plumbing;;
update-server-info) : daemon;;
upload-archive) : plumbing;;
upload-pack) : plumbing;;
write-tree) : plumbing;;
var) : infrequent;;
verify-pack) : infrequent;;
verify-tag) : plumbing;;
*) echo $i;;
esac
done
}
__git_porcelain_commands=
__git_compute_porcelain_commands ()
{
__git_compute_all_commands
test -n "$__git_porcelain_commands" ||
__git_porcelain_commands=$(__git_list_porcelain_commands)
}
__git_pretty_aliases ()
{
local i IFS=$'\n'
for i in $(git --git-dir="$(__gitdir)" config --get-regexp "pretty\..*" 2>/dev/null); do
case "$i" in
pretty.*)
i="${i#pretty.}"
echo "${i/ */}"
;;
esac
done
}
__git_aliases ()
{
local i IFS=$'\n'
for i in $(git --git-dir="$(__gitdir)" config --get-regexp "alias\..*" 2>/dev/null); do
case "$i" in
alias.*)
i="${i#alias.}"
echo "${i/ */}"
;;
esac
done
}
# __git_aliased_command requires 1 argument
__git_aliased_command ()
{
local word cmdline=$(git --git-dir="$(__gitdir)" \
config --get "alias.$1")
for word in $cmdline; do
case "$word" in
\!gitk|gitk)
echo "gitk"
return
;;
\!*) : shell command alias ;;
-*) : option ;;
*=*) : setting env ;;
git) : git itself ;;
*)
echo "$word"
return
esac
done
}
# __git_find_on_cmdline requires 1 argument
__git_find_on_cmdline ()
{
local word subcommand c=1
while [ $c -lt $cword ]; do
word="${words[c]}"
for subcommand in $1; do
if [ "$subcommand" = "$word" ]; then
echo "$subcommand"
return
fi
done
((c++))
done
}
__git_has_doubledash ()
{
local c=1
while [ $c -lt $cword ]; do
if [ "--" = "${words[c]}" ]; then
return 0
fi
((c++))
done
return 1
}
# Try to count non option arguments passed on the command line for the
# specified git command.
# When options are used, it is necessary to use the special -- option to
# tell the implementation were non option arguments begin.
# XXX this can not be improved, since options can appear everywhere, as
# an example:
# git mv x -n y
#
# __git_count_arguments requires 1 argument: the git command executed.
__git_count_arguments ()
{
local word i c=0
# Skip "git" (first argument)
for ((i=1; i < ${#words[@]}; i++)); do
word="${words[i]}"
case "$word" in
--)
# Good; we can assume that the following are only non
# option arguments.
((c = 0))
;;
"$1")
# Skip the specified git command and discard git
# main options
((c = 0))
;;
?*)
((c++))
;;
esac
done
printf "%d" $c
}
__git_whitespacelist="nowarn warn error error-all fix"
_git_am ()
{
local dir="$(__gitdir)"
if [ -d "$dir"/rebase-apply ]; then
__gitcomp "--skip --continue --resolved --abort"
return
fi
case "$cur" in
--whitespace=*)
__gitcomp "$__git_whitespacelist" "" "${cur##--whitespace=}"
return
;;
--*)
__gitcomp "
--3way --committer-date-is-author-date --ignore-date
--ignore-whitespace --ignore-space-change
--interactive --keep --no-utf8 --signoff --utf8
--whitespace= --scissors
"
return
esac
}
_git_apply ()
{
case "$cur" in
--whitespace=*)
__gitcomp "$__git_whitespacelist" "" "${cur##--whitespace=}"
return
;;
--*)
__gitcomp "
--stat --numstat --summary --check --index
--cached --index-info --reverse --reject --unidiff-zero
--apply --no-add --exclude=
--ignore-whitespace --ignore-space-change
--whitespace= --inaccurate-eof --verbose
"
return
esac
}
_git_add ()
{
case "$cur" in
--*)
__gitcomp "
--interactive --refresh --patch --update --dry-run
--ignore-errors --intent-to-add
"
return
esac
# XXX should we check for --update and --all options ?
__git_complete_index_file "--others --modified"
}
_git_archive ()
{
case "$cur" in
--format=*)
__gitcomp "$(git archive --list)" "" "${cur##--format=}"
return
;;
--remote=*)
__gitcomp_nl "$(__git_remotes)" "" "${cur##--remote=}"
return
;;
--*)
__gitcomp "
--format= --list --verbose
--prefix= --remote= --exec=
"
return
;;
esac
__git_complete_file
}
_git_bisect ()
{
__git_has_doubledash && return
local subcommands="start bad good skip reset visualize replay log run"
local subcommand="$(__git_find_on_cmdline "$subcommands")"
if [ -z "$subcommand" ]; then
if [ -f "$(__gitdir)"/BISECT_START ]; then
__gitcomp "$subcommands"
else
__gitcomp "replay start"
fi
return
fi
case "$subcommand" in
bad|good|reset|skip|start)
__gitcomp_nl "$(__git_refs)"
;;
*)
;;
esac
}
_git_branch ()
{
local i c=1 only_local_ref="n" has_r="n"
while [ $c -lt $cword ]; do
i="${words[c]}"
case "$i" in
-d|-m) only_local_ref="y" ;;
-r) has_r="y" ;;
esac
((c++))
done
case "$cur" in
--set-upstream-to=*)
__gitcomp "$(__git_refs)" "" "${cur##--set-upstream-to=}"
;;
--*)
__gitcomp "
--color --no-color --verbose --abbrev= --no-abbrev
--track --no-track --contains --merged --no-merged
--set-upstream-to= --edit-description --list
--unset-upstream
"
;;
*)
if [ $only_local_ref = "y" -a $has_r = "n" ]; then
__gitcomp_nl "$(__git_heads)"
else
__gitcomp_nl "$(__git_refs)"
fi
;;
esac
}
_git_bundle ()
{
local cmd="${words[2]}"
case "$cword" in
2)
__gitcomp "create list-heads verify unbundle"
;;
3)
# looking for a file
;;
*)
case "$cmd" in
create)
__git_complete_revlist
;;
esac
;;
esac
}
_git_checkout ()
{
__git_has_doubledash && return
case "$cur" in
--conflict=*)
__gitcomp "diff3 merge" "" "${cur##--conflict=}"
;;
--*)
__gitcomp "
--quiet --ours --theirs --track --no-track --merge
--conflict= --orphan --patch
"
;;
*)
# check if --track, --no-track, or --no-guess was specified
# if so, disable DWIM mode
local flags="--track --no-track --no-guess" track=1
if [ -n "$(__git_find_on_cmdline "$flags")" ]; then
track=''
fi
__gitcomp_nl "$(__git_refs '' $track)"
;;
esac
}
_git_cherry ()
{
__gitcomp "$(__git_refs)"
}
_git_cherry_pick ()
{
local dir="$(__gitdir)"
if [ -f "$dir"/CHERRY_PICK_HEAD ]; then
__gitcomp "--continue --quit --abort"
return
fi
case "$cur" in
--*)
__gitcomp "--edit --no-commit --signoff --strategy= --mainline"
;;
*)
__gitcomp_nl "$(__git_refs)"
;;
esac
}
_git_clean ()
{
case "$cur" in
--*)
__gitcomp "--dry-run --quiet"
return
;;
esac
# XXX should we check for -x option ?
__git_complete_index_file "--others"
}
_git_clone ()
{
case "$cur" in
--*)
__gitcomp "
--local
--no-hardlinks
--shared
--reference
--quiet
--no-checkout
--bare
--mirror
--origin
--upload-pack
--template=
--depth
--single-branch
--branch
"
return
;;
esac
}
_git_commit ()
{
case "$prev" in
-c|-C)
__gitcomp_nl "$(__git_refs)" "" "${cur}"
return
;;
esac
case "$cur" in
--cleanup=*)
__gitcomp "default strip verbatim whitespace
" "" "${cur##--cleanup=}"
return
;;
--reuse-message=*|--reedit-message=*|\
--fixup=*|--squash=*)
__gitcomp_nl "$(__git_refs)" "" "${cur#*=}"
return
;;
--untracked-files=*)
__gitcomp "all no normal" "" "${cur##--untracked-files=}"
return
;;
--*)
__gitcomp "
--all --author= --signoff --verify --no-verify
--edit --no-edit
--amend --include --only --interactive
--dry-run --reuse-message= --reedit-message=
--reset-author --file= --message= --template=
--cleanup= --untracked-files --untracked-files=
--verbose --quiet --fixup= --squash=
"
return
esac
if git rev-parse --verify --quiet HEAD >/dev/null; then
__git_complete_index_file "--committable"
else
# This is the first commit
__git_complete_index_file "--cached"
fi
}
_git_describe ()
{
case "$cur" in
--*)
__gitcomp "
--all --tags --contains --abbrev= --candidates=
--exact-match --debug --long --match --always
"
return
esac
__gitcomp_nl "$(__git_refs)"
}
__git_diff_algorithms="myers minimal patience histogram"
__git_diff_common_options="--stat --numstat --shortstat --summary
--patch-with-stat --name-only --name-status --color
--no-color --color-words --no-renames --check
--full-index --binary --abbrev --diff-filter=
--find-copies-harder
--text --ignore-space-at-eol --ignore-space-change
--ignore-all-space --exit-code --quiet --ext-diff
--no-ext-diff
--no-prefix --src-prefix= --dst-prefix=
--inter-hunk-context=
--patience --histogram --minimal
--raw --word-diff
--dirstat --dirstat= --dirstat-by-file
--dirstat-by-file= --cumulative
--diff-algorithm=
"
_git_diff ()
{
__git_has_doubledash && return
case "$cur" in
--diff-algorithm=*)
__gitcomp "$__git_diff_algorithms" "" "${cur##--diff-algorithm=}"
return
;;
--*)
__gitcomp "--cached --staged --pickaxe-all --pickaxe-regex
--base --ours --theirs --no-index
$__git_diff_common_options
"
return
;;
esac
__git_complete_revlist_file
}
__git_mergetools_common="diffuse ecmerge emerge kdiff3 meld opendiff
tkdiff vimdiff gvimdiff xxdiff araxis p4merge bc3 codecompare
"
_git_difftool ()
{
__git_has_doubledash && return
case "$cur" in
--tool=*)
__gitcomp "$__git_mergetools_common kompare" "" "${cur##--tool=}"
return
;;
--*)
__gitcomp "--cached --staged --pickaxe-all --pickaxe-regex
--base --ours --theirs
--no-renames --diff-filter= --find-copies-harder
--relative --ignore-submodules
--tool="
return
;;
esac
__git_complete_revlist_file
}
__git_fetch_options="
--quiet --verbose --append --upload-pack --force --keep --depth=
--tags --no-tags --all --prune --dry-run
"
_git_fetch ()
{
case "$cur" in
--*)
__gitcomp "$__git_fetch_options"
return
;;
esac
__git_complete_remote_or_refspec
}
__git_format_patch_options="
--stdout --attach --no-attach --thread --thread= --no-thread
--numbered --start-number --numbered-files --keep-subject --signoff
--signature --no-signature --in-reply-to= --cc= --full-index --binary
--not --all --cover-letter --no-prefix --src-prefix= --dst-prefix=
--inline --suffix= --ignore-if-in-upstream --subject-prefix=
--output-directory --reroll-count --to= --quiet --notes
"
_git_format_patch ()
{
case "$cur" in
--thread=*)
__gitcomp "
deep shallow
" "" "${cur##--thread=}"
return
;;
--*)
__gitcomp "$__git_format_patch_options"
return
;;
esac
__git_complete_revlist
}
_git_fsck ()
{
case "$cur" in
--*)
__gitcomp "
--tags --root --unreachable --cache --no-reflogs --full
--strict --verbose --lost-found
"
return
;;
esac
}
_git_gc ()
{
case "$cur" in
--*)
__gitcomp "--prune --aggressive"
return
;;
esac
}
_git_gitk ()
{
_gitk
}
__git_match_ctag() {
awk "/^${1////\\/}/ { print \$1 }" "$2"
}
_git_grep ()
{
__git_has_doubledash && return
case "$cur" in
--*)
__gitcomp "
--cached
--text --ignore-case --word-regexp --invert-match
--full-name --line-number
--extended-regexp --basic-regexp --fixed-strings
--perl-regexp
--files-with-matches --name-only
--files-without-match
--max-depth
--count
--and --or --not --all-match
"
return
;;
esac
case "$cword,$prev" in
2,*|*,-*)
if test -r tags; then
__gitcomp_nl "$(__git_match_ctag "$cur" tags)"
return
fi
;;
esac
__gitcomp_nl "$(__git_refs)"
}
_git_help ()
{
case "$cur" in
--*)
__gitcomp "--all --info --man --web"
return
;;
esac
__git_compute_all_commands
__gitcomp "$__git_all_commands $(__git_aliases)
attributes cli core-tutorial cvs-migration
diffcore gitk glossary hooks ignore modules
namespaces repository-layout tutorial tutorial-2
workflows
"
}
_git_init ()
{
case "$cur" in
--shared=*)
__gitcomp "
false true umask group all world everybody
" "" "${cur##--shared=}"
return
;;
--*)
__gitcomp "--quiet --bare --template= --shared --shared="
return
;;
esac
}
_git_ls_files ()
{
case "$cur" in
--*)
__gitcomp "--cached --deleted --modified --others --ignored
--stage --directory --no-empty-directory --unmerged
--killed --exclude= --exclude-from=
--exclude-per-directory= --exclude-standard
--error-unmatch --with-tree= --full-name
--abbrev --ignored --exclude-per-directory
"
return
;;
esac
# XXX ignore options like --modified and always suggest all cached
# files.
__git_complete_index_file "--cached"
}
_git_ls_remote ()
{
__gitcomp_nl "$(__git_remotes)"
}
_git_ls_tree ()
{
__git_complete_file
}
# Options that go well for log, shortlog and gitk
__git_log_common_options="
--not --all
--branches --tags --remotes
--first-parent --merges --no-merges
--max-count=
--max-age= --since= --after=
--min-age= --until= --before=
--min-parents= --max-parents=
--no-min-parents --no-max-parents
"
# Options that go well for log and gitk (not shortlog)
__git_log_gitk_options="
--dense --sparse --full-history
--simplify-merges --simplify-by-decoration
--left-right --notes --no-notes
"
# Options that go well for log and shortlog (not gitk)
__git_log_shortlog_options="
--author= --committer= --grep=
--all-match
"
__git_log_pretty_formats="oneline short medium full fuller email raw format:"
__git_log_date_formats="relative iso8601 rfc2822 short local default raw"
_git_log ()
{
__git_has_doubledash && return
local g="$(git rev-parse --git-dir 2>/dev/null)"
local merge=""
if [ -f "$g/MERGE_HEAD" ]; then
merge="--merge"
fi
case "$cur" in
--pretty=*|--format=*)
__gitcomp "$__git_log_pretty_formats $(__git_pretty_aliases)
" "" "${cur#*=}"
return
;;
--date=*)
__gitcomp "$__git_log_date_formats" "" "${cur##--date=}"
return
;;
--decorate=*)
__gitcomp "long short" "" "${cur##--decorate=}"
return
;;
--*)
__gitcomp "
$__git_log_common_options
$__git_log_shortlog_options
$__git_log_gitk_options
--root --topo-order --date-order --reverse
--follow --full-diff
--abbrev-commit --abbrev=
--relative-date --date=
--pretty= --format= --oneline
--cherry-pick
--graph
--decorate --decorate=
--walk-reflogs
--parents --children
$merge
$__git_diff_common_options
--pickaxe-all --pickaxe-regex
"
return
;;
esac
__git_complete_revlist
}
__git_merge_options="
--no-commit --no-stat --log --no-log --squash --strategy
--commit --stat --no-squash --ff --no-ff --ff-only --edit --no-edit
"
_git_merge ()
{
__git_complete_strategy && return
case "$cur" in
--*)
__gitcomp "$__git_merge_options"
return
esac
__gitcomp_nl "$(__git_refs)"
}
_git_mergetool ()
{
case "$cur" in
--tool=*)
__gitcomp "$__git_mergetools_common tortoisemerge" "" "${cur##--tool=}"
return
;;
--*)
__gitcomp "--tool="
return
;;
esac
}
_git_merge_base ()
{
__gitcomp_nl "$(__git_refs)"
}
_git_mv ()
{
case "$cur" in
--*)
__gitcomp "--dry-run"
return
;;
esac
if [ $(__git_count_arguments "mv") -gt 0 ]; then
# We need to show both cached and untracked files (including
# empty directories) since this may not be the last argument.
__git_complete_index_file "--cached --others --directory"
else
__git_complete_index_file "--cached"
fi
}
_git_name_rev ()
{
__gitcomp "--tags --all --stdin"
}
_git_notes ()
{
local subcommands='add append copy edit list prune remove show'
local subcommand="$(__git_find_on_cmdline "$subcommands")"
case "$subcommand,$cur" in
,--*)
__gitcomp '--ref'
;;
,*)
case "$prev" in
--ref)
__gitcomp_nl "$(__git_refs)"
;;
*)
__gitcomp "$subcommands --ref"
;;
esac
;;
add,--reuse-message=*|append,--reuse-message=*|\
add,--reedit-message=*|append,--reedit-message=*)
__gitcomp_nl "$(__git_refs)" "" "${cur#*=}"
;;
add,--*|append,--*)
__gitcomp '--file= --message= --reedit-message=
--reuse-message='
;;
copy,--*)
__gitcomp '--stdin'
;;
prune,--*)
__gitcomp '--dry-run --verbose'
;;
prune,*)
;;
*)
case "$prev" in
-m|-F)
;;
*)
__gitcomp_nl "$(__git_refs)"
;;
esac
;;
esac
}
_git_pull ()
{
__git_complete_strategy && return
case "$cur" in
--*)
__gitcomp "
--rebase --no-rebase
$__git_merge_options
$__git_fetch_options
"
return
;;
esac
__git_complete_remote_or_refspec
}
_git_push ()
{
case "$prev" in
--repo)
__gitcomp_nl "$(__git_remotes)"
return
esac
case "$cur" in
--repo=*)
__gitcomp_nl "$(__git_remotes)" "" "${cur##--repo=}"
return
;;
--*)
__gitcomp "
--all --mirror --tags --dry-run --force --verbose
--receive-pack= --repo= --set-upstream
"
return
;;
esac
__git_complete_remote_or_refspec
}
_git_rebase ()
{
local dir="$(__gitdir)"
if [ -d "$dir"/rebase-apply ] || [ -d "$dir"/rebase-merge ]; then
__gitcomp "--continue --skip --abort"
return
fi
__git_complete_strategy && return
case "$cur" in
--whitespace=*)
__gitcomp "$__git_whitespacelist" "" "${cur##--whitespace=}"
return
;;
--*)
__gitcomp "
--onto --merge --strategy --interactive
--preserve-merges --stat --no-stat
--committer-date-is-author-date --ignore-date
--ignore-whitespace --whitespace=
--autosquash
"
return
esac
__gitcomp_nl "$(__git_refs)"
}
_git_reflog ()
{
local subcommands="show delete expire"
local subcommand="$(__git_find_on_cmdline "$subcommands")"
if [ -z "$subcommand" ]; then
__gitcomp "$subcommands"
else
__gitcomp_nl "$(__git_refs)"
fi
}
__git_send_email_confirm_options="always never auto cc compose"
__git_send_email_suppresscc_options="author self cc bodycc sob cccmd body all"
_git_send_email ()
{
case "$cur" in
--confirm=*)
__gitcomp "
$__git_send_email_confirm_options
" "" "${cur##--confirm=}"
return
;;
--suppress-cc=*)
__gitcomp "
$__git_send_email_suppresscc_options
" "" "${cur##--suppress-cc=}"
return
;;
--smtp-encryption=*)
__gitcomp "ssl tls" "" "${cur##--smtp-encryption=}"
return
;;
--thread=*)
__gitcomp "
deep shallow
" "" "${cur##--thread=}"
return
;;
--*)
__gitcomp "--annotate --bcc --cc --cc-cmd --chain-reply-to
--compose --confirm= --dry-run --envelope-sender
--from --identity
--in-reply-to --no-chain-reply-to --no-signed-off-by-cc
--no-suppress-from --no-thread --quiet
--signed-off-by-cc --smtp-pass --smtp-server
--smtp-server-port --smtp-encryption= --smtp-user
--subject --suppress-cc= --suppress-from --thread --to
--validate --no-validate
$__git_format_patch_options"
return
;;
esac
__git_complete_revlist
}
_git_stage ()
{
_git_add
}
__git_config_get_set_variables ()
{
local prevword word config_file= c=$cword
while [ $c -gt 1 ]; do
word="${words[c]}"
case "$word" in
--system|--global|--local|--file=*)
config_file="$word"
break
;;
-f|--file)
config_file="$word $prevword"
break
;;
esac
prevword=$word
c=$((--c))
done
git --git-dir="$(__gitdir)" config $config_file --list 2>/dev/null |
while read -r line
do
case "$line" in
*.*=*)
echo "${line/=*/}"
;;
esac
done
}
_git_config ()
{
case "$prev" in
branch.*.remote|branch.*.pushremote)
__gitcomp_nl "$(__git_remotes)"
return
;;
branch.*.merge)
__gitcomp_nl "$(__git_refs)"
return
;;
branch.*.rebase)
__gitcomp "false true"
return
;;
remote.pushdefault)
__gitcomp_nl "$(__git_remotes)"
return
;;
remote.*.fetch)
local remote="${prev#remote.}"
remote="${remote%.fetch}"
if [ -z "$cur" ]; then
__gitcomp_nl "refs/heads/" "" "" ""
return
fi
__gitcomp_nl "$(__git_refs_remotes "$remote")"
return
;;
remote.*.push)
local remote="${prev#remote.}"
remote="${remote%.push}"
__gitcomp_nl "$(git --git-dir="$(__gitdir)" \
for-each-ref --format='%(refname):%(refname)' \
refs/heads)"
return
;;
pull.twohead|pull.octopus)
__git_compute_merge_strategies
__gitcomp "$__git_merge_strategies"
return
;;
color.branch|color.diff|color.interactive|\
color.showbranch|color.status|color.ui)
__gitcomp "always never auto"
return
;;
color.pager)
__gitcomp "false true"
return
;;
color.*.*)
__gitcomp "
normal black red green yellow blue magenta cyan white
bold dim ul blink reverse
"
return
;;
diff.submodule)
__gitcomp "log short"
return
;;
help.format)
__gitcomp "man info web html"
return
;;
log.date)
__gitcomp "$__git_log_date_formats"
return
;;
sendemail.aliasesfiletype)
__gitcomp "mutt mailrc pine elm gnus"
return
;;
sendemail.confirm)
__gitcomp "$__git_send_email_confirm_options"
return
;;
sendemail.suppresscc)
__gitcomp "$__git_send_email_suppresscc_options"
return
;;
--get|--get-all|--unset|--unset-all)
__gitcomp_nl "$(__git_config_get_set_variables)"
return
;;
*.*)
return
;;
esac
case "$cur" in
--*)
__gitcomp "
--system --global --local --file=
--list --replace-all
--get --get-all --get-regexp
--add --unset --unset-all
--remove-section --rename-section
"
return
;;
branch.*.*)
local pfx="${cur%.*}." cur_="${cur##*.}"
__gitcomp "remote pushremote merge mergeoptions rebase" "$pfx" "$cur_"
return
;;
branch.*)
local pfx="${cur%.*}." cur_="${cur#*.}"
__gitcomp_nl "$(__git_heads)" "$pfx" "$cur_" "."
return
;;
guitool.*.*)
local pfx="${cur%.*}." cur_="${cur##*.}"
__gitcomp "
argprompt cmd confirm needsfile noconsole norescan
prompt revprompt revunmerged title
" "$pfx" "$cur_"
return
;;
difftool.*.*)
local pfx="${cur%.*}." cur_="${cur##*.}"
__gitcomp "cmd path" "$pfx" "$cur_"
return
;;
man.*.*)
local pfx="${cur%.*}." cur_="${cur##*.}"
__gitcomp "cmd path" "$pfx" "$cur_"
return
;;
mergetool.*.*)
local pfx="${cur%.*}." cur_="${cur##*.}"
__gitcomp "cmd path trustExitCode" "$pfx" "$cur_"
return
;;
pager.*)
local pfx="${cur%.*}." cur_="${cur#*.}"
__git_compute_all_commands
__gitcomp_nl "$__git_all_commands" "$pfx" "$cur_"
return
;;
remote.*.*)
local pfx="${cur%.*}." cur_="${cur##*.}"
__gitcomp "
url proxy fetch push mirror skipDefaultUpdate
receivepack uploadpack tagopt pushurl
" "$pfx" "$cur_"
return
;;
remote.*)
local pfx="${cur%.*}." cur_="${cur#*.}"
__gitcomp_nl "$(__git_remotes)" "$pfx" "$cur_" "."
return
;;
url.*.*)
local pfx="${cur%.*}." cur_="${cur##*.}"
__gitcomp "insteadOf pushInsteadOf" "$pfx" "$cur_"
return
;;
esac
__gitcomp "
add.ignoreErrors
advice.commitBeforeMerge
advice.detachedHead
advice.implicitIdentity
advice.pushNonFastForward
advice.resolveConflict
advice.statusHints
alias.
am.keepcr
apply.ignorewhitespace
apply.whitespace
branch.autosetupmerge
branch.autosetuprebase
browser.
clean.requireForce
color.branch
color.branch.current
color.branch.local
color.branch.plain
color.branch.remote
color.decorate.HEAD
color.decorate.branch
color.decorate.remoteBranch
color.decorate.stash
color.decorate.tag
color.diff
color.diff.commit
color.diff.frag
color.diff.func
color.diff.meta
color.diff.new
color.diff.old
color.diff.plain
color.diff.whitespace
color.grep
color.grep.context
color.grep.filename
color.grep.function
color.grep.linenumber
color.grep.match
color.grep.selected
color.grep.separator
color.interactive
color.interactive.error
color.interactive.header
color.interactive.help
color.interactive.prompt
color.pager
color.showbranch
color.status
color.status.added
color.status.changed
color.status.header
color.status.nobranch
color.status.untracked
color.status.updated
color.ui
commit.status
commit.template
core.abbrev
core.askpass
core.attributesfile
core.autocrlf
core.bare
core.bigFileThreshold
core.compression
core.createObject
core.deltaBaseCacheLimit
core.editor
core.eol
core.excludesfile
core.fileMode
core.fsyncobjectfiles
core.gitProxy
core.ignoreStat
core.ignorecase
core.logAllRefUpdates
core.loosecompression
core.notesRef
core.packedGitLimit
core.packedGitWindowSize
core.pager
core.preferSymlinkRefs
core.preloadindex
core.quotepath
core.repositoryFormatVersion
core.safecrlf
core.sharedRepository
core.sparseCheckout
core.symlinks
core.trustctime
core.warnAmbiguousRefs
core.whitespace
core.worktree
diff.autorefreshindex
diff.external
diff.ignoreSubmodules
diff.mnemonicprefix
diff.noprefix
diff.renameLimit
diff.renames
diff.statGraphWidth
diff.submodule
diff.suppressBlankEmpty
diff.tool
diff.wordRegex
diff.algorithm
difftool.
difftool.prompt
fetch.recurseSubmodules
fetch.unpackLimit
format.attach
format.cc
format.headers
format.numbered
format.pretty
format.signature
format.signoff
format.subjectprefix
format.suffix
format.thread
format.to
gc.
gc.aggressiveWindow
gc.auto
gc.autopacklimit
gc.packrefs
gc.pruneexpire
gc.reflogexpire
gc.reflogexpireunreachable
gc.rerereresolved
gc.rerereunresolved
gitcvs.allbinary
gitcvs.commitmsgannotation
gitcvs.dbTableNamePrefix
gitcvs.dbdriver
gitcvs.dbname
gitcvs.dbpass
gitcvs.dbuser
gitcvs.enabled
gitcvs.logfile
gitcvs.usecrlfattr
guitool.
gui.blamehistoryctx
gui.commitmsgwidth
gui.copyblamethreshold
gui.diffcontext
gui.encoding
gui.fastcopyblame
gui.matchtrackingbranch
gui.newbranchtemplate
gui.pruneduringfetch
gui.spellingdictionary
gui.trustmtime
help.autocorrect
help.browser
help.format
http.lowSpeedLimit
http.lowSpeedTime
http.maxRequests
http.minSessions
http.noEPSV
http.postBuffer
http.proxy
http.sslCAInfo
http.sslCAPath
http.sslCert
http.sslCertPasswordProtected
http.sslKey
http.sslVerify
http.useragent
i18n.commitEncoding
i18n.logOutputEncoding
imap.authMethod
imap.folder
imap.host
imap.pass
imap.port
imap.preformattedHTML
imap.sslverify
imap.tunnel
imap.user
init.templatedir
instaweb.browser
instaweb.httpd
instaweb.local
instaweb.modulepath
instaweb.port
interactive.singlekey
log.date
log.decorate
log.showroot
mailmap.file
man.
man.viewer
merge.
merge.conflictstyle
merge.log
merge.renameLimit
merge.renormalize
merge.stat
merge.tool
merge.verbosity
mergetool.
mergetool.keepBackup
mergetool.keepTemporaries
mergetool.prompt
notes.displayRef
notes.rewrite.
notes.rewrite.amend
notes.rewrite.rebase
notes.rewriteMode
notes.rewriteRef
pack.compression
pack.deltaCacheLimit
pack.deltaCacheSize
pack.depth
pack.indexVersion
pack.packSizeLimit
pack.threads
pack.window
pack.windowMemory
pager.
pretty.
pull.octopus
pull.twohead
push.default
rebase.autosquash
rebase.stat
receive.autogc
receive.denyCurrentBranch
receive.denyDeleteCurrent
receive.denyDeletes
receive.denyNonFastForwards
receive.fsckObjects
receive.unpackLimit
receive.updateserverinfo
remote.pushdefault
remotes.
repack.usedeltabaseoffset
rerere.autoupdate
rerere.enabled
sendemail.
sendemail.aliasesfile
sendemail.aliasfiletype
sendemail.bcc
sendemail.cc
sendemail.cccmd
sendemail.chainreplyto
sendemail.confirm
sendemail.envelopesender
sendemail.from
sendemail.identity
sendemail.multiedit
sendemail.signedoffbycc
sendemail.smtpdomain
sendemail.smtpencryption
sendemail.smtppass
sendemail.smtpserver
sendemail.smtpserveroption
sendemail.smtpserverport
sendemail.smtpuser
sendemail.suppresscc
sendemail.suppressfrom
sendemail.thread
sendemail.to
sendemail.validate
showbranch.default
status.relativePaths
status.showUntrackedFiles
status.submodulesummary
submodule.
tar.umask
transfer.unpackLimit
url.
user.email
user.name
user.signingkey
web.browser
branch. remote.
"
}
_git_remote ()
{
local subcommands="add rename remove set-head set-branches set-url show prune update"
local subcommand="$(__git_find_on_cmdline "$subcommands")"
if [ -z "$subcommand" ]; then
__gitcomp "$subcommands"
return
fi
case "$subcommand" in
rename|remove|set-url|show|prune)
__gitcomp_nl "$(__git_remotes)"
;;
set-head|set-branches)
__git_complete_remote_or_refspec
;;
update)
local i c='' IFS=$'\n'
for i in $(git --git-dir="$(__gitdir)" config --get-regexp "remotes\..*" 2>/dev/null); do
i="${i#remotes.}"
c="$c ${i/ */}"
done
__gitcomp "$c"
;;
*)
;;
esac
}
_git_replace ()
{
__gitcomp_nl "$(__git_refs)"
}
_git_reset ()
{
__git_has_doubledash && return
case "$cur" in
--*)
__gitcomp "--merge --mixed --hard --soft --patch"
return
;;
esac
__gitcomp_nl "$(__git_refs)"
}
_git_revert ()
{
case "$cur" in
--*)
__gitcomp "--edit --mainline --no-edit --no-commit --signoff"
return
;;
esac
__gitcomp_nl "$(__git_refs)"
}
_git_rm ()
{
case "$cur" in
--*)
__gitcomp "--cached --dry-run --ignore-unmatch --quiet"
return
;;
esac
__git_complete_index_file "--cached"
}
_git_shortlog ()
{
__git_has_doubledash && return
case "$cur" in
--*)
__gitcomp "
$__git_log_common_options
$__git_log_shortlog_options
--numbered --summary
"
return
;;
esac
__git_complete_revlist
}
_git_show ()
{
__git_has_doubledash && return
case "$cur" in
--pretty=*|--format=*)
__gitcomp "$__git_log_pretty_formats $(__git_pretty_aliases)
" "" "${cur#*=}"
return
;;
--diff-algorithm=*)
__gitcomp "$__git_diff_algorithms" "" "${cur##--diff-algorithm=}"
return
;;
--*)
__gitcomp "--pretty= --format= --abbrev-commit --oneline
$__git_diff_common_options
"
return
;;
esac
__git_complete_revlist_file
}
_git_show_branch ()
{
case "$cur" in
--*)
__gitcomp "
--all --remotes --topo-order --current --more=
--list --independent --merge-base --no-name
--color --no-color
--sha1-name --sparse --topics --reflog
"
return
;;
esac
__git_complete_revlist
}
_git_stash ()
{
local save_opts='--keep-index --no-keep-index --quiet --patch'
local subcommands='save list show apply clear drop pop create branch'
local subcommand="$(__git_find_on_cmdline "$subcommands")"
if [ -z "$subcommand" ]; then
case "$cur" in
--*)
__gitcomp "$save_opts"
;;
*)
if [ -z "$(__git_find_on_cmdline "$save_opts")" ]; then
__gitcomp "$subcommands"
fi
;;
esac
else
case "$subcommand,$cur" in
save,--*)
__gitcomp "$save_opts"
;;
apply,--*|pop,--*)
__gitcomp "--index --quiet"
;;
show,--*|drop,--*|branch,--*)
;;
show,*|apply,*|drop,*|pop,*|branch,*)
__gitcomp_nl "$(git --git-dir="$(__gitdir)" stash list \
| sed -n -e 's/:.*//p')"
;;
*)
;;
esac
fi
}
_git_submodule ()
{
__git_has_doubledash && return
local subcommands="add status init deinit update summary foreach sync"
if [ -z "$(__git_find_on_cmdline "$subcommands")" ]; then
case "$cur" in
--*)
__gitcomp "--quiet --cached"
;;
*)
__gitcomp "$subcommands"
;;
esac
return
fi
}
_git_svn ()
{
local subcommands="
init fetch clone rebase dcommit log find-rev
set-tree commit-diff info create-ignore propget
proplist show-ignore show-externals branch tag blame
migrate mkdirs reset gc
"
local subcommand="$(__git_find_on_cmdline "$subcommands")"
if [ -z "$subcommand" ]; then
__gitcomp "$subcommands"
else
local remote_opts="--username= --config-dir= --no-auth-cache"
local fc_opts="
--follow-parent --authors-file= --repack=
--no-metadata --use-svm-props --use-svnsync-props
--log-window-size= --no-checkout --quiet
--repack-flags --use-log-author --localtime
--ignore-paths= --include-paths= $remote_opts
"
local init_opts="
--template= --shared= --trunk= --tags=
--branches= --stdlayout --minimize-url
--no-metadata --use-svm-props --use-svnsync-props
--rewrite-root= --prefix= --use-log-author
--add-author-from $remote_opts
"
local cmt_opts="
--edit --rmdir --find-copies-harder --copy-similarity=
"
case "$subcommand,$cur" in
fetch,--*)
__gitcomp "--revision= --fetch-all $fc_opts"
;;
clone,--*)
__gitcomp "--revision= $fc_opts $init_opts"
;;
init,--*)
__gitcomp "$init_opts"
;;
dcommit,--*)
__gitcomp "
--merge --strategy= --verbose --dry-run
--fetch-all --no-rebase --commit-url
--revision --interactive $cmt_opts $fc_opts
"
;;
set-tree,--*)
__gitcomp "--stdin $cmt_opts $fc_opts"
;;
create-ignore,--*|propget,--*|proplist,--*|show-ignore,--*|\
show-externals,--*|mkdirs,--*)
__gitcomp "--revision="
;;
log,--*)
__gitcomp "
--limit= --revision= --verbose --incremental
--oneline --show-commit --non-recursive
--authors-file= --color
"
;;
rebase,--*)
__gitcomp "
--merge --verbose --strategy= --local
--fetch-all --dry-run $fc_opts
"
;;
commit-diff,--*)
__gitcomp "--message= --file= --revision= $cmt_opts"
;;
info,--*)
__gitcomp "--url"
;;
branch,--*)
__gitcomp "--dry-run --message --tag"
;;
tag,--*)
__gitcomp "--dry-run --message"
;;
blame,--*)
__gitcomp "--git-format"
;;
migrate,--*)
__gitcomp "
--config-dir= --ignore-paths= --minimize
--no-auth-cache --username=
"
;;
reset,--*)
__gitcomp "--revision= --parent"
;;
*)
;;
esac
fi
}
_git_tag ()
{
local i c=1 f=0
while [ $c -lt $cword ]; do
i="${words[c]}"
case "$i" in
-d|-v)
__gitcomp_nl "$(__git_tags)"
return
;;
-f)
f=1
;;
esac
((c++))
done
case "$prev" in
-m|-F)
;;
-*|tag)
if [ $f = 1 ]; then
__gitcomp_nl "$(__git_tags)"
fi
;;
*)
__gitcomp_nl "$(__git_refs)"
;;
esac
}
_git_whatchanged ()
{
_git_log
}
__git_main ()
{
local i c=1 command __git_dir
while [ $c -lt $cword ]; do
i="${words[c]}"
case "$i" in
--git-dir=*) __git_dir="${i#--git-dir=}" ;;
--git-dir) ((c++)) ; __git_dir="${words[c]}" ;;
--bare) __git_dir="." ;;
--help) command="help"; break ;;
-c|--work-tree|--namespace) ((c++)) ;;
-*) ;;
*) command="$i"; break ;;
esac
((c++))
done
if [ -z "$command" ]; then
case "$cur" in
--*) __gitcomp "
--paginate
--no-pager
--git-dir=
--bare
--version
--exec-path
--exec-path=
--html-path
--man-path
--info-path
--work-tree=
--namespace=
--no-replace-objects
--help
"
;;
*) __git_compute_porcelain_commands
__gitcomp "$__git_porcelain_commands $(__git_aliases)" ;;
esac
return
fi
local completion_func="_git_${command//-/_}"
declare -f $completion_func >/dev/null && $completion_func && return
local expansion=$(__git_aliased_command "$command")
if [ -n "$expansion" ]; then
completion_func="_git_${expansion//-/_}"
declare -f $completion_func >/dev/null && $completion_func
fi
}
__gitk_main ()
{
__git_has_doubledash && return
local g="$(__gitdir)"
local merge=""
if [ -f "$g/MERGE_HEAD" ]; then
merge="--merge"
fi
case "$cur" in
--*)
__gitcomp "
$__git_log_common_options
$__git_log_gitk_options
$merge
"
return
;;
esac
__git_complete_revlist
}
if [[ -n ${ZSH_VERSION-} ]]; then
echo "WARNING: this script is deprecated, please see git-completion.zsh" 1>&2
autoload -U +X compinit && compinit
__gitcomp ()
{
emulate -L zsh
local cur_="${3-$cur}"
case "$cur_" in
--*=)
;;
*)
local c IFS=$' \t\n'
local -a array
for c in ${=1}; do
c="$c${4-}"
case $c in
--*=*|*.) ;;
*) c="$c " ;;
esac
array[$#array+1]="$c"
done
compset -P '*[=:]'
compadd -Q -S '' -p "${2-}" -a -- array && _ret=0
;;
esac
}
__gitcomp_nl ()
{
emulate -L zsh
local IFS=$'\n'
compset -P '*[=:]'
compadd -Q -S "${4- }" -p "${2-}" -- ${=1} && _ret=0
}
__gitcomp_file ()
{
emulate -L zsh
local IFS=$'\n'
compset -P '*[=:]'
compadd -Q -p "${2-}" -f -- ${=1} && _ret=0
}
_git ()
{
local _ret=1 cur cword prev
cur=${words[CURRENT]}
prev=${words[CURRENT-1]}
let cword=CURRENT-1
emulate ksh -c __${service}_main
let _ret && _default && _ret=0
return _ret
}
compdef _git git gitk
return
fi
__git_func_wrap ()
{
local cur words cword prev
_get_comp_words_by_ref -n =: cur words cword prev
$1
}
# Setup completion for certain functions defined above by setting common
# variables and workarounds.
# This is NOT a public function; use at your own risk.
__git_complete ()
{
local wrapper="__git_wrap${2}"
eval "$wrapper () { __git_func_wrap $2 ; }"
complete -o bashdefault -o default -o nospace -F $wrapper $1 2>/dev/null \
|| complete -o default -o nospace -F $wrapper $1
}
# wrapper for backwards compatibility
_git ()
{
__git_wrap__git_main
}
# wrapper for backwards compatibility
_gitk ()
{
__git_wrap__gitk_main
}
__git_complete git __git_main
__git_complete gitk __gitk_main
# The following are necessary only for Cygwin, and only are needed
# when the user has tab-completed the executable name and consequently
# included the '.exe' suffix.
#
if [ Cygwin = "$(uname -o 2>/dev/null)" ]; then
__git_complete git.exe __git_main
fi
# bash/zsh git prompt support
#
# Copyright (C) 2006,2007 Shawn O. Pearce <spearce@spearce.org>
# Distributed under the GNU General Public License, version 2.0.
#
# This script allows you to see repository status in your prompt.
#
# To enable:
#
# 1) Copy this file to somewhere (e.g. ~/.git-prompt.sh).
# 2) Add the following line to your .bashrc/.zshrc:
# source ~/.git-prompt.sh
# 3a) Change your PS1 to call __git_ps1 as
# command-substitution:
# Bash: PS1='[\u@\h \W$(__git_ps1 " (%s)")]\$ '
# ZSH: setopt PROMPT_SUBST ; PS1='[%n@%m %c$(__git_ps1 " (%s)")]\$ '
# the optional argument will be used as format string.
# 3b) Alternatively, for a slightly faster prompt, __git_ps1 can
# be used for PROMPT_COMMAND in Bash or for precmd() in Zsh
# with two parameters, <pre> and <post>, which are strings
# you would put in $PS1 before and after the status string
# generated by the git-prompt machinery. e.g.
# Bash: PROMPT_COMMAND='__git_ps1 "\u@\h:\w" "\\\$ "'
# will show username, at-sign, host, colon, cwd, then
# various status string, followed by dollar and SP, as
# your prompt.
# ZSH: precmd () { __git_ps1 "%n" ":%~$ " "|%s" }
# will show username, pipe, then various status string,
# followed by colon, cwd, dollar and SP, as your prompt.
# Optionally, you can supply a third argument with a printf
# format string to finetune the output of the branch status
#
# The repository status will be displayed only if you are currently in a
# git repository. The %s token is the placeholder for the shown status.
#
# The prompt status always includes the current branch name.
#
# In addition, if you set GIT_PS1_SHOWDIRTYSTATE to a nonempty value,
# unstaged (*) and staged (+) changes will be shown next to the branch
# name. You can configure this per-repository with the
# bash.showDirtyState variable, which defaults to true once
# GIT_PS1_SHOWDIRTYSTATE is enabled.
#
# You can also see if currently something is stashed, by setting
# GIT_PS1_SHOWSTASHSTATE to a nonempty value. If something is stashed,
# then a '$' will be shown next to the branch name.
#
# If you would like to see if there're untracked files, then you can set
# GIT_PS1_SHOWUNTRACKEDFILES to a nonempty value. If there're untracked
# files, then a '%' will be shown next to the branch name. You can
# configure this per-repository with the bash.showUntrackedFiles
# variable, which defaults to true once GIT_PS1_SHOWUNTRACKEDFILES is
# enabled.
#
# If you would like to see the difference between HEAD and its upstream,
# set GIT_PS1_SHOWUPSTREAM="auto". A "<" indicates you are behind, ">"
# indicates you are ahead, "<>" indicates you have diverged and "="
# indicates that there is no difference. You can further control
# behaviour by setting GIT_PS1_SHOWUPSTREAM to a space-separated list
# of values:
#
# verbose show number of commits ahead/behind (+/-) upstream
# legacy don't use the '--count' option available in recent
# versions of git-rev-list
# git always compare HEAD to @{upstream}
# svn always compare HEAD to your SVN upstream
#
# By default, __git_ps1 will compare HEAD to your SVN upstream if it can
# find one, or @{upstream} otherwise. Once you have set
# GIT_PS1_SHOWUPSTREAM, you can override it on a per-repository basis by
# setting the bash.showUpstream config variable.
#
# If you would like to see more information about the identity of
# commits checked out as a detached HEAD, set GIT_PS1_DESCRIBE_STYLE
# to one of these values:
#
# contains relative to newer annotated tag (v1.6.3.2~35)
# branch relative to newer tag or branch (master~4)
# describe relative to older annotated tag (v1.6.3.1-13-gdd42c2f)
# default exactly matching tag
#
# If you would like a colored hint about the current dirty state, set
# GIT_PS1_SHOWCOLORHINTS to a nonempty value. The colors are based on
# the colored output of "git status -sb" and are available only when
# using __git_ps1 for PROMPT_COMMAND or precmd.
# stores the divergence from upstream in $p
# used by GIT_PS1_SHOWUPSTREAM
__git_ps1_show_upstream ()
{
local key value
local svn_remote svn_url_pattern count n
local upstream=git legacy="" verbose=""
svn_remote=()
# get some config options from git-config
local output="$(git config -z --get-regexp '^(svn-remote\..*\.url|bash\.showupstream)$' 2>/dev/null | tr '\0\n' '\n ')"
while read -r key value; do
case "$key" in
bash.showupstream)
GIT_PS1_SHOWUPSTREAM="$value"
if [[ -z "${GIT_PS1_SHOWUPSTREAM}" ]]; then
p=""
return
fi
;;
svn-remote.*.url)
svn_remote[$((${#svn_remote[@]} + 1))]="$value"
svn_url_pattern+="\\|$value"
upstream=svn+git # default upstream is SVN if available, else git
;;
esac
done <<< "$output"
# parse configuration values
for option in ${GIT_PS1_SHOWUPSTREAM}; do
case "$option" in
git|svn) upstream="$option" ;;
verbose) verbose=1 ;;
legacy) legacy=1 ;;
esac
done
# Find our upstream
case "$upstream" in
git) upstream="@{upstream}" ;;
svn*)
# get the upstream from the "git-svn-id: ..." in a commit message
# (git-svn uses essentially the same procedure internally)
local -a svn_upstream
svn_upstream=($(git log --first-parent -1 \
--grep="^git-svn-id: \(${svn_url_pattern#??}\)" 2>/dev/null))
if [[ 0 -ne ${#svn_upstream[@]} ]]; then
svn_upstream=${svn_upstream[${#svn_upstream[@]} - 2]}
svn_upstream=${svn_upstream%@*}
local n_stop="${#svn_remote[@]}"
for ((n=1; n <= n_stop; n++)); do
svn_upstream=${svn_upstream#${svn_remote[$n]}}
done
if [[ -z "$svn_upstream" ]]; then
# default branch name for checkouts with no layout:
upstream=${GIT_SVN_ID:-git-svn}
else
upstream=${svn_upstream#/}
fi
elif [[ "svn+git" = "$upstream" ]]; then
upstream="@{upstream}"
fi
;;
esac
# Find how many commits we are ahead/behind our upstream
if [[ -z "$legacy" ]]; then
count="$(git rev-list --count --left-right \
"$upstream"...HEAD 2>/dev/null)"
else
# produce equivalent output to --count for older versions of git
local commits
if commits="$(git rev-list --left-right "$upstream"...HEAD 2>/dev/null)"
then
local commit behind=0 ahead=0
for commit in $commits
do
case "$commit" in
"<"*) ((behind++)) ;;
*) ((ahead++)) ;;
esac
done
count="$behind $ahead"
else
count=""
fi
fi
# calculate the result
if [[ -z "$verbose" ]]; then
case "$count" in
"") # no upstream
p="" ;;
"0 0") # equal to upstream
p="=" ;;
"0 "*) # ahead of upstream
p=">" ;;
*" 0") # behind upstream
p="<" ;;
*) # diverged from upstream
p="<>" ;;
esac
else
case "$count" in
"") # no upstream
p="" ;;
"0 0") # equal to upstream
p=" u=" ;;
"0 "*) # ahead of upstream
p=" u+${count#0 }" ;;
*" 0") # behind upstream
p=" u-${count% 0}" ;;
*) # diverged from upstream
p=" u+${count#* }-${count% *}" ;;
esac
fi
}
# Helper function that is meant to be called from __git_ps1. It
# injects color codes into the appropriate gitstring variables used
# to build a gitstring.
__git_ps1_colorize_gitstring ()
{
if [[ -n ${ZSH_VERSION-} ]]; then
local c_red='%F{red}'
local c_green='%F{green}'
local c_lblue='%F{blue}'
local c_clear='%f'
else
# Using \[ and \] around colors is necessary to prevent
# issues with command line editing/browsing/completion!
local c_red='\[\e[31m\]'
local c_green='\[\e[32m\]'
local c_lblue='\[\e[1;34m\]'
local c_clear='\[\e[0m\]'
fi
local bad_color=$c_red
local ok_color=$c_green
local flags_color="$c_lblue"
local branch_color=""
if [ $detached = no ]; then
branch_color="$ok_color"
else
branch_color="$bad_color"
fi
c="$branch_color$c"
z="$c_clear$z"
if [ "$w" = "*" ]; then
w="$bad_color$w"
fi
if [ -n "$i" ]; then
i="$ok_color$i"
fi
if [ -n "$s" ]; then
s="$flags_color$s"
fi
if [ -n "$u" ]; then
u="$bad_color$u"
fi
r="$c_clear$r"
}
# __git_ps1 accepts 0 or 1 arguments (i.e., format string)
# when called from PS1 using command substitution
# in this mode it prints text to add to bash PS1 prompt (includes branch name)
#
# __git_ps1 requires 2 or 3 arguments when called from PROMPT_COMMAND (pc)
# in that case it _sets_ PS1. The arguments are parts of a PS1 string.
# when two arguments are given, the first is prepended and the second appended
# to the state string when assigned to PS1.
# The optional third parameter will be used as printf format string to further
# customize the output of the git-status string.
# In this mode you can request colored hints using GIT_PS1_SHOWCOLORHINTS=true
__git_ps1 ()
{
local pcmode=no
local detached=no
local ps1pc_start='\u@\h:\w '
local ps1pc_end='\$ '
local printf_format=' (%s)'
case "$#" in
2|3) pcmode=yes
ps1pc_start="$1"
ps1pc_end="$2"
printf_format="${3:-$printf_format}"
;;
0|1) printf_format="${1:-$printf_format}"
;;
*) return
;;
esac
local repo_info rev_parse_exit_code
repo_info="$(git rev-parse --git-dir --is-inside-git-dir \
--is-bare-repository --is-inside-work-tree \
--short HEAD 2>/dev/null)"
rev_parse_exit_code="$?"
if [ -z "$repo_info" ]; then
if [ $pcmode = yes ]; then
#In PC mode PS1 always needs to be set
PS1="$ps1pc_start$ps1pc_end"
fi
return
fi
local short_sha
if [ "$rev_parse_exit_code" = "0" ]; then
short_sha="${repo_info##*$'\n'}"
repo_info="${repo_info%$'\n'*}"
fi
local inside_worktree="${repo_info##*$'\n'}"
repo_info="${repo_info%$'\n'*}"
local bare_repo="${repo_info##*$'\n'}"
repo_info="${repo_info%$'\n'*}"
local inside_gitdir="${repo_info##*$'\n'}"
local g="${repo_info%$'\n'*}"
local r=""
local b=""
local step=""
local total=""
if [ -d "$g/rebase-merge" ]; then
read b 2>/dev/null <"$g/rebase-merge/head-name"
read step 2>/dev/null <"$g/rebase-merge/msgnum"
read total 2>/dev/null <"$g/rebase-merge/end"
if [ -f "$g/rebase-merge/interactive" ]; then
r="|REBASE-i"
else
r="|REBASE-m"
fi
else
if [ -d "$g/rebase-apply" ]; then
read step 2>/dev/null <"$g/rebase-apply/next"
read total 2>/dev/null <"$g/rebase-apply/last"
if [ -f "$g/rebase-apply/rebasing" ]; then
read b 2>/dev/null <"$g/rebase-apply/head-name"
r="|REBASE"
elif [ -f "$g/rebase-apply/applying" ]; then
r="|AM"
else
r="|AM/REBASE"
fi
elif [ -f "$g/MERGE_HEAD" ]; then
r="|MERGING"
elif [ -f "$g/CHERRY_PICK_HEAD" ]; then
r="|CHERRY-PICKING"
elif [ -f "$g/REVERT_HEAD" ]; then
r="|REVERTING"
elif [ -f "$g/BISECT_LOG" ]; then
r="|BISECTING"
fi
if [ -n "$b" ]; then
:
elif [ -h "$g/HEAD" ]; then
# symlink symbolic ref
b="$(git symbolic-ref HEAD 2>/dev/null)"
else
local head=""
if ! read head 2>/dev/null <"$g/HEAD"; then
if [ $pcmode = yes ]; then
PS1="$ps1pc_start$ps1pc_end"
fi
return
fi
# is it a symbolic ref?
b="${head#ref: }"
if [ "$head" = "$b" ]; then
detached=yes
b="$(
case "${GIT_PS1_DESCRIBE_STYLE-}" in
(contains)
git describe --contains HEAD ;;
(branch)
git describe --contains --all HEAD ;;
(describe)
git describe HEAD ;;
(* | default)
git describe --tags --exact-match HEAD ;;
esac 2>/dev/null)" ||
b="$short_sha..."
b="($b)"
fi
fi
fi
if [ -n "$step" ] && [ -n "$total" ]; then
r="$r $step/$total"
fi
local w=""
local i=""
local s=""
local u=""
local c=""
local p=""
if [ "true" = "$inside_gitdir" ]; then
if [ "true" = "$bare_repo" ]; then
c="BARE:"
else
b="GIT_DIR!"
fi
elif [ "true" = "$inside_worktree" ]; then
if [ -n "${GIT_PS1_SHOWDIRTYSTATE-}" ] &&
[ "$(git config --bool bash.showDirtyState)" != "false" ]
then
git diff --no-ext-diff --quiet --exit-code || w="*"
if [ -n "$short_sha" ]; then
git diff-index --cached --quiet HEAD -- || i="+"
else
i="#"
fi
fi
if [ -n "${GIT_PS1_SHOWSTASHSTATE-}" ] &&
[ -r "$g/refs/stash" ]; then
s="$"
fi
if [ -n "${GIT_PS1_SHOWUNTRACKEDFILES-}" ] &&
[ "$(git config --bool bash.showUntrackedFiles)" != "false" ] &&
git ls-files --others --exclude-standard --error-unmatch -- '*' >/dev/null 2>/dev/null
then
u="%${ZSH_VERSION+%}"
fi
if [ -n "${GIT_PS1_SHOWUPSTREAM-}" ]; then
__git_ps1_show_upstream
fi
fi
local z="${GIT_PS1_STATESEPARATOR-" "}"
# NO color option unless in PROMPT_COMMAND mode
if [ $pcmode = yes ] && [ -n "${GIT_PS1_SHOWCOLORHINTS-}" ]; then
__git_ps1_colorize_gitstring
fi
local f="$w$i$s$u"
local gitstring="$c${b##refs/heads/}${f:+$z$f}$r$p"
if [ $pcmode = yes ]; then
if [[ -n ${ZSH_VERSION-} ]]; then
gitstring=$(printf -- "$printf_format" "$gitstring")
else
printf -v gitstring -- "$printf_format" "$gitstring"
fi
PS1="$ps1pc_start$gitstring$ps1pc_end"
else
printf -- "$printf_format" "$gitstring"
fi
}
......@@ -25,12 +25,14 @@ from . import dataset
from . import reader
import attr
import pooling
import inferencer
import networks
import py_paddle.swig_paddle as api
__all__ = [
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader',
'topology', 'networks'
'topology', 'networks', 'inferencer', 'infer'
]
......@@ -40,3 +42,6 @@ def init(**kwargs):
args.append('--%s=%s' % (key, str(kwargs[key])))
api.initPaddle(*args)
infer = inferencer.infer
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from paddle.trainer_config_helpers.default_decorators import wrap_name_default
import paddle.trainer_config_helpers as conf_helps
class Layer(object):
def __init__(self, name=None, parent_layers=None):
assert isinstance(parent_layers, dict)
self.name = name
self.__parent_layers__ = parent_layers
def to_proto(self, context):
"""
function to set proto attribute
"""
kwargs = dict()
for layer_name in self.__parent_layers__:
if not isinstance(self.__parent_layers__[layer_name],
collections.Sequence):
v1_layer = self.__parent_layers__[layer_name].to_proto(
context=context)
else:
v1_layer = map(lambda x: x.to_proto(context=context),
self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer
if self.name is None:
return self.to_proto_impl(**kwargs)
elif self.name not in context:
context[self.name] = self.to_proto_impl(**kwargs)
return context[self.name]
def to_proto_impl(self, **kwargs):
raise NotImplementedError()
def __convert_to_v2__(method_name, parent_names, is_default_name=True):
if is_default_name:
wrapper = wrap_name_default(name_prefix=method_name)
else:
wrapper = None
class V2LayerImpl(Layer):
def __init__(self, **kwargs):
parent_layers = dict()
other_kwargs = dict()
for pname in parent_names:
if kwargs.has_key(pname):
parent_layers[pname] = kwargs[pname]
for key in kwargs.keys():
if key not in parent_names:
other_kwargs[key] = kwargs[key]
name = kwargs.get('name', None)
super(V2LayerImpl, self).__init__(name, parent_layers)
self.__other_kwargs__ = other_kwargs
if wrapper is not None:
__init__ = wrapper(__init__)
def to_proto_impl(self, **kwargs):
args = dict()
for each in kwargs:
args[each] = kwargs[each]
for each in self.__other_kwargs__:
args[each] = self.__other_kwargs__[each]
return getattr(conf_helps, method_name)(**args)
return V2LayerImpl
......@@ -32,3 +32,10 @@ def download(url, module_name, md5sum):
shutil.copyfileobj(r.raw, f)
return filename
def dict_add(a_dict, ele):
if ele in a_dict:
a_dict[ele] += 1
else:
a_dict[ele] = 1
# /usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
IMDB dataset: http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz
"""
import paddle.v2.dataset.common
import tarfile
import Queue
import re
import string
import threading
__all__ = ['build_dict', 'train', 'test']
URL = 'http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz'
MD5 = '7c2ac02c03563afcf9b574c7e56c153a'
# Read files that match pattern. Tokenize and yield each file.
def tokenize(pattern):
with tarfile.open(paddle.v2.dataset.common.download(URL, 'imdb',
MD5)) as tarf:
# Note that we should use tarfile.next(), which does
# sequential access of member files, other than
# tarfile.extractfile, which does random access and might
# destroy hard disks.
tf = tarf.next()
while tf != None:
if bool(pattern.match(tf.name)):
# newline and punctuations removal and ad-hoc tokenization.
yield tarf.extractfile(tf).read().rstrip("\n\r").translate(
None, string.punctuation).lower().split()
tf = tarf.next()
def build_dict(pattern, cutoff):
word_freq = {}
for doc in tokenize(pattern):
for word in doc:
paddle.v2.dataset.common.dict_add(word_freq, word)
# Not sure if we should prune less-frequent words here.
word_freq = filter(lambda x: x[1] > cutoff, word_freq.items())
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary))
word_idx = dict(zip(words, xrange(len(words))))
word_idx['<unk>'] = len(words)
return word_idx
def reader_creator(pos_pattern, neg_pattern, word_idx, buffer_size):
UNK = word_idx['<unk>']
qs = [Queue.Queue(maxsize=buffer_size), Queue.Queue(maxsize=buffer_size)]
def load(pattern, queue):
for doc in tokenize(pattern):
queue.put(doc)
queue.put(None)
def reader():
# Creates two threads that loads positive and negative samples
# into qs.
t0 = threading.Thread(
target=load, args=(
pos_pattern,
qs[0], ))
t0.daemon = True
t0.start()
t1 = threading.Thread(
target=load, args=(
neg_pattern,
qs[1], ))
t1.daemon = True
t1.start()
# Read alternatively from qs[0] and qs[1].
i = 0
doc = qs[i].get()
while doc != None:
yield [word_idx.get(w, UNK) for w in doc], i % 2
i += 1
doc = qs[i % 2].get()
# If any queue is empty, reads from the other queue.
i += 1
doc = qs[i % 2].get()
while doc != None:
yield [word_idx.get(w, UNK) for w in doc], i % 2
doc = qs[i % 2].get()
return reader()
def train(word_idx):
return reader_creator(
re.compile("aclImdb/train/pos/.*\.txt$"),
re.compile("aclImdb/train/neg/.*\.txt$"), word_idx, 1000)
def test(word_idx):
return reader_creator(
re.compile("aclImdb/test/pos/.*\.txt$"),
re.compile("aclImdb/test/neg/.*\.txt$"), word_idx, 1000)
"""
imikolov's simple dataset: http://www.fit.vutbr.cz/~imikolov/rnnlm/
"""
import paddle.v2.dataset.common
import tarfile
__all__ = ['train', 'test']
URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d'
def word_count(f, word_freq=None):
add = paddle.v2.dataset.common.dict_add
if word_freq == None:
word_freq = {}
for l in f:
for w in l.strip().split():
add(word_freq, w)
add(word_freq, '<s>')
add(word_freq, '<e>')
return word_freq
def build_dict(train_filename, test_filename):
with tarfile.open(
paddle.v2.dataset.common.download(
paddle.v2.dataset.imikolov.URL, 'imikolov',
paddle.v2.dataset.imikolov.MD5)) as tf:
trainf = tf.extractfile(train_filename)
testf = tf.extractfile(test_filename)
word_freq = word_count(testf, word_count(trainf))
TYPO_FREQ = 50
word_freq = filter(lambda x: x[1] > TYPO_FREQ, word_freq.items())
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary))
word_idx = dict(zip(words, xrange(len(words))))
word_idx['<unk>'] = len(words)
return word_idx
word_idx = {}
def reader_creator(filename, n):
global word_idx
if len(word_idx) == 0:
word_idx = build_dict('./simple-examples/data/ptb.train.txt',
'./simple-examples/data/ptb.valid.txt')
def reader():
with tarfile.open(
paddle.v2.dataset.common.download(
paddle.v2.dataset.imikolov.URL, 'imikolov',
paddle.v2.dataset.imikolov.MD5)) as tf:
f = tf.extractfile(filename)
UNK = word_idx['<unk>']
for l in f:
l = ['<s>'] + l.strip().split() + ['<e>']
if len(l) >= n:
l = [word_idx.get(w, UNK) for w in l]
for i in range(n, len(l) + 1):
yield tuple(l[i - n:i])
return reader
def train(n):
return reader_creator('./simple-examples/data/ptb.train.txt', n)
def test(n):
return reader_creator('./simple-examples/data/ptb.valid.txt', n)
......@@ -9,9 +9,9 @@ __all__ = ['train', 'test']
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6'
TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
TEST_LABEL_MD5 = '4e9511fe019b2189026bd0421ba7b688'
TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
......@@ -35,24 +35,25 @@ def reader_creator(image_filename, label_filename, buffer_size):
l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE)
l.stdout.read(8) # skip some magic bytes
while True:
labels = numpy.fromfile(
l.stdout, 'ubyte', count=buffer_size).astype("int")
try: # reader could be break.
while True:
labels = numpy.fromfile(
l.stdout, 'ubyte', count=buffer_size).astype("int")
if labels.size != buffer_size:
break # numpy.fromfile returns empty slice after EOF.
if labels.size != buffer_size:
break # numpy.fromfile returns empty slice after EOF.
images = numpy.fromfile(
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
(buffer_size, 28 * 28)).astype('float32')
images = numpy.fromfile(
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
(buffer_size, 28 * 28)).astype('float32')
images = images / 255.0 * 2.0 - 1.0
images = images / 255.0 * 2.0 - 1.0
for i in xrange(buffer_size):
yield images[i, :], int(labels[i])
m.terminate()
l.terminate()
for i in xrange(buffer_size):
yield images[i, :], int(labels[i])
finally:
m.terminate()
l.terminate()
return reader
......
import paddle.v2.dataset.imdb
import unittest
import re
TRAIN_POS_PATTERN = re.compile("aclImdb/train/pos/.*\.txt$")
TRAIN_NEG_PATTERN = re.compile("aclImdb/train/neg/.*\.txt$")
TRAIN_PATTERN = re.compile("aclImdb/train/.*\.txt$")
TEST_POS_PATTERN = re.compile("aclImdb/test/pos/.*\.txt$")
TEST_NEG_PATTERN = re.compile("aclImdb/test/neg/.*\.txt$")
TEST_PATTERN = re.compile("aclImdb/test/.*\.txt$")
class TestIMDB(unittest.TestCase):
word_idx = None
def test_build_dict(self):
if self.word_idx == None:
self.word_idx = paddle.v2.dataset.imdb.build_dict(TRAIN_PATTERN,
150)
self.assertEqual(len(self.word_idx), 7036)
def check_dataset(self, dataset, expected_size):
if self.word_idx == None:
self.word_idx = paddle.v2.dataset.imdb.build_dict(TRAIN_PATTERN,
150)
sum = 0
for l in dataset(self.word_idx):
self.assertEqual(l[1], sum % 2)
sum += 1
self.assertEqual(sum, expected_size)
def test_train(self):
self.check_dataset(paddle.v2.dataset.imdb.train, 25000)
def test_test(self):
self.check_dataset(paddle.v2.dataset.imdb.test, 25000)
if __name__ == '__main__':
unittest.main()
import paddle.v2.dataset.imikolov
import unittest
class TestMikolov(unittest.TestCase):
def check_reader(self, reader, n):
for l in reader():
self.assertEqual(len(l), n)
def test_train(self):
n = 5
self.check_reader(paddle.v2.dataset.imikolov.train(n), n)
def test_test(self):
n = 5
self.check_reader(paddle.v2.dataset.imikolov.test(n), n)
if __name__ == '__main__':
unittest.main()
......@@ -11,7 +11,10 @@ There are:
TODO(yuyang18): Complete it!
"""
import py_paddle.swig_paddle as api
__all__ = ['EndIteration', 'BeginIteration', 'BeginPass', 'EndPass']
__all__ = [
'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult'
]
class WithMetric(object):
......@@ -30,6 +33,11 @@ class WithMetric(object):
return retv
class TestResult(WithMetric):
def __init__(self, evaluator):
super(TestResult, self).__init__(evaluator)
class BeginPass(object):
"""
Event On One Pass Training Start.
......
import py_paddle.swig_paddle as api
import topology
from data_feeder import DataFeeder
import itertools
import numpy
__all__ = ['Inference', 'infer']
class Inference(object):
def __init__(self, output, parameters):
topo = topology.Topology(output)
gm = api.GradientMachine.createFromConfigProto(
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
for param in gm.getParameters():
val = param.getBuf(api.PARAMETER_VALUE)
name = param.getName()
assert isinstance(val, api.Vector)
val.copyFromNumpyArray(parameters.get(name).flatten())
self.__gradient_machine__ = gm
self.__data_types__ = topo.data_type()
def iter_infer(self, reader, reader_dict=None):
if reader_dict is None:
reader_dict = self.default_reader_dict()
feeder = DataFeeder(self.__data_types__, reader_dict)
self.__gradient_machine__.start()
for data_batch in reader():
yield self.__gradient_machine__.forwardTest(feeder(data_batch))
self.__gradient_machine__.finish()
def iter_infer_field(self, field, **kwargs):
for result in self.iter_infer(**kwargs):
yield [each_result[field] for each_result in result]
def infer(self, field='value', **kwargs):
retv = None
for result in self.iter_infer_field(field=field, **kwargs):
if retv is None:
retv = [[]] * len(result)
for i, item in enumerate(result):
retv[i].append(item)
retv = [numpy.concatenate(out) for out in retv]
if len(retv) == 1:
return retv[0]
else:
return retv
def default_reader_dict(self):
reader_dict = dict()
for i, tp in enumerate(self.__data_types__):
reader_dict[tp[0]] = i
return reader_dict
def infer(output, parameters, reader, reader_dict=None, field='value'):
inferer = Inference(output=output, parameters=parameters)
return inferer.infer(field=field, reader=reader, reader_dict=reader_dict)
......@@ -68,6 +68,7 @@ paddle.v2.parameters.create, no longer exposed to users.
import collections
import inspect
from config_base import Layer, __convert_to_v2__
import paddle.trainer_config_helpers as conf_helps
from paddle.trainer_config_helpers.config_parser_utils import \
parse_network_config as __parse__
......@@ -108,90 +109,6 @@ def parse_network(*outputs):
return __parse__(__real_func__)
class Layer(object):
def __init__(self, name=None, size=None, parent_layers=None):
assert isinstance(parent_layers, dict)
self.name = name
self.size = size
self.__parent_layers__ = parent_layers
def to_proto(self, context):
"""
function to set proto attribute
"""
kwargs = dict()
for layer_name in self.__parent_layers__:
if not isinstance(self.__parent_layers__[layer_name],
collections.Sequence):
v1_layer = self.__parent_layers__[layer_name].to_proto(
context=context)
else:
v1_layer = map(lambda x: x.to_proto(context=context),
self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer
if self.context_name() is None:
return self.to_proto_impl(**kwargs)
elif self.context_name() not in context:
context[self.context_name()] = self.to_proto_impl(**kwargs)
if self.use_context_name():
return context[self.context_name()]
else:
return context[self.name]
def to_proto_impl(self, **kwargs):
raise NotImplementedError()
def context_name(self):
"""
Context name means the context which stores `to_proto_impl` result.
If multiple layer share same context_name, the `to_proto_impl` of them
will be invoked only once.
"""
return self.name
def use_context_name(self):
return False
def __convert_to_v2__(method_name, parent_names, is_default_name=True):
if is_default_name:
wrapper = wrap_name_default(name_prefix=method_name)
else:
wrapper = None
class V2LayerImpl(Layer):
def __init__(self, **kwargs):
parent_layers = dict()
other_kwargs = dict()
for pname in parent_names:
if kwargs.has_key(pname):
parent_layers[pname] = kwargs[pname]
for key in kwargs.keys():
if key not in parent_names:
other_kwargs[key] = kwargs[key]
name = kwargs.get('name', None)
size = kwargs.get('size', None)
super(V2LayerImpl, self).__init__(name, size, parent_layers)
self.__other_kwargs__ = other_kwargs
if wrapper is not None:
__init__ = wrapper(__init__)
def to_proto_impl(self, **kwargs):
args = dict()
for each in kwargs:
args[each] = kwargs[each]
for each in self.__other_kwargs__:
args[each] = self.__other_kwargs__[each]
return getattr(conf_helps, method_name)(**args)
return V2LayerImpl
"""
Some layer may need some special config, and can not use __convert_to_v2__ to convert.
So we also need to implement some special LayerV2.
......
......@@ -12,8 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from layer import __convert_to_v2__
import paddle.trainer_config_helpers.networks as conf_nw
import inspect
from config_base import __convert_to_v2__
simple_gru = __convert_to_v2__('simple_gru', ['input'])
simple_attention = __convert_to_v2__(
'simple_attention', ['encoded_sequence', 'encoded_proj', 'decoder_state'])
__all__ = []
def __initialize__():
for each_subnetwork in conf_nw.__all__:
if each_subnetwork in ['inputs', 'outputs']:
continue
func = getattr(conf_nw, each_subnetwork)
if hasattr(func, 'argspec'):
argspec = func.argspec
else:
argspec = inspect.getargspec(func)
if each_subnetwork == 'simple_attention':
parents = ['encoded_sequence', 'encoded_proj', 'decoder_state']
else:
parents = filter(lambda x: x.startswith('input'), argspec.args)
assert len(parents) != 0, each_subnetwork
v2_subnet = __convert_to_v2__(
each_subnetwork,
parent_names=parents,
is_default_name='name' in argspec.args)
globals()[each_subnetwork] = v2_subnet
global __all__
__all__.append(each_subnetwork)
__initialize__()
......@@ -14,13 +14,13 @@
__all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned', 'batched'
'ComposeNotAligned', 'batched', 'firstn'
]
from Queue import Queue
from threading import Thread
import itertools
import random
from Queue import Queue
from threading import Thread
def map_readers(func, *readers):
......@@ -213,3 +213,20 @@ def batched(reader, batch_size):
yield batch
return batched_reader
def firstn(reader, n):
"""
Limit the max number of samples that reader could return.
"""
# TODO(yuyang18): Check if just drop the reader, could clean the opened
# resource or not?
def firstn_reader():
for i, item in enumerate(reader()):
if i == n:
break
yield item
return firstn_reader
......@@ -235,4 +235,8 @@ class DataFeederTest(unittest.TestCase):
if __name__ == '__main__':
api.initPaddle("--use_gpu=0")
unittest.main()
suite = unittest.TestLoader().loadTestsFromTestCase(DataFeederTest)
unittest.TextTestRunner().run(suite)
if api.isGpuVersion():
api.setUseGpu(True)
unittest.main()
......@@ -18,6 +18,7 @@ import paddle.v2.attr as attr
import paddle.v2.data_type as data_type
import paddle.v2.layer as layer
import paddle.v2.pooling as pooling
import paddle.v2.networks as networks
pixel = layer.data(name='pixel', type=data_type.dense_vector(128))
label = layer.data(name='label', type=data_type.integer_value(10))
......@@ -251,5 +252,13 @@ class ProjOpTest(unittest.TestCase):
print layer.parse_network(conv1)
class NetworkTests(unittest.TestCase):
def test_vgg(self):
img = layer.data(name='pixel', type=data_type.dense_vector(784))
vgg_out = networks.small_vgg(
input_image=img, num_channels=1, num_classes=2)
print layer.parse_network(vgg_out)
if __name__ == '__main__':
unittest.main()
......@@ -21,6 +21,28 @@ import layer as v2_layer
__all__ = ['Topology']
def __flatten__(lis):
"""
Given a list, possibly nested to any level, return it flattened.
"""
new_lis = []
for item in lis:
if isinstance(item, collections.Sequence):
new_lis.extend(__flatten__(item))
else:
new_lis.append(item)
return new_lis
def __bfs_travel__(callback, *layers):
layers = __flatten__(layers)
for each_layer in layers:
__break__ = callback(each_layer)
if __break__:
return
__bfs_travel__(callback, *each_layer.__parent_layers__.values())
class Topology(object):
"""
Topology is used to store the information about all layers
......@@ -46,21 +68,17 @@ class Topology(object):
:param name:
:return:
"""
result_layer = []
result_layer = [None]
def find_layer_by_name(layer, layer_name):
if len(result_layer) == 1:
return
elif layer.name == layer_name:
result_layer.append(layer)
else:
for parent_layer in layer.__parent_layers__.values():
find_layer_by_name(parent_layer, layer_name)
def __impl__(l):
if l.name == name:
result_layer[0] = l
return True # break
return False
for layer in self.layers:
find_layer_by_name(layer, name)
assert len(result_layer) == 1
__bfs_travel__(__impl__, *self.layers)
if result_layer[0] is None:
raise ValueError("No such layer %s" % name)
return result_layer[0]
def data_layers(self):
......@@ -68,17 +86,13 @@ class Topology(object):
get all data layer
:return:
"""
data_layers = set()
def find_data_layer(layer):
if isinstance(layer, v2_layer.DataLayerV2):
data_layers.add(layer)
for parent_layer in layer.__parent_layers__.values():
find_data_layer(parent_layer)
data_layers = dict()
for layer in self.layers:
find_data_layer(layer)
def __impl__(l):
if isinstance(l, v2_layer.DataLayerV2):
data_layers[l.name] = l
__bfs_travel__(__impl__, *self.layers)
return data_layers
def data_type(self):
......@@ -86,8 +100,9 @@ class Topology(object):
get data_type from proto, such as:
[('image', dense_vector(768)), ('label', integer_value(10))]
"""
return [(data_layer.name, data_layer.type)
for data_layer in self.data_layers()]
data_layers = self.data_layers()
return [(nm, data_layers[nm].type)
for nm in self.proto().input_layer_names]
def __check_layer_type__(layer):
......
......@@ -42,25 +42,35 @@ class ITrainer(object):
class SGD(ITrainer):
def __init__(self, update_equation):
def __init__(self, cost, parameters, update_equation):
"""
Simple SGD Trainer.
:param update_equation: The optimizer object.
:type update_equation: v2_optimizer.Optimizer
"""
if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters')
if not isinstance(update_equation, v2_optimizer.Optimizer):
raise ValueError("update equation parameter must be "
"paddle.v2.optimizer.Optimizer")
raise TypeError("update equation parameter must be "
"paddle.v2.optimizer.Optimizer")
topology = Topology(cost)
self.__optimizer__ = update_equation
self.__topology__ = topology
self.__parameters__ = parameters
self.__topology_in_proto__ = topology.proto()
self.__data_types__ = topology.data_type()
gm = api.GradientMachine.createFromConfigProto(
self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine)
parameters.append_gradient_machine(gm)
self.__gradient_machine__ = gm
self.__gradient_machine__.randParameters()
def train(self,
reader,
cost,
parameters,
num_passes=1,
event_handler=None,
reader_dict=None):
def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
"""
Training method. Will train num_passes of input data.
......@@ -76,27 +86,22 @@ class SGD(ITrainer):
if event_handler is None:
event_handler = default_event_handler
topology = Topology(cost)
if reader_dict is None:
reader_dict = self.default_reader_dict()
__check_train_args__(**locals())
gm = api.GradientMachine.createFromConfigProto(
topology.proto(), api.CREATE_MODE_NORMAL,
self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine)
parameters.append_gradient_machine(gm)
gm.randParameters()
updater = self.__optimizer__.create_local_updater()
updater.init(gm)
updater.init(self.__gradient_machine__)
gm.start()
batch_evaluator = gm.makeEvaluator()
self.__gradient_machine__.start()
batch_evaluator = self.__gradient_machine__.makeEvaluator()
assert isinstance(batch_evaluator, api.Evaluator)
pass_evaluator = gm.makeEvaluator()
pass_evaluator = self.__gradient_machine__.makeEvaluator()
assert isinstance(pass_evaluator, api.Evaluator)
out_args = api.Arguments.createArguments(0)
feeder = DataFeeder(topology.data_type(), reader_dict)
feeder = DataFeeder(self.__data_types__, reader_dict)
for pass_id in xrange(num_passes):
event_handler(v2_event.BeginPass(pass_id))
......@@ -104,16 +109,18 @@ class SGD(ITrainer):
updater.startPass()
for batch_id, data_batch in enumerate(reader()):
pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
self.__gradient_machine__.forwardBackward(
feeder(data_batch), out_args, pass_type)
batch_evaluator.start()
event_handler(
v2_event.BeginIteration(
pass_id=pass_id, batch_id=batch_id))
pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
gm.eval(pass_evaluator)
gm.eval(batch_evaluator)
for each_param in gm.getParameters():
self.__gradient_machine__.forwardBackward(
feeder(data_batch), out_args, pass_type)
self.__gradient_machine__.eval(pass_evaluator)
self.__gradient_machine__.eval(batch_evaluator)
for each_param in self.__gradient_machine__.getParameters():
updater.update(each_param)
# Get cost. We use numpy to calculate total cost for this batch.
cost_vec = out_args.getSlotValue(0)
......@@ -131,22 +138,37 @@ class SGD(ITrainer):
updater.finishPass()
pass_evaluator.finish()
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
gm.finish()
self.__gradient_machine__.finish()
def default_reader_dict(self):
reader_dict = dict()
for i, tp in enumerate(self.__data_types__):
reader_dict[tp[0]] = i
return reader_dict
def test(self, reader, reader_dict=None):
if reader_dict is None:
reader_dict = self.default_reader_dict()
feeder = DataFeeder(self.__data_types__, reader_dict)
evaluator = self.__gradient_machine__.makeEvaluator()
out_args = api.Arguments.createArguments(0)
evaluator.start()
for data_batch in reader():
self.__gradient_machine__.forward(
feeder(data_batch), out_args, api.PASS_TEST)
self.__gradient_machine__.eval(evaluator)
evaluator.finish()
return v2_event.TestResult(evaluator=evaluator)
def __check_train_args__(reader, topology, parameters, event_handler, **kwargs):
def __check_train_args__(reader, event_handler, **kwargs):
"""
Check train function's argument types
"""
if not callable(reader) or not isinstance(reader(), collections.Iterator):
raise TypeError('train_data_reader should be a function, '
'which can return a iterator')
if not isinstance(topology, Topology):
raise TypeError('topology should be a model config')
if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be a parameter pool')
if not callable(event_handler):
raise TypeError('event handler should be a function')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册