diff --git a/go/.gitignore b/go/.gitignore index 000e1fd55b63b8e532308b787c2708a6c3e5ac87..398d70ca375ffceccdbfc82a4851a6830ca31264 100644 --- a/go/.gitignore +++ b/go/.gitignore @@ -1,2 +1,3 @@ vendor/ .glide/ +proto/*.go diff --git a/go/glide.lock b/go/glide.lock index ce654d36364f8078a493651d8d8b141532eea26d..d15fc934dbe511389cc92ce95cededa41ba32b4d 100644 --- a/go/glide.lock +++ b/go/glide.lock @@ -1,5 +1,5 @@ -hash: 51d9e2e46d7fd9173ff11ecada40f7b7728756be18d5e2f032535f66465e6e15 -updated: 2017-10-24T15:04:09.987751592-07:00 +hash: 107c058cf5c9163a75d40eef2273a793c36112683c25d72aa8288827fdde3a19 +updated: 2017-10-30T03:46:19.137696069Z imports: - name: github.com/alecthomas/gometalinter version: bae2f1293d092fd8167939d5108d1b025eaef9de diff --git a/go/glide.yaml b/go/glide.yaml index ba253f8bebef0ddab810a8303ab1fbe541defbdf..c5d66694acd0f45de5002391a7953b7491eaf2bc 100644 --- a/go/glide.yaml +++ b/go/glide.yaml @@ -30,3 +30,4 @@ import: version: v2.13 - package: github.com/go-stack/stack version: v1.6.0 +- package: github.com/golang/protobuf diff --git a/go/proto/.gitignore b/go/proto/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5e7d2734cfc60289debf74293817c0a8f572ff32 --- /dev/null +++ b/go/proto/.gitignore @@ -0,0 +1,4 @@ +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/go/pserver/CMakeLists.txt b/go/pserver/CMakeLists.txt index 4fe0a8cb021e8dbf443c8f33bfb046e228a2fd8d..9ac05199e7ab76c21275838092c0afbdf2612b77 100644 --- a/go/pserver/CMakeLists.txt +++ b/go/pserver/CMakeLists.txt @@ -13,5 +13,5 @@ # limitations under the License. # if(WITH_TESTING) - go_test(pserver_test DEPS paddle_go_optimizer) + go_test(pserver_test DEPS paddle_go_optimizer gen_proto_go) endif() diff --git a/go/pserver/service.go b/go/pserver/service.go index f703d99a29ae9f5310ef36a7492b729c4c892937..7484ec90b1a3a9e67fa798741a9dfeb580c51f1a 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -17,6 +17,7 @@ package pserver import ( "bufio" "bytes" + "encoding/binary" "encoding/gob" "encoding/json" "errors" @@ -26,11 +27,15 @@ import ( "os" "path" "strconv" + "strings" "sync" "time" + "github.com/golang/protobuf/proto" uuid "github.com/satori/go.uuid" + pb "github.com/PaddlePaddle/Paddle/go/proto" + log "github.com/inconshreveable/log15" ) @@ -65,6 +70,46 @@ type Parameter struct { Content []byte } +func float32ToString(b []byte) string { + f := make([]float32, len(b)/4) + buf := bytes.NewReader(b) + err := binary.Read(buf, binary.LittleEndian, &f) + if err != nil { + return "" + } + return fmt.Sprintf("%v", f) +} + +func float32ByteToString(c []byte) string { + var a []byte + var b []byte + if len(c) <= 80 { + a = c + } else { + a = c[0:40] + b = c[len(c)-40:] + } + + var s string + s = float32ToString(a) + + if b == nil { + return s + } + + s = strings.Replace(s, "]", "", -1) + "..." + strings.Replace(float32ToString(b), "[", "", -1) + return s +} + +func (p Parameter) String() string { + if p.ElementType != Float32 { + return fmt.Sprintf("name:%v ElementType:%v", + p.Name, p.ElementType) + } + + return float32ByteToString(p.Content) +} + // ParameterWithConfig contains the parameter and the configuration. type ParameterWithConfig struct { Param Parameter @@ -189,7 +234,9 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error default: } - // TODO(helin): parse parameter config + c := &pb.OptimizerConfig{} + proto.Unmarshal(paramWithConfigs.Config, c) + log.Debug(fmt.Sprintf("OptimizerConfig:%v", c)) s.mu.Lock() defer s.mu.Unlock() @@ -239,7 +286,8 @@ func (s *Service) SendGrad(g Gradient, _ *int) error { select { case <-s.initialized: default: - log.Warn("received gradient before initialization.", "name", g.Name, "size", len(g.Content), "type", g.ElementType) + log.Warn("received gradient before initialization.", + "name", g.Name, "size", len(g.Content), "type", g.ElementType) return errors.New(Uninitialized) } @@ -248,10 +296,14 @@ func (s *Service) SendGrad(g Gradient, _ *int) error { o, ok := s.optMap[g.Name] if !ok { + log.Warn("received gradient but can't find name.", + "name", g.Name, "size", len(g.Content), "type", g.ElementType) return fmt.Errorf("parameter: %s does not exist", g.Name) } - log.Info("received gradient from trainer, updating gradient.", "name", g.Name, "size", len(g.Content), "type", g.ElementType) + log.Debug(Parameter(g).String()) + log.Info("received gradient from trainer, updating gradient.", + "name", g.Name, "size", len(g.Content), "type", g.ElementType) return o.UpdateParameter(g) } @@ -277,7 +329,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { parameter.Name = name parameter.ElementType = opt.elementType parameter.Content = opt.GetWeights() - + log.Debug(parameter.String()) log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType) return nil } diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index b6f4566eb78cf797e3738afa5f86f5c4e8090d85..58a743e1fadff9d629f682d660e661013c33ac8a 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -15,6 +15,7 @@ package pserver_test import ( + "fmt" "io/ioutil" "reflect" "sync" @@ -178,3 +179,33 @@ func TestBlockUntilInitialized(t *testing.T) { wg.Wait() } + +func TestGradientString(t *testing.T) { + g := pserver.Parameter{} + g.ElementType = pserver.Float32 + g.Content = []byte{0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40} + if g.String() != "[3.3702806e+12 2.142699 3.3702806e+12 2.142699]" { + t.Fatal("get float data error!") + } + + g.Content = []byte{0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, + 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40} + if g.String() != "[3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699...3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699 3.3702806e+12 2.142699]" { + t.Fatal("get float data error!", g.String()) + } + fmt.Println(g) +} diff --git a/proto/CMakeLists.txt b/proto/CMakeLists.txt index 5d898d860cfc6dc26eaf5a81d8aed6d757ed5831..556bcd1d7e60c27fece43de666e9531ab4203414 100644 --- a/proto/CMakeLists.txt +++ b/proto/CMakeLists.txt @@ -27,3 +27,30 @@ foreach(filename ${proto_filenames}) endforeach() add_custom_target(gen_proto_py ALL DEPENDS ${PROTO_GEN_PY}) + + +if (WITH_GOLANG) + add_custom_target(protoc-gen-go) + add_custom_command(TARGET protoc-gen-go + COMMAND go + ARGS "get" "-u" "github.com/golang/protobuf/protoc-gen-go") + + set(PROTO_GEN_GO) + file(GLOB proto_filenames . OptimizerConfig.proto) + foreach(filename ${proto_filenames}) + message(STATUS ${filename}) + get_filename_component(ABS_FIL ${filename} ABSOLUTE) + get_filename_component(FIL_WE ${filename} NAME_WE) + set(CUR_PROTO_GEN_GO + ${PADDLE_SOURCE_DIR}/paddle/go/proto/${FIL_WE}.pb.go) + set(PROTO_GEN_GO + ${CUR_PROTO_GEN_GO} + ${PROTO_GEN_GO}) + add_custom_command(OUTPUT ${CUR_PROTO_GEN_GO} + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS "--go_out=${PADDLE_SOURCE_DIR}/go/proto" + "-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL} + DEPENDS ${ABS_FIL} protoc protoc-gen-go) + endforeach() + add_custom_target(gen_proto_go ALL DEPENDS ${PROTO_GEN_GO}) +endif() diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index b68fd0d5a97a7993ddd0a1d947304fa5428c01b8..db01ab7374eca18b6063dc634da5ef83c4bc9adc 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -205,7 +205,8 @@ class SGD(object): """ Testing method. Will test input data. - :param reader: A reader that reads and yeilds data items. + :param reader: A batch reader that reads and yeilds data items, + it should be a paddle.v2.batch. :type reader: collections.Iterable :param feeding: Feeding is a map of neural network input name and array index that reader returns.