未验证 提交 caf3680b 编写于 作者: T taixiurong 提交者: GitHub

fix bugs in transformer predict in xpu place (#30730)

* transformer predict

* trans bug fix
上级 a87d78f1
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class ArrayOp : public framework::OperatorBase { class ArrayOp : public framework::OperatorBase {
public: public:
ArrayOp(const std::string &type, const framework::VariableNameMap &inputs, ArrayOp(const std::string &type, const framework::VariableNameMap &inputs,
...@@ -45,7 +46,8 @@ class ArrayOp : public framework::OperatorBase { ...@@ -45,7 +46,8 @@ class ArrayOp : public framework::OperatorBase {
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
size_t offset; size_t offset;
if (platform::is_gpu_place(i_tensor.place())) { if (platform::is_gpu_place(i_tensor.place()) ||
platform::is_xpu_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU // FIXME: Avoid copy from GPU to CPU
framework::Tensor t; framework::Tensor t;
framework::TensorCopy(i_tensor, platform::CPUPlace(), dev_ctx, &t); framework::TensorCopy(i_tensor, platform::CPUPlace(), dev_ctx, &t);
......
...@@ -47,19 +47,6 @@ class ConcatXPUKernel : public framework::OpKernel<T> { ...@@ -47,19 +47,6 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
"size is %d.", "size is %d.",
axis, ins[0]->dims().size())); axis, ins[0]->dims().size()));
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
std::vector<int> choose_idx;
int n = 0;
for (unsigned int i = 0; i < ins.size(); ++i) {
if (ins[i] && ins[i]->numel() > 0) {
choose_idx.push_back(i);
n++;
}
}
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument("No tensor need concat?"));
// If axis is 0, the lod of the output is not the same as inputs. // If axis is 0, the lod of the output is not the same as inputs.
if (axis == 0 && ins[0]->lod().size() > 0) { if (axis == 0 && ins[0]->lod().size() > 0) {
size_t lod_size_0 = ins[0]->lod().size(); size_t lod_size_0 = ins[0]->lod().size();
...@@ -87,30 +74,32 @@ class ConcatXPUKernel : public framework::OpKernel<T> { ...@@ -87,30 +74,32 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
} }
} }
} }
auto place = ctx.GetPlace();
auto input_dims = ins[0]->dims(); out->mutable_data<T>(place);
std::vector<std::vector<int>> xdims_list(n); std::vector<std::vector<int>> xdims_list;
for (int i = 0; i < n; ++i) { std::vector<const T*> ptrs;
std::vector<int> tmp_dims(input_dims.size()); for (unsigned int i = 0; i < ins.size(); ++i) {
for (int j = 0; j < input_dims.size(); ++j) { if (ins[i] && ins[i]->numel() > 0) {
ptrs.push_back(ins[i]->data<T>());
int size = ins[i]->dims().size();
std::vector<int> tmp_dims(size);
for (int j = 0; j < size; ++j) {
tmp_dims[j] = ins[i]->dims()[j]; tmp_dims[j] = ins[i]->dims()[j];
} }
xdims_list[i] = tmp_dims; xdims_list.push_back(tmp_dims);
}
} }
PADDLE_ENFORCE_GT(xdims_list.size(), 0, platform::errors::InvalidArgument(
"No tensor need concat"));
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::vector<const T*> ptrs;
for (int i = 0; i < n; ++i) {
ptrs.push_back(ins[choose_idx[i]]->data<T>());
}
int r = xpu::concat<T>(dev_ctx.x_context(), ptrs, out->data<T>(), int r = xpu::concat<T>(dev_ctx.x_context(), ptrs, out->data<T>(),
xdims_list, axis); xdims_list, axis);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
"XPU API return wrong value[%d], please check whether " "XPU concat kernel return wrong value[%d %s]", r,
"Baidu Kunlun Card is properly installed.", XPUAPIErrorMsg[r]));
r));
} }
}; };
......
...@@ -380,11 +380,20 @@ class ReshapeKernel { ...@@ -380,11 +380,20 @@ class ReshapeKernel {
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) { if (platform::is_xpu_place(ctx.GetPlace())) {
void *out_ptr = out->data<void>();
const void *in_ptr = in->data<void>();
if ((out_ptr != nullptr) && (in_ptr != nullptr) &&
(paddle::framework::SizeOfType(in->type()) > 0)) {
auto &dev_ctx = auto &dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>(); ctx.template device_context<paddle::platform::XPUDeviceContext>();
xpu::memcpy_device( int r = xpu::memcpy_device(
dev_ctx.x_context(), out->data<void>(), in->data<void>(), dev_ctx.x_context(), out_ptr, in_ptr,
in->numel() * paddle::framework::SizeOfType(in->type())); in->numel() * paddle::framework::SizeOfType(in->type()));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU memcpy_device return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
} else { } else {
#endif #endif
framework::TensorCopy( framework::TensorCopy(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册