提交 12e9bf6c 编写于 作者: X Xin Pan

clean up

上级 ab72d28a
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under
the Apache License, Version 2.0 (the "License"); you may not use this file
except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto3";
package sendrecv;
option cc_generic_services = false;
service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
}
// VariableMessage is serialized paddle variable message.
// It can be:
// LoDTensor
// SelectedRows
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
NCCL_ID = 2;
}
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message LodData { repeated int64 lod_data = 1; }
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3;
repeated int64 dims = 4;
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
// If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from 1 to 2.
int64 profile = 11;
}
message VoidMessage {}
......@@ -71,7 +71,6 @@ class TestParallelExecutorBase(unittest.TestCase):
exec_strategy.allow_op_delay = allow_op_delay
build_strategy = fluid.BuildStrategy()
build_strategy.debug_graphviz_path = "/tmp/graphviz"
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
......
......@@ -152,6 +152,16 @@ class TestMNIST(TestParallelExecutorBase):
use_cuda=use_cuda,
use_reduce=use_reduce)
def test_simple_fc(self):
# use_cuda
self.check_simple_fc_convergence(True)
self.check_simple_fc_convergence(False)
def test_simple_fc_with_new_strategy(self):
# use_cuda, use_reduce
self._compare_reduce_and_allreduce(simple_fc_net, True)
self._compare_reduce_and_allreduce(simple_fc_net, False)
def check_simple_fc_parallel_accuracy(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda():
return
......@@ -178,6 +188,10 @@ class TestMNIST(TestParallelExecutorBase):
for p_l in parallel_last_loss:
self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6)
def test_simple_fc_parallel_accuracy(self):
self.check_simple_fc_parallel_accuracy(True)
self.check_simple_fc_parallel_accuracy(False)
def check_batchnorm_fc_convergence(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda():
return
......@@ -192,31 +206,13 @@ class TestMNIST(TestParallelExecutorBase):
"label": label},
use_cuda=use_cuda)
def check_batchnorm_fc_convergence_use_reduce(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence(
fc_with_batchnorm, use_cuda=use_cuda, use_reduce=False)
"""
img, label = self._init_data()
all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence(
fc_with_batchnorm,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=False)
reduce_first_loss, reduce_last_loss = self.check_network_convergence(
fc_with_batchnorm,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=True)
"""
def test_batchnorm_fc(self):
self.check_batchnorm_fc_convergence(True)
self.check_batchnorm_fc_convergence(False)
def test_batchnorm_fc_with_new_strategy(self):
self.check_batchnorm_fc_convergence_use_reduce(True)
# self.check_batchnorm_fc_convergence_use_reduce(False)
self._compare_reduce_and_allreduce(fc_with_batchnorm, True)
self._compare_reduce_and_allreduce(fc_with_batchnorm, False)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部