未验证 提交 c5fcc96d 编写于 作者: W wangchaochaohu 提交者: GitHub

xpu support for fill_constant Op (#27675)

上级 a8208716
......@@ -66,7 +66,9 @@ class FillConstantKernel : public framework::OpKernel<T> {
value_tensor->numel()));
const T *tensor_data = value_tensor->data<T>();
framework::Tensor cpu_tensor;
if (platform::is_gpu_place(value_tensor->place())) {
auto tmp_place = value_tensor->place();
if (platform::is_gpu_place(tmp_place) ||
platform::is_xpu_place(tmp_place)) {
TensorCopySync(*value_tensor, platform::CPUPlace(), &cpu_tensor);
tensor_data = cpu_tensor.data<T>();
}
......@@ -102,6 +104,14 @@ class FillConstantKernel : public framework::OpKernel<T> {
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
}
#endif
#ifdef PADDLE_WITH_XPU
if (!cpu_place) {
tensor->mutable_data(ctx.GetPlace(), data_type);
math::SetConstant<platform::XPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::XPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
}
#endif
}
};
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_constant_op.h"
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<int>);
#endif
......@@ -22,6 +22,7 @@ limitations under the License. */
#include <cblas.h>
#endif
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
......@@ -44,6 +45,15 @@ template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
#ifdef PADDLE_WITH_XPU
template struct SetConstant<platform::XPUDeviceContext, platform::float16>;
template struct SetConstant<platform::XPUDeviceContext, float>;
template struct SetConstant<platform::XPUDeviceContext, double>;
template struct SetConstant<platform::XPUDeviceContext, int>;
template struct SetConstant<platform::XPUDeviceContext, int64_t>;
template struct SetConstant<platform::XPUDeviceContext, bool>;
#endif
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <cmath>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
......@@ -84,6 +85,33 @@ struct RowwiseMean {
framework::Tensor* vec);
};
#ifdef PADDLE_WITH_XPU
template <typename U>
struct TensorSetConstantXPU {
TensorSetConstantXPU(framework::Tensor* tensor, U value)
: tensor_(tensor), value_(value) {}
template <typename T>
void apply() const {
int dev_id = -1;
xpu_current_device(&dev_id);
if (dev_id >= 64) {
// if dev_id >= 64, the device is a simulator device, -64 to get real
// dev_id
dev_id -= 64;
}
auto xpu = platform::XPUPlace(dev_id);
auto* begin = tensor_->mutable_data<T>(xpu);
int numel = tensor_->numel();
std::unique_ptr<T[]> data_cpu(new T[numel]);
std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast<T>(value_));
memory::Copy(xpu, begin, platform::CPUPlace(),
static_cast<void*>(data_cpu.get()), numel * sizeof(T));
}
framework::Tensor* tensor_;
U value_;
};
#endif
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -27,8 +28,18 @@ template <typename DeviceContext, typename T>
void SetConstant<DeviceContext, T>::operator()(const DeviceContext& context,
framework::Tensor* tensor,
T num) {
auto t = framework::EigenVector<T>::Flatten(*tensor);
t.device(*context.eigen_device()) = t.constant(static_cast<T>(num));
bool xpu_place = false;
#ifdef PADDLE_WITH_XPU
if (context.GetPlace() == platform::XPUPlace()) {
xpu_place = true;
framework::VisitDataType(tensor->type(),
TensorSetConstantXPU<T>(tensor, num));
}
#endif
if (!xpu_place) {
auto t = framework::EigenVector<T>::Flatten(*tensor);
t.device(*context.eigen_device()) = t.constant(static_cast<T>(num));
}
}
template <typename DeviceContext, typename T, int Rank>
......
......@@ -26,7 +26,7 @@ inline std::vector<T> GetDataFromTensor(const framework::Tensor* x) {
if (x->type() == framework::proto::VarType::INT32) {
auto* data = x->data<int>();
framework::Tensor cpu_attr_tensor;
if (platform::is_gpu_place(x->place())) {
if (!platform::is_cpu_place(x->place())) {
TensorCopySync(*x, platform::CPUPlace(), &cpu_attr_tensor);
data = cpu_attr_tensor.data<int>();
}
......@@ -34,7 +34,7 @@ inline std::vector<T> GetDataFromTensor(const framework::Tensor* x) {
} else if (x->type() == framework::proto::VarType::INT64) {
auto* data = x->data<int64_t>();
framework::Tensor cpu_attr_tensor;
if (platform::is_gpu_place(x->place())) {
if (!platform::is_cpu_place(x->place())) {
TensorCopySync(*x, platform::CPUPlace(), &cpu_attr_tensor);
data = cpu_attr_tensor.data<int64_t>();
}
......@@ -62,7 +62,7 @@ inline std::vector<T> GetDataFromTensorList(
tensor->dims()));
if (tensor->type() == framework::proto::VarType::INT32) {
if (platform::is_gpu_place(tensor->place())) {
if (!platform::is_cpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_data.push_back(static_cast<T>(*temp.data<int>()));
......@@ -70,7 +70,7 @@ inline std::vector<T> GetDataFromTensorList(
vec_new_data.push_back(static_cast<T>(*tensor->data<int>()));
}
} else if (tensor->type() == framework::proto::VarType::INT64) {
if (platform::is_gpu_place(tensor->place())) {
if (!platform::is_cpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
// NOTE: Converting int64 to int32 may cause data overflow.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
from op_test import OpTest
import paddle
import numpy as np
# Situation 1: Attr(shape) is a list(without tensor)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestFillConstantOp1(OpTest):
def setUp(self):
'''Test fill_constant op with specified value'''
self.op_type = "fill_constant"
self.inputs = {}
self.attrs = {'shape': [123, 92], 'dtype': 5, 'value': 3.8}
self.outputs = {'Out': np.full((123, 92), 3.8)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestFillConstantOp2(OpTest):
def setUp(self):
'''Test fill_constant op with default value'''
self.op_type = "fill_constant"
self.inputs = {}
self.attrs = {'shape': [123, 92], 'dtype': 5}
self.outputs = {'Out': np.full((123, 92), 0.0)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestFillConstantOp3(OpTest):
def setUp(self):
'''Test fill_constant op with specified int64 value'''
self.op_type = "fill_constant"
self.inputs = {}
self.attrs = {'shape': [123, 92], 'dtype': 3, 'value': 10000000000}
self.outputs = {'Out': np.full((123, 92), 10000000000)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestFillConstantOp4(OpTest):
def setUp(self):
'''Test fill_constant op with specified int value'''
self.op_type = "fill_constant"
self.inputs = {}
self.attrs = {'shape': [123, 92], 'dtype': 2, 'value': 3}
self.outputs = {'Out': np.full((123, 92), 3)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
# Situation 2: Attr(shape) is a list(with tensor)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestFillConstantOp1_ShapeTensorList(OpTest):
def setUp(self):
'''Test fill_constant op with specified value'''
self.op_type = "fill_constant"
self.init_data()
shape_tensor_list = []
for index, ele in enumerate(self.shape):
shape_tensor_list.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {"ShapeTensorList": shape_tensor_list}
self.attrs = {
'shape': self.infer_shape,
'dtype': 5,
'value': self.value
}
self.outputs = {'Out': np.full(self.shape, self.value)}
def init_data(self):
self.shape = [123, 92]
self.infer_shape = [-1, 92]
self.value = 3.8
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestFillConstantOp2_ShapeTensorList(OpTest):
def setUp(self):
'''Test fill_constant op with default value'''
self.op_type = "fill_constant"
self.init_data()
shape_tensor_list = []
for index, ele in enumerate(self.shape):
shape_tensor_list.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {"ShapeTensorList": shape_tensor_list}
self.attrs = {'shape': self.infer_shape, 'dtype': 5}
self.outputs = {'Out': np.full(self.shape, 0.0)}
def init_data(self):
self.shape = [123, 92]
self.infer_shape = [-1, -1]
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestFillConstantOp3_ShapeTensorList(TestFillConstantOp1_ShapeTensorList):
def init_data(self):
self.shape = [123, 92]
self.infer_shape = [123, -1]
self.value = 10000000000
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestFillConstantOp4_ShapeTensorList(TestFillConstantOp1_ShapeTensorList):
def init_data(self):
self.shape = [123, 92]
self.infer_shape = [123, -1]
self.value = 3
# Situation 3: shape is a tensor
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestFillConstantOp1_ShapeTensor(OpTest):
def setUp(self):
'''Test fill_constant op with specified value'''
self.op_type = "fill_constant"
self.init_data()
self.inputs = {"ShapeTensor": np.array(self.shape).astype("int32")}
self.attrs = {'value': self.value, 'dtype': 5}
self.outputs = {'Out': np.full(self.shape, self.value)}
def init_data(self):
self.shape = [123, 92]
self.value = 3.8
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
# Situation 4: value is a tensor
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestFillConstantOp1_ValueTensor(OpTest):
def setUp(self):
'''Test fill_constant op with specified value'''
self.op_type = "fill_constant"
self.init_data()
self.inputs = {
"ShapeTensor": np.array(self.shape).astype("int32"),
'ValueTensor': np.array([self.value]).astype("float32")
}
self.attrs = {'value': self.value + 1.0, 'dtype': 5}
self.outputs = {'Out': np.full(self.shape, self.value)}
def init_data(self):
self.shape = [123, 92]
self.value = 3.8
self.dtype = np.float32
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
# Situation 5: value is a tensor
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestFillConstantOp2_ValueTensor(OpTest):
def setUp(self):
'''Test fill_constant op with specified value'''
self.op_type = "fill_constant"
self.init_data()
self.inputs = {
"ShapeTensor": np.array(self.shape).astype("int32"),
'ValueTensor': np.array([self.value]).astype("int32")
}
self.attrs = {'value': self.value, 'dtype': 2}
self.outputs = {'Out': np.full(self.shape, self.value)}
def init_data(self):
self.shape = [123, 92]
self.value = 3
self.dtype = np.int32
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册