未验证 提交 22823df2 编写于 作者: A Aurelius84 提交者: GitHub

enhance embedding error message test=develop (#20246)

* enhance embedding error message test=develop

* enforce .h error test=develop

* fix unittest code test=develop

* Fix fp16 dtype in embedding test=develop

* add import warnings test=develop
上级 9707ded3
......@@ -27,20 +27,28 @@ class LookupTableOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"),
"Input(Ids) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LookupTableOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
"Input(W) of LookupTableOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Ids"), true,
"Input(Ids) of LookupTableOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of LookupTableOp should not be null.");
auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids");
int ids_rank = ids_dims.size();
VLOG(5) << "ids rank is " << ids_rank << std::endl;
PADDLE_ENFORCE_EQ(table_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
"The last dimension of the 'Ids' tensor must be 1.");
PADDLE_ENFORCE_EQ(
table_dims.size(), 2,
"ShapeError: The dimensions of the 'lookup table' must be 2. "
"But received lookup table's dimensions = %d, "
"lookup table's shape = [%s].",
table_dims.size(), table_dims);
PADDLE_ENFORCE_EQ(
ids_dims[ids_rank - 1], 1,
"ShapeError: The last dimensions of the 'Ids' tensor must be 1. "
"But received Ids's last dimensions = %d, Ids's shape = [%s].",
ids_dims[ids_rank - 1], ids_dims);
auto output_dims =
framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1));
......
......@@ -158,9 +158,14 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto *d_table_data = d_table_value->data<T>();
auto *d_output_data = d_output->data<T>();
auto d_output_dims = d_output->dims();
PADDLE_ENFORCE_EQ(
d_table_value->dims(),
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
auto d_output_dims_2d =
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output_dims_2d,
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s].",
d_table_value->dims(), d_output_dims_2d);
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
d_output->numel() * sizeof(T), stream);
......
......@@ -113,9 +113,15 @@ class LookupTableKernel : public framework::OpKernel<T> {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(ids[i], 0);
PADDLE_ENFORCE_GE(
ids[i], 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0. But received %ld",
ids[i]);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
PADDLE_ENFORCE_GE(
id_index, 0, "the input key should be exists. But received %d.",
id_index);
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
......@@ -180,9 +186,14 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
auto *d_table_data = d_table_value->data<T>();
auto d_output_dims = d_output->dims();
PADDLE_ENFORCE_EQ(
d_table_value->dims(),
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
auto d_output_dims_2d =
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output_dims_2d,
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s].",
d_table_value->dims(), d_output_dims_2d);
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
}
} else {
......
......@@ -38,7 +38,12 @@ class LookupTableV2Op : public framework::OperatorWithKernel {
auto ids_dims = ctx->GetInputDim("Ids");
int ids_rank = ids_dims.size();
VLOG(5) << "ids rank is " << ids_rank << std::endl;
PADDLE_ENFORCE_EQ(table_dims.size(), 2);
PADDLE_ENFORCE_EQ(
table_dims.size(), 2,
"ShapeError: The dimensions of the 'lookup table' must be 2. "
"But received lookup table's dimensions = %d, "
"lookup table's shape = [%s].",
table_dims.size(), table_dims);
auto output_dims = framework::vectorize(ids_dims);
output_dims.push_back(table_dims[1]);
......
......@@ -158,9 +158,14 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
auto *d_table_data = d_table_value->data<T>();
auto *d_output_data = d_output->data<T>();
auto d_output_dims = d_output->dims();
PADDLE_ENFORCE_EQ(
d_table_value->dims(),
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
auto d_output_dims_2d =
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output_dims_2d,
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s].",
d_table_value->dims(), d_output_dims_2d);
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
d_output->numel() * sizeof(T), stream);
......
......@@ -113,9 +113,15 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(ids[i], 0);
PADDLE_ENFORCE_GE(
ids[i], 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0. But received %ld",
ids[i]);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
PADDLE_ENFORCE_GE(
id_index, 0, "the input key should be exists. But received %d.",
id_index);
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
......@@ -170,9 +176,14 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
auto *d_table_data = d_table_value->data<T>();
auto d_output_dims = d_output->dims();
PADDLE_ENFORCE_EQ(
d_table_value->dims(),
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
auto d_output_dims_2d =
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output_dims_2d,
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s].",
d_table_value->dims(), d_output_dims_2d);
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else {
......
......@@ -13,8 +13,10 @@
# limitations under the License.
from __future__ import print_function
import warnings
from .framework import Variable, in_dygraph_mode
from .layer_helper import LayerHelper
from .data_feeder import convert_dtype
__all__ = ['one_hot', 'embedding']
......@@ -231,6 +233,21 @@ def embedding(input,
"""
helper = LayerHelper('embedding', **locals())
if not isinstance(input, Variable):
raise TypeError(
"The type of 'input' in fluid.embedding must be Variable, but received %s"
% (type(input)))
if convert_dtype(input.dtype) not in ['int64']:
raise TypeError(
"The data type of 'input' in fluid.embedding must be int64, but received %s."
% (convert_dtype(input.dtype)))
if convert_dtype(dtype) in ['float16']:
warnings.warn(
"The 'dtype' of fluid.embedding only support float16 in GPU now.")
if convert_dtype(dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The 'dtype' of fluid.embedding must be float16, float32 or float64, but received %s."
% (convert_dtype(dtype)))
remote_prefetch = is_sparse and (not is_distributed)
if remote_prefetch:
assert is_sparse is True and is_distributed is False
......
......@@ -601,6 +601,21 @@ def embedding(input,
"""
helper = LayerHelper('embedding', **locals())
if not isinstance(input, Variable):
raise TypeError(
"The type of 'input' in layers.embedding must be Variable, but received %s"
% (type(input)))
if convert_dtype(input.dtype) not in ['int64']:
raise TypeError(
"The data type of 'input' in layers.embedding must be int64, but received %s."
% (convert_dtype(input.dtype)))
if convert_dtype(dtype) in ['float16']:
warnings.warn(
"The 'dtype' of layers.embedding only support float16 in GPU now.")
if convert_dtype(dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The 'dtype' of layers.embedding must be float16, float32 or float64, but received %s."
% (convert_dtype(dtype)))
remote_prefetch = is_sparse and (not is_distributed)
if remote_prefetch:
assert is_sparse is True and is_distributed is False
......
......@@ -40,7 +40,7 @@ class TestListenAndServOp(OpTest):
print(sys.platform)
cmd = "wget --no-check-certificate https://pslib.bj.bcebos.com/fleet_desc.prototxt"
os.system(cmd)
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
x = fluid.layers.data(name='x', shape=[1], dtype='int64')
x_emb = fluid.layers.embedding(
input=x, size=[1, 2], is_distributed=True)
y_predict = fluid.layers.fc(input=x_emb, size=1, act=None)
......@@ -96,7 +96,7 @@ class TestListenAndServOp(OpTest):
print(sys.platform)
cmd = "wget --no-check-certificate https://pslib.bj.bcebos.com/fleet_desc.prototxt"
os.system(cmd)
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
x = fluid.layers.data(name='x', shape=[1], dtype='int64')
x_emb = fluid.layers.embedding(
input=x, size=[1, 2], is_distributed=True)
y_predict = fluid.layers.fc(input=x_emb, size=1, act=None)
......
......@@ -20,6 +20,8 @@ from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.compat as cpt
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestLookupTableOp(OpTest):
......@@ -150,5 +152,35 @@ class TestLookupTableWithTensorIdsWIsSelectedRows(
assert (row == result_array[idx]).all()
class TestEmbedOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
input_data = np.random.randint(0, 10, (4, 1)).astype("int64")
def test_Variable():
# the input type must be Variable
fluid.layers.embedding(input=input_data, size=(10, 64))
self.assertRaises(TypeError, test_Variable)
def test_input_dtype():
# the input dtype must be int64
input = fluid.data(name='x', shape=[4, 1], dtype='float32')
fluid.layers.embedding(input=input, size=(10, 64))
self.assertRaises(TypeError, test_input_dtype)
def test_param_dtype():
# dtype must be float32 or float64
input2 = fluid.data(name='x2', shape=[4, 1], dtype='int64')
fluid.layers.embedding(
input=input2, size=(10, 64), dtype='int64')
self.assertRaises(TypeError, test_param_dtype)
input3 = fluid.data(name='x3', shape=[4, 1], dtype='int64')
fluid.layers.embedding(input=input3, size=(10, 64), dtype='float16')
if __name__ == "__main__":
unittest.main()
......@@ -21,6 +21,8 @@ import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.op import Operator
import paddle.compat as cpt
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestLookupTableOp(OpTest):
......@@ -212,5 +214,33 @@ class TestLookupTableApi(unittest.TestCase):
return_numpy=False)
class TestEmbedOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
input_data = np.random.randint(0, 10, (4, 6)).astype("int64")
def test_Variable():
# the input type must be Variable
fluid.embedding(input=input_data, size=(10, 64))
self.assertRaises(TypeError, test_Variable)
def test_input_dtype():
# the input dtype must be int64
input = fluid.data(name='x1', shape=[4, 6], dtype='float32')
fluid.embedding(input=input, size=(10, 64))
self.assertRaises(TypeError, test_input_dtype)
def test_param_dtype():
# dtype must be float32 or float64
input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64')
fluid.embedding(input=input2, size=(10, 64), dtype='int64')
self.assertRaises(TypeError, test_param_dtype)
input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64')
fluid.embedding(input=input3, size=(10, 64), dtype='float16')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册