未验证 提交 57fe4fc9 编写于 作者: S Siming Dai 提交者: GitHub

[BugFix] Add error hint for one_hot gpu version (#41335) (#41495)

* add one_hot gpu hint

* move allow_out_of_range judgement

* delete useless unittest
上级 7d143a48
...@@ -25,18 +25,12 @@ struct OneHotV2OpFunctor { ...@@ -25,18 +25,12 @@ struct OneHotV2OpFunctor {
DenseTensor* out_; DenseTensor* out_;
int depth_; int depth_;
const DeviceContext& ctx_; const DeviceContext& ctx_;
bool allow_out_of_range_;
OneHotV2OpFunctor(const DenseTensor* in, OneHotV2OpFunctor(const DenseTensor* in,
DenseTensor* out, DenseTensor* out,
int depth, int depth,
const DeviceContext& ctx, const DeviceContext& ctx)
bool allow_out_of_range = false) : in_(in), out_(out), depth_(depth), ctx_(ctx) {}
: in_(in),
out_(out),
depth_(depth),
ctx_(ctx),
allow_out_of_range_(allow_out_of_range) {}
template <typename OutT> template <typename OutT>
void apply() const { void apply() const {
...@@ -45,32 +39,24 @@ struct OneHotV2OpFunctor { ...@@ -45,32 +39,24 @@ struct OneHotV2OpFunctor {
auto* p_out_data = ctx_.template Alloc<OutT>(out_); auto* p_out_data = ctx_.template Alloc<OutT>(out_);
funcs::set_constant(ctx_, out_, 0.0); funcs::set_constant(ctx_, out_, 0.0);
if (allow_out_of_range_) { for (int i = 0; i < numel; ++i) {
for (int i = 0; i < numel; ++i) { PADDLE_ENFORCE_GE(
if (p_in_data[i] >= 0 && p_in_data[i] < depth_) { p_in_data[i],
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0; 0,
} phi::errors::InvalidArgument(
} "Illegal index value, Input(input) value should be at least 0, "
} else { "but received input (%d) less than 0",
for (int i = 0; i < numel; ++i) { p_in_data[i]));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_LT(
p_in_data[i], p_in_data[i],
0, depth_,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be at least 0, " "Illegal index value, Input(input) value should be less than "
"but received input (%d) less than 0", "Input(depth), "
p_in_data[i])); "but received input (%d) not less than depth (%d)",
PADDLE_ENFORCE_LT( p_in_data[i],
p_in_data[i], depth_));
depth_, *(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
phi::errors::InvalidArgument(
"Illegal index value, Input(input) value should be less than "
"Input(depth), "
"but received input (%d) not less than depth (%d)",
p_in_data[i],
depth_));
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
} }
} }
}; };
...@@ -89,8 +75,7 @@ void OneHotRawKernel(const Context& dev_ctx, ...@@ -89,8 +75,7 @@ void OneHotRawKernel(const Context& dev_ctx,
} }
phi::VisitDataType(dtype, phi::VisitDataType(dtype,
OneHotV2OpFunctor<Context, T>( OneHotV2OpFunctor<Context, T>(&x, out, depth, dev_ctx));
&x, out, depth, dev_ctx, allow_out_of_range));
} }
} // namespace phi } // namespace phi
......
...@@ -29,7 +29,14 @@ __global__ void FillOutputKernel(const InT* p_in_data, ...@@ -29,7 +29,14 @@ __global__ void FillOutputKernel(const InT* p_in_data,
const int64_t numel, const int64_t numel,
const int depth) { const int depth) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) { if (idx < numel) {
PADDLE_ENFORCE(p_in_data[idx] >= 0 && p_in_data[idx] < depth,
"Illegal index value, Input(input) value should be "
"greater than or equal to 0, and less than depth [%d], "
"but received [%lld].",
depth,
p_in_data[idx]);
*(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0; *(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0;
} }
} }
......
...@@ -117,24 +117,6 @@ class TestOneHotOp_default_dtype_attr(OpTest): ...@@ -117,24 +117,6 @@ class TestOneHotOp_default_dtype_attr(OpTest):
self.check_output() self.check_output()
class TestOneHotOp_out_of_range(OpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
x_lod = [[4, 1, 3, 3]]
x = [np.random.choice([-1, depth]) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0])])
out = np.zeros(shape=(np.product(x.shape), depth)).astype('float32')
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth, 'allow_out_of_range': True}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output()
class TestOneHotOp_exception(unittest.TestCase): class TestOneHotOp_exception(unittest.TestCase):
def setUp(self): def setUp(self):
self.op_type = 'one_hot_v2' self.op_type = 'one_hot_v2'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册