未验证 提交 53e3c534 编写于 作者: S ShenLiang 提交者: GitHub

fix error message, test=develop (#24425)

上级 ea2c4987
...@@ -33,8 +33,9 @@ class AllReduceOpKernel : public framework::OpKernel<T> { ...@@ -33,8 +33,9 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place), PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
"AllReduce op can run on gpu place only for now."); platform::errors::PreconditionNotMet(
"AllReduce op can run on gpu place only for now."));
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto in = ctx.Input<framework::Tensor>("X"); auto in = ctx.Input<framework::Tensor>("X");
...@@ -49,7 +50,8 @@ class AllReduceOpKernel : public framework::OpKernel<T> { ...@@ -49,7 +50,8 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
auto* comm = dev_ctx.nccl_comm(); auto* comm = dev_ctx.nccl_comm();
// FIXME(typhoonzero): should use nccl stream here. // FIXME(typhoonzero): should use nccl stream here.
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
PADDLE_ENFORCE_NOT_NULL(stream, "Should initialize NCCL firstly."); PADDLE_ENFORCE_NOT_NULL(
stream, platform::errors::NotFound("Should initialize NCCL firstly."));
int reduce_type = ctx.Attr<int>("reduce_type"); int reduce_type = ctx.Attr<int>("reduce_type");
ncclRedOp_t red_type = ncclSum; ncclRedOp_t red_type = ncclSum;
...@@ -67,7 +69,7 @@ class AllReduceOpKernel : public framework::OpKernel<T> { ...@@ -67,7 +69,7 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
red_type = ncclMin; red_type = ncclMin;
break; break;
} }
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type, sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
comm, stream)); comm, stream));
if (ctx.Attr<bool>("sync_mode")) { if (ctx.Attr<bool>("sync_mode")) {
......
...@@ -26,10 +26,12 @@ class BroadcastOp : public framework::OperatorWithKernel { ...@@ -26,10 +26,12 @@ class BroadcastOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of BroadcastOp should not be null."); platform::errors::InvalidArgument(
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Input(X) of BroadcastOp should not be null."));
"Output(Output) of ConvOp should not be null."); PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Output) of ConvOp should not be null."));
} }
}; };
......
...@@ -34,8 +34,10 @@ template <typename T> ...@@ -34,8 +34,10 @@ template <typename T>
class NCCLBroadcastOpKernel : public framework::OpKernel<T> { class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(
"The place of ExecutionContext should be CUDAPlace."); platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"The place of ExecutionContext should be CUDAPlace."));
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).device; int dev_id = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).device;
...@@ -43,19 +45,22 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> { ...@@ -43,19 +45,22 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
auto in = ctx.Input<framework::Tensor>("X"); auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out"); auto out = ctx.Output<framework::Tensor>("Out");
PADDLE_ENFORCE(out->IsInitialized(), PADDLE_ENFORCE_EQ(
"Currently, the output of broadcast op must be initialized, " out->IsInitialized(), true,
"because this op can only be an In-Place operation."); platform::errors::PreconditionNotMet(
"Currently, the output of broadcast op must be initialized,"
"because this op can only be an In-Place operation."));
void* send_recv_buffer = out->mutable_data<T>(ctx.GetPlace()); void* send_recv_buffer = out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
send_recv_buffer, in->data<void>(), send_recv_buffer, in->data<void>(),
"Currently, the broadcast op can only be an In-Place operation."); platform::errors::PreconditionNotMet("Currently, the broadcast op can "
"only be an In-Place operation."));
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto comm = dev_ctx.nccl_comm(); auto comm = dev_ctx.nccl_comm();
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
send_recv_buffer, static_cast<size_t>(in->numel()), send_recv_buffer, static_cast<size_t>(in->numel()),
platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream)); platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream));
......
...@@ -22,16 +22,20 @@ class EyeOp : public framework::OperatorWithKernel { ...@@ -22,16 +22,20 @@ class EyeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of EyeOP should not be null."); platform::errors::InvalidArgument(
"Output(Out) of EyeOP should not be null."));
auto num_rows = ctx->Attrs().Get<int64_t>("num_rows"); auto num_rows = ctx->Attrs().Get<int64_t>("num_rows");
PADDLE_ENFORCE(num_rows >= 0, PADDLE_ENFORCE_EQ(
"The value of Input(num_rows) should be non-negative int."); num_rows >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_rows) should be non-negative int."));
auto num_columns = ctx->Attrs().Get<int64_t>("num_columns"); auto num_columns = ctx->Attrs().Get<int64_t>("num_columns");
if (num_columns == -1) num_columns = num_rows; if (num_columns == -1) num_columns = num_rows;
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
num_columns >= 0, num_columns >= 0, true,
"The value of Input(num_columns) should be non-negative int."); platform::errors::InvalidArgument(
"The value of Input(num_columns) should be non-negative int."));
ctx->SetOutputDim("Out", {num_rows, num_columns}); ctx->SetOutputDim("Out", {num_rows, num_columns});
} }
......
...@@ -78,12 +78,14 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -78,12 +78,14 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
// check index of shape 1-D // check index of shape 1-D
if (index.dims().size() == 1) { if (index.dims().size() == 1) {
PADDLE_ENFORCE_GT(index.dims()[0], 0, PADDLE_ENFORCE_GT(index.dims()[0], 0,
"The index of gather_op should not be empty when the " platform::errors::InvalidArgument(
"index's rank is 1."); "The index of gather_op should not be empty"
"when the index's rank is 1."));
} else if (index.dims().size() == 2) { } else if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1, PADDLE_ENFORCE_EQ(index.dims()[1], 1,
" If the index's rank of gather_op is 2, the second " platform::errors::InvalidArgument(
"dimension should be 1."); "If the index's rank of gather_op is 2,"
" the second dimension should be 1."));
} }
int index_size = index.dims()[0]; int index_size = index.dims()[0];
......
...@@ -36,15 +36,23 @@ using framework::Tensor; ...@@ -36,15 +36,23 @@ using framework::Tensor;
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true); PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("It should be running on the CPU."));
// check index of shape 1-D // check index of shape 1-D
if (index.dims().size() == 2) { if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1, PADDLE_ENFORCE_EQ(
"index.dims()[1] should be 1 when index.dims().size() == " index.dims()[1], 1,
"2 in gather_op."); platform::errors::InvalidArgument(
"index.dims()[1] should be 1 when index.dims().size() = 2"
"in gather_op, but received value is [%d].",
index.dims()[1]));
} else { } else {
PADDLE_ENFORCE_EQ(index.dims().size(), 1, PADDLE_ENFORCE_EQ(index.dims().size(), 1,
"index.dims().size() should be 1 or 2 in gather_op."); platform::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in gather_op,"
"but received shape's size is [%d].",
index.dims().size()));
} }
int64_t index_size = index.dims()[0]; int64_t index_size = index.dims()[0];
...@@ -69,8 +77,9 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -69,8 +77,9 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input, void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(
"It should be running on the CPU"); platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("It should be running on the CPU."));
auto index_dims = index.dims(); auto index_dims = index.dims();
auto index_dims_size = index_dims.size(); auto index_dims_size = index_dims.size();
...@@ -98,11 +107,14 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input, ...@@ -98,11 +107,14 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
int64_t temp = 1; int64_t temp = 1;
for (int64_t j = end_size - 1; j >= 0; --j) { for (int64_t j = end_size - 1; j >= 0; --j) {
IndexT index_value = p_index[i * end_size + j]; IndexT index_value = p_index[i * end_size + j];
PADDLE_ENFORCE_LT(index_value, input_dims[j], PADDLE_ENFORCE_LT(
"Input(index[-1)] has wrong value, it is %d", index_value, input_dims[j],
index_value); platform::errors::InvalidArgument(
PADDLE_ENFORCE_GE(index_value, 0UL, "Input(index[-1)] has wrong value, it is [%d]", index_value));
"The value of Input(index) must be no less than 0"); PADDLE_ENFORCE_GE(
index_value, 0UL,
platform::errors::InvalidArgument(
"The value of Input(index) must be no less than 0"));
index_ += (index_value * temp); index_ += (index_value * temp);
temp *= input_dims[j]; temp *= input_dims[j];
......
...@@ -27,11 +27,14 @@ class GatherNdOp : public framework::OperatorWithKernel { ...@@ -27,11 +27,14 @@ class GatherNdOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of GatherNdOp should not be null."); platform::errors::InvalidArgument(
"Input(X) of GatherNdOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true, PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
"Input(Index) of GatherNdOp should not be null."); platform::errors::InvalidArgument(
"Input(Index) of GatherNdOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of GatherNdOp should not be null."); platform::errors::InvalidArgument(
"Output(Out) of GatherNdOp should not be null."));
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_dims_size = x_dims.size(); auto x_dims_size = x_dims.size();
...@@ -40,9 +43,11 @@ class GatherNdOp : public framework::OperatorWithKernel { ...@@ -40,9 +43,11 @@ class GatherNdOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
index_dims[index_dims_size - 1], x_dims_size, index_dims[index_dims_size - 1], x_dims_size,
"Input(Index).shape[-1] should be no greater than Input(X).rank"); platform::errors::InvalidArgument(
"Input(Index).shape[-1] should be no greater than Input(X).rank"));
PADDLE_ENFORCE_GE(index_dims_size, 2UL, PADDLE_ENFORCE_GE(index_dims_size, 2UL,
"The rank of Input(Index) should be greater than 1"); platform::errors::InvalidArgument(
"The rank of Input(Index) should be greater than 1"));
std::vector<int64_t> result_dims; std::vector<int64_t> result_dims;
// The result dims is // The result dims is
......
...@@ -25,7 +25,8 @@ class GatherNdOpCUDAKernel : public framework::OpKernel<T> { ...@@ -25,7 +25,8 @@ class GatherNdOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device."); platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *x = ctx.Input<Tensor>("X"); auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out"); auto *output = ctx.Output<Tensor>("Out");
...@@ -35,12 +36,15 @@ class GatherNdOpCUDAKernel : public framework::OpKernel<T> { ...@@ -35,12 +36,15 @@ class GatherNdOpCUDAKernel : public framework::OpKernel<T> {
const auto &index_type = index->type(); const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(index_type_match, true,
index_type_match, true, platform::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, but desires to be %s or %s", "Index holds the wrong type, it holds [%s], but "
"desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUGatherNd<DeviceContext, T, int>(ctx, *x, *index, output); GPUGatherNd<DeviceContext, T, int>(ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
...@@ -54,7 +58,8 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -54,7 +58,8 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device."); platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
...@@ -70,12 +75,15 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -70,12 +75,15 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(index_type_match, true,
index_type_match, true, platform::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, but desires to be %s or %s", "Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUScatterNdAdd<DeviceContext, T, int>(ctx, *dO, *index, dX); GPUScatterNdAdd<DeviceContext, T, int>(ctx, *dO, *index, dX);
......
...@@ -27,8 +27,9 @@ template <typename T> ...@@ -27,8 +27,9 @@ template <typename T>
class GatherNdOpKernel : public framework::OpKernel<T> { class GatherNdOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(
"This kernel only runs on CPU."); platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *x = ctx.Input<Tensor>("X"); auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
...@@ -40,12 +41,15 @@ class GatherNdOpKernel : public framework::OpKernel<T> { ...@@ -40,12 +41,15 @@ class GatherNdOpKernel : public framework::OpKernel<T> {
const auto &index_type = index->type(); const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(index_type_match, true,
index_type_match, true, platform::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, but desires to be %s or %s", "Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
CPUGatherNd<T, int>(ctx.device_context(), *x, *index, output); CPUGatherNd<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
...@@ -58,8 +62,9 @@ template <typename T> ...@@ -58,8 +62,9 @@ template <typename T>
class GatherNdGradOpKernel : public framework::OpKernel<T> { class GatherNdGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(
"This kernel only runs on CPU."); platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
...@@ -73,12 +78,15 @@ class GatherNdGradOpKernel : public framework::OpKernel<T> { ...@@ -73,12 +78,15 @@ class GatherNdGradOpKernel : public framework::OpKernel<T> {
const auto &index_type = index->type(); const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(index_type_match, true,
index_type_match, true, platform::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, but desires to be %s or %s", "Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
ScatterNdAdd<T, int32_t>(ctx, *dO, *index, dX); ScatterNdAdd<T, int32_t>(ctx, *dO, *index, dX);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
......
...@@ -26,12 +26,15 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -26,12 +26,15 @@ class GatherOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of GatherOp should not be null."); platform::errors::InvalidArgument(
PADDLE_ENFORCE(ctx->HasInput("Index"), "Input(X) of GatherOp should not be null."));
"Input(Index) of GatherOp should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
PADDLE_ENFORCE(ctx->HasOutput("Out"), platform::errors::InvalidArgument(
"Output(Out) of GatherOp should not be null."); "Input(Index) of GatherOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of GatherOp should not be null."));
auto index_dims = ctx->GetInputDim("Index"); auto index_dims = ctx->GetInputDim("Index");
PADDLE_ENFORCE(index_dims.size() == 1 || PADDLE_ENFORCE(index_dims.size() == 1 ||
......
...@@ -24,8 +24,9 @@ template <typename T> ...@@ -24,8 +24,9 @@ template <typename T>
class GatherOpCUDAKernel : public framework::OpKernel<T> { class GatherOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device."); platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *x = ctx.Input<Tensor>("X"); auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out"); auto *output = ctx.Output<Tensor>("Out");
...@@ -35,12 +36,15 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -35,12 +36,15 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
const auto &index_type = index->type(); const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(index_type_match, true,
index_type_match, platform::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, but desires to be %s or %s", "Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int>(ctx.device_context(), *x, *index, output); GPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
...@@ -53,8 +57,9 @@ template <typename T> ...@@ -53,8 +57,9 @@ template <typename T>
class GatherGradOpCUDAKernel : public framework::OpKernel<T> { class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device."); platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
...@@ -69,12 +74,15 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -69,12 +74,15 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
const auto &index_type = index->type(); const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(index_type_match, true,
index_type_match, platform::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, but desires to be %s or %s", "Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int>(ctx, *dO, *index, dX, GPUScatterAssign<T, int>(ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite")); ctx.Attr<bool>("overwrite"));
......
...@@ -27,8 +27,9 @@ template <typename T> ...@@ -27,8 +27,9 @@ template <typename T>
class GatherOpKernel : public framework::OpKernel<T> { class GatherOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(
"This kernel only runs on CPU."); platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *x = ctx.Input<Tensor>("X"); auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
...@@ -40,12 +41,15 @@ class GatherOpKernel : public framework::OpKernel<T> { ...@@ -40,12 +41,15 @@ class GatherOpKernel : public framework::OpKernel<T> {
const auto &index_type = index->type(); const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(index_type_match, true,
index_type_match, platform::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, but desires to be %s or %s", "Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
CPUGather<T, int>(ctx.device_context(), *x, *index, output); CPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
...@@ -58,8 +62,9 @@ template <typename T> ...@@ -58,8 +62,9 @@ template <typename T>
class GatherGradientOpKernel : public framework::OpKernel<T> { class GatherGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(
"This kernel only runs on CPU."); platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
...@@ -76,12 +81,15 @@ class GatherGradientOpKernel : public framework::OpKernel<T> { ...@@ -76,12 +81,15 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
const auto &index_type = index->type(); const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(index_type_match, true,
index_type_match, platform::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, but desires to be %s or %s", "Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type), paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32), paddle::framework::DataTypeToString(
paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
if (overwrite) { if (overwrite) {
ScatterAssign<T, int32_t>(ctx.device_context(), *dO, *index, dX); ScatterAssign<T, int32_t>(ctx.device_context(), *dO, *index, dX);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册