未验证 提交 e3e50ea8 编写于 作者: F fwenguang 提交者: GitHub

[MLU]add mlu kernel for cast and scale op (#38961)

上级 f1143f0c
......@@ -16,6 +16,9 @@ limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
namespace paddle {
namespace operators {
......@@ -102,6 +105,19 @@ class CastOp : public framework::OperatorWithKernel {
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
#ifdef PADDLE_WITH_MLU
auto src_type = static_cast<VT::Type>(ctx.Attr<int>("in_dtype"));
auto dst_type = static_cast<VT::Type>(ctx.Attr<int>("out_dtype"));
if (src_type == dst_type || MLUSupportsCast(src_type, dst_type)) {
return framework::OpKernelType(tensor->type(), tensor_place);
} else {
VLOG(3) << "MLU not support cast type: "
<< framework::DataTypeToString(src_type)
<< " to type: " << framework::DataTypeToString(dst_type)
<< ", fallbacking to CPU one!";
return framework::OpKernelType(tensor->type(), platform::CPUPlace());
}
#endif
return framework::OpKernelType(tensor->type(), tensor_place);
}
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/platform/device/mlu/device_context.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class CastMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto src_type = static_cast<VT::Type>(ctx.Attr<int>("in_dtype"));
auto dst_type = static_cast<VT::Type>(ctx.Attr<int>("out_dtype"));
auto place = ctx.GetPlace();
if (src_type == dst_type) {
auto& dev_ctx = ctx.template device_context<platform::MLUDeviceContext>();
output->mutable_data<T>(place);
framework::TensorCopy(*input, place, dev_ctx, output);
return;
}
PADDLE_ENFORCE_EQ(MLUSupportsCast(src_type, dst_type), true,
platform::errors::InvalidArgument(
"MLU not support cast [%d] to [%d]",
framework::DataTypeToString(src_type),
framework::DataTypeToString(dst_type)));
switch (dst_type) {
case VT::FP32:
output->mutable_data<float>(place);
break;
case VT::FP16:
output->mutable_data<paddle::platform::float16>(place);
break;
case VT::INT32:
output->mutable_data<int32_t>(place);
break;
case VT::INT16:
output->mutable_data<int16_t>(place);
break;
case VT::INT8:
output->mutable_data<int8_t>(place);
break;
case VT::UINT8:
output->mutable_data<uint8_t>(place);
break;
case VT::BOOL:
output->mutable_data<bool>(place);
break;
case VT::INT64:
output->mutable_data<int64_t>(place);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"Not supported cast %d -> %d",
framework::DataTypeToString(src_type),
framework::DataTypeToString(dst_type)));
}
MLUCnnlTensorDesc input_desc(*input);
MLUCnnlTensorDesc output_desc(*output);
cnnlCastDataType_t cast_type = GetCastDataType(src_type, dst_type);
MLUCnnl::Cast(ctx, cast_type, input_desc.get(), GetBasePtr(input),
output_desc.get(), GetBasePtr(output));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL(cast, ops::CastMLUKernel<float>, ops::CastMLUKernel<int>,
ops::CastMLUKernel<int16_t>, ops::CastMLUKernel<uint8_t>,
ops::CastMLUKernel<bool>, ops::CastMLUKernel<int64_t>,
ops::CastMLUKernel<paddle::platform::float16>);
......@@ -20,6 +20,29 @@ limitations under the License. */
namespace paddle {
namespace operators {
cnnlCastDataType_t GetCastDataType(const VT::Type& src_type,
const VT::Type& dst_type) {
cnnlCastDataType_t cast_type = CNNL_CAST_FLOAT_TO_HALF;
for (auto it = MLU_SUPPORTED_CAST_TYPE.begin();
it != MLU_SUPPORTED_CAST_TYPE.end(); ++it) {
if (it->first.first == src_type && it->first.second == dst_type) {
cast_type = it->second;
break;
}
}
return cast_type;
}
bool MLUSupportsCast(const VT::Type& src_type, const VT::Type& dst_type) {
for (auto it = MLU_SUPPORTED_CAST_TYPE.begin();
it != MLU_SUPPORTED_CAST_TYPE.end(); ++it) {
if (it->first.first == src_type && it->first.second == dst_type) {
return true;
}
}
return false;
}
class MLUCnnlTensorDescPool {
public:
cnnlTensorDescriptor_t Pop() {
......@@ -153,6 +176,10 @@ MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor,
}
}
MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor)
: MLUCnnlTensorDesc(tensor, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(tensor.type())) {}
MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor& tensor,
cnnlTensorLayout_t layout,
const cnnlDataType_t tensor_dtype,
......@@ -1848,7 +1875,7 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
if (is_training) {
/*
* If in Paddle, running_mean_output = momentum * runnning_mean_input +
* In Paddle, running_mean_output = momentum * runnning_mean_input +
* (1 - momentum) * batch_mean. However, In CNNL,
* running_mean_output = (1 - momentum) * running_mean_input +
* momentum * batch_mean. So we pass (1.0 - momentum) to momentum param.
......
......@@ -74,6 +74,9 @@ inline cnnlDataType_t ToCnnlDataType(const framework::proto::VarType::Type& t) {
case framework::proto::VarType::INT8:
type = CNNL_DTYPE_INT8;
break;
case framework::proto::VarType::INT16:
type = CNNL_DTYPE_INT16;
break;
case framework::proto::VarType::INT32:
type = CNNL_DTYPE_INT32;
break;
......@@ -83,6 +86,9 @@ inline cnnlDataType_t ToCnnlDataType(const framework::proto::VarType::Type& t) {
case framework::proto::VarType::BOOL:
type = CNNL_DTYPE_BOOL;
break;
case framework::proto::VarType::UINT8:
type = CNNL_DTYPE_UINT8;
break;
default:
break;
}
......@@ -108,6 +114,47 @@ inline static const MLUDeviceContext& GetDevCtxFromCTX(
return ctx.template device_context<MLUDeviceContext>();
}
using VT = framework::proto::VarType;
const std::map<std::pair<VT::Type, VT::Type>, cnnlCastDataType_t>
MLU_SUPPORTED_CAST_TYPE = {
{{VT::FP32, /*cast to*/ VT::FP16}, CNNL_CAST_FLOAT_TO_HALF},
{{VT::FP32, /*cast to*/ VT::INT32}, CNNL_CAST_FLOAT_TO_INT32},
{{VT::FP32, /*cast to*/ VT::INT16}, CNNL_CAST_FLOAT_TO_INT16},
{{VT::FP32, /*cast to*/ VT::INT8}, CNNL_CAST_FLOAT_TO_INT8},
{{VT::FP32, /*cast to*/ VT::UINT8}, CNNL_CAST_FLOAT_TO_UINT8},
{{VT::FP32, /*cast to*/ VT::BOOL}, CNNL_CAST_FLOAT_TO_BOOL},
{{VT::FP16, /*cast to*/ VT::FP32}, CNNL_CAST_HALF_TO_FLOAT},
{{VT::FP16, /*cast to*/ VT::INT32}, CNNL_CAST_HALF_TO_INT32},
{{VT::FP16, /*cast to*/ VT::INT16}, CNNL_CAST_HALF_TO_INT16},
{{VT::FP16, /*cast to*/ VT::INT8}, CNNL_CAST_HALF_TO_INT8},
{{VT::FP16, /*cast to*/ VT::UINT8}, CNNL_CAST_HALF_TO_UINT8},
{{VT::FP16, /*cast to*/ VT::BOOL}, CNNL_CAST_HALF_TO_BOOL},
{{VT::INT32, /*cast to*/ VT::FP32}, CNNL_CAST_INT32_TO_FLOAT},
{{VT::INT32, /*cast to*/ VT::FP16}, CNNL_CAST_INT32_TO_HALF},
{{VT::INT32, /*cast to*/ VT::INT64}, CNNL_CAST_INT32_TO_INT64},
{{VT::INT32, /*cast to*/ VT::INT16}, CNNL_CAST_INT32_TO_INT16},
{{VT::INT32, /*cast to*/ VT::INT8}, CNNL_CAST_INT32_TO_INT8},
{{VT::INT32, /*cast to*/ VT::BOOL}, CNNL_CAST_INT32_TO_BOOL},
{{VT::INT16, /*cast to*/ VT::FP32}, CNNL_CAST_INT16_TO_FLOAT},
{{VT::INT16, /*cast to*/ VT::FP16}, CNNL_CAST_INT16_TO_HALF},
{{VT::INT16, /*cast to*/ VT::INT32}, CNNL_CAST_INT16_TO_INT32},
{{VT::INT8, /*cast to*/ VT::FP32}, CNNL_CAST_INT8_TO_FLOAT},
{{VT::INT8, /*cast to*/ VT::FP16}, CNNL_CAST_INT8_TO_HALF},
{{VT::INT8, /*cast to*/ VT::INT32}, CNNL_CAST_INT8_TO_INT32},
{{VT::UINT8, /*cast to*/ VT::FP32}, CNNL_CAST_UINT8_TO_FLOAT},
{{VT::UINT8, /*cast to*/ VT::FP16}, CNNL_CAST_UINT8_TO_HALF},
{{VT::UINT8, /*cast to*/ VT::INT64}, CNNL_CAST_UINT8_TO_INT64},
{{VT::UINT8, /*cast to*/ VT::INT32}, CNNL_CAST_UINT8_TO_INT32},
{{VT::BOOL, /*cast to*/ VT::FP32}, CNNL_CAST_BOOL_TO_FLOAT},
{{VT::BOOL, /*cast to*/ VT::FP16}, CNNL_CAST_BOOL_TO_HALF},
{{VT::BOOL, /*cast to*/ VT::INT32}, CNNL_CAST_BOOL_TO_INT32},
{{VT::INT64, /*cast to*/ VT::INT32}, CNNL_CAST_INT64_TO_INT32},
};
cnnlCastDataType_t GetCastDataType(const VT::Type& src_type,
const VT::Type& dst_type);
bool MLUSupportsCast(const VT::Type& src_type, const VT::Type& dst_type);
cnnlDeviceType_t GetCnnlDev(int dev_ordinal);
using CnnlTensorDesc = cnnlTensorDescriptor_t;
......@@ -150,6 +197,8 @@ class MLUCnnlTensorDesc {
MLUCnnlTensorDesc(const Tensor& tensor, const cnnlTensorLayout_t layout,
const cnnlDataType_t tensor_dtype);
explicit MLUCnnlTensorDesc(const Tensor& tensor);
MLUCnnlTensorDesc(const Tensor& tensor, cnnlTensorLayout_t layout,
const cnnlDataType_t tensor_dtype, int position);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/scale_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class ScaleMLUKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = GetDevCtxFromCTX(ctx);
auto* in_var = ctx.InputVar("X");
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
// cnnl require input, scale, bias with same type. And all in device side.
auto& scale = ctx.Attr<float>("scale");
framework::Tensor scale_tensor;
if (ctx.HasInput("ScaleTensor")) {
framework::Tensor float_scale_tensor =
*ctx.Input<framework::Tensor>("ScaleTensor");
if (float_scale_tensor.type() != in->type()) {
scale_tensor = ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc float_scale_desc(float_scale_tensor);
MLUCnnlTensorDesc final_scale_desc(scale_tensor);
cnnlCastDataType_t cast_type =
GetCastDataType(float_scale_tensor.type(), scale_tensor.type());
MLUCnnl::Cast(ctx, cast_type, float_scale_desc.get(),
GetBasePtr(&float_scale_tensor), final_scale_desc.get(),
GetBasePtr(&scale_tensor));
} else {
scale_tensor = float_scale_tensor;
}
} else {
scale_tensor = ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc scale_desc(scale_tensor);
MLUCnnl::Fill(ctx, scale, scale_desc.get(), GetBasePtr(&scale_tensor));
}
auto& bias = ctx.Attr<float>("bias");
framework::Tensor bias_tensor =
ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc bias_desc(bias_tensor);
MLUCnnl::Fill(ctx, bias, bias_desc.get(), GetBasePtr(&bias_tensor));
auto* out_var = ctx.OutputVar("Out");
if (in_var->IsType<framework::SelectedRows>() && in_var != out_var) {
auto& in_slr = in_var->Get<framework::SelectedRows>();
auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->set_rows(in_slr.rows());
out_slr->set_height(in_slr.height());
}
auto* out =
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var);
out->mutable_data<T>(in->place());
MLUCnnlTensorDesc input_desc(*in);
MLUCnnlTensorDesc scale_desc(scale_tensor);
MLUCnnlTensorDesc output_desc(*out);
const int axis = std::max(in->dims().size() - 1, 0);
auto bias_after_scale = ctx.Attr<bool>("bias_after_scale");
if (bias_after_scale) {
MLUCnnl::Scale(ctx, axis, input_desc.get(), GetBasePtr(in),
scale_desc.get(), GetBasePtr(&scale_tensor),
bias_desc.get(), GetBasePtr(&bias_tensor),
output_desc.get(), GetBasePtr(out));
} else {
framework::Tensor new_bias_tensor =
ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc new_bias_desc(new_bias_tensor);
MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL,
ToCnnlDataType(in->type()),
CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(
ctx, mul_op_desc.get(), scale_desc.get(), GetBasePtr(&scale_tensor),
bias_desc.get(), GetBasePtr(&bias_tensor), new_bias_desc.get(),
GetBasePtr(&new_bias_tensor), ToCnnlDataType(in->type()));
MLUCnnl::Scale(ctx, axis, input_desc.get(), GetBasePtr(in),
scale_desc.get(), GetBasePtr(&scale_tensor),
new_bias_desc.get(), GetBasePtr(&new_bias_tensor),
output_desc.get(), GetBasePtr(out));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL(scale, ops::ScaleMLUKernel<float>,
ops::ScaleMLUKernel<paddle::platform::float16>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
class TestCastOpFp32ToFp16(OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.astype('float16')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.FP32),
'out_dtype': int(core.VarDesc.VarType.FP16)
}
self.op_type = 'cast'
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
class TestCastOpFp16ToFp32(OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float16')}
self.outputs = {'Out': ipt.astype('float32')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.FP16),
'out_dtype': int(core.VarDesc.VarType.FP32)
}
self.op_type = 'cast'
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
class TestCastOpInt32ToInt32(OpTest):
def setUp(self):
ipt = np.random.randint(1000, size=(10, 10))
self.inputs = {'X': ipt.astype('int32')}
self.outputs = {'Out': ipt.astype('int32')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.INT32),
'out_dtype': int(core.VarDesc.VarType.INT32)
}
self.op_type = 'cast'
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
class TestCastOpInt32ToFp32(OpTest):
def setUp(self):
ipt = np.random.randint(1000, size=[10, 10])
self.inputs = {'X': ipt.astype('int32')}
self.outputs = {'Out': ipt.astype('float32')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.INT32),
'out_dtype': int(core.VarDesc.VarType.FP32)
}
self.op_type = 'cast'
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
class TestCastOpInt16ToFp64(OpTest):
def setUp(self):
ipt = np.random.randint(1000, size=[10, 10])
self.inputs = {'X': ipt.astype('int16')}
self.outputs = {'Out': ipt.astype('int64')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.INT16),
'out_dtype': int(core.VarDesc.VarType.INT64)
}
self.op_type = 'cast'
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
class TestCastOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of cast_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.MLUPlace(0))
self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32')
# The input dtype of cast_op must be bool, float16, float32, float64, int32, int64, uint8.
x2 = fluid.layers.data(name='x2', shape=[4], dtype='int16')
self.assertRaises(TypeError, fluid.layers.cast, x2, 'int32')
def test_dtype_type():
x4 = fluid.layers.data(name='x4', shape=[4], dtype='int32')
output = fluid.layers.cast(x=x4, dtype='int16')
self.assertRaises(TypeError, test_dtype_type)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append('..')
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.static import Program, program_guard
class TestScaleOp(OpTest):
def setUp(self):
self.op_type = "scale"
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
self.dtype = np.float32
self.init_dtype_type()
self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)}
self.attrs = {'scale': -2.3}
self.outputs = {
'Out': self.inputs['X'] * self.dtype(self.attrs['scale'])
}
def init_dtype_type(self):
pass
def test_check_output(self):
self.check_output_with_place(self.place)
class TestScaleOpScaleVariable(OpTest):
def setUp(self):
self.op_type = "scale"
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
self.dtype = np.float32
self.init_dtype_type()
self.scale = -2.3
self.inputs = {
'X': np.random.random((10, 10)).astype(self.dtype),
'ScaleTensor': np.array([self.scale]).astype('float32')
}
self.attrs = {}
self.outputs = {'Out': self.inputs['X'] * self.dtype(self.scale)}
def init_dtype_type(self):
pass
def test_check_output(self):
self.check_output_with_place(self.place)
class TestScaleOpSelectedRows(unittest.TestCase):
def init_dtype_type(self):
pass
def check_with_place(self, place, in_name, out_name):
scope = core.Scope()
self.dtype = np.float32
self.init_dtype_type()
# create and initialize Grad Variable
in_height = 10
in_rows = [0, 4, 7]
in_row_numel = 12
scale = 2.0
in_selected_rows = scope.var(in_name).get_selected_rows()
in_selected_rows.set_height(in_height)
in_selected_rows.set_rows(in_rows)
in_array = np.random.random(
(len(in_rows), in_row_numel)).astype(self.dtype)
in_tensor = in_selected_rows.get_tensor()
in_tensor.set(in_array, place)
# create and initialize Param Variable
out_selected_rows = scope.var(out_name).get_selected_rows()
out_tensor = out_selected_rows.get_tensor()
out_tensor._set_dims(in_tensor._get_dims())
# create and run sgd operator
scale_op = Operator("scale", X=in_name, Out=out_name, scale=scale)
scale_op.run(scope, place)
# get and compare result
out_height = out_selected_rows.height()
out_rows = out_selected_rows.rows()
result_array = np.array(out_tensor)
assert (in_array * scale == result_array).all()
assert in_height == out_height
assert in_rows == out_rows
def test_scale_selected_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_mlu():
places.append(core.MLUPlace(0))
for place in places:
self.check_with_place(place, 'in', 'out')
def test_scale_selected_rows_inplace(self):
places = [core.CPUPlace()]
if core.is_compiled_with_mlu():
places.append(core.MLUPlace(0))
for place in places:
self.check_with_place(place, 'in', 'in')
class TestScaleRaiseError(unittest.TestCase):
def test_errors(self):
def test_type():
fluid.layers.scale([10])
self.assertRaises(TypeError, test_type)
# Add FP16 test
@unittest.skipIf(not core.is_compiled_with_mlu(),
"core is not compiled with MLU")
class TestScaleFp16Op(TestScaleOp):
def init_dtype_type(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=0.002)
@unittest.skipIf(not core.is_compiled_with_mlu(),
"core is not compiled with MLU")
class TestScaleFp16OpSelectedRows(TestScaleOpSelectedRows):
def init_dtype_type(self):
self.dtype = np.float16
def test_scale_selected_rows(self):
place = core.MLUPlace(0)
self.check_with_place(place, 'in', 'out')
def test_scale_selected_rows_inplace(self):
place = core.MLUPlace(0)
self.check_with_place(place, 'in', 'in')
class TestScaleApiStatic(unittest.TestCase):
def _executed_api(self, x, scale=1.0, bias=0.0):
return paddle.scale(x, scale, bias)
def test_api(self):
paddle.enable_static()
input = np.random.random([2, 25]).astype("float32")
main_prog = Program()
with program_guard(main_prog, Program()):
x = paddle.static.data(name="x", shape=[2, 25], dtype="float32")
out = self._executed_api(x, scale=2.0, bias=3.0)
exe = paddle.static.Executor(place=paddle.CPUPlace())
out = exe.run(main_prog, feed={"x": input}, fetch_list=[out])
self.assertEqual(np.array_equal(out[0], input * 2.0 + 3.0), True)
class TestScaleInplaceApiStatic(TestScaleApiStatic):
def _executed_api(self, x, scale=1.0, bias=0.0):
return x.scale_(scale, bias)
class TestScaleApiDygraph(unittest.TestCase):
def _executed_api(self, x, scale=1.0, bias=0.0):
return paddle.scale(x, scale, bias)
def test_api(self):
paddle.disable_static()
input = np.random.random([2, 25]).astype("float32")
x = paddle.to_tensor(input)
out = self._executed_api(x, scale=2.0, bias=3.0)
self.assertEqual(np.array_equal(out.numpy(), input * 2.0 + 3.0), True)
paddle.enable_static()
class TestScaleInplaceApiDygraph(TestScaleApiDygraph):
def _executed_api(self, x, scale=1.0, bias=0.0):
return x.scale_(scale, bias)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册