未验证 提交 93d01787 编写于 作者: H huangjiyi 提交者: GitHub

register fluid kerenls to phi [part 1] (#52014)

* update assign_pos

* update attention_lstm

* update barrier

* update batch_fc

* update beam_search

* update beam_search_decode

* update bilateral_slice

* fix bug

* Handle Structure kernel for InterpreterCore::RunOperator

* fix bug

* fix rocm compile

* fix rocm compile

* Revert "fix rocm compile"

* test

* revert test and update cmake

---------
Co-authored-by: Nchenruibiao <chenruibiao@baidu.com>
上级 70ebef81
...@@ -475,6 +475,9 @@ function(op_library TARGET) ...@@ -475,6 +475,9 @@ function(op_library TARGET)
foreach(hip_src ${hip_srcs}) foreach(hip_src ${hip_srcs})
set(op_name "") set(op_name "")
find_register(${hip_src} "REGISTER_OP_CUDA_KERNEL" op_name) find_register(${hip_src} "REGISTER_OP_CUDA_KERNEL" op_name)
find_phi_register(${hip_src} ${pybind_file} "PD_REGISTER_KERNEL")
find_phi_register(${hip_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL")
find_phi_register(${hip_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL")
if(NOT ${op_name} EQUAL "") if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n")
set(pybind_flag 1) set(pybind_flag 1)
......
...@@ -947,12 +947,14 @@ void InterpreterCore::RunOperator(const Instruction& instr_node) { ...@@ -947,12 +947,14 @@ void InterpreterCore::RunOperator(const Instruction& instr_node) {
platform::TracerEventType::OperatorInner, platform::TracerEventType::OperatorInner,
1, 1,
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
if (op_with_kernel == nullptr) { if (op_with_kernel == nullptr) { // operator base
instr_node.OpBase()->Run(*local_scope, place_); instr_node.OpBase()->Run(*local_scope, place_);
} else { } else {
// fit for phi phi::Kernel* kernel = instr_node.PhiKernel();
if (instr_node.PhiKernel() && instr_node.PhiKernel()->IsValid()) { if (kernel && kernel->IsValid()) { // phi kernel
VLOG(4) << "Run phi kernel: " << op->Type(); if (kernel->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
VLOG(4) << "Run function kernel: " << op->Type();
VLOG(4) << instr_node.InnerRuntimeContext().get() << " " VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
<< &instr_node.DeviceContext(); << &instr_node.DeviceContext();
phi::KernelContext phi_kernel_context; phi::KernelContext phi_kernel_context;
...@@ -961,9 +963,12 @@ void InterpreterCore::RunOperator(const Instruction& instr_node) { ...@@ -961,9 +963,12 @@ void InterpreterCore::RunOperator(const Instruction& instr_node) {
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()), const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()),
&phi_kernel_context); &phi_kernel_context);
(*instr_node.PhiKernel())(&phi_kernel_context); (*kernel)(&phi_kernel_context);
} else { } else {
VLOG(4) << "Run structure kernel: " << op->Type();
(*kernel)(instr_node.InnerExecutionContext().get());
}
} else { // fluid kernel
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
} }
} }
......
...@@ -79,6 +79,5 @@ REGISTER_OP_WITHOUT_GRADIENT(assign_pos, ...@@ -79,6 +79,5 @@ REGISTER_OP_WITHOUT_GRADIENT(assign_pos,
ops::AssignPosOp, ops::AssignPosOp,
ops::AssignPosOpMaker); ops::AssignPosOpMaker);
REGISTER_OP_CPU_KERNEL(assign_pos, PD_REGISTER_STRUCT_KERNEL(
ops::AssignPosOpCPUKernel<int>, assign_pos, CPU, ALL_LAYOUT, ops::AssignPosOpCPUKernel, int, int64_t) {}
ops::AssignPosOpCPUKernel<int64_t>);
...@@ -53,7 +53,7 @@ __global__ void AssignPos(T* cum_count, ...@@ -53,7 +53,7 @@ __global__ void AssignPos(T* cum_count,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class AssignPosCUDAKernel : public framework::OpKernel<T> { class AssignPosCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -102,4 +102,6 @@ class AssignPosCUDAKernel : public framework::OpKernel<T> { ...@@ -102,4 +102,6 @@ class AssignPosCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(assign_pos, ops::AssignPosCUDAKernel<int64_t>);
PD_REGISTER_STRUCT_KERNEL(
assign_pos, GPU, ALL_LAYOUT, ops::AssignPosCUDAKernel, int64_t) {}
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class AssignPosOpCPUKernel : public framework::OpKernel<T> { class AssignPosOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -340,12 +340,10 @@ inline void vec_softmax(const int n, const T* x, T* y) { ...@@ -340,12 +340,10 @@ inline void vec_softmax(const int n, const T* x, T* y) {
phi::funcs::vec_scal<T>(n, static_cast<T>(1) / scalar, y); // scale phi::funcs::vec_scal<T>(n, static_cast<T>(1) / scalar, y); // scale
} }
template <typename T> template <typename T, typename DeviceContext>
class AttentionLSTMKernel : public framework::OpKernel<T> { class AttentionLSTMKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = phi::CPUContext;
auto* x = ctx.Input<phi::DenseTensor>("X"); auto* x = ctx.Input<phi::DenseTensor>("X");
auto* h0 = ctx.Input<phi::DenseTensor>("H0"); auto* h0 = ctx.Input<phi::DenseTensor>("H0");
auto* c0 = ctx.Input<phi::DenseTensor>("C0"); auto* c0 = ctx.Input<phi::DenseTensor>("C0");
...@@ -525,6 +523,5 @@ REGISTER_OPERATOR(attention_lstm, ...@@ -525,6 +523,5 @@ REGISTER_OPERATOR(attention_lstm,
ops::AttentionLSTMOp, ops::AttentionLSTMOp,
ops::AttentionLSTMOpMaker); ops::AttentionLSTMOpMaker);
REGISTER_OP_CPU_KERNEL(attention_lstm, PD_REGISTER_STRUCT_KERNEL(
ops::AttentionLSTMKernel<float>, attention_lstm, CPU, ALL_LAYOUT, ops::AttentionLSTMKernel, float, double) {}
ops::AttentionLSTMKernel<double>);
...@@ -165,6 +165,5 @@ REGISTER_OPERATOR(batch_fc_grad, ...@@ -165,6 +165,5 @@ REGISTER_OPERATOR(batch_fc_grad,
ops::BatchFCGradOp, ops::BatchFCGradOp,
ops::BatchFCGradOpNoNeedBufferVarsInferer); ops::BatchFCGradOpNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(batch_fc, PD_REGISTER_STRUCT_KERNEL(
ops::BatchFCKernel<phi::CPUContext, float>, batch_fc, CPU, ALL_LAYOUT, ops::BatchFCKernel, float, double) {}
ops::BatchFCKernel<phi::CPUContext, double>);
...@@ -85,7 +85,7 @@ void add_bias_grad(gpuStream_t stream, ...@@ -85,7 +85,7 @@ void add_bias_grad(gpuStream_t stream,
dout_data, slot_pairs_num, ins_num, out_dim, db_data); dout_data, slot_pairs_num, ins_num, out_dim, db_data);
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class BatchFCCUDAKernel : public framework::OpKernel<T> { class BatchFCCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -149,7 +149,7 @@ class BatchFCCUDAKernel : public framework::OpKernel<T> { ...@@ -149,7 +149,7 @@ class BatchFCCUDAKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class BatchFCGradOpCUDAKernel : public framework::OpKernel<T> { class BatchFCGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -238,10 +238,13 @@ class BatchFCGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -238,10 +238,13 @@ class BatchFCGradOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
using GPUCtx = phi::GPUContext; using GPUCtx = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(batch_fc,
ops::BatchFCCUDAKernel<GPUCtx, float>,
ops::BatchFCCUDAKernel<GPUCtx, double>);
REGISTER_OP_CUDA_KERNEL(batch_fc_grad, PD_REGISTER_STRUCT_KERNEL(
ops::BatchFCGradOpCUDAKernel<GPUCtx, float>, batch_fc, GPU, ALL_LAYOUT, ops::BatchFCCUDAKernel, float, double) {}
ops::BatchFCGradOpCUDAKernel<GPUCtx, double>);
PD_REGISTER_STRUCT_KERNEL(batch_fc_grad,
GPU,
ALL_LAYOUT,
ops::BatchFCGradOpCUDAKernel,
float,
double) {}
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class BatchFCKernel : public framework::OpKernel<T> { class BatchFCKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -108,10 +108,12 @@ REGISTER_OPERATOR(beam_search_decode, ...@@ -108,10 +108,12 @@ REGISTER_OPERATOR(beam_search_decode,
paddle::operators::BeamSearchDecodeOpProtoMaker, paddle::operators::BeamSearchDecodeOpProtoMaker,
paddle::operators::BeamSearchDecodeInferVarType); paddle::operators::BeamSearchDecodeInferVarType);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(beam_search_decode,
beam_search_decode, CPU,
ops::BeamSearchDecodeOpKernel<phi::CPUContext, float>, ALL_LAYOUT,
ops::BeamSearchDecodeOpKernel<phi::CPUContext, double>, ops::BeamSearchDecodeOpKernel,
ops::BeamSearchDecodeOpKernel<phi::CPUContext, paddle::platform::float16>, float,
ops::BeamSearchDecodeOpKernel<phi::CPUContext, int>, double,
ops::BeamSearchDecodeOpKernel<phi::CPUContext, int64_t>); paddle::platform::float16,
int,
int64_t) {}
...@@ -16,10 +16,13 @@ limitations under the License. */ ...@@ -16,10 +16,13 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
beam_search_decode, PD_REGISTER_STRUCT_KERNEL(beam_search_decode,
ops::BeamSearchDecodeOpKernel<phi::GPUContext, float>, GPU,
ops::BeamSearchDecodeOpKernel<phi::GPUContext, double>, ALL_LAYOUT,
ops::BeamSearchDecodeOpKernel<phi::GPUContext, paddle::platform::float16>, ops::BeamSearchDecodeOpKernel,
ops::BeamSearchDecodeOpKernel<phi::GPUContext, int>, float,
ops::BeamSearchDecodeOpKernel<phi::GPUContext, int64_t>); double,
paddle::platform::float16,
int,
int64_t) {}
...@@ -123,7 +123,7 @@ struct BeamSearchDecodeFunctor { ...@@ -123,7 +123,7 @@ struct BeamSearchDecodeFunctor {
phi::DenseTensor* score_tensor_; phi::DenseTensor* score_tensor_;
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class BeamSearchDecodeOpKernel : public framework::OpKernel<T> { class BeamSearchDecodeOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -148,8 +148,12 @@ REGISTER_OPERATOR(beam_search, ...@@ -148,8 +148,12 @@ REGISTER_OPERATOR(beam_search,
ops::BeamSearchOp, ops::BeamSearchOp,
ops::BeamSearchOpMaker, ops::BeamSearchOpMaker,
ops::BeamSearchInferVarType); ops::BeamSearchInferVarType);
REGISTER_OP_CPU_KERNEL(beam_search,
ops::BeamSearchOpKernel<phi::CPUContext, float>, PD_REGISTER_STRUCT_KERNEL(beam_search,
ops::BeamSearchOpKernel<phi::CPUContext, double>, CPU,
ops::BeamSearchOpKernel<phi::CPUContext, int>, ALL_LAYOUT,
ops::BeamSearchOpKernel<phi::CPUContext, int64_t>); ops::BeamSearchOpKernel,
float,
double,
int,
int64_t) {}
...@@ -17,8 +17,12 @@ limitations under the License. */ ...@@ -17,8 +17,12 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(beam_search,
ops::BeamSearchOpKernel<phi::GPUContext, float>, PD_REGISTER_STRUCT_KERNEL(beam_search,
ops::BeamSearchOpKernel<phi::GPUContext, double>, GPU,
ops::BeamSearchOpKernel<phi::GPUContext, int>, ALL_LAYOUT,
ops::BeamSearchOpKernel<phi::GPUContext, int64_t>); ops::BeamSearchOpKernel,
float,
double,
int,
int64_t) {}
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class BeamSearchOpKernel : public framework::OpKernel<T> { class BeamSearchOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -16,9 +16,10 @@ limitations under the License. */ ...@@ -16,9 +16,10 @@ limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h" #include "paddle/fluid/operators/beam_search_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL( using NPUCtx = paddle::platform::NPUDeviceContext;
beam_search,
ops::BeamSearchOpKernel<paddle::platform::NPUDeviceContext, float>, REGISTER_OP_NPU_KERNEL(beam_search,
ops::BeamSearchOpKernel<paddle::platform::NPUDeviceContext, double>, ops::BeamSearchOpKernel<float, NPUCtx>,
ops::BeamSearchOpKernel<paddle::platform::NPUDeviceContext, int>, ops::BeamSearchOpKernel<double, NPUCtx>,
ops::BeamSearchOpKernel<paddle::platform::NPUDeviceContext, int64_t>); ops::BeamSearchOpKernel<int, NPUCtx>,
ops::BeamSearchOpKernel<int64_t, NPUCtx>);
...@@ -18,10 +18,11 @@ limitations under the License. */ ...@@ -18,10 +18,11 @@ limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h" #include "paddle/fluid/operators/beam_search_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( using XPUCtx = paddle::platform::XPUDeviceContext;
beam_search,
ops::BeamSearchOpKernel<paddle::platform::XPUDeviceContext, float>, REGISTER_OP_XPU_KERNEL(beam_search,
ops::BeamSearchOpKernel<paddle::platform::XPUDeviceContext, double>, ops::BeamSearchOpKernel<float, XPUCtx>,
ops::BeamSearchOpKernel<paddle::platform::XPUDeviceContext, int>, ops::BeamSearchOpKernel<double, XPUCtx>,
ops::BeamSearchOpKernel<paddle::platform::XPUDeviceContext, int64_t>); ops::BeamSearchOpKernel<int, XPUCtx>,
ops::BeamSearchOpKernel<int64_t, XPUCtx>);
#endif #endif
...@@ -175,7 +175,7 @@ class BilateralSliceGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -175,7 +175,7 @@ class BilateralSliceGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
template <typename T> template <typename T, typename DeviceContext>
class BilateralSliceKernel : public framework::OpKernel<T> { class BilateralSliceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -196,6 +196,10 @@ REGISTER_OPERATOR(bilateral_slice, ...@@ -196,6 +196,10 @@ REGISTER_OPERATOR(bilateral_slice,
ops::BilateralSliceGradMaker<paddle::framework::OpDesc>, ops::BilateralSliceGradMaker<paddle::framework::OpDesc>,
ops::BilateralSliceGradMaker<paddle::imperative::OpBase>); ops::BilateralSliceGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bilateral_slice_grad, ops::BilateralSliceOpGrad); REGISTER_OPERATOR(bilateral_slice_grad, ops::BilateralSliceOpGrad);
REGISTER_OP_CPU_KERNEL(bilateral_slice,
ops::BilateralSliceKernel<float>, PD_REGISTER_STRUCT_KERNEL(bilateral_slice,
ops::BilateralSliceKernel<double>); CPU,
ALL_LAYOUT,
ops::BilateralSliceKernel,
float,
double) {}
...@@ -126,7 +126,7 @@ __global__ void BilateralSliceCudaForwardKernel(T* output, ...@@ -126,7 +126,7 @@ __global__ void BilateralSliceCudaForwardKernel(T* output,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class BilateralSliceOpCUDAKernel : public framework::OpKernel<T> { class BilateralSliceOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -442,7 +442,7 @@ __global__ void BilateralSliceCudaInputGradKernel(T* out_input_grad, ...@@ -442,7 +442,7 @@ __global__ void BilateralSliceCudaInputGradKernel(T* out_input_grad,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> { class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -557,9 +557,16 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -557,9 +557,16 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(bilateral_slice,
ops::BilateralSliceOpCUDAKernel<float>, PD_REGISTER_STRUCT_KERNEL(bilateral_slice,
ops::BilateralSliceOpCUDAKernel<double>); GPU,
REGISTER_OP_CUDA_KERNEL(bilateral_slice_grad, ALL_LAYOUT,
ops::BilateralSliceGradOpCUDAKernel<float>, ops::BilateralSliceOpCUDAKernel,
ops::BilateralSliceGradOpCUDAKernel<double>); float,
double) {}
PD_REGISTER_STRUCT_KERNEL(bilateral_slice_grad,
GPU,
ALL_LAYOUT,
ops::BilateralSliceGradOpCUDAKernel,
float,
double) {}
...@@ -44,4 +44,6 @@ namespace ops = paddle::operators; ...@@ -44,4 +44,6 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(barrier, ops::BarrierOp, ops::BarrierOpMaker); REGISTER_OP_WITHOUT_GRADIENT(barrier, ops::BarrierOp, ops::BarrierOpMaker);
REGISTER_OP_CPU_KERNEL(barrier, ops::BarrierOpCPUKernel<int>);
PD_REGISTER_STRUCT_KERNEL(
barrier, CPU, ALL_LAYOUT, ops::BarrierOpCPUKernel, int) {}
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class BarrierOpCUDAKernel : public framework::OpKernel<T> { class BarrierOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -58,4 +58,5 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> { ...@@ -58,4 +58,5 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(barrier, ops::BarrierOpCUDAKernel<int>); PD_REGISTER_STRUCT_KERNEL(
barrier, GPU, ALL_LAYOUT, ops::BarrierOpCUDAKernel, int) {}
...@@ -32,7 +32,7 @@ limitations under the License. */ ...@@ -32,7 +32,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class BarrierOpCPUKernel : public framework::OpKernel<T> { class BarrierOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册