From 2ae37a4ea2f4b02ffe6b773590ed05c77675e6f5 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 31 Aug 2017 00:28:01 +0800 Subject: [PATCH] fix data_layer for 3D data --- python/paddle/trainer_config_helpers/layers.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 42bf1c19d1..2aa86850d1 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -926,16 +926,18 @@ def data_layer(name, size, height=None, width=None, depth=None, type=LayerType.DATA, name=name, size=size, + depth=depth, height=height, width=width, - depth=depth, **ExtraLayerAttribute.to_kwargs(layer_attr)) + if depth is None: + depth = 1 num_filters = None if height is not None and width is not None: - num_filters = size / (width * height) - assert num_filters * width * height == size, \ - "size=%s width=%s height=%s" % (size, width, height) + num_filters = size / (width * height * depth) + assert num_filters * width * height*depth == size, \ + "size=%s width=%s height=%s depth=%s" % (size, width, height, depth) return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters) -- GitLab