提交 4973e07b 编写于 作者: T tangwei12

sampling op optimize

上级 3206970b
......@@ -57,9 +57,11 @@ SamplingId Operator.
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sampling_id,
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, float>,
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, double>,
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int>,
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
sampling_id, ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int>,
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, float>,
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -30,11 +30,9 @@ class SamplingIdOp : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
sampling_id, ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int>,
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, float>,
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
sampling_id,
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, float>,
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, double>,
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int>,
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -15,30 +15,31 @@ limitations under the License. */
#include <random>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class SamplingIdKernel : public framework::OpKernel<T> {
/// Produces random floating-point values, uniformly distributed on [0, 1).
std::uniform_real_distribution<double> rand1_;
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("X");
const int batch_size = static_cast<int>(input->dims()[0]);
const int width = static_cast<int>(input->dims()[1]);
std::vector<int> ids(batchSize);
auto& reng = get();
std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector);
for (size_t i = 0; i < batchSize; ++i) {
double r = rand1_(reng);
int id = dim - 1;
for (int j = 0; j < dim; ++j) {
if ((r -= buf[i * dim + j]) < 0) {
std::vector<int> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
double r = this->get_rand();
int id = width - 1;
for (int j = 0; j < width; ++j) {
if ((r -= ins_vector[i * width + j]) < 0) {
id = j;
break;
}
......@@ -50,19 +51,22 @@ class SamplingIdKernel : public framework::OpKernel<T> {
out_dim.push_back(static_cast<int64_t>(batch_size));
Tensor* output = context.Output<Tensor>("Output");
output->Resize(framework::make_ddim(in_dim));
output->Resize(framework::make_ddim(out_dim));
output->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(ids, context.device_context(), output);
}
std::default_random_engine& get() {
auto engine = new std::default_random_engine;
engine->seed(defaultSeed);
return *engine;
double get_rand() const {
// Will be used to obtain a seed for the random number engine
std::random_device rd;
// Standard mersenne_twister_engine seeded with rd()
std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(0, 1);
return dis(gen);
}
private:
unsigned int defaultSeed = 0;
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册