未验证 提交 4440d7ce 编写于 作者: W wangchaochaohu 提交者: GitHub

test=develop cuda realization of label smooth op (#19175)

上级 31c5a5ee
...@@ -12,15 +12,101 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,15 +12,101 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/label_smooth_op.h" #include "paddle/fluid/operators/label_smooth_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void LabelSmoothRunOriginKernel(const int N, const float epsilon,
const int label_dim, const T* src,
T* dst) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
dst[idx] = static_cast<T>(1 - epsilon) * src[idx] +
static_cast<T>(epsilon / label_dim);
}
}
template <typename T>
__global__ void LabelSmoothRunDistKernel(const int N, const float epsilon,
const int dist_numel, const T* src,
const T* dist_data, T* dst) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
int dist_idx = idx - (idx / dist_numel) * dist_numel;
dst[idx] = static_cast<T>(1 - epsilon) * src[idx] +
static_cast<T>(epsilon) * dist_data[dist_idx];
}
}
template <typename T>
__global__ void LabelSmoothGradRunKernel(const int N, const float epsilon,
const T* src, T* dst) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
dst[idx] = static_cast<T>(1 - epsilon) * src[idx];
}
}
template <typename DeviceContext, typename T>
class LabelSmoothGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::LoDTensor>("Out");
auto* in_t = ctx.Input<framework::LoDTensor>("X");
auto* dist_t = ctx.Input<framework::Tensor>("PriorDist");
auto label_dim = in_t->dims()[1];
auto epsilon = ctx.Attr<float>("epsilon");
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
auto size_prob = in_t->numel();
const T* in_data = in_t->data<T>();
T* out_data = out_t->mutable_data<T>(ctx.GetPlace());
int threads = 512;
int grid = (size_prob + threads - 1) / threads;
auto stream = ctx.cuda_device_context().stream();
if (dist_t) {
auto dist_numel = dist_t->numel();
const T* dist_data = dist_t->data<T>();
LabelSmoothRunDistKernel<T><<<grid, threads, 0, stream>>>(
size_prob, epsilon, dist_numel, in_data, dist_data, out_data);
} else {
LabelSmoothRunOriginKernel<T><<<grid, threads, 0, stream>>>(
size_prob, epsilon, label_dim, in_data, out_data);
}
}
};
template <typename DeviceContext, typename T>
class LabelSmoothGradGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out_t = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_in_t = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_in_t->mutable_data<T>(ctx.GetPlace());
auto epsilon = ctx.Attr<float>("epsilon");
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
const T* in_data = d_out_t->data<T>();
auto size_prob = d_out_t->numel();
T* out_data = d_in_t->mutable_data<T>(ctx.GetPlace());
int threads = 512;
int grid = (size_prob + threads - 1) / threads;
auto stream = ctx.cuda_device_context().stream();
LabelSmoothGradRunKernel<T><<<grid, threads, 0, stream>>>(
size_prob, epsilon, in_data, out_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
label_smooth, label_smooth,
ops::LabelSmoothKernel<paddle::platform::CUDADeviceContext, float>, ops::LabelSmoothGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::LabelSmoothKernel<paddle::platform::CUDADeviceContext, double>); ops::LabelSmoothGPUKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
label_smooth_grad, label_smooth_grad,
ops::LabelSmoothGradKernel<paddle::platform::CUDADeviceContext, float>, ops::LabelSmoothGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::LabelSmoothGradKernel<paddle::platform::CUDADeviceContext, double>); ops::LabelSmoothGradGPUKernel<paddle::platform::CUDADeviceContext, double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册