提交 777a5cca 编写于 作者: H Helin Wang

Client test: concurrently init param. Concurrently send grad and get param

上级 11660eab
...@@ -3,11 +3,13 @@ package client_test ...@@ -3,11 +3,13 @@ package client_test
import ( import (
"context" "context"
"io/ioutil" "io/ioutil"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
...@@ -111,16 +113,23 @@ func testClient(t *testing.T, c *client.Client) { ...@@ -111,16 +113,23 @@ func testClient(t *testing.T, c *client.Client) {
if err != nil { if err != nil {
t.Fatalf("read optimizer proto failed") t.Fatalf("read optimizer proto failed")
} }
var wg sync.WaitGroup
for i := 0; i < numParameter; i++ { for i := 0; i < numParameter; i++ {
var p pserver.Parameter wg.Add(1)
p.Name = "p_" + strconv.Itoa(i) go func(i int) {
p.ElementType = pserver.Float32 var p pserver.Parameter
p.Content = make([]byte, (i+1)*100) p.Name = "p_" + strconv.Itoa(i)
err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}) p.ElementType = pserver.Float32
if err != nil { p.Content = make([]byte, (i+1)*100)
t.Fatal(err) err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
} if err != nil {
t.Fatal(err)
}
wg.Done()
}(i)
} }
wg.Wait()
err = c.FinishInitParams() err = c.FinishInitParams()
if err != nil { if err != nil {
...@@ -136,9 +145,31 @@ func testClient(t *testing.T, c *client.Client) { ...@@ -136,9 +145,31 @@ func testClient(t *testing.T, c *client.Client) {
grads = append(grads, g) grads = append(grads, g)
} }
err = c.SendGrads(grads) const paramPerGroup = 10
if err != nil { const numGroups = numParameter / paramPerGroup
t.Fatal(err)
// shuffle send grads order
for i := range grads {
j := rand.Intn(i + 1)
grads[i], grads[j] = grads[j], grads[i]
}
for i := 0; i < numGroups; i++ {
var gs []pserver.Gradient
if i == numGroups-1 {
gs = grads[i*paramPerGroup:]
} else {
gs = grads[i*paramPerGroup : (i+1)*paramPerGroup]
}
wg.Add(1)
go func(gs []pserver.Gradient) {
err = c.SendGrads(gs)
if err != nil {
t.Fatal(err)
}
wg.Done()
}(gs)
} }
names := make([]string, numParameter) names := make([]string, numParameter)
...@@ -146,20 +177,35 @@ func testClient(t *testing.T, c *client.Client) { ...@@ -146,20 +177,35 @@ func testClient(t *testing.T, c *client.Client) {
names[i] = "p_" + strconv.Itoa(i) names[i] = "p_" + strconv.Itoa(i)
} }
params, err := c.GetParams(names) for i := 0; i < numGroups; i++ {
if err != nil { var ns []string
t.Fatal(err) if i == numGroups-1 {
} ns = names[i*paramPerGroup:]
} else {
ns = names[i*paramPerGroup : (i+1)*paramPerGroup]
}
if len(names) != len(params) { wg.Add(1)
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params)) go func(ns []string) {
} params, err := c.GetParams(ns)
if err != nil {
t.Fatal(err)
}
for i := range params { if len(ns) != len(params) {
if names[i] != params[i].Name { t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name) }
}
for i := range params {
if ns[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", ns[i], params[i].Name)
}
}
wg.Done()
}(ns)
} }
wg.Wait()
} }
func TestNativeClient(t *testing.T) { func TestNativeClient(t *testing.T) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册