未验证 提交 91266b96 编写于 作者: S Siming Dai 提交者: GitHub

[BugFix] Add error hint for one_hot gpu version (#41335)

* add one_hot gpu hint

* move allow_out_of_range judgement

* delete useless unittest
上级 dbd6e2df
......@@ -25,18 +25,12 @@ struct OneHotV2OpFunctor {
DenseTensor* out_;
int depth_;
const DeviceContext& ctx_;
bool allow_out_of_range_;
OneHotV2OpFunctor(const DenseTensor* in,
DenseTensor* out,
int depth,
const DeviceContext& ctx,
bool allow_out_of_range = false)
: in_(in),
out_(out),
depth_(depth),
ctx_(ctx),
allow_out_of_range_(allow_out_of_range) {}
const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
template <typename OutT>
void apply() const {
......@@ -45,32 +39,24 @@ struct OneHotV2OpFunctor {
auto* p_out_data = ctx_.template Alloc<OutT>(out_);
funcs::set_constant(ctx_, out_, 0.0);
if (allow_out_of_range_) {
for (int i = 0; i < numel; ++i) {
if (p_in_data[i] >= 0 && p_in_data[i] < depth_) {
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
} else {
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(
p_in_data[i],
0,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be at least 0, "
"but received input (%d) less than 0",
p_in_data[i]));
PADDLE_ENFORCE_LT(
p_in_data[i],
depth_,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be less than "
"Input(depth), "
"but received input (%d) not less than depth (%d)",
p_in_data[i],
depth_));
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(
p_in_data[i],
0,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be at least 0, "
"but received input (%d) less than 0",
p_in_data[i]));
PADDLE_ENFORCE_LT(
p_in_data[i],
depth_,
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be less than "
"Input(depth), "
"but received input (%d) not less than depth (%d)",
p_in_data[i],
depth_));
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
};
......@@ -89,8 +75,7 @@ void OneHotRawKernel(const Context& dev_ctx,
}
phi::VisitDataType(dtype,
OneHotV2OpFunctor<Context, T>(
&x, out, depth, dev_ctx, allow_out_of_range));
OneHotV2OpFunctor<Context, T>(&x, out, depth, dev_ctx));
}
} // namespace phi
......
......@@ -29,7 +29,14 @@ __global__ void FillOutputKernel(const InT* p_in_data,
const int64_t numel,
const int depth) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) {
if (idx < numel) {
PADDLE_ENFORCE(p_in_data[idx] >= 0 && p_in_data[idx] < depth,
"Illegal index value, Input(input) value should be "
"greater than or equal to 0, and less than depth [%d], "
"but received [%lld].",
depth,
p_in_data[idx]);
*(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0;
}
}
......
......@@ -117,24 +117,6 @@ class TestOneHotOp_default_dtype_attr(OpTest):
self.check_output()
class TestOneHotOp_out_of_range(OpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
x_lod = [[4, 1, 3, 3]]
x = [np.random.choice([-1, depth]) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0])])
out = np.zeros(shape=(np.product(x.shape), depth)).astype('float32')
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth, 'allow_out_of_range': True}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output()
class TestOneHotOp_exception(unittest.TestCase):
def setUp(self):
self.op_type = 'one_hot_v2'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册