未验证 提交 a2387ef2 编写于 作者: T TTerror 提交者: GitHub

fix concat_grad on kunlun (#32151)

* fix concat_grad on kunlun

* fix concat_grad on kunlun
上级 f8bab5b0
...@@ -13,7 +13,7 @@ if(NOT XPU_SDK_ROOT) ...@@ -13,7 +13,7 @@ if(NOT XPU_SDK_ROOT)
elseif(WITH_SUNWAY) elseif(WITH_SUNWAY)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE) SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE)
else() else()
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_03_30.tar.gz" CACHE STRING "" FORCE) SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_04_09.tar.gz" CACHE STRING "" FORCE)
endif() endif()
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
......
...@@ -132,16 +132,14 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> { ...@@ -132,16 +132,14 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
axis = ComputeAxis(static_cast<int64_t>(axis), axis = ComputeAxis(static_cast<int64_t>(axis),
static_cast<int64_t>(ins[0]->dims().size())); static_cast<int64_t>(ins[0]->dims().size()));
// get output tensor that the name is not kEmptyVarName // get output tensor that the name is not kEmptyVarName
std::vector<framework::Tensor*> outputs; std::vector<T*> ptrs(outs.size());
std::vector<int> choose_idx;
int n = 0;
for (size_t j = 0; j < outs.size(); ++j) { for (size_t j = 0; j < outs.size(); ++j) {
if (out_var_names[j] != framework::kEmptyVarName && if (out_var_names[j] != framework::kEmptyVarName &&
outs[j]->numel() != 0UL) { outs[j]->numel() != 0UL) {
outs[j]->mutable_data<T>(ctx.GetPlace()); outs[j]->mutable_data<T>(ctx.GetPlace());
outputs.push_back(outs[j]); ptrs[j] = outs[j]->data<T>();
choose_idx.push_back(j); } else {
n++; ptrs[j] = nullptr;
} }
} }
PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument(
...@@ -157,10 +155,10 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> { ...@@ -157,10 +155,10 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
axis, out_grad->dims().size())); axis, out_grad->dims().size()));
auto input_dims = ins[0]->dims(); auto input_dims = ins[0]->dims();
std::vector<int> split_list(n); std::vector<int> split_list(ins.size());
std::vector<int> xdims_list(input_dims.size()); std::vector<int> xdims_list(input_dims.size());
int total_length = 0; int total_length = 0;
for (int i = 0; i < n; ++i) { for (size_t i = 0; i < ins.size(); ++i) {
split_list[i] = ins[i]->dims()[axis]; split_list[i] = ins[i]->dims()[axis];
total_length += ins[i]->dims()[axis]; total_length += ins[i]->dims()[axis];
} }
...@@ -172,11 +170,6 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> { ...@@ -172,11 +170,6 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
} }
xdims_list[axis] = total_length; xdims_list[axis] = total_length;
std::vector<T*> ptrs(n);
for (int i = 0; i < n; ++i) {
ptrs[i] = outputs[i]->data<T>();
}
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::split<T>(dev_ctx.x_context(), out_grad->data<T>(), ptrs, int r = xpu::split<T>(dev_ctx.x_context(), out_grad->data<T>(), ptrs,
xdims_list, split_list, axis); xdims_list, split_list, axis);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册