未验证 提交 038ce70d 编写于 作者: J Jiabin Yang 提交者: GitHub

[Custom OP] Support stream set on Custom Op (#31257)

上级 1dd40870
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdint>
#include <stdexcept>
#include <string>
......
......@@ -19,12 +19,32 @@ limitations under the License. */
#include "paddle/fluid/extension/include/dll_decl.h"
#include "paddle/fluid/extension/include/dtype.h"
#include "paddle/fluid/extension/include/place.h"
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
namespace paddle {
namespace framework {
class CustomTensorUtils;
} // namespace framework
class StreamWrapper {
public:
StreamWrapper() : stream_(nullptr), is_stream_set_(false) {}
void SetStream(void* stream) {
stream_ = stream;
is_stream_set_ = true;
}
void* GetStream() const { return stream_; }
bool IsStreamSet() const { return is_stream_set_; }
private:
// cudaStream_t stream_;
void* stream_;
bool is_stream_set_;
};
class PD_DLL_DECL Tensor {
public:
/// \brief Construct a Tensor on target Place for CustomOp.
......@@ -88,10 +108,16 @@ class PD_DLL_DECL Tensor {
/// \brief Cast datatype from one to another
Tensor cast(const DataType& target_type) const;
#ifdef PADDLE_WITH_CUDA
/// \bref Get current stream of Tensor
cudaStream_t stream() const;
#endif
private:
friend class framework::CustomTensorUtils;
mutable std::shared_ptr<void> tensor_;
mutable PlaceType place_;
StreamWrapper stream_;
};
} // namespace paddle
......@@ -101,8 +101,9 @@ void Tensor::reshape(const std::vector<int> &shape) {
}
Tensor::Tensor(const PlaceType &place)
: tensor_(std::make_shared<framework::LoDTensor>()), place_(place) {}
: tensor_(std::make_shared<framework::LoDTensor>()),
place_(place),
stream_(StreamWrapper()) {}
template <typename T>
T *Tensor::mutable_data(const PlaceType &place) {
place_ = place;
......@@ -323,6 +324,18 @@ int64_t Tensor::size() const {
return tensor->numel();
}
#ifdef PADDLE_WITH_CUDA
cudaStream_t Tensor::stream() const {
if (!stream_.IsStreamSet()) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Stream is not Set, only input tensor will have "
"stream which is set by framework "));
} else {
return reinterpret_cast<cudaStream_t>(stream_.GetStream());
}
}
#endif
namespace framework {
void CustomTensorUtils::ShareDataTo(const paddle::Tensor &src, void *dst) {
......
......@@ -114,6 +114,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
auto custom_in = paddle::Tensor(
CustomTensorUtils::ConvertInnerPlaceToEnumPlace(x->place()));
CustomTensorUtils::ShareDataFrom(static_cast<const void*>(x), custom_in);
CustomTensorUtils::SetTensorCurrentStream(&custom_in, ctx.GetPlace());
custom_ins.emplace_back(custom_in);
}
......
......@@ -20,6 +20,9 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#endif
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
......@@ -123,6 +126,19 @@ class CustomTensorUtils {
}
return PlaceType::kUNK;
}
static void SetTensorCurrentStream(paddle::Tensor* src,
const platform::Place& pc) {
if (platform::is_gpu_place(pc)) {
#ifdef PADDLE_WITH_CUDA
auto* dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(pc));
src->stream_.SetStream(reinterpret_cast<void*>(dev_ctx->stream()));
#endif
} else {
return;
}
}
};
} // namespace framework
......
......@@ -91,6 +91,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
const framework::AttributeMap& attrs) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
framework::RuntimeContext ctx({}, {});
#ifdef PADDLE_WITH_MKLDNN
......
......@@ -39,8 +39,8 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data,
std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
out.reshape(x.shape());
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>(
......
......@@ -37,14 +37,14 @@ __global__ void relu_cuda_backward_kernel(const data_t* dy,
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kGPU);
out.reshape(x.shape());
out.reshape(x.shape());
int numel = x.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cuda_forward_kernel", ([&] {
relu_cuda_forward_kernel<data_t><<<grid, block>>>(
relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
}));
......@@ -62,7 +62,7 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_TYPES(
out.type(), "relu_cuda_backward_kernel", ([&] {
relu_cuda_backward_kernel<data_t><<<grid, block>>>(
relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(x.place()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册