未验证 提交 795d7121 编写于 作者: S sneaxiy 提交者: GitHub

fix some ops (#41577)

上级 c1394c6a
...@@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(size, ...@@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(size,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SizeKernel, phi::SizeKernel,
int16_t,
int, int,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
......
...@@ -222,25 +222,28 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -222,25 +222,28 @@ void CumsumKernel(const Context& dev_ctx,
// Use thrust for parallel acceleration when the input size is equal to the // Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension. // length of the ‘axis’ dimension.
if (size == out_dims[axis]) { if (size == out_dims[axis]) {
#ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
const auto& policy = thrust::cuda::par.on(dev_ctx.stream());
#endif
if (reverse) { if (reverse) {
thrust::device_ptr<const T> dev_ptr = thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in(
thrust::device_pointer_cast(in_data); thrust::device_pointer_cast(in_data) + size);
thrust::device_vector<T> vec(dev_ptr, dev_ptr + size); thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out(
thrust::device_pointer_cast(out_data) + size);
if (exclusive) { if (exclusive) {
thrust::exclusive_scan( thrust::exclusive_scan(
thrust::device, vec.rbegin(), vec.rend(), out_data); policy, reversed_in, reversed_in + size, reversed_out);
} else { } else {
thrust::inclusive_scan( thrust::inclusive_scan(
thrust::device, vec.rbegin(), vec.rend(), out_data); policy, reversed_in, reversed_in + size, reversed_out);
} }
thrust::reverse(thrust::device, out_data, out_data + size);
} else { } else {
if (exclusive) { if (exclusive) {
thrust::exclusive_scan( thrust::exclusive_scan(policy, in_data, in_data + size, out_data);
thrust::device, in_data, in_data + size, out_data);
} else { } else {
thrust::inclusive_scan( thrust::inclusive_scan(policy, in_data, in_data + size, out_data);
thrust::device, in_data, in_data + size, out_data);
} }
} }
return; return;
......
...@@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(size, ...@@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(size,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SizeKernel, phi::SizeKernel,
int16_t,
int, int,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
......
...@@ -1795,7 +1795,7 @@ def cross_entropy(input, ...@@ -1795,7 +1795,7 @@ def cross_entropy(input,
# 2. else # 2. else
# numerator: loss's weighted sum # numerator: loss's weighted sum
# denominator: cal the sum of weight where the sample's class_index!=ignore_index # denominator: cal the sum of weight where the sample's class_index!=ignore_index
if ignore_index != -100: if ignore_index >= 0:
out_sum = _C_ops.reduce_sum(out, 'reduce_all', True) out_sum = _C_ops.reduce_sum(out, 'reduce_all', True)
# for each label[i],set 1 or 0, according to ignore_index # for each label[i],set 1 or 0, according to ignore_index
# mask[i]=0, if label[i]==ignore_index # mask[i]=0, if label[i]==ignore_index
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册