提交 6acca6f7 编写于 作者: B barrierye

add var shape

上级 0f67439c
......@@ -410,7 +410,11 @@ class GClient(object):
if self.feed_types_[var.alias_name] == 'float':
self.feed_types_[var.alias_name] = 'float32'
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
self.lod_tensor_set.add(var.alias_name)
else:
counter = 1
for dim in self.feed_shapes_[var.alias_name]:
counter *= dim
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if self.fetch_types_[var.alias_name] == 'float':
......@@ -435,6 +439,14 @@ class GClient(object):
itype = self.type_map_[self.feed_types_[name]]
data = np.array(var, dtype=itype)
inst.data.append(data.tobytes())
if isinstance(var, np.ndarray):
inst.shape.append(
np.array(
list(var.shape), dtype="int32").tobytes())
else:
inst.shape.append(
np.array(
self.feed_shapes_[name], dtype="int32").tobytes())
req.feed_insts.append(inst)
return req
......
......@@ -481,8 +481,10 @@ class GServerService(
feed_dict = {}
for idx, name in enumerate(feed_inst.names):
data = feed_inst.data[idx]
shape = feed_inst.shape[idx]
itype = self.type_map_[self.feed_types_[name]]
feed_dict[name] = np.frombuffer(data, dtype=itype)
feed_dict[name].shape = np.frombuffer(shape, dtype="int32")
feed_batch.append(feed_dict)
return feed_batch, fetch_names
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册