提交 470fb7c5 编写于 作者: T tangwei12

bug fix

上级 60dda7bf
...@@ -36,15 +36,15 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -36,15 +36,15 @@ class SamplingIdKernel : public framework::OpKernel<T> {
std::vector<T> ins_vector; std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector); framework::TensorToVector(*input, context.device_context(), &ins_vector);
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed")); unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine; std::minstd_rand engine;
if (seed == 0) { if (seed == 0) {
seed = std::random_device()(); seed = std::random_device()();
} }
engine.seed(seed); engine.seed(seed);
std::uniform_real_distribution<T> dist( std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")), static_cast<T>(context.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max"))); static_cast<T>(context.Attr<float>("max")));
std::vector<T> ids(batch_size); std::vector<T> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
......
...@@ -39,7 +39,7 @@ namespace operators { ...@@ -39,7 +39,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
class SamplingIdKernel : public framework::OpKernel<T> { class SamplingIdGPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("X"); const Tensor* input = context.Input<Tensor>("X");
...@@ -83,5 +83,6 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -83,5 +83,6 @@ class SamplingIdKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel<float>, REGISTER_OP_CUDA_KERNEL(sampling_id,
paddle::operators::SamplingIdKernel<double>); paddle::operators::SamplingIdGPUKernel<float>,
paddle::operators::SamplingIdGPUKernel<double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册