“f6c9f56394838021af5db26d046f3c90606a17fc”上不存在“release/0.10.0/doc/howto/usage/cmd_parameter/index_en.html”
未验证 提交 fff270ea 编写于 作者: Z Zeng Jinle 提交者: GitHub

follow comments,test=develop (#17273)

上级 7a3bb061
...@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and ...@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
...@@ -309,12 +310,6 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { ...@@ -309,12 +310,6 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
int ignore_idx_; int ignore_idx_;
}; };
template <typename T>
static __global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out, int n) {
auto idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < n) out[idx] = static_cast<T>(1);
}
template <typename T> template <typename T>
static void HardLabelSoftmaxWithCrossEntropy( static void HardLabelSoftmaxWithCrossEntropy(
const platform::CUDADeviceContext& ctx, const T* logits_data, const platform::CUDADeviceContext& ctx, const T* logits_data,
...@@ -354,13 +349,6 @@ static void HardLabelSoftmaxWithCrossEntropy( ...@@ -354,13 +349,6 @@ static void HardLabelSoftmaxWithCrossEntropy(
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8); CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
case 1:
SetSoftmaxToOneWhenFeatureSizeIsOne<<<(grid_dim + kMaxBlockDim - 1) /
kMaxBlockDim,
kMaxBlockDim, 0, stream>>>(
softmax_data, grid_dim);
cudaMemsetAsync(loss_data, 0, grid_dim * sizeof(T), stream);
break;
default: default:
PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op"); PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
break; break;
...@@ -401,13 +389,6 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data, ...@@ -401,13 +389,6 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
case 1:
SetSoftmaxToOneWhenFeatureSizeIsOne<<<(grid_dim + kMaxBlockDim - 1) /
kMaxBlockDim,
kMaxBlockDim, 0, stream>>>(
softmax_data, n);
cudaMemsetAsync(loss_data, 0, grid_dim * sizeof(T), stream);
break;
default: default:
PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op"); PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
break; break;
...@@ -431,6 +412,13 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -431,6 +412,13 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logits->dims()[axis]; int axis_dim = logits->dims()[axis];
if (axis_dim == 1) {
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
return;
}
const int n = SizeToAxis(axis, logits->dims()); const int n = SizeToAxis(axis, logits->dims());
const int d = SizeFromAxis(axis, logits->dims()); const int d = SizeFromAxis(axis, logits->dims());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册