# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle.fluid as fluid import numpy as np def simple_fc_net(use_feed=None): img = fluid.layers.data(name='image', shape=[784], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') hidden = img for _ in range(4): hidden = fluid.layers.fc( hidden, size=200, act='relu', bias_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(value=1.0))) prediction = fluid.layers.fc(hidden, size=10, act='softmax') loss = fluid.layers.cross_entropy(input=prediction, label=label) loss = fluid.layers.mean(loss) return loss def fc_with_batchnorm(use_feed=None): img = fluid.layers.data(name='image', shape=[784], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') hidden = img for _ in range(2): hidden = fluid.layers.fc( hidden, size=200, act='relu', bias_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(value=1.0))) hidden = fluid.layers.batch_norm(input=hidden) prediction = fluid.layers.fc(hidden, size=10, act='softmax') loss = fluid.layers.cross_entropy(input=prediction, label=label) loss = fluid.layers.mean(loss) return loss def init_data(batch_size=32, img_shape=[784], label_range=9): np.random.seed(5) assert isinstance(img_shape, list) input_shape = [batch_size] + img_shape img = np.random.random(size=input_shape).astype(np.float32) label = np.array( [np.random.randint(0, label_range) for _ in range(batch_size)]).reshape( (-1, 1)).astype("int64") return img, label