linspace_op.cu 3.2 KB
Newer Older
Z
zhoukunsheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/framework/data_type_transform.h"
Z
zhoukunsheng 已提交
16 17 18 19 20 21 22
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/linspace_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

23 24
using Tensor = framework::Tensor;

Z
zhoukunsheng 已提交
25
template <typename T>
26 27 28 29
__global__ void LinspaceKernel(T start, double step, int64_t size, T* out) {
  CUDA_KERNEL_LOOP(index, size) {
    out[index] = static_cast<T>(start + step * index);
  }
Z
zhoukunsheng 已提交
30 31 32 33
}

template <typename T>
__global__ void LinspaceSpecialKernel(T start, T* out) {
34
  out[0] = static_cast<T>(start);
Z
zhoukunsheng 已提交
35 36 37 38 39 40
}

template <typename T>
class CUDALinspaceKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
41 42
    auto* pre_start = context.Input<framework::Tensor>("Start");
    auto* pre_stop = context.Input<framework::Tensor>("Stop");
Z
zhoukunsheng 已提交
43 44
    auto* num_t = context.Input<framework::Tensor>("Num");
    auto* out = context.Output<framework::Tensor>("Out");
45 46 47 48 49 50 51 52 53 54 55 56
    auto dtype = static_cast<framework::proto::VarType::Type>(
        context.Attr<int>("dtype"));

    Tensor start_t;
    Tensor stop_t;
    auto start_dtype =
        framework::OpKernelType(pre_start->type(), context.GetPlace());
    auto stop_dtype =
        framework::OpKernelType(pre_stop->type(), context.GetPlace());
    auto out_dtype = framework::OpKernelType(dtype, context.GetPlace());
    framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
    framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
Z
zhoukunsheng 已提交
57 58

    framework::Tensor n;
59
    framework::TensorCopy(start_t, platform::CPUPlace(), &n);
Z
zhoukunsheng 已提交
60
    T start = n.data<T>()[0];
61
    framework::TensorCopy(stop_t, platform::CPUPlace(), &n);
Z
zhoukunsheng 已提交
62 63 64 65 66 67 68 69 70
    T stop = n.data<T>()[0];
    framework::TensorCopy(*num_t, platform::CPUPlace(), &n);
    int32_t num = n.data<int32_t>()[0];

    PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0.");

    out->Resize(framework::make_ddim({num}));
    T* out_data = out->mutable_data<T>(context.GetPlace());

71
    double step = 0;
Z
zhoukunsheng 已提交
72
    if (num != 1) {
73
      step = (static_cast<double>(stop - start)) / (num - 1);
Z
zhoukunsheng 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87
    }

    auto stream = context.cuda_device_context().stream();
    int block = 512;
    int grid = (num + block - 1) / block;
    LinspaceKernel<T><<<grid, block, 0, stream>>>(start, step, num, out_data);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(linspace, ops::CUDALinspaceKernel<float>,
88 89
                        ops::CUDALinspaceKernel<int32_t>,
                        ops::CUDALinspaceKernel<int64_t>,
Z
zhoukunsheng 已提交
90
                        ops::CUDALinspaceKernel<double>);