未验证 提交 7b860a23 编写于 作者: T taixiurong 提交者: GitHub

1.fix elementwise_add_grad bug. 2. add dropout kernel in kl2 (#38726)

上级 066a8063
...@@ -488,6 +488,14 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -488,6 +488,14 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
} }
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dst_place), dst_ptr, memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dst_place), dst_ptr,
BOOST_GET_CONST(platform::XPUPlace, src_place), src_ptr, size); BOOST_GET_CONST(platform::XPUPlace, src_place), src_ptr, size);
platform::XPUPlace xpu_dst_place =
BOOST_GET_CONST(platform::XPUPlace, dst_place);
platform::XPUPlace xpu_src_place =
BOOST_GET_CONST(platform::XPUPlace, src_place);
if (xpu_dst_place.device == xpu_src_place.device) {
auto xpu_ctx = platform::DeviceContextPool::Instance().Get(xpu_dst_place);
xpu_ctx->Wait();
}
} }
else { // NOLINT else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
......
...@@ -66,7 +66,7 @@ void Copy<platform::XPUPlace, platform::CPUPlace>(platform::XPUPlace dst_place, ...@@ -66,7 +66,7 @@ void Copy<platform::XPUPlace, platform::CPUPlace>(platform::XPUPlace dst_place,
VLOG(1) << "memcpy XPU_HOST_TO_DEVICE size <= 0 (" << num << ")"; VLOG(1) << "memcpy XPU_HOST_TO_DEVICE size <= 0 (" << num << ")";
return; return;
} }
platform::MemcpySyncH2D(dst, src, num, dst_place.device); platform::MemcpySyncH2D(dst, src, num, dst_place);
} }
template <> template <>
...@@ -78,7 +78,7 @@ void Copy<platform::CPUPlace, platform::XPUPlace>(platform::CPUPlace dst_place, ...@@ -78,7 +78,7 @@ void Copy<platform::CPUPlace, platform::XPUPlace>(platform::CPUPlace dst_place,
VLOG(1) << "memcpy XPU_DEVICE_TO_HOST size <= 0 (" << num << ")"; VLOG(1) << "memcpy XPU_DEVICE_TO_HOST size <= 0 (" << num << ")";
return; return;
} }
platform::MemcpySyncD2H(dst, src, num, src_place.device); platform::MemcpySyncD2H(dst, src, num, src_place);
} }
template <> template <>
...@@ -90,7 +90,7 @@ void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place, ...@@ -90,7 +90,7 @@ void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
VLOG(1) << "memcpy XPU_DEVICE_TO_DEVICE size <= 0 (" << num << ")"; VLOG(1) << "memcpy XPU_DEVICE_TO_DEVICE size <= 0 (" << num << ")";
return; return;
} }
platform::MemcpySyncD2D(dst, dst_place.device, src, src_place.device, num); platform::MemcpySyncD2D(dst, dst_place, src, src_place, num);
} }
#endif #endif
......
...@@ -11,7 +11,7 @@ limitations under the License. */ ...@@ -11,7 +11,7 @@ limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/dropout_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -55,17 +55,11 @@ class DropoutXPUKernel : public framework::OpKernel<T> { ...@@ -55,17 +55,11 @@ class DropoutXPUKernel : public framework::OpKernel<T> {
int r = xpu::constant(dev_ctx.x_context(), int r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUTyp*>(y_data), y->numel(), reinterpret_cast<XPUTyp*>(y_data), y->numel(),
XPUTyp(0)); XPUTyp(0));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant ");
"XPU API(constant) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::constant(dev_ctx.x_context(), r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUTyp*>(mask_data), mask->numel(), reinterpret_cast<XPUTyp*>(mask_data), mask->numel(),
XPUTyp(0)); XPUTyp(0));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant ");
"XPU API(constant) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
return; return;
} }
int r = xpu::dropout(dev_ctx.x_context(), int r = xpu::dropout(dev_ctx.x_context(),
...@@ -73,26 +67,20 @@ class DropoutXPUKernel : public framework::OpKernel<T> { ...@@ -73,26 +67,20 @@ class DropoutXPUKernel : public framework::OpKernel<T> {
reinterpret_cast<XPUTyp*>(y->data<T>()), reinterpret_cast<XPUTyp*>(y->data<T>()),
reinterpret_cast<XPUTyp*>(mask_data), seed, reinterpret_cast<XPUTyp*>(mask_data), seed,
mask->numel(), is_upscale, dropout_prob); mask->numel(), is_upscale, dropout_prob);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout ");
"XPU API(dropout) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
} else { } else {
float scale = float scale =
(is_upscale) ? (1.0) : (static_cast<float>(1.0f - dropout_prob)); (is_upscale) ? (1.0) : (static_cast<float>(1.0f - dropout_prob));
int r = xpu::scale( int r = xpu::scale(
dev_ctx.x_context(), reinterpret_cast<const XPUTyp*>(x_data), dev_ctx.x_context(), reinterpret_cast<const XPUTyp*>(x_data),
reinterpret_cast<XPUTyp*>(y_data), x->numel(), false, scale, 0.0f); reinterpret_cast<XPUTyp*>(y_data), x->numel(), false, scale, 0.0f);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale ");
"XPU API(scale) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
} }
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class DropoutGradXPUKernel : public framework::OpKernel<T> { class DropoutGradXPUKernel : public framework::OpKernel<T> {
using XPUTyp = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -108,31 +96,43 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> { ...@@ -108,31 +96,43 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> {
context.Attr<std::string>("dropout_implementation"); context.Attr<std::string>("dropout_implementation");
float dropout_prob = context.Attr<float>("dropout_prob"); float dropout_prob = context.Attr<float>("dropout_prob");
const T* mask_data = mask->data<T>(); const T* mask_data = mask->data<T>();
framework::Tensor mask_new;
if (dropout_implementation == "upscale_in_train") { if (dropout_implementation != "upscale_in_train") {
mask_new = context.AllocateTmpTensor<T, platform::XPUDeviceContext>( int r = xpu::mul(dev_ctx.x_context(),
mask->dims(), dev_ctx); reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<const XPUType*>(mask_data),
reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul ");
return;
}
paddle::platform::XPUVersion version = dev_ctx.xpu_version();
if (version == paddle::platform::XPUVersion::XPU1) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* mask_new = RAII_GUARD.alloc_l3_or_gm<XPUType>(mask->numel());
float scale = float scale =
(dropout_prob == 1.0f) ? (1.0f) : (1.0f / (1.0f - dropout_prob)); (dropout_prob == 1.0f) ? (1.0f) : (1.0f / (1.0f - dropout_prob));
int r = xpu::scale(dev_ctx.x_context(), int r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUTyp*>(mask->data<T>()), reinterpret_cast<const XPUType*>(mask->data<T>()),
reinterpret_cast<XPUTyp*>(mask_new.data<T>()), reinterpret_cast<XPUType*>(mask_new), mask->numel(),
mask->numel(), false, scale, 0.0f); false, scale, 0.0f);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale ");
"XPU API(scale) return wrong " r = xpu::mul(dev_ctx.x_context(),
"value[%d %s]", reinterpret_cast<const XPUType*>(grad_y->data<T>()),
r, XPUAPIErrorMsg[r])); reinterpret_cast<const XPUType*>(mask_new),
mask_data = mask_new.data<T>(); reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul ");
} else {
int r =
xpu::dropout_grad(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(mask->data<T>()),
reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<XPUType*>(grad_x->data<T>()),
dropout_prob, grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad ");
} }
int r = xpu::mul(
dev_ctx.x_context(), reinterpret_cast<const XPUTyp*>(grad_y->data<T>()),
reinterpret_cast<const XPUTyp*>(mask_data),
reinterpret_cast<XPUTyp*>(grad_x->data<T>()), grad_y->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(mul) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h" #include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -106,39 +107,43 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -106,39 +107,43 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
const T* dz_data = dz->data<T>(); const T* dz_data = dz->data<T>();
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>(); ctx.template device_context<paddle::platform::XPUDeviceContext>();
if (dx != nullptr) { if (dx != nullptr) {
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (rdims_for_x.size() == 0) { if (rdims_for_x.size() == 0) {
if (dx_data != dz_data) {
framework::TensorCopy( framework::TensorCopy(
*dz, ctx.GetPlace(), *dz, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx); ctx.template device_context<platform::DeviceContext>(), dx);
}
} else { } else {
T* dx_data = dx->mutable_data<T>(ctx.GetPlace()); // For inplace strategy, dx will be stored in addr of dz, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(*dz)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
int ret = xpu::reduce_sum<XPUType>( int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data), dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dx_data), z_dims_vec, rdims_for_x); reinterpret_cast<XPUType*>(dx_data), z_dims_vec, rdims_for_x);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum ");
ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in "
"XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
} }
} }
if (dy != nullptr) { if (dy != nullptr) {
T* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (rdims_for_y.size() == 0) { if (rdims_for_y.size() == 0) {
if (dy_data != dz_data) {
framework::TensorCopy( framework::TensorCopy(
*dz, ctx.GetPlace(), *dz, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy); ctx.template device_context<platform::DeviceContext>(), dy);
}
} else { } else {
T* dy_data = dy->mutable_data<T>(ctx.GetPlace());
int ret = xpu::reduce_sum<XPUType>( int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data), dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dy_data), z_dims_vec, rdims_for_y); reinterpret_cast<XPUType*>(dy_data), z_dims_vec, rdims_for_y);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum ");
ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in "
"XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
} }
} }
} }
......
...@@ -42,8 +42,10 @@ class MaskedSelectXPUKernel : public framework::OpKernel<T> { ...@@ -42,8 +42,10 @@ class MaskedSelectXPUKernel : public framework::OpKernel<T> {
int* out_size = RAII_GUARD.alloc_l3_or_gm<int32_t>(1); int* out_size = RAII_GUARD.alloc_l3_or_gm<int32_t>(1);
int out_size_cpu; int out_size_cpu;
PADDLE_ENFORCE_XPU_SUCCESS(xpu::nonzero_count( PADDLE_ENFORCE_XDNN_SUCCESS(
dev_ctx.x_context(), mask_data, out_size, mask->numel())); xpu::nonzero_count(dev_ctx.x_context(), mask_data, out_size,
mask->numel()),
"nonzero_count ");
memory::Copy(platform::CPUPlace(), static_cast<void*>(&out_size_cpu), memory::Copy(platform::CPUPlace(), static_cast<void*>(&out_size_cpu),
BOOST_GET_CONST(platform::XPUPlace, mask->place()), BOOST_GET_CONST(platform::XPUPlace, mask->place()),
static_cast<void*>(out_size), sizeof(int32_t)); static_cast<void*>(out_size), sizeof(int32_t));
...@@ -55,9 +57,10 @@ class MaskedSelectXPUKernel : public framework::OpKernel<T> { ...@@ -55,9 +57,10 @@ class MaskedSelectXPUKernel : public framework::OpKernel<T> {
auto input_shape = framework::vectorize<int>(input_dim); auto input_shape = framework::vectorize<int>(input_dim);
auto mask_shape = framework::vectorize<int>(mask_dim); auto mask_shape = framework::vectorize<int>(mask_dim);
PADDLE_ENFORCE_XPU_SUCCESS( PADDLE_ENFORCE_XDNN_SUCCESS(
xpu::masked_select(dev_ctx.x_context(), input_data, mask_data, out_data, xpu::masked_select(dev_ctx.x_context(), input_data, mask_data, out_data,
input_shape, mask_shape, out_size_cpu)); input_shape, mask_shape, out_size_cpu),
"masked_select");
} }
}; };
......
...@@ -4,7 +4,7 @@ endif() ...@@ -4,7 +4,7 @@ endif()
set(XPU_CTX_DEPS xpulib ssl crypto rt z resolv dl) set(XPU_CTX_DEPS xpulib ssl crypto rt z resolv dl)
cc_library(xpu_info SRCS xpu_info.cc DEPS gflags glog enforce xpulib) cc_library(xpu_info SRCS xpu_info.cc DEPS gflags glog enforce xpulib device_context place)
cc_library(xpu_op_list SRCS xpu_op_list.cc DEPS gflags glog enforce xpulib device_context) cc_library(xpu_op_list SRCS xpu_op_list.cc DEPS gflags glog enforce xpulib device_context)
add_subdirectory(tests) add_subdirectory(tests)
...@@ -113,6 +113,23 @@ inline const char* bkclGetErrorString(BKCLResult_t stat) { ...@@ -113,6 +113,23 @@ inline const char* bkclGetErrorString(BKCLResult_t stat) {
} }
} }
inline const char* xdnnGetErrorString(int stat) {
switch (stat) {
case xpu::Error_t::SUCCESS:
return "XDNN_SUCCESS";
case xpu::Error_t::INVALID_PARAM:
return "XDNN_INVALID_PARAM";
case xpu::Error_t::RUNTIME_ERROR:
return "XDNN_RUNTIME_ERROR";
case xpu::Error_t::NO_ENOUGH_WORKSPACE:
return "XDNN_NO_ENOUGH_WORKSPACE";
case xpu::Error_t::NOT_IMPLEMENT:
return "XDNN_NOT_IMPLEMENT";
default:
return "Unknown XDNN status";
}
}
inline std::string build_xpu_error_msg(int stat) { inline std::string build_xpu_error_msg(int stat) {
std::string msg("XPU Error <" + std::to_string(stat) + ">, "); std::string msg("XPU Error <" + std::to_string(stat) + ">, ");
return msg + xpuGetErrorString(stat) + " "; return msg + xpuGetErrorString(stat) + " ";
...@@ -123,6 +140,10 @@ inline std::string build_xpu_error_msg(BKCLResult_t stat) { ...@@ -123,6 +140,10 @@ inline std::string build_xpu_error_msg(BKCLResult_t stat) {
return msg + bkclGetErrorString(stat) + " "; return msg + bkclGetErrorString(stat) + " ";
} }
inline std::string build_xpu_xdnn_error_msg(int stat, std::string msg) {
return msg + " XDNN Error, " + xdnnGetErrorString(stat) + " ";
}
namespace details { namespace details {
template <typename T> template <typename T>
...@@ -156,5 +177,15 @@ DEFINE_EXTERNAL_API_TYPE(BKCLResult_t, BKCL_SUCCESS); ...@@ -156,5 +177,15 @@ DEFINE_EXTERNAL_API_TYPE(BKCLResult_t, BKCL_SUCCESS);
} \ } \
} while (0) } while (0)
#define PADDLE_ENFORCE_XDNN_SUCCESS(COND, MSG) \
do { \
auto __cond__ = (COND); \
if (UNLIKELY(__cond__ != xpu::Error_t::SUCCESS)) { \
auto __summary__ = paddle::platform::errors::External( \
::paddle::platform::build_xpu_xdnn_error_msg(__cond__, MSG)); \
__THROW_ERROR_INTERNAL__(__summary__); \
} \
} while (0)
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -33,6 +33,24 @@ bool CheckXPUStatusFailure(T value, const std::string& msg) { ...@@ -33,6 +33,24 @@ bool CheckXPUStatusFailure(T value, const std::string& msg) {
} }
} }
template <typename T>
bool CheckXDNNStatusSuccess(T value, const std::string& msg = "success") {
PADDLE_ENFORCE_XDNN_SUCCESS(value, "XDNN Error ");
return true;
}
template <typename T>
bool CheckXDNNStatusFailure(T value, const std::string& msg) {
try {
PADDLE_ENFORCE_XDNN_SUCCESS(value, "XDNN Error ");
return false;
} catch (paddle::platform::EnforceNotMet& error) {
std::string ex_msg = error.what();
std::cout << ex_msg << std::endl;
return ex_msg.find(msg) != std::string::npos;
}
}
TEST(enforce, xpu_status) { TEST(enforce, xpu_status) {
EXPECT_TRUE(CheckXPUStatusSuccess(static_cast<int>(XPU_SUCCESS))); EXPECT_TRUE(CheckXPUStatusSuccess(static_cast<int>(XPU_SUCCESS)));
EXPECT_TRUE(CheckXPUStatusFailure(static_cast<int>(XPUERR_INVALID_DEVICE), EXPECT_TRUE(CheckXPUStatusFailure(static_cast<int>(XPUERR_INVALID_DEVICE),
...@@ -114,3 +132,15 @@ TEST(enforce, bkcl_status) { ...@@ -114,3 +132,15 @@ TEST(enforce, bkcl_status) {
EXPECT_TRUE( EXPECT_TRUE(
CheckXPUStatusFailure(BKCL_INTERNAL_ERROR, "BKCL_INTERNAL_ERROR")); CheckXPUStatusFailure(BKCL_INTERNAL_ERROR, "BKCL_INTERNAL_ERROR"));
} }
TEST(enforce, xdnn_status) {
EXPECT_TRUE(CheckXDNNStatusSuccess(xpu::Error_t::SUCCESS));
EXPECT_TRUE(CheckXDNNStatusFailure(xpu::Error_t::INVALID_PARAM,
"XDNN_INVALID_PARAM"));
EXPECT_TRUE(CheckXDNNStatusFailure(xpu::Error_t::RUNTIME_ERROR,
"XDNN_RUNTIME_ERROR"));
EXPECT_TRUE(CheckXDNNStatusFailure(xpu::Error_t::NO_ENOUGH_WORKSPACE,
"XDNN_NO_ENOUGH_WORKSPACE"));
EXPECT_TRUE(CheckXDNNStatusFailure(xpu::Error_t::NOT_IMPLEMENT,
"XDNN_NOT_IMPLEMENT"));
}
...@@ -14,8 +14,11 @@ limitations under the License. */ ...@@ -14,8 +14,11 @@ limitations under the License. */
#include <cstdlib> #include <cstdlib>
#include <string> #include <string>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h" #include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h" #include "paddle/fluid/string/split.h"
PADDLE_DEFINE_EXPORTED_string( PADDLE_DEFINE_EXPORTED_string(
...@@ -56,7 +59,7 @@ int GetRuntimeVersion() { ...@@ -56,7 +59,7 @@ int GetRuntimeVersion() {
/**************************** Device Management **************************/ /**************************** Device Management **************************/
static int GetDeviceCountImpl() { static int GetDeviceCountImpl() {
const auto *xpu_visible_devices = std::getenv("XPU_VISIBLE_DEVICES"); const auto* xpu_visible_devices = std::getenv("XPU_VISIBLE_DEVICES");
if (xpu_visible_devices != nullptr) { if (xpu_visible_devices != nullptr) {
std::string xpu_visible_devices_str(xpu_visible_devices); std::string xpu_visible_devices_str(xpu_visible_devices);
if (std::all_of(xpu_visible_devices_str.begin(), if (std::all_of(xpu_visible_devices_str.begin(),
...@@ -114,28 +117,39 @@ std::vector<int> GetXPUSelectedDevices() { ...@@ -114,28 +117,39 @@ std::vector<int> GetXPUSelectedDevices() {
/**************************** Memory Management **************************/ /**************************** Memory Management **************************/
void MemcpySyncH2D(void *dst, const void *src, size_t count, int dev_id) { void MemcpySyncH2D(void* dst, const void* src, size_t count,
platform::XPUDeviceGuard guard(dev_id); const platform::XPUPlace& dst_place) {
platform::XPUDeviceGuard guard(dst_place.device);
PADDLE_ENFORCE_XPU_SUCCESS( PADDLE_ENFORCE_XPU_SUCCESS(
xpu_memcpy(dst, src, count, XPUMemcpyKind::XPU_HOST_TO_DEVICE)); xpu_memcpy(dst, src, count, XPUMemcpyKind::XPU_HOST_TO_DEVICE));
} }
void MemcpySyncD2H(void *dst, const void *src, size_t count, int dev_id) { void MemcpySyncD2H(void* dst, const void* src, size_t count,
platform::XPUDeviceGuard guard(dev_id); const platform::XPUPlace& src_place) {
platform::XPUDeviceGuard guard(src_place.device);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.GetByPlace(src_place);
dev_ctx->Wait();
PADDLE_ENFORCE_XPU_SUCCESS( PADDLE_ENFORCE_XPU_SUCCESS(
xpu_memcpy(dst, src, count, XPUMemcpyKind::XPU_DEVICE_TO_HOST)); xpu_memcpy(dst, src, count, XPUMemcpyKind::XPU_DEVICE_TO_HOST));
} }
void MemcpySyncD2D(void *dst, int dst_id, const void *src, int src_id, // if src.device == dst.device and you need sync , after call this function,
// need to call xpu_wait()
void MemcpySyncD2D(void* dst, const platform::XPUPlace& dst_place,
const void* src, const platform::XPUPlace& src_place,
size_t count) { size_t count) {
int dev_id = GetXPUCurrentDeviceId(); int dev_id = GetXPUCurrentDeviceId();
if (dst_id == dev_id && src_id == dev_id) { if (dst_place.device == dev_id && src_place.device == dev_id) {
platform::XPUDeviceGuard guard(dev_id); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
PADDLE_ENFORCE_XPU_SUCCESS( auto* dev_ctx = pool.GetByPlace(src_place);
xpu_memcpy(dst, src, count, XPUMemcpyKind::XPU_DEVICE_TO_DEVICE)); PADDLE_ENFORCE_XDNN_SUCCESS(
xpu::copy(dev_ctx->x_context(), static_cast<const int8_t*>(src),
static_cast<int8_t*>(dst), count),
"copy ");
} else { } else {
PADDLE_ENFORCE_XPU_SUCCESS( PADDLE_ENFORCE_XPU_SUCCESS(
xpu_memcpy_peer(dst_id, dst, src_id, src, count)); xpu_memcpy_peer(dst_place.device, dst, src_place.device, src, count));
} }
} }
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
class XPUPlace;
/***** Version Management *****/ /***** Version Management *****/
//! Get the version of XPU Driver //! Get the version of XPU Driver
...@@ -41,9 +42,12 @@ std::vector<int> GetXPUSelectedDevices(); ...@@ -41,9 +42,12 @@ std::vector<int> GetXPUSelectedDevices();
/***** Memory Management *****/ /***** Memory Management *****/
//! Copy memory from address src to dst synchronously. //! Copy memory from address src to dst synchronously.
void MemcpySyncH2D(void *dst, const void *src, size_t count, int dev_id); void MemcpySyncH2D(void *dst, const void *src, size_t count,
void MemcpySyncD2H(void *dst, const void *src, size_t count, int dev_id); const platform::XPUPlace &dst_place);
void MemcpySyncD2D(void *dst, int dst_id, const void *src, int src_id, void MemcpySyncD2H(void *dst, const void *src, size_t count,
const platform::XPUPlace &src_place);
void MemcpySyncD2D(void *dst, const platform::XPUPlace &dst_place,
const void *src, const platform::XPUPlace &src_place,
size_t count); size_t count);
class XPUDeviceGuard { class XPUDeviceGuard {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册