未验证 提交 a2387ef2 编写于 作者: T TTerror 提交者: GitHub

fix concat_grad on kunlun (#32151)

* fix concat_grad on kunlun

* fix concat_grad on kunlun
上级 f8bab5b0
......@@ -13,7 +13,7 @@ if(NOT XPU_SDK_ROOT)
elseif(WITH_SUNWAY)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE)
else()
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_03_30.tar.gz" CACHE STRING "" FORCE)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_04_09.tar.gz" CACHE STRING "" FORCE)
endif()
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
......
......@@ -132,16 +132,14 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
axis = ComputeAxis(static_cast<int64_t>(axis),
static_cast<int64_t>(ins[0]->dims().size()));
// get output tensor that the name is not kEmptyVarName
std::vector<framework::Tensor*> outputs;
std::vector<int> choose_idx;
int n = 0;
std::vector<T*> ptrs(outs.size());
for (size_t j = 0; j < outs.size(); ++j) {
if (out_var_names[j] != framework::kEmptyVarName &&
outs[j]->numel() != 0UL) {
outs[j]->mutable_data<T>(ctx.GetPlace());
outputs.push_back(outs[j]);
choose_idx.push_back(j);
n++;
ptrs[j] = outs[j]->data<T>();
} else {
ptrs[j] = nullptr;
}
}
PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument(
......@@ -157,10 +155,10 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
axis, out_grad->dims().size()));
auto input_dims = ins[0]->dims();
std::vector<int> split_list(n);
std::vector<int> split_list(ins.size());
std::vector<int> xdims_list(input_dims.size());
int total_length = 0;
for (int i = 0; i < n; ++i) {
for (size_t i = 0; i < ins.size(); ++i) {
split_list[i] = ins[i]->dims()[axis];
total_length += ins[i]->dims()[axis];
}
......@@ -172,11 +170,6 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
}
xdims_list[axis] = total_length;
std::vector<T*> ptrs(n);
for (int i = 0; i < n; ++i) {
ptrs[i] = outputs[i]->data<T>();
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::split<T>(dev_ctx.x_context(), out_grad->data<T>(), ptrs,
xdims_list, split_list, axis);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册