Created by: pangyoki
PR types
Bug fixes
PR changes
APIs
Describe
Fix the sample
method of Uniform
and Normal
class .
For Uniform
class:
The shape of inputs of elementwise_add
op in sample
method is not matching. In detail, in the following code:
output = uniform_random_tmp * (zero_tmp + self.high - self.low) + self.low
the shape of zero_tmp is [batch_shape, sample_shape],
the shape of self.high is [batch_shape].
When zero_tmp and self.high doing elementwise_add
operation, it will cause "Broadcast dimension dismatch".
To solve this problem, before doing elementwise_add
operation, reshape
zero_tmp. That is, change the shape of zero_tmp from [batch_shape, sample_shape] to [sample_shape, batch_shape].
At the same time, reshape uniform_random_tmp to [sample_shape, batch_shape].
For Normal
class, it has the same problem, using the reshape
method too.