optimizer_test.go 1.9 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

H
Helin Wang 已提交
15 16
package pserver

D
dongzhihong 已提交
17
import (
18
	"encoding/binary"
D
dongzhihong 已提交
19
	"io/ioutil"
20
	"math"
D
dongzhihong 已提交
21
	"testing"
22 23

	"github.com/stretchr/testify/assert"
D
dongzhihong 已提交
24
)
H
Helin Wang 已提交
25

D
dongzhihong 已提交
26 27 28
func TestOptimizerCreateRelease(t *testing.T) {
	p := Parameter{
		Name:        "a",
D
dongzhihong 已提交
29
		ElementType: Int32,
D
dzhwinter 已提交
30
	}
D
dongzhihong 已提交
31
	p.Content = []byte{1, 3}
Q
Qiao Longfei 已提交
32
	config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb")
D
dongzhihong 已提交
33 34 35
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
D
dongzhihong 已提交
36 37 38 39
	param := ParameterWithConfig{
		Param:  p,
		Config: config,
	}
D
dongzhihong 已提交
40
	o := newOptimizer(param, nil)
H
Helin Wang 已提交
41 42
	o.Cleanup()
}
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

func float32Bytes(float float32) []byte {
	bits := math.Float32bits(float)
	bytes := make([]byte, 4)
	binary.LittleEndian.PutUint32(bytes, bits)
	return bytes
}

func TestOptimizerState(t *testing.T) {
	p := Parameter{
		Name:        "a",
		ElementType: Int32,
	}
	weights := float32Bytes(100)
	p.Content = weights
	config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb")
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
	param := ParameterWithConfig{
		Param:  p,
		Config: config,
	}
	o := newOptimizer(param, nil)
	s := o.GetStates()

	// clear param content and check if the state is restored.
	param.Param.Content = float32Bytes(300)
	o1 := newOptimizer(param, s)
	s1 := o1.GetStates()
	assert.Equal(t, s, s1)
	assert.Equal(t, weights, o.GetWeights())
	assert.Equal(t, weights, o1.GetWeights())
	o.Cleanup()
	o1.Cleanup()
}