未验证 提交 15fac5e7 编写于 作者: L liuyuhui 提交者: GitHub

fix assign_op_xpu concat_op_xpu warining (#30120)

上级 f5428eca
......@@ -276,7 +276,7 @@ class FuseAllReduceOpPass : public ir::Pass {
ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce, multi_nccl_ctxs);
#elif defined(PADDLE_WITH_XPU_BKCL)
auto *op_handle = new details::FusedAllReduceOpHandle(
op_handle = new details::FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce",
ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce, multi_bkcl_ctxs);
......
......@@ -522,7 +522,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
scopes, places, grad_merge_cond_name, multi_nccl_ctxs_));
#elif defined(PADDLE_WITH_XPU_BKCL)
result->Get<GraphOps>(kGraphOps).emplace_back(
new datails::GradMergeAllReduceOpHandle(
new details::GradMergeAllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, grad_merge_cond_name, multi_bkcl_ctxs_));
#else
......
......@@ -36,11 +36,16 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
"XPU donot surpport AxisTensor for now"));
axis = ComputeAxis(static_cast<int64_t>(axis),
static_cast<int64_t>(ins[0]->dims().size()));
PADDLE_ENFORCE_GE(
axis, 0, platform::errors::InvalidArgument("concat: axis shoud >= 0!"));
PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument(
"concat: axis should be larger than or "
"equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis, ins[0]->dims().size(),
platform::errors::InvalidArgument(
"concat: axis shoud < ins[0]->dims()!"));
"concat: axis should be less than ins[0]->dims()!"
"But received axis is %d, while ins[0]->dims()"
"size is %d.",
axis, ins[0]->dims().size()));
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
......@@ -151,10 +156,16 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
}
}
PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument(
"concat_grad: axis shoud >= 0!"));
PADDLE_ENFORCE_LT(axis, out_grad->dims().size(),
"concat_grad: axis should be larger than or "
"equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(
axis, out_grad->dims().size(),
platform::errors::InvalidArgument(
"concat_grad: axis shoud < ins[0]->dims()!"));
"concat_grad: axis should be less than ins[0]->dims()!"
"But received axis is %d, while ins[0]->dims()"
"size is %d.",
axis, out_grad->dims().size()));
auto input_dims = ins[0]->dims();
std::vector<int> split_list(n);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册