diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.cc
new file mode 100644
index 0000000000000000000000000000000000000000..1400f623b7efe34761344bcd145b6e2b07b46779
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.cc
@@ -0,0 +1,26 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h"
+namespace mindspore {
+namespace kernel {
+MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
+                      OnesLikeGpuKernel, float)
+MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
+                      OnesLikeGpuKernel, half)
+MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
+                      OnesLikeGpuKernel, int)
+}  // namespace kernel
+}  // namespace mindspore
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..355ed9e05f1a880199c81a34115e7dc692bc2151
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h
@@ -0,0 +1,85 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ONESLIKE_H_
+#define MINDSPORE_CCSRC_KERNEL_GPU_ONESLIKE_H_
+
+#include <vector>
+#include "backend/kernel_compiler/gpu/gpu_kernel.h"
+#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
+#include "backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh"
+namespace mindspore {
+namespace kernel {
+template <typename T>
+class OnesLikeGpuKernel : public GpuKernel {
+ public:
+  OnesLikeGpuKernel() : input_size_(0), output_size_(0) {}
+  ~OnesLikeGpuKernel() override = default;
+
+  const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
+  const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
+  const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
+
+  bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
+              const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
+    T *input = GetDeviceAddress<T>(inputs, 0);
+    T *output = GetDeviceAddress<T>(outputs, 0);
+    int size = SizeToInt(input_size_ / sizeof(T));
+    CalOnesLike(size, input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
+    return true;
+  }
+
+  bool Init(const CNodePtr &kernel_node) override {
+    size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
+    if (input_num != 1) {
+      MS_LOG(ERROR) << "Input number is " << input_num << ", but oneslike needs 1 input.";
+      return false;
+    }
+    size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
+    if (output_num != 1) {
+      MS_LOG(ERROR) << "Output number is " << output_num << ", but oneslike needs 1 output.";
+      return false;
+    }
+    auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
+    size_t shape_size = input_shape.size();
+
+    input_size_ = 1;
+    for (size_t i = 0; i < shape_size; i++) {
+      input_size_ *= input_shape[i];
+    }
+    input_size_ *= sizeof(T);
+    output_size_ = input_size_;
+    InitSizeLists();
+    return true;
+  }
+
+ protected:
+  void InitSizeLists() override {
+    input_size_list_.push_back(input_size_);
+    output_size_list_.push_back(output_size_);
+    return;
+  }
+
+ private:
+  std::vector<size_t> input_size_list_;
+  std::vector<size_t> output_size_list_;
+  std::vector<size_t> workspace_size_list_;
+  size_t input_size_;
+  size_t output_size_;
+};
+}  // namespace kernel
+}  // namespace mindspore
+
+#endif  // MINDSPORE_CCSRC_KERNEL_GPU_ONESLIKE_H_
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cu
new file mode 100644
index 0000000000000000000000000000000000000000..dc1d9cd206e179dfe65cc8e1cadc51cfd9fa1656
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cu
@@ -0,0 +1,37 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cuda_runtime.h>
+#include "oneslike_impl.cuh"
+#include "runtime/device/gpu/cuda_common.h"
+template <typename T>
+__global__ void OnesLike(const int size, const T* input,  T* output) {
+  int one = 1;
+  T val = static_cast<T>(one);
+  for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
+    output[pos] = val;
+  }
+  return;
+}
+template <typename T>
+void CalOnesLike(const int size, const T* input, T* output, cudaStream_t cuda_stream) {
+  OnesLike<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output);
+  return;
+}
+
+template void CalOnesLike<float>(const int size, const float* input, float* output, cudaStream_t cuda_stream);
+template void CalOnesLike<half>(const int size, const half* input, half* output, cudaStream_t cuda_stream);
+template void CalOnesLike<int>(const int size, const int* input, int* output, cudaStream_t cuda_stream);
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..81e92c1d099f78282c67e8c61db2a86db1c10324
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh
@@ -0,0 +1,23 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_
+#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_
+
+template <typename T>
+void CalOnesLike(const int size, const T* input, T* output, cudaStream_t cuda_stream);
+
+#endif  // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_
diff --git a/tests/st/ops/gpu/test_oneslike_op.py b/tests/st/ops/gpu/test_oneslike_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..e721d157295974ca4333e971b8a575119c830a74
--- /dev/null
+++ b/tests/st/ops/gpu/test_oneslike_op.py
@@ -0,0 +1,85 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import numpy as np
+import pytest
+
+import mindspore.context as context
+import mindspore.nn as nn
+from mindspore import Tensor
+from mindspore.ops import operations as P
+
+context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
+
+
+class NetOnesLike(nn.Cell):
+    def __init__(self):
+        super(NetOnesLike, self).__init__()
+        self.ones_like = P.OnesLike()
+
+    def construct(self, x):
+        return self.ones_like(x)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_x86_gpu_training
+@pytest.mark.env_onecard
+def test_OnesLike():
+    x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)
+    x1_np = np.random.uniform(-2, 2, 1).astype(np.float16)
+    x2_np = np.zeros([3, 3, 3], dtype=np.int32)
+
+    x0 = Tensor(x0_np)
+    x1 = Tensor(x1_np)
+    x2 = Tensor(x2_np)
+
+    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
+    ones_like = NetOnesLike()
+    output0 = ones_like(x0)
+    expect0 = np.ones_like(x0_np)
+    diff0 = output0.asnumpy() - expect0
+    error0 = np.ones(shape=expect0.shape) * 1.0e-5
+    assert np.all(diff0 < error0)
+    assert output0.shape == expect0.shape
+
+    output1 = ones_like(x1)
+    expect1 = np.ones_like(x1_np)
+    diff1 = output1.asnumpy() - expect1
+    error1 = np.ones(shape=expect1.shape) * 1.0e-5
+    assert np.all(diff1 < error1)
+    assert output1.shape == expect1.shape
+
+    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
+    ones_like = NetOnesLike()
+    output0 = ones_like(x0)
+    expect0 = np.ones_like(x0_np)
+    diff0 = output0.asnumpy() - expect0
+    error0 = np.ones(shape=expect0.shape) * 1.0e-5
+    assert np.all(diff0 < error0)
+    assert output0.shape == expect0.shape
+
+    output1 = ones_like(x1)
+    expect1 = np.ones_like(x1_np)
+    diff1 = output1.asnumpy() - expect1
+    error1 = np.ones(shape=expect1.shape) * 1.0e-5
+    assert np.all(diff1 < error1)
+    assert output1.shape == expect1.shape
+
+    output2 = ones_like(x2)
+    expect2 = np.ones_like(x2_np)
+    diff2 = output2.asnumpy() - expect2
+    error2 = np.ones(shape=expect2.shape) * 1.0e-5
+    assert np.all(diff2 < error2)
+    assert output2.shape == expect2.shape