未验证 提交 15fac5e7 编写于 作者: L liuyuhui 提交者: GitHub

fix assign_op_xpu concat_op_xpu warining (#30120)

上级 f5428eca
...@@ -276,7 +276,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -276,7 +276,7 @@ class FuseAllReduceOpPass : public ir::Pass {
ir::Node::Type::kOperation), ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce, multi_nccl_ctxs); local_scopes, places, num_of_all_reduce, multi_nccl_ctxs);
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
auto *op_handle = new details::FusedAllReduceOpHandle( op_handle = new details::FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", result->CreateEmptyNode("fused_all_reduce",
ir::Node::Type::kOperation), ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce, multi_bkcl_ctxs); local_scopes, places, num_of_all_reduce, multi_bkcl_ctxs);
......
...@@ -522,7 +522,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result, ...@@ -522,7 +522,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
scopes, places, grad_merge_cond_name, multi_nccl_ctxs_)); scopes, places, grad_merge_cond_name, multi_nccl_ctxs_));
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new datails::GradMergeAllReduceOpHandle( new details::GradMergeAllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, grad_merge_cond_name, multi_bkcl_ctxs_)); scopes, places, grad_merge_cond_name, multi_bkcl_ctxs_));
#else #else
......
...@@ -36,11 +36,16 @@ class ConcatXPUKernel : public framework::OpKernel<T> { ...@@ -36,11 +36,16 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
"XPU donot surpport AxisTensor for now")); "XPU donot surpport AxisTensor for now"));
axis = ComputeAxis(static_cast<int64_t>(axis), axis = ComputeAxis(static_cast<int64_t>(axis),
static_cast<int64_t>(ins[0]->dims().size())); static_cast<int64_t>(ins[0]->dims().size()));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument(
axis, 0, platform::errors::InvalidArgument("concat: axis shoud >= 0!")); "concat: axis should be larger than or "
"equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis, ins[0]->dims().size(), PADDLE_ENFORCE_LT(axis, ins[0]->dims().size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"concat: axis shoud < ins[0]->dims()!")); "concat: axis should be less than ins[0]->dims()!"
"But received axis is %d, while ins[0]->dims()"
"size is %d.",
axis, ins[0]->dims().size()));
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
out->mutable_data<T>(place); out->mutable_data<T>(place);
...@@ -151,10 +156,16 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> { ...@@ -151,10 +156,16 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
} }
} }
PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument(
"concat_grad: axis shoud >= 0!")); "concat_grad: axis should be larger than or "
PADDLE_ENFORCE_LT(axis, out_grad->dims().size(), "equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(
axis, out_grad->dims().size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"concat_grad: axis shoud < ins[0]->dims()!")); "concat_grad: axis should be less than ins[0]->dims()!"
"But received axis is %d, while ins[0]->dims()"
"size is %d.",
axis, out_grad->dims().size()));
auto input_dims = ins[0]->dims(); auto input_dims = ins[0]->dims();
std::vector<int> split_list(n); std::vector<int> split_list(n);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册