测试一个函数出错
Created by: sguuaa
写了一个函数,想测试一下这个函数,程序如下。
import numpy as np
import paddle
import paddle.fluid as fluid
def fun(prediction, gt):
valid_mask = (gt >=0).astype('float32')
gt_mask = []
for i in range(0, 3):
temp1 = np.array((gt == i).astype('float32'))
# print(valid_mask.shape)
temp2 = np.array(valid_mask)
# print(temp2.shape)
temp2 = temp2[:, i, :, :]
temp = fluid.layers.elementwise_mul(temp1, temp2)
gt_mask.append(temp)
prediction = fluid.layers.elementwise_mul(prediction, valid_mask)
loss = fluid.layers.elementwise_mul(prediction.clamp(min=1e-7).log(), gt_mask).sum()
return loss
if __name__ == '__main__':
prediction = fluid.layers.data(name='prediction', shape=[3, 10, 10], dtype='float32')
gt = fluid.layers.data(name = 'gt', shape = [1, 10, 10], dtype='float32')
cost = fun(prediction, gt)
def create_reader():
def reader():
prediction = np.random.rand(3, 10, 10)
gt = np.random.rand(1, 10, 10) * 3
yield prediction, gt
return reader
db = paddle.batch(create_reader(), batch_size = 2)
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(place = place, feed_list = [prediction, gt])
fetch_list = [cost]
exe = fluid.Executor(place)
for idx, data in enumerate(db()):
print('data:',data)
loss = exe.run(feed = feeder.feed(data), fetch_list = fetch_list)
print(loss)
break