未验证 提交 4e71cac6 编写于 作者: G guo ran 提交者: GitHub

fix unsorted_segment_sum (#3898)

上级 20348c38
......@@ -82,10 +82,11 @@ class UnsortedSegmentSumKernel final : public user_op::OpKernel {
CHECK_EQ(out->shape().At(axis), sum_state->upper() - sum_state->lower());
offset = sum_state->lower();
}
UnsortedSegmentSumKernelUtil<device_type, T, K>::UnsortedSegmentSum(
ctx->device_ctx(), segment_ids->dptr<K>(), data->dptr<T>(), num_segment_ids, num_segments,
outer_dim_size, inner_dim_size, offset, out->mut_dptr<T>());
if (num_segment_ids != 0) {
UnsortedSegmentSumKernelUtil<device_type, T, K>::UnsortedSegmentSum(
ctx->device_ctx(), segment_ids->dptr<K>(), data->dptr<T>(), num_segment_ids, num_segments,
outer_dim_size, inner_dim_size, offset, out->mut_dptr<T>());
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }
};
......
......@@ -130,7 +130,7 @@ REGISTER_USER_OP("unsorted_segment_sum_like")
const Shape* segment_ids_shape = ctx->Shape4ArgNameAndIndex("segment_ids", 0);
CHECK_OR_RETURN(IsIndexDataType(*ctx->Dtype4ArgNameAndIndex("segment_ids", 0)));
const int64_t axis = ctx->Attr<int64_t>("axis");
Shape* out_shape = ctx->Shape4ArgNameAndIndex("out", 0);
user_op::TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex("out", 0);
CHECK_EQ_OR_RETURN(data->data_type(), like->data_type());
CHECK_GE_OR_RETURN(axis, 0);
CHECK_LE_OR_RETURN(axis, like_shape->NumAxes());
......@@ -140,9 +140,7 @@ REGISTER_USER_OP("unsorted_segment_sum_like")
FOR_RANGE(int64_t, i, axis + 1, like_shape->NumAxes()) {
CHECK_EQ_OR_RETURN(like_shape->At(i), data_shape->At(i + segment_ids_shape->NumAxes() - 1));
}
*out_shape = *like_shape;
*ctx->Dtype4ArgNameAndIndex("out", 0) = *ctx->Dtype4ArgNameAndIndex("like", 0);
*out = *like;
return Maybe<void>::Ok();
})
.SetBatchAxisInferFn([](user_op::BatchAxisContext* ctx) -> Maybe<void> {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册