未验证 提交 9e776f62 编写于 作者: C Chen Weihang 提交者: GitHub

[Cherry-pick] Fix incompatible error for place type (#43830)

* Create Tensor by paddle::empty  in custom operator (#41840)

* create tensor by empty in custom op

* fix some bug

* update relu custom op demo (#43173)

* Fix incompatible error for custom op Placetype (#43749)

* fix incompatible error

* rmeove default constructor

* add macro

* fix cpu make error

* add DefaultGPUPlace api
Co-authored-by: Nzyfncg <zhangyunfei07@baidu.com>
上级 51240331
...@@ -45,7 +45,7 @@ yaml_types_mapping = { ...@@ -45,7 +45,7 @@ yaml_types_mapping = {
'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \ 'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'str' : 'std::string', \ 'str' : 'std::string', \
'Place' : 'paddle::experimental::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \ 'Place' : 'paddle::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>', 'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor', 'Tensor' : 'Tensor',
'Tensor[]' : 'std::vector<Tensor>', 'Tensor[]' : 'std::vector<Tensor>',
......
...@@ -46,7 +46,7 @@ atype_to_parsing_function = { ...@@ -46,7 +46,7 @@ atype_to_parsing_function = {
"std::vector<std::string>": "CastPyArg2Strings", "std::vector<std::string>": "CastPyArg2Strings",
"paddle::experimental::Scalar": "CastPyArg2Scalar", "paddle::experimental::Scalar": "CastPyArg2Scalar",
"paddle::experimental::IntArray": "CastPyArg2IntArray", "paddle::experimental::IntArray": "CastPyArg2IntArray",
"paddle::experimental::Place": "CastPyArg2Place", "paddle::Place": "CastPyArg2Place",
"paddle::experimental::DataType": "CastPyArg2DataType", "paddle::experimental::DataType": "CastPyArg2DataType",
} }
......
...@@ -1194,15 +1194,13 @@ std::vector<paddle::framework::Scope*> GetScopePtrListFromArgs( ...@@ -1194,15 +1194,13 @@ std::vector<paddle::framework::Scope*> GetScopePtrListFromArgs(
return result; return result;
} }
paddle::experimental::Place CastPyArg2Place(PyObject* obj, paddle::Place CastPyArg2Place(PyObject* obj, const std::string& op_type,
const std::string& op_type, ssize_t arg_pos) {
ssize_t arg_pos) {
return CastPyArg2Place(obj, arg_pos); return CastPyArg2Place(obj, arg_pos);
} }
paddle::experimental::DataType CastPyArg2DataType(PyObject* obj, paddle::DataType CastPyArg2DataType(PyObject* obj, const std::string& op_type,
const std::string& op_type, ssize_t arg_pos) {
ssize_t arg_pos) {
if (obj == Py_None) { if (obj == Py_None) {
return paddle::experimental::DataType::UNDEFINED; return paddle::experimental::DataType::UNDEFINED;
} }
......
...@@ -171,13 +171,11 @@ paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj, ...@@ -171,13 +171,11 @@ paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj,
const std::string& op_type, const std::string& op_type,
ssize_t arg_pos); ssize_t arg_pos);
paddle::experimental::Place CastPyArg2Place(PyObject* obj, paddle::Place CastPyArg2Place(PyObject* obj, const std::string& op_type,
const std::string& op_type, ssize_t arg_pos);
ssize_t arg_pos);
paddle::experimental::DataType CastPyArg2DataType(PyObject* obj, paddle::DataType CastPyArg2DataType(PyObject* obj, const std::string& op_type,
const std::string& op_type, ssize_t arg_pos);
ssize_t arg_pos);
paddle::optional<const paddle::experimental::Tensor&> GetOptionalTensorFromArgs( paddle::optional<const paddle::experimental::Tensor&> GetOptionalTensorFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args, const std::string& op_type, const std::string& arg_name, PyObject* args,
......
...@@ -37,24 +37,6 @@ limitations under the License. */ ...@@ -37,24 +37,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
namespace detail {
static Place GetCorrectPlaceByPlaceType(const Place &place_type) {
auto alloc_type = place_type.GetType();
switch (alloc_type) {
case AllocationType::CPU:
return place_type;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
case AllocationType::GPU:
return phi::Place(AllocationType::GPU,
phi::backends::gpu::GetCurrentDeviceId());
#endif
default:
PADDLE_THROW(phi::errors::Unavailable(
"The PlaceType is a legacy design, only supports CPU and GPU, "
"and will not support other place types in the future."));
}
}
} // namespace detail
/////// Tensor Methods //////// /////// Tensor Methods ////////
...@@ -76,7 +58,7 @@ Tensor::Tensor(const Place &place) { ...@@ -76,7 +58,7 @@ Tensor::Tensor(const Place &place) {
"Reason: A legal tensor cannot be constructed only based on " "Reason: A legal tensor cannot be constructed only based on "
"the `place`, and datatype, shape, layout, etc. is also " "the `place`, and datatype, shape, layout, etc. is also "
"required."; "required.";
DefaultAllocator alloc(detail::GetCorrectPlaceByPlaceType(place)); DefaultAllocator alloc(place);
impl_ = std::move(std::make_shared<phi::DenseTensor>( impl_ = std::move(std::make_shared<phi::DenseTensor>(
&alloc, &alloc,
std::move(phi::DenseTensorMeta( std::move(phi::DenseTensorMeta(
...@@ -92,7 +74,7 @@ Tensor::Tensor(const Place &place, const std::vector<int64_t> &shape) { ...@@ -92,7 +74,7 @@ Tensor::Tensor(const Place &place, const std::vector<int64_t> &shape) {
"Reason: A legal tensor cannot be constructed only based on " "Reason: A legal tensor cannot be constructed only based on "
"the `place` and `shape`, and datatype, layout, etc. is also " "the `place` and `shape`, and datatype, layout, etc. is also "
"required."; "required.";
DefaultAllocator alloc(detail::GetCorrectPlaceByPlaceType(place)); DefaultAllocator alloc(place);
impl_ = std::move(std::make_shared<phi::DenseTensor>( impl_ = std::move(std::make_shared<phi::DenseTensor>(
&alloc, &alloc,
std::move(phi::DenseTensorMeta(phi::DataType::FLOAT32, std::move(phi::DenseTensorMeta(phi::DataType::FLOAT32,
......
cc_library(phi_place SRCS place.cc) if(WITH_GPU)
cc_library(scalar SRCS scalar.cc DEPS phi_enforce tensor) nv_library(
phi_place
SRCS place.cc
DEPS phi_gpu_info)
elseif(WITH_ROCM)
hip_library(
phi_place
SRCS place.cc
DEPS phi_gpu_info)
else()
cc_library(phi_place SRCS place.cc)
endif()
cc_library(
scalar
SRCS scalar.cc
DEPS phi_enforce tensor)
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/api/ext/exception.h" #include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
namespace phi { namespace phi {
...@@ -110,14 +111,32 @@ uint32_t Place::Hash::operator()(const Place &place) const { ...@@ -110,14 +111,32 @@ uint32_t Place::Hash::operator()(const Place &place) const {
return hash_value; return hash_value;
} }
namespace detail {
static int8_t GetCorrectDeviceIdByPlaceType(
const paddle::PlaceType &place_type) {
switch (place_type) {
case paddle::PlaceType::kCPU:
return 0;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
case paddle::PlaceType::kGPU:
return phi::backends::gpu::GetCurrentDeviceId();
#endif
default:
PD_THROW(
"The PlaceType is a legacy design, only supports CPU and GPU, "
"and will not support other place types in the future.");
}
}
} // namespace detail
Place::Place(paddle::PlaceType type) Place::Place(paddle::PlaceType type)
: device(0), : device(detail::GetCorrectDeviceIdByPlaceType(type)),
alloc_type_(static_cast<AllocationType>(type)), alloc_type_(static_cast<AllocationType>(type)),
device_type_id_(GetOrRegisterGlobalDeviceTypeId("")) { device_type_id_(GetOrRegisterGlobalDeviceTypeId("")) {
LOG_FIRST_N(WARNING, 1) LOG_FIRST_N(WARNING, 1)
<< "The `paddle::PlaceType::kCPU/kGPU` is deprecated since version " << "The `paddle::PlaceType::kCPU/kGPU` is deprecated since version "
"2.3, and will be removed in version 2.4! Please use " "2.3, and will be removed in version 2.4! Please use "
"`paddle::CPUPlace()/GPUPlace()` to represent the place type."; "`paddle::CPUPlace()/DefaultGPUPlace()` to represent the place type.";
} }
} // namespace phi } // namespace phi
...@@ -140,4 +159,13 @@ bool operator==(PlaceType place_type, const Place &place) { ...@@ -140,4 +159,13 @@ bool operator==(PlaceType place_type, const Place &place) {
return static_cast<AllocationType>(place_type) == place.GetType(); return static_cast<AllocationType>(place_type) == place.GetType();
} }
GPUPlace DefaultGPUPlace() {
return GPUPlace(
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::backends::gpu::GetCurrentDeviceId());
#else
0);
#endif
}
} // namespace paddle } // namespace paddle
...@@ -213,9 +213,6 @@ std::ostream& operator<<(std::ostream&, const Place&); ...@@ -213,9 +213,6 @@ std::ostream& operator<<(std::ostream&, const Place&);
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
using AllocationType = phi::AllocationType; using AllocationType = phi::AllocationType;
using Place = phi::Place;
using CPUPlace = phi::CPUPlace;
using GPUPlace = phi::GPUPlace;
using GPUPinnedPlace = phi::GPUPinnedPlace; using GPUPinnedPlace = phi::GPUPinnedPlace;
using XPUPlace = phi::XPUPlace; using XPUPlace = phi::XPUPlace;
using NPUPlace = phi::NPUPlace; using NPUPlace = phi::NPUPlace;
...@@ -259,4 +256,6 @@ enum class PlaceType { ...@@ -259,4 +256,6 @@ enum class PlaceType {
PADDLE_API bool operator==(const Place& place, PlaceType place_type); PADDLE_API bool operator==(const Place& place, PlaceType place_type);
PADDLE_API bool operator==(PlaceType place_type, const Place& place); PADDLE_API bool operator==(PlaceType place_type, const Place& place);
PADDLE_API GPUPlace DefaultGPUPlace();
} // namespace paddle } // namespace paddle
...@@ -37,13 +37,11 @@ namespace tests { ...@@ -37,13 +37,11 @@ namespace tests {
// TODO(chenweihang): Remove this test after the API is used in the dygraph // TODO(chenweihang): Remove this test after the API is used in the dygraph
TEST(API, data_transform_same_place) { TEST(API, data_transform_same_place) {
// 1. create tensor // 1. create tensor
auto x = paddle::experimental::full({3, 3}, auto x =
1.0, paddle::experimental::full({3, 3}, 1.0, DataType::COMPLEX128, CPUPlace());
experimental::DataType::COMPLEX128,
experimental::CPUPlace());
auto y = paddle::experimental::full( auto y =
{3, 3}, 2.0, experimental::DataType::FLOAT32, experimental::CPUPlace()); paddle::experimental::full({3, 3}, 2.0, DataType::FLOAT32, CPUPlace());
std::vector<phi::dtype::complex<double>> sum(9, 6.0); std::vector<phi::dtype::complex<double>> sum(9, 6.0);
...@@ -75,10 +73,10 @@ TEST(API, data_transform_same_place) { ...@@ -75,10 +73,10 @@ TEST(API, data_transform_same_place) {
TEST(Tensor, data_transform_diff_place) { TEST(Tensor, data_transform_diff_place) {
// 1. create tensor // 1. create tensor
auto x = paddle::experimental::full( auto x = paddle::experimental::full(
{3, 3}, 1.0, experimental::DataType::FLOAT64, experimental::CPUPlace()); {3, 3}, 1.0, experimental::DataType::FLOAT64, CPUPlace());
auto y = paddle::experimental::full( auto y = paddle::experimental::full(
{3, 3}, 2.0, experimental::DataType::FLOAT64, experimental::GPUPlace()); {3, 3}, 2.0, experimental::DataType::FLOAT64, GPUPlace());
std::vector<float> sum(9, 6.0); std::vector<float> sum(9, 6.0);
...@@ -93,10 +91,9 @@ TEST(Tensor, data_transform_diff_place) { ...@@ -93,10 +91,9 @@ TEST(Tensor, data_transform_diff_place) {
ASSERT_EQ(out.dtype(), phi::DataType::FLOAT64); ASSERT_EQ(out.dtype(), phi::DataType::FLOAT64);
ASSERT_EQ(out.layout(), phi::DataLayout::NCHW); ASSERT_EQ(out.layout(), phi::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true); ASSERT_EQ(out.initialized(), true);
ASSERT_EQ(out.impl()->place(), ASSERT_EQ(out.impl()->place(), phi::TransToPhiPlace(phi::Backend::GPU));
phi::TransToPhiPlace(experimental::Backend::GPU));
auto ref_out = experimental::copy_to(out, experimental::CPUPlace(), true); auto ref_out = experimental::copy_to(out, CPUPlace(), true);
auto dense_out = std::dynamic_pointer_cast<phi::DenseTensor>(ref_out.impl()); auto dense_out = std::dynamic_pointer_cast<phi::DenseTensor>(ref_out.impl());
for (size_t i = 0; i < 9; i++) { for (size_t i = 0; i < 9; i++) {
......
...@@ -30,7 +30,7 @@ namespace tests { ...@@ -30,7 +30,7 @@ namespace tests {
TEST(API, scale) { TEST(API, scale) {
auto x = experimental::full( auto x = experimental::full(
{3, 4}, 1.0, experimental::DataType::FLOAT32, experimental::CPUPlace()); {3, 4}, 1.0, experimental::DataType::FLOAT32, CPUPlace());
const size_t cycles = 300; const size_t cycles = 300;
phi::tests::Timer timer; phi::tests::Timer timer;
......
...@@ -22,8 +22,7 @@ ...@@ -22,8 +22,7 @@
std::vector<paddle::Tensor> ContextPoolTest(const paddle::Tensor& x) { std::vector<paddle::Tensor> ContextPoolTest(const paddle::Tensor& x) {
// 1. test cpu context // 1. test cpu context
paddle::experimental::Place cpu_place( paddle::Place cpu_place(paddle::experimental::AllocationType::CPU);
paddle::experimental::AllocationType::CPU);
auto* cpu_ctx = auto* cpu_ctx =
paddle::experimental::DeviceContextPool::Instance() paddle::experimental::DeviceContextPool::Instance()
.Get<paddle::experimental::AllocationType::CPU>(cpu_place); .Get<paddle::experimental::AllocationType::CPU>(cpu_place);
...@@ -34,8 +33,7 @@ std::vector<paddle::Tensor> ContextPoolTest(const paddle::Tensor& x) { ...@@ -34,8 +33,7 @@ std::vector<paddle::Tensor> ContextPoolTest(const paddle::Tensor& x) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// 2. test gpu context // 2. test gpu context
paddle::experimental::Place gpu_place( paddle::Place gpu_place(paddle::experimental::AllocationType::GPU);
paddle::experimental::AllocationType::GPU);
auto* gpu_ctx = auto* gpu_ctx =
paddle::experimental::DeviceContextPool::Instance() paddle::experimental::DeviceContextPool::Instance()
.Get<paddle::experimental::AllocationType::GPU>(gpu_place); .Get<paddle::experimental::AllocationType::GPU>(gpu_place);
......
...@@ -75,7 +75,7 @@ std::vector<paddle::Tensor> ConcatForwardDynamicAxis( ...@@ -75,7 +75,7 @@ std::vector<paddle::Tensor> ConcatForwardDynamicAxis(
auto out_shape = ComputeOutShape(in_shapes, axis); auto out_shape = ComputeOutShape(in_shapes, axis);
// create output // create output
auto out = paddle::Tensor(paddle::PlaceType::kCPU, out_shape); auto out = paddle::empty(out_shape, inputs[0].type(), paddle::CPUPlace());
// calc // calc
PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(
...@@ -106,7 +106,7 @@ std::vector<paddle::Tensor> ConcatBackwardDynamicAxis( ...@@ -106,7 +106,7 @@ std::vector<paddle::Tensor> ConcatBackwardDynamicAxis(
// create outputs // create outputs
std::vector<paddle::Tensor> grad_inputs; std::vector<paddle::Tensor> grad_inputs;
for (auto& t : inputs) { for (auto& t : inputs) {
auto grad = paddle::Tensor(paddle::PlaceType::kCPU, t.shape()); auto grad = paddle::empty(t.shape(), t.dtype(), t.place());
grad_inputs.emplace_back(grad); grad_inputs.emplace_back(grad);
} }
...@@ -161,7 +161,7 @@ std::vector<paddle::Tensor> ConcatForwardStaticAxis( ...@@ -161,7 +161,7 @@ std::vector<paddle::Tensor> ConcatForwardStaticAxis(
auto out_shape = ComputeOutShape(in_shapes, final_axis); auto out_shape = ComputeOutShape(in_shapes, final_axis);
// create output // create output
auto out = paddle::Tensor(paddle::PlaceType::kCPU, out_shape); auto out = paddle::empty(out_shape, inputs[0].type(), paddle::CPUPlace());
// calc // calc
PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(
...@@ -190,7 +190,7 @@ std::vector<paddle::Tensor> ConcatBackwardStaticAxis( ...@@ -190,7 +190,7 @@ std::vector<paddle::Tensor> ConcatBackwardStaticAxis(
// create outputs // create outputs
std::vector<paddle::Tensor> grad_inputs; std::vector<paddle::Tensor> grad_inputs;
for (auto& t : inputs) { for (auto& t : inputs) {
auto grad = paddle::Tensor(paddle::PlaceType::kCPU, t.shape()); auto grad = paddle::empty(t.shape(), t.dtype(), t.place());
grad_inputs.emplace_back(grad); grad_inputs.emplace_back(grad);
} }
......
...@@ -71,7 +71,7 @@ void ConjCPUKernel(const data_t* x_data, int64_t numel, data_t* out_data) { ...@@ -71,7 +71,7 @@ void ConjCPUKernel(const data_t* x_data, int64_t numel, data_t* out_data) {
std::vector<paddle::Tensor> ConjFunction(const paddle::Tensor& x) { std::vector<paddle::Tensor> ConjFunction(const paddle::Tensor& x) {
CHECK_INPUT(x); CHECK_INPUT(x);
paddle::Tensor out(x.place(), x.shape()); paddle::Tensor out = paddle::empty(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES( PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
x.type(), "ConjCPUKernel", ([&] { x.type(), "ConjCPUKernel", ([&] {
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
#include "paddle/extension.h" #include "paddle/extension.h"
#define CHECK_CPU_INPUT(x) \ #define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
template <typename data_t> template <typename data_t>
void relu_cpu_forward_kernel(const data_t* x_data, void relu_cpu_forward_kernel(const data_t* x_data,
...@@ -26,7 +25,7 @@ void relu_cpu_forward_kernel(const data_t* x_data, ...@@ -26,7 +25,7 @@ void relu_cpu_forward_kernel(const data_t* x_data,
int64_t x_numel) { int64_t x_numel) {
PD_CHECK(x_data != nullptr, "x_data is nullptr."); PD_CHECK(x_data != nullptr, "x_data is nullptr.");
PD_CHECK(out_data != nullptr, "out_data is nullptr."); PD_CHECK(out_data != nullptr, "out_data is nullptr.");
for (int i = 0; i < x_numel; ++i) { for (int64_t i = 0; i < x_numel; ++i) {
out_data[i] = std::max(static_cast<data_t>(0.), x_data[i]); out_data[i] = std::max(static_cast<data_t>(0.), x_data[i]);
} }
} }
...@@ -36,7 +35,7 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data, ...@@ -36,7 +35,7 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data,
const data_t* out_data, const data_t* out_data,
data_t* grad_x_data, data_t* grad_x_data,
int64_t out_numel) { int64_t out_numel) {
for (int i = 0; i < out_numel; ++i) { for (int64_t i = 0; i < out_numel; ++i) {
grad_x_data[i] = grad_x_data[i] =
grad_out_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.); grad_out_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
} }
...@@ -54,12 +53,12 @@ void relu_cpu_double_backward_kernel(const data_t* out_data, ...@@ -54,12 +53,12 @@ void relu_cpu_double_backward_kernel(const data_t* out_data,
} }
std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) { std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); auto out = paddle::empty_like(x);
PD_DISPATCH_FLOATING_TYPES( PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] { x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>( relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size()); x.data<data_t>(), out.data<data_t>(), x.numel());
})); }));
return {out}; return {out};
...@@ -68,13 +67,13 @@ std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) { ...@@ -68,13 +67,13 @@ std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x, std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
const paddle::Tensor& out, const paddle::Tensor& out,
const paddle::Tensor& grad_out) { const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); auto grad_x = paddle::empty_like(x);
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>( relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(), grad_out.data<data_t>(),
out.data<data_t>(), out.data<data_t>(),
grad_x.mutable_data<data_t>(x.place()), grad_x.data<data_t>(),
out.size()); out.size());
})); }));
...@@ -85,7 +84,7 @@ std::vector<paddle::Tensor> relu_cpu_double_backward( ...@@ -85,7 +84,7 @@ std::vector<paddle::Tensor> relu_cpu_double_backward(
const paddle::Tensor& out, const paddle::Tensor& ddx) { const paddle::Tensor& out, const paddle::Tensor& ddx) {
CHECK_CPU_INPUT(out); CHECK_CPU_INPUT(out);
CHECK_CPU_INPUT(ddx); CHECK_CPU_INPUT(ddx);
auto ddout = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); auto ddout = paddle::empty(out.shape(), out.dtype(), out.place());
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_double_backward", ([&] { PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_double_backward", ([&] {
relu_cpu_double_backward_kernel<data_t>( relu_cpu_double_backward_kernel<data_t>(
...@@ -108,9 +107,9 @@ std::vector<paddle::Tensor> relu_cuda_double_backward( ...@@ -108,9 +107,9 @@ std::vector<paddle::Tensor> relu_cuda_double_backward(
const paddle::Tensor& out, const paddle::Tensor& ddx); const paddle::Tensor& out, const paddle::Tensor& ddx);
std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) { std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
if (x.place() == paddle::PlaceType::kCPU) { if (x.is_cpu()) {
return relu_cpu_forward(x); return relu_cpu_forward(x);
} else if (x.place() == paddle::PlaceType::kGPU) { } else if (x.is_gpu()) {
return relu_cuda_forward(x); return relu_cuda_forward(x);
} else { } else {
PD_THROW("Not implemented."); PD_THROW("Not implemented.");
...@@ -120,10 +119,9 @@ std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) { ...@@ -120,10 +119,9 @@ std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x, std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
const paddle::Tensor& out, const paddle::Tensor& out,
const paddle::Tensor& grad_out) { const paddle::Tensor& grad_out) {
// TODO(chenweihang): Check Input if (x.is_cpu()) {
if (x.place() == paddle::PlaceType::kCPU) {
return relu_cpu_backward(x, out, grad_out); return relu_cpu_backward(x, out, grad_out);
} else if (x.place() == paddle::PlaceType::kGPU) { } else if (x.is_gpu()) {
return relu_cuda_backward(x, out, grad_out); return relu_cuda_backward(x, out, grad_out);
} else { } else {
PD_THROW("Not implemented."); PD_THROW("Not implemented.");
...@@ -165,7 +163,7 @@ PD_BUILD_DOUBLE_GRAD_OP(custom_relu) ...@@ -165,7 +163,7 @@ PD_BUILD_DOUBLE_GRAD_OP(custom_relu)
std::vector<paddle::Tensor> relu_cpu_backward_without_x( std::vector<paddle::Tensor> relu_cpu_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out) { const paddle::Tensor& out, const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); auto grad_x = paddle::empty(out.shape(), out.dtype(), out.place());
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>( relu_cpu_backward_kernel<data_t>(
...@@ -214,7 +212,7 @@ void relu_cpu_forward_out(const paddle::Tensor& x, paddle::Tensor* out) { ...@@ -214,7 +212,7 @@ void relu_cpu_forward_out(const paddle::Tensor& x, paddle::Tensor* out) {
PD_DISPATCH_FLOATING_TYPES( PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] { x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>( relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out->mutable_data<data_t>(x.place()), x.size()); x.data<data_t>(), out->mutable_data<data_t>(x.place()), x.numel());
})); }));
} }
......
...@@ -14,15 +14,14 @@ ...@@ -14,15 +14,14 @@
#include "paddle/extension.h" #include "paddle/extension.h"
#define CHECK_GPU_INPUT(x) \ #define CHECK_GPU_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.")
template <typename data_t> template <typename data_t>
__global__ void relu_cuda_forward_kernel(const data_t* x, __global__ void relu_cuda_forward_kernel(const data_t* x,
data_t* y, data_t* y,
const int num) { int64_t num) {
int gid = blockIdx.x * blockDim.x + threadIdx.x; int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) { for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) {
y[i] = x[i] > static_cast<data_t>(0.) ? x[i] : static_cast<data_t>(0.); y[i] = x[i] > static_cast<data_t>(0.) ? x[i] : static_cast<data_t>(0.);
} }
} }
...@@ -31,9 +30,9 @@ template <typename data_t> ...@@ -31,9 +30,9 @@ template <typename data_t>
__global__ void relu_cuda_backward_kernel(const data_t* dy, __global__ void relu_cuda_backward_kernel(const data_t* dy,
const data_t* y, const data_t* y,
data_t* dx, data_t* dx,
const int num) { int64_t num) {
int gid = blockIdx.x * blockDim.x + threadIdx.x; int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) { for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) {
dx[i] = dy[i] * (y[i] > static_cast<data_t>(0.) ? static_cast<data_t>(1.) dx[i] = dy[i] * (y[i] > static_cast<data_t>(0.) ? static_cast<data_t>(1.)
: static_cast<data_t>(0.)); : static_cast<data_t>(0.));
} }
...@@ -54,15 +53,17 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data, ...@@ -54,15 +53,17 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data,
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) { std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
CHECK_GPU_INPUT(x); CHECK_GPU_INPUT(x);
auto out = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); auto out = paddle::empty_like(x);
int numel = x.size(); PD_CHECK(x.place() == paddle::DefaultGPUPlace());
int block = 512;
int grid = (numel + block - 1) / block; int64_t numel = x.numel();
int64_t block = 512;
int64_t grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES( PD_DISPATCH_FLOATING_AND_HALF_TYPES(
x.type(), "relu_cuda_forward_kernel", ([&] { x.type(), "relu_cuda_forward_kernel", ([&] {
relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>( relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel); x.data<data_t>(), out.data<data_t>(), numel);
})); }));
return {out}; return {out};
...@@ -74,11 +75,13 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x, ...@@ -74,11 +75,13 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
CHECK_GPU_INPUT(x); CHECK_GPU_INPUT(x);
CHECK_GPU_INPUT(out); CHECK_GPU_INPUT(out);
CHECK_GPU_INPUT(grad_out); CHECK_GPU_INPUT(grad_out);
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); auto grad_x = paddle::empty_like(x);
int numel = out.size(); PD_CHECK(x.place() == paddle::DefaultGPUPlace());
int block = 512;
int grid = (numel + block - 1) / block; int64_t numel = out.numel();
int64_t block = 512;
int64_t grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES( PD_DISPATCH_FLOATING_AND_HALF_TYPES(
out.type(), "relu_cuda_backward_kernel", ([&] { out.type(), "relu_cuda_backward_kernel", ([&] {
relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>( relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
...@@ -95,19 +98,19 @@ std::vector<paddle::Tensor> relu_cuda_double_backward( ...@@ -95,19 +98,19 @@ std::vector<paddle::Tensor> relu_cuda_double_backward(
const paddle::Tensor& out, const paddle::Tensor& ddx) { const paddle::Tensor& out, const paddle::Tensor& ddx) {
CHECK_GPU_INPUT(out); CHECK_GPU_INPUT(out);
CHECK_GPU_INPUT(ddx); CHECK_GPU_INPUT(ddx);
auto ddout = paddle::Tensor(paddle::PlaceType::kGPU, out.shape()); auto ddout = paddle::empty(out.shape(), out.dtype(), out.place());
int64_t numel = out.size(); int64_t numel = out.numel();
int64_t block = 512; int64_t block = 512;
int64_t grid = (numel + block - 1) / block; int64_t grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES( PD_DISPATCH_FLOATING_AND_HALF_TYPES(
out.type(), "relu_cuda_double_backward_kernel", ([&] { out.type(), "relu_cuda_double_backward_kernel", ([&] {
relu_cuda_double_backward_kernel< relu_cuda_double_backward_kernel<data_t>
data_t><<<grid, block, 0, out.stream()>>>( <<<grid, block, 0, out.stream()>>>(
out.data<data_t>(), out.data<data_t>(),
ddx.data<data_t>(), ddx.data<data_t>(),
ddout.mutable_data<data_t>(out.place()), ddout.mutable_data<data_t>(out.place()),
numel); numel);
})); }));
std::cout << "Debug info: run relu gpu double backward success." << std::endl; std::cout << "Debug info: run relu gpu double backward success." << std::endl;
...@@ -117,9 +120,9 @@ std::vector<paddle::Tensor> relu_cuda_double_backward( ...@@ -117,9 +120,9 @@ std::vector<paddle::Tensor> relu_cuda_double_backward(
std::vector<paddle::Tensor> relu_cuda_backward_without_x( std::vector<paddle::Tensor> relu_cuda_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out) { const paddle::Tensor& out, const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, out.shape()); auto grad_x = paddle::empty(out.shape(), out.dtype(), out.place());
int numel = out.size(); int numel = out.numel();
int block = 512; int block = 512;
int grid = (numel + block - 1) / block; int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES( PD_DISPATCH_FLOATING_AND_HALF_TYPES(
...@@ -135,7 +138,7 @@ std::vector<paddle::Tensor> relu_cuda_backward_without_x( ...@@ -135,7 +138,7 @@ std::vector<paddle::Tensor> relu_cuda_backward_without_x(
} }
void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out) { void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out) {
int numel = x.size(); int numel = x.numel();
int block = 512; int block = 512;
int grid = (numel + block - 1) / block; int grid = (numel + block - 1) / block;
out->reshape(x.shape()); out->reshape(x.shape());
...@@ -150,7 +153,7 @@ void relu_cuda_backward_out(const paddle::Tensor& x, ...@@ -150,7 +153,7 @@ void relu_cuda_backward_out(const paddle::Tensor& x,
const paddle::Tensor& out, const paddle::Tensor& out,
const paddle::Tensor& grad_out, const paddle::Tensor& grad_out,
paddle::Tensor* grad_x) { paddle::Tensor* grad_x) {
int numel = out.size(); int numel = out.numel();
int block = 512; int block = 512;
int grid = (numel + block - 1) / block; int grid = (numel + block - 1) / block;
grad_x->reshape(x.shape()); grad_x->reshape(x.shape());
......
...@@ -68,7 +68,7 @@ void tanh_cpu_double_backward_kernel(const data_t* out_data, ...@@ -68,7 +68,7 @@ void tanh_cpu_double_backward_kernel(const data_t* out_data,
std::vector<paddle::Tensor> TanhForward(const paddle::Tensor& x) { std::vector<paddle::Tensor> TanhForward(const paddle::Tensor& x) {
CHECK_CPU_INPUT(x); CHECK_CPU_INPUT(x);
auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); auto out = paddle::empty(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES( PD_DISPATCH_FLOATING_TYPES(
x.dtype(), "tanh_cpu_forward", ([&] { x.dtype(), "tanh_cpu_forward", ([&] {
...@@ -82,7 +82,7 @@ std::vector<paddle::Tensor> TanhForward(const paddle::Tensor& x) { ...@@ -82,7 +82,7 @@ std::vector<paddle::Tensor> TanhForward(const paddle::Tensor& x) {
std::vector<paddle::Tensor> TanhBackward(const paddle::Tensor& out, std::vector<paddle::Tensor> TanhBackward(const paddle::Tensor& out,
const paddle::Tensor& grad_out) { const paddle::Tensor& grad_out) {
CHECK_CPU_INPUT(out); CHECK_CPU_INPUT(out);
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); auto grad_x = paddle::empty(out.shape(), out.dtype(), out.place());
PD_DISPATCH_FLOATING_TYPES(out.dtype(), "tanh_cpu_backward", ([&] { PD_DISPATCH_FLOATING_TYPES(out.dtype(), "tanh_cpu_backward", ([&] {
tanh_cpu_backward_kernel<data_t>( tanh_cpu_backward_kernel<data_t>(
...@@ -101,8 +101,8 @@ std::vector<paddle::Tensor> TanhDoubleBackward(const paddle::Tensor& out, ...@@ -101,8 +101,8 @@ std::vector<paddle::Tensor> TanhDoubleBackward(const paddle::Tensor& out,
CHECK_CPU_INPUT(out); CHECK_CPU_INPUT(out);
CHECK_CPU_INPUT(ddx); CHECK_CPU_INPUT(ddx);
CHECK_CPU_INPUT(dout); CHECK_CPU_INPUT(dout);
auto dout_new = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); auto dout_new = paddle::empty(out.shape(), out.dtype(), out.place());
auto ddout = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); auto ddout = paddle::empty(out.shape(), out.dtype(), out.place());
PD_DISPATCH_FLOATING_TYPES(out.dtype(), "tanh_cpu_double_backward", ([&] { PD_DISPATCH_FLOATING_TYPES(out.dtype(), "tanh_cpu_double_backward", ([&] {
tanh_cpu_double_backward_kernel<data_t>( tanh_cpu_double_backward_kernel<data_t>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册