diff --git a/paddle/fluid/extension/include/dtype.h b/paddle/fluid/extension/include/dtype.h index 38c836c6fc7c09280e2cd66a2f63eb53cedb36d0..2fbeaf9262046de40b3193482bfb64acfed0cbc4 100644 --- a/paddle/fluid/extension/include/dtype.h +++ b/paddle/fluid/extension/include/dtype.h @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include #include #include diff --git a/paddle/fluid/extension/include/tensor.h b/paddle/fluid/extension/include/tensor.h index 061dc3ded2fc67761a5c376ffbdc1442a669e4f2..e6066b42322b0c41955f5669c0ca468ee2d3573b 100644 --- a/paddle/fluid/extension/include/tensor.h +++ b/paddle/fluid/extension/include/tensor.h @@ -19,12 +19,32 @@ limitations under the License. */ #include "paddle/fluid/extension/include/dll_decl.h" #include "paddle/fluid/extension/include/dtype.h" #include "paddle/fluid/extension/include/place.h" - +#ifdef PADDLE_WITH_CUDA +#include +#endif namespace paddle { namespace framework { class CustomTensorUtils; } // namespace framework +class StreamWrapper { + public: + StreamWrapper() : stream_(nullptr), is_stream_set_(false) {} + void SetStream(void* stream) { + stream_ = stream; + is_stream_set_ = true; + } + + void* GetStream() const { return stream_; } + + bool IsStreamSet() const { return is_stream_set_; } + + private: + // cudaStream_t stream_; + void* stream_; + bool is_stream_set_; +}; + class PD_DLL_DECL Tensor { public: /// \brief Construct a Tensor on target Place for CustomOp. @@ -88,10 +108,16 @@ class PD_DLL_DECL Tensor { /// \brief Cast datatype from one to another Tensor cast(const DataType& target_type) const; +#ifdef PADDLE_WITH_CUDA + /// \bref Get current stream of Tensor + cudaStream_t stream() const; +#endif + private: friend class framework::CustomTensorUtils; mutable std::shared_ptr tensor_; mutable PlaceType place_; + StreamWrapper stream_; }; } // namespace paddle diff --git a/paddle/fluid/extension/src/tensor.cc b/paddle/fluid/extension/src/tensor.cc index dc7e3607bdfa8f22987e1118fd9ade54621d8c46..fa8c3c4f090f0b0ac7f099e1db2be2d18bc48201 100644 --- a/paddle/fluid/extension/src/tensor.cc +++ b/paddle/fluid/extension/src/tensor.cc @@ -101,8 +101,9 @@ void Tensor::reshape(const std::vector &shape) { } Tensor::Tensor(const PlaceType &place) - : tensor_(std::make_shared()), place_(place) {} - + : tensor_(std::make_shared()), + place_(place), + stream_(StreamWrapper()) {} template T *Tensor::mutable_data(const PlaceType &place) { place_ = place; @@ -323,6 +324,18 @@ int64_t Tensor::size() const { return tensor->numel(); } +#ifdef PADDLE_WITH_CUDA +cudaStream_t Tensor::stream() const { + if (!stream_.IsStreamSet()) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Stream is not Set, only input tensor will have " + "stream which is set by framework ")); + } else { + return reinterpret_cast(stream_.GetStream()); + } +} +#endif + namespace framework { void CustomTensorUtils::ShareDataTo(const paddle::Tensor &src, void *dst) { diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 90831afc9ba89185dbe85dbf54bb38ea3ffbace6..582e328dcfdfcb04d68b821fed47863c37a4066e 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -114,6 +114,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, auto custom_in = paddle::Tensor( CustomTensorUtils::ConvertInnerPlaceToEnumPlace(x->place())); CustomTensorUtils::ShareDataFrom(static_cast(x), custom_in); + CustomTensorUtils::SetTensorCurrentStream(&custom_in, ctx.GetPlace()); custom_ins.emplace_back(custom_in); } diff --git a/paddle/fluid/framework/custom_tensor_utils.h b/paddle/fluid/framework/custom_tensor_utils.h index 1dc4e06e572c12791bcb0e6efa9fc40bcd760966..f481d2881dd6707c1d50b58e0eb13265e8b62c13 100644 --- a/paddle/fluid/framework/custom_tensor_utils.h +++ b/paddle/fluid/framework/custom_tensor_utils.h @@ -20,6 +20,9 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/place.h" +#ifdef PADDLE_WITH_CUDA +#endif +#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace framework { @@ -123,6 +126,19 @@ class CustomTensorUtils { } return PlaceType::kUNK; } + + static void SetTensorCurrentStream(paddle::Tensor* src, + const platform::Place& pc) { + if (platform::is_gpu_place(pc)) { +#ifdef PADDLE_WITH_CUDA + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(pc)); + src->stream_.SetStream(reinterpret_cast(dev_ctx->stream())); +#endif + } else { + return; + } + } }; } // namespace framework diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index e6e5135316aba052f2e8668768bed85d0398db79..2a3b6424d4a14e1cd6345cf24594582bd19f51d4 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -91,6 +91,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, const framework::AttributeMap& attrs) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); + framework::RuntimeContext ctx({}, {}); #ifdef PADDLE_WITH_MKLDNN diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc index 4b8d3bca63695c0abfb5dafc0423e292fd157c07..e70c1b39707e1b306ec1949e8a7598dca55f14cb 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc @@ -39,8 +39,8 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data, std::vector relu_cpu_forward(const paddle::Tensor& x) { auto out = paddle::Tensor(paddle::PlaceType::kCPU); - out.reshape(x.shape()); + out.reshape(x.shape()); PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cpu_forward", ([&] { relu_cpu_forward_kernel( diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu index a9ce5176070939be24a8e6d965faa60b6f391bff..be3309d84f57d6f4f000f920339b06dc370c85a8 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu @@ -37,14 +37,14 @@ __global__ void relu_cuda_backward_kernel(const data_t* dy, std::vector relu_cuda_forward(const paddle::Tensor& x) { auto out = paddle::Tensor(paddle::PlaceType::kGPU); - out.reshape(x.shape()); + out.reshape(x.shape()); int numel = x.size(); int block = 512; int grid = (numel + block - 1) / block; PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cuda_forward_kernel", ([&] { - relu_cuda_forward_kernel<<>>( + relu_cuda_forward_kernel<<>>( x.data(), out.mutable_data(x.place()), numel); })); @@ -62,7 +62,7 @@ std::vector relu_cuda_backward(const paddle::Tensor& x, int grid = (numel + block - 1) / block; PD_DISPATCH_FLOATING_TYPES( out.type(), "relu_cuda_backward_kernel", ([&] { - relu_cuda_backward_kernel<<>>( + relu_cuda_backward_kernel<<>>( grad_out.data(), out.data(), grad_x.mutable_data(x.place()),