where_op.cu 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"

namespace platform = paddle::platform;

namespace paddle {
namespace operators {

template <typename T>
__global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x,
                                const T* y, T* out) {
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  for (; idx < N; idx += blockDim.x * gridDim.x) {
    out[idx] = cond[idx] ? x[idx] : y[idx];
  }
}

template <typename T>
__global__ void WhereGradCUDAKernel(const int N, const T* out, const bool* cond,
                                    T* x, T* y) {
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  for (; idx < N; idx += blockDim.x * gridDim.x) {
    if (x != nullptr) {
      x[idx] = out[idx] * (cond[idx] ? 1. : 0.);
    }
    if (y != nullptr) {
      y[idx] = out[idx] * (cond[idx] ? 0. : 1.);
    }
  }
}

template <typename T>
class WhereKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* condition = context.Input<framework::Tensor>("Condition");
    auto* X = context.Input<framework::Tensor>("X");
    auto* Y = context.Input<framework::Tensor>("Y");
    auto* out = context.Output<framework::Tensor>("Out");
    auto numel = condition->numel();

    // TODO(GaaoWei8): Input of where can be broadcast
    const bool* cond_data = condition->data<bool>();
    const T* x_data = X->data<T>();
    const T* y_data = Y->data<T>();
    T* out_data = out->mutable_data<T>(context.GetPlace());

    auto stream = context.cuda_device_context().stream();
    auto& dev_ctx =
        context.template device_context<platform::CUDADeviceContext>();
    auto config = GetGpuLaunchConfig1D(dev_ctx, numel);
    WhereCUDAKernel<
        T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
        numel, cond_data, x_data, y_data, out_data);
  }
};

template <typename T>
class WhereGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* condition = context.Input<framework::Tensor>("Condition");
    const bool* cond_data = condition->data<bool>();
    auto numel = condition->numel();

    auto* dout_t =
        context.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* dx_t = context.Output<framework::Tensor>(framework::GradVarName("X"));
    auto* dy_t = context.Output<framework::Tensor>(framework::GradVarName("Y"));
    auto* dout = dout_t->data<T>();
    T* dx =
        (dx_t != nullptr) ? dx_t->mutable_data<T>(context.GetPlace()) : nullptr;
    T* dy =
        (dy_t != nullptr) ? dy_t->mutable_data<T>(context.GetPlace()) : nullptr;

    auto stream = context.cuda_device_context().stream();
    auto& dev_ctx =
        context.template device_context<platform::CUDADeviceContext>();
    auto config = GetGpuLaunchConfig1D(dev_ctx, condition->numel());
    WhereGradCUDAKernel<
        T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
        numel, dout, cond_data, dx, dy);
  }
};

}  // namespace operators
}  // namespace paddle

REGISTER_OP_CUDA_KERNEL(
    where, paddle::operators::WhereKernel<platform::CUDADeviceContext, float>,
    paddle::operators::WhereKernel<platform::CUDADeviceContext, double>,
    paddle::operators::WhereKernel<platform::CUDADeviceContext, int>,
    paddle::operators::WhereKernel<platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
    where_grad,
    paddle::operators::WhereGradKernel<platform::CUDADeviceContext, float>,
    paddle::operators::WhereGradKernel<platform::CUDADeviceContext, double>,
    paddle::operators::WhereGradKernel<platform::CUDADeviceContext, int>,
    paddle::operators::WhereGradKernel<platform::CUDADeviceContext, int64_t>);