未验证 提交 7b860a23 编写于 作者: T taixiurong 提交者: GitHub

1.fix elementwise_add_grad bug. 2. add dropout kernel in kl2 (#38726)

上级 066a8063
......@@ -488,6 +488,14 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
}
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dst_place), dst_ptr,
BOOST_GET_CONST(platform::XPUPlace, src_place), src_ptr, size);
platform::XPUPlace xpu_dst_place =
BOOST_GET_CONST(platform::XPUPlace, dst_place);
platform::XPUPlace xpu_src_place =
BOOST_GET_CONST(platform::XPUPlace, src_place);
if (xpu_dst_place.device == xpu_src_place.device) {
auto xpu_ctx = platform::DeviceContextPool::Instance().Get(xpu_dst_place);
xpu_ctx->Wait();
}
}
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
......
......@@ -66,7 +66,7 @@ void Copy<platform::XPUPlace, platform::CPUPlace>(platform::XPUPlace dst_place,
VLOG(1) << "memcpy XPU_HOST_TO_DEVICE size <= 0 (" << num << ")";
return;
}
platform::MemcpySyncH2D(dst, src, num, dst_place.device);
platform::MemcpySyncH2D(dst, src, num, dst_place);
}
template <>
......@@ -78,7 +78,7 @@ void Copy<platform::CPUPlace, platform::XPUPlace>(platform::CPUPlace dst_place,
VLOG(1) << "memcpy XPU_DEVICE_TO_HOST size <= 0 (" << num << ")";
return;
}
platform::MemcpySyncD2H(dst, src, num, src_place.device);
platform::MemcpySyncD2H(dst, src, num, src_place);
}
template <>
......@@ -90,7 +90,7 @@ void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
VLOG(1) << "memcpy XPU_DEVICE_TO_DEVICE size <= 0 (" << num << ")";
return;
}
platform::MemcpySyncD2D(dst, dst_place.device, src, src_place.device, num);
platform::MemcpySyncD2D(dst, dst_place, src, src_place, num);
}
#endif
......
......@@ -11,7 +11,7 @@ limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
......@@ -55,17 +55,11 @@ class DropoutXPUKernel : public framework::OpKernel<T> {
int r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUTyp*>(y_data), y->numel(),
XPUTyp(0));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(constant) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant ");
r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUTyp*>(mask_data), mask->numel(),
XPUTyp(0));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(constant) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant ");
return;
}
int r = xpu::dropout(dev_ctx.x_context(),
......@@ -73,26 +67,20 @@ class DropoutXPUKernel : public framework::OpKernel<T> {
reinterpret_cast<XPUTyp*>(y->data<T>()),
reinterpret_cast<XPUTyp*>(mask_data), seed,
mask->numel(), is_upscale, dropout_prob);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(dropout) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout ");
} else {
float scale =
(is_upscale) ? (1.0) : (static_cast<float>(1.0f - dropout_prob));
int r = xpu::scale(
dev_ctx.x_context(), reinterpret_cast<const XPUTyp*>(x_data),
reinterpret_cast<XPUTyp*>(y_data), x->numel(), false, scale, 0.0f);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(scale) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale ");
}
}
};
template <typename DeviceContext, typename T>
class DropoutGradXPUKernel : public framework::OpKernel<T> {
using XPUTyp = typename XPUTypeTrait<T>::Type;
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -108,31 +96,43 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> {
context.Attr<std::string>("dropout_implementation");
float dropout_prob = context.Attr<float>("dropout_prob");
const T* mask_data = mask->data<T>();
framework::Tensor mask_new;
if (dropout_implementation == "upscale_in_train") {
mask_new = context.AllocateTmpTensor<T, platform::XPUDeviceContext>(
mask->dims(), dev_ctx);
if (dropout_implementation != "upscale_in_train") {
int r = xpu::mul(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<const XPUType*>(mask_data),
reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul ");
return;
}
paddle::platform::XPUVersion version = dev_ctx.xpu_version();
if (version == paddle::platform::XPUVersion::XPU1) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* mask_new = RAII_GUARD.alloc_l3_or_gm<XPUType>(mask->numel());
float scale =
(dropout_prob == 1.0f) ? (1.0f) : (1.0f / (1.0f - dropout_prob));
int r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUTyp*>(mask->data<T>()),
reinterpret_cast<XPUTyp*>(mask_new.data<T>()),
mask->numel(), false, scale, 0.0f);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(scale) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
mask_data = mask_new.data<T>();
reinterpret_cast<const XPUType*>(mask->data<T>()),
reinterpret_cast<XPUType*>(mask_new), mask->numel(),
false, scale, 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale ");
r = xpu::mul(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<const XPUType*>(mask_new),
reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul ");
} else {
int r =
xpu::dropout_grad(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(mask->data<T>()),
reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<XPUType*>(grad_x->data<T>()),
dropout_prob, grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad ");
}
int r = xpu::mul(
dev_ctx.x_context(), reinterpret_cast<const XPUTyp*>(grad_y->data<T>()),
reinterpret_cast<const XPUTyp*>(mask_data),
reinterpret_cast<XPUTyp*>(grad_x->data<T>()), grad_y->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(mul) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
} // namespace operators
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
......@@ -106,39 +107,43 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
const T* dz_data = dz->data<T>();
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
if (dx != nullptr) {
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (rdims_for_x.size() == 0) {
framework::TensorCopy(
*dz, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
if (dx_data != dz_data) {
framework::TensorCopy(
*dz, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
}
} else {
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
// For inplace strategy, dx will be stored in addr of dz, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(*dz)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dx_data), z_dims_vec, rdims_for_x);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in "
"XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum ");
}
}
if (dy != nullptr) {
T* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (rdims_for_y.size() == 0) {
framework::TensorCopy(
*dz, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
if (dy_data != dz_data) {
framework::TensorCopy(
*dz, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
}
} else {
T* dy_data = dy->mutable_data<T>(ctx.GetPlace());
int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dy_data), z_dims_vec, rdims_for_y);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in "
"XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum ");
}
}
}
......
......@@ -42,8 +42,10 @@ class MaskedSelectXPUKernel : public framework::OpKernel<T> {
int* out_size = RAII_GUARD.alloc_l3_or_gm<int32_t>(1);
int out_size_cpu;
PADDLE_ENFORCE_XPU_SUCCESS(xpu::nonzero_count(
dev_ctx.x_context(), mask_data, out_size, mask->numel()));
PADDLE_ENFORCE_XDNN_SUCCESS(
xpu::nonzero_count(dev_ctx.x_context(), mask_data, out_size,
mask->numel()),
"nonzero_count ");
memory::Copy(platform::CPUPlace(), static_cast<void*>(&out_size_cpu),
BOOST_GET_CONST(platform::XPUPlace, mask->place()),
static_cast<void*>(out_size), sizeof(int32_t));
......@@ -55,9 +57,10 @@ class MaskedSelectXPUKernel : public framework::OpKernel<T> {
auto input_shape = framework::vectorize<int>(input_dim);
auto mask_shape = framework::vectorize<int>(mask_dim);
PADDLE_ENFORCE_XPU_SUCCESS(
PADDLE_ENFORCE_XDNN_SUCCESS(
xpu::masked_select(dev_ctx.x_context(), input_data, mask_data, out_data,
input_shape, mask_shape, out_size_cpu));
input_shape, mask_shape, out_size_cpu),
"masked_select");
}
};
......
......@@ -4,7 +4,7 @@ endif()
set(XPU_CTX_DEPS xpulib ssl crypto rt z resolv dl)
cc_library(xpu_info SRCS xpu_info.cc DEPS gflags glog enforce xpulib)
cc_library(xpu_info SRCS xpu_info.cc DEPS gflags glog enforce xpulib device_context place)
cc_library(xpu_op_list SRCS xpu_op_list.cc DEPS gflags glog enforce xpulib device_context)
add_subdirectory(tests)
......@@ -113,6 +113,23 @@ inline const char* bkclGetErrorString(BKCLResult_t stat) {
}
}
inline const char* xdnnGetErrorString(int stat) {
switch (stat) {
case xpu::Error_t::SUCCESS:
return "XDNN_SUCCESS";
case xpu::Error_t::INVALID_PARAM:
return "XDNN_INVALID_PARAM";
case xpu::Error_t::RUNTIME_ERROR:
return "XDNN_RUNTIME_ERROR";
case xpu::Error_t::NO_ENOUGH_WORKSPACE:
return "XDNN_NO_ENOUGH_WORKSPACE";
case xpu::Error_t::NOT_IMPLEMENT:
return "XDNN_NOT_IMPLEMENT";
default:
return "Unknown XDNN status";
}
}
inline std::string build_xpu_error_msg(int stat) {
std::string msg("XPU Error <" + std::to_string(stat) + ">, ");
return msg + xpuGetErrorString(stat) + " ";
......@@ -123,6 +140,10 @@ inline std::string build_xpu_error_msg(BKCLResult_t stat) {
return msg + bkclGetErrorString(stat) + " ";
}
inline std::string build_xpu_xdnn_error_msg(int stat, std::string msg) {
return msg + " XDNN Error, " + xdnnGetErrorString(stat) + " ";
}
namespace details {
template <typename T>
......@@ -156,5 +177,15 @@ DEFINE_EXTERNAL_API_TYPE(BKCLResult_t, BKCL_SUCCESS);
} \
} while (0)
#define PADDLE_ENFORCE_XDNN_SUCCESS(COND, MSG) \
do { \
auto __cond__ = (COND); \
if (UNLIKELY(__cond__ != xpu::Error_t::SUCCESS)) { \
auto __summary__ = paddle::platform::errors::External( \
::paddle::platform::build_xpu_xdnn_error_msg(__cond__, MSG)); \
__THROW_ERROR_INTERNAL__(__summary__); \
} \
} while (0)
} // namespace platform
} // namespace paddle
......@@ -33,6 +33,24 @@ bool CheckXPUStatusFailure(T value, const std::string& msg) {
}
}
template <typename T>
bool CheckXDNNStatusSuccess(T value, const std::string& msg = "success") {
PADDLE_ENFORCE_XDNN_SUCCESS(value, "XDNN Error ");
return true;
}
template <typename T>
bool CheckXDNNStatusFailure(T value, const std::string& msg) {
try {
PADDLE_ENFORCE_XDNN_SUCCESS(value, "XDNN Error ");
return false;
} catch (paddle::platform::EnforceNotMet& error) {
std::string ex_msg = error.what();
std::cout << ex_msg << std::endl;
return ex_msg.find(msg) != std::string::npos;
}
}
TEST(enforce, xpu_status) {
EXPECT_TRUE(CheckXPUStatusSuccess(static_cast<int>(XPU_SUCCESS)));
EXPECT_TRUE(CheckXPUStatusFailure(static_cast<int>(XPUERR_INVALID_DEVICE),
......@@ -114,3 +132,15 @@ TEST(enforce, bkcl_status) {
EXPECT_TRUE(
CheckXPUStatusFailure(BKCL_INTERNAL_ERROR, "BKCL_INTERNAL_ERROR"));
}
TEST(enforce, xdnn_status) {
EXPECT_TRUE(CheckXDNNStatusSuccess(xpu::Error_t::SUCCESS));
EXPECT_TRUE(CheckXDNNStatusFailure(xpu::Error_t::INVALID_PARAM,
"XDNN_INVALID_PARAM"));
EXPECT_TRUE(CheckXDNNStatusFailure(xpu::Error_t::RUNTIME_ERROR,
"XDNN_RUNTIME_ERROR"));
EXPECT_TRUE(CheckXDNNStatusFailure(xpu::Error_t::NO_ENOUGH_WORKSPACE,
"XDNN_NO_ENOUGH_WORKSPACE"));
EXPECT_TRUE(CheckXDNNStatusFailure(xpu::Error_t::NOT_IMPLEMENT,
"XDNN_NOT_IMPLEMENT"));
}
......@@ -14,8 +14,11 @@ limitations under the License. */
#include <cstdlib>
#include <string>
#include "gflags/gflags.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
PADDLE_DEFINE_EXPORTED_string(
......@@ -56,7 +59,7 @@ int GetRuntimeVersion() {
/**************************** Device Management **************************/
static int GetDeviceCountImpl() {
const auto *xpu_visible_devices = std::getenv("XPU_VISIBLE_DEVICES");
const auto* xpu_visible_devices = std::getenv("XPU_VISIBLE_DEVICES");
if (xpu_visible_devices != nullptr) {
std::string xpu_visible_devices_str(xpu_visible_devices);
if (std::all_of(xpu_visible_devices_str.begin(),
......@@ -114,28 +117,39 @@ std::vector<int> GetXPUSelectedDevices() {
/**************************** Memory Management **************************/
void MemcpySyncH2D(void *dst, const void *src, size_t count, int dev_id) {
platform::XPUDeviceGuard guard(dev_id);
void MemcpySyncH2D(void* dst, const void* src, size_t count,
const platform::XPUPlace& dst_place) {
platform::XPUDeviceGuard guard(dst_place.device);
PADDLE_ENFORCE_XPU_SUCCESS(
xpu_memcpy(dst, src, count, XPUMemcpyKind::XPU_HOST_TO_DEVICE));
}
void MemcpySyncD2H(void *dst, const void *src, size_t count, int dev_id) {
platform::XPUDeviceGuard guard(dev_id);
void MemcpySyncD2H(void* dst, const void* src, size_t count,
const platform::XPUPlace& src_place) {
platform::XPUDeviceGuard guard(src_place.device);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.GetByPlace(src_place);
dev_ctx->Wait();
PADDLE_ENFORCE_XPU_SUCCESS(
xpu_memcpy(dst, src, count, XPUMemcpyKind::XPU_DEVICE_TO_HOST));
}
void MemcpySyncD2D(void *dst, int dst_id, const void *src, int src_id,
// if src.device == dst.device and you need sync , after call this function,
// need to call xpu_wait()
void MemcpySyncD2D(void* dst, const platform::XPUPlace& dst_place,
const void* src, const platform::XPUPlace& src_place,
size_t count) {
int dev_id = GetXPUCurrentDeviceId();
if (dst_id == dev_id && src_id == dev_id) {
platform::XPUDeviceGuard guard(dev_id);
PADDLE_ENFORCE_XPU_SUCCESS(
xpu_memcpy(dst, src, count, XPUMemcpyKind::XPU_DEVICE_TO_DEVICE));
if (dst_place.device == dev_id && src_place.device == dev_id) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.GetByPlace(src_place);
PADDLE_ENFORCE_XDNN_SUCCESS(
xpu::copy(dev_ctx->x_context(), static_cast<const int8_t*>(src),
static_cast<int8_t*>(dst), count),
"copy ");
} else {
PADDLE_ENFORCE_XPU_SUCCESS(
xpu_memcpy_peer(dst_id, dst, src_id, src, count));
xpu_memcpy_peer(dst_place.device, dst, src_place.device, src, count));
}
}
......
......@@ -16,6 +16,7 @@ limitations under the License. */
namespace paddle {
namespace platform {
class XPUPlace;
/***** Version Management *****/
//! Get the version of XPU Driver
......@@ -41,9 +42,12 @@ std::vector<int> GetXPUSelectedDevices();
/***** Memory Management *****/
//! Copy memory from address src to dst synchronously.
void MemcpySyncH2D(void *dst, const void *src, size_t count, int dev_id);
void MemcpySyncD2H(void *dst, const void *src, size_t count, int dev_id);
void MemcpySyncD2D(void *dst, int dst_id, const void *src, int src_id,
void MemcpySyncH2D(void *dst, const void *src, size_t count,
const platform::XPUPlace &dst_place);
void MemcpySyncD2H(void *dst, const void *src, size_t count,
const platform::XPUPlace &src_place);
void MemcpySyncD2D(void *dst, const platform::XPUPlace &dst_place,
const void *src, const platform::XPUPlace &src_place,
size_t count);
class XPUDeviceGuard {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册