未验证 提交 6e71b007 编写于 作者: G guo ran 提交者: GitHub

indexed_slices_model_update handle empty tensor (#3933)

* indexed_slices_model_update handle empty tensor

* indexed_slices_sgd
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 8fb3048c
......@@ -27,6 +27,8 @@ class TmpBufferManager final {
public:
OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager);
TmpBufferManager(void* ptr, const int64_t num_indices, const int64_t num_values) : ptr_(ptr) {
CHECK_NE(num_indices, 0);
CHECK_NE(num_values, 0);
const size_t unique_diff_indices_bytes = GetCudaAlignedSize(num_indices * sizeof(K));
const size_t unique_diff_values_bytes = GetCudaAlignedSize(num_values * sizeof(T));
const size_t num_unique_diff_indices_bytes = GetCudaAlignedSize(1 * sizeof(int32_t));
......@@ -176,6 +178,14 @@ class IndexedSlicesSGDUpdateKernel final : public user_op::OpKernel {
ctx->Tensor4ArgNameAndIndex("model_diff_indices", 0);
const user_op::Tensor* model_diff_values = ctx->Tensor4ArgNameAndIndex("model_diff_values", 0);
user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0);
const int64_t num_indices = model_diff_indices->shape().elem_cnt();
const int64_t num_values = model_diff_values->shape().elem_cnt();
if (num_indices == 0) {
CHECK_EQ(num_values, 0);
return;
}
CHECK_NE(num_values, 0);
CHECK_EQ(num_values % num_indices, 0);
auto* kernel_state = dynamic_cast<IndexedSlicesUpdateOpKernelState*>(state);
CHECK_NOTNULL(kernel_state);
CHECK_EQ(model->shape().At(0), kernel_state->upper() - kernel_state->lower());
......@@ -303,6 +313,11 @@ class IndexedSlicesMomentumUpdateKernel final : public user_op::OpKernel {
const auto beta = ctx->Attr<float>("beta");
const int64_t num_indices = model_diff_indices->shape().elem_cnt();
const int64_t num_values = model_diff_values->shape().elem_cnt();
if (num_indices == 0) {
CHECK_EQ(num_values, 0);
return;
}
CHECK_NE(num_values, 0);
CHECK_EQ(num_values % num_indices, 0);
const int64_t feature_size = num_values / num_indices;
CHECK_EQ(feature_size, model_diff_values->shape().Count(model_diff_indices->shape().NumAxes()));
......@@ -312,7 +327,7 @@ class IndexedSlicesMomentumUpdateKernel final : public user_op::OpKernel {
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
TmpBufferManager<device_type, T, K> buffer_manager(tmp_buffer->mut_dptr(), num_indices,
num_values);
CHECK_EQ(tmp_buffer->shape().elem_cnt(), buffer_manager.GetTotalBufferSize());
CHECK_GE(tmp_buffer->shape().elem_cnt(), buffer_manager.GetTotalBufferSize());
ReduceSumUtilT::ReduceSum(
ctx->device_ctx(), num_indices, feature_size, model_diff_indices->dptr<K>(),
model_diff_values->dptr<T>(), buffer_manager.NumUniqueDiffIndicesPtr(),
......@@ -429,13 +444,18 @@ class IndexedSlicesAdamUpdateKernel final : public user_op::OpKernel {
CHECK_EQ(model->shape().At(0), kernel_state->upper() - kernel_state->lower());
const int64_t num_indices = model_diff_indices->shape().elem_cnt();
const int64_t num_values = model_diff_values->shape().elem_cnt();
if (num_indices == 0) {
CHECK_EQ(num_values, 0);
return;
}
CHECK_NE(num_values, 0);
CHECK_EQ(num_values % num_indices, 0);
const int64_t feature_size = num_values / num_indices;
CHECK_EQ(feature_size, model_diff_values->shape().Count(model_diff_indices->shape().NumAxes()));
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
TmpBufferManager<device_type, T, K> buffer_manager(tmp_buffer->mut_dptr(), num_indices,
num_values);
CHECK_EQ(tmp_buffer->shape().elem_cnt(), buffer_manager.GetTotalBufferSize());
CHECK_GE(tmp_buffer->shape().elem_cnt(), buffer_manager.GetTotalBufferSize());
ReduceSumUtilT::ReduceSum(
ctx->device_ctx(), num_indices, feature_size, model_diff_indices->dptr<K>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册