diff --git a/paddle/fluid/operators/abs_op_npu.cc b/paddle/fluid/operators/abs_op_npu.cc
new file mode 100644
index 0000000000000000000000000000000000000000..7bfe35ef6e02145714209452fadd9182b58659e7
--- /dev/null
+++ b/paddle/fluid/operators/abs_op_npu.cc
@@ -0,0 +1,76 @@
+/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the Licnse. */
+
+#include "paddle/fluid/operators/abs_op.h"
+#include "paddle/fluid/operators/npu_op_runner.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template <typename DeviceContext, typename T>
+class AbsNPUKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* x = ctx.Input<Tensor>("X");
+    auto* out = ctx.Output<Tensor>("Out");
+
+    out->mutable_data<T>(ctx.GetPlace());
+
+    const auto& runner = NpuOpRunner("Abs",
+                                     {
+                                         *x,
+                                     },
+                                     {*out}, {});
+
+    auto stream =
+        ctx.template device_context<paddle::platform::NPUDeviceContext>()
+            .stream();
+    runner.Run(stream);
+  }
+};
+
+template <typename DeviceContext, typename T>
+class AbsGradNPUKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* x = ctx.Input<Tensor>("X");
+    auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
+    auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
+
+    dx->mutable_data<T>(ctx.GetPlace());
+
+    const auto& runner = NpuOpRunner("AbsGrad", {*x, *dout}, {*dx}, {});
+
+    auto stream =
+        ctx.template device_context<paddle::platform::NPUDeviceContext>()
+            .stream();
+    runner.Run(stream);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+namespace plat = paddle::platform;
+
+REGISTER_OP_NPU_KERNEL(
+    abs, ops::AbsNPUKernel<plat::NPUDeviceContext, float>,
+    ops::AbsNPUKernel<plat::NPUDeviceContext, plat::float16>);
+
+REGISTER_OP_NPU_KERNEL(
+    abs_grad, ops::AbsGradNPUKernel<plat::NPUDeviceContext, float>,
+    ops::AbsGradNPUKernel<plat::NPUDeviceContext, plat::float16>);
diff --git a/paddle/fluid/operators/uniform_random_op_npu.cc b/paddle/fluid/operators/uniform_random_op_npu.cc
new file mode 100644
index 0000000000000000000000000000000000000000..efd9d844fcb40b3714d55311c600262a0db868a5
--- /dev/null
+++ b/paddle/fluid/operators/uniform_random_op_npu.cc
@@ -0,0 +1,106 @@
+/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/uniform_random_op.h"
+#include <string>
+#include "paddle/fluid/framework/generator.h"
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/framework/operator.h"
+
+namespace paddle {
+namespace operators {
+
+template <typename T>
+class NPUUniformRandomKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext &ctx) const override {
+    framework::Tensor *tensor = nullptr;
+    auto out_var = ctx.OutputVar("Out");
+    std::vector<int64_t> new_shape;
+    auto list_new_shape_tensor =
+        ctx.MultiInput<framework::Tensor>("ShapeTensorList");
+    if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) {
+      if (ctx.HasInput("ShapeTensor")) {
+        auto *shape_tensor = ctx.Input<framework::Tensor>("ShapeTensor");
+        new_shape = GetNewDataFromShapeTensor(shape_tensor);
+      } else if (list_new_shape_tensor.size() > 0) {
+        new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor);
+      }
+    }
+
+    if (out_var->IsType<framework::SelectedRows>()) {
+      auto *selected_rows = out_var->GetMutable<framework::SelectedRows>();
+      tensor = selected_rows->mutable_value();
+      auto shape = ctx.Attr<std::vector<int64_t>>("shape");
+      if (!new_shape.empty()) shape = new_shape;
+      tensor->Resize(framework::make_ddim(shape));
+      selected_rows->mutable_rows()->reserve(shape[0]);
+    } else if (out_var->IsType<framework::LoDTensor>()) {
+      tensor = out_var->GetMutable<framework::LoDTensor>();
+      if (!new_shape.empty()) tensor->Resize(framework::make_ddim(new_shape));
+    } else {
+      PADDLE_THROW(platform::errors::InvalidArgument(
+          "Expected type of Output(out) in uniform_random_op must be Tensor, "
+          "SelectedRows. But got "
+          "unsupport type: %s.",
+          framework::ToTypeName(out_var->Type())));
+    }
+    T *data = tensor->mutable_data<T>(ctx.GetPlace());
+
+    int64_t size = tensor->numel();
+    std::unique_ptr<T[]> data_cpu(new T[size]);
+    std::uniform_real_distribution<T> dist(
+        static_cast<T>(ctx.Attr<float>("min")),
+        static_cast<T>(ctx.Attr<float>("max")));
+    unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
+    auto engine = framework::GetCPURandomEngine(seed);
+
+    for (int64_t i = 0; i < size; ++i) {
+      data_cpu[i] = dist(*engine);
+    }
+
+    unsigned int diag_num =
+        static_cast<unsigned int>(ctx.Attr<int>("diag_num"));
+    unsigned int diag_step =
+        static_cast<unsigned int>(ctx.Attr<int>("diag_step"));
+    auto diag_val = static_cast<T>(ctx.Attr<float>("diag_val"));
+    if (diag_num > 0) {
+      PADDLE_ENFORCE_GT(
+          size, (diag_num - 1) * (diag_step + 1),
+          platform::errors::InvalidArgument(
+              "ShapeInvalid: the diagonal's elements is equal (num-1) "
+              "* (step-1) with num %d, step %d,"
+              "It should be smaller than %d, but received %d",
+              diag_num, diag_step, (diag_num - 1) * (diag_step + 1), size));
+      for (int64_t i = 0; i < diag_num; ++i) {
+        int64_t pos = i * diag_step + i;
+        data_cpu[pos] = diag_val;
+      }
+    }
+
+    // copy to NPU
+    auto stream =
+        ctx.template device_context<paddle::platform::NPUDeviceContext>()
+            .stream();
+    memory::Copy(BOOST_GET_CONST(platform::NPUPlace, ctx.GetPlace()), data,
+                 platform::CPUPlace(), reinterpret_cast<void *>(data_cpu.get()),
+                 size * sizeof(T), stream);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+REGISTER_OP_NPU_KERNEL(uniform_random,
+                       paddle::operators::NPUUniformRandomKernel<float>);
diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt
index 0356aead2e0bf655f4d011c5fda7ae42eba4028f..644f25db9e77cee18017a78fd32b95f443669146 100644
--- a/python/paddle/fluid/tests/unittests/CMakeLists.txt
+++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt
@@ -270,6 +270,12 @@ function(py_test_modules TARGET_NAME)
                 COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
                 ${PYTHON_EXECUTABLE} -m coverage run --branch -p ${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES}
                 WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
