提交 f5807670 编写于 作者: Y yangyaming

Fix typos and use HOSTDEVICE instead.

上级 b7776e66
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,10 +29,10 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,10 +29,10 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T> template <typename T>
struct SmoothL1LossFoward { struct SmoothL1LossForward {
__host__ __device__ SmoothL1LossFoward(const T& sigma2) : sigma2(sigma2) {} HOSTDEVICE SmoothL1LossForward(const T& sigma2) : sigma2(sigma2) {}
__host__ __device__ T operator()(const T& val) const { HOSTDEVICE T operator()(const T& val) const {
T abs_val = std::abs(val); T abs_val = std::abs(val);
if (abs_val < 1.0 / sigma2) { if (abs_val < 1.0 / sigma2) {
return 0.5 * val * val * sigma2; return 0.5 * val * val * sigma2;
...@@ -80,7 +81,7 @@ class SmoothL1LossKernel : public framework::OpKernel { ...@@ -80,7 +81,7 @@ class SmoothL1LossKernel : public framework::OpKernel {
context.GetPlace()); context.GetPlace());
auto errors = EigenVector<T>::Flatten(paddle_errors); auto errors = EigenVector<T>::Flatten(paddle_errors);
// apply smooth l1 forward // apply smooth l1 forward
errors.device(place) = diff.unaryExpr(SmoothL1LossFoward<T>(sigma2)); errors.device(place) = diff.unaryExpr(SmoothL1LossForward<T>(sigma2));
// multiply outside weight // multiply outside weight
if (has_weight) { if (has_weight) {
...@@ -99,9 +100,9 @@ class SmoothL1LossKernel : public framework::OpKernel { ...@@ -99,9 +100,9 @@ class SmoothL1LossKernel : public framework::OpKernel {
template <typename T> template <typename T>
struct SmoothL1LossBackward { struct SmoothL1LossBackward {
__host__ __device__ SmoothL1LossBackward(const T& sigma2) : sigma2(sigma2) {} HOSTDEVICE SmoothL1LossBackward(const T& sigma2) : sigma2(sigma2) {}
__host__ __device__ T operator()(const T& val) const { HOSTDEVICE T operator()(const T& val) const {
T abs_val = std::abs(val); T abs_val = std::abs(val);
if (abs_val < 1.0 / sigma2) { if (abs_val < 1.0 / sigma2) {
return sigma2 * val; return sigma2 * val;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册