未验证 提交 d19a9b39 编写于 作者: T taixiurong 提交者: GitHub

[XPU AMP] 1. xpu support gradient acc 2. xpu support create tensor in dygraph...

[XPU AMP] 1. xpu support gradient acc 2. xpu support create tensor in dygraph 3. xpu support update weight params in amp (#36439)
上级 d3c93942
...@@ -87,9 +87,17 @@ class TensorAddFunctor : public boost::static_visitor<> { ...@@ -87,9 +87,17 @@ class TensorAddFunctor : public boost::static_visitor<> {
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
void operator()(const platform::XPUPlace& place) { void operator()(const platform::XPUPlace& place) {
using XPUType = typename XPUTypeTrait<T>::Type;
platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>( platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
xpu::add<T>(ctx->x_context(), x_, y_, y_, static_cast<int>(numel_)); int r = xpu::add<XPUType>(
ctx->x_context(), reinterpret_cast<const XPUType*>(x_),
reinterpret_cast<const XPUType*>(y_), reinterpret_cast<XPUType*>(y_),
static_cast<int>(numel_));
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU add kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
} }
#else #else
void operator()(const platform::XPUPlace& place) { void operator()(const platform::XPUPlace& place) {
...@@ -154,6 +162,24 @@ class TensorAddFunctor : public boost::static_visitor<> { ...@@ -154,6 +162,24 @@ class TensorAddFunctor : public boost::static_visitor<> {
T* y_; T* y_;
}; };
#ifdef PADDLE_WITH_XPU
template <typename T>
void XPUTensorAddFunctor(const platform::Place& place,
const framework::Tensor& src, framework::Tensor* dst) {
using XPUType = typename XPUTypeTrait<T>::Type;
platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
const XPUType* x = reinterpret_cast<const XPUType*>(src.data<T>());
XPUType* y = reinterpret_cast<XPUType*>(dst->mutable_data<T>(place));
int r = xpu::add<XPUType>(ctx->x_context(), x, y, y,
static_cast<int>(src.numel()));
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU add kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
#endif
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void TensorAddImpl(const framework::Tensor& src, framework::Tensor* dst, void TensorAddImpl(const framework::Tensor& src, framework::Tensor* dst,
const platform::Place& place) { const platform::Place& place) {
...@@ -226,7 +252,26 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { ...@@ -226,7 +252,26 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
return; return;
} }
#endif #endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(place)) {
if (data_type == framework::DataTypeTrait<float>::DataType()) {
XPUTensorAddFunctor<float>(place, src_tensor, dst_tensor);
} else if (data_type ==
framework::DataTypeTrait<platform::float16>::DataType()) {
XPUTensorAddFunctor<platform::float16>(place, src_tensor, dst_tensor);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type), place));
}
return;
}
#endif
PADDLE_TENSOR_ADD(float); PADDLE_TENSOR_ADD(float);
#ifndef PADDLE_WITH_XPU #ifndef PADDLE_WITH_XPU
// NOTE(phlrain): xpu only support float // NOTE(phlrain): xpu only support float
PADDLE_TENSOR_ADD(double); PADDLE_TENSOR_ADD(double);
......
...@@ -23,30 +23,103 @@ namespace paddle { ...@@ -23,30 +23,103 @@ namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ReduceMeanXPUKernel : public framework::OpKernel<T> { class ReduceMeanXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_xpu_place(context.GetPlace()), true, platform::is_xpu_place(context.GetPlace()), true,
platform::errors::Unavailable("This kernel only runs on XPU.")); platform::errors::Unavailable("This kernel only runs on XPU."));
// bool reduce_all = context.Attr<bool>("reduce_all"); bool reduce_all = context.Attr<bool>("reduce_all");
auto* input = context.Input<Tensor>("X"); auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out"); auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
int ndim = input->dims().size();
std::vector<int> idims; std::vector<int> xdims;
for (int i = 0; i < input->dims().size(); i++) { for (int i = 0; i < input->dims().size(); i++) {
idims.push_back(input->dims()[i]); xdims.push_back(input->dims()[i]);
} }
auto dims = context.Attr<std::vector<int>>("dim"); auto rdims = context.Attr<std::vector<int>>("dim");
int rdim = dims.size(); if (reduce_all) {
int r = rdims.clear();
xpu::reduce(dev_ctx.x_context(), input->data<T>(), output->data<T>(), for (size_t i = 0; i < xdims.size(); i++) {
idims.data(), ndim, dims.data(), rdim, xpu::REDUCE_MEAN); rdims.push_back(static_cast<int>(i));
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, }
platform::errors::External("XPU kernel error!")); }
int r = xpu::reduce_mean(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(input->data<T>()),
reinterpret_cast<XPUType*>(output->data<T>()), xdims, rdims);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU reduce_mean kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
} }
}; };
template <typename DeviceContext, typename T>
class ReduceMeanGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
XPUType* x_data =
reinterpret_cast<XPUType*>(input_grad->mutable_data<T>(ctx.GetPlace()));
const XPUType* dy_data =
reinterpret_cast<const XPUType*>(output_grad->data<T>());
bool reduce_all = ctx.Attr<bool>("reduce_all");
auto reduce_dims = ctx.Attr<std::vector<int>>("dim");
std::vector<int> xdims;
for (int i = 0; i < input->dims().size(); i++) {
xdims.push_back(input->dims()[i]);
}
std::vector<int> ydims;
for (int i = 0; i < output_grad->dims().size(); i++) {
ydims.push_back(output_grad->dims()[i]);
}
int reduce_numel = 1;
if (reduce_all) {
reduce_dims.clear();
for (size_t d = 0; d < xdims.size(); ++d) {
reduce_dims.push_back(static_cast<int>(d));
}
}
for (auto& d : reduce_dims) {
if (d < 0) {
d = d + xdims.size();
}
reduce_numel *= xdims[d];
}
float val = 1.0f / static_cast<float>(reduce_numel);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::constant(dev_ctx.x_context(), x_data, input->numel(),
static_cast<XPUType>(val));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU constant kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
r = xpu::broadcast_mul(dev_ctx.x_context(), x_data, dy_data, x_data, xdims,
ydims);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU broadcast_mul kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -54,4 +127,8 @@ REGISTER_OP_XPU_KERNEL( ...@@ -54,4 +127,8 @@ REGISTER_OP_XPU_KERNEL(
reduce_mean, reduce_mean,
ops::ReduceMeanXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::ReduceMeanXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
reduce_mean_grad,
ops::ReduceMeanGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif #endif
...@@ -27,6 +27,8 @@ using Tensor = framework::Tensor; ...@@ -27,6 +27,8 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SliceXPUKernel : public framework::OpKernel<T> { class SliceXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto in = ctx.Input<framework::Tensor>("Input"); auto in = ctx.Input<framework::Tensor>("Input");
...@@ -83,114 +85,93 @@ class SliceXPUKernel : public framework::OpKernel<T> { ...@@ -83,114 +85,93 @@ class SliceXPUKernel : public framework::OpKernel<T> {
} }
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* in_data = in->data<T>(); const XPUType* in_data = reinterpret_cast<const XPUType*>(in->data<T>());
auto* out_data = out->mutable_data<T>(ctx.GetPlace()); XPUType* out_data =
int r = xpu::slice<T>(dev_ctx.x_context(), in_data, out_data, shape, reinterpret_cast<XPUType*>(out->mutable_data<T>(ctx.GetPlace()));
starts_extension, ends_extension); int r = xpu::slice<XPUType>(dev_ctx.x_context(), in_data, out_data, shape,
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, starts_extension, ends_extension);
platform::errors::External("XPU slice kernel error!")); PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU slice kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SliceGradXPUKernel : public framework::OpKernel<T> { class SliceGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* input = ctx.Input<Tensor>("Input");
auto* d_in = ctx.Output<framework::Tensor>(framework::GradVarName("Input")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
d_in->mutable_data<T>(ctx.GetPlace()); auto* dinput = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto in_dims = d_in->dims(); auto axes_int = ctx.Attr<std::vector<int>>("axes");
auto axes = ctx.Attr<std::vector<int>>("axes"); auto starts_int = ctx.Attr<std::vector<int>>("starts");
auto starts = ctx.Attr<std::vector<int>>("starts"); auto ends_int = ctx.Attr<std::vector<int>>("ends");
auto ends = ctx.Attr<std::vector<int>>("ends"); std::vector<int> axes(axes_int.begin(), axes_int.end());
std::vector<int> starts(starts_int.begin(), starts_int.end());
std::vector<int> ends(ends_int.begin(), ends_int.end());
// Get the accurate attribute value of starts and ends
auto starts_tensor_list = ctx.MultiInput<Tensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts = GetDataFromTensor<int>(ctx.Input<Tensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int>(starts_tensor_list);
}
// prepare starts, ends on XPU auto ends_tensor_list = ctx.MultiInput<Tensor>("EndsTensorList");
int dim_value = 0, start = 0, end = 0; if (ctx.HasInput("EndsTensor")) {
// If a negative value is passed for any of the start or end indices, ends = GetDataFromTensor<int>(ctx.Input<Tensor>("EndsTensor"));
// it represents number of elements before the end of that dimension. } else if (ends_tensor_list.size() > 0) {
// If the value passed to start or end is larger than the n ends = GetDataFromTensorList<int>(ends_tensor_list);
// (the number of elements in this dimension), it represents n.
for (size_t i = 0; i < axes.size(); ++i) {
dim_value = in_dims[axes[i]];
start = starts[i];
end = ends[i];
start = start < 0 ? (start + dim_value) : start;
end = end < 0 ? (end + dim_value) : end;
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument(
"end should greater than start"));
starts[i] = start;
ends[i] = end;
} }
size_t shape_size = in_dims.size();
// the slice XPU kernel require that the length of `start`, `end` must be const auto& in_dims = input->dims();
// equal int rank = in_dims.size();
// to the dims size of input tensor, therefore, if shape_size > axes.size(),
// the `starts_extension` and `ends_extension` is necessary. std::vector<int> pad_left(rank);
std::vector<int> starts_extension(shape_size, 0); std::vector<int> out_dims(rank);
std::vector<int> ends_extension(shape_size, 0); std::vector<int> pad_right(rank);
if (shape_size > axes.size()) { int cnt = 0;
for (size_t i = 0; i < shape_size; ++i) { for (int i = 0; i < in_dims.size(); ++i) {
ends_extension[i] = in_dims[i]; int start = 0;
} int end = in_dims[i];
for (size_t i = 0; i < axes.size(); ++i) { int axis = cnt < static_cast<int>(axes.size()) ? axes[cnt] : -1;
starts_extension[axes[i]] = starts[i]; if (axis == i) {
ends_extension[axes[i]] = ends[i]; start = starts[cnt];
if (start < 0) {
start = (start + in_dims[i]);
}
start = std::max(start, static_cast<int>(0));
end = ends[cnt];
if (end < 0) {
end = (end + in_dims[i]);
}
end = std::min(end, static_cast<int>(in_dims[i]));
cnt++;
} }
}
int* starts_device = nullptr;
int* ends_device = nullptr;
int* starts_host =
shape_size > axes.size() ? starts_extension.data() : starts.data();
int* ends_host =
shape_size > axes.size() ? ends_extension.data() : ends.data();
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&starts_device),
shape_size * sizeof(int)),
XPU_SUCCESS,
platform::errors::External("XPU has no enough memory"));
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&ends_device),
shape_size * sizeof(int)),
XPU_SUCCESS,
platform::errors::External("XPU has no enough memory"));
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()),
starts_device, platform::CPUPlace(), starts_host,
shape_size * sizeof(int));
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()),
ends_device, platform::CPUPlace(), ends_host,
shape_size * sizeof(int));
// prepare shape on XPU pad_left[i] = start;
std::vector<int> shape(shape_size, 0); out_dims[i] = end - start;
for (size_t i = 0; i < shape_size; ++i) { pad_right[i] = in_dims[i] - out_dims[i] - pad_left[i];
shape[i] = in_dims[i];
} }
int* shape_device = nullptr;
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&shape_device),
shape_size * sizeof(int)),
XPU_SUCCESS,
platform::errors::External("XPU has no enough memory"));
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()),
shape_device, platform::CPUPlace(), shape.data(),
shape_size * sizeof(int));
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = const XPUType* dout_data =
xpu::slice_backward(dev_ctx.x_context(), shape_device, starts_device, reinterpret_cast<const XPUType*>(dout->data<T>());
ends_device, shape_size, d_out->data<T>(), XPUType* din_data =
d_in->data<T>(), d_in->numel(), d_out->numel()); reinterpret_cast<XPUType*>(dinput->mutable_data<T>(ctx.GetPlace()));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, int r = xpu::pad<XPUType>(dev_ctx.x_context(), dout_data, din_data,
platform::errors::External("xpu slice kernel error")); out_dims, pad_left, pad_right, XPUType(0));
dev_ctx.Wait(); PADDLE_ENFORCE_EQ(
// free device data r, XPU_SUCCESS,
xpu_free(shape_device); platform::errors::External("XPU pad kernel return wrong value[%d %s]",
xpu_free(starts_device); r, XPUAPIErrorMsg[r]));
xpu_free(ends_device);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -198,8 +179,13 @@ namespace ops = paddle::operators; ...@@ -198,8 +179,13 @@ namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
slice, ops::SliceXPUKernel<paddle::platform::XPUDeviceContext, float>, slice, ops::SliceXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::SliceXPUKernel<paddle::platform::XPUDeviceContext, int>); ops::SliceXPUKernel<paddle::platform::XPUDeviceContext, int>,
ops::SliceXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
slice_grad, slice_grad,
ops::SliceGradXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::SliceGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::SliceGradXPUKernel<paddle::platform::XPUDeviceContext, int>,
ops::SliceGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif #endif
...@@ -109,7 +109,16 @@ XPUOpMap& get_kl2_ops() { ...@@ -109,7 +109,16 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"iou_similarity", {"iou_similarity",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})} {"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_mean_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
// AddMore // AddMore
}; };
......
...@@ -313,6 +313,18 @@ def _current_expected_place(): ...@@ -313,6 +313,18 @@ def _current_expected_place():
"You are using GPU version Paddle, but your CUDA device is not set properly. CPU device will be used by default." "You are using GPU version Paddle, but your CUDA device is not set properly. CPU device will be used by default."
) )
_global_expected_place_ = core.CPUPlace() _global_expected_place_ = core.CPUPlace()
elif core.is_compiled_with_xpu():
try:
device_count = core.get_xpu_device_count()
except Exception as e:
device_count = 0
if device_count > 0:
_global_expected_place_ = core.XPUPlace(0)
else:
warnings.warn(
"You are using XPU version Paddle, but your XPU device is not set properly. CPU device will be used by default."
)
_global_expected_place_ = core.CPUPlace()
else: else:
_global_expected_place_ = core.CPUPlace() _global_expected_place_ = core.CPUPlace()
......
...@@ -190,9 +190,6 @@ class AdamW(Adam): ...@@ -190,9 +190,6 @@ class AdamW(Adam):
self.type = "adamw" self.type = "adamw"
if core.is_compiled_with_xpu():
self.type = "adam"
# Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that. # Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that.
self._auxiliary_vars = dict() self._auxiliary_vars = dict()
...@@ -259,10 +256,6 @@ class AdamW(Adam): ...@@ -259,10 +256,6 @@ class AdamW(Adam):
paddle.fluid.layers.assign(input=scaled_param, output=param) paddle.fluid.layers.assign(input=scaled_param, output=param)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
if paddle.is_compiled_with_xpu():
self._append_decoupled_weight_decay(block, param_and_grad)
return super(AdamW, self)._append_optimize_op(block, param_and_grad)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
if isinstance(param_and_grad, dict): if isinstance(param_and_grad, dict):
param_and_grad = self._update_param_group(param_and_grad) param_and_grad = self._update_param_group(param_and_grad)
......
...@@ -104,9 +104,9 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -104,9 +104,9 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
if place is None: if place is None:
place = _current_expected_place() place = _current_expected_place()
elif not isinstance(place, (core.Place, core.CPUPlace, core.CUDAPinnedPlace, elif not isinstance(place, (core.Place, core.CPUPlace, core.CUDAPinnedPlace,
core.CUDAPlace, core.NPUPlace)): core.CUDAPlace, core.NPUPlace, core.XPUPlace)):
raise ValueError( raise ValueError(
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace" "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace"
) )
#Todo(zhouwei): Support allocate tensor on any other specified card #Todo(zhouwei): Support allocate tensor on any other specified card
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册