未验证 提交 16817c70 编写于 作者: Y yaoxuefeng 提交者: GitHub

OP(datanorm lookupsparsetable lookuptable) error message enhancement (#24506)

* OP(datanorm lookupsparsetable lookuptable) error message enhancement

* fix test=develop

* fix test=develop

* fix test=develop

* fix test=develop

* fix test=develop

* fix test=develop

* fix test=develop
上级 fff9faae
...@@ -44,13 +44,15 @@ class DataNormOp : public framework::OperatorWithKernel { ...@@ -44,13 +44,15 @@ class DataNormOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), ""); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DataNorm");
PADDLE_ENFORCE(ctx->HasInput("BatchSize"), ""); OP_INOUT_CHECK(ctx->HasInput("BatchSize"), "Input", "BatchSize",
PADDLE_ENFORCE(ctx->HasInput("BatchSum"), ""); "DataNorm");
PADDLE_ENFORCE(ctx->HasInput("BatchSquareSum"), ""); OP_INOUT_CHECK(ctx->HasInput("BatchSum"), "Input", "BatchSum", "DataNorm");
PADDLE_ENFORCE(ctx->HasOutput("Means"), ""); OP_INOUT_CHECK(ctx->HasInput("BatchSquareSum"), "Input", "BatchSquareSum",
PADDLE_ENFORCE(ctx->HasOutput("Scales"), ""); "DataNorm");
PADDLE_ENFORCE(ctx->HasOutput("Y"), ""); OP_INOUT_CHECK(ctx->HasOutput("Means"), "Output", "Means", "DataNorm");
OP_INOUT_CHECK(ctx->HasOutput("Scales"), "Output", "Scales", "DataNorm");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "DataNorm");
bool enable_scale_and_shift = bool enable_scale_and_shift =
ctx->Attrs().Get<bool>("enable_scale_and_shift"); ctx->Attrs().Get<bool>("enable_scale_and_shift");
if (enable_scale_and_shift) { if (enable_scale_and_shift) {
...@@ -67,20 +69,33 @@ class DataNormOp : public framework::OperatorWithKernel { ...@@ -67,20 +69,33 @@ class DataNormOp : public framework::OperatorWithKernel {
const DataLayout data_layout = framework::StringToDataLayout( const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout")); ctx->Attrs().Get<std::string>("data_layout"));
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, PADDLE_ENFORCE_EQ(x_dims.size() >= 2 && x_dims.size() <= 5, true,
"Input X must have 2 to 5 dimensions."); platform::errors::InvalidArgument(
"Input X must have 2 to 5 dimensions."));
const int64_t C = const int64_t C =
(data_layout == DataLayout::kNCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize").size(), 1UL); PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize").size(), 1UL,
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum").size(), 1UL); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum").size(), 1UL); "The input dim of BatchSize shouold be 1"));
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum").size(), 1UL,
platform::errors::InvalidArgument(
"The input dim of BatchSum shouold be 1"));
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum").size(), 1UL,
platform::errors::InvalidArgument(
"The input dim of BatchSquareSum shouold be 1"));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C); PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C,
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C); "The input dim[0] of BatchSize shouold be C"));
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C,
platform::errors::InvalidArgument(
"The input dim[0] of BatchSum shouold be C"));
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C,
platform::errors::InvalidArgument(
"The input dim[0] of BatchSqureSum shouold be C"));
} }
if (enable_scale_and_shift) { if (enable_scale_and_shift) {
...@@ -141,13 +156,16 @@ class DataNormOp : public framework::OperatorWithKernel { ...@@ -141,13 +156,16 @@ class DataNormOp : public framework::OperatorWithKernel {
} }
PADDLE_ENFORCE_EQ(dn_param_type, PADDLE_ENFORCE_EQ(dn_param_type,
OperatorWithKernel::IndicateVarDataType(ctx, "BatchSize"), OperatorWithKernel::IndicateVarDataType(ctx, "BatchSize"),
"BatchSize input should be of float type"); platform::errors::InvalidArgument(
"BatchSize input should be of float type"));
PADDLE_ENFORCE_EQ(dn_param_type, PADDLE_ENFORCE_EQ(dn_param_type,
OperatorWithKernel::IndicateVarDataType(ctx, "BatchSum"), OperatorWithKernel::IndicateVarDataType(ctx, "BatchSum"),
"BatchSum input should be of float type"); platform::errors::InvalidArgument(
"BatchSum input should be of float type"));
PADDLE_ENFORCE_EQ(dn_param_type, OperatorWithKernel::IndicateVarDataType( PADDLE_ENFORCE_EQ(dn_param_type, OperatorWithKernel::IndicateVarDataType(
ctx, "BatchSquareSum"), ctx, "BatchSquareSum"),
"BatchSquareSum input should be of float type"); platform::errors::InvalidArgument(
"BatchSquareSum input should be of float type"));
bool enable_scale_and_shift = ctx.Attr<bool>("enable_scale_and_shift"); bool enable_scale_and_shift = ctx.Attr<bool>("enable_scale_and_shift");
if (enable_scale_and_shift) { if (enable_scale_and_shift) {
...@@ -183,8 +201,9 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -183,8 +201,9 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("epsilon", "") AddAttr<float>("epsilon", "")
.SetDefault(1e-4) .SetDefault(1e-4)
.AddCustomChecker([](const float &epsilon) { .AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
"'epsilon' should be between 0.0 and 0.001."); platform::errors::InvalidArgument(
"'epsilon' should be between 0.0 and 0.001."));
}); });
AddAttr<int>("slot_dim", AddAttr<int>("slot_dim",
"(int, default -1) Dimension of one slot if set, " "(int, default -1) Dimension of one slot if set, "
...@@ -256,7 +275,8 @@ class DataNormKernel<platform::CPUDeviceContext, T> ...@@ -256,7 +275,8 @@ class DataNormKernel<platform::CPUDeviceContext, T>
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() == 2, "The Input dim size should be 2"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument(
"The Input dim size should be 2"));
const int N = x_dims[0]; const int N = x_dims[0];
const int C = const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
...@@ -379,8 +399,9 @@ class DataNormGradOp : public framework::OperatorWithKernel { ...@@ -379,8 +399,9 @@ class DataNormGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
// check input // check input
PADDLE_ENFORCE(ctx->HasInput("X")); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DataNormGrad");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), ""); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
framework::GradVarName("Y"), "DataNormGrad");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasOutput("BatchSize"), true, ctx->HasOutput("BatchSize"), true,
platform::errors::NotFound( platform::errors::NotFound(
...@@ -393,15 +414,19 @@ class DataNormGradOp : public framework::OperatorWithKernel { ...@@ -393,15 +414,19 @@ class DataNormGradOp : public framework::OperatorWithKernel {
ctx->HasOutput("BatchSquareSum"), true, ctx->HasOutput("BatchSquareSum"), true,
platform::errors::NotFound( platform::errors::NotFound(
"Output(BatchSquareSum) of DataNormGradOp should not be null.")); "Output(BatchSquareSum) of DataNormGradOp should not be null."));
PADDLE_ENFORCE(ctx->HasInput("Means"), ""); OP_INOUT_CHECK(ctx->HasInput("Means"), "Input", "Means", "DataNormGrad");
PADDLE_ENFORCE(ctx->HasInput("Scales"), ""); OP_INOUT_CHECK(ctx->HasInput("Scales"), "Input", "Scales", "DataNormGrad");
bool enable_scale_and_shift = bool enable_scale_and_shift =
ctx->Attrs().Get<bool>("enable_scale_and_shift"); ctx->Attrs().Get<bool>("enable_scale_and_shift");
// check output // check output
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSize")), ""); OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BatchSize")),
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSum")), ""); "Output", framework::GradVarName("BatchSize"),
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSquareSum")), "DataNormGrad");
""); OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BatchSum")), "Output",
framework::GradVarName("BatchSum"), "DataNormGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BatchSquareSum")),
"Output", framework::GradVarName("BatchSquareSum"),
"DataNormGrad");
const auto x_dims = ctx->GetInputDim("X"); const auto x_dims = ctx->GetInputDim("X");
const DataLayout data_layout = framework::StringToDataLayout( const DataLayout data_layout = framework::StringToDataLayout(
...@@ -486,7 +511,8 @@ class DataNormGradKernel<platform::CPUDeviceContext, T> ...@@ -486,7 +511,8 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
// Get the size for each dimension. // Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width] // NCHW [batch_size, in_channels, in_height, in_width]
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() == 2, "The Input dim size should be 2"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument(
"The Input dim size should be 2"));
const int N = x_dims[0]; const int N = x_dims[0];
const int C = const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
......
...@@ -82,24 +82,28 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -82,24 +82,28 @@ class MergeIdsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("Ids"), OP_INOUT_CHECK(ctx->HasInputs("Ids"), "Input", "Ids", "MergeIds");
"MergeIdsOp must have multi input Ids."); OP_INOUT_CHECK(ctx->HasInputs("Rows"), "Input", "Rows", "MergeIds");
PADDLE_ENFORCE(ctx->HasInputs("Rows"), OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "MergeIds");
"MergeIdsOp must have multi input Rows."); OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", "MergeIds");
PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must have multi input X.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"MergeIdsOp must have multi output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front(); auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputsDim("Ids"); auto ids_dims = ctx->GetInputsDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(ids_dims[0][1], 1); ids_dims[0].size(), 2,
platform::errors::InvalidArgument(
"the ids size must be 2, but received %d", ids_dims[0].size()));
PADDLE_ENFORCE_EQ(
ids_dims[0][1], 1,
platform::errors::InvalidArgument(
"the ids dim must be 1, but received %d", ids_dims[0][1]));
} }
auto x_var_type = ctx->GetInputsVarType("X"); auto x_var_type = ctx->GetInputsVarType("X");
for (auto &var_type : x_var_type) { for (auto &var_type : x_var_type) {
PADDLE_ENFORCE_EQ(var_type, framework::proto::VarType::LOD_TENSOR, PADDLE_ENFORCE_EQ(var_type, framework::proto::VarType::LOD_TENSOR,
"input X only support lod tensors"); platform::errors::InvalidArgument(
"input X only support lod tensors"));
} }
ctx->ShareLoD("Ids", "Out"); ctx->ShareLoD("Ids", "Out");
} }
......
...@@ -39,9 +39,11 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -39,9 +39,11 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out"); auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
PADDLE_ENFORCE_EQ(row_ids.size(), x_tensors.size(), PADDLE_ENFORCE_EQ(row_ids.size(), x_tensors.size(),
"the number of Rows and X should be the same"); platform::errors::InvalidArgument(
"the number of Rows and X should be the same"));
PADDLE_ENFORCE_EQ(ids.size(), outs.size(), PADDLE_ENFORCE_EQ(ids.size(), outs.size(),
"the number of Ids and Out should be the same"); platform::errors::InvalidArgument(
"the number of Ids and Out should be the same"));
int64_t row_ids_size = 0; int64_t row_ids_size = 0;
int64_t row_size = 0; int64_t row_size = 0;
...@@ -55,14 +57,16 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -55,14 +57,16 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
embedding_size = x_tensor->dims()[1]; embedding_size = x_tensor->dims()[1];
} }
PADDLE_ENFORCE_EQ(embedding_size, x_tensor->dims()[1], PADDLE_ENFORCE_EQ(embedding_size, x_tensor->dims()[1],
"embedding size of all input should be the same"); platform::errors::InvalidArgument(
"embedding size of all input should be the same"));
row_size += x_tensor->dims()[0]; row_size += x_tensor->dims()[0];
row_ids_size += row_id->dims()[0]; row_ids_size += row_id->dims()[0];
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
row_size, row_ids_size, row_size, row_ids_size,
"the merged X dim[0] and merged Rows dim[0] should be the same"); platform::errors::InvalidArgument(
"the merged X dim[0] and merged Rows dim[0] should be the same"));
std::unordered_map<int64_t, std::tuple<int64_t, int64_t>> std::unordered_map<int64_t, std::tuple<int64_t, int64_t>>
selected_rows_idx_map; selected_rows_idx_map;
...@@ -76,7 +80,8 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -76,7 +80,8 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
} }
} }
PADDLE_ENFORCE_EQ(row_ids_size, selected_rows_idx_map.size(), PADDLE_ENFORCE_EQ(row_ids_size, selected_rows_idx_map.size(),
"the rows and tensor map size should be the same"); platform::errors::InvalidArgument(
"the rows and tensor map size should be the same"));
for (size_t i = 0; i < outs.size(); ++i) { for (size_t i = 0; i < outs.size(); ++i) {
auto *out_ids = ids[i]; auto *out_ids = ids[i];
......
...@@ -26,8 +26,7 @@ constexpr int64_t kNoPadding = -1; ...@@ -26,8 +26,7 @@ constexpr int64_t kNoPadding = -1;
class LookupSparseTableInferShape : public framework::InferShapeBase { class LookupSparseTableInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "LookupSparseTable");
"Output(Out) of LookupSparseTableOp should not be null.");
auto shape_w = ctx->GetInputDim("W"); auto shape_w = ctx->GetInputDim("W");
auto shape_ids = ctx->GetInputDim("Ids"); auto shape_ids = ctx->GetInputDim("Ids");
shape_w[0] = shape_ids.size(); shape_w[0] = shape_ids.size();
...@@ -47,12 +46,15 @@ class LookupSparseTableOp : public framework::OperatorBase { ...@@ -47,12 +46,15 @@ class LookupSparseTableOp : public framework::OperatorBase {
auto ids_var = scope.FindVar(Input("Ids")); auto ids_var = scope.FindVar(Input("Ids"));
auto is_test = Attr<bool>("is_test"); auto is_test = Attr<bool>("is_test");
PADDLE_ENFORCE(out_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE_EQ(out_var->IsType<framework::LoDTensor>(), true,
"The type of Out var should be LodTensor."); platform::errors::InvalidArgument(
PADDLE_ENFORCE(w_var->IsType<framework::SelectedRows>(), "The type of Out var should be LodTensor."));
"The type of W var should be SelectedRows."); PADDLE_ENFORCE_EQ(w_var->IsType<framework::SelectedRows>(), true,
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(), platform::errors::InvalidArgument(
"The type of Ids var should be LoDTensor."); "The type of W var should be SelectedRows."));
PADDLE_ENFORCE_EQ(ids_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The type of Ids var should be LoDTensor."));
auto &ids_t = ids_var->Get<framework::LoDTensor>(); auto &ids_t = ids_var->Get<framework::LoDTensor>();
auto out_t = out_var->GetMutable<framework::LoDTensor>(); auto out_t = out_var->GetMutable<framework::LoDTensor>();
auto w_t = w_var->GetMutable<framework::SelectedRows>(); auto w_t = w_var->GetMutable<framework::SelectedRows>();
...@@ -64,7 +66,8 @@ class LookupSparseTableOp : public framework::OperatorBase { ...@@ -64,7 +66,8 @@ class LookupSparseTableOp : public framework::OperatorBase {
out_t->Resize(out_shape); out_t->Resize(out_shape);
out_t->mutable_data(cpu, w_t->value().type()); out_t->mutable_data(cpu, w_t->value().type());
PADDLE_ENFORCE_EQ(w_t->value().type(), framework::proto::VarType::FP32, PADDLE_ENFORCE_EQ(w_t->value().type(), framework::proto::VarType::FP32,
"The sparse table only support FP32"); platform::errors::InvalidArgument(
"The sparse table only support FP32"));
w_t->Get(ids_t, out_t, true, is_test); w_t->Get(ids_t, out_t, true, is_test);
out_t->set_lod(ids_t.lod()); out_t->set_lod(ids_t.lod());
} }
......
...@@ -27,12 +27,9 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -27,12 +27,9 @@ class LookupTableOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "LookupTable");
"Input(W) of LookupTableOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "LookupTable");
PADDLE_ENFORCE_EQ(ctx->HasInput("Ids"), true, OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "LookupTable");
"Input(Ids) of LookupTableOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of LookupTableOp should not be null.");
auto table_dims = ctx->GetInputDim("W"); auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputDim("Ids");
...@@ -40,15 +37,17 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -40,15 +37,17 @@ class LookupTableOp : public framework::OperatorWithKernel {
VLOG(5) << "ids rank is " << ids_rank << std::endl; VLOG(5) << "ids rank is " << ids_rank << std::endl;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
table_dims.size(), 2, table_dims.size(), 2,
platform::errors::InvalidArgument(
"ShapeError: The dimensions of the 'lookup table' must be 2. " "ShapeError: The dimensions of the 'lookup table' must be 2. "
"But received lookup table's dimensions = %d, " "But received lookup table's dimensions = %d, "
"lookup table's shape = [%s].", "lookup table's shape = [%s].",
table_dims.size(), table_dims); table_dims.size(), table_dims));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ids_dims[ids_rank - 1], 1, ids_dims[ids_rank - 1], 1,
platform::errors::InvalidArgument(
"ShapeError: The last dimensions of the 'Ids' tensor must be 1. " "ShapeError: The last dimensions of the 'Ids' tensor must be 1. "
"But received Ids's last dimensions = %d, Ids's shape = [%s].", "But received Ids's last dimensions = %d, Ids's shape = [%s].",
ids_dims[ids_rank - 1], ids_dims); ids_dims[ids_rank - 1], ids_dims));
auto output_dims = auto output_dims =
framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1)); framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1));
......
...@@ -88,16 +88,18 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -88,16 +88,18 @@ class LookupTableKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
ids[i], row_number, ids[i], row_number,
platform::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) " "Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
row_number, ids[i]); row_number, ids[i]));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
ids[i], 0, ids[i], 0,
platform::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) " "Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
row_number, ids[i]); row_number, ids[i]));
memcpy(output + i * row_width, table + ids[i] * row_width, memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} }
...@@ -114,13 +116,16 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -114,13 +116,16 @@ class LookupTableKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
ids[i], 0, ids[i], 0,
platform::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) " "Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0. But received %ld", "expected >= 0. But received %ld",
ids[i]); ids[i]));
auto id_index = table_t.Index(ids[i]); auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
id_index, 0, "the input key should be exists. But received %d.", id_index, 0,
id_index); platform::errors::InvalidArgument(
"the input key should be exists. But received %d.",
id_index));
if (input_data_type == framework::proto::VarType::INT8) { if (input_data_type == framework::proto::VarType::INT8) {
memcpy(output + i * row_width, table + id_index * row_width, memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
...@@ -194,11 +199,12 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -194,11 +199,12 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
auto d_output_dims_2d = auto d_output_dims_2d =
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1); framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output_dims_2d, PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output_dims_2d,
platform::errors::InvalidArgument(
"ShapeError: The shape of lookup_table@Grad and " "ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. " "output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], " "But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s].", "output@Grad's shape = [%s].",
d_table_value->dims(), d_output_dims_2d); d_table_value->dims(), d_output_dims_2d));
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} }
} else { } else {
...@@ -223,14 +229,18 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -223,14 +229,18 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
ids_data[i], N, ids_data[i], N,
platform::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) " "Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.", "expected >= 0 and < %ld, but got %ld. Please check input "
N, ids_data[i]); "value.",
N, ids_data[i]));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
ids_data[i], 0, ids_data[i], 0,
platform::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) " "Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.", "expected >= 0 and < %ld, but got %ld. Please check input"
N, ids_data[i]); "value.",
N, ids_data[i]));
for (int j = 0; j < D; ++j) { for (int j = 0; j < D; ++j) {
d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j]; d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册