未验证 提交 795d7121 编写于 作者: S sneaxiy 提交者: GitHub

fix some ops (#41577)

上级 c1394c6a
......@@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(size,
CPU,
ALL_LAYOUT,
phi::SizeKernel,
int16_t,
int,
int64_t,
phi::dtype::float16,
......
......@@ -222,25 +222,28 @@ void CumsumKernel(const Context& dev_ctx,
// Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension.
if (size == out_dims[axis]) {
#ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
const auto& policy = thrust::cuda::par.on(dev_ctx.stream());
#endif
if (reverse) {
thrust::device_ptr<const T> dev_ptr =
thrust::device_pointer_cast(in_data);
thrust::device_vector<T> vec(dev_ptr, dev_ptr + size);
thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in(
thrust::device_pointer_cast(in_data) + size);
thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out(
thrust::device_pointer_cast(out_data) + size);
if (exclusive) {
thrust::exclusive_scan(
thrust::device, vec.rbegin(), vec.rend(), out_data);
policy, reversed_in, reversed_in + size, reversed_out);
} else {
thrust::inclusive_scan(
thrust::device, vec.rbegin(), vec.rend(), out_data);
policy, reversed_in, reversed_in + size, reversed_out);
}
thrust::reverse(thrust::device, out_data, out_data + size);
} else {
if (exclusive) {
thrust::exclusive_scan(
thrust::device, in_data, in_data + size, out_data);
thrust::exclusive_scan(policy, in_data, in_data + size, out_data);
} else {
thrust::inclusive_scan(
thrust::device, in_data, in_data + size, out_data);
thrust::inclusive_scan(policy, in_data, in_data + size, out_data);
}
}
return;
......
......@@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(size,
GPU,
ALL_LAYOUT,
phi::SizeKernel,
int16_t,
int,
int64_t,
phi::dtype::float16,
......
......@@ -1795,7 +1795,7 @@ def cross_entropy(input,
# 2. else
# numerator: loss's weighted sum
# denominator: cal the sum of weight where the sample's class_index!=ignore_index
if ignore_index != -100:
if ignore_index >= 0:
out_sum = _C_ops.reduce_sum(out, 'reduce_all', True)
# for each label[i],set 1 or 0, according to ignore_index
# mask[i]=0, if label[i]==ignore_index
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册