未验证 提交 11c785a4 编写于 作者: K kuizhiqing 提交者: GitHub

fix ndiv for npu (#37998)

上级 515d3562
......@@ -211,70 +211,6 @@ void SplitTensorsWithType<platform::XPUDeviceContext>(
}
#endif
// NOTE(liubo48): Only implement operators::math::SplitFunctor for npu now.
// If later the operators::StridedMemcpyWithAxis0 is supported,
// then this specific SplitTensorsForAllReduce can be removed.
#ifdef PADDLE_WITH_ASCEND_CL
template <>
void SplitTensorsForAllReduce<platform::NPUDeviceContext, float>(
const platform::NPUDeviceContext &context,
framework::Variable *p_dense_contents,
std::vector<framework::Tensor> *p_dense_tensors) {
auto *in = p_dense_contents->GetMutable<framework::LoDTensor>();
std::vector<framework::Tensor *> outs;
std::vector<const framework::Tensor *> shape_refer;
outs.reserve(p_dense_tensors->size());
shape_refer.reserve(p_dense_tensors->size());
for (auto &tensor : *p_dense_tensors) {
outs.emplace_back(&tensor);
shape_refer.emplace_back(&tensor);
}
operators::math::SplitFunctor<platform::NPUDeviceContext, float>
split_functor_;
split_functor_(context, *in, shape_refer, 0, &outs);
}
template <>
void ConcatTensorsWithType<platform::NPUDeviceContext>(
const platform::NPUDeviceContext &context,
const std::vector<framework::Tensor> &dense_tensors_,
framework::Variable *p_dense_contents,
framework::proto::VarType::Type type) {
switch (type) {
case framework::proto::VarType::FP32:
ConcatTensorsForAllReduce<platform::NPUDeviceContext, float>(
context, dense_tensors_, p_dense_contents);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it concats tensors for "
"allreduce.",
framework::DataTypeToString(type)));
}
}
template <>
void SplitTensorsWithType<platform::NPUDeviceContext>(
const platform::NPUDeviceContext &context,
framework::Variable *p_dense_contents,
std::vector<framework::Tensor> *p_dense_tensors,
framework::proto::VarType::Type type) {
switch (type) {
case framework::proto::VarType::FP32:
SplitTensorsForAllReduce<platform::NPUDeviceContext, float>(
context, p_dense_contents, p_dense_tensors);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors for "
"allreduce.",
framework::DataTypeToString(type)));
}
}
#endif
void Group::ConcatTensors(const platform::DeviceContext &context) {
auto place = context.GetPlace();
if (platform::is_gpu_place(place)) {
......
......@@ -238,7 +238,7 @@ def monkey_patch_varbase():
"Tensor shape not match, Tensor of grad_tensor [ {} ] with shape {} mismatch Tensor [ {} ] with shape {}".format(
grad_tensor.name, grad_tensor.shape, self.name, self.shape)
if paddle.is_compiled_with_xpu():
if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_npu():
# TODO(liuyuhui): Currently only for xpu. Will be removed in the future.
scaled_loss = scale_loss(self)
core.dygraph_run_backward([scaled_loss], [grad_tensor],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册