+    elseif(WITH_ASCEND_CL)
+        # AscendCL need to include ascend toolkit python path, or ACL error will be thrown when running ctest
+        add_test(NAME ${TARGET_NAME}
+                COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python:$ENV{PYTHONPATH} ${py_test_modules_ENVS}
+                ${PYTHON_EXECUTABLE} ${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES}
+                WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
     else()
         add_test(NAME ${TARGET_NAME}
                 COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_modules_ENVS}
diff --git a/python/paddle/fluid/tests/unittests/npu/test_abs_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_abs_op_npu.py
new file mode 100644
index 0000000000000000000000000000000000000000..9382cf2162ef2598192e0a8e0f1bd630cbb9a6a4
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/npu/test_abs_op_npu.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# 
+#     http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function, division
+
+import numpy as np
+import unittest
+import sys
+sys.path.append("..")
+from op_test import OpTest
+import paddle
+import paddle.fluid as fluid
+
+paddle.enable_static()
+
+
+@unittest.skipIf(not paddle.is_compiled_with_npu(),
+                 "core is not compiled with NPU")
+class TestNPUAbs(OpTest):
+    def setUp(self):
+        self.op_type = "abs"
+        self.set_npu()
+        self.init_dtype()
+
+        np.random.seed(1024)
+        x = np.random.uniform(-1, 1, [4, 25]).astype(self.dtype)
+        # Because we set delta = 0.005 in calculating numeric gradient,
+        # if x is too small, such as 0.002, x_neg will be -0.003
+        # x_pos will be 0.007, so the numeric gradient is inaccurate.
+        # we should avoid this
+        x[np.abs(x) < 0.005] = 0.02
+        out = np.abs(x)
+
+        self.inputs = {'X': x}
+        self.outputs = {'Out': out}
+
+    def set_npu(self):
+        self.__class__.use_npu = True
+        self.place = paddle.NPUPlace(0)
+
+    def init_dtype(self):
+        self.dtype = np.float32
+
+    def test_check_output(self):
+        self.check_output_with_place(self.place)
+
+    def test_check_grad(self):
+        self.check_grad_with_place(self.place, ['X'], 'Out')
+
+
+# To-do(qili93): numeric_place will use CPUPlace in op_test.py and abs do not have CPUKernel for float16, to be uncommented after numeric_place fixed
+# @unittest.skipIf(not paddle.is_compiled_with_npu(), "core is not compiled with NPU")
+# class TestNPUAbsFP16(TestNPUAbs):
+#     def init_dtype(self):
+#         self.dtype = np.float16
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/npu/test_uniform_random_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_uniform_random_op_npu.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c37f0a32ac801112476547805cd2ceeb50f2aef
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/npu/test_uniform_random_op_npu.py
@@ -0,0 +1,112 @@
+#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import sys
+import subprocess
+import unittest
+import numpy as np
+sys.path.append("..")
+from op_test import OpTest
+import paddle
+import paddle.fluid.core as core
+import paddle
+from paddle.fluid.op import Operator
+import paddle.fluid as fluid
+from paddle.fluid import Program, program_guard
+from test_uniform_random_op import TestUniformRandomOp, TestUniformRandomOpSelectedRows
+
+paddle.enable_static()
+
+
+def output_hist(out):
+    hist, _ = np.histogram(out, range=(-5, 10))
+    hist = hist.astype("float32")
+    hist /= float(out.size)
+    prob = 0.1 * np.ones((10))
+    return hist, prob
+
+
+@unittest.skipIf(not paddle.is_compiled_with_npu(),
+                 "core is not compiled with NPU")
+class TestNPUUniformRandomOp(OpTest):
+    def setUp(self):
+        self.set_npu()
+        self.op_type = "uniform_random"
+        self.init_dtype()
+        self.inputs = {}
+        self.init_attrs()
+        self.outputs = {"Out": np.zeros((1000, 784)).astype(self.dtype)}
+
+    def init_attrs(self):
+        self.attrs = {
+            "shape": [1000, 784],
+            "min": -5.0,
+            "max": 10.0,
+            "seed": 10
+        }
+        self.output_hist = output_hist
+
+    def set_npu(self):
+        self.__class__.use_npu = True
+        self.place = paddle.NPUPlace(0)
+
+    def init_dtype(self):
+        self.dtype = np.float32
+
+    def test_check_output(self):
+        self.check_output_customized(self.verify_output)
+
+    def verify_output(self, outs):
+        hist, prob = self.output_hist(np.array(outs[0]))
+        self.assertTrue(
+            np.allclose(
+                hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
+
+
+@unittest.skipIf(not paddle.is_compiled_with_npu(),
+                 "core is not compiled with NPU")
+class TestNPUUniformRandomOpSelectedRows(unittest.TestCase):
+    def get_places(self):
+        places = [core.CPUPlace()]
+        if core.is_compiled_with_npu():
+            places.append(core.NPUPlace(0))
+        return places
+
+    def test_check_output(self):
+        for place in self.get_places():
+            self.check_with_place(place)
+
+    def check_with_place(self, place):
+        scope = core.Scope()
+        out = scope.var("X").get_selected_rows()
+        paddle.seed(10)
+        op = Operator(
+            "uniform_random",
+            Out="X",
+            shape=[1000, 784],
+            min=-5.0,
+            max=10.0,
+            seed=10)
+        op.run(scope, place)
+        self.assertEqual(out.get_tensor().shape(), [1000, 784])
+        hist, prob = output_hist(np.array(out.get_tensor()))
+        self.assertTrue(
+            np.allclose(
+                hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tools/dockerfile/Dockerfile.npu_aarch64 b/tools/dockerfile/Dockerfile.npu_aarch64
new file mode 100644
index 0000000000000000000000000000000000000000..e3cd162edc1547940741bf79e46dd4b8723deeac
--- /dev/null
+++ b/tools/dockerfile/Dockerfile.npu_aarch64
@@ -0,0 +1,176 @@
+# A image for building paddle binaries
+# Use cann 5.0.2.alpha003 and aarch64 for A300t-9000
+# When you modify it, please be aware of cann version
+#
+# Build: CANN 5.0.2.alpha003
+# cd Paddle/tools/dockerfile
+# docker build -f Dockerfile.npu_aarch64  \
+# -t paddlepaddle/paddle:latest-cann5.0.2-gcc82-aarch64-dev .
+#
+# docker run -it --pids-limit 409600 \
+# -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
+# -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
+# -v /usr/local/dcmi:/usr/local/dcmi \
+# paddlepaddle/paddle:latest-cann5.0.2-gcc82-aarch64-dev /bin/bash
+
+FROM ubuntu:18.04
+MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
+
+RUN apt-get update && apt-get install -y apt-utils
+RUN ln -snf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends tzdata
+RUN apt-get update && apt-get install -y software-properties-common && add-apt-repository ppa:deadsnakes/ppa && add-apt-repository ppa:ubuntu-toolchain-r/test
+RUN apt-get update && apt-get install -y curl wget vim git unzip unrar tar xz-utils libssl-dev bzip2 gzip make libgcc-s1 sudo openssh-server \
+            coreutils ntp language-pack-zh-hans python-qt4 libsm6 libxext6 libxrender-dev libgl1-mesa-glx libsqlite3-dev libopenblas-dev \
+            bison graphviz libjpeg-dev zlib1g zlib1g-dev automake locales swig net-tools libtool module-init-tools numactl libnuma-dev \
+            openssl libffi-dev pciutils libblas-dev gfortran libblas3 liblapack-dev liblapack3 default-jre screen tmux gdb lldb gcc g++
+
+# GCC 8.2
+WORKDIR /opt
+RUN wget -q https://paddle-ci.gz.bcebos.com/gcc-8.2.0.tar.xz && \
+    tar -xvf gcc-8.2.0.tar.xz && cd gcc-8.2.0 && \
+    unset LIBRARY_PATH CPATH C_INCLUDE_PATH PKG_CONFIG_PATH CPLUS_INCLUDE_PATH INCLUDE && \
+    ./contrib/download_prerequisites && \
+    cd .. && mkdir temp_gcc82 && cd temp_gcc82 && \
+    ../gcc-8.2.0/configure --prefix=/opt/compiler/gcc-8.2 --enable-threads=posix --disable-checking --disable-multilib && \
+    make -j8 && make install && \
+    cd .. && rm -rf temp_gcc82 && rm -rf gcc-8.2.0* && \
+    cd /usr/lib/aarch64-linux-gnu && \
+    mv libstdc++.so.6 libstdc++.so.6.bak && mv libstdc++.so.6.0.25 libstdc++.so.6.0.25.bak && \
+    ln -s /opt/compiler/gcc-8.2/lib64/libgfortran.so.5 /usr/lib/aarch64-linux-gnu/libstdc++.so.5 && \
+    ln -s /opt/compiler/gcc-8.2/lib64/libstdc++.so.6   /usr/lib/aarch64-linux-gnu/libstdc++.so.6 && \
+    cp /opt/compiler/gcc-8.2/lib64/libstdc++.so.6.0.25 /usr/lib/aarch64-linux-gnu && \
+    cd /usr/bin && mv gcc gcc.bak && mv g++ g++.bak && \
+    ln -s /opt/compiler/gcc-8.2/bin/gcc /usr/bin/gcc && \
+    ln -s /opt/compiler/gcc-8.2/bin/g++ /usr/bin/g++
+ENV PATH=/opt/compiler/gcc-8.2/bin:$PATH
+ENV LD_LIBRARY_PATH=/opt/compiler/gcc-8.2/lib:/opt/compiler/gcc-8.2/lib64:$LD_LIBRARY_PATH
+
+# cmake 3.19
+WORKDIR /opt
+RUN wget -q https://cmake.org/files/v3.19/cmake-3.19.8-Linux-aarch64.tar.gz && \
+    tar -zxvf cmake-3.19.8-Linux-aarch64.tar.gz && rm cmake-3.19.8-Linux-aarch64.tar.gz && \
+    mv cmake-3.19.8-Linux-aarch64 cmake-3.19
+ENV PATH=/opt/cmake-3.19/bin:${PATH}
+
+# conda 4.9.2
+WORKDIR /opt
+ARG CONDA_FILE=Miniconda3-py37_4.9.2-Linux-aarch64.sh
+RUN cd /opt && wget -q https://repo.anaconda.com/miniconda/${CONDA_FILE} && chmod +x ${CONDA_FILE}
+RUN mkdir /opt/conda && ./${CONDA_FILE} -b -f -p "/opt/conda" && rm -rf ${CONDA_FILE}
+ENV PATH=/opt/conda/bin:${PATH}
+RUN conda init bash && conda install -n base jupyter jupyterlab
+
+# install pylint and pre-commit
+RUN /opt/conda/bin/pip install pre-commit pylint pylint pytest astroid isort coverage qtconsole 
+# install CANN 5.0.2 requirement
+RUN /opt/conda/bin/pip install 'numpy<1.20,>=1.13.3' 'decorator>=4.4.0' 'sympy>=1.4' 'cffi>=1.12.3' 'protobuf>=3.11.3'
+RUN /opt/conda/bin/pip install attrs pyyaml pathlib2 scipy requests psutil
+
+# install Paddle requirement
+RUN wget https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/python/requirements.txt -O /root/requirements.txt
+RUN /opt/conda/bin/pip install -r /root/requirements.txt && rm -rf /root/requirements.txt
+RUN wget https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/python/unittest_py/requirements.txt -O /root/requirements.txt
+RUN /opt/conda/bin/pip install -r /root/requirements.txt && rm -rf /root/requirements.txt
+
+# Install Go and glide
+RUN wget -qO- https://golang.org/dl/go1.16.5.linux-arm64.tar.gz | \
+    tar -xz -C /usr/local && \
+    mkdir /root/gopath && \
+    mkdir /root/gopath/bin && \
+    mkdir /root/gopath/src
+ENV GOROOT=/usr/local/go GOPATH=/root/gopath
+# should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT.
+ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
+# install glide
+RUN curl -s -q https://glide.sh/get | sh
+
+# git credential to skip password typing
+RUN git config --global credential.helper store
+
+# Fix locales to en_US.UTF-8
+RUN localedef -i en_US -f UTF-8 en_US.UTF-8
+
+RUN apt-get install libprotobuf-dev -y
+
+# Older versions of patchelf limited the size of the files being processed and were fixed in this pr.
+# https://github.com/NixOS/patchelf/commit/ba2695a8110abbc8cc6baf0eea819922ee5007fa
+# So install a newer version here.
+RUN wget -q http://ports.ubuntu.com/pool/universe/p/patchelf/patchelf_0.10-2build1_arm64.deb && \
+    dpkg -i patchelf_0.10-2build1_arm64.deb && rm -rf patchelf_0.10-2build1_arm64.deb
+
+# Configure OpenSSH server. c.f. https://docs.docker.com/engine/examples/running_ssh_service
+RUN mkdir /var/run/sshd && echo 'root:root' | chpasswd && sed -ri 's/^PermitRootLogin\s+.*/PermitRootLogin yes/' /etc/ssh/sshd_config && sed -ri 's/UsePAM yes/#UsePAM yes/g' /etc/ssh/sshd_config
+CMD source ~/.bashrc
+
+# ccache 3.7.9
+RUN wget https://paddle-ci.gz.bcebos.com/ccache-3.7.9.tar.gz && \
+    tar xf ccache-3.7.9.tar.gz && mkdir /usr/local/ccache-3.7.9 && cd ccache-3.7.9 && \
+    ./configure -prefix=/usr/local/ccache-3.7.9 && \
+    make -j8 && make install && cd .. && rm -rf ccache-3.7.9* && \
+    ln -s /usr/local/ccache-3.7.9/bin/ccache /usr/local/bin/ccache
+
+# clang-form 3.8.0
+RUN wget https://releases.llvm.org/3.8.0/clang+llvm-3.8.0-aarch64-linux-gnu.tar.xz && \ 
+    tar xf clang+llvm-3.8.0-aarch64-linux-gnu.tar.xz && cd clang+llvm-3.8.0-aarch64-linux-gnu && \
+    cp -r * /usr/local && cd .. && rm -rf clang+llvm-3.8.0-aarch64-linux-gnu*
+
+# HwHiAiUser
+RUN groupadd HwHiAiUser && \
+    useradd -g HwHiAiUser -m -d /home/HwHiAiUser HwHiAiUser
+
+# copy /etc/ascend_install.info to current dir fist
+COPY ascend_install.info /etc/ascend_install.info
+
+# copy /usr/local/Ascend/driver/version.info to current dir fist
+RUN mkdir -p /usr/local/Ascend/driver
+COPY version.info /usr/local/Ascend/driver/version.info
+
+# Packages from https://www.hiascend.com/software/cann/community
+WORKDIR /usr/local/Ascend
+# update envs for driver
+ENV LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64:$LD_LIBRARY_PATH
+ENV LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/common:$LD_LIBRARY_PATH
+ENV LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH
+
+# Install Ascend toolkit
+COPY Ascend-cann-toolkit_5.0.2.alpha003_linux-aarch64.run /usr/local/Ascend/
+RUN ./Ascend-cann-toolkit_5.0.2.alpha003_linux-aarch64.run --install --quiet
+RUN rm -rf Ascend-cann-toolkit_5.0.2.alpha003_linux-aarch64.run
+# udpate envs for model transformation and operator develop
+ENV PATH=/usr/local/Ascend/ascend-toolkit/latest/atc/bin:$PATH
+ENV LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/atc/lib64:$LD_LIBRARY_PATH
+ENV PYTHONPATH=/usr/local/Ascend/ascend-toolkit/latest/pyACL/python/site-packages/acl:$PYTHONPATH
+ENV PYTHONPATH=/usr/local/Ascend/ascend-toolkit/latest/atc/python/site-packages:$PYTHONPATH
+ENV PYTHONPATH=/usr/local/Ascend/ascend-toolkit/latest/toolkit/python/site-packages:$PYTHONPATH
+ENV TOOLCHAIN_HOME=/usr/local/Ascend/ascend-toolkit/latest/toolkit
+
+# Install Ascend NNAE
+COPY Ascend-cann-nnae_5.0.2.alpha003_linux-aarch64.run /usr/local/Ascend/
+RUN ./Ascend-cann-nnae_5.0.2.alpha003_linux-aarch64.run --install --quiet
+RUN rm -rf Ascend-cann-nnae_5.0.2.alpha003_linux-aarch64.run
+
+# update envs for third party AI framework develop
+ENV PATH=/usr/local/Ascend/nnae/latest/fwkacllib/bin:$PATH
+ENV PATH=/usr/local/Ascend/nnae/latest/fwkacllib/ccec_compiler/bin:$PATH
+ENV LD_LIBRARY_PATH=/usr/local/Ascend/nnae/latest/fwkacllib/lib64:$LD_LIBRARY_PATH
+ENV PYTHONPATH=/usr/local/Ascend/nnae/latest/fwkacllib/python/site-packages:$PYTHONPATH
+ENV ASCEND_AICPU_PATH=/usr/local/Ascend/nnae/latest
+ENV ASCEND_OPP_PATH=/usr/local/Ascend/nnae/latest/opp
+
+# DEV image should open error level log
+# 0 debug; 1 info; 2 warning; 3 error; 4 null
+ENV ASCEND_GLOBAL_LOG_LEVEL=3
+RUN rm -rf /usr/local/Ascend/driver
+
+# Create /lib64/ld-linux-aarch64.so.1
+RUN umask 0022 && \
+    if [ ! -d "/lib64" ]; \
+    then \
+        mkdir /lib64 && ln -sf /lib/ld-linux-aarch64.so.1 /lib64/ld-linux-aarch64.so.1; \
+    fi
+
+# Clean
+RUN apt-get clean -y
+
+EXPOSE 22