动态图中如何用normal分布来批量sample呢?
Created by: dbsxdbsx
PaddlePaddle版本:1.8.3 cuda 10 cudnn 7.6 我在看parl强化学习算法sac部分代码迁移到动态图中时发现:
def sample(self, obs):
mean, log_std = self.actor(obs) # mean[1,1] obs [1,3]
std = layers.exp(log_std) #TODO: why need exp
# mean = layers.squeeze(mean, axes=[1])
# std = layers.squeeze(std, axes=[1])
normal = self.distribution(mean, std)
# _size = mean.shape[0]
output = normal.sample([1]) #output size : [inputSize,1,1]
x_t = output[0]
y_t = layers.tanh(x_t)
action = y_t * self.act_lim
log_prob = normal.log_prob(x_t)
log_prob -= layers.log(self.act_lim * (1 - layers.pow(y_t, 2)) +
epsilon)
# log_prob= fluid.layers.unsqueeze(input=log_prob, axes=[1])
log_prob = layers.reduce_sum(log_prob, dim=1, keep_dim=True)
log_prob = layers.squeeze(log_prob, axes=[1])
return action, log_prob
上述代码在parl中https://github.com/PaddlePaddle/PARL/blob/develop/parl/algorithms/fluid/sac.py
是可以通过normal分布来批量采集的,但在我动态图模型中,当mean, log_std的shape是[batchsize,1]时,执行到normal.sample([1])
这里会报错。由于调试parl的sac模型发现执行到学习这一步:agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs,batch_terminal)
时无法输出print信息,所以不是很清楚在动态图中该怎么通过normal分布来批量采集?
再说明下:这里的批量sample意思是---当mean, log_std的shape是[batchsize,1]时,需要从batchsize个normal分布中为每个分布采集一个样本。