提交 a6ce2306 编写于 作者: H hutuxian 提交者: Yi Liu

support cvm_op run in gpu (#21300)

Previously, CVM OP was only able to run in CPU. This PR implements its GPU kernel.
What's more, we improve the UTs about CVM OP.
上级 b085ecc2
...@@ -54,7 +54,7 @@ class CVMOp : public framework::OperatorWithKernel { ...@@ -54,7 +54,7 @@ class CVMOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace()); ctx.device_context());
} }
}; };
...@@ -96,7 +96,7 @@ class CVMGradientOp : public framework::OperatorWithKernel { ...@@ -96,7 +96,7 @@ class CVMGradientOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace()); ctx.device_context());
} }
}; };
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/cvm_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void CvmComputeKernel(const bool use_cvm, const int64_t item_width,
const T* X, T* Y, int64_t numel) {
CUDA_KERNEL_LOOP(i, numel) {
if (use_cvm) {
if (i % item_width == 0) {
Y[i] = log(X[i] + 1);
} else if (i % item_width == 1) {
Y[i] = log(X[i] + 1) - log(X[i - 1] + 1);
} else {
Y[i] = X[i];
}
} else {
Y[i] = X[i / (item_width - 2) * item_width + i % (item_width - 2) + 2];
}
}
}
template <typename T>
__global__ void CvmGradComputeKernel(const bool use_cvm,
const int64_t item_width, const T* CVM,
const T* DY, T* DX, bool has_lod,
const size_t* lod, size_t lod_size,
int64_t numel) {
CUDA_KERNEL_LOOP(i, numel) {
int offset = i % item_width;
if (offset <= 1) {
int cvm_id = i / item_width;
if (has_lod) {
int low = 1;
int high = lod_size - 1;
while (low < high) {
int mid = (low + high) / 2;
if (cvm_id < lod[mid])
high = mid;
else
low = mid + 1;
}
cvm_id = low - 1;
}
DX[i] = CVM[2 * cvm_id + offset];
} else {
if (use_cvm) {
DX[i] = DY[i];
} else {
DX[i] = DY[i / item_width * (item_width - 2) + i % item_width - 2];
}
}
}
}
template <typename T>
class CVMCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* x = context.Input<LoDTensor>("X");
const T* x_data = x->data<T>();
auto batch_size = x->dims()[0];
auto numel = x->numel();
auto item_size = numel / batch_size;
auto use_cvm = context.Attr<bool>("use_cvm");
auto* y = context.Output<LoDTensor>("Y");
T* y_data = y->mutable_data<T>(context.GetPlace());
// for Input X do not have Lod Information.
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
if (x->NumLevels() == 0) {
CvmComputeKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
use_cvm, item_size, x_data, y_data, y->numel());
} else {
auto lod = x->lod()[0];
PADDLE_ENFORCE_EQ(
batch_size, lod[lod.size() - 1],
platform::errors::PreconditionNotMet(
"Input(X)'s dim[0] must be equal to last element of lod"));
CvmComputeKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
use_cvm, item_size, x_data, y_data, y->numel());
}
}
};
template <typename T>
class CVMGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dx = context.Output<LoDTensor>(framework::GradVarName("X"));
T* dx_data = dx->mutable_data<T>(context.GetPlace());
const Tensor* cvm = context.Input<Tensor>("CVM");
const T* cvm_data = cvm->data<T>();
const auto* dOut =
context.Input<framework::LoDTensor>(framework::GradVarName("Y"));
const T* dout_data = dOut->data<T>();
auto use_cvm = context.Attr<bool>("use_cvm");
auto offset = 2;
auto batch_size = dx->dims()[0];
auto dx_numel = dx->numel();
auto item_size = dx_numel / batch_size;
// for Input X do not have Lod Information.
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
if (dx->NumLevels() == 0) {
CvmGradComputeKernel<<<(dx_numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
use_cvm, item_size, cvm_data, dout_data, dx_data, false, NULL, 0,
dx_numel);
} else {
auto lod = dx->lod()[0];
PADDLE_ENFORCE_EQ(
batch_size, lod[lod.size() - 1],
platform::errors::PreconditionNotMet(
"Output(X@GRAD)'s dim[0] must be equal to last element of lod"));
CvmGradComputeKernel<<<(dx_numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
use_cvm, item_size, cvm_data, dout_data, dx_data, true,
lod.CUDAData(context.GetPlace()), lod.size(), dx_numel);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(cvm, ops::CVMCUDAKernel<float>,
ops::CVMCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(cvm_grad, ops::CVMGradCUDAKernel<float>,
ops::CVMGradCUDAKernel<double>);
...@@ -60,12 +60,14 @@ class TestCVMOpWithLodTensor(OpTest): ...@@ -60,12 +60,14 @@ class TestCVMOpWithLodTensor(OpTest):
self.op_type = "cvm" self.op_type = "cvm"
self.use_cvm = True self.use_cvm = True
batch_size = 8 self.batch_size = 1
dims = 11 self.item_width = 11
lod = [[1]] lod = [[1]]
self.inputs = { self.inputs = {
'X': (np.random.uniform(0, 1, [1, dims]).astype("float32"), lod), 'X': (np.random.uniform(
0, 1, [self.batch_size, self.item_width]).astype("float32"),
lod),
'CVM': np.array([[0.6, 0.4]]).astype("float32"), 'CVM': np.array([[0.6, 0.4]]).astype("float32"),
} }
self.attrs = {'use_cvm': False} self.attrs = {'use_cvm': False}
...@@ -77,6 +79,14 @@ class TestCVMOpWithLodTensor(OpTest): ...@@ -77,6 +79,14 @@ class TestCVMOpWithLodTensor(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
user_grads = np.array(
[1.0 / (self.item_width - 2)] * self.item_width).reshape(
(self.batch_size, self.item_width)).astype("float32")
user_grads[:, :2] = self.inputs['CVM'].reshape(self.batch_size, 2)
user_grads = [user_grads]
self.check_grad(['X'], 'Y', user_defined_grads=user_grads)
class TestCVMOpWithOutLodTensor1(OpTest): class TestCVMOpWithOutLodTensor1(OpTest):
""" """
...@@ -87,13 +97,14 @@ class TestCVMOpWithOutLodTensor1(OpTest): ...@@ -87,13 +97,14 @@ class TestCVMOpWithOutLodTensor1(OpTest):
self.op_type = "cvm" self.op_type = "cvm"
self.use_cvm = True self.use_cvm = True
batch_size = 2 self.batch_size = 2
item_width = 11 self.item_width = 11
input = np.random.uniform(0, 1, input = np.random.uniform(
(batch_size, item_width)).astype('float32') 0, 1, (self.batch_size, self.item_width)).astype('float32')
output = cvm_compute(input, item_width, self.use_cvm) output = cvm_compute(input, self.item_width, self.use_cvm)
cvm = np.array([[0.6, 0.4]]).astype("float32") cvm = np.array([[0.6, 0.4] * self.batch_size]).reshape(
(self.batch_size, 2)).astype("float32")
self.inputs = {'X': input, 'CVM': cvm} self.inputs = {'X': input, 'CVM': cvm}
self.attrs = {'use_cvm': self.use_cvm} self.attrs = {'use_cvm': self.use_cvm}
...@@ -102,6 +113,14 @@ class TestCVMOpWithOutLodTensor1(OpTest): ...@@ -102,6 +113,14 @@ class TestCVMOpWithOutLodTensor1(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
numel = self.batch_size * self.item_width
user_grads = np.array([1.0 / numel] * numel).reshape(
(self.batch_size, self.item_width)).astype("float32")
user_grads[:, :2] = self.inputs['CVM'].reshape(self.batch_size, 2)
user_grads = [user_grads]
self.check_grad(['X'], 'Y', user_defined_grads=user_grads)
class TestCVMOpWithOutLodTensor2(OpTest): class TestCVMOpWithOutLodTensor2(OpTest):
""" """
...@@ -112,13 +131,14 @@ class TestCVMOpWithOutLodTensor2(OpTest): ...@@ -112,13 +131,14 @@ class TestCVMOpWithOutLodTensor2(OpTest):
self.op_type = "cvm" self.op_type = "cvm"
self.use_cvm = False self.use_cvm = False
batch_size = 2 self.batch_size = 2
item_width = 11 self.item_width = 11
input = np.random.uniform(0, 1, input = np.random.uniform(
(batch_size, item_width)).astype('float32') 0, 1, (self.batch_size, self.item_width)).astype('float32')
output = cvm_compute(input, item_width, self.use_cvm) output = cvm_compute(input, self.item_width, self.use_cvm)
cvm = np.array([[0.6, 0.4]]).astype("float32") cvm = np.array([[0.6, 0.4] * self.batch_size]).reshape(
(self.batch_size, 2)).astype("float32")
self.inputs = {'X': input, 'CVM': cvm} self.inputs = {'X': input, 'CVM': cvm}
self.attrs = {'use_cvm': self.use_cvm} self.attrs = {'use_cvm': self.use_cvm}
...@@ -127,6 +147,15 @@ class TestCVMOpWithOutLodTensor2(OpTest): ...@@ -127,6 +147,15 @@ class TestCVMOpWithOutLodTensor2(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
numel = self.batch_size * self.item_width
user_grads = np.array(
[1.0 / (self.batch_size * (self.item_width - 2))] * numel).reshape(
(self.batch_size, self.item_width)).astype("float32")
user_grads[:, :2] = self.inputs['CVM'].reshape(self.batch_size, 2)
user_grads = [user_grads]
self.check_grad(['X'], 'Y', user_defined_grads=user_grads)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册