diff --git a/cmake/neuware.cmake b/cmake/neuware.cmake index 811c8d664a097ebc4f6feb83aab50c21a8498011..a371a0032d991afd397ff6ff1018733a3b9ad7f4 100644 --- a/cmake/neuware.cmake +++ b/cmake/neuware.cmake @@ -17,13 +17,16 @@ INCLUDE_DIRECTORIES(${NEUWARE_INCLUDE_DIR}) set(CNNL_LIB ${NEUWARE_LIB_DIR}/libcnnl.so) set(CNRT_LIB ${NEUWARE_LIB_DIR}/libcnrt.so) set(CNDRV_LIB ${NEUWARE_LIB_DIR}/libcndrv.so) +set(CNPAPI_LIB ${NEUWARE_LIB_DIR}/libcnpapi.so) generate_dummy_static_lib(LIB_NAME "neuware_lib" GENERATOR "neuware.cmake") +set(NEUWARE_LIB_DEPS ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB} ${CNPAPI_LIB}) + if(WITH_CNCL) MESSAGE(STATUS "Compile with CNCL!") ADD_DEFINITIONS(-DPADDLE_WITH_CNCL) set(CNCL_LIB ${NEUWARE_LIB_DIR}/libcncl.so) - TARGET_LINK_LIBRARIES(neuware_lib ${CNCL_LIB} ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB}) -else() - TARGET_LINK_LIBRARIES(neuware_lib ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB}) + list(APPEND NEUWARE_LIB_DEPS ${CNCL_LIB}) endif() + +TARGET_LINK_LIBRARIES(neuware_lib ${NEUWARE_LIB_DEPS}) diff --git a/paddle/fluid/operators/gather_op_mlu.cc b/paddle/fluid/operators/gather_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..220d045952643fb97fc6d7807fd2c4c79d628320 --- /dev/null +++ b/paddle/fluid/operators/gather_op_mlu.cc @@ -0,0 +1,75 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class GatherOpMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto axis = ctx.Attr("axis"); + + auto *out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc index_desc(*index); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnl::GatherFunctor(ctx, axis, 0 /*batch_dims*/, x_desc.get(), + GetBasePtr(x), index_desc.get(), GetBasePtr(index), + out_desc.get(), GetBasePtr(out)); + } +}; + +template +class GatherGradOpMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *index = ctx.Input("Index"); + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc dx_desc(*dx); + auto value = static_cast(0); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &value, dx_desc.get(), + GetBasePtr(dx)); + + MLUCnnlTensorDesc index_desc(*index); + MLUCnnlTensorDesc dout_desc(*dout); + const cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE; + MLUCnnl::ScatterFunctor(ctx, dx_desc.get(), GetBasePtr(dx), dout_desc.get(), + GetBasePtr(dout), index_desc.get(), + GetBasePtr(index), mode); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_MLU_KERNEL(gather, ops::GatherOpMLUKernel, + ops::GatherOpMLUKernel, + ops::GatherOpMLUKernel); + +REGISTER_OP_MLU_KERNEL(gather_grad, ops::GatherGradOpMLUKernel, + ops::GatherGradOpMLUKernel, + ops::GatherGradOpMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index ecde4db3f334eb1e6bc9d65094c2b130d40c7b92..793aa2644b54810ea0b27b6880cb76d5ac71811d 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -934,9 +934,8 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { beta_ptr = static_cast(&beta_int); } - PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetOpTensorWorkspaceSize_v2( - handle, op_tensor_desc, alpha1_ptr, a_desc, a, alpha2_ptr, b_desc, b, - beta_ptr, output_desc, output, &workspace_size)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetOpTensorWorkspaceSize( + handle, a_desc, b_desc, output_desc, &workspace_size)); auto& dev_ctx = GetDevCtxFromCTX(ctx); Tensor workspace = ctx.AllocateTmpTensor( diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc b/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc index b84a2bc579d3e7b9a1c9d594f9316c2ff38aff72..54ead6d3df7f056d1da41661348a30bf260dce49 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc @@ -118,11 +118,11 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel { GetBasePtr(&mu_tensor)); for (size_t idx = 0; idx < n; ++idx) { - RegularizationType regularization_flag = + phi::RegularizationType regularization_flag = regularization_methods.size() > 0 && regularization_methods[idx] == "l2_decay" - ? RegularizationType::kL2DECAY - : RegularizationType::kNONE; + ? phi::RegularizationType::kL2DECAY + : phi::RegularizationType::kNONE; T regularization_coeff = static_cast(0.0); if (regularization_coeffs.size() != 0) { regularization_coeff = static_cast(regularization_coeffs[idx]); @@ -135,7 +135,7 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel { auto grad = grads[idx]; Tensor regularized_grad; MLUCnnlTensorDesc param_desc(*param_out); - if (regularization_flag == RegularizationType::kL2DECAY) { + if (regularization_flag == phi::RegularizationType::kL2DECAY) { regularized_grad = ctx.AllocateTmpTensor( param_out->dims(), dev_ctx); MLUCnnlOpTensorDesc op_tensor_desc( diff --git a/paddle/fluid/operators/optimizers/momentum_op_mlu.cc b/paddle/fluid/operators/optimizers/momentum_op_mlu.cc index 71af14fd91c8c5f75f469840581052bdc068b2bd..b8fa81b2e71237f30a80ff81709e93e8a53e0951 100644 --- a/paddle/fluid/operators/optimizers/momentum_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/momentum_op_mlu.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/optimizers/momentum_op.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" namespace paddle { namespace operators { @@ -27,10 +28,10 @@ class MLUMomentumOpKernel : public framework::OpKernel { std::string regularization_method = ctx.Attr("regularization_method"); auto regularization_coeff = ctx.Attr("regularization_coeff"); - RegularizationType regularization_flag{ - RegularizationType::kNONE}; // disable regularization + phi::RegularizationType regularization_flag{ + phi::RegularizationType::kNONE}; // disable regularization if (regularization_method == "l2_decay") { - regularization_flag = RegularizationType::kL2DECAY; + regularization_flag = phi::RegularizationType::kL2DECAY; } T mu = static_cast(ctx.Attr("mu")); @@ -57,7 +58,7 @@ class MLUMomentumOpKernel : public framework::OpKernel { Tensor regularized_grad; MLUCnnlTensorDesc param_desc(*param); - if (regularization_flag == RegularizationType::kL2DECAY) { + if (regularization_flag == phi::RegularizationType::kL2DECAY) { regularized_grad = ctx.AllocateTmpTensor(param->dims(), dev_ctx); MLUCnnlOpTensorDesc op_tensor_desc( diff --git a/paddle/fluid/platform/profiler_helper.h b/paddle/fluid/platform/profiler_helper.h index c9e6f13f50524c64d3691a822295b6938ad81493..24c515f5b495682da998caf152bfdec4f38a4ad6 100644 --- a/paddle/fluid/platform/profiler_helper.h +++ b/paddle/fluid/platform/profiler_helper.h @@ -34,6 +34,10 @@ limitations under the License. */ #ifdef PADDLE_WITH_HIP #include #endif +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/enforce.h" +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" @@ -135,6 +139,13 @@ void SynchronizeAllDevice() { PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); } #endif +#ifdef PADDLE_WITH_MLU + int count = GetMLUDeviceCount(); + for (int i = 0; i < count; i++) { + SetMLUDeviceId(i); + PADDLE_ENFORCE_MLU_SUCCESS(cnrtSyncDevice()); + } +#endif } static double ToMegaBytes(size_t bytes) { diff --git a/python/paddle/fluid/tests/unittests/mlu/test_gather_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_gather_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..f0aff986fa1ff1b73d6905083dba79472f6f066e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_gather_op_mlu.py @@ -0,0 +1,179 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append('..') +from op_test import OpTest, convert_float_to_uint16 +import paddle +import paddle.fluid as fluid +from paddle.framework import core +from paddle.fluid.dygraph.base import switch_to_static_graph + +paddle.enable_static() + + +def gather_numpy(x, index, axis): + x_transpose = np.swapaxes(x, 0, axis) + tmp_gather = x_transpose[index, ...] + gather = np.swapaxes(tmp_gather, 0, axis) + return gather + + +class TestGatherOp(OpTest): + def setUp(self): + self.op_type = "gather" + self.place = paddle.MLUPlace(0) + self.__class__.use_mlu = True + self.python_api = paddle.gather + self.config() + xnp = np.random.random(self.x_shape).astype(self.x_type) + self.inputs = { + 'X': xnp, + 'Index': np.array(self.index).astype(self.index_type) + } + self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (10, 20) + self.x_type = "float32" + self.index = [1, 3, 5] + self.index_type = "int32" + + +class TestCase1(TestGatherOp): + def config(self): + """ + For one dimension input + """ + self.x_shape = (100) + self.x_type = "float32" + self.index = [1, 3, 5] + self.index_type = "int32" + + +class TestCase2(TestGatherOp): + def config(self): + """ + For int64_t index type + """ + self.x_shape = (100) + self.x_type = "float32" + self.index = [1, 3, 5] + self.index_type = "int64" + + +class API_TestDygraphGather(unittest.TestCase): + def test_out1(self): + paddle.disable_static() + input_1 = np.array([[1, 2], [3, 4], [5, 6]]).astype('int32') + index_1 = np.array([1, 2]) + input = paddle.to_tensor(input_1) + index = paddle.to_tensor(index_1) + output = paddle.fluid.layers.gather(input, index) + output_np = output.numpy() + expected_output = np.array([[3, 4], [5, 6]]).astype('int32') + self.assertTrue(np.allclose(output_np, expected_output)) + paddle.enable_static() + + def test_out12(self): + paddle.disable_static() + input_1 = np.array([[1, 2], [3, 4], [5, 6]]).astype('int32') + index_1 = np.array([1, 2]) + x = paddle.to_tensor(input_1) + index = paddle.to_tensor(index_1) + output = paddle.gather(x, index, axis=0) + output_np = output.numpy() + expected_output = gather_numpy(input_1, index_1, axis=0) + self.assertTrue(np.allclose(output_np, expected_output)) + paddle.enable_static() + + def test_zero_index(self): + paddle.disable_static() + x = paddle.to_tensor([[1, 2], [3, 4]]).astype('int32') + index = paddle.to_tensor(np.array([]).astype('int64')) + for axis in range(len(x.shape)): + out = paddle.gather(x, index, axis) + expected_shape = list(x.shape) + expected_shape[axis] = 0 + self.assertEqual(list(out.shape), expected_shape) + paddle.enable_static() + + +class TestGathertError(unittest.TestCase): + def test_error1(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + + shape = [8, 9, 6] + x = paddle.fluid.data(shape=shape, dtype='int8', name='x') + axis = paddle.fluid.data(shape=[1], dtype='float32', name='axis') + index = paddle.fluid.data(shape=shape, dtype='int32', name='index') + index_float = paddle.fluid.data( + shape=shape, dtype='float32', name='index_float') + + def test_x_type(): + paddle.gather(x, index) + + self.assertRaises(TypeError, test_x_type) + + def test_index_type(): + paddle.gather(x, index_float) + + self.assertRaises(TypeError, test_index_type) + + def test_axis_dtype(): + paddle.gather(x, index, axis=1.11) + + self.assertRaises(TypeError, test_axis_dtype) + + def test_axis_dtype1(): + paddle.gather(x, index, axis=axis) + + self.assertRaises(TypeError, test_axis_dtype1) + + def test_error2(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + + shape = [8, 9, 6] + x = fluid.data(shape=shape, dtype='int8', name='x') + index = fluid.data(shape=shape, dtype='int32', name='mask') + index_float = fluid.data( + shape=shape, dtype='float32', name='index_float') + + def test_x_type(): + paddle.fluid.layers.gather(x, index) + + self.assertRaises(TypeError, test_x_type) + + def test_index_type(): + paddle.fluid.layers.gather(x, index_float) + + self.assertRaises(TypeError, test_index_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index f4f1e7a3d5067109f20ac50f672e151f327f898e..f2dc16071c2c8a85497d2478b7264cfa55c84df9 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -216,9 +216,9 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): place = _current_expected_place() elif not isinstance(place, (core.Place, core.CPUPlace, core.CUDAPinnedPlace, core.CUDAPlace, core.NPUPlace, core.XPUPlace, - core.CustomPlace)): + core.MLUPlace, core.CustomPlace)): raise ValueError( - "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.CustomPlace" + "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.MLUPlace, paddle.CustomPlace" ) if not isinstance(data, np.ndarray):