未验证 提交 23ad2166 编写于 作者: F fwenguang 提交者: GitHub

[MLU] add gather mlu kernel (#41969)

上级 30d8d114
......@@ -17,13 +17,16 @@ INCLUDE_DIRECTORIES(${NEUWARE_INCLUDE_DIR})
set(CNNL_LIB ${NEUWARE_LIB_DIR}/libcnnl.so)
set(CNRT_LIB ${NEUWARE_LIB_DIR}/libcnrt.so)
set(CNDRV_LIB ${NEUWARE_LIB_DIR}/libcndrv.so)
set(CNPAPI_LIB ${NEUWARE_LIB_DIR}/libcnpapi.so)
generate_dummy_static_lib(LIB_NAME "neuware_lib" GENERATOR "neuware.cmake")
set(NEUWARE_LIB_DEPS ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB} ${CNPAPI_LIB})
if(WITH_CNCL)
MESSAGE(STATUS "Compile with CNCL!")
ADD_DEFINITIONS(-DPADDLE_WITH_CNCL)
set(CNCL_LIB ${NEUWARE_LIB_DIR}/libcncl.so)
TARGET_LINK_LIBRARIES(neuware_lib ${CNCL_LIB} ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB})
else()
TARGET_LINK_LIBRARIES(neuware_lib ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB})
list(APPEND NEUWARE_LIB_DEPS ${CNCL_LIB})
endif()
TARGET_LINK_LIBRARIES(neuware_lib ${NEUWARE_LIB_DEPS})
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class GatherOpMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto axis = ctx.Attr<int>("axis");
auto *out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc index_desc(*index);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::GatherFunctor(ctx, axis, 0 /*batch_dims*/, x_desc.get(),
GetBasePtr(x), index_desc.get(), GetBasePtr(index),
out_desc.get(), GetBasePtr(out));
}
};
template <typename T>
class GatherGradOpMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *index = ctx.Input<Tensor>("Index");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc dx_desc(*dx);
auto value = static_cast<T>(0);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &value, dx_desc.get(),
GetBasePtr(dx));
MLUCnnlTensorDesc index_desc(*index);
MLUCnnlTensorDesc dout_desc(*dout);
const cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE;
MLUCnnl::ScatterFunctor(ctx, dx_desc.get(), GetBasePtr(dx), dout_desc.get(),
GetBasePtr(dout), index_desc.get(),
GetBasePtr(index), mode);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL(gather, ops::GatherOpMLUKernel<float>,
ops::GatherOpMLUKernel<paddle::platform::float16>,
ops::GatherOpMLUKernel<int>);
REGISTER_OP_MLU_KERNEL(gather_grad, ops::GatherGradOpMLUKernel<float>,
ops::GatherGradOpMLUKernel<paddle::platform::float16>,
ops::GatherGradOpMLUKernel<int>);
......@@ -934,9 +934,8 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
beta_ptr = static_cast<const void*>(&beta_int);
}
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetOpTensorWorkspaceSize_v2(
handle, op_tensor_desc, alpha1_ptr, a_desc, a, alpha2_ptr, b_desc, b,
beta_ptr, output_desc, output, &workspace_size));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetOpTensorWorkspaceSize(
handle, a_desc, b_desc, output_desc, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
......
......@@ -118,11 +118,11 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel<T> {
GetBasePtr(&mu_tensor));
for (size_t idx = 0; idx < n; ++idx) {
RegularizationType regularization_flag =
phi::RegularizationType regularization_flag =
regularization_methods.size() > 0 &&
regularization_methods[idx] == "l2_decay"
? RegularizationType::kL2DECAY
: RegularizationType::kNONE;
? phi::RegularizationType::kL2DECAY
: phi::RegularizationType::kNONE;
T regularization_coeff = static_cast<T>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff = static_cast<T>(regularization_coeffs[idx]);
......@@ -135,7 +135,7 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel<T> {
auto grad = grads[idx];
Tensor regularized_grad;
MLUCnnlTensorDesc param_desc(*param_out);
if (regularization_flag == RegularizationType::kL2DECAY) {
if (regularization_flag == phi::RegularizationType::kL2DECAY) {
regularized_grad = ctx.AllocateTmpTensor<T, MLUDeviceContext>(
param_out->dims(), dev_ctx);
MLUCnnlOpTensorDesc op_tensor_desc(
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace paddle {
namespace operators {
......@@ -27,10 +28,10 @@ class MLUMomentumOpKernel : public framework::OpKernel<T> {
std::string regularization_method =
ctx.Attr<std::string>("regularization_method");
auto regularization_coeff = ctx.Attr<float>("regularization_coeff");
RegularizationType regularization_flag{
RegularizationType::kNONE}; // disable regularization
phi::RegularizationType regularization_flag{
phi::RegularizationType::kNONE}; // disable regularization
if (regularization_method == "l2_decay") {
regularization_flag = RegularizationType::kL2DECAY;
regularization_flag = phi::RegularizationType::kL2DECAY;
}
T mu = static_cast<T>(ctx.Attr<float>("mu"));
......@@ -57,7 +58,7 @@ class MLUMomentumOpKernel : public framework::OpKernel<T> {
Tensor regularized_grad;
MLUCnnlTensorDesc param_desc(*param);
if (regularization_flag == RegularizationType::kL2DECAY) {
if (regularization_flag == phi::RegularizationType::kL2DECAY) {
regularized_grad =
ctx.AllocateTmpTensor<T, MLUDeviceContext>(param->dims(), dev_ctx);
MLUCnnlOpTensorDesc op_tensor_desc(
......
......@@ -34,6 +34,10 @@ limitations under the License. */
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/enforce.h"
#include "paddle/fluid/platform/device/mlu/mlu_info.h"
#endif
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
......@@ -135,6 +139,13 @@ void SynchronizeAllDevice() {
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
}
#endif
#ifdef PADDLE_WITH_MLU
int count = GetMLUDeviceCount();
for (int i = 0; i < count; i++) {
SetMLUDeviceId(i);
PADDLE_ENFORCE_MLU_SUCCESS(cnrtSyncDevice());
}
#endif
}
static double ToMegaBytes(size_t bytes) {
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append('..')
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
from paddle.framework import core
from paddle.fluid.dygraph.base import switch_to_static_graph
paddle.enable_static()
def gather_numpy(x, index, axis):
x_transpose = np.swapaxes(x, 0, axis)
tmp_gather = x_transpose[index, ...]
gather = np.swapaxes(tmp_gather, 0, axis)
return gather
class TestGatherOp(OpTest):
def setUp(self):
self.op_type = "gather"
self.place = paddle.MLUPlace(0)
self.__class__.use_mlu = True
self.python_api = paddle.gather
self.config()
xnp = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {
'X': xnp,
'Index': np.array(self.index).astype(self.index_type)
}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def config(self):
"""
For multi-dimension input
"""
self.x_shape = (10, 20)
self.x_type = "float32"
self.index = [1, 3, 5]
self.index_type = "int32"
class TestCase1(TestGatherOp):
def config(self):
"""
For one dimension input
"""
self.x_shape = (100)
self.x_type = "float32"
self.index = [1, 3, 5]
self.index_type = "int32"
class TestCase2(TestGatherOp):
def config(self):
"""
For int64_t index type
"""
self.x_shape = (100)
self.x_type = "float32"
self.index = [1, 3, 5]
self.index_type = "int64"
class API_TestDygraphGather(unittest.TestCase):
def test_out1(self):
paddle.disable_static()
input_1 = np.array([[1, 2], [3, 4], [5, 6]]).astype('int32')
index_1 = np.array([1, 2])
input = paddle.to_tensor(input_1)
index = paddle.to_tensor(index_1)
output = paddle.fluid.layers.gather(input, index)
output_np = output.numpy()
expected_output = np.array([[3, 4], [5, 6]]).astype('int32')
self.assertTrue(np.allclose(output_np, expected_output))
paddle.enable_static()
def test_out12(self):
paddle.disable_static()
input_1 = np.array([[1, 2], [3, 4], [5, 6]]).astype('int32')
index_1 = np.array([1, 2])
x = paddle.to_tensor(input_1)
index = paddle.to_tensor(index_1)
output = paddle.gather(x, index, axis=0)
output_np = output.numpy()
expected_output = gather_numpy(input_1, index_1, axis=0)
self.assertTrue(np.allclose(output_np, expected_output))
paddle.enable_static()
def test_zero_index(self):
paddle.disable_static()
x = paddle.to_tensor([[1, 2], [3, 4]]).astype('int32')
index = paddle.to_tensor(np.array([]).astype('int64'))
for axis in range(len(x.shape)):
out = paddle.gather(x, index, axis)
expected_shape = list(x.shape)
expected_shape[axis] = 0
self.assertEqual(list(out.shape), expected_shape)
paddle.enable_static()
class TestGathertError(unittest.TestCase):
def test_error1(self):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
shape = [8, 9, 6]
x = paddle.fluid.data(shape=shape, dtype='int8', name='x')
axis = paddle.fluid.data(shape=[1], dtype='float32', name='axis')
index = paddle.fluid.data(shape=shape, dtype='int32', name='index')
index_float = paddle.fluid.data(
shape=shape, dtype='float32', name='index_float')
def test_x_type():
paddle.gather(x, index)
self.assertRaises(TypeError, test_x_type)
def test_index_type():
paddle.gather(x, index_float)
self.assertRaises(TypeError, test_index_type)
def test_axis_dtype():
paddle.gather(x, index, axis=1.11)
self.assertRaises(TypeError, test_axis_dtype)
def test_axis_dtype1():
paddle.gather(x, index, axis=axis)
self.assertRaises(TypeError, test_axis_dtype1)
def test_error2(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
shape = [8, 9, 6]
x = fluid.data(shape=shape, dtype='int8', name='x')
index = fluid.data(shape=shape, dtype='int32', name='mask')
index_float = fluid.data(
shape=shape, dtype='float32', name='index_float')
def test_x_type():
paddle.fluid.layers.gather(x, index)
self.assertRaises(TypeError, test_x_type)
def test_index_type():
paddle.fluid.layers.gather(x, index_float)
self.assertRaises(TypeError, test_index_type)
if __name__ == "__main__":
unittest.main()
......@@ -216,9 +216,9 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
place = _current_expected_place()
elif not isinstance(place, (core.Place, core.CPUPlace, core.CUDAPinnedPlace,
core.CUDAPlace, core.NPUPlace, core.XPUPlace,
core.CustomPlace)):
core.MLUPlace, core.CustomPlace)):
raise ValueError(
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.CustomPlace"
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.MLUPlace, paddle.CustomPlace"
)
if not isinstance(data, np.ndarray):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册