提交 d3574301 编写于 作者: C chengduozh

polish code

上级 950fa5c6
......@@ -195,27 +195,33 @@ class ResNet(fluid.dygraph.Layer):
return y, acc1, acc5
def init_data(batch_size=32, img_shape=[3, 224, 224], label_range=9):
assert isinstance(img_shape, list)
np.random.seed(5)
input_shape = [batch_size] + img_shape
img = np.random.random(size=input_shape).astype(np.float32)
label = np.array([1 for _ in range(batch_size)]).reshape(
(-1, 1)).astype("int64")
#label = np.array(
# [np.random.randint(0, label_range) for _ in range(batch_size)]).reshape(
# (-1, 1)).astype("int64")
return img, label
def train_resnet():
seed = 90
place = fluid.CUDAPlace(dygraph.parallel.Env().dev_id)
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
np.random.seed(seed)
import random
random.seed = seed
strategy = dygraph.parallel.prepare_context()
resnet = ResNet("dist_resnet", class_dim=1000)
strategy = dygraph.parallel.ParallelStrategy()
strategy.nranks = dygraph.parallel.Env().nranks
strategy.local_rank = dygraph.parallel.Env().local_rank
strategy.trainer_endpoints = dygraph.parallel.Env().trainer_endpoints
strategy.current_endpoint = dygraph.parallel.Env().current_endpoint
resnet = dygraph.parallel.DataParallel(resnet, strategy)
#if strategy.nranks > 1:
# dygraph.parallel.prepare_context(strategy)
optimizer = optimizer_setting()
train_reader = paddle.batch(
train(
data_dir="/imagenet/ImageNet_resize/",
......@@ -225,6 +231,7 @@ def train_resnet():
drop_last=True)
steps_per_epoch = int(total_images / strategy.nranks / batch_size)
print("steps per eoch: %d" % steps_per_epoch)
for eop in range(epoch):
for step_id, data in enumerate(train_reader()):
if step_id == steps_per_epoch:
......@@ -234,10 +241,13 @@ def train_resnet():
continue
s_time = time.time()
dy_x_data = np.array(
[x[0].reshape(3, 224, 224) for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
batch_size, 1)
#dy_x_data = np.array(
# [x[0].reshape(3, 224, 224) for x in data]).astype('float32')
#y_data = np.array([x[1] for x in data]).astype('int64').reshape(
# batch_size, 1)
dy_x_data, y_data = init_data()
print(np.sum(dy_x_data), np.sum(y_data))
img = to_variable(dy_x_data)
label = to_variable(y_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册