未验证 提交 008857be 编写于 作者: S ShenLiang 提交者: GitHub

fix error message for scatter and scatter_nd (#24514)

上级 14376486
...@@ -95,11 +95,17 @@ void GPUScatterAssign(const framework::ExecutionContext& context, ...@@ -95,11 +95,17 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
const auto& ctx = context.device_context(); const auto& ctx = context.device_context();
if (index.dims().size() == 2) { if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1, PADDLE_ENFORCE_EQ(index.dims()[1], 1,
"index.dims()[1] should be 1 when index.dims().size() == " platform::errors::InvalidArgument(
"2 in scatter_op."); "index.dims()[1] should be 1 when "
"index.dims().size() = 2 in scatter_op."
"But received value is [%d]",
index.dims()[1]));
} else { } else {
PADDLE_ENFORCE_EQ(index.dims().size(), 1, PADDLE_ENFORCE_EQ(index.dims().size(), 1,
"index.dims().size() should be 1 or 2 in scatter_op."); platform::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in scatter_op."
"But received value is [%d]",
index.dims().size()));
} }
int index_size = index.dims()[0]; int index_size = index.dims()[0];
......
...@@ -73,15 +73,23 @@ elementwise_inner_add(const framework::ExecutionContext& ctx, ...@@ -73,15 +73,23 @@ elementwise_inner_add(const framework::ExecutionContext& ctx,
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true); PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
// check index of shape 1-D // check index of shape 1-D
if (index.dims().size() == 2) { if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1, PADDLE_ENFORCE_EQ(index.dims()[1], 1,
"index.dims()[1] should be 1 when index.dims().size() == " platform::errors::InvalidArgument(
"2 in scatter_op."); "index.dims()[1] should be 1 when "
"index.dims().size() =2 in scatter_op."
"But received value is [%d]",
index.dims()[1]));
} else { } else {
PADDLE_ENFORCE_EQ(index.dims().size(), 1, PADDLE_ENFORCE_EQ(index.dims().size(), 1,
"index.dims().size() should be 1 or 2 in scatter_op."); platform::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in scatter_op."
"But received value is [%d]",
index.dims().size()));
} }
int index_size = index.dims()[0]; int index_size = index.dims()[0];
...@@ -94,7 +102,9 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -94,7 +102,9 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
// check src shape and dst shape should match // check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++) for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i]); PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i],
platform::errors::InvalidArgument(
"src shape and dst shape should match"));
// slice size // slice size
size_t slice_size = 1; size_t slice_size = 1;
...@@ -111,12 +121,14 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -111,12 +121,14 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.device_context().GetPlace()), PADDLE_ENFORCE_EQ(
true); platform::is_cpu_place(ctx.device_context().GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1 || PADDLE_ENFORCE_EQ(
(index.dims().size() == 2 && index.dims()[1] == 1), index.dims().size() == 1 ||
""); (index.dims().size() == 2 && index.dims()[1] == 1),
true, platform::errors::InvalidArgument("index's shape is error."));
int index_size = index.dims()[0]; int index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
...@@ -130,7 +142,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, ...@@ -130,7 +142,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
// check src shape and dst shape should match // check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++) for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i]); PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i],
platform::errors::InvalidArgument(
"src shape and dst shape should match"));
// slice size // slice size
size_t slice_size = 1; size_t slice_size = 1;
...@@ -156,8 +170,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, ...@@ -156,8 +170,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update, void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.device_context().GetPlace()), PADDLE_ENFORCE_EQ(
true, "It should be running on the CPU"); platform::is_cpu_place(ctx.device_context().GetPlace()), true,
platform::errors::PreconditionNotMet("It should be running on the CPU"));
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:] // update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
auto index_dims = index.dims(); auto index_dims = index.dims();
......
...@@ -26,13 +26,19 @@ class ScatterNdAddOp : public framework::OperatorWithKernel { ...@@ -26,13 +26,19 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of ScatterNdAddOp should not be null."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true, "Input(X) of ScatterNdAddOp should not be null."));
"Input(Index) of ScatterNdAddOp should not be null."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(ctx->HasInput("Updates"), true, ctx->HasInput("Index"), true,
"Input(Updates) of ScatterNdAddOp should not be null."); platform::errors::InvalidArgument(
"Input(Index) of ScatterNdAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Updates"), true,
platform::errors::InvalidArgument(
"Input(Updates) of ScatterNdAddOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of ScatterNdAddOp should not be null."); platform::errors::InvalidArgument(
"Output(Out) of ScatterNdAddOp should not be null."));
auto ref_dims = ctx->GetInputDim("X"); auto ref_dims = ctx->GetInputDim("X");
auto ref_dims_size = ref_dims.size(); auto ref_dims_size = ref_dims.size();
...@@ -43,9 +49,11 @@ class ScatterNdAddOp : public framework::OperatorWithKernel { ...@@ -43,9 +49,11 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
index_dims[index_dims_size - 1], ref_dims_size, index_dims[index_dims_size - 1], ref_dims_size,
"Input(Index).shape[-1] should be no greater than Input(X).rank"); platform::errors::InvalidArgument(
"Input(Index).shape[-1] should be no greater than Input(X).rank"));
PADDLE_ENFORCE_GE(index_dims_size, 2UL, PADDLE_ENFORCE_GE(index_dims_size, 2UL,
"The rank of Input(Index) should be greater than 1"); platform::errors::InvalidArgument(
"The rank of Input(Index) should be greater than 1"));
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:] // update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
std::vector<int64_t> r_updates_dims; std::vector<int64_t> r_updates_dims;
...@@ -56,12 +64,14 @@ class ScatterNdAddOp : public framework::OperatorWithKernel { ...@@ -56,12 +64,14 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
r_updates_dims.emplace_back(ref_dims[i]); r_updates_dims.emplace_back(ref_dims[i]);
} }
PADDLE_ENFORCE_EQ(r_updates_dims.size(), updates_dims_size, PADDLE_ENFORCE_EQ(
"Updates has wrong shape"); r_updates_dims.size(), updates_dims_size,
platform::errors::InvalidArgument("Updates has wrong shape"));
for (int64_t i = 0; i < updates_dims_size; ++i) { for (int64_t i = 0; i < updates_dims_size; ++i) {
PADDLE_ENFORCE_EQ(r_updates_dims[i], updates_dims[i], PADDLE_ENFORCE_EQ(
"Updates has wrong shape"); r_updates_dims[i], updates_dims[i],
platform::errors::InvalidArgument("Updates has wrong shape"));
} }
ctx->SetOutputDim("Out", ref_dims); ctx->SetOutputDim("Out", ref_dims);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
...@@ -72,7 +82,8 @@ class ScatterNdAddOp : public framework::OperatorWithKernel { ...@@ -72,7 +82,8 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(OperatorWithKernel::IndicateVarDataType(ctx, "X"), PADDLE_ENFORCE_EQ(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "Updates"), OperatorWithKernel::IndicateVarDataType(ctx, "Updates"),
"Ref and Updates must have same type"); platform::errors::InvalidArgument(
"Ref and Updates must have same type"));
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context()); ctx.device_context());
} }
......
...@@ -25,7 +25,8 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> { ...@@ -25,7 +25,8 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device."); platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *X = ctx.Input<Tensor>("X"); auto *X = ctx.Input<Tensor>("X");
auto *Ids = ctx.Input<Tensor>("Index"); auto *Ids = ctx.Input<Tensor>("Index");
auto *Updates = ctx.Input<Tensor>("Updates"); auto *Updates = ctx.Input<Tensor>("Updates");
...@@ -35,12 +36,15 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> { ...@@ -35,12 +36,15 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> {
const auto &index_type = Ids->type(); const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(index_type_match, true,
index_type_match, true, platform::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, but desires to be %s or %s", "Index holds the wrong type, it holds [%s], but "
paddle::framework::DataTypeToString(index_type), "desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUScatterNdAdd<DeviceContext, T, int32_t>(ctx, *Updates, *Ids, Out); GPUScatterNdAdd<DeviceContext, T, int32_t>(ctx, *Updates, *Ids, Out);
} else { } else {
...@@ -54,7 +58,8 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -54,7 +58,8 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device."); platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Index"); auto *Ids = ctx.Input<Tensor>("Index");
......
...@@ -27,8 +27,9 @@ template <typename T> ...@@ -27,8 +27,9 @@ template <typename T>
class ScatterNdAddOpKernel : public framework::OpKernel<T> { class ScatterNdAddOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(
"This kernel only runs on CPU."); platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *X = ctx.Input<Tensor>("X"); auto *X = ctx.Input<Tensor>("X");
auto *Ids = ctx.Input<Tensor>("Index"); auto *Ids = ctx.Input<Tensor>("Index");
auto *Updates = ctx.Input<Tensor>("Updates"); auto *Updates = ctx.Input<Tensor>("Updates");
...@@ -39,12 +40,15 @@ class ScatterNdAddOpKernel : public framework::OpKernel<T> { ...@@ -39,12 +40,15 @@ class ScatterNdAddOpKernel : public framework::OpKernel<T> {
const auto &index_type = Ids->type(); const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(index_type_match, true,
index_type_match, true, platform::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, but desires to be %s or %s", "Index holds the wrong type, it holds [%s], but "
paddle::framework::DataTypeToString(index_type), "desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
ScatterNdAdd<T, int32_t>(ctx, *Updates, *Ids, Out); ScatterNdAdd<T, int32_t>(ctx, *Updates, *Ids, Out);
...@@ -58,8 +62,9 @@ template <typename T> ...@@ -58,8 +62,9 @@ template <typename T>
class ScatterNdAddGradientOpKernel : public framework::OpKernel<T> { class ScatterNdAddGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(
"This kernel only runs on CPU."); platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Index"); auto *Ids = ctx.Input<Tensor>("Index");
......
...@@ -24,24 +24,32 @@ class ScatterOp : public framework::OperatorWithKernel { ...@@ -24,24 +24,32 @@ class ScatterOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of ScatterOp should not be null."); platform::errors::InvalidArgument(
PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(X) of ScatterOp should not be null."));
"Input(Ids) of ScatterOp should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Ids"), true,
PADDLE_ENFORCE(ctx->HasInput("Updates"), platform::errors::InvalidArgument(
"Input(Updates) of ScatterOp should not be null."); "Input(Ids) of ScatterOp should not be null."));
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE_EQ(ctx->HasInput("Updates"), true,
"Output(Out) of ScatterOp should not be null."); platform::errors::InvalidArgument(
"Input(Updates) of ScatterOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of ScatterOp should not be null."));
auto updates_dims = ctx->GetInputDim("Updates"); auto updates_dims = ctx->GetInputDim("Updates");
auto ref_dims = ctx->GetInputDim("X"); auto ref_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Ids").size(), 1, PADDLE_ENFORCE_EQ(
"Update Ids should be 1-D."); ctx->GetInputDim("Ids").size(), 1,
PADDLE_ENFORCE_EQ(ref_dims.size(), updates_dims.size(), platform::errors::InvalidArgument("Update Ids should be 1-D."));
"Xerence and Updates should have the same shape size"); PADDLE_ENFORCE_EQ(
ref_dims.size(), updates_dims.size(),
platform::errors::InvalidArgument(
"Rerence and Updates should have the same shape size."));
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0], PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0],
ctx->GetInputDim("Ids")[0], ctx->GetInputDim("Ids")[0],
"Updates and Ids should have same batch-size."); platform::errors::InvalidArgument(
"Updates and Ids should have same batch-size."));
ctx->SetOutputDim("Out", ref_dims); ctx->SetOutputDim("Out", ref_dims);
} }
......
...@@ -24,8 +24,9 @@ template <typename T> ...@@ -24,8 +24,9 @@ template <typename T>
class ScatterOpCUDAKernel : public framework::OpKernel<T> { class ScatterOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device."); platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *X = ctx.Input<Tensor>("X"); auto *X = ctx.Input<Tensor>("X");
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *Updates = ctx.Input<Tensor>("Updates"); auto *Updates = ctx.Input<Tensor>("Updates");
...@@ -39,11 +40,14 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -39,11 +40,14 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index_type_match, true, index_type_match, true,
"scatter_op Index holds the wrong type, it holds %s, but desires to be " platform::errors::InvalidArgument(
"%s or %s", "scatter_op Index holds the wrong type, it holds [%s],"
paddle::framework::DataTypeToString(index_type), "but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int32_t>(ctx, *Updates, *Ids, Out, overwrite); GPUScatterAssign<T, int32_t>(ctx, *Updates, *Ids, Out, overwrite);
} else { } else {
...@@ -56,8 +60,9 @@ template <typename T> ...@@ -56,8 +60,9 @@ template <typename T>
class ScatterGradOpCUDAKernel : public framework::OpKernel<T> { class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device."); platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
...@@ -74,12 +79,14 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -74,12 +79,14 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index_type_match, true, index_type_match, true,
"scatter_op Index holds the wrong type, it holds %s, but desires to " platform::errors::InvalidArgument(
"be %s or %s", "scatter_op Index holds the wrong type, it holds [%s], "
paddle::framework::DataTypeToString(index_type), "but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString( paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)); framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
// Gradient by Gather: dUpdates = dO[Ids] // Gradient by Gather: dUpdates = dO[Ids]
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates); GPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
......
...@@ -27,8 +27,9 @@ template <typename T> ...@@ -27,8 +27,9 @@ template <typename T>
class ScatterOpKernel : public framework::OpKernel<T> { class ScatterOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(
"This kernel only runs on CPU."); platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *X = ctx.Input<Tensor>("X"); auto *X = ctx.Input<Tensor>("X");
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *Updates = ctx.Input<Tensor>("Updates"); auto *Updates = ctx.Input<Tensor>("Updates");
...@@ -41,12 +42,15 @@ class ScatterOpKernel : public framework::OpKernel<T> { ...@@ -41,12 +42,15 @@ class ScatterOpKernel : public framework::OpKernel<T> {
const auto &index_type = Ids->type(); const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(index_type_match, true,
index_type_match, platform::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, but desires to be %s or %s", "Index holds the wrong type, it holds [%s],"
paddle::framework::DataTypeToString(index_type), "but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (overwrite) { if (overwrite) {
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
ScatterAssign<T, int32_t>(ctx.device_context(), *Updates, *Ids, Out); ScatterAssign<T, int32_t>(ctx.device_context(), *Updates, *Ids, Out);
...@@ -67,8 +71,9 @@ template <typename T> ...@@ -67,8 +71,9 @@ template <typename T>
class ScatterGradientOpKernel : public framework::OpKernel<T> { class ScatterGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(
"This kernel only runs on CPU."); platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
...@@ -86,12 +91,14 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> { ...@@ -86,12 +91,14 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index_type_match, true, index_type_match, true,
"scatter_op index holds the wrong type, it holds %s, but desires to " platform::errors::InvalidArgument(
"be %s or %s", "scatter_op index holds the wrong type, it holds [%s],"
paddle::framework::DataTypeToString(index_type), "but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString( paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)); framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
CPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates); CPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册