未验证 提交 9bf00cd5 编写于 作者: B Baibaifan 提交者: GitHub

repair npu matmul_grad and comm_init_hccl (#33719)

上级 affddfaa
......@@ -22,7 +22,11 @@ class Scope;
} // namespace framework
} // namespace paddle
#if defined(PADDLE_WITH_ASCEND_CL)
#include "acl/acl.h"
#include "hccl/hccl.h"
#include "hccl/hccl_types.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#endif
namespace paddle {
......@@ -57,6 +61,31 @@ class CCommInitOpAscend : public framework::OperatorBase {
}
platform::HCCLCommContext::Instance().CreateHCCLComm(
hccl_id, rank_ids, rank_id, device_id, rid);
// Build comm
float* buff;
int32_t size = 20;
std::vector<float> input(size, 0);
for (int32_t idx = 0; idx < size; idx++) {
input[idx] = 1.0;
}
aclrtMalloc(reinterpret_cast<void**>(&buff), size * sizeof(float),
ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMemcpy(reinterpret_cast<void*>(buff), size * sizeof(float),
input.data(), size * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE);
VLOG(3) << "Build buff data successful.";
aclrtStream stream = nullptr;
auto comm = paddle::platform::HCCLCommContext::Instance().Get(rid, place);
if (rank_id == 0) {
stream = comm->stream();
} else {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream();
}
platform::dynload::HcclBroadcast(buff, size, HCCL_DATA_TYPE_FP32, 0,
comm->comm(), stream);
VLOG(3) << "Build connection successful.";
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with NPU."));
......
......@@ -138,10 +138,34 @@ class MatMulV2GradNPUKernel : public framework::OpKernel<T> {
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
const auto& runner_dy =
NpuOpRunner("BatchMatMul", {*x, *dout}, {*dy},
{{"adj_x1", true}, {"adj_x2", false}});
runner_dy.Run(stream);
if ((x->dims().size() == 3) && (dout->dims().size() == 3) &&
(dy->dims().size() == 2)) {
framework::Tensor dout_;
TensorCopy(*dout, ctx.GetPlace(), &dout_);
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
std::vector<int> vec_dim = framework::vectorize<int>(dout_.dims());
std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]};
dout_.Resize(framework::make_ddim(vec_dim_v));
framework::Tensor x_;
TensorCopy(*x, ctx.GetPlace(), &x_);
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
std::vector<int> vec_dim_x = framework::vectorize<int>(x_.dims());
std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1],
vec_dim_x[2]};
x_.Resize(framework::make_ddim(vec_dim_x_v));
const auto& runner_dy =
NpuOpRunner("MatMul", {x_, dout_}, {*dy},
{{"transpose_x1", true}, {"transpose_x2", false}});
runner_dy.Run(stream);
} else {
const auto& runner_dy =
NpuOpRunner("BatchMatMul", {*x, *dout}, {*dy},
{{"adj_x1", true}, {"adj_x2", false}});
runner_dy.Run(stream);
}
}
}
}
......
......@@ -206,5 +206,85 @@ class TestMatMulNet(unittest.TestCase):
self.assertTrue(np.allclose(npu_loss, cpu_loss))
# The precision is aligned in NPU and GPU separately, which is only used for the usage method.
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestMatMulNet3_2(unittest.TestCase):
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
self._dtype = "float32"
a_np = np.random.random(size=(2, 1, 3)).astype(self._dtype)
b_np = np.random.random(size=(2, 1, 3)).astype(self._dtype)
c_np = np.random.random(size=(3, 2)).astype(self._dtype)
d_np = np.random.random(size=(3, 2)).astype(self._dtype)
label_np = np.random.randint(2, size=(2, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[2, 1, 3], dtype=self._dtype)
b = paddle.static.data(name="b", shape=[2, 1, 3], dtype=self._dtype)
c = paddle.static.data(name="c", shape=[3, 2], dtype=self._dtype)
d = paddle.static.data(name="d", shape=[3, 2], dtype=self._dtype)
label = paddle.static.data(
name="label", shape=[2, 1], dtype='int64')
sum_1 = paddle.add(a, b)
sum_2 = paddle.add(c, d)
sum_1 = paddle.cast(sum_1, 'float16')
sum_2 = paddle.cast(sum_2, 'float16')
if not run_npu:
sum_1 = paddle.cast(sum_1, 'float32')
sum_2 = paddle.cast(sum_2, 'float32')
result = paddle.matmul(sum_1, sum_2)
if run_npu:
result = paddle.cast(result, 'float32')
result = paddle.reshape(result, shape=[2, 2])
fc_1 = fluid.layers.fc(input=result, size=8)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_npu:
place = paddle.NPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(main_prog,
feed={
"a": a_np,
"b": b_np,
"c": c_np,
"d": d_np,
"label": label_np
},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_npu(self):
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
self.assertTrue(np.allclose(npu_pred, cpu_pred, atol=1e-4))
self.assertTrue(np.allclose(npu_loss, cpu_loss, atol=1e-4))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册