提交 777a5cca 编写于 作者: H Helin Wang

Client test: concurrently init param. Concurrently send grad and get param

上级 11660eab
...@@ -3,11 +3,13 @@ package client_test ...@@ -3,11 +3,13 @@ package client_test
import ( import (
"context" "context"
"io/ioutil" "io/ioutil"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
...@@ -111,7 +113,11 @@ func testClient(t *testing.T, c *client.Client) { ...@@ -111,7 +113,11 @@ func testClient(t *testing.T, c *client.Client) {
if err != nil { if err != nil {
t.Fatalf("read optimizer proto failed") t.Fatalf("read optimizer proto failed")
} }
var wg sync.WaitGroup
for i := 0; i < numParameter; i++ { for i := 0; i < numParameter; i++ {
wg.Add(1)
go func(i int) {
var p pserver.Parameter var p pserver.Parameter
p.Name = "p_" + strconv.Itoa(i) p.Name = "p_" + strconv.Itoa(i)
p.ElementType = pserver.Float32 p.ElementType = pserver.Float32
...@@ -120,7 +126,10 @@ func testClient(t *testing.T, c *client.Client) { ...@@ -120,7 +126,10 @@ func testClient(t *testing.T, c *client.Client) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wg.Done()
}(i)
} }
wg.Wait()
err = c.FinishInitParams() err = c.FinishInitParams()
if err != nil { if err != nil {
...@@ -136,30 +145,67 @@ func testClient(t *testing.T, c *client.Client) { ...@@ -136,30 +145,67 @@ func testClient(t *testing.T, c *client.Client) {
grads = append(grads, g) grads = append(grads, g)
} }
err = c.SendGrads(grads) const paramPerGroup = 10
const numGroups = numParameter / paramPerGroup
// shuffle send grads order
for i := range grads {
j := rand.Intn(i + 1)
grads[i], grads[j] = grads[j], grads[i]
}
for i := 0; i < numGroups; i++ {
var gs []pserver.Gradient
if i == numGroups-1 {
gs = grads[i*paramPerGroup:]
} else {
gs = grads[i*paramPerGroup : (i+1)*paramPerGroup]
}
wg.Add(1)
go func(gs []pserver.Gradient) {
err = c.SendGrads(gs)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wg.Done()
}(gs)
}
names := make([]string, numParameter) names := make([]string, numParameter)
for i := 0; i < numParameter; i++ { for i := 0; i < numParameter; i++ {
names[i] = "p_" + strconv.Itoa(i) names[i] = "p_" + strconv.Itoa(i)
} }
params, err := c.GetParams(names) for i := 0; i < numGroups; i++ {
var ns []string
if i == numGroups-1 {
ns = names[i*paramPerGroup:]
} else {
ns = names[i*paramPerGroup : (i+1)*paramPerGroup]
}
wg.Add(1)
go func(ns []string) {
params, err := c.GetParams(ns)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(names) != len(params) { if len(ns) != len(params) {
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params)) t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
} }
for i := range params { for i := range params {
if names[i] != params[i].Name { if ns[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name) t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", ns[i], params[i].Name)
} }
} }
wg.Done()
}(ns)
}
wg.Wait()
} }
func TestNativeClient(t *testing.T) { func TestNativeClient(t *testing.T) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册