提交 470fb7c5 编写于 作者: T tangwei12

bug fix

上级 60dda7bf
......@@ -36,15 +36,15 @@ class SamplingIdKernel : public framework::OpKernel<T> {
std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector);
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max")));
static_cast<T>(context.Attr<float>("min")),
static_cast<T>(context.Attr<float>("max")));
std::vector<T> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
......
......@@ -39,7 +39,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class SamplingIdKernel : public framework::OpKernel<T> {
class SamplingIdGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("X");
......@@ -83,5 +83,6 @@ class SamplingIdKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel<float>,
paddle::operators::SamplingIdKernel<double>);
REGISTER_OP_CUDA_KERNEL(sampling_id,
paddle::operators::SamplingIdGPUKernel<float>,
paddle::operators::SamplingIdGPUKernel<double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册