diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 5fd581220097b8a690546e7b6a6e7d01a9ba490b..724e3cc1e2ee80441a533dbd35d2acaf7afbda7b 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -396,7 +396,8 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place, TENSOR* dst) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); const platform::DeviceContext* dev_ctx; - if (platform::is_gpu_place(dst_place) || platform::is_npu_place(dst_place)) { + if (platform::is_gpu_place(dst_place) || platform::is_npu_place(dst_place) || + platform::is_mlu_place(dst_place)) { dev_ctx = pool.Get(dst_place); } else { dev_ctx = pool.Get(src.place()); @@ -1048,6 +1049,29 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, #else PADDLE_THROW(platform::errors::Unimplemented( "XPUPlace is not supported when not compiled with XPU")); +#endif + } else if (platform::is_mlu_place(tensor.place())) { +#ifdef PADDLE_WITH_MLU + constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB + std::unique_ptr buf(new char[kBufSize]); + auto& mlu_dev_ctx = + static_cast(dev_ctx); + platform::CPUPlace cpu; + uintptr_t data = reinterpret_cast(data_ptr); + while (size != 0) { + size_t size_to_write = std::min(kBufSize, static_cast(size)); + memory::Copy(cpu, buf.get(), + BOOST_GET_CONST(platform::MLUPlace, tensor.place()), + reinterpret_cast(data), size_to_write, + mlu_dev_ctx.stream()); + mlu_dev_ctx.Wait(); + os.write(buf.get(), size_to_write); + data += size_to_write; + size -= size_to_write; + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "MLUPlace is not supported when not compiled with MLU")); #endif } else if (platform::is_npu_place(tensor.place())) { #ifdef PADDLE_WITH_ASCEND_CL @@ -1127,9 +1151,11 @@ void TensorFromStream(std::istream& is, Tensor* tensor, size_t size = tensor->numel() * framework::SizeOfType(desc.data_type()); if (platform::is_gpu_place(dev_ctx.GetPlace()) || platform::is_xpu_place(dev_ctx.GetPlace()) || + platform::is_mlu_place(dev_ctx.GetPlace()) || platform::is_npu_place(dev_ctx.GetPlace())) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ - defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_ASCEND_CL) + defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_MLU) || \ + defined(PADDLE_WITH_ASCEND_CL) Tensor cpu_tensor; cpu_tensor.Resize(framework::make_ddim(shape)); framework::VisitDataType( @@ -1148,6 +1174,9 @@ void TensorFromStream(std::istream& is, Tensor* tensor, } else if (platform::is_xpu_place(dev_ctx.GetPlace())) { PADDLE_THROW(platform::errors::Unimplemented( "XPUPlace is not supported when not compiled with XPU")); + } else if (platform::is_mlu_place(dev_ctx.GetPlace())) { + PADDLE_THROW(platform::errors::Unimplemented( + "MLUPlace is not supported when not compiled with MLU")); } else { PADDLE_THROW(platform::errors::Unimplemented( "NPUPlace is not supported when not compiled with NPU")); @@ -1192,9 +1221,11 @@ void TensorFromStream(std::istream& is, Tensor* tensor, size_t size = tensor->numel() * framework::SizeOfType(desc.data_type()); if (platform::is_gpu_place(dev_ctx.GetPlace()) || platform::is_xpu_place(dev_ctx.GetPlace()) || + platform::is_mlu_place(dev_ctx.GetPlace()) || platform::is_npu_place(dev_ctx.GetPlace())) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ - defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_ASCEND_CL) + defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_MLU) || \ + defined(PADDLE_WITH_ASCEND_CL) Tensor cpu_tensor; cpu_tensor.Resize(framework::make_ddim(dims)); framework::VisitDataType( @@ -1213,6 +1244,9 @@ void TensorFromStream(std::istream& is, Tensor* tensor, } else if (platform::is_xpu_place(dev_ctx.GetPlace())) { PADDLE_THROW(platform::errors::Unimplemented( "XPUPlace is not supported when not compiled with XPU")); + } else if (platform::is_mlu_place(dev_ctx.GetPlace())) { + PADDLE_THROW(platform::errors::Unimplemented( + "MLUPlace is not supported when not compiled with MLU")); } else { PADDLE_THROW(platform::errors::Unimplemented( "NPUPlace is not supported when not compiled with NPU")); diff --git a/paddle/fluid/memory/detail/buddy_allocator.cc b/paddle/fluid/memory/detail/buddy_allocator.cc index 96fcd6254d88589c728f26eb7f187d7290f4442b..b02fb6642be3fd4ade7dc1b4ed7642be28cc7757 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.cc +++ b/paddle/fluid/memory/detail/buddy_allocator.cc @@ -231,9 +231,8 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool( allocate_bytes = DeviceAllocateSize(&platform::NPUInitAllocSize, &platform::NPUReallocSize, request_bytes); #elif defined(PADDLE_WITH_MLU) - allocate_bytes = - DeviceAllocateSize(&platform::MLUInitAllocSize(), - &platform::MLUReallocSize(), request_bytes); + allocate_bytes = DeviceAllocateSize(&platform::MLUInitAllocSize, + &platform::MLUReallocSize, request_bytes); #endif // Allocate a new block diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index e6aed2c90daced15b6f4dc79e68c7dfd8cd811d5..153e19a9f1450b87c7d46d7627e781f4c93845f9 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -508,6 +508,9 @@ void Copy(platform::CPUPlace dst_place, platform::RecordEvent record_event("MLUMemcpyD2HAsync:MLU->CPU"); platform::MLUMemcpyD2HAsync(dst, src, num, stream); } else { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + static_cast(pool.Get(src_place))->Wait(); + VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place << " to " << dst_place; platform::RecordEvent record_event("MLUMemcpyD2HSync:MLU->CPU"); @@ -530,6 +533,9 @@ void Copy(platform::MLUPlace dst_place, platform::RecordEvent record_event("MLUMemcpyH2DAsync:CPU->MLU"); platform::MLUMemcpyH2DAsync(dst, src, num, stream); } else { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + static_cast(pool.Get(src_place))->Wait(); + VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place << " to " << dst_place; platform::RecordEvent record_event("MLUMemcpyH2DSync:CPU->MLU"); @@ -554,6 +560,10 @@ void Copy(platform::MLUPlace dst_place, "MLUMemcpyD2DAsync(same_mlu):MLU->MLU"); platform::MLUMemcpyD2DAsync(dst, src, num, stream); } else { + platform::DeviceContextPool& pool = + platform::DeviceContextPool::Instance(); + static_cast(pool.Get(src_place))->Wait(); + VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place << " to " << dst_place; platform::RecordEvent record_event("MLUMemcpyD2DSync(same_mlu):MLU->MLU"); diff --git a/paddle/fluid/operators/mean_op_mlu.cc b/paddle/fluid/operators/mean_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..9862c2bd95256541ebeed20b202991f7bf6d3bc8 --- /dev/null +++ b/paddle/fluid/operators/mean_op_mlu.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/mean_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/fluid/platform/device/mlu/device_context.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +template +class MeanMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + + const T* in_data = input->data(); + T* out_data = output->mutable_data(context.GetPlace()); + auto numel = input->numel(); + auto rank = input->dims().size(); + auto place = context.GetPlace(); + auto stream = context.template device_context().stream(); + + if (rank == 0) { // scalar + auto mlu_place = BOOST_GET(platform::MLUPlace, place); + memory::Copy(mlu_place, out_data, mlu_place, in_data, numel * sizeof(T), + stream); + return; + } + + std::vector reduce_dims; + reduce_dims.reserve(rank); + for (decltype(rank) i = 0; i < rank; ++i) { + reduce_dims.push_back(i); + } + + MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input->type())); + MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(output->type())); + + MLUCnnlReduceDesc reduction_desc( + reduce_dims, CNNL_REDUCE_AVG, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + + MLUCnnl::Reduce(context, true /*need_workspace*/, reduction_desc.get(), + nullptr, input_desc.get(), + reinterpret_cast(in_data), 0 /*indices_size*/, + nullptr, nullptr, output_desc.get(), + reinterpret_cast(out_data)); + } +}; + +template +class MeanMLUGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto output_grad = context.Input(framework::GradVarName("Out")); + PADDLE_ENFORCE_EQ(output_grad->numel(), 1, + platform::errors::InvalidArgument( + "Mean Gradient Input Tensor len should be 1. But " + "received Out@Grad's elements num is %d.", + output_grad->numel())); + auto input_grad = context.Output(framework::GradVarName("X")); + input_grad->mutable_data(context.GetPlace()); + + auto in_data = output_grad->data(); + auto numel = input_grad->numel(); + auto rank = input_grad->dims().size(); + auto out_data = input_grad->data(); + auto place = context.GetPlace(); + auto stream = context.template device_context().stream(); + + if (rank == 0) { // scalar + auto mlu_place = BOOST_GET(platform::MLUPlace, place); + memory::Copy(mlu_place, out_data, mlu_place, in_data, numel * sizeof(T), + stream); + return; + } + + // means + Tensor mean_var(output_grad->type()); + mean_var.mutable_data(input_grad->dims(), context.GetPlace()); + MLUCnnlTensorDesc mean_var_desc(mean_var, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(mean_var.type())); + auto value = static_cast(1.0 / static_cast(input_grad->numel())); + MLUCnnl::Fill(context, value, mean_var_desc.get(), GetBasePtr(&mean_var)); + + // means mul output_grad + MLUCnnlTensorDesc in_desc(*output_grad, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(output_grad->type())); + MLUCnnlTensorDesc out_desc(*input_grad, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input_grad->type())); + + MLUCnnlOpTensorDesc op_tensor_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN); + + MLUCnnl::OpTensor(context, op_tensor_desc.get(), in_desc.get(), + reinterpret_cast(in_data), + mean_var_desc.get(), GetBasePtr(&mean_var), + out_desc.get(), reinterpret_cast(out_data), + ToCnnlDataType()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(mean, ops::MeanMLUKernel, + ops::MeanMLUKernel); +REGISTER_OP_MLU_KERNEL(mean_grad, ops::MeanMLUGradKernel, + ops::MeanMLUGradKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index ab398a92c2972986766f58d29f8c98cb27258655..8082c45d14b95dc1359641e1e13e751bd543dce2 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -45,12 +45,22 @@ enum MLULogicMethod { CNNL_LOGIC_OP_OR = 7, }; +inline const void* GetBasePtr(const Tensor* t) { return t->data(); } + +inline void* GetBasePtr(Tensor* t) { return t->data(); } + template inline cnnlDataType_t ToCnnlDataType(const T& t) { auto type = framework::ToDataType(t); return ToCnnlDataType(type); } +template +inline cnnlDataType_t ToCnnlDataType() { + auto type = framework::ToDataType(std::type_index(typeid(T))); + return ToCnnlDataType(type); +} + template <> inline cnnlDataType_t ToCnnlDataType(const framework::proto::VarType::Type& t) { cnnlDataType_t type = CNNL_DTYPE_FLOAT; @@ -89,11 +99,12 @@ NarrowT CheckedNarrowing(const WideT& wide) { return narrow; } -static cnnlHandle_t GetHandleFromCTX(const ExecutionContext& ctx) { +inline static cnnlHandle_t GetHandleFromCTX(const ExecutionContext& ctx) { return ctx.template device_context().cnnl_handle(); } -static const MLUDeviceContext& GetDevCtxFromCTX(const ExecutionContext& ctx) { +inline static const MLUDeviceContext& GetDevCtxFromCTX( + const ExecutionContext& ctx) { return ctx.template device_context(); } diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..ef7e9940f0590b2d3db9d14dee12098dbed51550 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/fluid/platform/device/mlu/device_context.h" + +namespace paddle { +namespace operators { + +template +class ReduceMeanMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + + bool reduce_all = context.Attr("reduce_all"); + auto dims = context.Attr>("dim"); + auto input_dims = framework::vectorize(input->dims()); + const auto& input_dim_size = input->dims().size(); + std::vector reduce_dims; + if (reduce_all) { + for (size_t i = 0; i < input_dims.size(); i++) { + reduce_dims.push_back(static_cast(i)); + } + } else { + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) { + reduce_dims.push_back(dims[i] + input_dim_size); + } else { + reduce_dims.push_back(dims[i]); + } + } + } + + MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input->type())); + MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(output->type())); + + MLUCnnlReduceDesc reduction_desc( + reduce_dims, CNNL_REDUCE_AVG, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + + MLUCnnl::Reduce(context, true /*need_workspace*/, reduction_desc.get(), + nullptr, input_desc.get(), GetBasePtr(input), + 0 /*indices_size*/, nullptr, nullptr, output_desc.get(), + GetBasePtr(output)); + } +}; + +template +class ReduceMeanGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("X"); + auto* output_grad = context.Input(framework::GradVarName("Out")); + auto* input_grad = context.Output(framework::GradVarName("X")); + input_grad->mutable_data(context.GetPlace()); + + bool reduce_all = context.Attr("reduce_all"); + auto reduce_dims = context.Attr>("dim"); + auto input_dims = framework::vectorize(input->dims()); + + int reduce_numel = 1; + if (reduce_all) { + reduce_dims.clear(); + for (size_t d = 0; d < input_dims.size(); ++d) { + reduce_dims.push_back(static_cast(d)); + } + } + for (auto& d : reduce_dims) { + if (d < 0) { + d = d + input_dims.size(); + } + reduce_numel *= input_dims[d]; + } + + Tensor tmp_output_grad(output_grad->type()); + auto tmp_output_dims = input_dims; + for (auto d : reduce_dims) { + tmp_output_dims[d] = 1; + } + tmp_output_grad.ShareDataWith(*output_grad); + tmp_output_grad.Resize(framework::make_ddim(tmp_output_dims)); + + MLUCnnlTensorDesc output_grad_desc(tmp_output_grad, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(tmp_output_grad.type())); + MLUCnnlTensorDesc input_grad_desc(*input_grad, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input_grad->type())); + + auto value = static_cast(1.0 / static_cast(reduce_numel)); + MLUCnnl::Fill(context, value, input_grad_desc.get(), + GetBasePtr(input_grad)); + + MLUCnnlOpTensorDesc op_tensor_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN); + + MLUCnnl::OpTensor(context, op_tensor_desc.get(), output_grad_desc.get(), + GetBasePtr(&tmp_output_grad), input_grad_desc.get(), + GetBasePtr(input_grad), input_grad_desc.get(), + GetBasePtr(input_grad), ToCnnlDataType()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(reduce_mean, ops::ReduceMeanMLUKernel, + ops::ReduceMeanMLUKernel); +REGISTER_OP_MLU_KERNEL(reduce_mean_grad, ops::ReduceMeanGradMLUKernel, + ops::ReduceMeanGradMLUKernel); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index b31b7456ebca73d6e73774f05268038ab756530f..1fe66869194531602b3e0b552727f5e6b0a491d4 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -232,6 +232,13 @@ T TensorGetElement(const framework::Tensor &self, size_t offset) { auto p = BOOST_GET_CONST(platform::CUDAPlace, self.place()); paddle::memory::Copy(platform::CPUPlace(), &b, p, a + offset, sizeof(T), nullptr); +#endif + } else if (platform::is_mlu_place(self.place())) { +#ifdef PADDLE_WITH_MLU + const T *a = self.data(); + auto p = BOOST_GET_CONST(platform::MLUPlace, self.place()); + paddle::memory::Copy(platform::CPUPlace(), &b, p, a + offset, sizeof(T), + nullptr); #endif } else if (platform::is_npu_place(self.place())) { #if defined(PADDLE_WITH_ASCEND_CL) @@ -267,6 +274,13 @@ void TensorSetElement(framework::Tensor *self, size_t offset, T elem) { T *a = self->mutable_data(p); paddle::memory::Copy(p, a + offset, platform::CPUPlace(), &elem, sizeof(T), nullptr); +#endif + } else if (platform::is_mlu_place(self->place())) { +#ifdef PADDLE_WITH_MLU + auto p = BOOST_GET_CONST(platform::MLUPlace, self->place()); + T *a = self->mutable_data(p); + paddle::memory::Copy(p, a + offset, platform::CPUPlace(), &elem, sizeof(T), + nullptr); #endif } else if (platform::is_npu_place(self->place())) { #if defined(PADDLE_WITH_ASCEND_CL) @@ -543,6 +557,11 @@ inline framework::Tensor *_getTensor(const framework::Tensor &self, #ifdef PADDLE_WITH_XPU output->mutable_data(BOOST_GET_CONST(platform::XPUPlace, place), self.type()); +#endif + } else if (platform::is_mlu_place(place)) { +#ifdef PADDLE_WITH_MLU + output->mutable_data(BOOST_GET_CONST(platform::MLUPlace, place), + self.type()); #endif } else { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -845,8 +864,13 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor, size_t copy_bytes = sizeof_dtype * numel; auto p = BOOST_GET_CONST(platform::MLUPlace, tensor.place()); - paddle::memory::Copy(platform::CPUPlace(), py_arr.mutable_data(), p, - tensor_buf_ptr, copy_bytes, nullptr); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &ctx = *pool.Get(tensor.place()); + paddle::memory::Copy( + platform::CPUPlace(), py_arr.mutable_data(), p, tensor_buf_ptr, + copy_bytes, + reinterpret_cast(ctx).stream()); + ctx.Wait(); return py_arr; #else PADDLE_THROW(platform::errors::PermissionDenied( diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index b46a10c8c79d895235fcc24fd10aab64fcd90241..67697fcfd839871c4795f334ee4ff1fe0e178332 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -803,6 +803,10 @@ if (WITH_MKLDNN) add_subdirectory(mkldnn) endif() +if (WITH_MLU) + add_subdirectory(mlu) +endif() + add_subdirectory(asp) add_subdirectory(ir) diff --git a/python/paddle/fluid/tests/unittests/mlu/CMakeLists.txt b/python/paddle/fluid/tests/unittests/mlu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8fcd3f196dc1943d61bd64dad977dab0bd9922af --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/CMakeLists.txt @@ -0,0 +1,9 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +if (WITH_MLU) + foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) + endforeach(TEST_OP) + +endif() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_mean_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_mean_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..36419327db6b0535f9a6836e3716926517b0df02 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_mean_op_mlu.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import core + +paddle.enable_static() +SEED = 2021 + + +class TestMean(OpTest): + def setUp(self): + self.set_mlu() + self.place = paddle.device.MLUPlace(0) + self.op_type = "mean" + self.init_dtype() + + x = np.random.random([1, 100]).astype(self.dtype) + self.inputs = {'X': x} + + self.attrs = {} + np_out = np.mean(x) + self.outputs = {'Out': np_out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestMeanFP16(OpTest): + def setUp(self): + self.set_mlu() + self.place = paddle.MLUPlace(0) + self.op_type = "mean" + self.init_dtype() + + x = np.random.random([3, 200]).astype(self.dtype) + self.inputs = {'X': x} + + self.attrs = {} + np_out = np.mean(x) + self.outputs = {'Out': np_out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_reduce_mean_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_reduce_mean_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..c0be644c791154bcf6b366d9c7397fdd7b515678 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_reduce_mean_op_mlu.py @@ -0,0 +1,185 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() + + +class TestMeanOp(OpTest): + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.outputs = {'Out': self.inputs['X'].mean(axis=0)} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestMeanOp5D(TestMeanOp): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.inputs = { + 'X': np.random.random((1, 2, 5, 6, 10)).astype("float32") + } + self.outputs = {'Out': self.inputs['X'].mean(axis=0)} + + +class TestMeanOp6D(TestMeanOp): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.inputs = { + 'X': np.random.random((1, 1, 2, 5, 6, 10)).astype("float32") + } + self.outputs = {'Out': self.inputs['X'].mean(axis=0)} + + +class TestMeanOp8D(TestMeanOp): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.inputs = { + 'X': np.random.random((1, 3, 1, 2, 1, 4, 3, 10)).astype("float32") + } + self.attrs = {'dim': (0, 3)} + self.outputs = {'Out': self.inputs['X'].mean(axis=(0, 3))} + + +class Test1DReduce(TestMeanOp): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.inputs = {'X': np.random.random(120).astype("float32")} + self.outputs = {'Out': self.inputs['X'].mean(axis=0)} + + +class Test2DReduce0(Test1DReduce): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.attrs = {'dim': [0]} + self.inputs = {'X': np.random.random((20, 10)).astype("float32")} + self.outputs = {'Out': self.inputs['X'].mean(axis=0)} + + +class Test2DReduce1(Test1DReduce): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.attrs = {'dim': [1]} + self.inputs = {'X': np.random.random((20, 10)).astype("float32")} + self.outputs = { + 'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim'])) + } + + +class Test3DReduce0(Test1DReduce): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.attrs = {'dim': [1]} + self.inputs = {'X': np.random.random((5, 6, 7)).astype("float32")} + self.outputs = { + 'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim'])) + } + + +class Test3DReduce1(Test1DReduce): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.attrs = {'dim': [2]} + self.inputs = {'X': np.random.random((5, 6, 7)).astype("float32")} + self.outputs = { + 'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim'])) + } + + +class Test3DReduce2(Test1DReduce): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.attrs = {'dim': [-2]} + self.inputs = {'X': np.random.random((5, 6, 7)).astype("float32")} + self.outputs = { + 'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim'])) + } + + +class Test3DReduce3(Test1DReduce): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.attrs = {'dim': [1, 2]} + self.inputs = {'X': np.random.random((5, 6, 7)).astype("float32")} + self.outputs = { + 'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim'])) + } + + +class TestKeepDimReduce(Test1DReduce): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': [1], 'keep_dim': True} + self.outputs = { + 'Out': self.inputs['X'].mean( + axis=tuple(self.attrs['dim']), keepdims=self.attrs['keep_dim']) + } + + +class TestKeepDim8DReduce(Test1DReduce): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.inputs = { + 'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float32") + } + self.attrs = {'dim': (3, 4, 5), 'keep_dim': True} + self.outputs = { + 'Out': self.inputs['X'].mean( + axis=tuple(self.attrs['dim']), keepdims=self.attrs['keep_dim']) + } + + +class TestReduceAll(Test1DReduce): + def setUp(self): + self.set_mlu() + self.op_type = "reduce_mean" + self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")} + self.attrs = {'reduce_all': True} + self.outputs = {'Out': self.inputs['X'].mean()} + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_relu_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_relu_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..25c50f67949e7ce565e62f2a6a021ccf65375e16 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_relu_op_mlu.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +class TestRelu(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "relu" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.rand(3, 2).astype(self.dtype) + out = x + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestReluFp16(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "relu" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.rand(3, 2).astype(self.dtype) + out = x + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +class TestReluNeg(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "relu" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.array([0.1, -0.1, -1.0]).astype(self.dtype) + out = np.array([0.1, 0.0, 0.0]).astype(self.dtype) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestReluNet(unittest.TestCase): + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + sum = paddle.add(a, b) + z = paddle.nn.functional.relu(sum) + + fc_1 = fluid.layers.fc(input=z, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_mlu: + place = paddle.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + cpu_pred, cpu_loss = self._test(False) + mlu_pred, mlu_loss = self._test(True) + + self.assertTrue(np.allclose(mlu_pred, cpu_pred)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index ec59c27558332a06ad1e18a5bd887e8e22c67021..01d851469a8d192a2aea72eda0f7b09e18f7fdc9 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -326,6 +326,9 @@ class OpTest(unittest.TestCase): def is_npu_op_test(): return hasattr(cls, "use_npu") and cls.use_npu == True + def is_mlu_op_test(): + return hasattr(cls, "use_mlu") and cls.use_mlu == True + if not hasattr(cls, "op_type"): raise AssertionError( "This test do not have op_type in class attrs, " @@ -348,7 +351,8 @@ class OpTest(unittest.TestCase): and not is_xpu_op_test() \ and not is_mkldnn_op_test() \ and not is_rocm_op_test() \ - and not is_npu_op_test(): + and not is_npu_op_test() \ + and not is_mlu_op_test(): raise AssertionError( "This test of %s op needs check_grad with fp64 precision." % cls.op_type) @@ -1297,7 +1301,8 @@ class OpTest(unittest.TestCase): # No effect on original OpTest # Currently not support ParallelExecutor on XPUPlace. if not paddle.is_compiled_with_xpu( - ) and not paddle.is_compiled_with_npu(): + ) and not paddle.is_compiled_with_npu( + ) and not paddle.is_compiled_with_mlu(): self.check_inplace_output_with_place( place, no_check_set=no_check_set, inplace_atol=inplace_atol) @@ -1547,11 +1552,9 @@ class OpTest(unittest.TestCase): delta=numeric_grad_delta, in_place=in_place) for input_to_check in inputs_to_check ] - analytic_grads = self._get_gradient(inputs_to_check, place, output_names, no_grad_set, user_defined_grad_outputs) - # comparison of bf16 results will happen as fp32 # loop over list of grads and convert bf16 to fp32 fp32_analytic_grads = []