提交 feac69aa 编写于 作者: J Jiabin Yang 提交者: liuwei1031

test=release/1.4, fix hsigmoid dereference nullptr (#16770)

* test=release/1.4, fix hsigmoid dereference nullptr

* test=release/1.4, refine condition

* test=release/1.4, refine comments
上级 af53eb6a
...@@ -234,6 +234,8 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -234,6 +234,8 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
zero(dev_ctx, w_grad, static_cast<T>(0.0)); zero(dev_ctx, w_grad, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, in); bit_code->MulGradWeight(pre_out_grad, w_grad, in);
} else { } else {
PADDLE_ENFORCE(path != nullptr,
"Sparse mode should not be used without custom tree!");
framework::Vector<int64_t> real_rows = PathToRows(*path); framework::Vector<int64_t> real_rows = PathToRows(*path);
auto* w_grad = auto* w_grad =
ctx.Output<framework::SelectedRows>(framework::GradVarName("W")); ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));
......
...@@ -5589,12 +5589,21 @@ def hsigmoid(input, ...@@ -5589,12 +5589,21 @@ def hsigmoid(input,
raise ValueError( raise ValueError(
"num_classes must not be less than 2 with default tree") "num_classes must not be less than 2 with default tree")
if (not is_custom) and (is_sparse):
print("Sparse mode should not be used without custom tree")
is_sparse = False
if (not is_custom) and ((path_table is not None) or
(path_code is not None)):
raise ValueError(
"only num_classes should be passed without custom tree")
if (is_custom) and (path_code is None): if (is_custom) and (path_code is None):
raise ValueError("path_code should not be None with costum tree") raise ValueError("path_code should not be None with custom tree")
elif (is_custom) and (path_table is None): elif (is_custom) and (path_table is None):
raise ValueError("path_table should not be None with costum tree") raise ValueError("path_table should not be None with custom tree")
elif (is_custom) and (num_classes is None): elif (is_custom) and (num_classes is None):
raise ValueError("num_classes should not be None with costum tree") raise ValueError("num_classes should not be None with custom tree")
else: else:
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册