提交 4973e07b 编写于 作者: T tangwei12

sampling op optimize

上级 3206970b
...@@ -57,9 +57,11 @@ SamplingId Operator. ...@@ -57,9 +57,11 @@ SamplingId Operator.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker,
sampling_id, paddle::framework::EmptyGradOpMaker);
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, float>,
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, double>, REGISTER_OP_CPU_KERNEL(
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int>, sampling_id, ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int>,
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, float>,
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -30,11 +30,9 @@ class SamplingIdOp : public framework::OperatorWithKernel { ...@@ -30,11 +30,9 @@ class SamplingIdOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker, REGISTER_OP_CUDA_KERNEL(
paddle::framework::EmptyGradOpMaker); sampling_id,
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, float>,
REGISTER_OP_CPU_KERNEL( ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, double>,
sampling_id, ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int>, ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int>,
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, float>,
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -15,30 +15,31 @@ limitations under the License. */ ...@@ -15,30 +15,31 @@ limitations under the License. */
#include <random> #include <random>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SamplingIdKernel : public framework::OpKernel<T> { class SamplingIdKernel : public framework::OpKernel<T> {
/// Produces random floating-point values, uniformly distributed on [0, 1).
std::uniform_real_distribution<double> rand1_;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("X"); const Tensor* input = context.Input<Tensor>("X");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
const int width = static_cast<int>(input->dims()[1]); const int width = static_cast<int>(input->dims()[1]);
std::vector<int> ids(batchSize); std::vector<T> ins_vector;
auto& reng = get(); framework::TensorToVector(*input, context.device_context(), &ins_vector);
for (size_t i = 0; i < batchSize; ++i) { std::vector<int> ids(batch_size);
double r = rand1_(reng); for (size_t i = 0; i < batch_size; ++i) {
int id = dim - 1; double r = this->get_rand();
for (int j = 0; j < dim; ++j) { int id = width - 1;
if ((r -= buf[i * dim + j]) < 0) { for (int j = 0; j < width; ++j) {
if ((r -= ins_vector[i * width + j]) < 0) {
id = j; id = j;
break; break;
} }
...@@ -50,19 +51,22 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -50,19 +51,22 @@ class SamplingIdKernel : public framework::OpKernel<T> {
out_dim.push_back(static_cast<int64_t>(batch_size)); out_dim.push_back(static_cast<int64_t>(batch_size));
Tensor* output = context.Output<Tensor>("Output"); Tensor* output = context.Output<Tensor>("Output");
output->Resize(framework::make_ddim(in_dim)); output->Resize(framework::make_ddim(out_dim));
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(ids, context.device_context(), output); framework::TensorFromVector(ids, context.device_context(), output);
} }
std::default_random_engine& get() { double get_rand() const {
auto engine = new std::default_random_engine; // Will be used to obtain a seed for the random number engine
engine->seed(defaultSeed); std::random_device rd;
return *engine; // Standard mersenne_twister_engine seeded with rd()
std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(0, 1);
return dis(gen);
} }
private: private:
unsigned int defaultSeed = 0; unsigned int defaultSeed = 0;
} };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册