diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 4853e5324c30f56e90cec9a7c75a48686b58a4b8..90d665eb93bcbad0e7d7e6da25c0f2d645fea00b 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -16,6 +16,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/float16.h" +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#endif namespace paddle { namespace operators { @@ -102,6 +105,19 @@ class CastOp : public framework::OperatorWithKernel { framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); } +#endif +#ifdef PADDLE_WITH_MLU + auto src_type = static_cast(ctx.Attr("in_dtype")); + auto dst_type = static_cast(ctx.Attr("out_dtype")); + if (src_type == dst_type || MLUSupportsCast(src_type, dst_type)) { + return framework::OpKernelType(tensor->type(), tensor_place); + } else { + VLOG(3) << "MLU not support cast type: " + << framework::DataTypeToString(src_type) + << " to type: " << framework::DataTypeToString(dst_type) + << ", fallbacking to CPU one!"; + return framework::OpKernelType(tensor->type(), platform::CPUPlace()); + } #endif return framework::OpKernelType(tensor->type(), tensor_place); } diff --git a/paddle/fluid/operators/cast_op_mlu.cc b/paddle/fluid/operators/cast_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..f28889e7acf8773e2c55044037eb6bbde71ce12f --- /dev/null +++ b/paddle/fluid/operators/cast_op_mlu.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/cast_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/fluid/platform/device/mlu/device_context.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class CastMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto src_type = static_cast(ctx.Attr("in_dtype")); + auto dst_type = static_cast(ctx.Attr("out_dtype")); + auto place = ctx.GetPlace(); + + if (src_type == dst_type) { + auto& dev_ctx = ctx.template device_context(); + output->mutable_data(place); + framework::TensorCopy(*input, place, dev_ctx, output); + return; + } + + PADDLE_ENFORCE_EQ(MLUSupportsCast(src_type, dst_type), true, + platform::errors::InvalidArgument( + "MLU not support cast [%d] to [%d]", + framework::DataTypeToString(src_type), + framework::DataTypeToString(dst_type))); + + switch (dst_type) { + case VT::FP32: + output->mutable_data(place); + break; + case VT::FP16: + output->mutable_data(place); + break; + case VT::INT32: + output->mutable_data(place); + break; + case VT::INT16: + output->mutable_data(place); + break; + case VT::INT8: + output->mutable_data(place); + break; + case VT::UINT8: + output->mutable_data(place); + break; + case VT::BOOL: + output->mutable_data(place); + break; + case VT::INT64: + output->mutable_data(place); + break; + default: + PADDLE_THROW(platform::errors::Unavailable( + "Not supported cast %d -> %d", + framework::DataTypeToString(src_type), + framework::DataTypeToString(dst_type))); + } + + MLUCnnlTensorDesc input_desc(*input); + MLUCnnlTensorDesc output_desc(*output); + cnnlCastDataType_t cast_type = GetCastDataType(src_type, dst_type); + + MLUCnnl::Cast(ctx, cast_type, input_desc.get(), GetBasePtr(input), + output_desc.get(), GetBasePtr(output)); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_MLU_KERNEL(cast, ops::CastMLUKernel, ops::CastMLUKernel, + ops::CastMLUKernel, ops::CastMLUKernel, + ops::CastMLUKernel, ops::CastMLUKernel, + ops::CastMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index c877b7130c55c728a635a4170405447115551244..e93ec32e2bd5fc808861516db7c75c48ff77ff1a 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -20,6 +20,29 @@ limitations under the License. */ namespace paddle { namespace operators { +cnnlCastDataType_t GetCastDataType(const VT::Type& src_type, + const VT::Type& dst_type) { + cnnlCastDataType_t cast_type = CNNL_CAST_FLOAT_TO_HALF; + for (auto it = MLU_SUPPORTED_CAST_TYPE.begin(); + it != MLU_SUPPORTED_CAST_TYPE.end(); ++it) { + if (it->first.first == src_type && it->first.second == dst_type) { + cast_type = it->second; + break; + } + } + return cast_type; +} + +bool MLUSupportsCast(const VT::Type& src_type, const VT::Type& dst_type) { + for (auto it = MLU_SUPPORTED_CAST_TYPE.begin(); + it != MLU_SUPPORTED_CAST_TYPE.end(); ++it) { + if (it->first.first == src_type && it->first.second == dst_type) { + return true; + } + } + return false; +} + class MLUCnnlTensorDescPool { public: cnnlTensorDescriptor_t Pop() { @@ -153,6 +176,10 @@ MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor, } } +MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor) + : MLUCnnlTensorDesc(tensor, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(tensor.type())) {} + MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor, cnnlTensorLayout_t layout, const cnnlDataType_t tensor_dtype, @@ -1848,7 +1875,7 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { if (is_training) { /* - * If in Paddle, running_mean_output = momentum * runnning_mean_input + + * In Paddle, running_mean_output = momentum * runnning_mean_input + * (1 - momentum) * batch_mean. However, In CNNL, * running_mean_output = (1 - momentum) * running_mean_input + * momentum * batch_mean. So we pass (1.0 - momentum) to momentum param. diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 8082c45d14b95dc1359641e1e13e751bd543dce2..67b6b3ec1614dd51adc62cf418d9eadadf276ca9 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -74,6 +74,9 @@ inline cnnlDataType_t ToCnnlDataType(const framework::proto::VarType::Type& t) { case framework::proto::VarType::INT8: type = CNNL_DTYPE_INT8; break; + case framework::proto::VarType::INT16: + type = CNNL_DTYPE_INT16; + break; case framework::proto::VarType::INT32: type = CNNL_DTYPE_INT32; break; @@ -83,6 +86,9 @@ inline cnnlDataType_t ToCnnlDataType(const framework::proto::VarType::Type& t) { case framework::proto::VarType::BOOL: type = CNNL_DTYPE_BOOL; break; + case framework::proto::VarType::UINT8: + type = CNNL_DTYPE_UINT8; + break; default: break; } @@ -108,6 +114,47 @@ inline static const MLUDeviceContext& GetDevCtxFromCTX( return ctx.template device_context(); } +using VT = framework::proto::VarType; +const std::map, cnnlCastDataType_t> + MLU_SUPPORTED_CAST_TYPE = { + {{VT::FP32, /*cast to*/ VT::FP16}, CNNL_CAST_FLOAT_TO_HALF}, + {{VT::FP32, /*cast to*/ VT::INT32}, CNNL_CAST_FLOAT_TO_INT32}, + {{VT::FP32, /*cast to*/ VT::INT16}, CNNL_CAST_FLOAT_TO_INT16}, + {{VT::FP32, /*cast to*/ VT::INT8}, CNNL_CAST_FLOAT_TO_INT8}, + {{VT::FP32, /*cast to*/ VT::UINT8}, CNNL_CAST_FLOAT_TO_UINT8}, + {{VT::FP32, /*cast to*/ VT::BOOL}, CNNL_CAST_FLOAT_TO_BOOL}, + {{VT::FP16, /*cast to*/ VT::FP32}, CNNL_CAST_HALF_TO_FLOAT}, + {{VT::FP16, /*cast to*/ VT::INT32}, CNNL_CAST_HALF_TO_INT32}, + {{VT::FP16, /*cast to*/ VT::INT16}, CNNL_CAST_HALF_TO_INT16}, + {{VT::FP16, /*cast to*/ VT::INT8}, CNNL_CAST_HALF_TO_INT8}, + {{VT::FP16, /*cast to*/ VT::UINT8}, CNNL_CAST_HALF_TO_UINT8}, + {{VT::FP16, /*cast to*/ VT::BOOL}, CNNL_CAST_HALF_TO_BOOL}, + {{VT::INT32, /*cast to*/ VT::FP32}, CNNL_CAST_INT32_TO_FLOAT}, + {{VT::INT32, /*cast to*/ VT::FP16}, CNNL_CAST_INT32_TO_HALF}, + {{VT::INT32, /*cast to*/ VT::INT64}, CNNL_CAST_INT32_TO_INT64}, + {{VT::INT32, /*cast to*/ VT::INT16}, CNNL_CAST_INT32_TO_INT16}, + {{VT::INT32, /*cast to*/ VT::INT8}, CNNL_CAST_INT32_TO_INT8}, + {{VT::INT32, /*cast to*/ VT::BOOL}, CNNL_CAST_INT32_TO_BOOL}, + {{VT::INT16, /*cast to*/ VT::FP32}, CNNL_CAST_INT16_TO_FLOAT}, + {{VT::INT16, /*cast to*/ VT::FP16}, CNNL_CAST_INT16_TO_HALF}, + {{VT::INT16, /*cast to*/ VT::INT32}, CNNL_CAST_INT16_TO_INT32}, + {{VT::INT8, /*cast to*/ VT::FP32}, CNNL_CAST_INT8_TO_FLOAT}, + {{VT::INT8, /*cast to*/ VT::FP16}, CNNL_CAST_INT8_TO_HALF}, + {{VT::INT8, /*cast to*/ VT::INT32}, CNNL_CAST_INT8_TO_INT32}, + {{VT::UINT8, /*cast to*/ VT::FP32}, CNNL_CAST_UINT8_TO_FLOAT}, + {{VT::UINT8, /*cast to*/ VT::FP16}, CNNL_CAST_UINT8_TO_HALF}, + {{VT::UINT8, /*cast to*/ VT::INT64}, CNNL_CAST_UINT8_TO_INT64}, + {{VT::UINT8, /*cast to*/ VT::INT32}, CNNL_CAST_UINT8_TO_INT32}, + {{VT::BOOL, /*cast to*/ VT::FP32}, CNNL_CAST_BOOL_TO_FLOAT}, + {{VT::BOOL, /*cast to*/ VT::FP16}, CNNL_CAST_BOOL_TO_HALF}, + {{VT::BOOL, /*cast to*/ VT::INT32}, CNNL_CAST_BOOL_TO_INT32}, + {{VT::INT64, /*cast to*/ VT::INT32}, CNNL_CAST_INT64_TO_INT32}, +}; + +cnnlCastDataType_t GetCastDataType(const VT::Type& src_type, + const VT::Type& dst_type); +bool MLUSupportsCast(const VT::Type& src_type, const VT::Type& dst_type); + cnnlDeviceType_t GetCnnlDev(int dev_ordinal); using CnnlTensorDesc = cnnlTensorDescriptor_t; @@ -150,6 +197,8 @@ class MLUCnnlTensorDesc { MLUCnnlTensorDesc(const Tensor& tensor, const cnnlTensorLayout_t layout, const cnnlDataType_t tensor_dtype); + explicit MLUCnnlTensorDesc(const Tensor& tensor); + MLUCnnlTensorDesc(const Tensor& tensor, cnnlTensorLayout_t layout, const cnnlDataType_t tensor_dtype, int position); diff --git a/paddle/fluid/operators/scale_op_mlu.cc b/paddle/fluid/operators/scale_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..8d9690a866ae26abe0817d98636a45d58735aefd --- /dev/null +++ b/paddle/fluid/operators/scale_op_mlu.cc @@ -0,0 +1,106 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/scale_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class ScaleMLUKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& ctx) const { + auto& dev_ctx = GetDevCtxFromCTX(ctx); + auto* in_var = ctx.InputVar("X"); + auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var); + + // cnnl require input, scale, bias with same type. And all in device side. + auto& scale = ctx.Attr("scale"); + framework::Tensor scale_tensor; + if (ctx.HasInput("ScaleTensor")) { + framework::Tensor float_scale_tensor = + *ctx.Input("ScaleTensor"); + if (float_scale_tensor.type() != in->type()) { + scale_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); + MLUCnnlTensorDesc float_scale_desc(float_scale_tensor); + MLUCnnlTensorDesc final_scale_desc(scale_tensor); + cnnlCastDataType_t cast_type = + GetCastDataType(float_scale_tensor.type(), scale_tensor.type()); + MLUCnnl::Cast(ctx, cast_type, float_scale_desc.get(), + GetBasePtr(&float_scale_tensor), final_scale_desc.get(), + GetBasePtr(&scale_tensor)); + } else { + scale_tensor = float_scale_tensor; + } + } else { + scale_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); + MLUCnnlTensorDesc scale_desc(scale_tensor); + MLUCnnl::Fill(ctx, scale, scale_desc.get(), GetBasePtr(&scale_tensor)); + } + + auto& bias = ctx.Attr("bias"); + framework::Tensor bias_tensor = + ctx.AllocateTmpTensor({1}, dev_ctx); + MLUCnnlTensorDesc bias_desc(bias_tensor); + MLUCnnl::Fill(ctx, bias, bias_desc.get(), GetBasePtr(&bias_tensor)); + + auto* out_var = ctx.OutputVar("Out"); + if (in_var->IsType() && in_var != out_var) { + auto& in_slr = in_var->Get(); + auto* out_slr = out_var->GetMutable(); + out_slr->set_rows(in_slr.rows()); + out_slr->set_height(in_slr.height()); + } + auto* out = + framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var); + out->mutable_data(in->place()); + + MLUCnnlTensorDesc input_desc(*in); + MLUCnnlTensorDesc scale_desc(scale_tensor); + MLUCnnlTensorDesc output_desc(*out); + + const int axis = std::max(in->dims().size() - 1, 0); + auto bias_after_scale = ctx.Attr("bias_after_scale"); + if (bias_after_scale) { + MLUCnnl::Scale(ctx, axis, input_desc.get(), GetBasePtr(in), + scale_desc.get(), GetBasePtr(&scale_tensor), + bias_desc.get(), GetBasePtr(&bias_tensor), + output_desc.get(), GetBasePtr(out)); + } else { + framework::Tensor new_bias_tensor = + ctx.AllocateTmpTensor({1}, dev_ctx); + MLUCnnlTensorDesc new_bias_desc(new_bias_tensor); + + MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, + ToCnnlDataType(in->type()), + CNNL_NOT_PROPAGATE_NAN); + MLUCnnl::OpTensor( + ctx, mul_op_desc.get(), scale_desc.get(), GetBasePtr(&scale_tensor), + bias_desc.get(), GetBasePtr(&bias_tensor), new_bias_desc.get(), + GetBasePtr(&new_bias_tensor), ToCnnlDataType(in->type())); + MLUCnnl::Scale(ctx, axis, input_desc.get(), GetBasePtr(in), + scale_desc.get(), GetBasePtr(&scale_tensor), + new_bias_desc.get(), GetBasePtr(&new_bias_tensor), + output_desc.get(), GetBasePtr(out)); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_MLU_KERNEL(scale, ops::ScaleMLUKernel, + ops::ScaleMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_cast_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_cast_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..71f79c34d2312eda3952447a5748e66b44d1ab82 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_cast_op_mlu.py @@ -0,0 +1,135 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest + +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard + + +class TestCastOpFp32ToFp16(OpTest): + def setUp(self): + ipt = np.random.random(size=[10, 10]) + self.inputs = {'X': ipt.astype('float32')} + self.outputs = {'Out': ipt.astype('float16')} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.FP32), + 'out_dtype': int(core.VarDesc.VarType.FP16) + } + self.op_type = 'cast' + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestCastOpFp16ToFp32(OpTest): + def setUp(self): + ipt = np.random.random(size=[10, 10]) + self.inputs = {'X': ipt.astype('float16')} + self.outputs = {'Out': ipt.astype('float32')} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.FP16), + 'out_dtype': int(core.VarDesc.VarType.FP32) + } + self.op_type = 'cast' + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestCastOpInt32ToInt32(OpTest): + def setUp(self): + ipt = np.random.randint(1000, size=(10, 10)) + self.inputs = {'X': ipt.astype('int32')} + self.outputs = {'Out': ipt.astype('int32')} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.INT32), + 'out_dtype': int(core.VarDesc.VarType.INT32) + } + self.op_type = 'cast' + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestCastOpInt32ToFp32(OpTest): + def setUp(self): + ipt = np.random.randint(1000, size=[10, 10]) + self.inputs = {'X': ipt.astype('int32')} + self.outputs = {'Out': ipt.astype('float32')} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.INT32), + 'out_dtype': int(core.VarDesc.VarType.FP32) + } + self.op_type = 'cast' + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestCastOpInt16ToFp64(OpTest): + def setUp(self): + ipt = np.random.randint(1000, size=[10, 10]) + self.inputs = {'X': ipt.astype('int16')} + self.outputs = {'Out': ipt.astype('int64')} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.INT16), + 'out_dtype': int(core.VarDesc.VarType.INT64) + } + self.op_type = 'cast' + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestCastOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The input type of cast_op must be Variable. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.MLUPlace(0)) + self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32') + # The input dtype of cast_op must be bool, float16, float32, float64, int32, int64, uint8. + x2 = fluid.layers.data(name='x2', shape=[4], dtype='int16') + self.assertRaises(TypeError, fluid.layers.cast, x2, 'int32') + + def test_dtype_type(): + x4 = fluid.layers.data(name='x4', shape=[4], dtype='int32') + output = fluid.layers.cast(x=x4, dtype='int16') + + self.assertRaises(TypeError, test_dtype_type) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_scale_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_scale_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..bb7f438c4ab2b5eaae7ac2df2f7c1978186b1a05 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_scale_op_mlu.py @@ -0,0 +1,205 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append('..') +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.op import Operator +from paddle.static import Program, program_guard + + +class TestScaleOp(OpTest): + def setUp(self): + self.op_type = "scale" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.dtype = np.float32 + self.init_dtype_type() + self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)} + self.attrs = {'scale': -2.3} + self.outputs = { + 'Out': self.inputs['X'] * self.dtype(self.attrs['scale']) + } + + def init_dtype_type(self): + pass + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestScaleOpScaleVariable(OpTest): + def setUp(self): + self.op_type = "scale" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.dtype = np.float32 + self.init_dtype_type() + self.scale = -2.3 + self.inputs = { + 'X': np.random.random((10, 10)).astype(self.dtype), + 'ScaleTensor': np.array([self.scale]).astype('float32') + } + self.attrs = {} + self.outputs = {'Out': self.inputs['X'] * self.dtype(self.scale)} + + def init_dtype_type(self): + pass + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestScaleOpSelectedRows(unittest.TestCase): + def init_dtype_type(self): + pass + + def check_with_place(self, place, in_name, out_name): + scope = core.Scope() + + self.dtype = np.float32 + self.init_dtype_type() + + # create and initialize Grad Variable + in_height = 10 + in_rows = [0, 4, 7] + in_row_numel = 12 + scale = 2.0 + + in_selected_rows = scope.var(in_name).get_selected_rows() + in_selected_rows.set_height(in_height) + in_selected_rows.set_rows(in_rows) + in_array = np.random.random( + (len(in_rows), in_row_numel)).astype(self.dtype) + + in_tensor = in_selected_rows.get_tensor() + in_tensor.set(in_array, place) + + # create and initialize Param Variable + out_selected_rows = scope.var(out_name).get_selected_rows() + out_tensor = out_selected_rows.get_tensor() + out_tensor._set_dims(in_tensor._get_dims()) + + # create and run sgd operator + scale_op = Operator("scale", X=in_name, Out=out_name, scale=scale) + scale_op.run(scope, place) + + # get and compare result + out_height = out_selected_rows.height() + out_rows = out_selected_rows.rows() + result_array = np.array(out_tensor) + + assert (in_array * scale == result_array).all() + assert in_height == out_height + assert in_rows == out_rows + + def test_scale_selected_rows(self): + places = [core.CPUPlace()] + if core.is_compiled_with_mlu(): + places.append(core.MLUPlace(0)) + for place in places: + self.check_with_place(place, 'in', 'out') + + def test_scale_selected_rows_inplace(self): + places = [core.CPUPlace()] + if core.is_compiled_with_mlu(): + places.append(core.MLUPlace(0)) + for place in places: + self.check_with_place(place, 'in', 'in') + + +class TestScaleRaiseError(unittest.TestCase): + def test_errors(self): + def test_type(): + fluid.layers.scale([10]) + + self.assertRaises(TypeError, test_type) + + +# Add FP16 test +@unittest.skipIf(not core.is_compiled_with_mlu(), + "core is not compiled with MLU") +class TestScaleFp16Op(TestScaleOp): + def init_dtype_type(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=0.002) + + +@unittest.skipIf(not core.is_compiled_with_mlu(), + "core is not compiled with MLU") +class TestScaleFp16OpSelectedRows(TestScaleOpSelectedRows): + def init_dtype_type(self): + self.dtype = np.float16 + + def test_scale_selected_rows(self): + place = core.MLUPlace(0) + self.check_with_place(place, 'in', 'out') + + def test_scale_selected_rows_inplace(self): + place = core.MLUPlace(0) + self.check_with_place(place, 'in', 'in') + + +class TestScaleApiStatic(unittest.TestCase): + def _executed_api(self, x, scale=1.0, bias=0.0): + return paddle.scale(x, scale, bias) + + def test_api(self): + paddle.enable_static() + input = np.random.random([2, 25]).astype("float32") + main_prog = Program() + with program_guard(main_prog, Program()): + x = paddle.static.data(name="x", shape=[2, 25], dtype="float32") + out = self._executed_api(x, scale=2.0, bias=3.0) + + exe = paddle.static.Executor(place=paddle.CPUPlace()) + out = exe.run(main_prog, feed={"x": input}, fetch_list=[out]) + self.assertEqual(np.array_equal(out[0], input * 2.0 + 3.0), True) + + +class TestScaleInplaceApiStatic(TestScaleApiStatic): + def _executed_api(self, x, scale=1.0, bias=0.0): + return x.scale_(scale, bias) + + +class TestScaleApiDygraph(unittest.TestCase): + def _executed_api(self, x, scale=1.0, bias=0.0): + return paddle.scale(x, scale, bias) + + def test_api(self): + paddle.disable_static() + input = np.random.random([2, 25]).astype("float32") + x = paddle.to_tensor(input) + out = self._executed_api(x, scale=2.0, bias=3.0) + self.assertEqual(np.array_equal(out.numpy(), input * 2.0 + 3.0), True) + paddle.enable_static() + + +class TestScaleInplaceApiDygraph(TestScaleApiDygraph): + def _executed_api(self, x, scale=1.0, bias=0.0): + return x.scale_(scale, bias) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()