diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go
index fe1fe5f6f03b315cf30d96e171dca53e32efa040..6c85b1804bb9c5f3a8bc46bb3f54cc62c56cca70 100644
--- a/go/cmd/pserver/pserver.go
+++ b/go/cmd/pserver/pserver.go
@@ -18,6 +18,7 @@ func main() {
 	etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
 		"comma separated endpoint string for pserver to connect to etcd")
 	etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
+	numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
 	logLevel := flag.String("log-level", "info",
 		"log level, possible values: debug, info, warning, error, fatal, panic")
 	flag.Parse()
@@ -29,7 +30,7 @@ func main() {
 	log.SetLevel(level)
 
 	timeout := time.Second * time.Duration((*etcdTimeout))
-	s, err := pserver.NewService(*etcdEndpoint, timeout)
+	s, err := pserver.NewService(*etcdEndpoint, *numPservers, timeout)
 	if err != nil {
 		panic(err)
 	}
diff --git a/go/pserver/service.go b/go/pserver/service.go
index 7e2b841dd8e8efda8b3033645b4b1b6a596e53c4..f966595fdccbf23e23f94a857503ce05815164ef 100644
--- a/go/pserver/service.go
+++ b/go/pserver/service.go
@@ -73,7 +73,7 @@ type Service struct {
 
 // NewService creates a new service, will bypass etcd registration if no
 // endpoints specified.
-func NewService(endpoints string, timeout time.Duration) (*Service, error) {
+func NewService(endpoints string, numPservers int, timeout time.Duration) (*Service, error) {
 	s := &Service{opt: newOptimizer(sgd, 0.005)}
 	s.paramMap = make(map[string]Parameter)
 	s.initialized = make(chan struct{})
@@ -103,6 +103,22 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) {
 			log.Debugf("inited client to %s", s.etcdEndpoints)
 			break
 		}
+		// init /ps_desired using transaction, for multiple pservers may want to write
+		// it at the same time.
+		for {
+			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+			_, err := s.initDesiredPsercers(ctx, numPservers)
+			cancel()
+			if err != nil {
+				log.Warn(err)
+				time.Sleep(s.etcdTimeout)
+				continue
+			}
+			break
+		}
+		// TODO: when implementing extending or reducing pservers, /ps_desired is
+		// changed, then we need to watch /ps_desired node for events. For now, just
+		// write once when init and read from it.
 		// wait and set s.desired init value
 		for {
 			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
@@ -141,6 +157,16 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) {
 	return s, nil
 }
 
+func (s *Service) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
+	return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
+		dsStr := c.Get(PsDesired)
+		if dsStr == "" {
+			c.Put(PsDesired, strconv.Itoa(numPservers))
+		}
+		return nil
+	}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
+}
+
 // registerPserverEtcd registers pserver node on etcd using transaction.
 func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) {
 	return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {