From 777a5cca91dcc9617e85be4be037534040f3fbc7 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Thu, 13 Jul 2017 20:07:26 +0000 Subject: [PATCH] Client test: concurrently init param. Concurrently send grad and get param --- go/pserver/client/client_test.go | 90 ++++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 22 deletions(-) diff --git a/go/pserver/client/client_test.go b/go/pserver/client/client_test.go index 21dac92417f..27f4ff2380b 100644 --- a/go/pserver/client/client_test.go +++ b/go/pserver/client/client_test.go @@ -3,11 +3,13 @@ package client_test import ( "context" "io/ioutil" + "math/rand" "net" "net/http" "net/rpc" "strconv" "strings" + "sync" "testing" "time" @@ -111,16 +113,23 @@ func testClient(t *testing.T, c *client.Client) { if err != nil { t.Fatalf("read optimizer proto failed") } + + var wg sync.WaitGroup for i := 0; i < numParameter; i++ { - var p pserver.Parameter - p.Name = "p_" + strconv.Itoa(i) - p.ElementType = pserver.Float32 - p.Content = make([]byte, (i+1)*100) - err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}) - if err != nil { - t.Fatal(err) - } + wg.Add(1) + go func(i int) { + var p pserver.Parameter + p.Name = "p_" + strconv.Itoa(i) + p.ElementType = pserver.Float32 + p.Content = make([]byte, (i+1)*100) + err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}) + if err != nil { + t.Fatal(err) + } + wg.Done() + }(i) } + wg.Wait() err = c.FinishInitParams() if err != nil { @@ -136,9 +145,31 @@ func testClient(t *testing.T, c *client.Client) { grads = append(grads, g) } - err = c.SendGrads(grads) - if err != nil { - t.Fatal(err) + const paramPerGroup = 10 + const numGroups = numParameter / paramPerGroup + + // shuffle send grads order + for i := range grads { + j := rand.Intn(i + 1) + grads[i], grads[j] = grads[j], grads[i] + } + + for i := 0; i < numGroups; i++ { + var gs []pserver.Gradient + if i == numGroups-1 { + gs = grads[i*paramPerGroup:] + } else { + gs = grads[i*paramPerGroup : (i+1)*paramPerGroup] + } + + wg.Add(1) + go func(gs []pserver.Gradient) { + err = c.SendGrads(gs) + if err != nil { + t.Fatal(err) + } + wg.Done() + }(gs) } names := make([]string, numParameter) @@ -146,20 +177,35 @@ func testClient(t *testing.T, c *client.Client) { names[i] = "p_" + strconv.Itoa(i) } - params, err := c.GetParams(names) - if err != nil { - t.Fatal(err) - } + for i := 0; i < numGroups; i++ { + var ns []string + if i == numGroups-1 { + ns = names[i*paramPerGroup:] + } else { + ns = names[i*paramPerGroup : (i+1)*paramPerGroup] + } - if len(names) != len(params) { - t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params)) - } + wg.Add(1) + go func(ns []string) { + params, err := c.GetParams(ns) + if err != nil { + t.Fatal(err) + } - for i := range params { - if names[i] != params[i].Name { - t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name) - } + if len(ns) != len(params) { + t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params)) + } + + for i := range params { + if ns[i] != params[i].Name { + t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", ns[i], params[i].Name) + } + } + wg.Done() + }(ns) } + + wg.Wait() } func TestNativeClient(t *testing.T) { -- GitLab