未验证 提交 4cab812e 编写于 作者: F fwenguang 提交者: GitHub

[MLU] add transpose2 mlu kernel (#39994)

上级 4e00d2bb
......@@ -1157,19 +1157,22 @@ inline void TransposeFromMLUTensor(const ExecutionContext& ctx,
const Tensor* transformed_input,
Tensor* transformed_output,
bool need_reshape_or_alloc) {
auto in_dims_vec = phi::vectorize(transformed_input->dims());
const int dim_size = perm.size();
if (need_reshape_or_alloc) {
std::vector<int> output_shape;
auto input_dims = transformed_input->dims();
for (int i = 0; i < dim_size; ++i) {
output_shape.push_back(input_dims[perm[i]]);
}
transformed_output->mutable_data<T>(
{in_dims_vec[perm[0]], in_dims_vec[perm[1]], in_dims_vec[perm[2]],
in_dims_vec[perm[3]]},
ctx.GetPlace());
framework::DDim(output_shape.data(), dim_size), ctx.GetPlace());
}
MLUCnnlTensorDesc trans_in_desc(*transformed_input, CNNL_LAYOUT_ARRAY,
ToCnnlDataType<T>());
MLUCnnlTensorDesc trans_out_desc(*transformed_output, CNNL_LAYOUT_ARRAY,
ToCnnlDataType<T>());
MLUCnnl::Transpose(ctx, perm, in_dims_vec.size(), trans_in_desc.get(),
MLUCnnl::Transpose(ctx, perm, dim_size, trans_in_desc.get(),
GetBasePtr(transformed_input), trans_out_desc.get(),
GetBasePtr(transformed_output));
}
......
......@@ -27,11 +27,11 @@ class ReduceMaxMLUKernel : public framework::OpKernel<T> {
int out_dtype = context.Attr<int>("out_dtype");
bool reduce_all = context.Attr<bool>("reduce_all");
auto dims = context.Attr<std::vector<int>>("dim");
auto input_dims = framework::vectorize(input->dims());
auto input_dims = input->dims();
const auto& input_dim_size = input->dims().size();
std::vector<int> reduce_dims;
if (reduce_all) {
for (size_t i = 0; i < input_dims.size(); i++) {
for (int i = 0; i < input_dims.size(); i++) {
reduce_dims.push_back(static_cast<int>(i));
}
} else {
......
......@@ -27,11 +27,11 @@ class ReduceMinMLUKernel : public framework::OpKernel<T> {
int out_dtype = context.Attr<int>("out_dtype");
bool reduce_all = context.Attr<bool>("reduce_all");
auto dims = context.Attr<std::vector<int>>("dim");
auto input_dims = framework::vectorize(input->dims());
auto input_dims = input->dims();
const auto& input_dim_size = input->dims().size();
std::vector<int> reduce_dims;
if (reduce_all) {
for (size_t i = 0; i < input_dims.size(); i++) {
for (int i = 0; i < input_dims.size(); i++) {
reduce_dims.push_back(static_cast<int>(i));
}
} else {
......
......@@ -37,7 +37,7 @@ class SoftmaxWithCrossEntropyMLUKernel : public framework::OpKernel<T> {
"the mlu kernel of softmax_with_cross_entropy."));
const int rank = logits->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
const int axis = phi::funcs::CanonicalAxis(ctx.Attr<int>("axis"), rank);
loss->mutable_data<T>(ctx.GetPlace());
backprop->mutable_data<T>(ctx.GetPlace());
......@@ -45,10 +45,10 @@ class SoftmaxWithCrossEntropyMLUKernel : public framework::OpKernel<T> {
// cnnl softmax only support 3-dims, regard all shape as [d1, d2, d3]
const int cnnl_softmax_dims = 3;
const int d1 = SizeToAxis(axis, logits->dims());
const int d1 = phi::funcs::SizeToAxis(axis, logits->dims());
const int d2_logits = logits->dims()[axis];
const int d2_labels = labels->dims()[axis];
const int d3 = SizeOutAxis(axis, logits->dims());
const int d3 = phi::funcs::SizeOutAxis(axis, logits->dims());
// CNNL_SOFTMAX_MODE_LOW_DIMENSION has better perfermence, use it as much as
// possible.
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class TransposeMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out");
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
out->mutable_data<T>(ctx.device_context().GetPlace());
TransposeFromMLUTensor<T>(ctx, axis, x, out,
false /*need_reshape_or_alloc*/);
}
};
template <typename T>
class TransposeGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out_grad =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
x_grad->mutable_data<T>(ctx.GetPlace());
TransposeFromMLUTensor<T>(ctx, reversed_axis, out_grad, x_grad,
false /*need_reshape_or_alloc*/);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL(transpose2, ops::TransposeMLUKernel<float>,
ops::TransposeMLUKernel<paddle::platform::float16>,
ops::TransposeMLUKernel<int>,
ops::TransposeMLUKernel<int16_t>,
ops::TransposeMLUKernel<uint8_t>,
ops::TransposeMLUKernel<int8_t>,
ops::TransposeMLUKernel<bool>);
REGISTER_OP_MLU_KERNEL(transpose2_grad, ops::TransposeGradMLUKernel<float>,
ops::TransposeGradMLUKernel<paddle::platform::float16>,
ops::TransposeGradMLUKernel<int>,
ops::TransposeGradMLUKernel<int16_t>,
ops::TransposeGradMLUKernel<uint8_t>,
ops::TransposeGradMLUKernel<int8_t>,
ops::TransposeGradMLUKernel<bool>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append('..')
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle.fluid.core as core
paddle.enable_static()
class TestTransposeOp(OpTest):
def setUp(self):
self.init_op_type()
self.initKernelType()
self.initTestCase()
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
self.attrs = {'axis': list(self.axis), }
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
def init_op_type(self):
self.op_type = "transpose2"
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def initTestCase(self):
self.shape = (3, 40)
self.axis = (1, 0)
def initKernelType(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
class TestCase0(TestTransposeOp):
def initTestCase(self):
self.shape = (100, )
self.axis = (0, )
class TestCase1(TestTransposeOp):
def initTestCase(self):
self.shape = (3, 4, 10)
self.axis = (0, 2, 1)
class TestCase2(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5)
self.axis = (0, 2, 3, 1)
class TestCase3(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.axis = (4, 2, 3, 1, 0)
class TestCase4(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6, 1)
self.axis = (4, 2, 3, 1, 0, 5)
class TestCase5(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 16, 96)
self.axis = (0, 2, 1)
class TestCase6(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 10, 12, 16)
self.axis = (3, 1, 2, 0)
class TestCase7(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 10, 2, 16)
self.axis = (0, 1, 3, 2)
class TestCase8(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (0, 1, 3, 2, 4, 5, 6, 7)
class TestCase9(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (6, 1, 3, 5, 0, 2, 4, 7)
class TestTransposeOpBool(TestTransposeOp):
def test_check_grad(self):
pass
class TestTransposeOpBool1D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (100, )
self.axis = (0, )
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
class TestTransposeOpBool2D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (3, 40)
self.axis = (1, 0)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
class TestTransposeOpBool3D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (3, 4, 10)
self.axis = (0, 2, 1)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
class TestTransposeOpBool4D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 4, 5)
self.axis = (0, 2, 3, 1)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
class TestTransposeOpBool5D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.axis = (4, 2, 3, 1, 0)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
class TestTransposeOpBool6D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6, 1)
self.axis = (4, 2, 3, 1, 0, 5)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
class TestTransposeOpBool7D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3)
self.axis = (0, 1, 3, 2, 4, 5, 6)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
class TestTransposeOpBool8D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (6, 1, 3, 5, 0, 2, 4, 7)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
class TestTransposeOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[10, 5, 3], dtype='float32')
def test_x_Variable_check():
# the Input(x)'s type must be Variable
fluid.layers.transpose("not_variable", perm=[1, 0, 2])
self.assertRaises(TypeError, test_x_Variable_check)
def test_perm_list_check():
# Input(perm)'s type must be list
fluid.layers.transpose(x, perm="[1, 0, 2]")
self.assertRaises(TypeError, test_perm_list_check)
def test_perm_length_and_x_dim_check():
# Input(perm) is the permutation of dimensions of Input(input)
# its length should be equal to dimensions of Input(input)
fluid.layers.transpose(x, perm=[1, 0, 2, 3, 4])
self.assertRaises(ValueError, test_perm_length_and_x_dim_check)
def test_each_elem_value_check():
# Each element in Input(perm) should be less than Input(x)'s dimension
fluid.layers.transpose(x, perm=[3, 5, 7])
self.assertRaises(ValueError, test_each_elem_value_check)
class TestTransposeApi(unittest.TestCase):
def test_static_out(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name='x', shape=[2, 3, 4], dtype='float32')
x_trans1 = paddle.transpose(x, perm=[1, 0, 2])
x_trans2 = paddle.transpose(x, perm=(2, 1, 0))
place = paddle.MLUPlace(0)
exe = paddle.static.Executor(place)
x_np = np.random.random([2, 3, 4]).astype("float32")
result1, result2 = exe.run(feed={"x": x_np},
fetch_list=[x_trans1, x_trans2])
expected_result1 = np.transpose(x_np, [1, 0, 2])
expected_result2 = np.transpose(x_np, (2, 1, 0))
np.testing.assert_array_equal(result1, expected_result1)
np.testing.assert_array_equal(result2, expected_result2)
def test_dygraph_out(self):
# This is an old test before 2.0 API so we need to disable static
# to trigger dygraph
paddle.disable_static()
x = paddle.randn([2, 3, 4])
x_trans1 = paddle.transpose(x, perm=[1, 0, 2])
x_trans2 = paddle.transpose(x, perm=(2, 1, 0))
x_np = x.numpy()
expected_result1 = np.transpose(x_np, [1, 0, 2])
expected_result2 = np.transpose(x_np, (2, 1, 0))
np.testing.assert_array_equal(x_trans1.numpy(), expected_result1)
np.testing.assert_array_equal(x_trans2.numpy(), expected_result2)
# This is an old test before 2.0 API so we enable static again after
# dygraph test
paddle.enable_static()
class TestTAPI(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
data = fluid.data(shape=[10], dtype="float32", name="data")
data_t = paddle.t(data)
place = fluid.MLUPlace(0)
exe = fluid.Executor(place)
data_np = np.random.random([10]).astype("float32")
result, = exe.run(feed={"data": data_np}, fetch_list=[data_t])
expected_result = np.transpose(data_np)
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data = fluid.data(shape=[10, 5], dtype="float32", name="data")
data_t = paddle.t(data)
place = fluid.MLUPlace(0)
exe = fluid.Executor(place)
data_np = np.random.random([10, 5]).astype("float32")
result, = exe.run(feed={"data": data_np}, fetch_list=[data_t])
expected_result = np.transpose(data_np)
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data = fluid.data(shape=[1, 5], dtype="float32", name="data")
data_t = paddle.t(data)
place = fluid.MLUPlace(0)
exe = fluid.Executor(place)
data_np = np.random.random([1, 5]).astype("float32")
result, = exe.run(feed={"data": data_np}, fetch_list=[data_t])
expected_result = np.transpose(data_np)
self.assertEqual((result == expected_result).all(), True)
with fluid.dygraph.guard():
np_x = np.random.random([10]).astype("float32")
data = fluid.dygraph.to_variable(np_x)
z = paddle.t(data)
np_z = z.numpy()
z_expected = np.array(np.transpose(np_x))
self.assertEqual((np_z == z_expected).all(), True)
with fluid.dygraph.guard():
np_x = np.random.random([10, 5]).astype("float32")
data = fluid.dygraph.to_variable(np_x)
z = paddle.t(data)
np_z = z.numpy()
z_expected = np.array(np.transpose(np_x))
self.assertEqual((np_z == z_expected).all(), True)
with fluid.dygraph.guard():
np_x = np.random.random([1, 5]).astype("float32")
data = fluid.dygraph.to_variable(np_x)
z = paddle.t(data)
np_z = z.numpy()
z_expected = np.array(np.transpose(np_x))
self.assertEqual((np_z == z_expected).all(), True)
def test_errors(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name='x', shape=[10, 5, 3], dtype='float32')
def test_x_dimension_check():
paddle.t(x)
self.assertRaises(ValueError, test_x_dimension_check)
class TestMoveAxis(unittest.TestCase):
def test_moveaxis1(self):
x_np = np.random.randn(2, 3, 4, 5, 7).astype('float32')
expected = np.moveaxis(x_np, [0, 4, 3, 2], [1, 3, 2, 0])
paddle.enable_static()
with paddle.static.program_guard(fluid.Program()):
x = paddle.static.data("x", shape=[2, 3, 4, 5, 7], dtype='float32')
out = paddle.moveaxis(x, [0, 4, 3, 2], [1, 3, 2, 0])
exe = paddle.static.Executor()
out_np = exe.run(feed={"x": x_np}, fetch_list=[out])[0]
self.assertEqual(np.array_equal(out_np, expected), True)
paddle.disable_static()
x = paddle.to_tensor(x_np)
out = paddle.moveaxis(x, [0, 4, 3, 2], [1, 3, 2, 0])
self.assertEqual(out.shape, [4, 2, 5, 7, 3])
self.assertEqual(np.array_equal(out.numpy(), expected), True)
paddle.enable_static()
def test_moveaxis2(self):
x_np = np.random.randn(2, 3, 5).astype('float32')
expected = np.moveaxis(x_np, -2, -1)
paddle.enable_static()
with paddle.static.program_guard(fluid.Program()):
x = paddle.static.data("x", shape=[2, 3, 5], dtype='float32')
out = x.moveaxis(-2, -1)
exe = paddle.static.Executor()
out_np = exe.run(feed={"x": x_np}, fetch_list=[out])[0]
self.assertEqual(np.array_equal(out_np, expected), True)
paddle.disable_static()
x = paddle.to_tensor(x_np)
out = x.moveaxis(-2, -1)
self.assertEqual(out.shape, [2, 5, 3])
self.assertEqual(np.array_equal(out.numpy(), expected), True)
paddle.enable_static()
def test_error(self):
x = paddle.randn([2, 3, 4, 5])
# src must have the same number with dst
with self.assertRaises(AssertionError):
paddle.moveaxis(x, [1, 0], [2])
# each element of src must be unique
with self.assertRaises(ValueError):
paddle.moveaxis(x, [1, 1], [0, 2])
# each element of dst must be unique
with self.assertRaises(ValueError):
paddle.moveaxis(x, [0, 1], [2, 2])
# each element of src must be integer
with self.assertRaises(AssertionError):
paddle.moveaxis(x, [0.5], [1])
# each element of dst must be integer
with self.assertRaises(AssertionError):
paddle.moveaxis(x, [0], [1.5])
# each element of src must be in the range of [-4, 3)
with self.assertRaises(AssertionError):
paddle.moveaxis(x, [-10, 1], [2, 3])
# each element of dst must be in the range of [-4, 3)
with self.assertRaises(AssertionError):
paddle.moveaxis(x, [2, 1], [10, 3])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册